Sudoku solver in OCaml

Sudoku is a number placement puzzle. The objective is to fill a 9*9 grid so that each column, each row, and each of the nine 3*3 boxes contain all the numbers from 1 to 9. That means the same digit cannot appear twice in any row, column or 3*3 box.

A sudoku puzzle

This page describes how to solve Sudoku in OCaml programming language.

Strategy for solving Sudoku

There are several strategies for solving sudoku. Here, I adopt a simple one.

  1. Fill each empty cell with candidate numerals 1 to 9.
  2. Eliminate candidate numerals from cells.
  3. Repeat until all candidate numerals are eliminated and only one numeral remains.

Programming in OCaml

Definition of Sudoku grid

Each cell is represented as an int list, which means candidate numerals. A 9*9 grid is represented as a 2-d matrix of cell.

type cell = int list
type matrix = cell ref array array

The above representation is convenient for programs but not suitable for input. We need simpler representation for Sudoku problems. One solution is to represent as vector of string. For example,

let p1 = 
  [
    "58??26???";
    "?6???7??2";
    "?7?????61";
    "2???1935?";
    "???3?4???";
    "?5386???4";
    "89?????4?";
    "7??4???1?";
    "???68??37"
  ]

where '?' means an empty cell. Now, we want to construct a Sudoku matrix from this simpler representation. A constructor of matrix is defined as follows:

let create s =
  let m = Array.make_matrix 9 9 (ref []) in
    for i = 0 to 8 do
      for j = 0 to 8 do
        let n = Char.code (List.nth s i).[j] - Char.code '0' in
          m.(i).(j) <- if n <= 9 && n >= 1 then ref [n] else ref [1;2;3;4;5;6;7;8;9]
      done
    done;
    m
and pretty-printer of matrix is defined as follows:
let print matrix =
  for i = 0 to Array.length matrix - 1 do
    if i mod 3 == 0 then
      print_string "+-----+-----+-----+\n";
    for j = 0 to Array.length matrix.(i) - 1 do
      print_string (if j mod 3 == 0 then "|" else " ");
      ( match !(matrix.(i).(j)) with
            [n] -> print_int n
          | _ -> print_string "?" )
    done;
    print_string "|\n"
  done;
  print_string "+-----+-----+-----+\n"

Accessor to get each row, column, and box

Next, we need accessor functions to get each group of numerals (row, column, and box) from matrix.

let row m nth = let r = Array.make 9 (ref []) in
  for i = 0 to 8 do
    r.(i) <- m.(nth).(i)
  done;
  r

let col m nth = let c = Array.make 9 (ref []) in
  for i = 0 to 8 do
    c.(i) <- m.(i).(nth)
  done;
  c

let box m nth =
  let (r,c) = (nth / 3, nth mod 3) in
  let b = Array.make 9 (ref []) in
    for i = 0 to 2 do
      for j = 0 to 2 do
        b.(i*3+j) <- m.(r*3+i).(c*3+j)
      done
    done;
    b

Judgement of grid

We need a judgement function that checks whether all cells in grid have only one numeral or not. If so, a problem has been solved.

The function 'unsolved' returns true if some cells are remained with several candidate numerals, that is, unsolved.

let unsolved m = let n = ref 0 in
  for i = 0 to 8 do
    for j = 0 to 8 do
      if (List.length !(m.(i).(j))) != 1 then incr n
    done
  done;
  !n

Candidate elimination

Candidate elimination is to remove numerals from each cell if a numeral is included in the decided numerals set.

let eliminate s = 
  let fixed s = 
    Array.fold_left
      (fun ls n -> if List.length !n == 1 then !n @ ls else ls)
      []
      s
  in
  let remove ls1 ls2 =
      List.fold_left (fun ls n -> if List.mem n ls2 then ls else n::ls) [] ls1
  in
  let f = ref (fixed s) in
    for i = 0 to 8 do
      if (List.length !(s.(i))) != 1 then s.(i) := remove !(s.(i)) !f;
      match (List.length !(s.(i))) with
          0 -> raise Unsolvable
        | 1 -> f := !f @ !(s.(i))
        | _ -> ()
    done

Solver

Finally, we can define a solver function as follows:
let solve m = 
  let iter = ref 0 in
  let solve0 m =
    for i = 0 to 8 do
      eliminate (row m i);
      eliminate (col m i);
      eliminate (box m i)
    done
  in
    while (unsolved m) > 0 do
      solve0 m;
      incr iter
    done;
    print m;
    Printf.printf "iter: %d\n" !iter

Copyright (c) 2008 by Kenta Hattori Contact