Files
nd-wfc/include/nd-wfc/wfc.hpp
2025-08-24 19:05:16 +09:00

446 lines
14 KiB
C++

#pragma once
#include <vector>
#include <functional>
#include <memory>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <random>
#include <optional>
#include <type_traits>
#include <cassert>
#include <algorithm>
#include <concepts>
#include <bit>
namespace WFC {
inline int FindNthSetBit(size_t num, int n) {
auto popCount = std::popcount(num);
assert(n < std::popcount(num) && "index is out of range");
int bitCount = 0;
while (num) {
if (bitCount == n) {
return std::countr_zero(num); // Index of the current set bit
}
bitCount++;
num &= (num - 1); // turn of lowest set bit
}
assert(bitCount < popCount && "out of bounds");
return bitCount;
}
template<typename T>
concept WorldType = requires(T world, size_t id, typename T::ValueType value) {
{ world.size() } -> std::convertible_to<size_t>;
{ world.setValue(id, value) };
typename T::ValueType;
};
template <typename MaskType>
class Wave;
template <typename VariableIDMapT>
class Constrainer;
template<typename WorldT, typename VarT, typename VariableIDMapT>
class WFC;
template<typename WorldT, typename VarT, typename VariableIDMapT>
class Variable;
template <typename VarT>
struct WorldValue;
/**
* @brief Class to map variable values to indices at compile time
*
* This class is used to map variable values to indices at compile time.
* It is a compile-time map of variable values to indices.
*/
template <typename VarT, VarT ... Values>
class VariableIDMap {
public:
using Type = VarT;
static constexpr size_t ValuesRegisteredAmount = sizeof...(Values);
using MaskType = typename std::conditional<
ValuesRegisteredAmount <= 8,
uint8_t,
typename std::conditional<
ValuesRegisteredAmount <= 16,
uint16_t,
typename std::conditional<
ValuesRegisteredAmount <= 32,
uint32_t,
uint64_t
>::type
>::type
>::type;
template <VarT ... AdditionalValues>
using Merge = VariableIDMap<VarT, Values..., AdditionalValues...>;
template <VarT Value>
static consteval bool HasValue()
{
constexpr VarT arr[] = {Values...};
constexpr size_t size = sizeof...(Values);
for (size_t i = 0; i < size; ++i)
if (arr[i] == Value)
return true;
return false;
}
template <VarT Value>
static consteval size_t GetIndex()
{
static_assert(HasValue<Value>(), "Value was not defined");
constexpr VarT arr[] = {Values...};
constexpr size_t size = sizeof...(Values);
for (size_t i = 0; i < size; ++i)
if (arr[i] == Value)
return i;
return static_cast<size_t>(-1); // This line is unreachable if value is found
}
static VarT GetValue(size_t index) {
assert(index < sizeof...(Values));
constexpr VarT arr[] = {Values...};
return arr[index];
}
template <VarT ... MaskValues>
static consteval MaskType GetMask()
{
return (0 | ... | (1 << GetIndex<MaskValues>()));
}
static consteval size_t size() { return sizeof...(Values); }
};
template <typename VarT>
struct WorldValue
{
public:
WorldValue() = default;
WorldValue(VarT value, uint16_t internalIndex)
: Value(value)
, InternalIndex(internalIndex)
{}
public:
operator VarT() const { return Value; }
public:
VarT Value{};
uint16_t InternalIndex{};
};
template <typename MaskType>
class Wave {
public:
Wave() = default;
Wave(size_t size, size_t variableAmount) : m_data(size)
{
for (auto& wave : m_data) wave = (1 << variableAmount) - 1;
}
void Collapse(size_t index, MaskType mask) { m_data[index] &= mask; }
size_t size() const { return m_data.size(); }
size_t Entropy(size_t index) const { return std::popcount(m_data[index]); }
bool IsCollapsed(size_t index) const { return Entropy(index) == 1; }
bool IsFullyCollapsed() const { return std::all_of(m_data.begin(), m_data.end(), [](MaskType value) { return std::popcount(value) == 1; }); }
bool HasContradiction() const { return std::any_of(m_data.begin(), m_data.end(), [](MaskType value) { return value == 0; }); }
bool IsContradicted(size_t index) const { return m_data[index] == 0; }
uint16_t GetVariableID(size_t index) const { return static_cast<uint16_t>(std::countr_zero(m_data[index])); }
MaskType GetMask(size_t index) const { return m_data[index]; }
private:
std::vector<MaskType> m_data;
};
/**
* @brief Constrainer class used in constraint functions to limit possible values for other cells
*/
template <typename VariableIDMapT>
class Constrainer {
public:
using MaskType = typename VariableIDMapT::MaskType;
public:
Constrainer(Wave<MaskType>& wave, std::queue<size_t>& propagationQueue)
: m_wave(wave)
, m_propagationQueue(propagationQueue)
{}
/**
* @brief Constrain a cell to exclude specific values
* @param cellId The ID of the cell to constrain
* @param forbiddenValues The set of forbidden values for this cell
*/
template <typename VariableIDMapT::Type ... ExcludedValues>
void Exclude(size_t cellId) {
static_assert(sizeof...(ExcludedValues) > 0, "At least one excluded value must be provided");
ApplyMask(cellId, ~VariableIDMapT::template GetMask<ExcludedValues...>());
}
void Exclude(WorldValue<typename VariableIDMapT::Type> value, size_t cellId) {
ApplyMask(cellId, ~(1 << value.InternalIndex));
}
/**
* @brief Constrain a cell to only allow one specific value
* @param cellId The ID of the cell to constrain
* @param value The only allowed value for this cell
*/
template <typename VariableIDMapT::Type ... AllowedValues>
void Only(size_t cellId) {
static_assert(sizeof...(AllowedValues) > 0, "At least one allowed value must be provided");
ApplyMask(cellId, VariableIDMapT::template GetMask<AllowedValues...>());
}
void Only(WorldValue<typename VariableIDMapT::Type> value, size_t cellId) {
ApplyMask(cellId, 1 << value.InternalIndex);
}
private:
void ApplyMask(size_t cellId, MaskType mask) {
if (m_wave.IsCollapsed(cellId)) return;
m_wave.Collapse(cellId, mask);
assert(!m_wave.HasContradiction() && "Contradiction found");
if (m_wave.IsCollapsed(cellId)) {
m_propagationQueue.push(cellId);
}
}
private:
Wave<MaskType>& m_wave;
std::queue<size_t>& m_propagationQueue;
};
/**
* @brief Variable definition with its constraint function
*/
template<typename WorldT, typename VarT, typename VariableIDMapT>
struct VariableData {
VarT value{};
std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc{};
VariableData() = default;
VariableData(VarT value, std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc)
: value(value)
, constraintFunc(constraintFunc)
{}
};
/**
* @brief Main WFC class implementing the Wave Function Collapse algorithm
*/
template<typename WorldT, typename VarT, typename VariableIDMapT = VariableIDMap<VarT>>
class WFC {
public:
static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements");
using MaskType = typename VariableIDMapT::MaskType;
public:
struct WorldSolver {
WorldT& world;
std::queue<size_t> propagationQueue;
Wave<MaskType> wave;
std::mt19937 rng;
WorldSolver(WorldT& world, const std::vector<VariableData<WorldT, VarT, VariableIDMapT>>& variables)
: world(world)
, propagationQueue()
, wave(world.size(), variables.size())
, rng(std::random_device{}())
{}
};
public:
WFC(std::vector<VariableData<WorldT, VarT, VariableIDMapT>>&& variables)
: m_variables(std::move(variables))
{}
public:
bool Run(WorldT& world)
{
WorldSolver worldSolver(world, m_variables);
return Run(worldSolver);
}
/**
* @brief Run the WFC algorithm to generate a solution
* @return true if a solution was found, false if contradiction occurred
*/
bool Run(WorldSolver& worldSolver)
{
for (size_t i = 0; i < 1024; ++i)
{
Propagate(worldSolver);
if (worldSolver.wave.IsFullyCollapsed()) {
PopulateWorld(worldSolver);
return true;
} else if (worldSolver.wave.HasContradiction()) {
return false;
} else {
GetMinEntropyCell(worldSolver);
}
}
return false;
}
/**
* @brief Get the value at a specific cell
* @param cellId The cell ID
* @return The value if collapsed, std::nullopt otherwise
*/
std::optional<VarT> GetValue(WorldSolver& worldSolver, int cellId) const {
if (worldSolver.wave.IsCollapsed(cellId)) {
auto variableId = worldSolver.wave.GetVariableID(cellId);
return VariableIDMapT::GetValue(variableId);
}
return std::nullopt;
}
/**
* @brief Get all possible values for a cell
* @param cellId The cell ID
* @return Set of possible values
*/
const std::vector<VarT> GetPossibleValues(WorldSolver& worldSolver, int cellId) const
{
std::vector<VarT> possibleValues;
MaskType mask = worldSolver.wave.GetMask(cellId);
for (size_t i = 0; i < m_variables.size(); ++i) {
if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i));
}
return possibleValues;
}
private:
bool GetMinEntropyCell(WorldSolver& worldSolver)
{
assert(worldSolver.propagationQueue.empty());
// Find cell with minimum entropy > 1
size_t minEntropyCell = static_cast<size_t>(-1);
size_t minEntropy = static_cast<size_t>(-1);
for (size_t i = 0; i < worldSolver.wave.size(); ++i) {
size_t entropy = worldSolver.wave.Entropy(i);
if (entropy > 1 && entropy < minEntropy) {
minEntropy = entropy;
minEntropyCell = i;
}
}
assert(!worldSolver.wave.IsCollapsed(minEntropyCell));
// Randomly select a value from possible values
size_t availableValues = worldSolver.wave.Entropy(minEntropyCell);
std::uniform_int_distribution<size_t> dist(0, availableValues - 1);
size_t selectedValue = FindNthSetBit(worldSolver.wave.GetMask(minEntropyCell), dist(worldSolver.rng));
assert(selectedValue < VariableIDMapT::ValuesRegisteredAmount && "Selected Value went outside bounds");
// Collapse the cell to the selected value
worldSolver.wave.Collapse(minEntropyCell, 1 << selectedValue);
assert(worldSolver.wave.IsCollapsed(minEntropyCell) && "Cell was not collapsed correctly");
worldSolver.propagationQueue.push(minEntropyCell);
return true;
}
void Propagate(WorldSolver& worldSolver)
{
while (!worldSolver.propagationQueue.empty())
{
size_t cellId = worldSolver.propagationQueue.front();
worldSolver.propagationQueue.pop();
assert(worldSolver.wave.IsCollapsed(cellId) && "Cell was not collapsed");
uint16_t variableID = worldSolver.wave.GetVariableID(cellId);
Constrainer<VariableIDMapT> constrainer(worldSolver.wave, worldSolver.propagationQueue);
m_variables[variableID].constraintFunc(worldSolver.world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
}
}
void PopulateWorld(WorldSolver& worldSolver)
{
for (size_t i = 0; i < worldSolver.wave.size(); ++i)
{
worldSolver.world.setValue(i, VariableIDMapT::GetValue(worldSolver.wave.GetVariableID(i)));
}
}
std::vector<VariableData<WorldT, VarT, VariableIDMapT>> m_variables {};
};
/**
* @brief Builder class for creating WFC instances
*/
template<typename WorldT, typename VarT, typename VariableIDMapT = VariableIDMap<VarT>>
class Builder {
public:
Builder() = default;
Builder(std::vector<VariableData<WorldT, VarT, VariableIDMapT>>&& variables)
: m_variables(std::move(variables))
{}
public:
template <VarT ... Values>
auto DefineIDs()
{
using NewVariableIDMapT = typename VariableIDMapT::template Merge<Values...>;
// reinterpret_cast is used to be able to move the variables with an outdated VariableIDMap to the new VariableIDMap. The previous indices still work.
return Builder<WorldT, VarT, NewVariableIDMapT>(std::move(reinterpret_cast<std::vector<VariableData<WorldT, VarT, NewVariableIDMapT>>&>(m_variables)));
}
/**
* @brief Add a variable with its constraint function
* @param value The variable value
* @param constraintFunc Function that defines constraints when this variable is placed
* @return Reference to this builder for method chaining
*/
template <VarT ... Values>
Builder& Variable(const std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc) {
m_variables.resize(VariableIDMapT::ValuesRegisteredAmount);
Variable_Internal<Values...>(constraintFunc);
return *this;
}
/**
* @brief Build the WFC instance
* @param world The world instance to work with
* @return A unique_ptr to the created WFC instance
*/
auto build() {
return WFC<WorldT, VarT, VariableIDMapT>(std::move(m_variables));
}
private:
template <VarT Value, VarT ... Values>
void Variable_Internal(const std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc)
{
m_variables[VariableIDMapT::template GetIndex<Value>()] = VariableData<WorldT, VarT, VariableIDMapT>{
Value,
constraintFunc
};
if constexpr (sizeof...(Values) > 0) {
Variable_Internal<Values...>(constraintFunc);
}
}
private:
std::vector<VariableData<WorldT, VarT, VariableIDMapT>> m_variables;
};
} // namespace WFC