random-selector
This commit is contained in:
@@ -26,10 +26,6 @@ if(ND_WFC_BUILD_TESTS)
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
||||
if(ND_WFC_BUILD_EXAMPLES)
|
||||
add_subdirectory(examples)
|
||||
endif()
|
||||
|
||||
# Install configuration temporarily disabled
|
||||
# TODO: Fix install configuration
|
||||
# include(GNUInstallDirs)
|
||||
|
||||
@@ -638,6 +638,7 @@ concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, Worl
|
||||
template <typename T, typename VarT>
|
||||
concept RandomSelectorFunction = requires(T func, std::span<const VarT> possibleValues) {
|
||||
{ func(possibleValues) } -> std::convertible_to<size_t>;
|
||||
{ func.rng(static_cast<uint32_t>(1)) } -> std::convertible_to<uint32_t>;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -653,11 +654,15 @@ public:
|
||||
constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {}
|
||||
|
||||
constexpr size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
if (possibleValues.empty()) return 0;
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
|
||||
// Simple linear congruential generator for constexpr compatibility
|
||||
return static_cast<size_t>(rng(possibleValues.size()));
|
||||
}
|
||||
|
||||
constexpr uint32_t rng(uint32_t max) {
|
||||
m_seed = m_seed * 1103515245 + 12345;
|
||||
return static_cast<size_t>(m_seed) % possibleValues.size();
|
||||
return m_seed % max;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -674,13 +679,141 @@ public:
|
||||
explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {}
|
||||
|
||||
size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
if (possibleValues.empty()) return 0;
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
|
||||
std::uniform_int_distribution<size_t> dist(0, possibleValues.size() - 1);
|
||||
return dist(m_rng);
|
||||
return rng(possibleValues.size());
|
||||
}
|
||||
|
||||
uint32_t rng(uint32_t max) {
|
||||
std::uniform_int_distribution<uint32_t> dist(0, max);
|
||||
return dist(m_rng);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Weight specification for a specific value
|
||||
* @tparam Value The variable value
|
||||
* @tparam Weight The 16-bit weight for this value
|
||||
*/
|
||||
template <typename VarT, VarT Value, uint16_t WeightValue>
|
||||
struct Weight {
|
||||
static constexpr VarT value = Value;
|
||||
static constexpr uint16_t weight = WeightValue;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compile-time weights storage for weighted random selection
|
||||
* @tparam VarT The variable type
|
||||
* @tparam VariableIDMapT The variable ID map type
|
||||
* @tparam DefaultWeight The default weight for values not explicitly specified
|
||||
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
|
||||
*/
|
||||
template <typename VarT, typename VariableIDMapT, uint16_t DefaultWeight, typename... WeightSpecs>
|
||||
class WeightsMap {
|
||||
private:
|
||||
static constexpr size_t NumWeights = sizeof...(WeightSpecs);
|
||||
|
||||
// Helper to get weight for a specific value
|
||||
static consteval uint16_t GetWeightForValue(VarT targetValue) {
|
||||
// Check each weight spec to find the target value
|
||||
uint16_t weight = DefaultWeight;
|
||||
((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...);
|
||||
return weight;
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Get the weight for a specific value at compile time
|
||||
* @tparam TargetValue The value to get weight for
|
||||
* @return The weight for the value
|
||||
*/
|
||||
template <VarT TargetValue>
|
||||
static consteval uint16_t GetWeight() {
|
||||
return GetWeightForValue(TargetValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get weights array for all registered values
|
||||
* @return Array of weights corresponding to all registered values
|
||||
*/
|
||||
static consteval std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> GetWeightsArray() {
|
||||
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> weights{};
|
||||
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
weights[i] = GetWeightForValue(VariableIDMapT::GetValueConsteval(i));
|
||||
}
|
||||
|
||||
return weights;
|
||||
}
|
||||
|
||||
static consteval uint32_t GetTotalWeight() {
|
||||
uint32_t totalWeight = 0;
|
||||
auto weights = GetWeightsArray();
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
totalWeight += weights[i];
|
||||
}
|
||||
return totalWeight;
|
||||
}
|
||||
|
||||
static consteval std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> GetCumulativeWeightsArray() {
|
||||
auto weights = GetWeightsArray();
|
||||
uint32_t totalWeight = 0;
|
||||
std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> cumulativeWeights{};
|
||||
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
|
||||
totalWeight += weights[i];
|
||||
cumulativeWeights[i] = totalWeight;
|
||||
}
|
||||
return cumulativeWeights;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Weighted random selector that uses another random selector as backend
|
||||
* @tparam VarT The variable type
|
||||
* @tparam VariableIDMapT The variable ID map type
|
||||
* @tparam BackendSelectorT The backend random selector type
|
||||
* @tparam WeightsMapT The weights map type containing weight specifications
|
||||
*/
|
||||
template <typename VarT, typename VariableIDMapT, typename BackendSelectorT, typename WeightsMapT>
|
||||
class WeightedSelector {
|
||||
private:
|
||||
BackendSelectorT m_backendSelector;
|
||||
const std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> m_weights;
|
||||
const std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> m_cumulativeWeights;
|
||||
|
||||
public:
|
||||
explicit WeightedSelector(BackendSelectorT backendSelector)
|
||||
: m_backendSelector(backendSelector)
|
||||
, m_weights(WeightsMapT::GetWeightsArray())
|
||||
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
|
||||
{}
|
||||
|
||||
explicit WeightedSelector(uint32_t seed)
|
||||
requires std::is_same_v<BackendSelectorT, DefaultRandomSelector<VarT>>
|
||||
: m_backendSelector(seed)
|
||||
, m_weights(WeightsMapT::GetWeightsArray())
|
||||
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
|
||||
{}
|
||||
|
||||
size_t operator()(std::span<const VarT> possibleValues) const {
|
||||
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
|
||||
constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value");
|
||||
|
||||
// Use backend selector to pick a random number in range [0, totalWeight)
|
||||
uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back());
|
||||
|
||||
// Find which value this random value corresponds to
|
||||
for (size_t i = 0; i < possibleValues.size(); ++i) {
|
||||
if (randomValue <= m_cumulativeWeights[i]) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback (should not reach here)
|
||||
return possibleValues.size() - 1;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Builder class for creating WFC instances
|
||||
*/
|
||||
@@ -714,6 +847,10 @@ public:
|
||||
requires RandomSelectorFunction<NewRandomSelectorT, VarT>
|
||||
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
|
||||
|
||||
template <uint16_t DefaultWeight, typename... WeightSpecs>
|
||||
using Weights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, WeightedSelector<VarT, VariableIDMapT, RandomSelectorT, WeightsMap<VarT, VariableIDMapT, DefaultWeight, WeightSpecs...>>>;
|
||||
|
||||
|
||||
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user