sudoku class integration
This commit is contained in:
@@ -42,6 +42,7 @@ add_executable(sudoku_demo
|
||||
# Create WFC demo executable
|
||||
add_executable(sudoku_wfc_demo
|
||||
sudoku_wfc.cpp
|
||||
sudoku.cpp
|
||||
)
|
||||
|
||||
# Set output directory for sudoku_demo
|
||||
|
||||
@@ -193,6 +193,22 @@ private:
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public: // WFC Support
|
||||
using ValueType = uint8_t;
|
||||
|
||||
ValueType getValue(size_t index) const {
|
||||
return board_.get(index);
|
||||
}
|
||||
|
||||
void setValue(size_t index, ValueType value) {
|
||||
board_.set(index, value);
|
||||
}
|
||||
|
||||
constexpr size_t size() const {
|
||||
return 81;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// Static assert to ensure exactly 41 bytes
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
#include <nd-wfc/wfc.hpp>
|
||||
#include <nd-wfc/worlds.hpp>
|
||||
#include "sudoku.h"
|
||||
#include <iostream>
|
||||
|
||||
int main()
|
||||
{
|
||||
std::cout << "Running Sudoku WFC" << std::endl;
|
||||
|
||||
auto sudokuSolver = WFC::Builder<WFC::Array2D<uint8_t, 9, 9>, uint8_t>()
|
||||
auto sudokuSolver = WFC::Builder<Sudoku, uint8_t>()
|
||||
.DefineIDs<1, 2, 3, 4, 5, 6, 7, 8, 9>()
|
||||
.Variable<1, 2, 3, 4, 5, 6, 7, 8, 9>([](WFC::Array2D<uint8_t, 9, 9>&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
|
||||
.Variable<1, 2, 3, 4, 5, 6, 7, 8, 9>([](Sudoku&, size_t index, WFC::WorldValue<uint8_t> val, auto& constrainer) {
|
||||
size_t x = index % 9;
|
||||
size_t y = index / 9;
|
||||
|
||||
@@ -38,15 +39,17 @@ int main()
|
||||
})
|
||||
.build();
|
||||
|
||||
WFC::Array2D<uint8_t, 9, 9> sudokuWorld;
|
||||
bool success = sudokuSolver.Run(sudokuWorld);
|
||||
Sudoku sudokuWorld;
|
||||
sudokuWorld.setValue(0, 5);
|
||||
sudokuWorld.setValue(80, 1);
|
||||
bool success = sudokuSolver.Run(sudokuWorld, true);
|
||||
|
||||
if (success) {
|
||||
std::cout << "Sudoku solved successfully!" << std::endl;
|
||||
// Print the solved sudoku
|
||||
for (size_t y = 0; y < 9; ++y) {
|
||||
for (size_t x = 0; x < 9; ++x) {
|
||||
std::cout << static_cast<int>(sudokuWorld.at(static_cast<int>(x), static_cast<int>(y))) << " ";
|
||||
std::cout << static_cast<int>(sudokuWorld.getValue(x + y * 9)) << " ";
|
||||
if (x == 2 || x == 5) std::cout << "| ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <optional>
|
||||
@@ -17,7 +15,6 @@
|
||||
namespace WFC {
|
||||
|
||||
inline int FindNthSetBit(size_t num, int n) {
|
||||
auto popCount = std::popcount(num);
|
||||
assert(n < std::popcount(num) && "index is out of range");
|
||||
int bitCount = 0;
|
||||
while (num) {
|
||||
@@ -27,7 +24,6 @@ inline int FindNthSetBit(size_t num, int n) {
|
||||
bitCount++;
|
||||
num &= (num - 1); // turn of lowest set bit
|
||||
}
|
||||
assert(bitCount < popCount && "out of bounds");
|
||||
return bitCount;
|
||||
}
|
||||
|
||||
@@ -35,6 +31,7 @@ template<typename T>
|
||||
concept WorldType = requires(T world, size_t id, typename T::ValueType value) {
|
||||
{ world.size() } -> std::convertible_to<size_t>;
|
||||
{ world.setValue(id, value) };
|
||||
{ world.getValue(id) } -> std::convertible_to<typename T::ValueType>;
|
||||
typename T::ValueType;
|
||||
};
|
||||
|
||||
@@ -96,7 +93,7 @@ public:
|
||||
{
|
||||
static_assert(HasValue<Value>(), "Value was not defined");
|
||||
constexpr VarT arr[] = {Values...};
|
||||
constexpr size_t size = sizeof...(Values);
|
||||
constexpr size_t size = ValuesRegisteredAmount;
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if (arr[i] == Value)
|
||||
@@ -106,7 +103,7 @@ public:
|
||||
}
|
||||
|
||||
static VarT GetValue(size_t index) {
|
||||
assert(index < sizeof...(Values));
|
||||
assert(index < ValuesRegisteredAmount);
|
||||
constexpr VarT arr[] = {Values...};
|
||||
return arr[index];
|
||||
}
|
||||
@@ -117,7 +114,12 @@ public:
|
||||
return (0 | ... | (1 << GetIndex<MaskValues>()));
|
||||
}
|
||||
|
||||
static consteval size_t size() { return sizeof...(Values); }
|
||||
static consteval std::array<VarT, ValuesRegisteredAmount> GetAllValues()
|
||||
{
|
||||
return {Values...};
|
||||
}
|
||||
|
||||
static consteval size_t size() { return ValuesRegisteredAmount; }
|
||||
};
|
||||
|
||||
template <typename VarT>
|
||||
@@ -267,18 +269,23 @@ public:
|
||||
{}
|
||||
|
||||
public:
|
||||
bool Run(WorldT& world)
|
||||
bool Run(WorldT& world, bool propagateInitialValues = false)
|
||||
{
|
||||
WorldSolver worldSolver(world, m_variables);
|
||||
return Run(worldSolver);
|
||||
return Run(worldSolver, propagateInitialValues);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Run the WFC algorithm to generate a solution
|
||||
* @return true if a solution was found, false if contradiction occurred
|
||||
*/
|
||||
bool Run(WorldSolver& worldSolver)
|
||||
bool Run(WorldSolver& worldSolver, bool propagateInitialValues = false)
|
||||
{
|
||||
if (propagateInitialValues)
|
||||
{
|
||||
PropogateInitialValues(worldSolver);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < 1024; ++i)
|
||||
{
|
||||
Propagate(worldSolver);
|
||||
@@ -379,6 +386,25 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void PropogateInitialValues(WorldSolver& worldSolver)
|
||||
{
|
||||
auto allValues = VariableIDMapT::GetAllValues();
|
||||
for (size_t i = 0; i < worldSolver.wave.size(); ++i)
|
||||
{
|
||||
for (size_t j = 0; j < allValues.size(); ++j)
|
||||
{
|
||||
if (worldSolver.world.getValue(i) == allValues[j])
|
||||
{
|
||||
worldSolver.wave.Collapse(i, 1 << j);
|
||||
worldSolver.propagationQueue.push(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<VariableData<WorldT, VarT, VariableIDMapT>> m_variables {};
|
||||
};
|
||||
|
||||
|
||||
@@ -86,6 +86,13 @@ public:
|
||||
data_[index] = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get value at specific index
|
||||
*/
|
||||
T getValue(size_t index) const {
|
||||
return data_[index];
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<T, Width * Height> data_;
|
||||
};
|
||||
@@ -161,6 +168,13 @@ public:
|
||||
data_[index] = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get value at specific index
|
||||
*/
|
||||
T getValue(size_t index) const {
|
||||
return data_[index];
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<T, Width * Height * Depth> data_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user