sudoku class integration

This commit is contained in:
cdemeyer-teachx
2025-08-24 20:03:11 +09:00
parent d2d4e8882d
commit 6fce648b01
5 changed files with 75 additions and 15 deletions

View File

@@ -42,6 +42,7 @@ add_executable(sudoku_demo
# Create WFC demo executable
add_executable(sudoku_wfc_demo
sudoku_wfc.cpp
sudoku.cpp
)
# Set output directory for sudoku_demo

View File

@@ -193,6 +193,22 @@ private:
}
return false;
}
public: // WFC Support
using ValueType = uint8_t;
ValueType getValue(size_t index) const {
return board_.get(index);
}
void setValue(size_t index, ValueType value) {
board_.set(index, value);
}
constexpr size_t size() const {
return 81;
}
};
// Static assert to ensure exactly 41 bytes

View File

@@ -1,14 +1,15 @@
#include <nd-wfc/wfc.hpp>
#include <nd-wfc/worlds.hpp>
#include "sudoku.h"
#include <iostream>
int main()
{
std::cout << "Running Sudoku WFC" << std::endl;
auto sudokuSolver = WFC::Builder<WFC::Array2D<uint8_t, 9, 9>, uint8_t>()
auto sudokuSolver = WFC::Builder<Sudoku, uint8_t>()
.DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>()
.Variable<1, 2, 3, 4, 5, 6, 7, 8, 9>([](WFC::Array2D<uint8_t, 9, 9>&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
.Variable<1, 2, 3, 4, 5, 6, 7, 8, 9>([](Sudoku&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
size_t x = index % 9;
size_t y = index / 9;
@@ -38,15 +39,17 @@ int main()
})
.build();
WFC::Array2D<uint8_t, 9, 9> sudokuWorld;
bool success = sudokuSolver.Run(sudokuWorld);
Sudoku sudokuWorld;
sudokuWorld.setValue(0, 5);
sudokuWorld.setValue(80, 1);
bool success = sudokuSolver.Run(sudokuWorld, true);
if (success) {
std::cout << "Sudoku solved successfully!" << std::endl;
// Print the solved sudoku
for (size_t y = 0; y < 9; ++y) {
for (size_t x = 0; x < 9; ++x) {
std::cout << static_cast<int>(sudokuWorld.at(static_cast<int>(x), static_cast<int>(y))) << " ";
std::cout << static_cast<int>(sudokuWorld.getValue(x + y * 9)) << " ";
if (x == 2 || x == 5) std::cout << "| ";
}
std::cout << std::endl;

View File

@@ -3,8 +3,6 @@
#include <vector>
#include <functional>
#include <memory>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <random>
#include <optional>
@@ -17,7 +15,6 @@
namespace WFC {
inline int FindNthSetBit(size_t num, int n) {
auto popCount = std::popcount(num);
assert(n < std::popcount(num) && "index is out of range");
int bitCount = 0;
while (num) {
@@ -27,7 +24,6 @@ inline int FindNthSetBit(size_t num, int n) {
bitCount++;
num &= (num - 1); // turn of lowest set bit
}
assert(bitCount < popCount && "out of bounds");
return bitCount;
}
@@ -35,6 +31,7 @@ template<typename T>
concept WorldType = requires(T world, size_t id, typename T::ValueType value) {
{ world.size() } -> std::convertible_to<size_t>;
{ world.setValue(id, value) };
{ world.getValue(id) } -> std::convertible_to<typename T::ValueType>;
typename T::ValueType;
};
@@ -96,7 +93,7 @@ public:
{
static_assert(HasValue<Value>(), "Value was not defined");
constexpr VarT arr[] = {Values...};
constexpr size_t size = sizeof...(Values);
constexpr size_t size = ValuesRegisteredAmount;
for (size_t i = 0; i < size; ++i)
if (arr[i] == Value)
@@ -106,7 +103,7 @@ public:
}
static VarT GetValue(size_t index) {
assert(index < sizeof...(Values));
assert(index < ValuesRegisteredAmount);
constexpr VarT arr[] = {Values...};
return arr[index];
}
@@ -117,7 +114,12 @@ public:
return (0 | ... | (1 << GetIndex<MaskValues>()));
}
static consteval size_t size() { return sizeof...(Values); }
static consteval std::array<VarT, ValuesRegisteredAmount> GetAllValues()
{
return {Values...};
}
static consteval size_t size() { return ValuesRegisteredAmount; }
};
template <typename VarT>
@@ -267,18 +269,23 @@ public:
{}
public:
bool Run(WorldT& world)
bool Run(WorldT& world, bool propagateInitialValues = false)
{
WorldSolver worldSolver(world, m_variables);
return Run(worldSolver);
return Run(worldSolver, propagateInitialValues);
}
/**
* @brief Run the WFC algorithm to generate a solution
* @return true if a solution was found, false if contradiction occurred
*/
bool Run(WorldSolver& worldSolver)
bool Run(WorldSolver& worldSolver, bool propagateInitialValues = false)
{
if (propagateInitialValues)
{
PropogateInitialValues(worldSolver);
}
for (size_t i = 0; i < 1024; ++i)
{
Propagate(worldSolver);
@@ -379,6 +386,25 @@ private:
}
}
private:
void PropogateInitialValues(WorldSolver& worldSolver)
{
auto allValues = VariableIDMapT::GetAllValues();
for (size_t i = 0; i < worldSolver.wave.size(); ++i)
{
for (size_t j = 0; j < allValues.size(); ++j)
{
if (worldSolver.world.getValue(i) == allValues[j])
{
worldSolver.wave.Collapse(i, 1 << j);
worldSolver.propagationQueue.push(i);
break;
}
}
}
}
private:
std::vector<VariableData<WorldT, VarT, VariableIDMapT>> m_variables {};
};

View File

@@ -86,6 +86,13 @@ public:
data_[index] = value;
}
/**
* @brief Get value at specific index
*/
T getValue(size_t index) const {
return data_[index];
}
private:
std::array<T, Width * Height> data_;
};
@@ -161,6 +168,13 @@ public:
data_[index] = value;
}
/**
* @brief Get value at specific index
*/
T getValue(size_t index) const {
return data_[index];
}
private:
std::array<T, Width * Height * Depth> data_;
};