branching

This commit is contained in:
cdemeyer-teachx
2025-08-24 22:40:36 +09:00
parent 6fce648b01
commit 2d4336fc8d
5 changed files with 146 additions and 63 deletions

View File

@@ -87,10 +87,15 @@ if(HAS_GTEST)
# Set test output directory
set_target_properties(sudoku_tests PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
CXX_STANDARD 20
CXX_STANDARD_REQUIRED ON
)
# Include directories for tests
target_include_directories(sudoku_tests PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_include_directories(sudoku_tests PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../../include
)
# Add test to CTest
add_test(NAME sudoku_tests COMMAND sudoku_tests)
@@ -108,10 +113,15 @@ if(HAS_BENCHMARK)
# Set benchmark output directory
set_target_properties(sudoku_benchmarks PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
CXX_STANDARD 20
CXX_STANDARD_REQUIRED ON
)
# Include directories for benchmarks
target_include_directories(sudoku_benchmarks PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_include_directories(sudoku_benchmarks PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../../include
)
endif()
# Installation (optional)

View File

@@ -4,13 +4,13 @@
// Benchmark fixture for Sudoku benchmarks
class SudokuBenchmark : public benchmark::Fixture {
public:
void SetUp(const ::benchmark::State& state) override {
void SetUp(const ::benchmark::State&) override {
// Create test puzzle
testPuzzle = "530070000600195000098000060800060003400803001700020006060000280000419005000080079";
sudoku.loadFromString(testPuzzle);
}
void TearDown(const ::benchmark::State& state) override {
void TearDown(const ::benchmark::State&) override {
// Cleanup if needed
}

View File

@@ -6,6 +6,7 @@
#include <cstdint>
#include <array>
#include <cassert>
#include <chrono>
// 4-bit packed Sudoku board storage - optimal packing
// 81 cells * 4 bits = 324 bits
@@ -209,10 +210,30 @@ public: // WFC Support
return 81;
}
public: // Solver Interface
// Solve the puzzle using WFC algorithm
bool solve();
// Solve with custom initial constraints (for testing)
bool solveWithConstraints(const std::vector<std::pair<int, int>>& constraints);
// Get the number of attempts made during solving
int getSolveAttempts() const { return solve_attempts_; }
// Get the time taken for the last solve operation (in microseconds)
long long getSolveTimeMicroseconds() const { return solve_time_us_; }
private:
mutable int solve_attempts_ = 0;
mutable long long solve_time_us_ = 0;
// Helper method for backtracking solver
bool solveBacktracking(size_t index);
};
// Static assert to ensure exactly 41 bytes
static_assert(sizeof(Sudoku) == 41, "Sudoku class must be exactly 41 bytes");
// Static assert to ensure correct size (now 56 bytes with solver additions)
static_assert(sizeof(Sudoku) == 56, "Sudoku class must be exactly 56 bytes");
// Fast solution validator (stateless)
class SudokuValidator {

View File

@@ -1,5 +1,6 @@
#include <gtest/gtest.h>
#include "sudoku.h"
#include <chrono>
// Test fixture for Sudoku tests
class SudokuTest : public ::testing::Test {
@@ -96,7 +97,7 @@ TEST_F(SudokuTest, Clear) {
}
TEST_F(SudokuTest, MemorySize) {
EXPECT_EQ(sizeof(Sudoku), 41);
EXPECT_EQ(sizeof(Sudoku), 56); // Updated to include solver members
}
TEST_F(SudokuTest, SetInvalidValue) {

View File

@@ -11,6 +11,7 @@
#include <algorithm>
#include <concepts>
#include <bit>
#include <iostream>
namespace WFC {
@@ -148,6 +149,9 @@ public:
for (auto& wave : m_data) wave = (1 << variableAmount) - 1;
}
Wave(const Wave& other) = default;
public:
void Collapse(size_t index, MaskType mask) { m_data[index] &= mask; }
size_t size() const { return m_data.size(); }
size_t Entropy(size_t index) const { return std::popcount(m_data[index]); }
@@ -249,18 +253,19 @@ public:
using MaskType = typename VariableIDMapT::MaskType;
public:
struct WorldSolver {
struct SolverState {
WorldT& world;
std::queue<size_t> propagationQueue;
Wave<MaskType> wave;
std::mt19937 rng;
std::queue<size_t> propagationQueue{};
Wave<MaskType> wave{};
std::mt19937& rng;
WorldSolver(WorldT& world, const std::vector<VariableData<WorldT, VarT, VariableIDMapT>>& variables)
SolverState(WorldT& world, size_t variableAmount, std::mt19937& rng)
: world(world)
, propagationQueue()
, wave(world.size(), variables.size())
, rng(std::random_device{}())
, wave(world.size(), variableAmount)
, rng(rng)
{}
SolverState(const SolverState& other) = default;
};
public:
@@ -271,32 +276,45 @@ public:
public:
bool Run(WorldT& world, bool propagateInitialValues = false)
{
WorldSolver worldSolver(world, m_variables);
return Run(worldSolver, propagateInitialValues);
//auto seed = std::random_device{}();
auto seed = 1844803044ull;
std::mt19937 random{ seed };
SolverState state(world, m_variables.size(), random);
return Run(state, propagateInitialValues);
}
/**
* @brief Run the WFC algorithm to generate a solution
* @return true if a solution was found, false if contradiction occurred
*/
bool Run(WorldSolver& worldSolver, bool propagateInitialValues = false)
bool Run(SolverState& state, bool propagateInitialValues = false)
{
if (propagateInitialValues)
{
PropogateInitialValues(worldSolver);
PropogateInitialValues(state);
}
for (size_t i = 0; i < 1024; ++i)
{
Propagate(worldSolver);
if (RunLoop(state)) {
if (worldSolver.wave.IsFullyCollapsed()) {
PopulateWorld(worldSolver);
PopulateWorld(state);
return true;
} else if (worldSolver.wave.HasContradiction()) {
}
return false;
}
bool RunLoop(SolverState& state)
{
constexpr size_t maxIterations = 1024;
for (size_t i = 0; i < maxIterations; ++i)
{
Propagate(state);
if (state.wave.IsFullyCollapsed()) {
return true;
} else if (state.wave.HasContradiction()) {
return false;
} else {
GetMinEntropyCell(worldSolver);
Branch(state);
}
}
return false;
@@ -307,9 +325,9 @@ public:
* @param cellId The cell ID
* @return The value if collapsed, std::nullopt otherwise
*/
std::optional<VarT> GetValue(WorldSolver& worldSolver, int cellId) const {
if (worldSolver.wave.IsCollapsed(cellId)) {
auto variableId = worldSolver.wave.GetVariableID(cellId);
std::optional<VarT> GetValue(SolverState& state, int cellId) const {
if (state.wave.IsCollapsed(cellId)) {
auto variableId = state.wave.GetVariableID(cellId);
return VariableIDMapT::GetValue(variableId);
}
return std::nullopt;
@@ -320,10 +338,10 @@ public:
* @param cellId The cell ID
* @return Set of possible values
*/
const std::vector<VarT> GetPossibleValues(WorldSolver& worldSolver, int cellId) const
const std::vector<VarT> GetPossibleValues(SolverState& state, int cellId) const
{
std::vector<VarT> possibleValues;
MaskType mask = worldSolver.wave.GetMask(cellId);
MaskType mask = state.wave.GetMask(cellId);
for (size_t i = 0; i < m_variables.size(); ++i) {
if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i));
}
@@ -331,73 +349,106 @@ public:
}
private:
bool GetMinEntropyCell(WorldSolver& worldSolver)
bool Branch(SolverState& state)
{
assert(worldSolver.propagationQueue.empty());
assert(state.propagationQueue.empty());
// Find cell with minimum entropy > 1
size_t minEntropyCell = static_cast<size_t>(-1);
size_t minEntropy = static_cast<size_t>(-1);
for (size_t i = 0; i < worldSolver.wave.size(); ++i) {
size_t entropy = worldSolver.wave.Entropy(i);
for (size_t i = 0; i < state.wave.size(); ++i) {
size_t entropy = state.wave.Entropy(i);
if (entropy > 1 && entropy < minEntropy) {
minEntropy = entropy;
minEntropyCell = i;
}
}
assert(!worldSolver.wave.IsCollapsed(minEntropyCell));
assert(!state.wave.IsCollapsed(minEntropyCell));
// Randomly select a value from possible values
size_t availableValues = worldSolver.wave.Entropy(minEntropyCell);
std::uniform_int_distribution<size_t> dist(0, availableValues - 1);
size_t selectedValue = FindNthSetBit(worldSolver.wave.GetMask(minEntropyCell), dist(worldSolver.rng));
assert(selectedValue < VariableIDMapT::ValuesRegisteredAmount && "Selected Value went outside bounds");
// create a list of possible values
size_t availableValues = state.wave.Entropy(minEntropyCell);
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> possibleValues; // inplace vector
MaskType mask = state.wave.GetMask(minEntropyCell);
for (size_t i = 0; i < availableValues; ++i)
{
uint16_t index = std::countr_zero(mask); // get the index of the lowest set bit
assert(index < VariableIDMapT::ValuesRegisteredAmount && "Possible value went outside bounds");
// Collapse the cell to the selected value
worldSolver.wave.Collapse(minEntropyCell, 1 << selectedValue);
assert(worldSolver.wave.IsCollapsed(minEntropyCell) && "Cell was not collapsed correctly");
possibleValues[i] = index;
assert(((mask & (1 << index)) != 0) && "Possible value was not set");
worldSolver.propagationQueue.push(minEntropyCell);
mask = mask & (mask - 1); // turn off lowest set bit
}
// randomly select a value from possible values
for (size_t i = 0; i < availableValues; ++i)
{
std::uniform_int_distribution<uint16_t> dist(0, availableValues - 1);
uint16_t selectedValue = possibleValues[dist(state.rng)];
{
// copy the state and branch out
SolverState newState(state);
newState.wave.Collapse(minEntropyCell, 1 << selectedValue);
newState.propagationQueue.push(minEntropyCell);
if (RunLoop(newState))
{
// copy the solution to the original state
state.wave = newState.wave;
return true;
}
}
void Propagate(WorldSolver& worldSolver)
// remove the failure state from the wave
assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) != 0 && "Possible value was not set");
state.wave.Collapse(minEntropyCell, ~(1 << selectedValue));
assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) == 0 && "Wave was not collapsed correctly");
// swap replacement value with the last value
std::swap(possibleValues[i], possibleValues[--availableValues]);
}
return false;
}
void Propagate(SolverState& state)
{
while (!worldSolver.propagationQueue.empty())
while (!state.propagationQueue.empty())
{
size_t cellId = worldSolver.propagationQueue.front();
worldSolver.propagationQueue.pop();
size_t cellId = state.propagationQueue.front();
state.propagationQueue.pop();
assert(worldSolver.wave.IsCollapsed(cellId) && "Cell was not collapsed");
assert(state.wave.IsCollapsed(cellId) && "Cell was not collapsed");
uint16_t variableID = worldSolver.wave.GetVariableID(cellId);
Constrainer<VariableIDMapT> constrainer(worldSolver.wave, worldSolver.propagationQueue);
m_variables[variableID].constraintFunc(worldSolver.world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
uint16_t variableID = state.wave.GetVariableID(cellId);
Constrainer<VariableIDMapT> constrainer(state.wave, state.propagationQueue);
m_variables[variableID].constraintFunc(state.world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
}
}
void PopulateWorld(WorldSolver& worldSolver)
void PopulateWorld(SolverState& state)
{
for (size_t i = 0; i < worldSolver.wave.size(); ++i)
for (size_t i = 0; i < state.wave.size(); ++i)
{
worldSolver.world.setValue(i, VariableIDMapT::GetValue(worldSolver.wave.GetVariableID(i)));
state.world.setValue(i, VariableIDMapT::GetValue(state.wave.GetVariableID(i)));
}
}
private:
void PropogateInitialValues(WorldSolver& worldSolver)
void PropogateInitialValues(SolverState& state)
{
auto allValues = VariableIDMapT::GetAllValues();
for (size_t i = 0; i < worldSolver.wave.size(); ++i)
for (size_t i = 0; i < state.wave.size(); ++i)
{
for (size_t j = 0; j < allValues.size(); ++j)
{
if (worldSolver.world.getValue(i) == allValues[j])
if (state.world.getValue(i) == allValues[j])
{
worldSolver.wave.Collapse(i, 1 << j);
worldSolver.propagationQueue.push(i);
state.wave.Collapse(i, 1 << j);
state.propagationQueue.push(i);
break;
}
}