446 lines
14 KiB
C++
446 lines
14 KiB
C++
#pragma once
|
|
|
|
#include <vector>
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <unordered_set>
|
|
#include <unordered_map>
|
|
#include <queue>
|
|
#include <random>
|
|
#include <optional>
|
|
#include <type_traits>
|
|
#include <cassert>
|
|
#include <algorithm>
|
|
#include <concepts>
|
|
#include <bit>
|
|
|
|
namespace WFC {
|
|
|
|
inline int FindNthSetBit(size_t num, int n) {
|
|
auto popCount = std::popcount(num);
|
|
assert(n < std::popcount(num) && "index is out of range");
|
|
int bitCount = 0;
|
|
while (num) {
|
|
if (bitCount == n) {
|
|
return std::countr_zero(num); // Index of the current set bit
|
|
}
|
|
bitCount++;
|
|
num &= (num - 1); // turn of lowest set bit
|
|
}
|
|
assert(bitCount < popCount && "out of bounds");
|
|
return bitCount;
|
|
}
|
|
|
|
template<typename T>
|
|
concept WorldType = requires(T world, size_t id, typename T::ValueType value) {
|
|
{ world.size() } -> std::convertible_to<size_t>;
|
|
{ world.setValue(id, value) };
|
|
typename T::ValueType;
|
|
};
|
|
|
|
template <typename MaskType>
|
|
class Wave;
|
|
template <typename VariableIDMapT>
|
|
class Constrainer;
|
|
template<typename WorldT, typename VarT, typename VariableIDMapT>
|
|
class WFC;
|
|
template<typename WorldT, typename VarT, typename VariableIDMapT>
|
|
class Variable;
|
|
template <typename VarT>
|
|
struct WorldValue;
|
|
|
|
/**
|
|
* @brief Class to map variable values to indices at compile time
|
|
*
|
|
* This class is used to map variable values to indices at compile time.
|
|
* It is a compile-time map of variable values to indices.
|
|
*/
|
|
template <typename VarT, VarT ... Values>
|
|
class VariableIDMap {
|
|
public:
|
|
|
|
using Type = VarT;
|
|
static constexpr size_t ValuesRegisteredAmount = sizeof...(Values);
|
|
|
|
using MaskType = typename std::conditional<
|
|
ValuesRegisteredAmount <= 8,
|
|
uint8_t,
|
|
typename std::conditional<
|
|
ValuesRegisteredAmount <= 16,
|
|
uint16_t,
|
|
typename std::conditional<
|
|
ValuesRegisteredAmount <= 32,
|
|
uint32_t,
|
|
uint64_t
|
|
>::type
|
|
>::type
|
|
>::type;
|
|
|
|
template <VarT ... AdditionalValues>
|
|
using Merge = VariableIDMap<VarT, Values..., AdditionalValues...>;
|
|
|
|
template <VarT Value>
|
|
static consteval bool HasValue()
|
|
{
|
|
constexpr VarT arr[] = {Values...};
|
|
constexpr size_t size = sizeof...(Values);
|
|
|
|
for (size_t i = 0; i < size; ++i)
|
|
if (arr[i] == Value)
|
|
return true;
|
|
return false;
|
|
}
|
|
|
|
template <VarT Value>
|
|
static consteval size_t GetIndex()
|
|
{
|
|
static_assert(HasValue<Value>(), "Value was not defined");
|
|
constexpr VarT arr[] = {Values...};
|
|
constexpr size_t size = sizeof...(Values);
|
|
|
|
for (size_t i = 0; i < size; ++i)
|
|
if (arr[i] == Value)
|
|
return i;
|
|
|
|
return static_cast<size_t>(-1); // This line is unreachable if value is found
|
|
}
|
|
|
|
static VarT GetValue(size_t index) {
|
|
assert(index < sizeof...(Values));
|
|
constexpr VarT arr[] = {Values...};
|
|
return arr[index];
|
|
}
|
|
|
|
template <VarT ... MaskValues>
|
|
static consteval MaskType GetMask()
|
|
{
|
|
return (0 | ... | (1 << GetIndex<MaskValues>()));
|
|
}
|
|
|
|
static consteval size_t size() { return sizeof...(Values); }
|
|
};
|
|
|
|
template <typename VarT>
|
|
struct WorldValue
|
|
{
|
|
public:
|
|
WorldValue() = default;
|
|
WorldValue(VarT value, uint16_t internalIndex)
|
|
: Value(value)
|
|
, InternalIndex(internalIndex)
|
|
{}
|
|
public:
|
|
operator VarT() const { return Value; }
|
|
|
|
public:
|
|
VarT Value{};
|
|
uint16_t InternalIndex{};
|
|
};
|
|
|
|
template <typename MaskType>
|
|
class Wave {
|
|
public:
|
|
Wave() = default;
|
|
Wave(size_t size, size_t variableAmount) : m_data(size)
|
|
{
|
|
for (auto& wave : m_data) wave = (1 << variableAmount) - 1;
|
|
}
|
|
|
|
void Collapse(size_t index, MaskType mask) { m_data[index] &= mask; }
|
|
size_t size() const { return m_data.size(); }
|
|
size_t Entropy(size_t index) const { return std::popcount(m_data[index]); }
|
|
bool IsCollapsed(size_t index) const { return Entropy(index) == 1; }
|
|
bool IsFullyCollapsed() const { return std::all_of(m_data.begin(), m_data.end(), [](MaskType value) { return std::popcount(value) == 1; }); }
|
|
bool HasContradiction() const { return std::any_of(m_data.begin(), m_data.end(), [](MaskType value) { return value == 0; }); }
|
|
bool IsContradicted(size_t index) const { return m_data[index] == 0; }
|
|
uint16_t GetVariableID(size_t index) const { return static_cast<uint16_t>(std::countr_zero(m_data[index])); }
|
|
MaskType GetMask(size_t index) const { return m_data[index]; }
|
|
|
|
private:
|
|
std::vector<MaskType> m_data;
|
|
};
|
|
|
|
/**
|
|
* @brief Constrainer class used in constraint functions to limit possible values for other cells
|
|
*/
|
|
template <typename VariableIDMapT>
|
|
class Constrainer {
|
|
public:
|
|
using MaskType = typename VariableIDMapT::MaskType;
|
|
|
|
public:
|
|
Constrainer(Wave<MaskType>& wave, std::queue<size_t>& propagationQueue)
|
|
: m_wave(wave)
|
|
, m_propagationQueue(propagationQueue)
|
|
{}
|
|
|
|
/**
|
|
* @brief Constrain a cell to exclude specific values
|
|
* @param cellId The ID of the cell to constrain
|
|
* @param forbiddenValues The set of forbidden values for this cell
|
|
*/
|
|
template <typename VariableIDMapT::Type ... ExcludedValues>
|
|
void Exclude(size_t cellId) {
|
|
static_assert(sizeof...(ExcludedValues) > 0, "At least one excluded value must be provided");
|
|
ApplyMask(cellId, ~VariableIDMapT::template GetMask<ExcludedValues...>());
|
|
}
|
|
|
|
void Exclude(WorldValue<typename VariableIDMapT::Type> value, size_t cellId) {
|
|
ApplyMask(cellId, ~(1 << value.InternalIndex));
|
|
}
|
|
|
|
/**
|
|
* @brief Constrain a cell to only allow one specific value
|
|
* @param cellId The ID of the cell to constrain
|
|
* @param value The only allowed value for this cell
|
|
*/
|
|
template <typename VariableIDMapT::Type ... AllowedValues>
|
|
void Only(size_t cellId) {
|
|
static_assert(sizeof...(AllowedValues) > 0, "At least one allowed value must be provided");
|
|
ApplyMask(cellId, VariableIDMapT::template GetMask<AllowedValues...>());
|
|
}
|
|
|
|
void Only(WorldValue<typename VariableIDMapT::Type> value, size_t cellId) {
|
|
ApplyMask(cellId, 1 << value.InternalIndex);
|
|
}
|
|
|
|
private:
|
|
void ApplyMask(size_t cellId, MaskType mask) {
|
|
if (m_wave.IsCollapsed(cellId)) return;
|
|
|
|
m_wave.Collapse(cellId, mask);
|
|
assert(!m_wave.HasContradiction() && "Contradiction found");
|
|
|
|
if (m_wave.IsCollapsed(cellId)) {
|
|
m_propagationQueue.push(cellId);
|
|
}
|
|
}
|
|
|
|
private:
|
|
Wave<MaskType>& m_wave;
|
|
std::queue<size_t>& m_propagationQueue;
|
|
};
|
|
|
|
/**
|
|
* @brief Variable definition with its constraint function
|
|
*/
|
|
template<typename WorldT, typename VarT, typename VariableIDMapT>
|
|
struct VariableData {
|
|
VarT value{};
|
|
std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc{};
|
|
|
|
VariableData() = default;
|
|
VariableData(VarT value, std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc)
|
|
: value(value)
|
|
, constraintFunc(constraintFunc)
|
|
{}
|
|
};
|
|
|
|
/**
|
|
* @brief Main WFC class implementing the Wave Function Collapse algorithm
|
|
*/
|
|
template<typename WorldT, typename VarT, typename VariableIDMapT = VariableIDMap<VarT>>
|
|
class WFC {
|
|
public:
|
|
static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements");
|
|
|
|
using MaskType = typename VariableIDMapT::MaskType;
|
|
|
|
public:
|
|
struct WorldSolver {
|
|
WorldT& world;
|
|
std::queue<size_t> propagationQueue;
|
|
Wave<MaskType> wave;
|
|
std::mt19937 rng;
|
|
|
|
WorldSolver(WorldT& world, const std::vector<VariableData<WorldT, VarT, VariableIDMapT>>& variables)
|
|
: world(world)
|
|
, propagationQueue()
|
|
, wave(world.size(), variables.size())
|
|
, rng(std::random_device{}())
|
|
{}
|
|
};
|
|
|
|
public:
|
|
WFC(std::vector<VariableData<WorldT, VarT, VariableIDMapT>>&& variables)
|
|
: m_variables(std::move(variables))
|
|
{}
|
|
|
|
public:
|
|
bool Run(WorldT& world)
|
|
{
|
|
WorldSolver worldSolver(world, m_variables);
|
|
return Run(worldSolver);
|
|
}
|
|
|
|
/**
|
|
* @brief Run the WFC algorithm to generate a solution
|
|
* @return true if a solution was found, false if contradiction occurred
|
|
*/
|
|
bool Run(WorldSolver& worldSolver)
|
|
{
|
|
for (size_t i = 0; i < 1024; ++i)
|
|
{
|
|
Propagate(worldSolver);
|
|
|
|
if (worldSolver.wave.IsFullyCollapsed()) {
|
|
PopulateWorld(worldSolver);
|
|
return true;
|
|
} else if (worldSolver.wave.HasContradiction()) {
|
|
return false;
|
|
} else {
|
|
GetMinEntropyCell(worldSolver);
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/**
|
|
* @brief Get the value at a specific cell
|
|
* @param cellId The cell ID
|
|
* @return The value if collapsed, std::nullopt otherwise
|
|
*/
|
|
std::optional<VarT> GetValue(WorldSolver& worldSolver, int cellId) const {
|
|
if (worldSolver.wave.IsCollapsed(cellId)) {
|
|
auto variableId = worldSolver.wave.GetVariableID(cellId);
|
|
return VariableIDMapT::GetValue(variableId);
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
/**
|
|
* @brief Get all possible values for a cell
|
|
* @param cellId The cell ID
|
|
* @return Set of possible values
|
|
*/
|
|
const std::vector<VarT> GetPossibleValues(WorldSolver& worldSolver, int cellId) const
|
|
{
|
|
std::vector<VarT> possibleValues;
|
|
MaskType mask = worldSolver.wave.GetMask(cellId);
|
|
for (size_t i = 0; i < m_variables.size(); ++i) {
|
|
if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i));
|
|
}
|
|
return possibleValues;
|
|
}
|
|
|
|
private:
|
|
bool GetMinEntropyCell(WorldSolver& worldSolver)
|
|
{
|
|
assert(worldSolver.propagationQueue.empty());
|
|
|
|
// Find cell with minimum entropy > 1
|
|
size_t minEntropyCell = static_cast<size_t>(-1);
|
|
size_t minEntropy = static_cast<size_t>(-1);
|
|
|
|
for (size_t i = 0; i < worldSolver.wave.size(); ++i) {
|
|
size_t entropy = worldSolver.wave.Entropy(i);
|
|
if (entropy > 1 && entropy < minEntropy) {
|
|
minEntropy = entropy;
|
|
minEntropyCell = i;
|
|
}
|
|
}
|
|
assert(!worldSolver.wave.IsCollapsed(minEntropyCell));
|
|
|
|
// Randomly select a value from possible values
|
|
size_t availableValues = worldSolver.wave.Entropy(minEntropyCell);
|
|
std::uniform_int_distribution<size_t> dist(0, availableValues - 1);
|
|
size_t selectedValue = FindNthSetBit(worldSolver.wave.GetMask(minEntropyCell), dist(worldSolver.rng));
|
|
assert(selectedValue < VariableIDMapT::ValuesRegisteredAmount && "Selected Value went outside bounds");
|
|
|
|
// Collapse the cell to the selected value
|
|
worldSolver.wave.Collapse(minEntropyCell, 1 << selectedValue);
|
|
assert(worldSolver.wave.IsCollapsed(minEntropyCell) && "Cell was not collapsed correctly");
|
|
|
|
worldSolver.propagationQueue.push(minEntropyCell);
|
|
|
|
return true;
|
|
}
|
|
|
|
void Propagate(WorldSolver& worldSolver)
|
|
{
|
|
while (!worldSolver.propagationQueue.empty())
|
|
{
|
|
size_t cellId = worldSolver.propagationQueue.front();
|
|
worldSolver.propagationQueue.pop();
|
|
|
|
assert(worldSolver.wave.IsCollapsed(cellId) && "Cell was not collapsed");
|
|
|
|
uint16_t variableID = worldSolver.wave.GetVariableID(cellId);
|
|
Constrainer<VariableIDMapT> constrainer(worldSolver.wave, worldSolver.propagationQueue);
|
|
m_variables[variableID].constraintFunc(worldSolver.world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
|
|
}
|
|
}
|
|
|
|
void PopulateWorld(WorldSolver& worldSolver)
|
|
{
|
|
for (size_t i = 0; i < worldSolver.wave.size(); ++i)
|
|
{
|
|
worldSolver.world.setValue(i, VariableIDMapT::GetValue(worldSolver.wave.GetVariableID(i)));
|
|
}
|
|
}
|
|
|
|
std::vector<VariableData<WorldT, VarT, VariableIDMapT>> m_variables {};
|
|
};
|
|
|
|
/**
|
|
* @brief Builder class for creating WFC instances
|
|
*/
|
|
template<typename WorldT, typename VarT, typename VariableIDMapT = VariableIDMap<VarT>>
|
|
class Builder {
|
|
public:
|
|
Builder() = default;
|
|
Builder(std::vector<VariableData<WorldT, VarT, VariableIDMapT>>&& variables)
|
|
: m_variables(std::move(variables))
|
|
{}
|
|
|
|
public:
|
|
template <VarT ... Values>
|
|
auto DefineIDs()
|
|
{
|
|
using NewVariableIDMapT = typename VariableIDMapT::template Merge<Values...>;
|
|
// reinterpret_cast is used to be able to move the variables with an outdated VariableIDMap to the new VariableIDMap. The previous indices still work.
|
|
return Builder<WorldT, VarT, NewVariableIDMapT>(std::move(reinterpret_cast<std::vector<VariableData<WorldT, VarT, NewVariableIDMapT>>&>(m_variables)));
|
|
}
|
|
|
|
/**
|
|
* @brief Add a variable with its constraint function
|
|
* @param value The variable value
|
|
* @param constraintFunc Function that defines constraints when this variable is placed
|
|
* @return Reference to this builder for method chaining
|
|
*/
|
|
template <VarT ... Values>
|
|
Builder& Variable(const std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc) {
|
|
m_variables.resize(VariableIDMapT::ValuesRegisteredAmount);
|
|
|
|
Variable_Internal<Values...>(constraintFunc);
|
|
return *this;
|
|
}
|
|
|
|
/**
|
|
* @brief Build the WFC instance
|
|
* @param world The world instance to work with
|
|
* @return A unique_ptr to the created WFC instance
|
|
*/
|
|
auto build() {
|
|
return WFC<WorldT, VarT, VariableIDMapT>(std::move(m_variables));
|
|
}
|
|
|
|
private:
|
|
template <VarT Value, VarT ... Values>
|
|
void Variable_Internal(const std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc)
|
|
{
|
|
m_variables[VariableIDMapT::template GetIndex<Value>()] = VariableData<WorldT, VarT, VariableIDMapT>{
|
|
Value,
|
|
constraintFunc
|
|
};
|
|
if constexpr (sizeof...(Values) > 0) {
|
|
Variable_Internal<Values...>(constraintFunc);
|
|
}
|
|
}
|
|
|
|
private:
|
|
std::vector<VariableData<WorldT, VarT, VariableIDMapT>> m_variables;
|
|
};
|
|
|
|
} // namespace WFC
|