branching
This commit is contained in:
@@ -87,10 +87,15 @@ if(HAS_GTEST)
|
||||
# Set test output directory
|
||||
set_target_properties(sudoku_tests PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
|
||||
CXX_STANDARD 20
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
)
|
||||
|
||||
# Include directories for tests
|
||||
target_include_directories(sudoku_tests PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_include_directories(sudoku_tests PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
)
|
||||
|
||||
# Add test to CTest
|
||||
add_test(NAME sudoku_tests COMMAND sudoku_tests)
|
||||
@@ -108,10 +113,15 @@ if(HAS_BENCHMARK)
|
||||
# Set benchmark output directory
|
||||
set_target_properties(sudoku_benchmarks PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
|
||||
CXX_STANDARD 20
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
)
|
||||
|
||||
# Include directories for benchmarks
|
||||
target_include_directories(sudoku_benchmarks PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_include_directories(sudoku_benchmarks PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
)
|
||||
endif()
|
||||
|
||||
# Installation (optional)
|
||||
|
||||
@@ -4,13 +4,13 @@
|
||||
// Benchmark fixture for Sudoku benchmarks
|
||||
class SudokuBenchmark : public benchmark::Fixture {
|
||||
public:
|
||||
void SetUp(const ::benchmark::State& state) override {
|
||||
void SetUp(const ::benchmark::State&) override {
|
||||
// Create test puzzle
|
||||
testPuzzle = "530070000600195000098000060800060003400803001700020006060000280000419005000080079";
|
||||
sudoku.loadFromString(testPuzzle);
|
||||
}
|
||||
|
||||
void TearDown(const ::benchmark::State& state) override {
|
||||
void TearDown(const ::benchmark::State&) override {
|
||||
// Cleanup if needed
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <cstdint>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
|
||||
// 4-bit packed Sudoku board storage - optimal packing
|
||||
// 81 cells * 4 bits = 324 bits
|
||||
@@ -209,10 +210,30 @@ public: // WFC Support
|
||||
return 81;
|
||||
}
|
||||
|
||||
public: // Solver Interface
|
||||
// Solve the puzzle using WFC algorithm
|
||||
bool solve();
|
||||
|
||||
// Solve with custom initial constraints (for testing)
|
||||
bool solveWithConstraints(const std::vector<std::pair<int, int>>& constraints);
|
||||
|
||||
// Get the number of attempts made during solving
|
||||
int getSolveAttempts() const { return solve_attempts_; }
|
||||
|
||||
// Get the time taken for the last solve operation (in microseconds)
|
||||
long long getSolveTimeMicroseconds() const { return solve_time_us_; }
|
||||
|
||||
private:
|
||||
mutable int solve_attempts_ = 0;
|
||||
mutable long long solve_time_us_ = 0;
|
||||
|
||||
// Helper method for backtracking solver
|
||||
bool solveBacktracking(size_t index);
|
||||
|
||||
};
|
||||
|
||||
// Static assert to ensure exactly 41 bytes
|
||||
static_assert(sizeof(Sudoku) == 41, "Sudoku class must be exactly 41 bytes");
|
||||
// Static assert to ensure correct size (now 56 bytes with solver additions)
|
||||
static_assert(sizeof(Sudoku) == 56, "Sudoku class must be exactly 56 bytes");
|
||||
|
||||
// Fast solution validator (stateless)
|
||||
class SudokuValidator {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "sudoku.h"
|
||||
#include <chrono>
|
||||
|
||||
// Test fixture for Sudoku tests
|
||||
class SudokuTest : public ::testing::Test {
|
||||
@@ -96,7 +97,7 @@ TEST_F(SudokuTest, Clear) {
|
||||
}
|
||||
|
||||
TEST_F(SudokuTest, MemorySize) {
|
||||
EXPECT_EQ(sizeof(Sudoku), 41);
|
||||
EXPECT_EQ(sizeof(Sudoku), 56); // Updated to include solver members
|
||||
}
|
||||
|
||||
TEST_F(SudokuTest, SetInvalidValue) {
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <algorithm>
|
||||
#include <concepts>
|
||||
#include <bit>
|
||||
#include <iostream>
|
||||
|
||||
namespace WFC {
|
||||
|
||||
@@ -148,6 +149,9 @@ public:
|
||||
for (auto& wave : m_data) wave = (1 << variableAmount) - 1;
|
||||
}
|
||||
|
||||
Wave(const Wave& other) = default;
|
||||
|
||||
public:
|
||||
void Collapse(size_t index, MaskType mask) { m_data[index] &= mask; }
|
||||
size_t size() const { return m_data.size(); }
|
||||
size_t Entropy(size_t index) const { return std::popcount(m_data[index]); }
|
||||
@@ -249,18 +253,19 @@ public:
|
||||
using MaskType = typename VariableIDMapT::MaskType;
|
||||
|
||||
public:
|
||||
struct WorldSolver {
|
||||
struct SolverState {
|
||||
WorldT& world;
|
||||
std::queue<size_t> propagationQueue;
|
||||
Wave<MaskType> wave;
|
||||
std::mt19937 rng;
|
||||
std::queue<size_t> propagationQueue{};
|
||||
Wave<MaskType> wave{};
|
||||
std::mt19937& rng;
|
||||
|
||||
WorldSolver(WorldT& world, const std::vector<VariableData<WorldT, VarT, VariableIDMapT>>& variables)
|
||||
SolverState(WorldT& world, size_t variableAmount, std::mt19937& rng)
|
||||
: world(world)
|
||||
, propagationQueue()
|
||||
, wave(world.size(), variables.size())
|
||||
, rng(std::random_device{}())
|
||||
, wave(world.size(), variableAmount)
|
||||
, rng(rng)
|
||||
{}
|
||||
|
||||
SolverState(const SolverState& other) = default;
|
||||
};
|
||||
|
||||
public:
|
||||
@@ -271,32 +276,45 @@ public:
|
||||
public:
|
||||
bool Run(WorldT& world, bool propagateInitialValues = false)
|
||||
{
|
||||
WorldSolver worldSolver(world, m_variables);
|
||||
return Run(worldSolver, propagateInitialValues);
|
||||
//auto seed = std::random_device{}();
|
||||
auto seed = 1844803044ull;
|
||||
std::mt19937 random{ seed };
|
||||
SolverState state(world, m_variables.size(), random);
|
||||
return Run(state, propagateInitialValues);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Run the WFC algorithm to generate a solution
|
||||
* @return true if a solution was found, false if contradiction occurred
|
||||
*/
|
||||
bool Run(WorldSolver& worldSolver, bool propagateInitialValues = false)
|
||||
bool Run(SolverState& state, bool propagateInitialValues = false)
|
||||
{
|
||||
if (propagateInitialValues)
|
||||
{
|
||||
PropogateInitialValues(worldSolver);
|
||||
PropogateInitialValues(state);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < 1024; ++i)
|
||||
{
|
||||
Propagate(worldSolver);
|
||||
if (RunLoop(state)) {
|
||||
|
||||
if (worldSolver.wave.IsFullyCollapsed()) {
|
||||
PopulateWorld(worldSolver);
|
||||
PopulateWorld(state);
|
||||
return true;
|
||||
} else if (worldSolver.wave.HasContradiction()) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RunLoop(SolverState& state)
|
||||
{
|
||||
constexpr size_t maxIterations = 1024;
|
||||
for (size_t i = 0; i < maxIterations; ++i)
|
||||
{
|
||||
Propagate(state);
|
||||
|
||||
if (state.wave.IsFullyCollapsed()) {
|
||||
return true;
|
||||
} else if (state.wave.HasContradiction()) {
|
||||
return false;
|
||||
} else {
|
||||
GetMinEntropyCell(worldSolver);
|
||||
Branch(state);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
@@ -307,9 +325,9 @@ public:
|
||||
* @param cellId The cell ID
|
||||
* @return The value if collapsed, std::nullopt otherwise
|
||||
*/
|
||||
std::optional<VarT> GetValue(WorldSolver& worldSolver, int cellId) const {
|
||||
if (worldSolver.wave.IsCollapsed(cellId)) {
|
||||
auto variableId = worldSolver.wave.GetVariableID(cellId);
|
||||
std::optional<VarT> GetValue(SolverState& state, int cellId) const {
|
||||
if (state.wave.IsCollapsed(cellId)) {
|
||||
auto variableId = state.wave.GetVariableID(cellId);
|
||||
return VariableIDMapT::GetValue(variableId);
|
||||
}
|
||||
return std::nullopt;
|
||||
@@ -320,10 +338,10 @@ public:
|
||||
* @param cellId The cell ID
|
||||
* @return Set of possible values
|
||||
*/
|
||||
const std::vector<VarT> GetPossibleValues(WorldSolver& worldSolver, int cellId) const
|
||||
const std::vector<VarT> GetPossibleValues(SolverState& state, int cellId) const
|
||||
{
|
||||
std::vector<VarT> possibleValues;
|
||||
MaskType mask = worldSolver.wave.GetMask(cellId);
|
||||
MaskType mask = state.wave.GetMask(cellId);
|
||||
for (size_t i = 0; i < m_variables.size(); ++i) {
|
||||
if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i));
|
||||
}
|
||||
@@ -331,73 +349,106 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
bool GetMinEntropyCell(WorldSolver& worldSolver)
|
||||
bool Branch(SolverState& state)
|
||||
{
|
||||
assert(worldSolver.propagationQueue.empty());
|
||||
assert(state.propagationQueue.empty());
|
||||
|
||||
// Find cell with minimum entropy > 1
|
||||
size_t minEntropyCell = static_cast<size_t>(-1);
|
||||
size_t minEntropy = static_cast<size_t>(-1);
|
||||
|
||||
for (size_t i = 0; i < worldSolver.wave.size(); ++i) {
|
||||
size_t entropy = worldSolver.wave.Entropy(i);
|
||||
for (size_t i = 0; i < state.wave.size(); ++i) {
|
||||
size_t entropy = state.wave.Entropy(i);
|
||||
if (entropy > 1 && entropy < minEntropy) {
|
||||
minEntropy = entropy;
|
||||
minEntropyCell = i;
|
||||
}
|
||||
}
|
||||
assert(!worldSolver.wave.IsCollapsed(minEntropyCell));
|
||||
assert(!state.wave.IsCollapsed(minEntropyCell));
|
||||
|
||||
// Randomly select a value from possible values
|
||||
size_t availableValues = worldSolver.wave.Entropy(minEntropyCell);
|
||||
std::uniform_int_distribution<size_t> dist(0, availableValues - 1);
|
||||
size_t selectedValue = FindNthSetBit(worldSolver.wave.GetMask(minEntropyCell), dist(worldSolver.rng));
|
||||
assert(selectedValue < VariableIDMapT::ValuesRegisteredAmount && "Selected Value went outside bounds");
|
||||
// create a list of possible values
|
||||
size_t availableValues = state.wave.Entropy(minEntropyCell);
|
||||
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> possibleValues; // inplace vector
|
||||
MaskType mask = state.wave.GetMask(minEntropyCell);
|
||||
for (size_t i = 0; i < availableValues; ++i)
|
||||
{
|
||||
uint16_t index = std::countr_zero(mask); // get the index of the lowest set bit
|
||||
assert(index < VariableIDMapT::ValuesRegisteredAmount && "Possible value went outside bounds");
|
||||
|
||||
// Collapse the cell to the selected value
|
||||
worldSolver.wave.Collapse(minEntropyCell, 1 << selectedValue);
|
||||
assert(worldSolver.wave.IsCollapsed(minEntropyCell) && "Cell was not collapsed correctly");
|
||||
possibleValues[i] = index;
|
||||
assert(((mask & (1 << index)) != 0) && "Possible value was not set");
|
||||
|
||||
worldSolver.propagationQueue.push(minEntropyCell);
|
||||
mask = mask & (mask - 1); // turn off lowest set bit
|
||||
}
|
||||
|
||||
// randomly select a value from possible values
|
||||
for (size_t i = 0; i < availableValues; ++i)
|
||||
{
|
||||
std::uniform_int_distribution<uint16_t> dist(0, availableValues - 1);
|
||||
uint16_t selectedValue = possibleValues[dist(state.rng)];
|
||||
|
||||
{
|
||||
// copy the state and branch out
|
||||
SolverState newState(state);
|
||||
newState.wave.Collapse(minEntropyCell, 1 << selectedValue);
|
||||
newState.propagationQueue.push(minEntropyCell);
|
||||
|
||||
if (RunLoop(newState))
|
||||
{
|
||||
// copy the solution to the original state
|
||||
state.wave = newState.wave;
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
void Propagate(WorldSolver& worldSolver)
|
||||
// remove the failure state from the wave
|
||||
assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) != 0 && "Possible value was not set");
|
||||
state.wave.Collapse(minEntropyCell, ~(1 << selectedValue));
|
||||
assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) == 0 && "Wave was not collapsed correctly");
|
||||
|
||||
// swap replacement value with the last value
|
||||
std::swap(possibleValues[i], possibleValues[--availableValues]);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void Propagate(SolverState& state)
|
||||
{
|
||||
while (!worldSolver.propagationQueue.empty())
|
||||
while (!state.propagationQueue.empty())
|
||||
{
|
||||
size_t cellId = worldSolver.propagationQueue.front();
|
||||
worldSolver.propagationQueue.pop();
|
||||
size_t cellId = state.propagationQueue.front();
|
||||
state.propagationQueue.pop();
|
||||
|
||||
assert(worldSolver.wave.IsCollapsed(cellId) && "Cell was not collapsed");
|
||||
assert(state.wave.IsCollapsed(cellId) && "Cell was not collapsed");
|
||||
|
||||
uint16_t variableID = worldSolver.wave.GetVariableID(cellId);
|
||||
Constrainer<VariableIDMapT> constrainer(worldSolver.wave, worldSolver.propagationQueue);
|
||||
m_variables[variableID].constraintFunc(worldSolver.world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
|
||||
uint16_t variableID = state.wave.GetVariableID(cellId);
|
||||
Constrainer<VariableIDMapT> constrainer(state.wave, state.propagationQueue);
|
||||
m_variables[variableID].constraintFunc(state.world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
|
||||
}
|
||||
}
|
||||
|
||||
void PopulateWorld(WorldSolver& worldSolver)
|
||||
void PopulateWorld(SolverState& state)
|
||||
{
|
||||
for (size_t i = 0; i < worldSolver.wave.size(); ++i)
|
||||
for (size_t i = 0; i < state.wave.size(); ++i)
|
||||
{
|
||||
worldSolver.world.setValue(i, VariableIDMapT::GetValue(worldSolver.wave.GetVariableID(i)));
|
||||
state.world.setValue(i, VariableIDMapT::GetValue(state.wave.GetVariableID(i)));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void PropogateInitialValues(WorldSolver& worldSolver)
|
||||
void PropogateInitialValues(SolverState& state)
|
||||
{
|
||||
auto allValues = VariableIDMapT::GetAllValues();
|
||||
for (size_t i = 0; i < worldSolver.wave.size(); ++i)
|
||||
for (size_t i = 0; i < state.wave.size(); ++i)
|
||||
{
|
||||
for (size_t j = 0; j < allValues.size(); ++j)
|
||||
{
|
||||
if (worldSolver.world.getValue(i) == allValues[j])
|
||||
if (state.world.getValue(i) == allValues[j])
|
||||
{
|
||||
worldSolver.wave.Collapse(i, 1 << j);
|
||||
worldSolver.propagationQueue.push(i);
|
||||
state.wave.Collapse(i, 1 << j);
|
||||
state.propagationQueue.push(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user