Files
nd-wfc/demos/sudoku/sudoku.h
2025-08-25 13:08:32 +09:00

236 lines
7.7 KiB
C++

#pragma once
#include <string>
#include <vector>
#include <optional>
#include <cstdint>
#include <array>
#include <cassert>
#include <chrono>
// 4-bit packed Sudoku board storage - optimal packing
// 81 cells * 4 bits = 324 bits
// Each byte holds 2 cells (8 bits / 4 bits per cell = 2)
// 81 cells / 2 = 40.5 bytes → 41 bytes total (with 4 bits unused)
class SudokuBoardStorage {
public:
std::array<uint8_t, 41> data;
// Get 4-bit value at position (0-80)
// Each byte contains 2 cells: [cell0(4bits)][cell1(4bits)]
// Ultra-fast: only bitwise operations, no divide/modulo!
// Optimization: pos >> 1 instead of pos / 2
// Optimization: (pos & 1) << 2 instead of (pos % 2) * 4
uint8_t get(int pos) const {
int byteIndex = pos >> 1; // pos / 2 using right shift
// Precomputed shift amounts: 4 for even positions, 0 for odd positions
// This is equivalent to: (4 - bitOffset) where bitOffset = (pos & 1) << 2
// For even pos (0,2,4,...): 4 - 0 = 4
// For odd pos (1,3,5,...): 4 - 4 = 0
int shiftAmount = 4 - ((pos & 1) << 2);
uint8_t result = (data[byteIndex] >> shiftAmount) & 0xF;
// Debug assertion: ensure result is in valid range
assert(result >= 0 && result <= 9 && "Sudoku cell value must be between 0-9");
return result;
}
// Set 4-bit value at position (0-80)
// Ultra-fast: only bitwise operations, no divide/modulo!
// Optimization: pos >> 1 instead of pos / 2
// Optimization: (pos & 1) << 2 instead of (pos % 2) * 4
void set(int pos, uint8_t value) {
// Assert that value is in valid Sudoku range (0-9)
assert(value >= 0 && value <= 9 && "Sudoku cell value must be between 0-9");
int byteIndex = pos >> 1; // pos / 2 using right shift
// Precomputed shift amounts: 4 for even positions, 0 for odd positions
int shiftAmount = 4 - ((pos & 1) << 2);
// Create mask to clear the 4 bits we're setting
uint8_t mask = ~(0xF << shiftAmount);
// Set the value (value is already 0-9, so only lower 4 bits are used)
data[byteIndex] = (data[byteIndex] & mask) | (value << shiftAmount);
}
void clear() {
data.fill(0);
}
};
// Ultra-memory-efficient Sudoku class: exactly 41 bytes
class Sudoku {
public:
Sudoku();
explicit Sudoku(const std::string& puzzle_str);
// Load from various formats
bool loadFromString(const std::string& puzzle_str);
bool loadFromFile(const std::string& filename);
// Board access (inlined for performance)
inline uint8_t get(int row, int col) const {
assert((row >= 0 && row < 9 && col >= 0 && col < 9) &&
"Sudoku::get() called with invalid position - row and col must be 0-8");
int linearIndex = getLinearIndex(row, col);
return board_.get(linearIndex);
}
inline bool set(int row, int col, uint8_t value) {
assert((row >= 0 && row < 9 && col >= 0 && col < 9) &&
"Sudoku::set() called with invalid position - row and col must be 0-8");
// Keep value validation as runtime check since it's about valid Sudoku numbers
if (value > 9) return false;
int linearIndex = getLinearIndex(row, col);
uint8_t old_value = board_.get(linearIndex);
// If same value, no change needed
if (old_value == value) return true;
// Check if new value is valid (skip for 0 as it clears)
if (value != 0 && !isValidMove(row, col, value)) {
return false;
}
board_.set(linearIndex, value);
return true;
}
void clear();
// Validation
bool isValid() const;
bool isSolved() const;
// Inlined validation (called frequently from set())
inline bool isValidMove(int row, int col, uint8_t value) const {
if (value == 0 || value > 9) return false;
return !hasRowConflictExcluding(row, col, value) &&
!hasColConflictExcluding(col, row, value) &&
!hasBoxConflictExcluding(getBoxIndex(row, col), row, col, value);
}
// Utility
void print() const;
std::string toString() const;
// Convert to standard board format for external use
std::array<uint8_t, 81> getBoard() const;
private:
SudokuBoardStorage board_;
// Helper functions (inlined for performance)
inline int getLinearIndex(int row, int col) const {
return row * 9 + col;
}
inline int getBoxIndex(int row, int col) const {
return (row / 3) * 3 + (col / 3);
}
inline bool isValidPosition(int row, int col) const {
return row >= 0 && row < 9 && col >= 0 && col < 9;
}
// Validation helpers (inlined for performance)
// Uses std::bitset<10> for efficient duplicate detection instead of arrays
inline bool hasRowConflict(int row, uint8_t value) const {
for (int col = 0; col < 9; ++col) {
if (get(row, col) == value) return true;
}
return false;
}
inline bool hasColConflict(int col, uint8_t value) const {
for (int row = 0; row < 9; ++row) {
if (get(row, col) == value) return true;
}
return false;
}
inline bool hasBoxConflict(int box, uint8_t value) const {
int startRow = (box / 3) * 3;
int startCol = (box % 3) * 3;
for (int row = 0; row < 3; ++row) {
for (int col = 0; col < 3; ++col) {
if (get(startRow + row, startCol + col) == value) return true;
}
}
return false;
}
// Validation helpers that exclude current position (for move validation)
inline bool hasRowConflictExcluding(int row, int excludeCol, uint8_t value) const {
for (int col = 0; col < 9; ++col) {
if (col != excludeCol && get(row, col) == value) return true;
}
return false;
}
inline bool hasColConflictExcluding(int col, int excludeRow, uint8_t value) const {
for (int row = 0; row < 9; ++row) {
if (row != excludeRow && get(row, col) == value) return true;
}
return false;
}
inline bool hasBoxConflictExcluding(int box, int excludeRow, int excludeCol, uint8_t value) const {
int startRow = (box / 3) * 3;
int startCol = (box % 3) * 3;
for (int row = 0; row < 3; ++row) {
for (int col = 0; col < 3; ++col) {
int r = startRow + row;
int c = startCol + col;
if ((r != excludeRow || c != excludeCol) && get(r, c) == value) return true;
}
}
return false;
}
public: // WFC Support
using ValueType = uint8_t;
ValueType getValue(size_t index) const {
return board_.get(index);
}
void setValue(size_t index, ValueType value) {
board_.set(index, value);
}
constexpr size_t size() const {
return 81;
}
};
// Static assert to ensure correct size (now 56 bytes with solver additions)
static_assert(sizeof(Sudoku) == 41, "Sudoku class must be exactly 41 bytes");
// Fast solution validator (stateless)
class SudokuValidator {
public:
static bool isValidSolution(const std::array<uint8_t, 81>& board);
static bool isValidPartial(const std::array<uint8_t, 81>& board);
static bool hasConflicts(const std::array<uint8_t, 81>& board);
static std::vector<std::pair<int, int>> findConflicts(const std::array<uint8_t, 81>& board);
};
// Fast puzzle loader
class SudokuLoader {
public:
static std::optional<Sudoku> fromString(const std::string& puzzle_str);
static std::optional<Sudoku> fromFile(const std::string& filename);
static std::vector<Sudoku> fromDirectory(const std::string& dirname, const std::string& extension = ".txt");
private:
static bool parseLine(const std::string& line, std::array<uint8_t, 81>& board);
};