Initial changes
This commit is contained in:
@@ -9,4 +9,5 @@
|
||||
|
||||
#include "wfc.hpp"
|
||||
#include "worlds.hpp"
|
||||
#include "wfc_builder.hpp"
|
||||
|
||||
|
||||
@@ -14,174 +14,17 @@
|
||||
#include <span>
|
||||
#include <tuple>
|
||||
|
||||
#include "wfc_utils.hpp"
|
||||
#include "wfc_variable_map.hpp"
|
||||
#include "wfc_allocator.hpp"
|
||||
#include "wfc_bit_container.hpp"
|
||||
#include "wfc_wave.hpp"
|
||||
#include "wfc_constrainer.hpp"
|
||||
#include "wfc_callbacks.hpp"
|
||||
#include "wfc_random.hpp"
|
||||
|
||||
namespace WFC {
|
||||
|
||||
inline constexpr void constexpr_assert(bool condition, const char* message = "") {
|
||||
if (!condition) throw message;
|
||||
}
|
||||
|
||||
inline int FindNthSetBit(size_t num, int n) {
|
||||
constexpr_assert(n < std::popcount(num), "index is out of range");
|
||||
int bitCount = 0;
|
||||
while (num) {
|
||||
if (bitCount == n) {
|
||||
return std::countr_zero(num); // Index of the current set bit
|
||||
}
|
||||
bitCount++;
|
||||
num &= (num - 1); // turn of lowest set bit
|
||||
}
|
||||
return bitCount;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
concept WorldType = requires(T world, size_t id, typename T::ValueType value) {
|
||||
{ world.size() } -> std::convertible_to<size_t>;
|
||||
{ world.setValue(id, value) };
|
||||
{ world.getValue(id) } -> std::convertible_to<typename T::ValueType>;
|
||||
typename T::ValueType;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Class to map variable values to indices at compile time
|
||||
*
|
||||
* This class is used to map variable values to indices at compile time.
|
||||
* It is a compile-time map of variable values to indices.
|
||||
*/
|
||||
template <typename VarT, VarT ... Values>
|
||||
class VariableIDMap {
|
||||
public:
|
||||
|
||||
using Type = VarT;
|
||||
static constexpr size_t ValuesRegisteredAmount = sizeof...(Values);
|
||||
|
||||
using MaskType = typename std::conditional<
|
||||
ValuesRegisteredAmount <= 8,
|
||||
uint8_t,
|
||||
typename std::conditional<
|
||||
ValuesRegisteredAmount <= 16,
|
||||
uint16_t,
|
||||
typename std::conditional<
|
||||
ValuesRegisteredAmount <= 32,
|
||||
uint32_t,
|
||||
uint64_t
|
||||
>::type
|
||||
>::type
|
||||
>::type;
|
||||
|
||||
template <VarT ... AdditionalValues>
|
||||
using Merge = VariableIDMap<VarT, Values..., AdditionalValues...>;
|
||||
|
||||
template <VarT Value>
|
||||
static consteval bool HasValue()
|
||||
{
|
||||
constexpr VarT arr[] = {Values...};
|
||||
constexpr size_t size = sizeof...(Values);
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if (arr[i] == Value)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
template <VarT Value>
|
||||
static consteval size_t GetIndex()
|
||||
{
|
||||
static_assert(HasValue<Value>(), "Value was not defined");
|
||||
constexpr VarT arr[] = {Values...};
|
||||
constexpr size_t size = ValuesRegisteredAmount;
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if (arr[i] == Value)
|
||||
return i;
|
||||
|
||||
return static_cast<size_t>(-1); // This line is unreachable if value is found
|
||||
}
|
||||
|
||||
template <VarT ... MaskValues>
|
||||
static consteval MaskType GetMask()
|
||||
{
|
||||
return (0 | ... | (1 << GetIndex<MaskValues>()));
|
||||
}
|
||||
|
||||
static std::span<const VarT> GetAllValues()
|
||||
{
|
||||
static const VarT allValues[]
|
||||
{
|
||||
Values...
|
||||
};
|
||||
return std::span<const VarT>{ allValues, ValuesRegisteredAmount };
|
||||
}
|
||||
|
||||
static VarT GetValue(size_t index) {
|
||||
constexpr_assert(index < ValuesRegisteredAmount);
|
||||
return GetAllValues()[index];
|
||||
}
|
||||
|
||||
static consteval VarT GetValueConsteval(size_t index)
|
||||
{
|
||||
constexpr VarT arr[] = {Values...};
|
||||
return arr[index];
|
||||
}
|
||||
|
||||
static consteval size_t size() { return ValuesRegisteredAmount; }
|
||||
};
|
||||
|
||||
template <typename ... ConstrainerFunctions>
|
||||
struct ConstrainerFunctionMap {
|
||||
public:
|
||||
static consteval size_t size() { return sizeof...(ConstrainerFunctions); }
|
||||
|
||||
using TupleType = std::tuple<ConstrainerFunctions...>;
|
||||
|
||||
template <typename ConstrainerFunctionPtrT>
|
||||
static ConstrainerFunctionPtrT GetFunction(size_t index)
|
||||
{
|
||||
static_assert((std::is_empty_v<ConstrainerFunctions> && ...), "Lambdas must not have any captures");
|
||||
static ConstrainerFunctionPtrT functions[] = {
|
||||
static_cast<ConstrainerFunctionPtrT>(ConstrainerFunctions{}) ...
|
||||
};
|
||||
return functions[index];
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to select the correct constrainer function based on the index and the value
|
||||
template<std::size_t I,
|
||||
typename VariableIDMapT,
|
||||
typename ConstrainerFunctionMapT,
|
||||
typename NewConstrainerFunctionT,
|
||||
typename SelectedIDsVariableIDMapT,
|
||||
typename EmptyFunctionT>
|
||||
using MergedConstrainerElementSelector =
|
||||
std::conditional_t<SelectedIDsVariableIDMapT::template HasValue<VariableIDMapT::GetValueConsteval(I)>(), // if the value is in the selected IDs
|
||||
NewConstrainerFunctionT,
|
||||
std::conditional_t<(I < ConstrainerFunctionMapT::size()), // if the index is within the size of the tuple
|
||||
std::tuple_element_t<std::min(I, ConstrainerFunctionMapT::size() - 1), typename ConstrainerFunctionMapT::TupleType>,
|
||||
EmptyFunctionT
|
||||
>
|
||||
>;
|
||||
|
||||
// Helper to make a merged constrainer function map
|
||||
template<typename VariableIDMapT,
|
||||
typename ConstrainerFunctionMapT,
|
||||
typename NewConstrainerFunctionT,
|
||||
typename SelectedIDsVariableIDMapT,
|
||||
typename EmptyFunctionT,
|
||||
std::size_t... Is>
|
||||
auto MakeMergedConstrainerIDMap(std::index_sequence<Is...>,VariableIDMapT*, ConstrainerFunctionMapT*, NewConstrainerFunctionT*, SelectedIDsVariableIDMapT*, EmptyFunctionT*)
|
||||
-> ConstrainerFunctionMap<MergedConstrainerElementSelector<Is, VariableIDMapT, ConstrainerFunctionMapT, NewConstrainerFunctionT, SelectedIDsVariableIDMapT, EmptyFunctionT>...>;
|
||||
|
||||
// Main alias for the merged constrainer function map
|
||||
template<typename VariableIDMapT,
|
||||
typename ConstrainerFunctionMapT,
|
||||
typename NewConstrainerFunctionT,
|
||||
typename SelectedIDsVariableIDMapT,
|
||||
typename EmptyFunctionT>
|
||||
using MergedConstrainerFunctionMap = decltype(
|
||||
MakeMergedConstrainerIDMap(std::make_index_sequence<VariableIDMapT::ValuesRegisteredAmount>{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr)
|
||||
);
|
||||
|
||||
template <typename VarT>
|
||||
struct WorldValue
|
||||
{
|
||||
@@ -199,150 +42,37 @@ public:
|
||||
uint16_t InternalIndex{};
|
||||
};
|
||||
|
||||
template <typename MaskType>
|
||||
class Wave {
|
||||
public:
|
||||
Wave() = default;
|
||||
Wave(size_t size, size_t variableAmount, WFCStackAllocator& allocator) : m_data(size, WFCStackAllocatorAdapter<MaskType>(allocator))
|
||||
{
|
||||
for (auto& wave : m_data) wave = (1 << variableAmount) - 1;
|
||||
}
|
||||
|
||||
Wave(const Wave& other) = default;
|
||||
|
||||
public:
|
||||
void Collapse(size_t index, MaskType mask) { m_data[index] &= mask; }
|
||||
size_t size() const { return m_data.size(); }
|
||||
size_t Entropy(size_t index) const { return std::popcount(m_data[index]); }
|
||||
bool IsCollapsed(size_t index) const { return Entropy(index) == 1; }
|
||||
bool IsFullyCollapsed() const { return std::all_of(m_data.begin(), m_data.end(), [](MaskType value) { return std::popcount(value) == 1; }); }
|
||||
bool HasContradiction() const { return std::any_of(m_data.begin(), m_data.end(), [](MaskType value) { return value == 0; }); }
|
||||
bool IsContradicted(size_t index) const { return m_data[index] == 0; }
|
||||
uint16_t GetVariableID(size_t index) const { return static_cast<uint16_t>(std::countr_zero(m_data[index])); }
|
||||
MaskType GetMask(size_t index) const { return m_data[index]; }
|
||||
|
||||
private:
|
||||
WFCVector<MaskType> m_data;
|
||||
template<typename T>
|
||||
concept WorldType = requires(T world, size_t id, typename T::ValueType value) {
|
||||
{ world.size() } -> std::convertible_to<size_t>;
|
||||
{ world.setValue(id, value) };
|
||||
{ world.getValue(id) } -> std::convertible_to<typename T::ValueType>;
|
||||
typename T::ValueType;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Constrainer class used in constraint functions to limit possible values for other cells
|
||||
*/
|
||||
template <typename VariableIDMapT>
|
||||
class Constrainer {
|
||||
public:
|
||||
using MaskType = typename VariableIDMapT::MaskType;
|
||||
|
||||
public:
|
||||
Constrainer(Wave<MaskType>& wave, WFCQueue<size_t>& propagationQueue)
|
||||
: m_wave(wave)
|
||||
, m_propagationQueue(propagationQueue)
|
||||
{}
|
||||
|
||||
/**
|
||||
* @brief Constrain a cell to exclude specific values
|
||||
* @param cellId The ID of the cell to constrain
|
||||
* @param forbiddenValues The set of forbidden values for this cell
|
||||
*/
|
||||
template <typename VariableIDMapT::Type ... ExcludedValues>
|
||||
void Exclude(size_t cellId) {
|
||||
static_assert(sizeof...(ExcludedValues) > 0, "At least one excluded value must be provided");
|
||||
ApplyMask(cellId, ~VariableIDMapT::template GetMask<ExcludedValues...>());
|
||||
}
|
||||
|
||||
void Exclude(WorldValue<typename VariableIDMapT::Type> value, size_t cellId) {
|
||||
ApplyMask(cellId, ~(1 << value.InternalIndex));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Constrain a cell to only allow one specific value
|
||||
* @param cellId The ID of the cell to constrain
|
||||
* @param value The only allowed value for this cell
|
||||
*/
|
||||
template <typename VariableIDMapT::Type ... AllowedValues>
|
||||
void Only(size_t cellId) {
|
||||
static_assert(sizeof...(AllowedValues) > 0, "At least one allowed value must be provided");
|
||||
ApplyMask(cellId, VariableIDMapT::template GetMask<AllowedValues...>());
|
||||
}
|
||||
|
||||
void Only(WorldValue<typename VariableIDMapT::Type> value, size_t cellId) {
|
||||
ApplyMask(cellId, 1 << value.InternalIndex);
|
||||
}
|
||||
|
||||
private:
|
||||
void ApplyMask(size_t cellId, MaskType mask) {
|
||||
bool wasCollapsed = m_wave.IsCollapsed(cellId);
|
||||
|
||||
m_wave.Collapse(cellId, mask);
|
||||
|
||||
bool collapsed = m_wave.IsCollapsed(cellId);
|
||||
if (!wasCollapsed && collapsed) {
|
||||
m_propagationQueue.push(cellId);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Wave<MaskType>& m_wave;
|
||||
WFCQueue<size_t>& m_propagationQueue;
|
||||
* @brief Concept to validate constrainer function signature
|
||||
* The function must be callable with parameters: (WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)
|
||||
*/
|
||||
template <typename T, typename WorldT, typename VarT, typename VariableIDMapT>
|
||||
concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue<VarT> value, Constrainer<VariableIDMapT>& constrainer) {
|
||||
func(world, index, value, constrainer);
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Variable definition with its constraint function
|
||||
*/
|
||||
template<typename WorldT, typename VarT, typename VariableIDMapT>
|
||||
struct VariableData {
|
||||
VarT value{};
|
||||
std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc{};
|
||||
|
||||
VariableData() = default;
|
||||
VariableData(VarT value, std::function<void(WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)> constraintFunc)
|
||||
: value(value)
|
||||
, constraintFunc(constraintFunc)
|
||||
{}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Empty callback function
|
||||
* @param World The world type
|
||||
*/
|
||||
template <typename World>
|
||||
using EmptyCallback = decltype([](World&){});
|
||||
|
||||
/**
|
||||
* @brief Callback struct
|
||||
* @param WorldT The world type
|
||||
* @param AllCellsCollapsedCallbackT The all cells collapsed callback type
|
||||
* @param CellCollapsedCallbackT The cell collapsed callback type
|
||||
* @param ContradictionCallbackT The contradiction callback type
|
||||
* @param BranchCallbackT The branch callback type
|
||||
*/
|
||||
template <typename WorldT,
|
||||
typename CellCollapsedCallbackT = EmptyCallback<WorldT>,
|
||||
typename ContradictionCallbackT = EmptyCallback<WorldT>,
|
||||
typename BranchCallbackT = EmptyCallback<WorldT>
|
||||
>
|
||||
struct Callbacks
|
||||
{
|
||||
using CellCollapsedCallback = CellCollapsedCallbackT;
|
||||
using ContradictionCallback = ContradictionCallbackT;
|
||||
using BranchCallback = BranchCallbackT;
|
||||
|
||||
template <typename NewCellCollapsedCallbackT>
|
||||
using SetCellCollapsedCallbackT = Callbacks<WorldT, NewCellCollapsedCallbackT, ContradictionCallbackT, BranchCallbackT>;
|
||||
template <typename NewContradictionCallbackT>
|
||||
using SetContradictionCallbackT = Callbacks<WorldT, CellCollapsedCallbackT, NewContradictionCallbackT, BranchCallbackT>;
|
||||
template <typename NewBranchCallbackT>
|
||||
using SetBranchCallbackT = Callbacks<WorldT, CellCollapsedCallbackT, ContradictionCallbackT, NewBranchCallbackT>;
|
||||
|
||||
static consteval bool HasCellCollapsedCallback() { return !std::is_same_v<CellCollapsedCallbackT, EmptyCallback<WorldT>>; }
|
||||
static consteval bool HasContradictionCallback() { return !std::is_same_v<ContradictionCallbackT, EmptyCallback<WorldT>>; }
|
||||
static consteval bool HasBranchCallback() { return !std::is_same_v<BranchCallbackT, EmptyCallback<WorldT>>; }
|
||||
* @brief Concept to validate random selector function signature
|
||||
* The function must be callable with parameters: (std::span<const VarT>) and return size_t
|
||||
*/
|
||||
template <typename T, typename VarT>
|
||||
concept RandomSelectorFunction = requires(T func, std::span<const VarT> possibleValues) {
|
||||
{ func(possibleValues) } -> std::convertible_to<size_t>;
|
||||
{ func.rng(static_cast<uint32_t>(1)) } -> std::convertible_to<uint32_t>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Main WFC class implementing the Wave Function Collapse algorithm
|
||||
*/
|
||||
template<typename WorldT, typename VarT,
|
||||
template<typename WorldT, typename VarT, size_t WorldSize = 0,
|
||||
typename VariableIDMapT = VariableIDMap<VarT>,
|
||||
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
|
||||
typename CallbacksT = Callbacks<WorldT>,
|
||||
@@ -351,14 +81,14 @@ class WFC {
|
||||
public:
|
||||
static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements");
|
||||
|
||||
using MaskType = typename VariableIDMapT::MaskType;
|
||||
using ElementT = typename VariableIDMapT::ElementT;
|
||||
|
||||
public:
|
||||
struct SolverState
|
||||
{
|
||||
WorldT& world;
|
||||
WFCQueue<size_t> propagationQueue;
|
||||
Wave<MaskType> wave;
|
||||
Wave<VariableIDMapT, WorldSize> wave;
|
||||
std::mt19937& rng;
|
||||
RandomSelectorT& randomSelector;
|
||||
WFCStackAllocator& allocator;
|
||||
@@ -479,7 +209,7 @@ public:
|
||||
static const std::vector<VarT> GetPossibleValues(SolverState& state, int cellId)
|
||||
{
|
||||
std::vector<VarT> possibleValues;
|
||||
MaskType mask = state.wave.GetMask(cellId);
|
||||
ElementT mask = state.wave.GetMask(cellId);
|
||||
for (size_t i = 0; i < ConstrainerFunctionMapT::size(); ++i) {
|
||||
if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i));
|
||||
}
|
||||
@@ -489,7 +219,7 @@ public:
|
||||
private:
|
||||
static void CollapseCell(SolverState& state, size_t cellId, uint16_t value)
|
||||
{
|
||||
constexpr_assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (1 << value));
|
||||
constexpr_assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (ElementT(1) << value));
|
||||
state.wave.Collapse(cellId, 1 << value);
|
||||
constexpr_assert(state.wave.IsCollapsed(cellId));
|
||||
|
||||
@@ -522,14 +252,14 @@ private:
|
||||
// create a list of possible values
|
||||
uint16_t availableValues = static_cast<uint16_t>(state.wave.Entropy(minEntropyCell));
|
||||
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> possibleValues; // inplace vector
|
||||
MaskType mask = state.wave.GetMask(minEntropyCell);
|
||||
ElementT mask = state.wave.GetMask(minEntropyCell);
|
||||
for (size_t i = 0; i < availableValues; ++i)
|
||||
{
|
||||
uint16_t index = static_cast<uint16_t>(std::countr_zero(mask)); // get the index of the lowest set bit
|
||||
constexpr_assert(index < VariableIDMapT::ValuesRegisteredAmount, "Possible value went outside bounds");
|
||||
|
||||
possibleValues[i] = index;
|
||||
constexpr_assert(((mask & (1 << index)) != 0), "Possible value was not set");
|
||||
constexpr_assert(((mask & (ElementT(1) << index)) != 0), "Possible value was not set");
|
||||
|
||||
mask = mask & (mask - 1); // turn off lowest set bit
|
||||
}
|
||||
@@ -563,9 +293,9 @@ private:
|
||||
}
|
||||
|
||||
// remove the failure state from the wave
|
||||
constexpr_assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) != 0, "Possible value was not set");
|
||||
constexpr_assert((state.wave.GetMask(minEntropyCell) & (ElementT(1) << selectedValue)) != 0, "Possible value was not set");
|
||||
state.wave.Collapse(minEntropyCell, ~(1 << selectedValue));
|
||||
constexpr_assert((state.wave.GetMask(minEntropyCell) & (1 << selectedValue)) == 0, "Wave was not collapsed correctly");
|
||||
constexpr_assert((state.wave.GetMask(minEntropyCell) & (ElementT(1) << selectedValue)) == 0, "Wave was not collapsed correctly");
|
||||
|
||||
// swap replacement value with the last value
|
||||
std::swap(possibleValues[randomIndex], possibleValues[--availableValues]);
|
||||
@@ -622,236 +352,4 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Concept to validate constrainer function signature
|
||||
* The function must be callable with parameters: (WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)
|
||||
*/
|
||||
template <typename T, typename WorldT, typename VarT, typename VariableIDMapT>
|
||||
concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue<VarT> value, Constrainer<VariableIDMapT>& constrainer) {
|
||||
func(world, index, value, constrainer);
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Concept to validate random selector function signature
|
||||
* The function must be callable with parameters: (std::span<const VarT>) and return size_t
|
||||
*/
|
||||
template <typename T, typename VarT>
|
||||
concept RandomSelectorFunction = requires(T func, std::span<const VarT> possibleValues) {
|
||||
{ func(possibleValues) } -> std::convertible_to<size_t>;
|
||||
{ func.rng(static_cast<uint32_t>(1)) } -> std::convertible_to<uint32_t>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Default constexpr random selector using a simple seed-based algorithm
|
||||
* This provides a compile-time random selection that maintains state between calls
|
||||
*/
|
||||
template <typename VarT>
|
||||
class DefaultRandomSelector {
|
||||
private:
|
||||
mutable uint32_t m_seed;
|
||||
|
||||
public:
|
||||
constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {}
|
||||
|
||||
constexpr size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
|
||||
// Simple linear congruential generator for constexpr compatibility
|
||||
return static_cast<size_t>(rng(possibleValues.size()));
|
||||
}
|
||||
|
||||
constexpr uint32_t rng(uint32_t max) {
|
||||
m_seed = m_seed * 1103515245 + 12345;
|
||||
return m_seed % max;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Advanced random selector using std::mt19937 and std::uniform_int_distribution
|
||||
* This provides high-quality randomization for runtime use
|
||||
*/
|
||||
template <typename VarT>
|
||||
class AdvancedRandomSelector {
|
||||
private:
|
||||
std::mt19937& m_rng;
|
||||
|
||||
public:
|
||||
explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {}
|
||||
|
||||
size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
|
||||
return rng(possibleValues.size());
|
||||
}
|
||||
|
||||
uint32_t rng(uint32_t max) {
|
||||
std::uniform_int_distribution<uint32_t> dist(0, max);
|
||||
return dist(m_rng);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Weight specification for a specific value
|
||||
* @tparam Value The variable value
|
||||
* @tparam Weight The 16-bit weight for this value
|
||||
*/
|
||||
template <typename VarT, VarT Value, uint16_t WeightValue>
|
||||
struct Weight {
|
||||
static constexpr VarT value = Value;
|
||||
static constexpr uint16_t weight = WeightValue;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compile-time weights storage for weighted random selection
|
||||
* @tparam VarT The variable type
|
||||
* @tparam VariableIDMapT The variable ID map type
|
||||
* @tparam DefaultWeight The default weight for values not explicitly specified
|
||||
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
|
||||
*/
|
||||
template <typename VarT, typename VariableIDMapT, uint16_t DefaultWeight, typename... WeightSpecs>
|
||||
class WeightsMap {
|
||||
private:
|
||||
static constexpr size_t NumWeights = sizeof...(WeightSpecs);
|
||||
|
||||
// Helper to get weight for a specific value
|
||||
static consteval uint16_t GetWeightForValue(VarT targetValue) {
|
||||
// Check each weight spec to find the target value
|
||||
uint16_t weight = DefaultWeight;
|
||||
((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...);
|
||||
return weight;
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Get the weight for a specific value at compile time
|
||||
* @tparam TargetValue The value to get weight for
|
||||
* @return The weight for the value
|
||||
*/
|
||||
template <VarT TargetValue>
|
||||
static consteval uint16_t GetWeight() {
|
||||
return GetWeightForValue(TargetValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get weights array for all registered values
|
||||
* @return Array of weights corresponding to all registered values
|
||||
*/
|
||||
static consteval std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> GetWeightsArray() {
|
||||
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> weights{};
|
||||
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
weights[i] = GetWeightForValue(VariableIDMapT::GetValueConsteval(i));
|
||||
}
|
||||
|
||||
return weights;
|
||||
}
|
||||
|
||||
static consteval uint32_t GetTotalWeight() {
|
||||
uint32_t totalWeight = 0;
|
||||
auto weights = GetWeightsArray();
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
totalWeight += weights[i];
|
||||
}
|
||||
return totalWeight;
|
||||
}
|
||||
|
||||
static consteval std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> GetCumulativeWeightsArray() {
|
||||
auto weights = GetWeightsArray();
|
||||
uint32_t totalWeight = 0;
|
||||
std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> cumulativeWeights{};
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
totalWeight += weights[i];
|
||||
cumulativeWeights[i] = totalWeight;
|
||||
}
|
||||
return cumulativeWeights;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Weighted random selector that uses another random selector as backend
|
||||
* @tparam VarT The variable type
|
||||
* @tparam VariableIDMapT The variable ID map type
|
||||
* @tparam BackendSelectorT The backend random selector type
|
||||
* @tparam WeightsMapT The weights map type containing weight specifications
|
||||
*/
|
||||
template <typename VarT, typename VariableIDMapT, typename BackendSelectorT, typename WeightsMapT>
|
||||
class WeightedSelector {
|
||||
private:
|
||||
BackendSelectorT m_backendSelector;
|
||||
const std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> m_weights;
|
||||
const std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> m_cumulativeWeights;
|
||||
|
||||
public:
|
||||
explicit WeightedSelector(BackendSelectorT backendSelector)
|
||||
: m_backendSelector(backendSelector)
|
||||
, m_weights(WeightsMapT::GetWeightsArray())
|
||||
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
|
||||
{}
|
||||
|
||||
explicit WeightedSelector(uint32_t seed)
|
||||
requires std::is_same_v<BackendSelectorT, DefaultRandomSelector<VarT>>
|
||||
: m_backendSelector(seed)
|
||||
, m_weights(WeightsMapT::GetWeightsArray())
|
||||
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
|
||||
{}
|
||||
|
||||
size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value");
|
||||
|
||||
// Use backend selector to pick a random number in range [0, totalWeight)
|
||||
uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back());
|
||||
|
||||
// Find which value this random value corresponds to
|
||||
for (size_t i = 0; i < possibleValues.size(); ++i) {
|
||||
if (randomValue <= m_cumulativeWeights[i]) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback (should not reach here)
|
||||
return possibleValues.size() - 1;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Builder class for creating WFC instances
|
||||
*/
|
||||
template<typename WorldT, typename VarT = typename WorldT::ValueType, typename VariableIDMapT = VariableIDMap<VarT>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>, typename CallbacksT = Callbacks<WorldT>, typename RandomSelectorT = DefaultRandomSelector<VarT>>
|
||||
class Builder {
|
||||
public:
|
||||
|
||||
template <VarT ... Values>
|
||||
using DefineIDs = Builder<WorldT, VarT, typename VariableIDMapT::template Merge<Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
|
||||
|
||||
template <typename ConstrainerFunctionT, VarT ... CorrespondingValues>
|
||||
requires ConstrainerFunction<ConstrainerFunctionT, WorldT, VarT, VariableIDMapT>
|
||||
using DefineConstrainer = Builder<WorldT, VarT, VariableIDMapT,
|
||||
MergedConstrainerFunctionMap<
|
||||
VariableIDMapT,
|
||||
ConstrainerFunctionMapT,
|
||||
ConstrainerFunctionT,
|
||||
VariableIDMap<VarT, CorrespondingValues...>,
|
||||
decltype([](WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&) {})
|
||||
>, CallbacksT, RandomSelectorT
|
||||
>;
|
||||
|
||||
template <typename NewCellCollapsedCallbackT>
|
||||
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT>;
|
||||
template <typename NewContradictionCallbackT>
|
||||
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT>;
|
||||
template <typename NewBranchCallbackT>
|
||||
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT>;
|
||||
|
||||
template <typename NewRandomSelectorT>
|
||||
requires RandomSelectorFunction<NewRandomSelectorT, VarT>
|
||||
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
|
||||
|
||||
template <uint16_t DefaultWeight, typename... WeightSpecs>
|
||||
using Weights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, WeightedSelector<VarT, VariableIDMapT, RandomSelectorT, WeightsMap<VarT, VariableIDMapT, DefaultWeight, WeightSpecs...>>>;
|
||||
|
||||
|
||||
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
|
||||
};
|
||||
|
||||
} // namespace WFC
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
|
||||
#include "wfc_utils.hpp"
|
||||
|
||||
#define WFC_USE_STACK_ALLOCATOR
|
||||
|
||||
inline void* allocate_aligned_memory(size_t alignment, size_t size) {
|
||||
|
||||
172
include/nd-wfc/wfc_bit_container.hpp
Normal file
172
include/nd-wfc/wfc_bit_container.hpp
Normal file
@@ -0,0 +1,172 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
#include <cassert>
|
||||
#include <bit>
|
||||
#include <type_traits>
|
||||
|
||||
#include "wfc_utils.hpp"
|
||||
#include "wfc_allocator.hpp"
|
||||
|
||||
namespace WFC {
|
||||
|
||||
namespace detail {
|
||||
// Helper to determine the optimal storage type based on bits needed
|
||||
template<size_t Bits>
|
||||
struct OptimalStorageType {
|
||||
static constexpr size_t bits_needed = Bits == 0 ? 0 :
|
||||
(Bits <= 1) ? 1 :
|
||||
(Bits <= 2) ? 2 :
|
||||
(Bits <= 4) ? 4 :
|
||||
(Bits <= 8) ? 8 :
|
||||
(Bits <= 16) ? 16 :
|
||||
(Bits <= 32) ? 32 :
|
||||
(Bits <= 64) ? 64 :
|
||||
((Bits + 63) / 64) * 64; // Round up to multiple of 64 for >64 bits
|
||||
|
||||
using type = std::conditional_t<bits_needed <= 8, uint8_t,
|
||||
std::conditional_t<bits_needed <= 16, uint16_t,
|
||||
std::conditional_t<bits_needed <= 32, uint32_t,
|
||||
uint64_t>>>;
|
||||
};
|
||||
|
||||
// Helper for multi-element storage (>64 bits)
|
||||
template<size_t Bits>
|
||||
struct StorageArray {
|
||||
static constexpr size_t StorageBits = OptimalStorageType<Bits>::bits_needed;
|
||||
static constexpr size_t ArraySize = StorageBits > 64 ? (StorageBits / 64) : 1;
|
||||
using element_type = std::conditional_t<StorageBits <= 64, typename OptimalStorageType<Bits>::type, uint64_t>;
|
||||
using type = std::conditional_t<ArraySize == 1, element_type, std::array<element_type, ArraySize>>;
|
||||
};
|
||||
|
||||
struct Empty{};
|
||||
}
|
||||
|
||||
template<size_t Bits, size_t Size = 0, typename AllocatorT = WFCStackAllocatorAdapter<typename detail::StorageArray<Bits>::type>>
|
||||
class BitContainer : private AllocatorT{
|
||||
public:
|
||||
using StorageInfo = detail::OptimalStorageType<Bits>;
|
||||
using StorageArrayInfo = detail::StorageArray<Bits>;
|
||||
using StorageType = typename StorageArrayInfo::type;
|
||||
using AllocatorType = AllocatorT;
|
||||
|
||||
static constexpr size_t BitsPerElement = Bits;
|
||||
static constexpr size_t StorageBits = StorageInfo::bits_needed;
|
||||
static constexpr bool IsResizable = (Size == 0);
|
||||
static constexpr bool IsMultiElement = (StorageBits > 64);
|
||||
static constexpr bool IsSubByte = (StorageBits < 8);
|
||||
static constexpr bool IsDefaultByteLayout = !IsMultiElement && !IsSubByte;
|
||||
static constexpr size_t ElementsPerByte = sizeof(StorageType) * 8 / std::max<size_t>(1u, StorageBits);
|
||||
|
||||
using ContainerType =
|
||||
std::conditional_t<Bits == 0,
|
||||
detail::Empty,
|
||||
std::conditional_t<IsResizable,
|
||||
std::vector<StorageType, AllocatorType>,
|
||||
std::array<StorageType, Size>>>;
|
||||
|
||||
private:
|
||||
ContainerType m_container;
|
||||
|
||||
private:
|
||||
// Mask for extracting bits
|
||||
static constexpr auto get_Mask()
|
||||
{
|
||||
if constexpr (BitsPerElement == 0)
|
||||
{
|
||||
return uint64_t{0};
|
||||
}
|
||||
else if constexpr (BitsPerElement >= 64)
|
||||
{
|
||||
return ~uint64_t{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return (uint64_t{1} << BitsPerElement) - 1;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr uint64_t Mask = get_Mask();
|
||||
|
||||
public:
|
||||
|
||||
BitContainer() = default;
|
||||
BitContainer(AllocatorT& allocator) : AllocatorT(allocator) {};
|
||||
explicit BitContainer(size_t initial_size, AllocatorT& allocator) requires (IsResizable)
|
||||
: AllocatorT(allocator)
|
||||
, m_container(initial_size, allocator)
|
||||
{};
|
||||
explicit BitContainer(size_t, AllocatorT& allocator) requires (!IsResizable)
|
||||
: AllocatorT(allocator)
|
||||
, m_container()
|
||||
{};
|
||||
|
||||
public:
|
||||
// Size operations
|
||||
constexpr size_t size() const noexcept
|
||||
{
|
||||
if constexpr (IsResizable)
|
||||
{
|
||||
return m_container.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
return Size;
|
||||
}
|
||||
}
|
||||
constexpr std::span<const StorageType> data() const { return std::span<const StorageType>(m_container); }
|
||||
constexpr std::span<StorageType> data() { return std::span<StorageType>(m_container); }
|
||||
|
||||
constexpr void resize(size_t new_size) requires (IsResizable) { m_container.resize(new_size); }
|
||||
constexpr void reserve(size_t capacity) requires (IsResizable) { m_container.reserve(capacity); }
|
||||
|
||||
public: // Sub byte
|
||||
struct SubTypeAccess
|
||||
{
|
||||
constexpr SubTypeAccess(uint8_t& data, uint8_t subIndex) : Data{ data }, Shift{ StorageBits * subIndex } {};
|
||||
|
||||
constexpr uint8_t GetValue() const { return ((Data >> Shift) & Mask); }
|
||||
constexpr uint8_t SetValue(uint8_t val) { Clear(); return Data |= ((val & Mask) << Shift); }
|
||||
constexpr void Clear() { Data &= ~Mask; }
|
||||
|
||||
constexpr operator uint8_t() const { return GetValue(); }
|
||||
|
||||
template <typename T> constexpr uint8_t operator&=(T other) { return SetValue(GetValue() & other); }
|
||||
template <typename T> constexpr uint8_t operator|=(T other) { return SetValue(GetValue() | other); }
|
||||
template <typename T> constexpr uint8_t operator^=(T other) { return SetValue(GetValue() ^ other); }
|
||||
template <typename T> constexpr uint8_t operator<<=(T other) { return SetValue(GetValue() << other); }
|
||||
template <typename T> constexpr uint8_t operator>>=(T other) { return SetValue(GetValue() >> other); }
|
||||
|
||||
uint8_t& Data;
|
||||
uint8_t Shift;
|
||||
};
|
||||
|
||||
constexpr const SubTypeAccess operator[](size_t index) const requires(IsSubByte) { return SubTypeAccess{data()[index / ElementsPerByte], index & ElementsPerByte }; }
|
||||
constexpr SubTypeAccess operator[](size_t index) requires(IsSubByte) { return SubTypeAccess{data()[index / ElementsPerByte], index & ElementsPerByte }; }
|
||||
|
||||
public: // MultiElement
|
||||
struct MultiElementAccess
|
||||
{
|
||||
constexpr MultiElementAccess(StorageType& data) : Data{ data } {};
|
||||
|
||||
StorageType& Data;
|
||||
};
|
||||
|
||||
constexpr const MultiElementAccess operator[](size_t index) const requires(IsMultiElement) { return MultiElementAccess{data()[index]}; }
|
||||
constexpr MultiElementAccess operator[](size_t index) requires(IsMultiElement) { return MultiElementAccess{data()[index]}; }
|
||||
|
||||
public: // default
|
||||
constexpr const StorageType& operator[](size_t index) const requires(IsDefaultByteLayout) { return data()[index]; }
|
||||
constexpr StorageType& operator[](size_t index) requires(IsDefaultByteLayout) { return data()[index]; }
|
||||
|
||||
};
|
||||
|
||||
static_assert(BitContainer<1, 10>::ElementsPerByte == 8);
|
||||
static_assert(BitContainer<2, 10>::ElementsPerByte == 4);
|
||||
static_assert(BitContainer<4, 10>::ElementsPerByte == 2);
|
||||
static_assert(BitContainer<8, 10>::ElementsPerByte == 1);
|
||||
|
||||
|
||||
} // namespace WFC
|
||||
52
include/nd-wfc/wfc_builder.hpp
Normal file
52
include/nd-wfc/wfc_builder.hpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#pragma once
|
||||
|
||||
namespace WFC {
|
||||
|
||||
#include "wfc_utils.hpp"
|
||||
#include "wfc_variable_map.hpp"
|
||||
#include "wfc_constrainer.hpp"
|
||||
#include "wfc_callbacks.hpp"
|
||||
#include "wfc_random.hpp"
|
||||
#include "wfc.hpp"
|
||||
|
||||
/**
|
||||
* @brief Builder class for creating WFC instances
|
||||
*/
|
||||
template<typename WorldT, typename VarT = typename WorldT::ValueType, typename VariableIDMapT = VariableIDMap<VarT>, typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>, typename CallbacksT = Callbacks<WorldT>, typename RandomSelectorT = DefaultRandomSelector<VarT>>
|
||||
class Builder {
|
||||
public:
|
||||
|
||||
template <VarT ... Values>
|
||||
using DefineIDs = Builder<WorldT, VarT, typename VariableIDMapT::template Merge<Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
|
||||
|
||||
template <typename ConstrainerFunctionT, VarT ... CorrespondingValues>
|
||||
requires ConstrainerFunction<ConstrainerFunctionT, WorldT, VarT, VariableIDMapT>
|
||||
using DefineConstrainer = Builder<WorldT, VarT, VariableIDMapT,
|
||||
MergedConstrainerFunctionMap<
|
||||
VariableIDMapT,
|
||||
ConstrainerFunctionMapT,
|
||||
ConstrainerFunctionT,
|
||||
VariableIDMap<VarT, CorrespondingValues...>,
|
||||
decltype([](WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&) {})
|
||||
>, CallbacksT, RandomSelectorT
|
||||
>;
|
||||
|
||||
template <typename NewCellCollapsedCallbackT>
|
||||
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT>;
|
||||
template <typename NewContradictionCallbackT>
|
||||
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT>;
|
||||
template <typename NewBranchCallbackT>
|
||||
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT>;
|
||||
|
||||
template <typename NewRandomSelectorT>
|
||||
requires RandomSelectorFunction<NewRandomSelectorT, VarT>
|
||||
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
|
||||
|
||||
template <uint16_t DefaultWeight, typename... WeightSpecs>
|
||||
using Weights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, WeightedSelector<VarT, VariableIDMapT, RandomSelectorT, WeightsMap<VarT, VariableIDMapT, DefaultWeight, WeightSpecs...>>>;
|
||||
|
||||
|
||||
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
|
||||
};
|
||||
|
||||
}
|
||||
44
include/nd-wfc/wfc_callbacks.hpp
Normal file
44
include/nd-wfc/wfc_callbacks.hpp
Normal file
@@ -0,0 +1,44 @@
|
||||
#pragma once
|
||||
|
||||
namespace WFC {
|
||||
|
||||
/**
|
||||
* @brief Empty callback function
|
||||
* @param WorldT The world type
|
||||
*/
|
||||
template <typename WorldT>
|
||||
using EmptyCallback = decltype([](WorldT&){});
|
||||
|
||||
/**
|
||||
* @brief Callback struct
|
||||
* @param WorldT The world type
|
||||
* @param AllCellsCollapsedCallbackT The all cells collapsed callback type
|
||||
* @param CellCollapsedCallbackT The cell collapsed callback type
|
||||
* @param ContradictionCallbackT The contradiction callback type
|
||||
* @param BranchCallbackT The branch callback type
|
||||
*/
|
||||
template <typename WorldT,
|
||||
typename CellCollapsedCallbackT = EmptyCallback<WorldT>,
|
||||
typename ContradictionCallbackT = EmptyCallback<WorldT>,
|
||||
typename BranchCallbackT = EmptyCallback<WorldT>
|
||||
>
|
||||
struct Callbacks
|
||||
{
|
||||
using CellCollapsedCallback = CellCollapsedCallbackT;
|
||||
using ContradictionCallback = ContradictionCallbackT;
|
||||
using BranchCallback = BranchCallbackT;
|
||||
|
||||
template <typename NewCellCollapsedCallbackT>
|
||||
using SetCellCollapsedCallbackT = Callbacks<WorldT, NewCellCollapsedCallbackT, ContradictionCallbackT, BranchCallbackT>;
|
||||
template <typename NewContradictionCallbackT>
|
||||
using SetContradictionCallbackT = Callbacks<WorldT, CellCollapsedCallbackT, NewContradictionCallbackT, BranchCallbackT>;
|
||||
template <typename NewBranchCallbackT>
|
||||
using SetBranchCallbackT = Callbacks<WorldT, CellCollapsedCallbackT, ContradictionCallbackT, NewBranchCallbackT>;
|
||||
|
||||
static consteval bool HasCellCollapsedCallback() { return !std::is_same_v<CellCollapsedCallbackT, EmptyCallback<WorldT>>; }
|
||||
static consteval bool HasContradictionCallback() { return !std::is_same_v<ContradictionCallbackT, EmptyCallback<WorldT>>; }
|
||||
static consteval bool HasBranchCallback() { return !std::is_same_v<BranchCallbackT, EmptyCallback<WorldT>>; }
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
122
include/nd-wfc/wfc_constrainer.hpp
Normal file
122
include/nd-wfc/wfc_constrainer.hpp
Normal file
@@ -0,0 +1,122 @@
|
||||
#pragma once
|
||||
|
||||
#include "wfc_variable_map.hpp"
|
||||
|
||||
namespace WFC {
|
||||
|
||||
template <typename ... ConstrainerFunctions>
|
||||
struct ConstrainerFunctionMap {
|
||||
public:
|
||||
static consteval size_t size() { return sizeof...(ConstrainerFunctions); }
|
||||
|
||||
using TupleType = std::tuple<ConstrainerFunctions...>;
|
||||
|
||||
template <typename ConstrainerFunctionPtrT>
|
||||
static ConstrainerFunctionPtrT GetFunction(size_t index)
|
||||
{
|
||||
static_assert((std::is_empty_v<ConstrainerFunctions> && ...), "Lambdas must not have any captures");
|
||||
static ConstrainerFunctionPtrT functions[] = {
|
||||
static_cast<ConstrainerFunctionPtrT>(ConstrainerFunctions{}) ...
|
||||
};
|
||||
return functions[index];
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to select the correct constrainer function based on the index and the value
|
||||
template<std::size_t I,
|
||||
typename VariableIDMapT,
|
||||
typename ConstrainerFunctionMapT,
|
||||
typename NewConstrainerFunctionT,
|
||||
typename SelectedIDsVariableIDMapT,
|
||||
typename EmptyFunctionT>
|
||||
using MergedConstrainerElementSelector =
|
||||
std::conditional_t<SelectedIDsVariableIDMapT::template HasValue<VariableIDMapT::GetValueConsteval(I)>(), // if the value is in the selected IDs
|
||||
NewConstrainerFunctionT,
|
||||
std::conditional_t<(I < ConstrainerFunctionMapT::size()), // if the index is within the size of the tuple
|
||||
std::tuple_element_t<std::min(I, ConstrainerFunctionMapT::size() - 1), typename ConstrainerFunctionMapT::TupleType>,
|
||||
EmptyFunctionT
|
||||
>
|
||||
>;
|
||||
|
||||
// Helper to make a merged constrainer function map
|
||||
template<typename VariableIDMapT,
|
||||
typename ConstrainerFunctionMapT,
|
||||
typename NewConstrainerFunctionT,
|
||||
typename SelectedIDsVariableIDMapT,
|
||||
typename EmptyFunctionT,
|
||||
std::size_t... Is>
|
||||
auto MakeMergedConstrainerIDMap(std::index_sequence<Is...>,VariableIDMapT*, ConstrainerFunctionMapT*, NewConstrainerFunctionT*, SelectedIDsVariableIDMapT*, EmptyFunctionT*)
|
||||
-> ConstrainerFunctionMap<MergedConstrainerElementSelector<Is, VariableIDMapT, ConstrainerFunctionMapT, NewConstrainerFunctionT, SelectedIDsVariableIDMapT, EmptyFunctionT>...>;
|
||||
|
||||
// Main alias for the merged constrainer function map
|
||||
template<typename VariableIDMapT,
|
||||
typename ConstrainerFunctionMapT,
|
||||
typename NewConstrainerFunctionT,
|
||||
typename SelectedIDsVariableIDMapT,
|
||||
typename EmptyFunctionT>
|
||||
using MergedConstrainerFunctionMap = decltype(
|
||||
MakeMergedConstrainerIDMap(std::make_index_sequence<VariableIDMapT::ValuesRegisteredAmount>{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr)
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Constrainer class used in constraint functions to limit possible values for other cells
|
||||
*/
|
||||
template <typename VariableIDMapT>
|
||||
class Constrainer {
|
||||
public:
|
||||
using MaskType = typename VariableIDMapT::MaskType;
|
||||
|
||||
public:
|
||||
Constrainer(Wave<MaskType>& wave, WFCQueue<size_t>& propagationQueue)
|
||||
: m_wave(wave)
|
||||
, m_propagationQueue(propagationQueue)
|
||||
{}
|
||||
|
||||
/**
|
||||
* @brief Constrain a cell to exclude specific values
|
||||
* @param cellId The ID of the cell to constrain
|
||||
* @param forbiddenValues The set of forbidden values for this cell
|
||||
*/
|
||||
template <typename VariableIDMapT::Type ... ExcludedValues>
|
||||
void Exclude(size_t cellId) {
|
||||
static_assert(sizeof...(ExcludedValues) > 0, "At least one excluded value must be provided");
|
||||
ApplyMask(cellId, ~VariableIDMapT::template GetMask<ExcludedValues...>());
|
||||
}
|
||||
|
||||
void Exclude(WorldValue<typename VariableIDMapT::Type> value, size_t cellId) {
|
||||
ApplyMask(cellId, ~(1 << value.InternalIndex));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Constrain a cell to only allow one specific value
|
||||
* @param cellId The ID of the cell to constrain
|
||||
* @param value The only allowed value for this cell
|
||||
*/
|
||||
template <typename VariableIDMapT::Type ... AllowedValues>
|
||||
void Only(size_t cellId) {
|
||||
static_assert(sizeof...(AllowedValues) > 0, "At least one allowed value must be provided");
|
||||
ApplyMask(cellId, VariableIDMapT::template GetMask<AllowedValues...>());
|
||||
}
|
||||
|
||||
void Only(WorldValue<typename VariableIDMapT::Type> value, size_t cellId) {
|
||||
ApplyMask(cellId, 1 << value.InternalIndex);
|
||||
}
|
||||
|
||||
private:
|
||||
void ApplyMask(size_t cellId, MaskType mask) {
|
||||
bool wasCollapsed = m_wave.IsCollapsed(cellId);
|
||||
|
||||
m_wave.Collapse(cellId, mask);
|
||||
|
||||
bool collapsed = m_wave.IsCollapsed(cellId);
|
||||
if (!wasCollapsed && collapsed) {
|
||||
m_propagationQueue.push(cellId);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Wave<MaskType>& m_wave;
|
||||
WFCQueue<size_t>& m_propagationQueue;
|
||||
};
|
||||
|
||||
}
|
||||
22
include/nd-wfc/wfc_large_integers.hpp
Normal file
22
include/nd-wfc/wfc_large_integers.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
namespace WFC {
|
||||
|
||||
template <size_t Size>
|
||||
struct LargeInteger
|
||||
{
|
||||
std::array<uint64_t, Size> m_data;
|
||||
|
||||
template <size_t OtherSize>
|
||||
constexpr LargeInteger<std::max(Size, OtherSize)> operator+(const LargeInteger<OtherSize>& other) const {
|
||||
LargeInteger<std::max(Size, OtherSize)> result;
|
||||
for (size_t i = 0; i < std::max(Size, OtherSize); i++) {
|
||||
result[i] = m_data[i] + other[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
178
include/nd-wfc/wfc_random.hpp
Normal file
178
include/nd-wfc/wfc_random.hpp
Normal file
@@ -0,0 +1,178 @@
|
||||
#pragma once
|
||||
|
||||
namespace WFC {
|
||||
|
||||
/**
|
||||
* @brief Default constexpr random selector using a simple seed-based algorithm
|
||||
* This provides a compile-time random selection that maintains state between calls
|
||||
*/
|
||||
template <typename VarT>
|
||||
class DefaultRandomSelector {
|
||||
private:
|
||||
mutable uint32_t m_seed;
|
||||
|
||||
public:
|
||||
constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {}
|
||||
|
||||
constexpr size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
|
||||
// Simple linear congruential generator for constexpr compatibility
|
||||
return static_cast<size_t>(rng(possibleValues.size()));
|
||||
}
|
||||
|
||||
constexpr uint32_t rng(uint32_t max) {
|
||||
m_seed = m_seed * 1103515245 + 12345;
|
||||
return m_seed % max;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Advanced random selector using std::mt19937 and std::uniform_int_distribution
|
||||
* This provides high-quality randomization for runtime use
|
||||
*/
|
||||
template <typename VarT>
|
||||
class AdvancedRandomSelector {
|
||||
private:
|
||||
std::mt19937& m_rng;
|
||||
|
||||
public:
|
||||
explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {}
|
||||
|
||||
size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
|
||||
return rng(possibleValues.size());
|
||||
}
|
||||
|
||||
uint32_t rng(uint32_t max) {
|
||||
std::uniform_int_distribution<uint32_t> dist(0, max);
|
||||
return dist(m_rng);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Weight specification for a specific value
|
||||
* @tparam Value The variable value
|
||||
* @tparam Weight The 16-bit weight for this value
|
||||
*/
|
||||
template <typename VarT, VarT Value, uint16_t WeightValue>
|
||||
struct Weight {
|
||||
static constexpr VarT value = Value;
|
||||
static constexpr uint16_t weight = WeightValue;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compile-time weights storage for weighted random selection
|
||||
* @tparam VarT The variable type
|
||||
* @tparam VariableIDMapT The variable ID map type
|
||||
* @tparam DefaultWeight The default weight for values not explicitly specified
|
||||
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
|
||||
*/
|
||||
template <typename VarT, typename VariableIDMapT, uint16_t DefaultWeight, typename... WeightSpecs>
|
||||
class WeightsMap {
|
||||
private:
|
||||
static constexpr size_t NumWeights = sizeof...(WeightSpecs);
|
||||
|
||||
// Helper to get weight for a specific value
|
||||
static consteval uint16_t GetWeightForValue(VarT targetValue) {
|
||||
// Check each weight spec to find the target value
|
||||
uint16_t weight = DefaultWeight;
|
||||
((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...);
|
||||
return weight;
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Get the weight for a specific value at compile time
|
||||
* @tparam TargetValue The value to get weight for
|
||||
* @return The weight for the value
|
||||
*/
|
||||
template <VarT TargetValue>
|
||||
static consteval uint16_t GetWeight() {
|
||||
return GetWeightForValue(TargetValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get weights array for all registered values
|
||||
* @return Array of weights corresponding to all registered values
|
||||
*/
|
||||
static consteval std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> GetWeightsArray() {
|
||||
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> weights{};
|
||||
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
weights[i] = GetWeightForValue(VariableIDMapT::GetValueConsteval(i));
|
||||
}
|
||||
|
||||
return weights;
|
||||
}
|
||||
|
||||
static consteval uint32_t GetTotalWeight() {
|
||||
uint32_t totalWeight = 0;
|
||||
auto weights = GetWeightsArray();
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
totalWeight += weights[i];
|
||||
}
|
||||
return totalWeight;
|
||||
}
|
||||
|
||||
static consteval std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> GetCumulativeWeightsArray() {
|
||||
auto weights = GetWeightsArray();
|
||||
uint32_t totalWeight = 0;
|
||||
std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> cumulativeWeights{};
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
totalWeight += weights[i];
|
||||
cumulativeWeights[i] = totalWeight;
|
||||
}
|
||||
return cumulativeWeights;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Weighted random selector that uses another random selector as backend
|
||||
* @tparam VarT The variable type
|
||||
* @tparam VariableIDMapT The variable ID map type
|
||||
* @tparam BackendSelectorT The backend random selector type
|
||||
* @tparam WeightsMapT The weights map type containing weight specifications
|
||||
*/
|
||||
template <typename VarT, typename VariableIDMapT, typename BackendSelectorT, typename WeightsMapT>
|
||||
class WeightedSelector {
|
||||
private:
|
||||
BackendSelectorT m_backendSelector;
|
||||
const std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> m_weights;
|
||||
const std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> m_cumulativeWeights;
|
||||
|
||||
public:
|
||||
explicit WeightedSelector(BackendSelectorT backendSelector)
|
||||
: m_backendSelector(backendSelector)
|
||||
, m_weights(WeightsMapT::GetWeightsArray())
|
||||
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
|
||||
{}
|
||||
|
||||
explicit WeightedSelector(uint32_t seed)
|
||||
requires std::is_same_v<BackendSelectorT, DefaultRandomSelector<VarT>>
|
||||
: m_backendSelector(seed)
|
||||
, m_weights(WeightsMapT::GetWeightsArray())
|
||||
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
|
||||
{}
|
||||
|
||||
size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value");
|
||||
|
||||
// Use backend selector to pick a random number in range [0, totalWeight)
|
||||
uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back());
|
||||
|
||||
// Find which value this random value corresponds to
|
||||
for (size_t i = 0; i < possibleValues.size(); ++i) {
|
||||
if (randomValue <= m_cumulativeWeights[i]) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback (should not reach here)
|
||||
return possibleValues.size() - 1;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
27
include/nd-wfc/wfc_utils.hpp
Normal file
27
include/nd-wfc/wfc_utils.hpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
namespace WFC
|
||||
{
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
inline constexpr void constexpr_assert(bool condition, const char* message = "") {
|
||||
if (!condition) throw message;
|
||||
}
|
||||
|
||||
inline int FindNthSetBit(size_t num, int n) {
|
||||
constexpr_assert(n < std::popcount(num), "index is out of range");
|
||||
int bitCount = 0;
|
||||
while (num) {
|
||||
if (bitCount == n) {
|
||||
return std::countr_zero(num); // Index of the current set bit
|
||||
}
|
||||
bitCount++;
|
||||
num &= (num - 1); // turn of lowest set bit
|
||||
}
|
||||
return bitCount;
|
||||
}
|
||||
|
||||
}
|
||||
73
include/nd-wfc/wfc_variable_map.hpp
Normal file
73
include/nd-wfc/wfc_variable_map.hpp
Normal file
@@ -0,0 +1,73 @@
|
||||
#pragma once
|
||||
|
||||
#include "wfc_utils.hpp"
|
||||
|
||||
namespace WFC {
|
||||
|
||||
/**
|
||||
* @brief Class to map variable values to indices at compile time
|
||||
*
|
||||
* This class is used to map variable values to indices at compile time.
|
||||
* It is a compile-time map of variable values to indices.
|
||||
*/
|
||||
template <typename VarT, VarT ... Values>
|
||||
class VariableIDMap {
|
||||
public:
|
||||
|
||||
using Type = VarT;
|
||||
static constexpr size_t ValuesRegisteredAmount = sizeof...(Values);
|
||||
|
||||
template <VarT ... AdditionalValues>
|
||||
using Merge = VariableIDMap<VarT, Values..., AdditionalValues...>;
|
||||
|
||||
template <VarT Value>
|
||||
static consteval bool HasValue()
|
||||
{
|
||||
constexpr VarT arr[] = {Values...};
|
||||
constexpr size_t size = sizeof...(Values);
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if (arr[i] == Value)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
template <VarT Value>
|
||||
static consteval size_t GetIndex()
|
||||
{
|
||||
static_assert(HasValue<Value>(), "Value was not defined");
|
||||
constexpr VarT arr[] = {Values...};
|
||||
constexpr size_t size = ValuesRegisteredAmount;
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if (arr[i] == Value)
|
||||
return i;
|
||||
|
||||
return static_cast<size_t>(-1); // This line is unreachable if value is found
|
||||
}
|
||||
|
||||
static std::span<const VarT> GetAllValues()
|
||||
{
|
||||
static const VarT allValues[]
|
||||
{
|
||||
Values...
|
||||
};
|
||||
return std::span<const VarT>{ allValues, ValuesRegisteredAmount };
|
||||
}
|
||||
|
||||
static VarT GetValue(size_t index) {
|
||||
constexpr_assert(index < ValuesRegisteredAmount);
|
||||
return GetAllValues()[index];
|
||||
}
|
||||
|
||||
static consteval VarT GetValueConsteval(size_t index)
|
||||
{
|
||||
constexpr VarT arr[] = {Values...};
|
||||
return arr[index];
|
||||
}
|
||||
|
||||
static consteval size_t size() { return ValuesRegisteredAmount; }
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
39
include/nd-wfc/wfc_wave.hpp
Normal file
39
include/nd-wfc/wfc_wave.hpp
Normal file
@@ -0,0 +1,39 @@
|
||||
#pragma once
|
||||
|
||||
#include "wfc_bit_container.hpp"
|
||||
#include "wfc_variable_map.hpp"
|
||||
#include "wfc_allocator.hpp"
|
||||
|
||||
namespace WFC {
|
||||
|
||||
template <typename VariableIDMapT, size_t Size = 0>
|
||||
class Wave {
|
||||
public:
|
||||
using BitContainerT = BitContainer<VariableIDMapT::ValuesRegisteredAmount, Size>;
|
||||
using ElementT = typename BitContainerT::StorageType;
|
||||
|
||||
public:
|
||||
Wave() = default;
|
||||
Wave(size_t size, size_t variableAmount, WFCStackAllocator& allocator) : m_data(size, WFCStackAllocatorAdapter<ElementT>(allocator))
|
||||
{
|
||||
for (auto& wave : m_data) wave = (1 << variableAmount) - 1;
|
||||
}
|
||||
|
||||
Wave(const Wave& other) = default;
|
||||
|
||||
public:
|
||||
void Collapse(size_t index, ElementT mask) { m_data[index] &= mask; }
|
||||
size_t size() const { return m_data.size(); }
|
||||
size_t Entropy(size_t index) const { return std::popcount(m_data[index]); }
|
||||
bool IsCollapsed(size_t index) const { return Entropy(index) == 1; }
|
||||
bool IsFullyCollapsed() const { return std::all_of(m_data.begin(), m_data.end(), [](ElementT value) { return std::popcount(value) == 1; }); }
|
||||
bool HasContradiction() const { return std::any_of(m_data.begin(), m_data.end(), [](ElementT value) { return value == 0; }); }
|
||||
bool IsContradicted(size_t index) const { return m_data[index] == 0; }
|
||||
uint16_t GetVariableID(size_t index) const { return static_cast<uint16_t>(std::countr_zero(m_data[index])); }
|
||||
ElementT GetMask(size_t index) const { return m_data[index]; }
|
||||
|
||||
private:
|
||||
BitContainerT m_data;
|
||||
};
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user