Rekursiver Sudoku-Löser mit Python
Ein Sudoku-Löser, der rekursiv arbeitet. Ich würde mich über Ihre Kommentare zu Codierungsstil, -struktur und Verbesserungsmöglichkeiten freuen. Vielen Dank für Ihre Zeit.
Codestruktur
Der Solver akzeptiert 81 Zeichen für die Eingabe des Sudoku-Puzzles. Nullen werden als leere Zellen genommen. Es analysiert es in ein 9x9 Numpy-Array.
Die get_candidatesFunktion erstellt Listen möglicher Ziffern, um jede Zelle gemäß den Sudoku-Regeln zu füllen (keine Wiederholung von 1-9 Ziffern entlang von Zeilen, Spalten und 3x3-Teilgittern).
Die Hauptlöserfunktion ist solve. Erstens werden falsche Kandidaten mit der filter-candidatesFunktion verworfen . "Falsche Kandidaten" sind solche, die, wenn sie in eine leere Zelle gefüllt werden, dazu führen, dass eine andere Zelle keine Kandidaten mehr an anderer Stelle im Sudoku-Raster hat.
Wird nach dem Filtern von Kandidaten fill_singlesaufgerufen, leere Zellen zu füllen, die nur noch einen Kandidaten haben. Wenn dieser Prozess zu einem vollständig gefüllten Sudoku-Gitter führt, wird es als Lösung zurückgegeben. Es gibt eine zurückzugebende Klausel, mit Noneder Änderungen durch die make_guessFunktion zurückverfolgt werden . Diese Funktion füllt die nächste leere Zelle mit der geringsten Anzahl von Kandidaten mit einem ihrer Kandidaten, einem "Schätzwert". Es wird dann rekursiv aufgerufen solve, entweder eine Lösung zu finden oder ein Raster ohne Lösung zu erreichen (in diesem Fall wird solvezurückgegeben Noneund die letzten Vermutungsänderungen werden zurückgesetzt).
from copy import deepcopy
import numpy as np
def create_grid(puzzle_str: str) -> np.ndarray:
"""Create a 9x9 Sudoku grid from a string of digits"""
# Deleting whitespaces and newlines (\n)
lines = puzzle_str.replace(' ','').replace('\n','')
digits = list(map(int, lines))
# Turning it to a 9x9 numpy array
grid = np.array(digits).reshape(9,9)
return grid
def get_subgrids(grid: np.ndarray) -> np.ndarray:
"""Divide the input grid into 9 3x3 sub-grids"""
subgrids = []
for box_i in range(3):
for box_j in range(3):
subgrid = []
for i in range(3):
for j in range(3):
subgrid.append(grid[3*box_i + i][3*box_j + j])
subgrids.append(subgrid)
return np.array(subgrids)
def get_candidates(grid : np.ndarray) -> list:
"""Get a list of candidates to fill empty cells of the input grid"""
def subgrid_index(i, j):
return (i//3) * 3 + j // 3
subgrids = get_subgrids(grid)
grid_candidates = []
for i in range(9):
row_candidates = []
for j in range(9):
# Row, column and subgrid digits
row = set(grid[i])
col = set(grid[:, j])
sub = set(subgrids[subgrid_index(i, j)])
common = row | col | sub
candidates = set(range(10)) - common
# If the case is filled take its value as the only candidate
if not grid[i][j]:
row_candidates.append(list(candidates))
else:
row_candidates.append([grid[i][j]])
grid_candidates.append(row_candidates)
return grid_candidates
def is_valid_grid(grid : np.ndarray) -> bool:
"""Verify the input grid has a possible solution"""
candidates = get_candidates(grid)
for i in range(9):
for j in range(9):
if len(candidates[i][j]) == 0:
return False
return True
def is_solution(grid : np.ndarray) -> bool:
"""Verify if the input grid is a solution"""
if np.all(np.sum(grid, axis=1) == 45) and \
np.all(np.sum(grid, axis=0) == 45) and \
np.all(np.sum(get_subgrids(grid), axis=1) == 45):
return True
return False
def filter_candidates(grid : np.ndarray) -> list:
"""Filter input grid's list of candidates"""
test_grid = grid.copy()
candidates = get_candidates(grid)
filtered_candidates = deepcopy(candidates)
for i in range(9):
for j in range(9):
# Check for empty cells
if grid[i][j] == 0:
for candidate in candidates[i][j]:
# Use test candidate
test_grid[i][j] = candidate
# Remove candidate if it produces an invalid grid
if not is_valid_grid(fill_singles(test_grid)):
filtered_candidates[i][j].remove(candidate)
# Revert changes
test_grid[i][j] = 0
return filtered_candidates
def merge(candidates_1 : list, candidates_2 : list) -> list:
"""Take shortest candidate list from inputs for each cell"""
candidates_min = []
for i in range(9):
row = []
for j in range(9):
if len(candidates_1[i][j]) < len(candidates_2[i][j]):
row.append(candidates_1[i][j][:])
else:
row.append(candidates_2[i][j][:])
candidates_min.append(row)
return candidates_min
def fill_singles(grid : np.ndarray, candidates=None) -> np.ndarray:
"""Fill input grid's cells with single candidates"""
grid = grid.copy()
if not candidates:
candidates = get_candidates(grid)
any_fill = True
while any_fill:
any_fill = False
for i in range(9):
for j in range(9):
if len(candidates[i][j]) == 1 and grid[i][j] == 0:
grid[i][j] = candidates[i][j][0]
candidates = merge(get_candidates(grid), candidates)
any_fill = True
return grid
def make_guess(grid : np.ndarray, candidates=None) -> np.ndarray:
"""Fill next empty cell with least candidates with first candidate"""
grid = grid.copy()
if not candidates:
candidates = get_candidates(grid)
# Getting the shortest number of candidates > 1:
min_len = sorted(list(set(map(
len, np.array(candidates).reshape(1,81)[0]))))[1]
for i in range(9):
for j in range(9):
if len(candidates[i][j]) == min_len:
for guess in candidates[i][j]:
grid[i][j] = guess
solution = solve(grid)
if solution is not None:
return solution
# Discarding a wrong guess
grid[i][j] = 0
def solve(grid : np.ndarray) -> np.ndarray:
"""Recursively find a solution filtering candidates and guessing values"""
candidates = filter_candidates(grid)
grid = fill_singles(grid, candidates)
if is_solution(grid):
return grid
if not is_valid_grid(grid):
return None
return make_guess(grid, candidates)
# # Example usage
# puzzle = """100920000
# 524010000
# 000000070
# 050008102
# 000000000
# 402700090
# 060000000
# 000030945
# 000071006"""
# grid = create_grid(puzzle)
# solve(grid)
```
Antworten
Ich konnte die Leistung des Programms um ungefähr 900% verbessern, ohne einen Großteil des Algorithmus in ungefähr einer Stunde zu verstehen oder zu ändern. Folgendes habe ich getan:
Zunächst benötigen Sie einen Benchmark. Es ist sehr einfach, nur mal dein Programm
start = time.time()
solve(grid)
print(time.time()-start)
Auf meinem Computer dauerte es ungefähr 4,5 Sekunden. Dies ist unsere Basis.
Das nächste ist das Profil. Das Tool, das ich ausgewählt habe, ist VizTracer, das von mir selbst entwickelt wurde :)https://github.com/gaogaotiantian/viztracer
VizTracer generiert einen HTML-Bericht (oder json, der durch chrome :: // tracing geladen werden könnte) der Zeitachse Ihrer Codeausführung. In Ihrer Originalversion sieht es so aus:
Wie Sie sehen, gibt es dort viele Anrufe. Wir müssen herausfinden, was hier der Engpass ist. Die Struktur ist nicht kompliziert, viele fill_singleswerden aufgerufen, und wir müssen zoomen, um zu überprüfen, was dort drin ist.
Es ist sehr klar, dass dies get_candidatesdie Funktion ist, die die meiste Zeit in verursacht hat fill_singlesund die den größten Teil der Zeitachse belegt. Das ist also die Funktion, die wir uns zuerst ansehen wollen.
def get_candidates(grid : np.ndarray) -> list:
"""Get a list of candidates to fill empty cells of the input grid"""
def subgrid_index(i, j):
return (i//3) * 3 + j // 3
subgrids = get_subgrids(grid)
grid_candidates = []
for i in range(9):
row_candidates = []
for j in range(9):
# Row, column and subgrid digits
row = set(grid[i])
col = set(grid[:, j])
sub = set(subgrids[subgrid_index(i, j)])
common = row | col | sub
candidates = set(range(10)) - common
# If the case is filled take its value as the only candidate
if not grid[i][j]:
row_candidates.append(list(candidates))
else:
row_candidates.append([grid[i][j]])
grid_candidates.append(row_candidates)
return grid_candidates
Das, was mir zuerst aufgefallen ist, war das Ende Ihrer verschachtelten for-Schleife. Sie haben geprüft, ob grid[i][j]gefüllt ist. Wenn ja, dann ist das der einzige Kandidat. Wenn es jedoch gefüllt ist, hat es nichts damit zu tun candidates, was Sie in Ihrer verschachtelten for-Schleife sehr hart berechnet haben.
Als erstes habe ich den Scheck an den Anfang der for-Schleife verschoben.
for i in range(9):
row_candidates = []
for j in range(9):
if grid[i][j]:
row_candidates.append([grid[i][j]])
continue
# Row, column and subgrid digits
row = set(grid[i])
col = set(grid[:, j])
sub = set(subgrids[subgrid_index(i, j)])
common = row | col | sub
candidates = set(range(10)) - common
row_candidates.append(list(candidates))
Allein diese Optimierung hat die Laufzeit halbiert, wir sind jetzt bei ca. 2,3s.
Dann habe ich festgestellt, dass Sie in Ihrer verschachtelten for-Schleife viele redundante Set-Operationen ausführen. Sogar row / col / sub muss nur 9 Mal berechnet werden, Sie berechnen es 81 Mal, was ziemlich schlecht ist. Also habe ich die Berechnung aus der for-Schleife verschoben.
def get_candidates(grid : np.ndarray) -> list:
"""Get a list of candidates to fill empty cells of the input grid"""
def subgrid_index(i, j):
return (i//3) * 3 + j // 3
subgrids = get_subgrids(grid)
grid_candidates = []
row_sets = [set(grid[i]) for i in range(9)]
col_sets = [set(grid[:, j]) for j in range(9)]
subgrid_sets = [set(subgrids[i]) for i in range(9)]
total_sets = set(range(10))
for i in range(9):
row_candidates = []
for j in range(9):
if grid[i][j]:
row_candidates.append([grid[i][j]])
continue
# Row, column and subgrid digits
row = row_sets[i]
col = col_sets[j]
sub = subgrid_sets[subgrid_index(i, j)]
common = row | col | sub
candidates = total_sets - common
# If the case is filled take its value as the only candidate
row_candidates.append(list(candidates))
grid_candidates.append(row_candidates)
return grid_candidates
Dies verkürzte die Laufzeit auf ca. 1,5 s. Beachten Sie, dass ich noch nicht versucht habe, Ihren Algorithmus zu verstehen. Ich habe nur VizTracer verwendet, um die zu optimierende Funktion zu finden und eine Transformation mit derselben Logik durchzuführen. Ich habe die Leistung in nur 15 Minuten um etwa 300% verbessert.
Bis zu diesem Punkt ist der Overhead von VizTracer in der WSL erheblich, daher habe ich den C-Funktions-Trace deaktiviert. Es waren nur noch Python-Funktionen übrig und der Overhead betrug ca. 10%.
Jetzt get_candidateswurde das verbessert (obwohl es besser gemacht werden kann), wir müssen ein größeres Bild davon machen. Was ich am Ergebnis von VizTracer beobachten kann, war, dass sehr häufig fill_singlesangerufen wurde get_candidates, einfach zu viele Anrufe. (Dies ist etwas, das auf cProfiler schwer zu bemerken ist)
Der nächste Schritt war also herauszufinden, ob wir weniger oft fill_singlestelefonieren können get_candidates. Hier erfordert es ein gewisses Maß an Algorithmusverständnis.
while any_fill:
any_fill = False
for i in range(9):
for j in range(9):
if len(candidates[i][j]) == 1 and grid[i][j] == 0:
grid[i][j] = candidates[i][j][0]
candidates = merge(get_candidates(grid), candidates)
any_fill = True
Es sieht so aus, als hätten Sie hier versucht, eine Lücke mit nur einem Kandidaten auszufüllen, die Kandidaten des gesamten Rasters neu zu berechnen und dann die nächste Lücke mit einem Kandidaten zu finden. Dies ist eine gültige Methode, die jedoch zu viele Aufrufe verursacht hat get_candidates. Wenn Sie darüber nachdenken, wenn wir eine Lücke mit einer Nummer ausfüllen n, sind alle anderen Lücken mit nur einem Kandidaten, der nicht nbetroffen ist, nicht betroffen. Während eines Durchgangs des Gitters könnten wir also tatsächlich versuchen, mehr Lücken auszufüllen, solange wir nicht zweimal dieselbe Zahl ausfüllen. Auf diese Weise können wir get_candidatesweniger häufig anrufen , was ein großer Zeitverbraucher ist. Ich habe dazu ein Set benutzt.
filled_number = set()
for i in range(9):
for j in range(9):
if len(candidates[i][j]) == 1 and grid[i][j] == 0 and candidates[i][j][0] not in filled_number:
grid[i][j] = candidates[i][j][0]
filled_number.add(candidates[i][j][0])
any_fill = True
candidates = merge(get_candidates(grid), candidates)
Dies brachte die Laufzeit auf 0,9 s.
Dann habe ich mir den VizTracer-Bericht angesehen und festgestellt, dass er fill_singlesfast immer von aufgerufen wird. filter_candidatesDas einzige, was mich filter_candidatesinteressiert, ist, ob fill_singlesein gültiges Raster zurückgegeben wird. Dies ist eine Information, die wir möglicherweise frühzeitig kennen, solange wir fill_singleseine Stelle ohne Kandidaten finden. Wenn wir früh zurückkehren, müssen wir nicht so get_candidatesoft rechnen .
Also habe ich die Codestruktur ein wenig geändert und fill_singleszurückgegeben, Nonewenn kein gültiges Raster gefunden werden kann.
Endlich konnte ich die Laufzeit auf 0,5 s einstellen, was 900% schneller ist als die Originalversion.
Es war tatsächlich ein lustiges Abenteuer, weil ich mein Projekt VizTracer getestet und versucht habe herauszufinden, ob es hilfreich ist, den zeitaufwändigen Teil zu finden. Es hat gut funktioniert :)
Numpyifizierung
get_subgridsordnet im Wesentlichen ein Numpy-Array mit einem Minimum an Numpy neu an. Es könnte mit numpy selbst gemacht werden, zum Beispiel:
def get_subgrids(grid: np.ndarray) -> np.ndarray:
"""Divide the input grid into 9 3x3 sub-grids"""
swapped = np.swapaxes(np.reshape(grid, (3, 3, 3, 3)), 1, 2)
return np.reshape(swapped, (9, 9))
Der Nachteil, den ich vermute, ist, dass das Vertauschen der beiden mittleren Achsen eines 4D-Arrays etwas umwerfend ist.
Performance
Fast die gesamte Zeit wird in verbracht get_candidates. Ich denke, die Gründe dafür sind hauptsächlich:
- Es wird zu oft aufgerufen. Wenn Sie beispielsweise eine Zelle (z. B. in
fill_singles) ausgefüllt haben, anstatt die Kandidaten von Grund auf neu zu berechnen, ist es schneller, den neuen Wert lediglich aus den Kandidaten in derselben Zeile / Spalte / Haus zu entfernen. - Wenn eine Zelle gefüllt ist, ist die Liste der Kandidaten nur der ausgefüllte Wert, aber die teure Mengenberechnung wird trotzdem durchgeführt. Das lässt sich leicht vermeiden, indem Sie diese Aussage in das verschieben
if.
Algorithmische Leistung
Dieser Solver verwendet Naked Singles nur als "Propagationstechnik". Das Hinzufügen von Hidden Singles ist meiner Erfahrung nach ein sehr großer Schritt in Richtung eines effizienten Solvers.