3 min read

Solving Squares Sudoku using Z3

Table of Contents

Squares Sudoku

In addition to the normal Sudoku rules, there is one additional rule for a Squares Sudoku puzzle - sum of the numbers in each cage should be a perfect square.

Squares Sudoku puzzle

Here is a hard Squares Sudoku puzzle

Solution using Z3

from z3 import Solver, And, Int, Distinct, sat, If, Or

puzzle = [
    [(0,0),(0,1)],
    [(0,2),(0,3)],
    [(0,4),(1,3),(1,4)],
    [(0,5),(1,5),(0,6)],
    [(0,7),(1,7),(1,6),(2,6)],
    [(0,8),(1,8),(2,8)],
    [(1,0),(2,0),(3,0)],
    [(1,1),(2,1),(3,1),(1,2),(2,2)],
    [(2,3),(2,4),(2,5)],
    [(3,6),(3,7),(2,7)],
    [(4,0),(4,1),(4,2),(5,0)],
    [(4,6),(5,6),(5,7)],
    [(4,7),(4,8),(3,8)],
    [(5,8),(6,8)],
    [(6,0),(7,0)],
    [(6,1),(6,2)],
    [(6,3),(6,4),(7,3),(7,4)],
    [(6,5),(7,5),(8,5)],
    [(6,6),(6,7)],
    [(7,1),(7,2)],
    [(7,6),(7,7),(7,8)],
    [(8,0),(8,1),(8,2)],
    [(8,3),(8,4)],
    [(8,6),(8,7),(8,8)],
]

def print_grid(mod, x, rows, cols):
    for i in range(rows):
        print("  ".join([str(mod.eval(x[i][j])) for j in range(cols)]))

def solveSqudoku(puzzle, n):
    X = [[Int("x_%s_%s" % (i+1, j+1)) for j in range(n)] for i in range(n)]

    # each cell contains a value in {1, ..., n}
    cells_c = [And(1 <= X[i][j], X[i][j] <= n) for i in range(n)
               for j in range(n)]

    # each row contains a digit at most once
    rows_c = [Distinct(X[i]) for i in range(n)]

    # each column contains a digit at most once
    cols_c = [Distinct([X[i][j] for i in range(n)]) for j in range(n)]

    # each 3x3 square contains a digit at most once
    sq_c = [ Distinct([ X[3*i0 + i][3*j0 + j]
                            for i in range(3) for j in range(3) ])
                for i0 in range(3) for j0 in range(3) ]

    # sum of numbers in each cage is a square
    puzz_c =[]
    for cage in puzzle:
        cs = sum([X[i][j] for i,j in cage])
        puzz_c.append(Or([(cs==k) for k in [4, 9, 16, 25]]))

    squdoku_c = cells_c + rows_c + cols_c + [And(puzz_c)] + sq_c

    s = Solver()
    s.add(squdoku_c)
    if s.check() == sat:
        m = s.model()
        print("Here is the solution")
        print_grid(m, X, n, n)
    else:
        print("Failed to solve the puzzle")

solveSqudoku(puzzle, 9)

Here is the solution:

6 3 4 5 9 1 8 7 2
8 2 1 3 4 5 7 9 6
5 7 9 8 2 6 4 3 1
3 6 7 2 1 8 9 4 5
1 9 2 7 5 4 6 8 3
4 5 8 9 6 3 1 2 7
7 4 5 6 8 2 3 1 9
9 1 3 4 7 5 2 6 8
2 8 6 1 3 9 7 5 4