WFC class refactor

This commit is contained in:
Connor
2026-02-06 20:05:16 +09:00
parent 0816546550
commit 414ded7e09
5 changed files with 282 additions and 259 deletions

View File

@@ -64,7 +64,7 @@ int main()
Sudoku sudokuWorld = Sudoku{ "6......3.......7....7463....7.8...2.4...9...1.9...7.8....9851....6.......1......9" };
bool success = SudokuSolverCallback::Run(sudokuWorld, true);
bool success = WFC::Run<SudokuSolverCallback>(sudokuWorld, true);
bool solved = sudokuWorld.isSolved();

View File

@@ -33,7 +33,7 @@ protected:
// Helper function to solve a puzzle using WFC
void solvePuzzle(Sudoku& sudoku) {
SudokuSolver::Run(sudoku, true);
WFC::Run<SudokuSolver>(sudoku, true);
}
};
@@ -286,7 +286,7 @@ void testPuzzleSolving(const std::string& difficulty, const std::string& filenam
Sudoku& sudoku = puzzles[i];
EXPECT_TRUE(sudoku.isValid()) << difficulty << " puzzle " << i << " is not valid";
SudokuSolver::Run(sudoku);
WFC::Run<SudokuSolver>(sudoku);
EXPECT_TRUE(sudoku.isSolved()) << difficulty << " puzzle " << i << " was not solved. Puzzle string: " << sudoku.toString();

View File

@@ -44,30 +44,14 @@ concept HasConstexprSize = requires {
{ []() constexpr -> std::size_t { return WorldT{}.size(); }() };
};
template<typename WorldT, typename VarT,
typename VariableIDMapT = VariableIDMap<VarT>,
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
typename CallbacksT = Callbacks<WorldT>,
typename RandomSelectorT = DefaultRandomSelector<VarT>
>
class WFC {
public:
static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements");
// Standalone SolverState struct
template <typename WorldT, typename RandomSelectorT = DefaultRandomSelector<typename WorldT::ValueType>>
struct SolverState {
using WorldType = WorldT;
using WorldSizeT = decltype(WorldT{}.size());
// Try getting the world size, which is only available if the world type has a constexpr size() method
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
using WaveType = Wave<VariableIDMapT, WorldSize>;
static constexpr WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
using MaskType = typename WaveType::ElementT;
using VariableIDT = typename WaveType::VariableIDT;
public:
struct SolverState
{
WorldT& m_world;
PropagationQueueType m_propagationQueue{};
RandomSelectorT m_randomSelector{};
@@ -83,101 +67,44 @@ public:
SolverState(const SolverState& other) = default;
};
public:
WFC() = delete; // dont make an instance of this class, only use the static methods.
// Types-only config struct produced by Builder
template <typename WorldT, typename VarT, typename VariableIDMapT,
typename ConstrainerFunctionMapT, typename CallbacksT, typename RandomSelectorT>
struct WFCConfig {
static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements");
public:
using WorldSizeT = decltype(WorldT{}.size());
static constexpr WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
using SolverStateType = SolverState<WorldT, RandomSelectorT>;
using WaveType = Wave<VariableIDMapT, WorldSize>;
using CallbacksType = CallbacksT;
using ConstrainerFunctionMapType = ConstrainerFunctionMapT;
};
static bool Run(WorldT& world, uint32_t seed = std::random_device{}())
// Forward declarations for mutually recursive functions
template <typename CallbacksT, typename ConstrainerFunctionMapT, typename StateT, typename WaveT>
bool RunLoop(StateT& state, WaveT& wave);
template <typename CallbacksT, typename ConstrainerFunctionMapT, typename StateT, typename WaveT>
bool Branch(StateT& state, WaveT& wave);
namespace detail {
template <typename StateT, typename WaveT>
void PopulateWorld(StateT& state, WaveT& wave)
{
SolverState state{ world, seed };
bool result = Run(state);
return result;
}
/**
* @brief Run the WFC algorithm to generate a solution
* @return true if a solution was found, false if contradiction occurred
*/
static bool Run(SolverState& state)
using VariableIDMapT = typename WaveT::IDMapT;
for (size_t i = 0; i < wave.size(); ++i)
{
WaveType wave{ WorldSize, VariableIDMapT::size(), state.m_allocator };
PropogateInitialValues(state, wave);
if (RunLoop(state, wave)) {
PopulateWorld(state, wave);
return true;
if (wave.IsCollapsed(i))
state.m_world.setValue(i, VariableIDMapT::GetValue(wave.GetVariableID(i)));
}
return false;
}
static bool RunLoop(SolverState& state, WaveType& wave)
{
static constexpr size_t MaxIterations = 1024 * 8;
for (; state.m_iterations < MaxIterations; ++state.m_iterations)
{
if (!Propagate(state, wave))
return false;
if (wave.HasContradiction())
{
if constexpr (CallbacksT::HasContradictionCallback())
{
PopulateWorld(state, wave);
typename CallbacksT::ContradictionCallback{}(state.m_world);
}
return false;
}
if (wave.IsFullyCollapsed())
return true;
if constexpr (CallbacksT::HasBranchCallback())
{
PopulateWorld(state, wave);
typename CallbacksT::BranchCallback{}(state.m_world);
}
if (Branch(state, wave))
return true;
}
return false;
}
/**
* @brief Get the value at a specific cell
* @param cellId The cell ID
* @return The value if collapsed, std::nullopt otherwise
*/
static std::optional<VarT> GetValue(WaveType& wave, int cellId) {
if (wave.IsCollapsed(cellId)) {
auto variableId = wave.GetVariableID(cellId);
return VariableIDMapT::GetValue(variableId);
}
return std::nullopt;
}
/**
* @brief Get all possible values for a cell
* @param cellId The cell ID
* @return Set of possible values
*/
static const std::vector<VarT> GetPossibleValues(WaveType& wave, int cellId)
{
std::vector<VarT> possibleValues;
MaskType mask = wave.GetMask(cellId);
for (size_t i = 0; i < ConstrainerFunctionMapT::size(); ++i) {
if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i));
}
return possibleValues;
}
private:
static void CollapseCell(SolverState& state, WaveType& wave, WorldSizeT cellId, VariableIDT value)
template <typename CallbacksT, typename StateT, typename WaveT>
void CollapseCell(StateT& state, WaveT& wave, typename StateT::WorldSizeT cellId, typename WaveT::VariableIDT value)
{
using MaskType = typename WaveT::ElementT;
constexpr_assert(!wave.IsCollapsed(cellId) || wave.GetMask(cellId) == (MaskType(1) << value));
wave.Collapse(cellId, 1 << value);
constexpr_assert(wave.IsCollapsed(cellId));
@@ -189,8 +116,65 @@ private:
}
}
static bool Branch(SolverState& state, WaveType& wave)
template <typename CallbacksT, typename StateT, typename WaveT>
void PropogateInitialValues(StateT& state, WaveT& wave)
{
using VariableIDMapT = typename WaveT::IDMapT;
using WorldSizeT = typename StateT::WorldSizeT;
using VariableIDT = typename WaveT::VariableIDT;
for (size_t i = 0; i < wave.size(); ++i)
{
for (size_t j = 0; j < VariableIDMapT::size(); ++j)
{
if (state.m_world.getValue(i) == VariableIDMapT::GetValue(j))
{
CollapseCell<CallbacksT>(state, wave, static_cast<WorldSizeT>(i), static_cast<VariableIDT>(j));
state.m_propagationQueue.push(i);
break;
}
}
}
}
template <typename ConstrainerFunctionMapT, typename StateT, typename WaveT>
bool Propagate(StateT& state, WaveT& wave)
{
using VariableIDMapT = typename WaveT::IDMapT;
using VarT = typename VariableIDMapT::Type;
using WorldSizeT = typename StateT::WorldSizeT;
using VariableIDT = typename WaveT::VariableIDT;
using PropagationQueueType = typename StateT::PropagationQueueType;
using ConstrainerType = Constrainer<WaveT, PropagationQueueType>;
while (!state.m_propagationQueue.empty())
{
WorldSizeT cellId = state.m_propagationQueue.pop();
if (wave.IsContradicted(cellId)) return false;
constexpr_assert(wave.IsCollapsed(cellId), "Cell was not collapsed");
VariableIDT variableID = wave.GetVariableID(cellId);
ConstrainerType constrainer(wave, state.m_propagationQueue);
using WorldT = typename StateT::WorldType;
using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue<VarT>, ConstrainerType&);
ConstrainerFunctionMapT::template GetFunction<ConstrainerFunctionPtrT>(variableID)(state.m_world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
}
return true;
}
} // namespace detail
template <typename CallbacksT, typename ConstrainerFunctionMapT, typename StateT, typename WaveT>
bool Branch(StateT& state, WaveT& wave)
{
using VariableIDMapT = typename WaveT::IDMapT;
using MaskType = typename WaveT::ElementT;
using WorldSizeT = typename StateT::WorldSizeT;
using VariableIDT = typename WaveT::VariableIDT;
constexpr_assert(state.m_propagationQueue.empty());
// Find cell with minimum entropy > 1
@@ -235,10 +219,10 @@ private:
auto queueFrame = state.m_propagationQueue.createBranchPoint();
auto newWave = wave;
CollapseCell(state, newWave, minEntropyCell, selectedValue);
detail::CollapseCell<CallbacksT>(state, newWave, minEntropyCell, selectedValue);
state.m_propagationQueue.push(minEntropyCell);
if (RunLoop(state, newWave))
if (RunLoop<CallbacksT, ConstrainerFunctionMapT>(state, newWave))
{
// move the solution to the original state
wave = newWave;
@@ -259,50 +243,88 @@ private:
return false;
}
static bool Propagate(SolverState& state, WaveType& wave)
template <typename CallbacksT, typename ConstrainerFunctionMapT, typename StateT, typename WaveT>
bool RunLoop(StateT& state, WaveT& wave)
{
while (!state.m_propagationQueue.empty())
static constexpr size_t MaxIterations = 1024 * 8;
for (; state.m_iterations < MaxIterations; ++state.m_iterations)
{
WorldSizeT cellId = state.m_propagationQueue.pop();
if (!detail::Propagate<ConstrainerFunctionMapT>(state, wave))
return false;
if (wave.IsContradicted(cellId)) return false;
constexpr_assert(wave.IsCollapsed(cellId), "Cell was not collapsed");
VariableIDT variableID = wave.GetVariableID(cellId);
ConstrainerType constrainer(wave, state.m_propagationQueue);
using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue<VarT>, ConstrainerType&);
ConstrainerFunctionMapT::template GetFunction<ConstrainerFunctionPtrT>(variableID)(state.m_world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
if (wave.HasContradiction())
{
if constexpr (CallbacksT::HasContradictionCallback())
{
detail::PopulateWorld(state, wave);
typename CallbacksT::ContradictionCallback{}(state.m_world);
}
return false;
}
if (wave.IsFullyCollapsed())
return true;
if constexpr (CallbacksT::HasBranchCallback())
{
detail::PopulateWorld(state, wave);
typename CallbacksT::BranchCallback{}(state.m_world);
}
if (Branch<CallbacksT, ConstrainerFunctionMapT>(state, wave))
return true;
}
static void PopulateWorld(SolverState& state, WaveType& wave)
{
for (size_t i = 0; i < wave.size(); ++i)
{
if (wave.IsCollapsed(i))
state.m_world.setValue(i, VariableIDMapT::GetValue(wave.GetVariableID(i)));
}
return false;
}
static void PropogateInitialValues(SolverState& state, WaveType& wave)
template <typename ConfigT>
bool Run(typename ConfigT::SolverStateType& state)
{
for (size_t i = 0; i < wave.size(); ++i)
using CallbacksT = typename ConfigT::CallbacksType;
using ConstrainerFunctionMapT = typename ConfigT::ConstrainerFunctionMapType;
using WaveType = typename ConfigT::WaveType;
using VariableIDMapT = typename WaveType::IDMapT;
WaveType wave{ ConfigT::WorldSize, VariableIDMapT::size(), state.m_allocator };
detail::PropogateInitialValues<CallbacksT>(state, wave);
if (RunLoop<CallbacksT, ConstrainerFunctionMapT>(state, wave)) {
detail::PopulateWorld(state, wave);
return true;
}
return false;
}
template <typename ConfigT, typename WorldT>
bool Run(WorldT& world, uint32_t seed = std::random_device{}())
{
for (size_t j = 0; j < VariableIDMapT::size(); ++j)
typename ConfigT::SolverStateType state{ world, seed };
return Run<ConfigT>(state);
}
template <typename WaveT>
std::optional<typename WaveT::IDMapT::Type> GetValue(WaveT& wave, int cellId) {
using VariableIDMapT = typename WaveT::IDMapT;
if (wave.IsCollapsed(cellId)) {
auto variableId = wave.GetVariableID(cellId);
return VariableIDMapT::GetValue(variableId);
}
return std::nullopt;
}
template <typename ConstrainerFunctionMapT, typename WaveT>
const std::vector<typename WaveT::IDMapT::Type> GetPossibleValues(WaveT& wave, int cellId)
{
if (state.m_world.getValue(i) == VariableIDMapT::GetValue(j))
{
CollapseCell(state, wave, static_cast<WorldSizeT>(i), static_cast<VariableIDT>(j));
state.m_propagationQueue.push(i);
break;
using VariableIDMapT = typename WaveT::IDMapT;
using VarT = typename VariableIDMapT::Type;
using MaskType = typename WaveT::ElementT;
std::vector<VarT> possibleValues;
MaskType mask = wave.GetMask(cellId);
for (size_t i = 0; i < ConstrainerFunctionMapT::size(); ++i) {
if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i));
}
return possibleValues;
}
}
}
};
} // namespace WFC

View File

@@ -85,7 +85,7 @@ public:
template <typename NewRandomSelectorT>
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
using Build = WFCConfig<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
};
}

View File

@@ -19,6 +19,7 @@ using VariableIDType = std::conditional_t<VariablesAmount <= std::numeric_limits
template <typename VarT, VarT ... Values>
class VariableIDMap {
public:
using Type = VarT;
template <VarT ... AdditionalValues>
using Merge = VariableIDMap<VarT, Values..., AdditionalValues...>;