migration 24/12/2025
This commit is contained in:
@@ -8,6 +8,8 @@
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include <nd-wfc/wfc.h>
|
||||
|
||||
// Forward declarations
|
||||
struct NonogramHints;
|
||||
struct NonogramSolution;
|
||||
@@ -178,3 +180,7 @@ std::string trim(const std::string& str);
|
||||
bool startsWith(const std::string& str, const std::string& prefix);
|
||||
std::vector<std::string> split(const std::string& str, char delimiter);
|
||||
std::vector<uint8_t> parseNumberSequence(const std::string& str, char delimiter);
|
||||
|
||||
using NonogramWFC = WFC::Builder<Nonogram, bool>
|
||||
::Define<false, true>
|
||||
::Build;
|
||||
@@ -279,15 +279,15 @@ private:
|
||||
public: // WFC Support
|
||||
using ValueType = uint8_t;
|
||||
|
||||
constexpr inline ValueType getValue(size_t index) const {
|
||||
constexpr inline ValueType getValue(uint8_t index) const {
|
||||
return board_.get(static_cast<int>(index));
|
||||
}
|
||||
|
||||
constexpr inline void setValue(size_t index, ValueType value) {
|
||||
constexpr inline void setValue(uint8_t index, ValueType value) {
|
||||
board_.set(static_cast<int>(index), value);
|
||||
}
|
||||
|
||||
constexpr inline size_t size() const {
|
||||
constexpr inline uint8_t size() const {
|
||||
return 81;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <type_traits>
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
#include <ranges>
|
||||
#include <concepts>
|
||||
#include <bit>
|
||||
#include <span>
|
||||
@@ -22,60 +23,50 @@
|
||||
#include "wfc_callbacks.hpp"
|
||||
#include "wfc_random.hpp"
|
||||
#include "wfc_queue.hpp"
|
||||
#include "wfc_weights.hpp"
|
||||
|
||||
namespace WFC {
|
||||
|
||||
template<typename T>
|
||||
concept WorldType = requires(T world, size_t id, typename T::ValueType value) {
|
||||
{ world.size() } -> std::convertible_to<size_t>;
|
||||
{ world.setValue(id, value) };
|
||||
{ world.getValue(id) } -> std::convertible_to<typename T::ValueType>;
|
||||
concept WorldType = requires(T world, typename T::ValueType value) {
|
||||
{ world.size() } -> std::is_integral;
|
||||
{ world.setValue(static_cast<decltype(world.size())>(0), value) };
|
||||
{ world.getValue(static_cast<decltype(world.size())>(0)) } -> std::convertible_to<typename T::ValueType>;
|
||||
typename T::ValueType;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Concept to validate constrainer function signature
|
||||
* The function must be callable with parameters: (WorldT&, size_t, WorldValue<VarT>, Constrainer<VariableIDMapT>&)
|
||||
*/
|
||||
template <typename T, typename WorldT, typename VarT, typename VariableIDMapT, typename PropagationQueueType>
|
||||
concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, WorldValue<VarT> value, Constrainer<VariableIDMapT, PropagationQueueType>& constrainer) {
|
||||
func(world, index, value, constrainer);
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Concept to validate random selector function signature
|
||||
* The function must be callable with parameters: (std::span<const VarT>) and return size_t
|
||||
*/
|
||||
template <typename T, typename VarT>
|
||||
concept RandomSelectorFunction = requires(const T& func, std::span<const VarT> possibleValues) {
|
||||
{ func(possibleValues) } -> std::convertible_to<size_t>;
|
||||
{ func.rng(static_cast<uint32_t>(1)) } -> std::convertible_to<uint32_t>;
|
||||
};
|
||||
|
||||
template <typename WorldT>
|
||||
concept HasConstexprSize = requires {
|
||||
{ []() constexpr -> std::size_t { return WorldT{}.size(); }() };
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Main WFC class implementing the Wave Function Collapse algorithm
|
||||
*/
|
||||
template<typename WorldT, typename VarT,
|
||||
typename VariableIDMapT = VariableIDMap<VarT>,
|
||||
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
|
||||
typename CallbacksT = Callbacks<WorldT>,
|
||||
typename RandomSelectorT = DefaultRandomSelector<VarT>>
|
||||
typename RandomSelectorT = DefaultRandomSelector<VarT>,
|
||||
typename WeightsMapT = WeightsMap<VariableIDMapT>
|
||||
>
|
||||
class WFC {
|
||||
public:
|
||||
static_assert(WorldType<WorldT>, "WorldT must satisfy World type requirements");
|
||||
|
||||
// Try getting the world size, which is only available if the world type has a constexpr size() method
|
||||
constexpr static size_t WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
|
||||
using WorldSizeT = decltype(WorldT{}.size());
|
||||
|
||||
using WaveType = Wave<VariableIDMapT, WorldSize>;
|
||||
using PropagationQueueType = WFCQueue<WorldSize>;
|
||||
// Try getting the world size, which is only available if the world type has a constexpr size() method
|
||||
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
|
||||
|
||||
using WaveType = Wave<VariableIDMapT, WeightsMapT, WorldSize>;
|
||||
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
|
||||
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
|
||||
using MaskType = typename WaveType::ElementT;
|
||||
using VariableIDT = typename WaveType::VariableIDT;
|
||||
using WeightsBufferType = BitContainer<WeightsMapT::GetMaxPrecision(), VariableIDMapT::size()>;
|
||||
|
||||
public:
|
||||
struct SolverState
|
||||
@@ -88,7 +79,7 @@ public:
|
||||
|
||||
SolverState(WorldT& world, uint32_t seed)
|
||||
: m_world(world)
|
||||
, m_propagationQueue{ WorldSize ? WorldSize : static_cast<size_t>(world.size()) }
|
||||
, m_propagationQueue{ WorldSize ? WorldSize : static_cast<WorldSizeT>(world.size()) }
|
||||
, m_randomSelector(seed)
|
||||
{}
|
||||
|
||||
@@ -114,7 +105,7 @@ public:
|
||||
*/
|
||||
static bool Run(SolverState& state)
|
||||
{
|
||||
WaveType wave{ WorldSize, VariableIDMapT::ValuesRegisteredAmount, state.m_allocator };
|
||||
WaveType wave{ WorldSize, VariableIDMapT::size(), state.m_allocator };
|
||||
|
||||
PropogateInitialValues(state, wave);
|
||||
|
||||
@@ -188,7 +179,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
static void CollapseCell(SolverState& state, WaveType& wave, size_t cellId, uint16_t value)
|
||||
static void CollapseCell(SolverState& state, WaveType& wave, WorldSizeT cellId, VariableIDT value)
|
||||
{
|
||||
constexpr_assert(!wave.IsCollapsed(cellId) || wave.GetMask(cellId) == (MaskType(1) << value));
|
||||
wave.Collapse(cellId, 1 << value);
|
||||
@@ -201,51 +192,111 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
static bool Branch(SolverState& state, WaveType& wave)
|
||||
static WorldSizeT FindMinimumEntropyCells(std::span<VariableIDT>& buffer, WaveType& wave)
|
||||
{
|
||||
constexpr_assert(state.m_propagationQueue.empty());
|
||||
|
||||
// Find cell with minimum entropy > 1
|
||||
size_t minEntropyCell = static_cast<size_t>(-1);
|
||||
size_t minEntropy = static_cast<size_t>(-1);
|
||||
|
||||
for (size_t i = 0; i < wave.size(); ++i) {
|
||||
size_t entropy = wave.Entropy(i);
|
||||
if (entropy > 1 && entropy < minEntropy) {
|
||||
minEntropy = entropy;
|
||||
minEntropyCell = i;
|
||||
}
|
||||
}
|
||||
if (minEntropyCell == static_cast<size_t>(-1)) return false;
|
||||
auto entropyGetter = [&wave](size_t index) -> size_t { return wave.Entropy(index); };
|
||||
auto entropyFilter = [&wave](size_t entropy) -> bool { return entropy > 1; };
|
||||
auto minEntropyCell = *std::ranges::min_element(std::views::iota(0, wave.size()) | std::views::transform(entropyGetter) | std::views::filter(entropyFilter));
|
||||
|
||||
constexpr_assert(!wave.IsCollapsed(minEntropyCell));
|
||||
|
||||
// create a list of possible values
|
||||
uint16_t availableValues = static_cast<uint16_t>(wave.Entropy(minEntropyCell));
|
||||
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> possibleValues; // inplace vector
|
||||
VariableIDT availableValues = wave.Entropy(minEntropyCell);
|
||||
MaskType mask = wave.GetMask(minEntropyCell);
|
||||
for (size_t i = 0; i < availableValues; ++i)
|
||||
{
|
||||
uint16_t index = static_cast<uint16_t>(std::countr_zero(mask)); // get the index of the lowest set bit
|
||||
constexpr_assert(index < VariableIDMapT::ValuesRegisteredAmount, "Possible value went outside bounds");
|
||||
VariableIDT index = static_cast<VariableIDT>(std::countr_zero(mask)); // get the index of the lowest set bit
|
||||
constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds");
|
||||
|
||||
possibleValues[i] = index;
|
||||
buffer[i] = index;
|
||||
constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set");
|
||||
|
||||
mask = mask & (mask - 1); // turn off lowest set bit
|
||||
}
|
||||
|
||||
// randomly select a value from possible values
|
||||
while (availableValues)
|
||||
return minEntropyCell;
|
||||
}
|
||||
|
||||
using WeightsType = MinimumIntegerType<WeightsMapT::GetMaxValue()>;
|
||||
using RandomGeneratorReturnType = decltype(RandomSelectorT{}.rng(static_cast<uint32_t>(1)));
|
||||
|
||||
static WorldSizeT FindMinimumEntropyCellsWeighted(std::span<VariableIDT>& buffer, std::span<WeightsType>& weights, WaveType& wave)
|
||||
{
|
||||
constexpr size_t ElementsMaxWeight = std::min(WeightsMapT::GetMaxValue() * VariableIDMapT::size(), std::numeric_limits<RandomGeneratorReturnType>::max()) / VariableIDMapT::size();
|
||||
|
||||
auto accumulatedWeightedEntropyGetter = [&wave](size_t index) -> uint64_t
|
||||
{
|
||||
// Create a span of the actual variable values for the random selector
|
||||
std::array<VarT, VariableIDMapT::ValuesRegisteredAmount> valueArray;
|
||||
for (size_t i = 0; i < availableValues; ++i) {
|
||||
valueArray[i] = VariableIDMapT::GetValue(possibleValues[i]);
|
||||
auto entropyFilter = [&wave](size_t index) -> bool { return wave.Entropy(index) > 1; };
|
||||
auto weightedEntropyGetter = [&wave](size_t index) -> uint64_t { return wave.template GetWeight<ElementsMaxWeight>(index); };
|
||||
|
||||
auto view = std::views::iota(0, VariableIDMapT::size()) | std::views::filter(entropyFilter) | std::views::transform(weightedEntropyGetter);
|
||||
return std::accumulate(view.begin(), view.end(), 0);
|
||||
};
|
||||
|
||||
auto minEntropyCell = *std::ranges::min_element(std::views::iota(0, wave.size()) | std::views::transform(accumulatedWeightedEntropyGetter));
|
||||
|
||||
VariableIDT availableValues = wave.Entropy(minEntropyCell);
|
||||
MaskType mask = wave.GetMask(minEntropyCell);
|
||||
for (size_t i = 0; i < availableValues; ++i)
|
||||
{
|
||||
VariableIDT index = static_cast<VariableIDT>(std::countr_zero(mask)); // get the index of the lowest set bit
|
||||
constexpr_assert(index < VariableIDMapT::size(), "Possible value went outside bounds");
|
||||
|
||||
buffer[i] = index;
|
||||
constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set");
|
||||
|
||||
weights[i] = wave.template GetWeight<ElementsMaxWeight>(index);
|
||||
|
||||
mask = mask & (mask - 1); // turn off lowest set bit
|
||||
}
|
||||
|
||||
return minEntropyCell;
|
||||
}
|
||||
|
||||
static bool Branch(SolverState& state, WaveType& wave)
|
||||
{
|
||||
constexpr_assert(state.m_propagationQueue.empty());
|
||||
|
||||
std::array<VariableIDT, VariableIDMapT::size()> Buffer{};
|
||||
std::array<WeightsType, VariableIDMapT::size()> WeightsBuffer{};
|
||||
uint64_t accumulatedWeights = 0;
|
||||
WorldSizeT minEntropyCell{};
|
||||
|
||||
if constexpr (WeightsMapT::HasWeights())
|
||||
{
|
||||
minEntropyCell = FindMinimumEntropyCellsWeighted(Buffer, WeightsBuffer, wave);
|
||||
accumulatedWeights = std::accumulate(WeightsBuffer.begin(), WeightsBuffer.end(), 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
minEntropyCell = FindMinimumEntropyCells(Buffer, wave);
|
||||
}
|
||||
|
||||
// randomly select a value from possible values
|
||||
while (Buffer.size())
|
||||
{
|
||||
size_t randomIndex;
|
||||
VariableIDT selectedValue;
|
||||
if constexpr (WeightsMapT::HasWeights())
|
||||
{
|
||||
auto randomWeight = state.m_randomSelector.rng(accumulatedWeights);
|
||||
for (size_t i = 0; i < WeightsBuffer.size(); ++i)
|
||||
{
|
||||
if (randomWeight < WeightsBuffer[i])
|
||||
{
|
||||
randomIndex = i;
|
||||
selectedValue = Buffer[i];
|
||||
break;
|
||||
}
|
||||
randomWeight -= WeightsBuffer[i];
|
||||
}
|
||||
accumulatedWeights -= WeightsBuffer[randomIndex];
|
||||
}
|
||||
else
|
||||
{
|
||||
randomIndex = state.m_randomSelector.rng(Buffer.size());
|
||||
selectedValue = Buffer[randomIndex];
|
||||
}
|
||||
std::span<const VarT> currentPossibleValues(valueArray.data(), availableValues);
|
||||
size_t randomIndex = state.m_randomSelector(currentPossibleValues);
|
||||
size_t selectedValue = possibleValues[randomIndex];
|
||||
|
||||
{
|
||||
// copy the state and branch out
|
||||
@@ -253,12 +304,12 @@ private:
|
||||
auto queueFrame = state.m_propagationQueue.createBranchPoint();
|
||||
|
||||
auto newWave = wave;
|
||||
CollapseCell(state, newWave, minEntropyCell, static_cast<uint16_t>(selectedValue));
|
||||
CollapseCell(state, newWave, minEntropyCell, selectedValue);
|
||||
state.m_propagationQueue.push(minEntropyCell);
|
||||
|
||||
if (RunLoop(state, newWave))
|
||||
{
|
||||
// copy the solution to the original state
|
||||
// move the solution to the original state
|
||||
wave = newWave;
|
||||
|
||||
return true;
|
||||
@@ -271,7 +322,8 @@ private:
|
||||
constexpr_assert((wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) == 0, "Wave was not collapsed correctly");
|
||||
|
||||
// swap replacement value with the last value
|
||||
std::swap(possibleValues[randomIndex], possibleValues[--availableValues]);
|
||||
std::swap(Buffer[randomIndex], Buffer[Buffer.size() - 1]);
|
||||
Buffer = Buffer.subspan(0, Buffer.size() - 1);
|
||||
}
|
||||
|
||||
return false;
|
||||
@@ -281,16 +333,16 @@ private:
|
||||
{
|
||||
while (!state.m_propagationQueue.empty())
|
||||
{
|
||||
size_t cellId = state.m_propagationQueue.pop();
|
||||
WorldSizeT cellId = state.m_propagationQueue.pop();
|
||||
|
||||
if (wave.IsContradicted(cellId)) return false;
|
||||
|
||||
constexpr_assert(wave.IsCollapsed(cellId), "Cell was not collapsed");
|
||||
|
||||
uint16_t variableID = wave.GetVariableID(cellId);
|
||||
VariableIDT variableID = wave.GetVariableID(cellId);
|
||||
ConstrainerType constrainer(wave, state.m_propagationQueue);
|
||||
|
||||
using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue<VarT>, ConstrainerType&);
|
||||
using ConstrainerFunctionPtrT = void(*)(WorldT&, WorldSizeT, WorldValue<VarT>, ConstrainerType&);
|
||||
|
||||
ConstrainerFunctionMapT::template GetFunction<ConstrainerFunctionPtrT>(variableID)(state.m_world, cellId, WorldValue<VarT>{VariableIDMapT::GetValue(variableID), variableID}, constrainer);
|
||||
}
|
||||
@@ -314,7 +366,7 @@ private:
|
||||
{
|
||||
if (state.m_world.getValue(i) == VariableIDMapT::GetValue(j))
|
||||
{
|
||||
CollapseCell(state, wave, static_cast<uint16_t>(i), static_cast<uint16_t>(j));
|
||||
CollapseCell(state, wave, static_cast<WorldSizeT>(i), static_cast<VariableIDT>(j));
|
||||
state.m_propagationQueue.push(i);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -60,6 +60,7 @@ public:
|
||||
static constexpr bool IsMultiElement = (StorageBits > 64);
|
||||
static constexpr bool IsSubByte = (StorageBits < 8);
|
||||
static constexpr size_t ElementsPerByte = sizeof(StorageType) * 8 / std::max<size_t>(1u, StorageBits);
|
||||
static constexpr size_t MaxValue = (StorageType{1} << BitsPerElement) - 1;
|
||||
|
||||
using ContainerType =
|
||||
std::conditional_t<Bits == 0,
|
||||
@@ -110,9 +111,9 @@ public:
|
||||
|
||||
BitContainer() = default;
|
||||
BitContainer(const AllocatorT& allocator) : AllocatorT(allocator) {};
|
||||
explicit BitContainer(size_t initial_size, const AllocatorT& allocator) requires (IsResizable)
|
||||
explicit BitContainer(size_t size, const AllocatorT& allocator) requires (IsResizable)
|
||||
: AllocatorT(allocator)
|
||||
, m_container(initial_size, allocator)
|
||||
, m_container(size, allocator)
|
||||
{};
|
||||
explicit BitContainer(size_t, const AllocatorT& allocator) requires (!IsResizable)
|
||||
: AllocatorT(allocator)
|
||||
|
||||
@@ -7,6 +7,7 @@ namespace WFC {
|
||||
#include "wfc_constrainer.hpp"
|
||||
#include "wfc_callbacks.hpp"
|
||||
#include "wfc_random.hpp"
|
||||
#include "wfc_weights.hpp"
|
||||
#include "wfc.hpp"
|
||||
|
||||
/**
|
||||
@@ -19,33 +20,37 @@ template<
|
||||
typename ConstrainerFunctionMapT = ConstrainerFunctionMap<void*>,
|
||||
typename CallbacksT = Callbacks<WorldT>,
|
||||
typename RandomSelectorT = DefaultRandomSelector<VarT>,
|
||||
typename WeightsMapT = WeightsMap<VariableIDMapT>,
|
||||
typename SelectedValueT = void>
|
||||
class Builder {
|
||||
public:
|
||||
constexpr static size_t WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
|
||||
using WorldSizeT = decltype(WorldT{}.size());
|
||||
constexpr static WorldSizeT WorldSize = HasConstexprSize<WorldT> ? WorldT{}.size() : 0;
|
||||
|
||||
using WaveType = Wave<VariableIDMapT, WorldSize>;
|
||||
using PropagationQueueType = WFCQueue<WorldSize>;
|
||||
using WaveType = Wave<VariableIDMapT, WeightsMapT, WorldSize>;
|
||||
using PropagationQueueType = WFCQueue<WorldSize, WorldSizeT>;
|
||||
using ConstrainerType = Constrainer<WaveType, PropagationQueueType>;
|
||||
|
||||
|
||||
template <VarT ... Values>
|
||||
using DefineIDs = Builder<WorldT, VarT, VariableIDMap<VarT, Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap<VarT, Values...>>;
|
||||
using DefineIDs = Builder<WorldT, VarT, VariableIDMap<VarT, Values...>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap<VarT, Values...>>;
|
||||
|
||||
template <size_t RangeStart, size_t RangeEnd>
|
||||
using DefineRange = Builder<WorldT, VarT, VariableIDRange<VarT, RangeStart, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
|
||||
using DefineRange = Builder<WorldT, VarT, VariableIDRange<VarT, RangeStart, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
|
||||
|
||||
template <size_t RangeEnd>
|
||||
using DefineRange0 = Builder<WorldT, VarT, VariableIDRange<VarT, 0, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, 0, RangeEnd>>;
|
||||
using DefineRange0 = Builder<WorldT, VarT, VariableIDRange<VarT, 0, RangeEnd>, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, 0, RangeEnd>>;
|
||||
|
||||
|
||||
template <VarT ... Values>
|
||||
using Variable = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDMap<VarT, Values...>>;
|
||||
using Variable = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDMap<VarT, Values...>>;
|
||||
|
||||
template <size_t RangeStart, size_t RangeEnd>
|
||||
using VariableRange = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
|
||||
using VariableRange = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT, VariableIDRange<VarT, RangeStart, RangeEnd>>;
|
||||
|
||||
|
||||
using EmptyConstrainerFunctionT = EmptyConstrainerFunction<WorldT, WorldSizeT, VarT, ConstrainerType>;
|
||||
|
||||
template <typename ConstrainerFunctionT>
|
||||
requires ConstrainerFunction<ConstrainerFunctionT, WorldT, VarT, WaveType, PropagationQueueType>
|
||||
using Constrain = Builder<WorldT, VarT, VariableIDMapT,
|
||||
@@ -54,8 +59,8 @@ public:
|
||||
ConstrainerFunctionMapT,
|
||||
ConstrainerFunctionT,
|
||||
SelectedValueT,
|
||||
decltype([](WorldT&, size_t, WorldValue<VarT>, ConstrainerType&) {})
|
||||
>, CallbacksT, RandomSelectorT, SelectedValueT
|
||||
EmptyConstrainerFunctionT
|
||||
>, CallbacksT, RandomSelectorT, WeightsMapT, SelectedValueT
|
||||
>;
|
||||
|
||||
template <typename ConstrainerFunctionT>
|
||||
@@ -66,27 +71,28 @@ public:
|
||||
ConstrainerFunctionMapT,
|
||||
ConstrainerFunctionT,
|
||||
VariableIDMapT,
|
||||
decltype([](WorldT&, size_t, WorldValue<VarT>, ConstrainerType&) {})
|
||||
>, CallbacksT, RandomSelectorT
|
||||
EmptyConstrainerFunctionT
|
||||
>, CallbacksT, RandomSelectorT, WeightsMapT
|
||||
>;
|
||||
|
||||
|
||||
template <typename NewCellCollapsedCallbackT>
|
||||
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT>;
|
||||
using SetCellCollapsedCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetCellCollapsedCallbackT<NewCellCollapsedCallbackT>, RandomSelectorT, WeightsMapT>;
|
||||
template <typename NewContradictionCallbackT>
|
||||
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT>;
|
||||
using SetContradictionCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetContradictionCallbackT<NewContradictionCallbackT>, RandomSelectorT, WeightsMapT>;
|
||||
template <typename NewBranchCallbackT>
|
||||
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT>;
|
||||
using SetBranchCallback = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, typename CallbacksT::template SetBranchCallbackT<NewBranchCallbackT>, RandomSelectorT, WeightsMapT>;
|
||||
|
||||
|
||||
template <typename NewRandomSelectorT>
|
||||
requires RandomSelectorFunction<NewRandomSelectorT, VarT>
|
||||
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
|
||||
|
||||
template <uint16_t DefaultWeight, typename... WeightSpecs>
|
||||
using Weights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, WeightedSelector<VarT, VariableIDMapT, RandomSelectorT, WeightsMap<VarT, VariableIDMapT, DefaultWeight, WeightSpecs...>>>;
|
||||
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT, WeightsMapT>;
|
||||
|
||||
|
||||
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
|
||||
template <EPrecision Precision>
|
||||
using SetWeights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, typename WeightsMapT::template Merge<PrecisionEntry<SelectedValueT, static_cast<uint8_t>(Precision)>>, SelectedValueT>;
|
||||
|
||||
|
||||
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT, WeightsMapT>;
|
||||
};
|
||||
|
||||
}
|
||||
@@ -5,6 +5,12 @@
|
||||
|
||||
namespace WFC {
|
||||
|
||||
template <typename WorldT, typename WorldSizeT, typename VarT, typename ConstainerType>
|
||||
struct EmptyConstrainerFunction
|
||||
{
|
||||
void operator()(WorldT&, WorldSizeT, WorldValue<VarT>, ConstainerType&) const {}
|
||||
};
|
||||
|
||||
template <typename ... ConstrainerFunctions>
|
||||
struct ConstrainerFunctionMap {
|
||||
public:
|
||||
@@ -56,7 +62,7 @@ template<typename VariableIDMapT,
|
||||
typename SelectedIDsVariableIDMapT,
|
||||
typename EmptyFunctionT>
|
||||
using MergedConstrainerFunctionMap = decltype(
|
||||
MakeMergedConstrainerIDMap(std::make_index_sequence<VariableIDMapT::ValuesRegisteredAmount>{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr)
|
||||
MakeMergedConstrainerIDMap(std::make_index_sequence<VariableIDMapT::size()>{}, (VariableIDMapT*)nullptr, (ConstrainerFunctionMapT*)nullptr, (NewConstrainerFunctionT*)nullptr, (SelectedIDsVariableIDMapT*)nullptr, (EmptyFunctionT*)nullptr)
|
||||
);
|
||||
|
||||
/**
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <span>
|
||||
#include <algorithm>
|
||||
|
||||
#include "nd-wfc/wfc_utils.hpp"
|
||||
#include "wfc_utils.hpp"
|
||||
|
||||
namespace WFC
|
||||
{
|
||||
|
||||
@@ -14,13 +14,6 @@ private:
|
||||
public:
|
||||
constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {}
|
||||
|
||||
constexpr size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
|
||||
// Simple linear congruential generator for constexpr compatibility
|
||||
return static_cast<size_t>(rng(possibleValues.size()));
|
||||
}
|
||||
|
||||
constexpr uint32_t rng(uint32_t max) const {
|
||||
m_seed = m_seed * 1103515245 + 12345;
|
||||
return m_seed % max;
|
||||
@@ -39,140 +32,10 @@ private:
|
||||
public:
|
||||
explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {}
|
||||
|
||||
size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
|
||||
return rng(possibleValues.size());
|
||||
}
|
||||
|
||||
uint32_t rng(uint32_t max) const {
|
||||
std::uniform_int_distribution<uint32_t> dist(0, max);
|
||||
return dist(m_rng);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Weight specification for a specific value
|
||||
* @tparam Value The variable value
|
||||
* @tparam Weight The 16-bit weight for this value
|
||||
*/
|
||||
template <typename VarT, VarT Value, uint16_t WeightValue>
|
||||
struct Weight {
|
||||
static constexpr VarT value = Value;
|
||||
static constexpr uint16_t weight = WeightValue;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compile-time weights storage for weighted random selection
|
||||
* @tparam VarT The variable type
|
||||
* @tparam VariableIDMapT The variable ID map type
|
||||
* @tparam DefaultWeight The default weight for values not explicitly specified
|
||||
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
|
||||
*/
|
||||
template <typename VarT, typename VariableIDMapT, uint16_t DefaultWeight, typename... WeightSpecs>
|
||||
class WeightsMap {
|
||||
private:
|
||||
static constexpr size_t NumWeights = sizeof...(WeightSpecs);
|
||||
|
||||
// Helper to get weight for a specific value
|
||||
static consteval uint16_t GetWeightForValue(VarT targetValue) {
|
||||
// Check each weight spec to find the target value
|
||||
uint16_t weight = DefaultWeight;
|
||||
((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...);
|
||||
return weight;
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Get the weight for a specific value at compile time
|
||||
* @tparam TargetValue The value to get weight for
|
||||
* @return The weight for the value
|
||||
*/
|
||||
template <VarT TargetValue>
|
||||
static consteval uint16_t GetWeight() {
|
||||
return GetWeightForValue(TargetValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get weights array for all registered values
|
||||
* @return Array of weights corresponding to all registered values
|
||||
*/
|
||||
static consteval std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> GetWeightsArray() {
|
||||
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> weights{};
|
||||
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
weights[i] = GetWeightForValue(VariableIDMapT::GetValue(i));
|
||||
}
|
||||
|
||||
return weights;
|
||||
}
|
||||
|
||||
static consteval uint32_t GetTotalWeight() {
|
||||
uint32_t totalWeight = 0;
|
||||
auto weights = GetWeightsArray();
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
totalWeight += weights[i];
|
||||
}
|
||||
return totalWeight;
|
||||
}
|
||||
|
||||
static consteval std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> GetCumulativeWeightsArray() {
|
||||
auto weights = GetWeightsArray();
|
||||
uint32_t totalWeight = 0;
|
||||
std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> cumulativeWeights{};
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
totalWeight += weights[i];
|
||||
cumulativeWeights[i] = totalWeight;
|
||||
}
|
||||
return cumulativeWeights;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Weighted random selector that uses another random selector as backend
|
||||
* @tparam VarT The variable type
|
||||
* @tparam VariableIDMapT The variable ID map type
|
||||
* @tparam BackendSelectorT The backend random selector type
|
||||
* @tparam WeightsMapT The weights map type containing weight specifications
|
||||
*/
|
||||
template <typename VarT, typename VariableIDMapT, typename BackendSelectorT, typename WeightsMapT>
|
||||
class WeightedSelector {
|
||||
private:
|
||||
BackendSelectorT m_backendSelector;
|
||||
const std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> m_weights;
|
||||
const std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> m_cumulativeWeights;
|
||||
|
||||
public:
|
||||
explicit WeightedSelector(BackendSelectorT backendSelector)
|
||||
: m_backendSelector(backendSelector)
|
||||
, m_weights(WeightsMapT::GetWeightsArray())
|
||||
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
|
||||
{}
|
||||
|
||||
explicit WeightedSelector(uint32_t seed)
|
||||
requires std::is_same_v<BackendSelectorT, DefaultRandomSelector<VarT>>
|
||||
: m_backendSelector(seed)
|
||||
, m_weights(WeightsMapT::GetWeightsArray())
|
||||
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
|
||||
{}
|
||||
|
||||
size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value");
|
||||
|
||||
// Use backend selector to pick a random number in range [0, totalWeight)
|
||||
uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back());
|
||||
|
||||
// Find which value this random value corresponds to
|
||||
for (size_t i = 0; i < possibleValues.size(); ++i) {
|
||||
if (randomValue <= m_cumulativeWeights[i]) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback (should not reach here)
|
||||
return possibleValues.size() - 1;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@@ -6,10 +6,32 @@ namespace WFC
|
||||
|
||||
|
||||
|
||||
|
||||
inline constexpr void constexpr_assert(bool condition, const char* message = "") {
|
||||
# ifdef _DEBUG
|
||||
inline constexpr void constexpr_assert(bool condition, const char* message = "")
|
||||
{
|
||||
if (!condition) throw message;
|
||||
}
|
||||
#else
|
||||
inline constexpr void constexpr_assert(bool condition, const char* message = "")
|
||||
{
|
||||
(void)condition;
|
||||
(void)message;
|
||||
}
|
||||
# endif
|
||||
|
||||
template <size_t Size>
|
||||
using MinimumIntegerType = std::conditional_t<Size <= std::numeric_limits<uint8_t>::max(), uint8_t,
|
||||
std::conditional_t<Size <= std::numeric_limits<uint16_t>::max(), uint16_t,
|
||||
std::conditional_t<Size <= std::numeric_limits<uint32_t>::max(), uint32_t,
|
||||
uint64_t>>>;
|
||||
|
||||
template <uint8_t bits>
|
||||
using MinimumBitsType = std::conditional_t<bits <= 8, uint8_t,
|
||||
std::conditional_t<bits <= 16, uint16_t,
|
||||
std::conditional_t<bits <= 32, uint32_t,
|
||||
std::conditional_t<bits <= 64, uint64_t,
|
||||
void>>>>;
|
||||
|
||||
|
||||
inline int FindNthSetBit(size_t num, int n) {
|
||||
constexpr_assert(n < std::popcount(num), "index is out of range");
|
||||
|
||||
@@ -12,23 +12,25 @@ namespace WFC {
|
||||
* It is a compile-time map of variable values to indices.
|
||||
*/
|
||||
|
||||
template <size_t VariablesAmount>
|
||||
using VariableIDType = std::conditional_t<VariablesAmount <= std::numeric_limits<uint8_t>::max(), uint8_t, uint16_t>;
|
||||
|
||||
|
||||
template <typename VarT, VarT ... Values>
|
||||
class VariableIDMap {
|
||||
public:
|
||||
|
||||
using Type = VarT;
|
||||
static constexpr size_t ValuesRegisteredAmount = sizeof...(Values);
|
||||
|
||||
template <VarT ... AdditionalValues>
|
||||
using Merge = VariableIDMap<VarT, Values..., AdditionalValues...>;
|
||||
|
||||
using VariableIDT = VariableIDType<sizeof...(Values)>;
|
||||
|
||||
template <VarT Value>
|
||||
static consteval bool HasValue()
|
||||
{
|
||||
constexpr VarT arr[] = {Values...};
|
||||
constexpr size_t size = sizeof...(Values);
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
for (size_t i = 0; i < size(); ++i)
|
||||
if (arr[i] == Value)
|
||||
return true;
|
||||
return false;
|
||||
@@ -39,9 +41,8 @@ public:
|
||||
{
|
||||
static_assert(HasValue<Value>(), "Value was not defined");
|
||||
constexpr VarT arr[] = {Values...};
|
||||
constexpr size_t size = ValuesRegisteredAmount;
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
for (size_t i = 0; i < size(); ++i)
|
||||
if (arr[i] == Value)
|
||||
return i;
|
||||
|
||||
@@ -54,15 +55,15 @@ public:
|
||||
{
|
||||
Values...
|
||||
};
|
||||
return std::span<const VarT>{ allValues, ValuesRegisteredAmount };
|
||||
return std::span<const VarT>{ allValues, size() };
|
||||
}
|
||||
|
||||
static constexpr VarT GetValue(size_t index) {
|
||||
constexpr_assert(index < ValuesRegisteredAmount);
|
||||
constexpr_assert(index < size());
|
||||
return GetAllValues()[index];
|
||||
}
|
||||
|
||||
static consteval size_t size() { return ValuesRegisteredAmount; }
|
||||
static consteval size_t size() { return sizeof...(Values); }
|
||||
|
||||
template <VarT ... ValuesSlice>
|
||||
static constexpr auto ValuesToIndices() -> std::array<size_t, sizeof...(ValuesSlice)> {
|
||||
@@ -76,12 +77,13 @@ class VariableIDRange
|
||||
{
|
||||
public:
|
||||
using Type = VarT;
|
||||
using VariableIDT = VariableIDType<End - Start>;
|
||||
|
||||
static_assert(Start < End, "Start must be less than End");
|
||||
static_assert(std::numeric_limits<VarT>::min() <= Start, "VarT must be able to represent all values in the range");
|
||||
static_assert(std::numeric_limits<VarT>::max() >= End, "VarT must be able to represent all values in the range");
|
||||
|
||||
static constexpr size_t ValuesRegisteredAmount = End - Start;
|
||||
static constexpr size_t size() { return End - Start; }
|
||||
|
||||
template <VarT Value>
|
||||
static consteval bool HasValue()
|
||||
@@ -100,8 +102,6 @@ public:
|
||||
return Start + index;
|
||||
}
|
||||
|
||||
static consteval size_t size() { return End - Start; }
|
||||
|
||||
template <VarT ... ValuesSlice>
|
||||
static constexpr auto ValuesToIndices() -> std::array<size_t, sizeof...(ValuesSlice)> {
|
||||
std::array<size_t, sizeof...(ValuesSlice)> indices = {GetIndex<ValuesSlice>()...};
|
||||
|
||||
@@ -6,12 +6,16 @@
|
||||
|
||||
namespace WFC {
|
||||
|
||||
template <typename VariableIDMapT, size_t Size = 0>
|
||||
template <typename VariableIDMapT, typename WeightsMapT, size_t Size = 0>
|
||||
class Wave {
|
||||
public:
|
||||
using BitContainerT = BitContainer<VariableIDMapT::ValuesRegisteredAmount, Size>;
|
||||
using BitContainerT = BitContainer<VariableIDMapT::size(), Size>;
|
||||
using ElementT = typename BitContainerT::StorageType;
|
||||
using IDMapT = VariableIDMapT;
|
||||
using WeightContainersT = typename WeightsMapT::template WeightContainersT<Size, WFCStackAllocator>;
|
||||
using VariableIDT = typename VariableIDMapT::VariableIDT;
|
||||
using WeightT = typename WeightsMapT::WeightT;
|
||||
|
||||
static constexpr size_t ElementsAmount = Size;
|
||||
|
||||
public:
|
||||
@@ -34,8 +38,15 @@ public:
|
||||
uint16_t GetVariableID(size_t index) const { return static_cast<uint16_t>(std::countr_zero(m_data[index])); }
|
||||
ElementT GetMask(size_t index) const { return m_data[index]; }
|
||||
|
||||
void SetWeight(VariableIDT containerIndex, size_t elementIndex, double weight) { m_weights.SetValueFloat(containerIndex, elementIndex, weight); }
|
||||
|
||||
template <size_t MaxWeight>
|
||||
WeightT GetWeight(VariableIDT containerIndex, size_t elementIndex) const { return m_weights.template GetValue<MaxWeight>(containerIndex, elementIndex); }
|
||||
|
||||
|
||||
private:
|
||||
BitContainerT m_data;
|
||||
WeightContainersT m_weights;
|
||||
};
|
||||
|
||||
}
|
||||
234
include/nd-wfc/wfc_weights.hpp
Normal file
234
include/nd-wfc/wfc_weights.hpp
Normal file
@@ -0,0 +1,234 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <span>
|
||||
#include <tuple>
|
||||
|
||||
#include "wfc_bit_container.hpp"
|
||||
#include "wfc_utils.hpp"
|
||||
|
||||
namespace WFC {
|
||||
|
||||
template <typename VariableMap, uint8_t Precision>
|
||||
struct PrecisionEntry
|
||||
{
|
||||
|
||||
constexpr static uint8_t PrecisionValue = Precision;
|
||||
|
||||
template <typename MainVariableMap>
|
||||
constexpr static bool UpdatePrecisions(std::span<uint8_t> precisions)
|
||||
{
|
||||
constexpr auto SelectedEntries = VariableMap::GetAllValues();
|
||||
for (auto entry : SelectedEntries)
|
||||
{
|
||||
precisions[MainVariableMap::GetIndex(entry)] = Precision;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
enum class EPrecision : uint8_t
|
||||
{
|
||||
Precision_0 = 0,
|
||||
Precision_2 = 2,
|
||||
Precision_4 = 4,
|
||||
Precision_8 = 8,
|
||||
Precision_16 = 16,
|
||||
Precision_32 = 32,
|
||||
Precision_64 = 64,
|
||||
};
|
||||
|
||||
template <size_t Size, typename AllocatorT, EPrecision ... Precisions>
|
||||
class WeightContainers
|
||||
{
|
||||
private:
|
||||
template <EPrecision Precision>
|
||||
using BitContainerT = BitContainer<static_cast<uint8_t>(Precision), Size, AllocatorT>;
|
||||
|
||||
using TupleT = std::tuple<BitContainerT<Precisions>...>;
|
||||
TupleT m_WeightContainers;
|
||||
|
||||
static_assert(((static_cast<uint8_t>(Precisions) <= static_cast<uint8_t>(EPrecision::Precision_64)) && ...), "Cannot have precision larger than 64 (double precision)");
|
||||
|
||||
public:
|
||||
WeightContainers() = default;
|
||||
WeightContainers(size_t size)
|
||||
: m_WeightContainers{ BitContainerT<Precisions>(size, AllocatorT()) ... }
|
||||
{}
|
||||
WeightContainers(size_t size, AllocatorT& allocator)
|
||||
: m_WeightContainers{ BitContainerT<Precisions>(size, allocator) ... }
|
||||
{}
|
||||
|
||||
public:
|
||||
static constexpr size_t size()
|
||||
{
|
||||
return sizeof...(Precisions);
|
||||
}
|
||||
|
||||
/*
|
||||
template <typename ValueT>
|
||||
void SetValue(size_t containerIndex, size_t index, ValueT value)
|
||||
{
|
||||
SetValueFunctions<ValueT>()[containerIndex](*this, index, value);
|
||||
}
|
||||
*/
|
||||
void SetValueFloat(size_t containerIndex, size_t index, double value)
|
||||
{
|
||||
SetFloatValueFunctions()[containerIndex](*this, index, value);
|
||||
}
|
||||
|
||||
template <size_t MaxWeight>
|
||||
uint64_t GetValue(size_t containerIndex, size_t index)
|
||||
{
|
||||
return GetValueFunctions<MaxWeight>()[containerIndex](*this, index);
|
||||
}
|
||||
|
||||
private:
|
||||
/*
|
||||
template <typename ValueT>
|
||||
static constexpr auto& SetValueFunctions()
|
||||
{
|
||||
return SetValueFunctions<ValueT>(std::make_index_sequence<size()>());
|
||||
}
|
||||
|
||||
template <typename ValueT, size_t ... Is>
|
||||
static constexpr auto& SetValueFunctions(std::index_sequence<Is...>)
|
||||
{
|
||||
static constexpr std::array<void(*)(WeightContainers& weightContainers, size_t index, ValueT value), VariableIDMapT::size()> setValueFunctions =
|
||||
{
|
||||
[] (WeightContainers& weightContainers, size_t index, ValueT value) {
|
||||
std::get<Is>(weightContainers.m_WeightContainers)[index] = value;
|
||||
},
|
||||
...
|
||||
};
|
||||
return setValueFunctions;
|
||||
}
|
||||
*/
|
||||
static constexpr auto& SetFloatValueFunctions()
|
||||
{
|
||||
return SetFloatValueFunctions(std::make_index_sequence<size()>());
|
||||
}
|
||||
|
||||
template <size_t ... Is>
|
||||
static constexpr auto& SetFloatValueFunctions(std::index_sequence<Is...>)
|
||||
{
|
||||
using FunctionT = void(*)(WeightContainers& weightContainers, size_t index, double value);
|
||||
constexpr std::array<FunctionT, size()> setFloatValueFunctions
|
||||
{
|
||||
[](WeightContainers& weightContainers, size_t index, double value) -> FunctionT {
|
||||
|
||||
using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
|
||||
if constexpr (!std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
|
||||
{
|
||||
constexpr_assert(value >= 0.0 && value <= 1.0, "Value must be between 0.0 and 1.0");
|
||||
std::get<Is>(weightContainers.m_WeightContainers)[index] = static_cast<BitContainerEntryT::StorageType>(value * BitContainerEntryT::MaxValue);
|
||||
}
|
||||
}
|
||||
...
|
||||
};
|
||||
return setFloatValueFunctions;
|
||||
}
|
||||
|
||||
template <size_t MaxWeight>
|
||||
static constexpr auto& GetValueFunctions()
|
||||
{
|
||||
return GetValueFunctions<MaxWeight>(std::make_index_sequence<size()>());
|
||||
}
|
||||
|
||||
template <size_t MaxWeight, size_t ... Is>
|
||||
static constexpr auto& GetValueFunctions(std::index_sequence<Is...>)
|
||||
{
|
||||
using FunctionT = uint64_t(*)(WeightContainers& weightContainers, size_t index);
|
||||
constexpr std::array<FunctionT, size()> getValueFunctions =
|
||||
{
|
||||
[] (WeightContainers& weightContainers, size_t index) -> FunctionT {
|
||||
using BitContainerEntryT = typename WeightContainers::TupleT::template tuple_element<Is>::type;
|
||||
|
||||
if constexpr (std::is_same_v<BitContainerEntryT::StorageType, detail::Empty>)
|
||||
{
|
||||
return MaxWeight / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr size_t maxValue = BitContainerEntryT::MaxValue;
|
||||
if constexpr (maxValue <= MaxWeight)
|
||||
{
|
||||
return std::get<Is>(weightContainers.m_WeightContainers)[index];
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<uint64_t>(std::get<Is>(weightContainers.m_WeightContainers)[index]) * MaxWeight / maxValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
...
|
||||
};
|
||||
return getValueFunctions;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compile-time weights storage for weighted random selection
|
||||
* @tparam VarT The variable type
|
||||
* @tparam VariableIDMapT The variable ID map type
|
||||
* @tparam DefaultWeight The default weight for values not explicitly specified
|
||||
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
|
||||
*/
|
||||
template <typename VariableIDMapT, typename ... PrecisionEntries>
|
||||
class WeightsMap {
|
||||
public:
|
||||
static constexpr std::array<uint8_t, VariableIDMapT::size()> GeneratePrecisionArray()
|
||||
{
|
||||
std::array<uint8_t, VariableIDMapT::size()> precisionArray{};
|
||||
|
||||
(PrecisionEntries::template UpdatePrecisions<VariableIDMapT>(precisionArray) && ...);
|
||||
|
||||
return precisionArray;
|
||||
}
|
||||
|
||||
static constexpr std::array<uint8_t, VariableIDMapT::size()> GetPrecisionArray()
|
||||
{
|
||||
constexpr std::array<uint8_t, VariableIDMapT::size()> precisionArray = GeneratePrecisionArray();
|
||||
return precisionArray;
|
||||
}
|
||||
|
||||
static constexpr size_t GetPrecision(size_t index)
|
||||
{
|
||||
return GetPrecisionArray()[index];
|
||||
}
|
||||
|
||||
static constexpr uint8_t GetMaxPrecision()
|
||||
{
|
||||
return std::max<uint8_t>({PrecisionEntries::PrecisionValue ...});
|
||||
}
|
||||
|
||||
static constexpr uint8_t GetMaxValue()
|
||||
{
|
||||
return (1 << GetMaxPrecision()) - 1;
|
||||
}
|
||||
|
||||
static constexpr bool HasWeights()
|
||||
{
|
||||
return sizeof...(PrecisionEntries) > 0;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
using VariablesT = VariableIDMapT;
|
||||
|
||||
template<size_t Size, typename AllocatorT, size_t... Is>
|
||||
auto MakeWeightContainersT(AllocatorT*, std::index_sequence<Is...>)
|
||||
-> WeightContainers<Size, AllocatorT, GetPrecision(Is) ...>;
|
||||
|
||||
template <size_t Size, typename AllocatorT>
|
||||
using WeightContainersT = decltype(
|
||||
MakeWeightContainersT<Size>(static_cast<AllocatorT*>(nullptr), std::make_index_sequence<VariableIDMapT::size()>{})
|
||||
);
|
||||
|
||||
template <typename PrecisionEntryT>
|
||||
using Merge = WeightsMap<VariableIDMapT, PrecisionEntries..., PrecisionEntryT>;
|
||||
|
||||
using WeightT = typename BitContainer<GetMaxPrecision(), 0, std::allocator<uint8_t>>::StorageType;
|
||||
};
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user