migration 24/12/2025

This commit is contained in:
cdemeyer-teachx
2025-12-24 16:58:29 +09:00
parent 51f4936aff
commit 860bb2d35a
12 changed files with 449 additions and 248 deletions

View File

@@ -8,6 +8,8 @@
#include <memory>
#include <utility>
#include <nd-wfc/wfc.h>
// Forward declarations
struct NonogramHints;
struct NonogramSolution;
@@ -178,3 +180,7 @@ std::string trim(const std::string& str);
bool startsWith(const std::string& str, const std::string& prefix);
std::vector<std::string> split(const std::string& str, char delimiter);
std::vector<uint8_t> parseNumberSequence(const std::string& str, char delimiter);
using NonogramWFC = WFC::Builder<Nonogram, bool>
::Define<false, true>
::Build;

View File

@@ -279,15 +279,15 @@ private:
public: // WFC Support
using ValueType = uint8_t;
constexpr inline ValueType getValue(size_t index) const {
constexpr inline ValueType getValue(uint8_t index) const {
return board_.get(static_cast<int>(index));
}
constexpr inline void setValue(size_t index, ValueType value) {
constexpr inline void setValue(uint8_t index, ValueType value) {
board_.set(static_cast<int>(index), value);
}
constexpr inline size_t size() const {
constexpr inline uint8_t size() const {
return 81;
}
};

View File

@@ -8,6 +8,7 @@
#include <type_traits>
#include <cassert>
#include <algorithm>
#include <ranges>
#include <concepts>
#include <bit>
#include <span>
@@ -22,60 +23,50 @@
#include "wfc_callbacks.hpp"
#include "wfc_random.hpp"
#include "wfc_queue.hpp"
#include "wfc_weights.hpp"
namespace WFC {
template<typename T>
concept WorldType = requires(T world, size_t id, typename T::ValueType value) {
{ world.size() } -> std::convertible_to<size_t>;
{ world.setValue(id, value) };
{ world.getValue(id) } -> std::convertible_to<typename T::ValueType>;
concept WorldType = requires(T world, typename T::ValueType value) {
{ world.size() } -> std::is_integral;
{ world.setValue(static_cast<decltype(world.size())>(0), value) };
{ world.getValue(static_cast<decltype(world.size())>(0)) } -> std::convertible_to<typename T::ValueType>;
typename T::ValueType;
};
/**
* @brief Concept to validate constrainer function signature
* The function must be callable with parameters: (WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)
*/
template <typename T, typename WorldT, typename VarT, typename VariableIDMapT, typename PropagationQueueType>
concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue<VarT> value, Constrainer<VariableIDMapT, PropagationQueueType>& constrainer) {
func(world, index, value, constrainer);
};
/**
* @brief Concept to validate random selector function signature
* The function must be callable with parameters: (std::span<const VarT>) and return size_t
*/
template <typename T, typename VarT>
concept RandomSelectorFunction = requires(const T& func, std::span<const VarT> possibleValues) {
{ func(possibleValues) } -> std::convertible_to<size_t>;
{ func.rng(static_cast<uint32_t>(1)) } -> std::convertible_to<uint32_t>;
};
template <typename WorldT>
concept HasConstexprSize = requires {
{ []() constexpr -> std::size_t { return WorldT{}.size(); }() };
};
/**
* @brief Main WFC class implementing the Wave Function Collapse algorithm
*/
template<typename WorldT, typename VarT,
typename VariableIDMapT = VariableIDMap<VarT>,
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
typename CallbacksT = Callbacks<WorldT>,
typename RandomSelectorT = DefaultRandomSelector<VarT>>
typename RandomSelectorT = DefaultRandomSelector<VarT>,
typename WeightsMapT = WeightsMap<VariableIDMapT>
>
class WFC {
public:
static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements");
// Try getting the world size, which is only available if the world type has a constexpr size() method
constexpr static size_t WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
using WorldSizeT = decltype(WorldT{}.size());
using WaveType = Wave<VariableIDMapT, WorldSize>;
using PropagationQueueType = WFCQueue<WorldSize>;
// Try getting the world size, which is only available if the world type has a constexpr size() method
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
using WaveType = Wave<VariableIDMapT, WeightsMapT, WorldSize>;
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
using MaskType = typename WaveType::ElementT;
using VariableIDT = typename WaveType::VariableIDT;
using WeightsBufferType = BitContainer<WeightsMapT::GetMaxPrecision(), VariableIDMapT::size()>;
public:
struct SolverState
@@ -88,7 +79,7 @@ public:
SolverState(WorldT& world, uint32_t seed)
: m_world(world)
, m_propagationQueue{ WorldSize ? WorldSize : static_cast<size_t>(world.size()) }
, m_propagationQueue{ WorldSize ? WorldSize : static_cast<WorldSizeT>(world.size()) }
, m_randomSelector(seed)
{}
@@ -114,7 +105,7 @@ public:
*/
static bool Run(SolverState& state)
{
WaveType wave{ WorldSize, VariableIDMapT::ValuesRegisteredAmount, state.m_allocator };
WaveType wave{ WorldSize, VariableIDMapT::size(), state.m_allocator };
PropogateInitialValues(state, wave);
@@ -188,7 +179,7 @@ public:
}
private:
static void CollapseCell(SolverState& state, WaveType& wave, size_t cellId, uint16_t value)
static void CollapseCell(SolverState& state, WaveType& wave, WorldSizeT cellId, VariableIDT value)
{
constexpr_assert(!wave.IsCollapsed(cellId) || wave.GetMask(cellId) == (MaskType(1) << value));
wave.Collapse(cellId, 1 << value);
@@ -201,51 +192,111 @@ private:
}
}
static bool Branch(SolverState& state, WaveType& wave)
static WorldSizeT FindMinimumEntropyCells(std::span<VariableIDT>& buffer, WaveType& wave)
{
constexpr_assert(state.m_propagationQueue.empty());
// Find cell with minimum entropy > 1
size_t minEntropyCell = static_cast<size_t>(-1);
size_t minEntropy = static_cast<size_t>(-1);
for (size_t i = 0; i < wave.size(); ++i) {
size_t entropy = wave.Entropy(i);
if (entropy > 1 && entropy < minEntropy) {
minEntropy = entropy;
minEntropyCell = i;
}
}
if (minEntropyCell == static_cast<size_t>(-1)) return false;
auto entropyGetter = [&wave](size_t index) -> size_t { return wave.Entropy(index); };
auto entropyFilter = [&wave](size_t entropy) -> bool { return entropy > 1; };
auto minEntropyCell = *std::ranges::min_element(std::views::iota(0, wave.size()) | std::views::transform(entropyGetter) | std::views::filter(entropyFilter));
constexpr_assert(!wave.IsCollapsed(minEntropyCell));
// create a list of possible values
uint16_t availableValues = static_cast<uint16_t>(wave.Entropy(minEntropyCell));
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> possibleValues; // inplace vector
VariableIDT availableValues = wave.Entropy(minEntropyCell);
MaskType mask = wave.GetMask(minEntropyCell);
for (size_t i = 0; i < availableValues; ++i)
{
uint16_t index = static_cast<uint16_t>(std::countr_zero(mask)); // get the index of the lowest set bit
constexpr_assert(index < VariableIDMapT::ValuesRegisteredAmount, "Possible value went outside bounds");
VariableIDT index = static_cast<VariableIDT>(std::countr_zero(mask)); // get the index of the lowest set bit
constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds");
possibleValues[i] = index;
buffer[i] = index;
constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set");
mask = mask & (mask - 1); // turn off lowest set bit
}
// randomly select a value from possible values
while (availableValues)
return minEntropyCell;
}
using WeightsType = MinimumIntegerType<WeightsMapT::GetMaxValue()>;
using RandomGeneratorReturnType = decltype(RandomSelectorT{}.rng(static_cast<uint32_t>(1)));
static WorldSizeT FindMinimumEntropyCellsWeighted(std::span<VariableIDT>& buffer, std::span<WeightsType>& weights, WaveType& wave)
{
constexpr size_t ElementsMaxWeight = std::min(WeightsMapT::GetMaxValue() * VariableIDMapT::size(), std::numeric_limits<RandomGeneratorReturnType>::max()) / VariableIDMapT::size();
auto accumulatedWeightedEntropyGetter = [&wave](size_t index) -> uint64_t
{
auto entropyFilter = [&wave](size_t index) -> bool { return wave.Entropy(index) > 1; };
auto weightedEntropyGetter = [&wave](size_t index) -> uint64_t { return wave.template GetWeight<ElementsMaxWeight>(index); };
auto view = std::views::iota(0, VariableIDMapT::size()) | std::views::filter(entropyFilter) | std::views::transform(weightedEntropyGetter);
return std::accumulate(view.begin(), view.end(), 0);
};
auto minEntropyCell = *std::ranges::min_element(std::views::iota(0, wave.size()) | std::views::transform(accumulatedWeightedEntropyGetter));
VariableIDT availableValues = wave.Entropy(minEntropyCell);
MaskType mask = wave.GetMask(minEntropyCell);
for (size_t i = 0; i < availableValues; ++i)
{
// Create a span of the actual variable values for the random selector
std::array<VarT, VariableIDMapT::ValuesRegisteredAmount> valueArray;
for (size_t i = 0; i < availableValues; ++i) {
valueArray[i] = VariableIDMapT::GetValue(possibleValues[i]);
VariableIDT index = static_cast<VariableIDT>(std::countr_zero(mask)); // get the index of the lowest set bit
constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds");
buffer[i] = index;
constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set");
weights[i] = wave.template GetWeight<ElementsMaxWeight>(index);
mask = mask & (mask - 1); // turn off lowest set bit
}
return minEntropyCell;
}
static bool Branch(SolverState& state, WaveType& wave)
{
constexpr_assert(state.m_propagationQueue.empty());
std::array<VariableIDT, VariableIDMapT::size()> Buffer{};
std::array<WeightsType, VariableIDMapT::size()> WeightsBuffer{};
uint64_t accumulatedWeights = 0;
WorldSizeT minEntropyCell{};
if constexpr (WeightsMapT::HasWeights())
{
minEntropyCell = FindMinimumEntropyCellsWeighted(Buffer, WeightsBuffer, wave);
accumulatedWeights = std::accumulate(WeightsBuffer.begin(), WeightsBuffer.end(), 0);
}
else
{
minEntropyCell = FindMinimumEntropyCells(Buffer, wave);
}
// randomly select a value from possible values
while (Buffer.size())
{
size_t randomIndex;
VariableIDT selectedValue;
if constexpr (WeightsMapT::HasWeights())
{
auto randomWeight = state.m_randomSelector.rng(accumulatedWeights);
for (size_t i = 0; i < WeightsBuffer.size(); ++i)
{
if (randomWeight < WeightsBuffer[i])
{
randomIndex = i;
selectedValue = Buffer[i];
break;
}
randomWeight -= WeightsBuffer[i];
}
accumulatedWeights -= WeightsBuffer[randomIndex];
}
else
{
randomIndex = state.m_randomSelector.rng(Buffer.size());
selectedValue = Buffer[randomIndex];
}
std::span<const VarT> currentPossibleValues(valueArray.data(), availableValues);
size_t randomIndex = state.m_randomSelector(currentPossibleValues);
size_t selectedValue = possibleValues[randomIndex];
{
// copy the state and branch out
@@ -253,12 +304,12 @@ private:
auto queueFrame = state.m_propagationQueue.createBranchPoint();
auto newWave = wave;
CollapseCell(state, newWave, minEntropyCell, static_cast<uint16_t>(selectedValue));
CollapseCell(state, newWave, minEntropyCell, selectedValue);
state.m_propagationQueue.push(minEntropyCell);
if (RunLoop(state, newWave))
{
// copy the solution to the original state
// move the solution to the original state
wave = newWave;
return true;
@@ -271,7 +322,8 @@ private:
constexpr_assert((wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) == 0, "Wave was not collapsed correctly");
// swap replacement value with the last value
std::swap(possibleValues[randomIndex], possibleValues[--availableValues]);
std::swap(Buffer[randomIndex], Buffer[Buffer.size() - 1]);
Buffer = Buffer.subspan(0, Buffer.size() - 1);
}
return false;
@@ -281,16 +333,16 @@ private:
{
while (!state.m_propagationQueue.empty())
{
size_t cellId = state.m_propagationQueue.pop();
WorldSizeT cellId = state.m_propagationQueue.pop();
if (wave.IsContradicted(cellId)) return false;
constexpr_assert(wave.IsCollapsed(cellId), "Cell was not collapsed");
uint16_t variableID = wave.GetVariableID(cellId);
VariableIDT variableID = wave.GetVariableID(cellId);
ConstrainerType constrainer(wave, state.m_propagationQueue);
using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue<VarT>, ConstrainerType&);
using ConstrainerFunctionPtrT = void(*)(WorldT&, WorldSizeT, WorldValue<VarT>, ConstrainerType&);
ConstrainerFunctionMapT::template GetFunction<ConstrainerFunctionPtrT>(variableID)(state.m_world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
}
@@ -314,7 +366,7 @@ private:
{
if (state.m_world.getValue(i) == VariableIDMapT::GetValue(j))
{
CollapseCell(state, wave, static_cast<uint16_t>(i), static_cast<uint16_t>(j));
CollapseCell(state, wave, static_cast<WorldSizeT>(i), static_cast<VariableIDT>(j));
state.m_propagationQueue.push(i);
break;
}

View File

@@ -60,6 +60,7 @@ public:
static constexpr bool IsMultiElement = (StorageBits > 64);
static constexpr bool IsSubByte = (StorageBits < 8);
static constexpr size_t ElementsPerByte = sizeof(StorageType) * 8 / std::max<size_t>(1u, StorageBits);
static constexpr size_t MaxValue = (StorageType{1} << BitsPerElement) - 1;
using ContainerType =
std::conditional_t<Bits == 0,
@@ -110,9 +111,9 @@ public:
BitContainer() = default;
BitContainer(const AllocatorT& allocator) : AllocatorT(allocator) {};
explicit BitContainer(size_t initial_size, const AllocatorT& allocator) requires (IsResizable)
explicit BitContainer(size_t size, const AllocatorT& allocator) requires (IsResizable)
: AllocatorT(allocator)
, m_container(initial_size, allocator)
, m_container(size, allocator)
{};
explicit BitContainer(size_t, const AllocatorT& allocator) requires (!IsResizable)
: AllocatorT(allocator)

View File

@@ -7,6 +7,7 @@ namespace WFC {
#include "wfc_constrainer.hpp"
#include "wfc_callbacks.hpp"
#include "wfc_random.hpp"
#include "wfc_weights.hpp"
#include "wfc.hpp"
/**
@@ -19,33 +20,37 @@ template<
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
typename CallbacksT = Callbacks<WorldT>,
typename RandomSelectorT = DefaultRandomSelector<VarT>,
typename WeightsMapT = WeightsMap<VariableIDMapT>,
typename SelectedValueT = void>
class Builder {
public:
constexpr static size_t WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
using WorldSizeT = decltype(WorldT{}.size());
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
using WaveType = Wave<VariableIDMapT, WorldSize>;
using PropagationQueueType = WFCQueue<WorldSize>;
using WaveType = Wave<VariableIDMapT, WeightsMapT, WorldSize>;
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
template <VarT ... Values>
using DefineIDs = Builder<WorldT, VarT, VariableIDMap<VarT, Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap<VarT, Values...>>;
using DefineIDs = Builder<WorldT, VarT, VariableIDMap<VarT, Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap<VarT, Values...>>;
template <size_t RangeStart, size_t RangeEnd>
using DefineRange = Builder<WorldT, VarT, VariableIDRange<VarT, RangeStart, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
using DefineRange = Builder<WorldT, VarT, VariableIDRange<VarT, RangeStart, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
template <size_t RangeEnd>
using DefineRange0 = Builder<WorldT, VarT, VariableIDRange<VarT, 0, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, 0, RangeEnd>>;
using DefineRange0 = Builder<WorldT, VarT, VariableIDRange<VarT, 0, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, 0, RangeEnd>>;
template <VarT ... Values>
using Variable = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap<VarT, Values...>>;
using Variable = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap<VarT, Values...>>;
template <size_t RangeStart, size_t RangeEnd>
using VariableRange = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
using VariableRange = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
using EmptyConstrainerFunctionT = EmptyConstrainerFunction<WorldT, WorldSizeT, VarT, ConstrainerType>;
template <typename ConstrainerFunctionT>
requires ConstrainerFunction<ConstrainerFunctionT, WorldT, VarT, WaveType, PropagationQueueType>
using Constrain = Builder<WorldT, VarT, VariableIDMapT,
@@ -54,8 +59,8 @@ public:
ConstrainerFunctionMapT,
ConstrainerFunctionT,
SelectedValueT,
decltype([](WorldT&, size_t, WorldValue<VarT>, ConstrainerType&) {})
>, CallbacksT, RandomSelectorT, SelectedValueT
EmptyConstrainerFunctionT
>, CallbacksT, RandomSelectorT, WeightsMapT, SelectedValueT
>;
template <typename ConstrainerFunctionT>
@@ -66,27 +71,28 @@ public:
ConstrainerFunctionMapT,
ConstrainerFunctionT,
VariableIDMapT,
decltype([](WorldT&, size_t, WorldValue<VarT>, ConstrainerType&) {})
>, CallbacksT, RandomSelectorT
EmptyConstrainerFunctionT
>, CallbacksT, RandomSelectorT, WeightsMapT
>;
template <typename NewCellCollapsedCallbackT>
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT>;
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT, WeightsMapT>;
template <typename NewContradictionCallbackT>
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT>;
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT, WeightsMapT>;
template <typename NewBranchCallbackT>
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT>;
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT, WeightsMapT>;
template <typename NewRandomSelectorT>
requires RandomSelectorFunction<NewRandomSelectorT, VarT>
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
template <uint16_t DefaultWeight, typename... WeightSpecs>
using Weights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, WeightedSelector<VarT, VariableIDMapT, RandomSelectorT, WeightsMap<VarT, VariableIDMapT, DefaultWeight, WeightSpecs...>>>;
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT, WeightsMapT>;
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
template <EPrecision Precision>
using SetWeights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, typename WeightsMapT::template Merge<PrecisionEntry<SelectedValueT, static_cast<uint8_t>(Precision)>>, SelectedValueT>;
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT>;
};
}

View File

@@ -5,6 +5,12 @@
namespace WFC {
template <typename WorldT, typename WorldSizeT, typename VarT, typename ConstainerType>
struct EmptyConstrainerFunction
{
void operator()(WorldT&, WorldSizeT, WorldValue<VarT>, ConstainerType&) const {}
};
template <typename ... ConstrainerFunctions>
struct ConstrainerFunctionMap {
public:
@@ -56,7 +62,7 @@ template<typename VariableIDMapT,
typename SelectedIDsVariableIDMapT,
typename EmptyFunctionT>
using MergedConstrainerFunctionMap = decltype(
MakeMergedConstrainerIDMap(std::make_index_sequence<VariableIDMapT::ValuesRegisteredAmount>{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr)
MakeMergedConstrainerIDMap(std::make_index_sequence<VariableIDMapT::size()>{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr)
);
/**

View File

@@ -7,7 +7,7 @@
#include <span>
#include <algorithm>
#include "nd-wfc/wfc_utils.hpp"
#include "wfc_utils.hpp"
namespace WFC
{

View File

@@ -14,13 +14,6 @@ private:
public:
constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {}
constexpr size_t operator()(std::span<const VarT> possibleValues) const {
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
// Simple linear congruential generator for constexpr compatibility
return static_cast<size_t>(rng(possibleValues.size()));
}
constexpr uint32_t rng(uint32_t max) const {
m_seed = m_seed * 1103515245 + 12345;
return m_seed % max;
@@ -39,140 +32,10 @@ private:
public:
explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {}
size_t operator()(std::span<const VarT> possibleValues) const {
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
return rng(possibleValues.size());
}
uint32_t rng(uint32_t max) const {
std::uniform_int_distribution<uint32_t> dist(0, max);
return dist(m_rng);
}
};
/**
* @brief Weight specification for a specific value
* @tparam Value The variable value
* @tparam Weight The 16-bit weight for this value
*/
template <typename VarT, VarT Value, uint16_t WeightValue>
struct Weight {
static constexpr VarT value = Value;
static constexpr uint16_t weight = WeightValue;
};
/**
* @brief Compile-time weights storage for weighted random selection
* @tparam VarT The variable type
* @tparam VariableIDMapT The variable ID map type
* @tparam DefaultWeight The default weight for values not explicitly specified
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
*/
template <typename VarT, typename VariableIDMapT, uint16_t DefaultWeight, typename... WeightSpecs>
class WeightsMap {
private:
static constexpr size_t NumWeights = sizeof...(WeightSpecs);
// Helper to get weight for a specific value
static consteval uint16_t GetWeightForValue(VarT targetValue) {
// Check each weight spec to find the target value
uint16_t weight = DefaultWeight;
((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...);
return weight;
}
public:
/**
* @brief Get the weight for a specific value at compile time
* @tparam TargetValue The value to get weight for
* @return The weight for the value
*/
template <VarT TargetValue>
static consteval uint16_t GetWeight() {
return GetWeightForValue(TargetValue);
}
/**
* @brief Get weights array for all registered values
* @return Array of weights corresponding to all registered values
*/
static consteval std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> GetWeightsArray() {
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> weights{};
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
weights[i] = GetWeightForValue(VariableIDMapT::GetValue(i));
}
return weights;
}
static consteval uint32_t GetTotalWeight() {
uint32_t totalWeight = 0;
auto weights = GetWeightsArray();
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
totalWeight += weights[i];
}
return totalWeight;
}
static consteval std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> GetCumulativeWeightsArray() {
auto weights = GetWeightsArray();
uint32_t totalWeight = 0;
std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> cumulativeWeights{};
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
totalWeight += weights[i];
cumulativeWeights[i] = totalWeight;
}
return cumulativeWeights;
}
};
/**
* @brief Weighted random selector that uses another random selector as backend
* @tparam VarT The variable type
* @tparam VariableIDMapT The variable ID map type
* @tparam BackendSelectorT The backend random selector type
* @tparam WeightsMapT The weights map type containing weight specifications
*/
template <typename VarT, typename VariableIDMapT, typename BackendSelectorT, typename WeightsMapT>
class WeightedSelector {
private:
BackendSelectorT m_backendSelector;
const std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> m_weights;
const std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> m_cumulativeWeights;
public:
explicit WeightedSelector(BackendSelectorT backendSelector)
: m_backendSelector(backendSelector)
, m_weights(WeightsMapT::GetWeightsArray())
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
{}
explicit WeightedSelector(uint32_t seed)
requires std::is_same_v<BackendSelectorT, DefaultRandomSelector<VarT>>
: m_backendSelector(seed)
, m_weights(WeightsMapT::GetWeightsArray())
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
{}
size_t operator()(std::span<const VarT> possibleValues) const {
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value");
// Use backend selector to pick a random number in range [0, totalWeight)
uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back());
// Find which value this random value corresponds to
for (size_t i = 0; i < possibleValues.size(); ++i) {
if (randomValue <= m_cumulativeWeights[i]) {
return i;
}
}
// Fallback (should not reach here)
return possibleValues.size() - 1;
}
};
}

View File

@@ -6,10 +6,32 @@ namespace WFC
inline constexpr void constexpr_assert(bool condition, const char* message = "") {
# ifdef _DEBUG
inline constexpr void constexpr_assert(bool condition, const char* message = "")
{
if (!condition) throw message;
}
#else
inline constexpr void constexpr_assert(bool condition, const char* message = "")
{
(void)condition;
(void)message;
}
# endif
template <size_t Size>
using MinimumIntegerType = std::conditional_t<Size <= std::numeric_limits<uint8_t>::max(), uint8_t,
std::conditional_t<Size <= std::numeric_limits<uint16_t>::max(), uint16_t,
std::conditional_t<Size <= std::numeric_limits<uint32_t>::max(), uint32_t,
uint64_t>>>;
template <uint8_t bits>
using MinimumBitsType = std::conditional_t<bits <= 8, uint8_t,
std::conditional_t<bits <= 16, uint16_t,
std::conditional_t<bits <= 32, uint32_t,
std::conditional_t<bits <= 64, uint64_t,
void>>>>;
inline int FindNthSetBit(size_t num, int n) {
constexpr_assert(n < std::popcount(num), "index is out of range");

View File

@@ -12,23 +12,25 @@ namespace WFC {
* It is a compile-time map of variable values to indices.
*/
template <size_t VariablesAmount>
using VariableIDType = std::conditional_t<VariablesAmount <= std::numeric_limits<uint8_t>::max(), uint8_t, uint16_t>;
template <typename VarT, VarT ... Values>
class VariableIDMap {
public:
using Type = VarT;
static constexpr size_t ValuesRegisteredAmount = sizeof...(Values);
template <VarT ... AdditionalValues>
using Merge = VariableIDMap<VarT, Values..., AdditionalValues...>;
using VariableIDT = VariableIDType<sizeof...(Values)>;
template <VarT Value>
static consteval bool HasValue()
{
constexpr VarT arr[] = {Values...};
constexpr size_t size = sizeof...(Values);
for (size_t i = 0; i < size; ++i)
for (size_t i = 0; i < size(); ++i)
if (arr[i] == Value)
return true;
return false;
@@ -39,9 +41,8 @@ public:
{
static_assert(HasValue<Value>(), "Value was not defined");
constexpr VarT arr[] = {Values...};
constexpr size_t size = ValuesRegisteredAmount;
for (size_t i = 0; i < size; ++i)
for (size_t i = 0; i < size(); ++i)
if (arr[i] == Value)
return i;
@@ -54,15 +55,15 @@ public:
{
Values...
};
return std::span<const VarT>{ allValues, ValuesRegisteredAmount };
return std::span<const VarT>{ allValues, size() };
}
static constexpr VarT GetValue(size_t index) {
constexpr_assert(index < ValuesRegisteredAmount);
constexpr_assert(index < size());
return GetAllValues()[index];
}
static consteval size_t size() { return ValuesRegisteredAmount; }
static consteval size_t size() { return sizeof...(Values); }
template <VarT ... ValuesSlice>
static constexpr auto ValuesToIndices() -> std::array<size_t, sizeof...(ValuesSlice)> {
@@ -76,12 +77,13 @@ class VariableIDRange
{
public:
using Type = VarT;
using VariableIDT = VariableIDType<End - Start>;
static_assert(Start < End, "Start must be less than End");
static_assert(std::numeric_limits<VarT>::min() <= Start, "VarT must be able to represent all values in the range");
static_assert(std::numeric_limits<VarT>::max() >= End, "VarT must be able to represent all values in the range");
static constexpr size_t ValuesRegisteredAmount = End - Start;
static constexpr size_t size() { return End - Start; }
template <VarT Value>
static consteval bool HasValue()
@@ -100,8 +102,6 @@ public:
return Start + index;
}
static consteval size_t size() { return End - Start; }
template <VarT ... ValuesSlice>
static constexpr auto ValuesToIndices() -> std::array<size_t, sizeof...(ValuesSlice)> {
std::array<size_t, sizeof...(ValuesSlice)> indices = {GetIndex<ValuesSlice>()...};

View File

@@ -6,12 +6,16 @@
namespace WFC {
template <typename VariableIDMapT, size_t Size = 0>
template <typename VariableIDMapT, typename WeightsMapT, size_t Size = 0>
class Wave {
public:
using BitContainerT = BitContainer<VariableIDMapT::ValuesRegisteredAmount, Size>;
using BitContainerT = BitContainer<VariableIDMapT::size(), Size>;
using ElementT = typename BitContainerT::StorageType;
using IDMapT = VariableIDMapT;
using WeightContainersT = typename WeightsMapT::template WeightContainersT<Size, WFCStackAllocator>;
using VariableIDT = typename VariableIDMapT::VariableIDT;
using WeightT = typename WeightsMapT::WeightT;
static constexpr size_t ElementsAmount = Size;
public:
@@ -34,8 +38,15 @@ public:
uint16_t GetVariableID(size_t index) const { return static_cast<uint16_t>(std::countr_zero(m_data[index])); }
ElementT GetMask(size_t index) const { return m_data[index]; }
void SetWeight(VariableIDT containerIndex, size_t elementIndex, double weight) { m_weights.SetValueFloat(containerIndex, elementIndex, weight); }
template <size_t MaxWeight>
WeightT GetWeight(VariableIDT containerIndex, size_t elementIndex) const { return m_weights.template GetValue<MaxWeight>(containerIndex, elementIndex); }
private:
BitContainerT m_data;
WeightContainersT m_weights;
};
}

View File

@@ -0,0 +1,234 @@
#pragma once
#include <array>
#include <span>
#include <tuple>
#include "wfc_bit_container.hpp"
#include "wfc_utils.hpp"
namespace WFC {
template <typename VariableMap, uint8_t Precision>
struct PrecisionEntry
{
constexpr static uint8_t PrecisionValue = Precision;
template <typename MainVariableMap>
constexpr static bool UpdatePrecisions(std::span<uint8_t> precisions)
{
constexpr auto SelectedEntries = VariableMap::GetAllValues();
for (auto entry : SelectedEntries)
{
precisions[MainVariableMap::GetIndex(entry)] = Precision;
}
return true;
}
};
enum class EPrecision : uint8_t
{
Precision_0 = 0,
Precision_2 = 2,
Precision_4 = 4,
Precision_8 = 8,
Precision_16 = 16,
Precision_32 = 32,
Precision_64 = 64,
};
template <size_t Size, typename AllocatorT, EPrecision ... Precisions>
class WeightContainers
{
private:
template <EPrecision Precision>
using BitContainerT = BitContainer<static_cast<uint8_t>(Precision), Size, AllocatorT>;
using TupleT = std::tuple<BitContainerT<Precisions>...>;
TupleT m_WeightContainers;
static_assert(((static_cast<uint8_t>(Precisions) <= static_cast<uint8_t>(EPrecision::Precision_64)) && ...), "Cannot have precision larger than 64 (double precision)");
public:
WeightContainers() = default;
WeightContainers(size_t size)
: m_WeightContainers{ BitContainerT<Precisions>(size, AllocatorT()) ... }
{}
WeightContainers(size_t size, AllocatorT& allocator)
: m_WeightContainers{ BitContainerT<Precisions>(size, allocator) ... }
{}
public:
static constexpr size_t size()
{
return sizeof...(Precisions);
}
/*
template <typename ValueT>
void SetValue(size_t containerIndex, size_t index, ValueT value)
{
SetValueFunctions<ValueT>()[containerIndex](*this, index, value);
}
*/
void SetValueFloat(size_t containerIndex, size_t index, double value)
{
SetFloatValueFunctions()[containerIndex](*this, index, value);
}
template <size_t MaxWeight>
uint64_t GetValue(size_t containerIndex, size_t index)
{
return GetValueFunctions<MaxWeight>()[containerIndex](*this, index);
}
private:
/*
template <typename ValueT>
static constexpr auto& SetValueFunctions()
{
return SetValueFunctions<ValueT>(std::make_index_sequence<size()>());
}
template <typename ValueT, size_t ... Is>
static constexpr auto& SetValueFunctions(std::index_sequence<Is...>)
{
static constexpr std::array<void(*)(WeightContainers& weightContainers, size_t index, ValueT value), VariableIDMapT::size()> setValueFunctions =
{
[] (WeightContainers& weightContainers, size_t index, ValueT value) {
std::get<Is>(weightContainers.m_WeightContainers)[index] = value;
},
...
};
return setValueFunctions;
}
*/
static constexpr auto& SetFloatValueFunctions()
{
return SetFloatValueFunctions(std::make_index_sequence<size()>());
}
template <size_t ... Is>
static constexpr auto& SetFloatValueFunctions(std::index_sequence<Is...>)
{
using FunctionT = void(*)(WeightContainers& weightContainers, size_t index, double value);
constexpr std::array<FunctionT, size()> setFloatValueFunctions
{
[](WeightContainers& weightContainers, size_t index, double value) -> FunctionT {
using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
if constexpr (!std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
{
constexpr_assert(value >= 0.0 && value <= 1.0, "Value must be between 0.0 and 1.0");
std::get<Is>(weightContainers.m_WeightContainers)[index] = static_cast<BitContainerEntryT::StorageType>(value * BitContainerEntryT::MaxValue);
}
}
...
};
return setFloatValueFunctions;
}
template <size_t MaxWeight>
static constexpr auto& GetValueFunctions()
{
return GetValueFunctions<MaxWeight>(std::make_index_sequence<size()>());
}
template <size_t MaxWeight, size_t ... Is>
static constexpr auto& GetValueFunctions(std::index_sequence<Is...>)
{
using FunctionT = uint64_t(*)(WeightContainers& weightContainers, size_t index);
constexpr std::array<FunctionT, size()> getValueFunctions =
{
[] (WeightContainers& weightContainers, size_t index) -> FunctionT {
using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
if constexpr (std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
{
return MaxWeight / 2;
}
else
{
constexpr size_t maxValue = BitContainerEntryT::MaxValue;
if constexpr (maxValue <= MaxWeight)
{
return std::get<Is>(weightContainers.m_WeightContainers)[index];
}
else
{
return static_cast<uint64_t>(std::get<Is>(weightContainers.m_WeightContainers)[index]) * MaxWeight / maxValue;
}
}
}
...
};
return getValueFunctions;
}
};
/**
* @brief Compile-time weights storage for weighted random selection
* @tparam VarT The variable type
* @tparam VariableIDMapT The variable ID map type
* @tparam DefaultWeight The default weight for values not explicitly specified
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
*/
template <typename VariableIDMapT, typename ... PrecisionEntries>
class WeightsMap {
public:
static constexpr std::array<uint8_t, VariableIDMapT::size()> GeneratePrecisionArray()
{
std::array<uint8_t, VariableIDMapT::size()> precisionArray{};
(PrecisionEntries::template UpdatePrecisions<VariableIDMapT>(precisionArray) && ...);
return precisionArray;
}
static constexpr std::array<uint8_t, VariableIDMapT::size()> GetPrecisionArray()
{
constexpr std::array<uint8_t, VariableIDMapT::size()> precisionArray = GeneratePrecisionArray();
return precisionArray;
}
static constexpr size_t GetPrecision(size_t index)
{
return GetPrecisionArray()[index];
}
static constexpr uint8_t GetMaxPrecision()
{
return std::max<uint8_t>({PrecisionEntries::PrecisionValue ...});
}
static constexpr uint8_t GetMaxValue()
{
return (1 << GetMaxPrecision()) - 1;
}
static constexpr bool HasWeights()
{
return sizeof...(PrecisionEntries) > 0;
}
public:
using VariablesT = VariableIDMapT;
template<size_t Size, typename AllocatorT, size_t... Is>
auto MakeWeightContainersT(AllocatorT*, std::index_sequence<Is...>)
-> WeightContainers<Size, AllocatorT, GetPrecision(Is) ...>;
template <size_t Size, typename AllocatorT>
using WeightContainersT = decltype(
MakeWeightContainersT<Size>(static_cast<AllocatorT*>(nullptr), std::make_index_sequence<VariableIDMapT::size()>{})
);
template <typename PrecisionEntryT>
using Merge = WeightsMap<VariableIDMapT, PrecisionEntries..., PrecisionEntryT>;
using WeightT = typename BitContainer<GetMaxPrecision(), 0, std::allocator<uint8_t>>::StorageType;
};
}