random-selector
This commit is contained in:
@@ -26,10 +26,6 @@ if(ND_WFC_BUILD_TESTS)
|
|||||||
add_subdirectory(tests)
|
add_subdirectory(tests)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ND_WFC_BUILD_EXAMPLES)
|
|
||||||
add_subdirectory(examples)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Install configuration temporarily disabled
|
# Install configuration temporarily disabled
|
||||||
# TODO: Fix install configuration
|
# TODO: Fix install configuration
|
||||||
# include(GNUInstallDirs)
|
# include(GNUInstallDirs)
|
||||||
|
|||||||
@@ -638,6 +638,7 @@ concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, Worl
|
|||||||
template <typename T, typename VarT>
|
template <typename T, typename VarT>
|
||||||
concept RandomSelectorFunction = requires(T func, std::span<const VarT> possibleValues) {
|
concept RandomSelectorFunction = requires(T func, std::span<const VarT> possibleValues) {
|
||||||
{ func(possibleValues) } -> std::convertible_to<size_t>;
|
{ func(possibleValues) } -> std::convertible_to<size_t>;
|
||||||
|
{ func.rng(static_cast<uint32_t>(1)) } -> std::convertible_to<uint32_t>;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -653,11 +654,15 @@ public:
|
|||||||
constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {}
|
constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {}
|
||||||
|
|
||||||
constexpr size_t operator()(std::span<const VarT> possibleValues) const {
|
constexpr size_t operator()(std::span<const VarT> possibleValues) const {
|
||||||
if (possibleValues.empty()) return 0;
|
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||||
|
|
||||||
// Simple linear congruential generator for constexpr compatibility
|
// Simple linear congruential generator for constexpr compatibility
|
||||||
|
return static_cast<size_t>(rng(possibleValues.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr uint32_t rng(uint32_t max) {
|
||||||
m_seed = m_seed * 1103515245 + 12345;
|
m_seed = m_seed * 1103515245 + 12345;
|
||||||
return static_cast<size_t>(m_seed) % possibleValues.size();
|
return m_seed % max;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -674,13 +679,141 @@ public:
|
|||||||
explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {}
|
explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {}
|
||||||
|
|
||||||
size_t operator()(std::span<const VarT> possibleValues) const {
|
size_t operator()(std::span<const VarT> possibleValues) const {
|
||||||
if (possibleValues.empty()) return 0;
|
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||||
|
|
||||||
std::uniform_int_distribution<size_t> dist(0, possibleValues.size() - 1);
|
return rng(possibleValues.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t rng(uint32_t max) {
|
||||||
|
std::uniform_int_distribution<uint32_t> dist(0, max);
|
||||||
return dist(m_rng);
|
return dist(m_rng);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Weight specification for a specific value
|
||||||
|
* @tparam Value The variable value
|
||||||
|
* @tparam Weight The 16-bit weight for this value
|
||||||
|
*/
|
||||||
|
template <typename VarT, VarT Value, uint16_t WeightValue>
|
||||||
|
struct Weight {
|
||||||
|
static constexpr VarT value = Value;
|
||||||
|
static constexpr uint16_t weight = WeightValue;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compile-time weights storage for weighted random selection
|
||||||
|
* @tparam VarT The variable type
|
||||||
|
* @tparam VariableIDMapT The variable ID map type
|
||||||
|
* @tparam DefaultWeight The default weight for values not explicitly specified
|
||||||
|
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
|
||||||
|
*/
|
||||||
|
template <typename VarT, typename VariableIDMapT, uint16_t DefaultWeight, typename... WeightSpecs>
|
||||||
|
class WeightsMap {
|
||||||
|
private:
|
||||||
|
static constexpr size_t NumWeights = sizeof...(WeightSpecs);
|
||||||
|
|
||||||
|
// Helper to get weight for a specific value
|
||||||
|
static consteval uint16_t GetWeightForValue(VarT targetValue) {
|
||||||
|
// Check each weight spec to find the target value
|
||||||
|
uint16_t weight = DefaultWeight;
|
||||||
|
((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...);
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Get the weight for a specific value at compile time
|
||||||
|
* @tparam TargetValue The value to get weight for
|
||||||
|
* @return The weight for the value
|
||||||
|
*/
|
||||||
|
template <VarT TargetValue>
|
||||||
|
static consteval uint16_t GetWeight() {
|
||||||
|
return GetWeightForValue(TargetValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get weights array for all registered values
|
||||||
|
* @return Array of weights corresponding to all registered values
|
||||||
|
*/
|
||||||
|
static consteval std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> GetWeightsArray() {
|
||||||
|
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> weights{};
|
||||||
|
|
||||||
|
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||||
|
weights[i] = GetWeightForValue(VariableIDMapT::GetValueConsteval(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
static consteval uint32_t GetTotalWeight() {
|
||||||
|
uint32_t totalWeight = 0;
|
||||||
|
auto weights = GetWeightsArray();
|
||||||
|
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||||
|
totalWeight += weights[i];
|
||||||
|
}
|
||||||
|
return totalWeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
static consteval std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> GetCumulativeWeightsArray() {
|
||||||
|
auto weights = GetWeightsArray();
|
||||||
|
uint32_t totalWeight = 0;
|
||||||
|
std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> cumulativeWeights{};
|
||||||
|
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||||
|
totalWeight += weights[i];
|
||||||
|
cumulativeWeights[i] = totalWeight;
|
||||||
|
}
|
||||||
|
return cumulativeWeights;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Weighted random selector that uses another random selector as backend
|
||||||
|
* @tparam VarT The variable type
|
||||||
|
* @tparam VariableIDMapT The variable ID map type
|
||||||
|
* @tparam BackendSelectorT The backend random selector type
|
||||||
|
* @tparam WeightsMapT The weights map type containing weight specifications
|
||||||
|
*/
|
||||||
|
template <typename VarT, typename VariableIDMapT, typename BackendSelectorT, typename WeightsMapT>
|
||||||
|
class WeightedSelector {
|
||||||
|
private:
|
||||||
|
BackendSelectorT m_backendSelector;
|
||||||
|
const std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> m_weights;
|
||||||
|
const std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> m_cumulativeWeights;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit WeightedSelector(BackendSelectorT backendSelector)
|
||||||
|
: m_backendSelector(backendSelector)
|
||||||
|
, m_weights(WeightsMapT::GetWeightsArray())
|
||||||
|
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
|
||||||
|
{}
|
||||||
|
|
||||||
|
explicit WeightedSelector(uint32_t seed)
|
||||||
|
requires std::is_same_v<BackendSelectorT, DefaultRandomSelector<VarT>>
|
||||||
|
: m_backendSelector(seed)
|
||||||
|
, m_weights(WeightsMapT::GetWeightsArray())
|
||||||
|
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
|
||||||
|
{}
|
||||||
|
|
||||||
|
size_t operator()(std::span<const VarT> possibleValues) const {
|
||||||
|
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||||
|
constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value");
|
||||||
|
|
||||||
|
// Use backend selector to pick a random number in range [0, totalWeight)
|
||||||
|
uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back());
|
||||||
|
|
||||||
|
// Find which value this random value corresponds to
|
||||||
|
for (size_t i = 0; i < possibleValues.size(); ++i) {
|
||||||
|
if (randomValue <= m_cumulativeWeights[i]) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback (should not reach here)
|
||||||
|
return possibleValues.size() - 1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Builder class for creating WFC instances
|
* @brief Builder class for creating WFC instances
|
||||||
*/
|
*/
|
||||||
@@ -714,6 +847,10 @@ public:
|
|||||||
requires RandomSelectorFunction<NewRandomSelectorT, VarT>
|
requires RandomSelectorFunction<NewRandomSelectorT, VarT>
|
||||||
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
|
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
|
||||||
|
|
||||||
|
template <uint16_t DefaultWeight, typename... WeightSpecs>
|
||||||
|
using Weights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, WeightedSelector<VarT, VariableIDMapT, RandomSelectorT, WeightsMap<VarT, VariableIDMapT, DefaultWeight, WeightSpecs...>>>;
|
||||||
|
|
||||||
|
|
||||||
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
|
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user