Merge commit '0b4923d649426448f784224ad2505ae0b312757c' into prompt/weighted-random-selector

This commit is contained in:
cdemeyer-teachx
2025-09-04 08:01:04 +09:00
2 changed files with 142 additions and 9 deletions

View File

@@ -26,10 +26,6 @@ if(ND_WFC_BUILD_TESTS)
add_subdirectory(tests)
endif()
if(ND_WFC_BUILD_EXAMPLES)
add_subdirectory(examples)
endif()
# Install configuration temporarily disabled
# TODO: Fix install configuration
# include(GNUInstallDirs)

View File

@@ -638,6 +638,7 @@ concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, Worl
template <typename T, typename VarT>
concept RandomSelectorFunction = requires(T func, std::span<const VarT> possibleValues) {
{ func(possibleValues) } -> std::convertible_to<size_t>;
{ func.rng(static_cast<uint32_t>(1)) } -> std::convertible_to<uint32_t>;
};
/**
@@ -653,11 +654,15 @@ public:
constexpr explicit DefaultRandomSelector(uint32_t seed = 0x12345678) : m_seed(seed) {}
constexpr size_t operator()(std::span<const VarT> possibleValues) const {
if (possibleValues.empty()) return 0;
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
// Simple linear congruential generator for constexpr compatibility
return static_cast<size_t>(rng(possibleValues.size()));
}
constexpr uint32_t rng(uint32_t max) {
m_seed = m_seed * 1103515245 + 12345;
return static_cast<size_t>(m_seed) % possibleValues.size();
return m_seed % max;
}
};
@@ -674,13 +679,141 @@ public:
explicit AdvancedRandomSelector(std::mt19937& rng) : m_rng(rng) {}
size_t operator()(std::span<const VarT> possibleValues) const {
if (possibleValues.empty()) return 0;
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
std::uniform_int_distribution<size_t> dist(0, possibleValues.size() - 1);
return rng(possibleValues.size());
}
uint32_t rng(uint32_t max) {
std::uniform_int_distribution<uint32_t> dist(0, max);
return dist(m_rng);
}
};
/**
* @brief Weight specification for a specific value
* @tparam Value The variable value
* @tparam Weight The 16-bit weight for this value
*/
template <typename VarT, VarT Value, uint16_t WeightValue>
struct Weight {
static constexpr VarT value = Value;
static constexpr uint16_t weight = WeightValue;
};
/**
* @brief Compile-time weights storage for weighted random selection
* @tparam VarT The variable type
* @tparam VariableIDMapT The variable ID map type
* @tparam DefaultWeight The default weight for values not explicitly specified
* @tparam WeightSpecs Variadic template parameters of Weight<VarT, Value, Weight> specifications
*/
template <typename VarT, typename VariableIDMapT, uint16_t DefaultWeight, typename... WeightSpecs>
class WeightsMap {
private:
static constexpr size_t NumWeights = sizeof...(WeightSpecs);
// Helper to get weight for a specific value
static consteval uint16_t GetWeightForValue(VarT targetValue) {
// Check each weight spec to find the target value
uint16_t weight = DefaultWeight;
((WeightSpecs::value == targetValue ? weight = WeightSpecs::weight : weight), ...);
return weight;
}
public:
/**
* @brief Get the weight for a specific value at compile time
* @tparam TargetValue The value to get weight for
* @return The weight for the value
*/
template <VarT TargetValue>
static consteval uint16_t GetWeight() {
return GetWeightForValue(TargetValue);
}
/**
* @brief Get weights array for all registered values
* @return Array of weights corresponding to all registered values
*/
static consteval std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> GetWeightsArray() {
std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> weights{};
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
weights[i] = GetWeightForValue(VariableIDMapT::GetValueConsteval(i));
}
return weights;
}
static consteval uint32_t GetTotalWeight() {
uint32_t totalWeight = 0;
auto weights = GetWeightsArray();
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
totalWeight += weights[i];
}
return totalWeight;
}
static consteval std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> GetCumulativeWeightsArray() {
auto weights = GetWeightsArray();
uint32_t totalWeight = 0;
std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> cumulativeWeights{};
for (size_t i = 0; i < VariableIDMapT::ValuesRegisteredAmount; ++i) {
totalWeight += weights[i];
cumulativeWeights[i] = totalWeight;
}
return cumulativeWeights;
}
};
/**
* @brief Weighted random selector that uses another random selector as backend
* @tparam VarT The variable type
* @tparam VariableIDMapT The variable ID map type
* @tparam BackendSelectorT The backend random selector type
* @tparam WeightsMapT The weights map type containing weight specifications
*/
template <typename VarT, typename VariableIDMapT, typename BackendSelectorT, typename WeightsMapT>
class WeightedSelector {
private:
BackendSelectorT m_backendSelector;
const std::array<uint16_t, VariableIDMapT::ValuesRegisteredAmount> m_weights;
const std::array<uint32_t, VariableIDMapT::ValuesRegisteredAmount> m_cumulativeWeights;
public:
explicit WeightedSelector(BackendSelectorT backendSelector)
: m_backendSelector(backendSelector)
, m_weights(WeightsMapT::GetWeightsArray())
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
{}
explicit WeightedSelector(uint32_t seed)
requires std::is_same_v<BackendSelectorT, DefaultRandomSelector<VarT>>
: m_backendSelector(seed)
, m_weights(WeightsMapT::GetWeightsArray())
, m_cumulativeWeights(WeightsMapT::GetCumulativeWeightsArray())
{}
size_t operator()(std::span<const VarT> possibleValues) const {
constexpr_assert(!possibleValues.empty(), "possibleValues must not be empty");
constexpr_assert(possibleValues.size() == 1, "possibleValues must be a single value");
// Use backend selector to pick a random number in range [0, totalWeight)
uint32_t randomValue = m_backendSelector.rng(m_cumulativeWeights.back());
// Find which value this random value corresponds to
for (size_t i = 0; i < possibleValues.size(); ++i) {
if (randomValue <= m_cumulativeWeights[i]) {
return i;
}
}
// Fallback (should not reach here)
return possibleValues.size() - 1;
}
};
/**
* @brief Builder class for creating WFC instances
*/
@@ -714,6 +847,10 @@ public:
requires RandomSelectorFunction<NewRandomSelectorT, VarT>
using SetRandomSelector = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, NewRandomSelectorT>;
template <uint16_t DefaultWeight, typename... WeightSpecs>
using Weights = Builder<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, WeightedSelector<VarT, VariableIDMapT, RandomSelectorT, WeightsMap<VarT, VariableIDMapT, DefaultWeight, WeightSpecs...>>>;
using Build = WFC<WorldT, VarT, VariableIDMapT, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>;
};