From bc9d7e3b9bbfb63fe3ef359a7dc09c1011e517aa Mon Sep 17 00:00:00 2001 From: cdemeyer-teachx Date: Wed, 10 Sep 2025 12:21:31 +0900 Subject: [PATCH] implementation + tests pass --- demos/sudoku/sudoku.h | 7 +- demos/sudoku/sudoku_wfc.cpp | 7 +- demos/sudoku/test_sudoku.cpp | 2 - include/nd-wfc/wfc.hpp | 58 ++- include/nd-wfc/wfc_bit_container.hpp | 123 +++++-- include/nd-wfc/wfc_builder.hpp | 16 +- include/nd-wfc/wfc_callbacks.hpp | 5 +- include/nd-wfc/wfc_constrainer.hpp | 24 +- include/nd-wfc/wfc_large_integers.hpp | 505 +++++++++++++++++++++++++- include/nd-wfc/wfc_random.hpp | 4 +- include/nd-wfc/wfc_utils.hpp | 17 + include/nd-wfc/wfc_variable_map.hpp | 6 + include/nd-wfc/wfc_wave.hpp | 4 +- 13 files changed, 693 insertions(+), 85 deletions(-) diff --git a/demos/sudoku/sudoku.h b/demos/sudoku/sudoku.h index b15bd8d..8d0ffd2 100644 --- a/demos/sudoku/sudoku.h +++ b/demos/sudoku/sudoku.h @@ -11,7 +11,7 @@ #include #include -#include +#include // 4-bit packed Sudoku board storage - optimal packing // 81 cells * 4 bits = 324 bits @@ -38,7 +38,7 @@ public: uint8_t result = (data[byteIndex] >> shiftAmount) & 0xF; // Debug assertion: ensure result is in valid range - WFC::constexpr_assert(result >= 0 && result <= 9, "Sudoku cell value must be between 0-9"); + WFC::constexpr_assert(result <= 9, "Sudoku cell value must be between 0-9"); return result; } @@ -49,7 +49,7 @@ public: // Optimization: (pos & 1) << 2 instead of (pos % 2) * 4 constexpr inline void set(int pos, uint8_t value) { // Assert that value is in valid Sudoku range (0-9) - WFC::constexpr_assert(value >= 0 && value <= 9, "Sudoku cell value must be between 0-9"); + WFC::constexpr_assert(value <= 9, "Sudoku cell value must be between 0-9"); int byteIndex = pos >> 1; // pos / 2 using right shift @@ -294,6 +294,7 @@ public: // WFC Support // Static assert to ensure correct size (now 56 bytes with solver additions) static_assert(sizeof(Sudoku) == 41, "Sudoku class must be exactly 41 bytes"); +static_assert(WFC::HasConstexprSize, "Sudoku class must have a constexpr size() method"); // Fast solution validator (stateless) class SudokuValidator { diff --git a/demos/sudoku/sudoku_wfc.cpp b/demos/sudoku/sudoku_wfc.cpp index 2bd8827..169dae5 100644 --- a/demos/sudoku/sudoku_wfc.cpp +++ b/demos/sudoku/sudoku_wfc.cpp @@ -58,16 +58,11 @@ using SudokuSolverCallback = SudokuSolverBuilder::SetCellCollapsedCallback ::Build; -Sudoku GetWorldConsteval() -{ - return Sudoku{ "6......3.......7....7463....7.8...2.4...9...1.9...7.8....9851....6.......1......9" }; -} - int main() { std::cout << "Running Sudoku WFC" << std::endl; - Sudoku sudokuWorld = GetWorldConsteval(); + Sudoku sudokuWorld = Sudoku{ "6......3.......7....7463....7.8...2.4...9...1.9...7.8....9851....6.......1......9" }; bool success = SudokuSolverCallback::Run(sudokuWorld, true); diff --git a/demos/sudoku/test_sudoku.cpp b/demos/sudoku/test_sudoku.cpp index 8d238ca..7713813 100644 --- a/demos/sudoku/test_sudoku.cpp +++ b/demos/sudoku/test_sudoku.cpp @@ -288,9 +288,7 @@ void testPuzzleSolving(const std::string& difficulty, const std::string& filenam Sudoku& sudoku = puzzles[i]; EXPECT_TRUE(sudoku.isValid()) << difficulty << " puzzle " << i << " is not valid"; - auto puzzleStart = std::chrono::high_resolution_clock::now(); SudokuSolver::Run(sudoku, allocator); - auto puzzleEnd = std::chrono::high_resolution_clock::now(); EXPECT_TRUE(sudoku.isSolved()) << difficulty << " puzzle " << i << " was not solved. Puzzle string: " << sudoku.toString(); diff --git a/include/nd-wfc/wfc.hpp b/include/nd-wfc/wfc.hpp index 3d845e7..daed5e3 100644 --- a/include/nd-wfc/wfc.hpp +++ b/include/nd-wfc/wfc.hpp @@ -25,23 +25,6 @@ namespace WFC { -template -struct WorldValue -{ -public: - WorldValue() = default; - WorldValue(VarT value, uint16_t internalIndex) - : Value(value) - , InternalIndex(internalIndex) - {} -public: - operator VarT() const { return Value; } - -public: - VarT Value{}; - uint16_t InternalIndex{}; -}; - template concept WorldType = requires(T world, size_t id, typename T::ValueType value) { { world.size() } -> std::convertible_to; @@ -64,15 +47,20 @@ concept ConstrainerFunction = requires(T func, WorldT& world, size_t index, Worl * The function must be callable with parameters: (std::span) and return size_t */ template -concept RandomSelectorFunction = requires(T func, std::span possibleValues) { +concept RandomSelectorFunction = requires(const T& func, std::span possibleValues) { { func(possibleValues) } -> std::convertible_to; { func.rng(static_cast(1)) } -> std::convertible_to; }; +template +concept HasConstexprSize = requires { + { []() constexpr -> std::size_t { return WorldT{}.size(); }() }; +}; + /** * @brief Main WFC class implementing the Wave Function Collapse algorithm */ -template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, @@ -81,14 +69,19 @@ class WFC { public: static_assert(WorldType, "WorldT must satisfy World type requirements"); - using ElementT = typename VariableIDMapT::ElementT; + // Try getting the world size, which is only available if the world type has a constexpr size() method + constexpr static size_t WorldSize = HasConstexprSize ? WorldT{}.size() : 0; + + using WaveType = Wave; + using ConstrainerType = Constrainer; + using MaskType = typename WaveType::ElementT; public: struct SolverState { WorldT& world; WFCQueue propagationQueue; - Wave wave; + WaveType wave; std::mt19937& rng; RandomSelectorT& randomSelector; WFCStackAllocator& allocator; @@ -97,7 +90,7 @@ public: SolverState(WorldT& world, size_t variableAmount, std::mt19937& rng, RandomSelectorT& randomSelector, WFCStackAllocator& allocator, size_t& iterations) : world(world) , propagationQueue{ WFCStackAllocatorAdapter(allocator) } - , wave{ world.size(), variableAmount, allocator } + , wave{ WorldSize, variableAmount, allocator } , rng(rng) , randomSelector(randomSelector) , allocator(allocator) @@ -111,6 +104,7 @@ public: WFC() = delete; // dont make an instance of this class, only use the static methods. public: + static bool Run(WorldT& world, uint32_t seed = std::random_device{}()) { WFCStackAllocator allocator{}; @@ -134,10 +128,12 @@ public: allocator, iterations }; - return Run(state); + bool result = Run(state); allocator.reset(); constexpr_assert(allocator.getUsed() == 0, "Allocator must be empty"); + + return result; } /** @@ -209,7 +205,7 @@ public: static const std::vector GetPossibleValues(SolverState& state, int cellId) { std::vector possibleValues; - ElementT mask = state.wave.GetMask(cellId); + MaskType mask = state.wave.GetMask(cellId); for (size_t i = 0; i < ConstrainerFunctionMapT::size(); ++i) { if (mask & (1 << i)) possibleValues.push_back(VariableIDMapT::GetValue(i)); } @@ -219,7 +215,7 @@ public: private: static void CollapseCell(SolverState& state, size_t cellId, uint16_t value) { - constexpr_assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (ElementT(1) << value)); + constexpr_assert(!state.wave.IsCollapsed(cellId) || state.wave.GetMask(cellId) == (MaskType(1) << value)); state.wave.Collapse(cellId, 1 << value); constexpr_assert(state.wave.IsCollapsed(cellId)); @@ -252,14 +248,14 @@ private: // create a list of possible values uint16_t availableValues = static_cast(state.wave.Entropy(minEntropyCell)); std::array possibleValues; // inplace vector - ElementT mask = state.wave.GetMask(minEntropyCell); + MaskType mask = state.wave.GetMask(minEntropyCell); for (size_t i = 0; i < availableValues; ++i) { uint16_t index = static_cast(std::countr_zero(mask)); // get the index of the lowest set bit constexpr_assert(index < VariableIDMapT::ValuesRegisteredAmount, "Possible value went outside bounds"); possibleValues[i] = index; - constexpr_assert(((mask & (ElementT(1) << index)) != 0), "Possible value was not set"); + constexpr_assert(((mask & (MaskType(1) << index)) != 0), "Possible value was not set"); mask = mask & (mask - 1); // turn off lowest set bit } @@ -293,9 +289,9 @@ private: } // remove the failure state from the wave - constexpr_assert((state.wave.GetMask(minEntropyCell) & (ElementT(1) << selectedValue)) != 0, "Possible value was not set"); + constexpr_assert((state.wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) != 0, "Possible value was not set"); state.wave.Collapse(minEntropyCell, ~(1 << selectedValue)); - constexpr_assert((state.wave.GetMask(minEntropyCell) & (ElementT(1) << selectedValue)) == 0, "Wave was not collapsed correctly"); + constexpr_assert((state.wave.GetMask(minEntropyCell) & (MaskType(1) << selectedValue)) == 0, "Wave was not collapsed correctly"); // swap replacement value with the last value std::swap(possibleValues[randomIndex], possibleValues[--availableValues]); @@ -316,9 +312,9 @@ private: constexpr_assert(state.wave.IsCollapsed(cellId), "Cell was not collapsed"); uint16_t variableID = state.wave.GetVariableID(cellId); - Constrainer constrainer(state.wave, state.propagationQueue); + ConstrainerType constrainer(state.wave, state.propagationQueue); - using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue, Constrainer&); + using ConstrainerFunctionPtrT = void(*)(WorldT&, size_t, WorldValue, ConstrainerType&); ConstrainerFunctionMapT::template GetFunction(variableID)(state.world, cellId, WorldValue{VariableIDMapT::GetValue(variableID), variableID}, constrainer); } diff --git a/include/nd-wfc/wfc_bit_container.hpp b/include/nd-wfc/wfc_bit_container.hpp index 60b322f..509feb4 100644 --- a/include/nd-wfc/wfc_bit_container.hpp +++ b/include/nd-wfc/wfc_bit_container.hpp @@ -6,9 +6,11 @@ #include #include #include +#include #include "wfc_utils.hpp" #include "wfc_allocator.hpp" +#include "wfc_large_integers.hpp" namespace WFC { @@ -38,7 +40,7 @@ namespace detail { static constexpr size_t StorageBits = OptimalStorageType::bits_needed; static constexpr size_t ArraySize = StorageBits > 64 ? (StorageBits / 64) : 1; using element_type = std::conditional_t::type, uint64_t>; - using type = std::conditional_t>; + using type = std::conditional_t>; }; struct Empty{}; @@ -57,7 +59,6 @@ public: static constexpr bool IsResizable = (Size == 0); static constexpr bool IsMultiElement = (StorageBits > 64); static constexpr bool IsSubByte = (StorageBits < 8); - static constexpr bool IsDefaultByteLayout = !IsMultiElement && !IsSubByte; static constexpr size_t ElementsPerByte = sizeof(StorageType) * 8 / std::max(1u, StorageBits); using ContainerType = @@ -90,15 +91,30 @@ private: static constexpr uint64_t Mask = get_Mask(); +public: + static constexpr StorageType GetWaveMask() + { + return (StorageType{1} << BitsPerElement) - 1; + } + + static constexpr StorageType GetMask(std::span indices) + { + StorageType mask = 0; + for (const auto& index : indices) { + mask |= (StorageType{1} << index); + } + return mask; + } + public: BitContainer() = default; - BitContainer(AllocatorT& allocator) : AllocatorT(allocator) {}; - explicit BitContainer(size_t initial_size, AllocatorT& allocator) requires (IsResizable) + BitContainer(const AllocatorT& allocator) : AllocatorT(allocator) {}; + explicit BitContainer(size_t initial_size, const AllocatorT& allocator) requires (IsResizable) : AllocatorT(allocator) , m_container(initial_size, allocator) {}; - explicit BitContainer(size_t, AllocatorT& allocator) requires (!IsResizable) + explicit BitContainer(size_t, const AllocatorT& allocator) requires (!IsResizable) : AllocatorT(allocator) , m_container() {}; @@ -131,13 +147,15 @@ public: // Sub byte constexpr uint8_t SetValue(uint8_t val) { Clear(); return Data |= ((val & Mask) << Shift); } constexpr void Clear() { Data &= ~Mask; } + + constexpr SubTypeAccess& operator=(uint8_t other) { return SetValue(other); } constexpr operator uint8_t() const { return GetValue(); } - template constexpr uint8_t operator&=(T other) { return SetValue(GetValue() & other); } - template constexpr uint8_t operator|=(T other) { return SetValue(GetValue() | other); } - template constexpr uint8_t operator^=(T other) { return SetValue(GetValue() ^ other); } - template constexpr uint8_t operator<<=(T other) { return SetValue(GetValue() << other); } - template constexpr uint8_t operator>>=(T other) { return SetValue(GetValue() >> other); } + constexpr SubTypeAccess& operator&=(uint8_t other) { return SetValue(GetValue() & other); } + constexpr SubTypeAccess& operator|=(uint8_t other) { return SetValue(GetValue() | other); } + constexpr SubTypeAccess& operator^=(uint8_t other) { return SetValue(GetValue() ^ other); } + constexpr SubTypeAccess& operator<<=(uint8_t other) { return SetValue(GetValue() << other); } + constexpr SubTypeAccess& operator>>=(uint8_t other) { return SetValue(GetValue() >> other); } uint8_t& Data; uint8_t Shift; @@ -146,23 +164,86 @@ public: // Sub byte constexpr const SubTypeAccess operator[](size_t index) const requires(IsSubByte) { return SubTypeAccess{data()[index / ElementsPerByte], index & ElementsPerByte }; } constexpr SubTypeAccess operator[](size_t index) requires(IsSubByte) { return SubTypeAccess{data()[index / ElementsPerByte], index & ElementsPerByte }; } -public: // MultiElement - struct MultiElementAccess - { - constexpr MultiElementAccess(StorageType& data) : Data{ data } {}; +public: // default + constexpr const StorageType& operator[](size_t index) const requires(!IsSubByte) { return data()[index]; } + constexpr StorageType& operator[](size_t index) requires(!IsSubByte) { return data()[index]; } - StorageType& Data; +public: // iterators + template + class BitIterator { + public: + // Iterator traits + using iterator_category = std::random_access_iterator_tag; + using value_type = StorageType; + using difference_type = std::ptrdiff_t; + using pointer = std::conditional_t; + using reference = std::conditional_t; + + private: + using ContainerType = std::conditional_t; + + ContainerType* m_container{}; + size_t m_index{}; + + public: + // Constructor + constexpr BitIterator() = default; + constexpr BitIterator(ContainerType& container, size_t index) : m_container(&container), m_index(index) {} + + // Dereference + constexpr reference operator*() const { return (*m_container)[m_index]; } + constexpr pointer operator->() const { return &(*m_container)[m_index]; } + + // Element access + constexpr reference operator[](difference_type n) const { return (*m_container)[m_index + n]; } + + // Increment / Decrement + constexpr BitIterator& operator++() { ++m_index; return *this; } + constexpr BitIterator operator++(int) { BitIterator tmp = *this; ++m_index; return tmp; } + constexpr BitIterator& operator--() { --m_index; return *this; } + constexpr BitIterator operator--(int) { BitIterator tmp = *this; --m_index; return tmp; } + + // Arithmetic + constexpr BitIterator operator+(difference_type n) const { return BitIterator(*m_container, m_index + n); } + constexpr BitIterator operator-(difference_type n) const { return BitIterator(*m_container, m_index - n); } + constexpr difference_type operator-(const BitIterator& other) const { return static_cast(m_index) - static_cast(other.m_index); } + + // Assignment + constexpr BitIterator& operator+=(difference_type n) { m_index += n; return *this; } + constexpr BitIterator& operator-=(difference_type n) { m_index -= n; return *this; } + + // Comparison + constexpr bool operator==(const BitIterator& other) const { return m_index == other.m_index; } + constexpr bool operator!=(const BitIterator& other) const { return m_index != other.m_index; } + constexpr bool operator<(const BitIterator& other) const { return m_index < other.m_index; } + constexpr bool operator>(const BitIterator& other) const { return m_index > other.m_index; } + constexpr bool operator<=(const BitIterator& other) const { return m_index <= other.m_index; } + constexpr bool operator>=(const BitIterator& other) const { return m_index >= other.m_index; } + + // Conversion from non-const to const iterator + constexpr operator BitIterator() const { + return BitIterator(*m_container, m_index); + } }; - constexpr const MultiElementAccess operator[](size_t index) const requires(IsMultiElement) { return MultiElementAccess{data()[index]}; } - constexpr MultiElementAccess operator[](size_t index) requires(IsMultiElement) { return MultiElementAccess{data()[index]}; } - -public: // default - constexpr const StorageType& operator[](size_t index) const requires(IsDefaultByteLayout) { return data()[index]; } - constexpr StorageType& operator[](size_t index) requires(IsDefaultByteLayout) { return data()[index]; } + // Type aliases for convenience + using ConstIterator = BitIterator; + using Iterator = BitIterator; + constexpr Iterator begin() { return Iterator{*this, 0}; } + constexpr Iterator end() { return Iterator{*this, size()}; } + constexpr const ConstIterator begin() const { return ConstIterator{*this, 0}; } + constexpr const ConstIterator end() const { return ConstIterator{*this, size()}; } }; +// Free function for iterator addition +template ::type>, bool IsConst> +BitContainer::BitIterator operator+( + typename BitContainer::template BitIterator::difference_type n, + const typename BitContainer::template BitIterator& it) { + return it + n; +} + static_assert(BitContainer<1, 10>::ElementsPerByte == 8); static_assert(BitContainer<2, 10>::ElementsPerByte == 4); static_assert(BitContainer<4, 10>::ElementsPerByte == 2); diff --git a/include/nd-wfc/wfc_builder.hpp b/include/nd-wfc/wfc_builder.hpp index b30697f..b85366e 100644 --- a/include/nd-wfc/wfc_builder.hpp +++ b/include/nd-wfc/wfc_builder.hpp @@ -12,22 +12,32 @@ namespace WFC { /** * @brief Builder class for creating WFC instances */ -template, typename ConstrainerFunctionMapT = ConstrainerFunctionMap, typename CallbacksT = Callbacks, typename RandomSelectorT = DefaultRandomSelector> +template< + typename WorldT, + typename VarT = typename WorldT::ValueType, + typename VariableIDMapT = VariableIDMap, + typename ConstrainerFunctionMapT = ConstrainerFunctionMap, + typename CallbacksT = Callbacks, + typename RandomSelectorT = DefaultRandomSelector> class Builder { public: + constexpr static size_t WorldSize = HasConstexprSize ? WorldT{}.size() : 0; + + using WaveType = Wave; + using ConstrainerType = Constrainer; template using DefineIDs = Builder, ConstrainerFunctionMapT, CallbacksT, RandomSelectorT>; template - requires ConstrainerFunction + requires ConstrainerFunction using DefineConstrainer = Builder, - decltype([](WorldT&, size_t, WorldValue, Constrainer&) {}) + decltype([](WorldT&, size_t, WorldValue, ConstrainerType&) {}) >, CallbacksT, RandomSelectorT >; diff --git a/include/nd-wfc/wfc_callbacks.hpp b/include/nd-wfc/wfc_callbacks.hpp index f3a02ce..8965b58 100644 --- a/include/nd-wfc/wfc_callbacks.hpp +++ b/include/nd-wfc/wfc_callbacks.hpp @@ -7,7 +7,10 @@ namespace WFC { * @param WorldT The world type */ template -using EmptyCallback = decltype([](WorldT&){}); +struct EmptyCallback +{ + void operator()(WorldT&) const {} +}; /** * @brief Callback struct diff --git a/include/nd-wfc/wfc_constrainer.hpp b/include/nd-wfc/wfc_constrainer.hpp index 100ae59..d40635e 100644 --- a/include/nd-wfc/wfc_constrainer.hpp +++ b/include/nd-wfc/wfc_constrainer.hpp @@ -61,13 +61,15 @@ using MergedConstrainerFunctionMap = decltype( /** * @brief Constrainer class used in constraint functions to limit possible values for other cells */ -template +template class Constrainer { public: - using MaskType = typename VariableIDMapT::MaskType; + using IDMapT = typename WaveT::IDMapT; + using BitContainerT = typename WaveT::BitContainerT; + using MaskType = typename BitContainerT::StorageType; public: - Constrainer(Wave& wave, WFCQueue& propagationQueue) + Constrainer(WaveT& wave, WFCQueue& propagationQueue) : m_wave(wave) , m_propagationQueue(propagationQueue) {} @@ -77,13 +79,14 @@ public: * @param cellId The ID of the cell to constrain * @param forbiddenValues The set of forbidden values for this cell */ - template + template void Exclude(size_t cellId) { static_assert(sizeof...(ExcludedValues) > 0, "At least one excluded value must be provided"); - ApplyMask(cellId, ~VariableIDMapT::template GetMask()); + auto indices = IDMapT::template ValuesToIndices(); + ApplyMask(cellId, ~BitContainerT::GetMask(indices)); } - void Exclude(WorldValue value, size_t cellId) { + void Exclude(WorldValue value, size_t cellId) { ApplyMask(cellId, ~(1 << value.InternalIndex)); } @@ -92,13 +95,14 @@ public: * @param cellId The ID of the cell to constrain * @param value The only allowed value for this cell */ - template + template void Only(size_t cellId) { static_assert(sizeof...(AllowedValues) > 0, "At least one allowed value must be provided"); - ApplyMask(cellId, VariableIDMapT::template GetMask()); + auto indices = IDMapT::template ValuesToIndices(); + ApplyMask(cellId, BitContainerT::GetMask(indices)); } - void Only(WorldValue value, size_t cellId) { + void Only(WorldValue value, size_t cellId) { ApplyMask(cellId, 1 << value.InternalIndex); } @@ -115,7 +119,7 @@ private: } private: - Wave& m_wave; + WaveT& m_wave; WFCQueue& m_propagationQueue; }; diff --git a/include/nd-wfc/wfc_large_integers.hpp b/include/nd-wfc/wfc_large_integers.hpp index eca0afa..6a2484d 100644 --- a/include/nd-wfc/wfc_large_integers.hpp +++ b/include/nd-wfc/wfc_large_integers.hpp @@ -1,22 +1,517 @@ #pragma once #include +#include +#include +#include +#include +#include + +// Detect __uint128_t support +#if (defined(__SIZEOF_INT128__) || defined(__INTEL_COMPILER) || (defined(__GNUC__) && __GNUC__ >= 4)) && !defined(_MSC_VER) +#define WFC_HAS_UINT128 1 +#else +#define WFC_HAS_UINT128 0 +#endif namespace WFC { template struct LargeInteger { + static_assert(Size > 0, "Size must be greater than 0"); + std::array m_data; + // Constructors + constexpr LargeInteger() = default; + constexpr LargeInteger(const LargeInteger&) = default; + constexpr LargeInteger(LargeInteger&&) = default; + constexpr LargeInteger& operator=(const LargeInteger&) = default; + constexpr LargeInteger& operator=(LargeInteger&&) = default; + + // Constructor from uint64_t (for small values) + template && std::is_unsigned_v>> + constexpr explicit LargeInteger(T value) { + m_data.fill(0); + if constexpr (sizeof(T) <= sizeof(uint64_t)) { + m_data[0] = static_cast(value); + } else { + // Handle larger types if needed + static_assert(sizeof(T) <= sizeof(uint64_t), "Type too large for LargeInteger"); + } + } + + // Access operators + constexpr uint64_t& operator[](size_t index) { return m_data[index]; } + constexpr const uint64_t& operator[](size_t index) const { return m_data[index]; } + + // Helper function to get the larger size type template - constexpr LargeInteger operator+(const LargeInteger& other) const { - LargeInteger result; - for (size_t i = 0; i < std::max(Size, OtherSize); i++) { - result[i] = m_data[i] + other[i]; + using LargerType = LargeInteger; + + // Helper function to promote operands to the same size + template + constexpr auto promote(const LargeInteger& other) const { + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger lhs_promoted{}; + LargeInteger rhs_promoted{}; + + // Copy data, padding with zeros + for (size_t i = 0; i < Size; ++i) { + lhs_promoted[i] = m_data[i]; + } + for (size_t i = 0; i < OtherSize; ++i) { + rhs_promoted[i] = other[i]; + } + + return std::make_pair(lhs_promoted, rhs_promoted); + } + + // Arithmetic operators + template + constexpr LargerType operator+(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result{}; + + uint64_t carry = 0; + for (size_t i = 0; i < ResultSize; ++i) { + uint64_t sum = lhs[i] + rhs[i] + carry; + result[i] = sum; + carry = (sum < lhs[i] || (carry && sum == lhs[i])) ? 1 : 0; + } + + return result; + } + + template + constexpr LargeInteger& operator+=(const LargeInteger& other) { + *this = *this + other; + return *this; + } + + template + constexpr LargerType operator-(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result{}; + + uint64_t borrow = 0; + for (size_t i = 0; i < ResultSize; ++i) { + uint64_t diff = lhs[i] - rhs[i] - borrow; + result[i] = diff; + borrow = (lhs[i] < rhs[i] + borrow) ? 1 : 0; + } + + return result; + } + + template + constexpr LargeInteger& operator-=(const LargeInteger& other) { + *this = *this - other; + return *this; + } + + template + constexpr LargerType operator*(const LargeInteger& other) const { +#if WFC_HAS_UINT128 + auto [lhs, rhs] = promote(other); + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result{}; // Multiplication can double the size + + for (size_t i = 0; i < ResultSize; ++i) { + uint64_t carry = 0; + for (size_t j = 0; j < ResultSize; ++j) { + __uint128_t product = static_cast<__uint128_t>(lhs[i]) * rhs[j] + result[i + j] + carry; + result[i + j] = static_cast(product); + carry = product >> 64; + } + size_t k = i + ResultSize; + while (carry && k < ResultSize * 2) { + __uint128_t sum = result[k] + carry; + result[k] = static_cast(sum); + carry = sum >> 64; + ++k; + } + } + + // Truncate to the larger of the original sizes + LargeInteger final_result{}; + for (size_t i = 0; i < ResultSize; ++i) { + final_result[i] = result[i]; + } + return final_result; +#else + throw std::runtime_error("LargeInteger multiplication requires __uint128_t support, which is not available on this compiler/platform"); +#endif + } + + template + constexpr LargeInteger& operator*=(const LargeInteger& other) { + *this = *this * other; + return *this; + } + + // Division and modulo (simplified implementation) + template + constexpr LargerType operator/(const LargeInteger& other) const { + // Simplified division - assumes other is not zero and result fits + auto [lhs, rhs] = promote(other); + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result{}; + + // This is a very basic division implementation + // For a full implementation, you'd need proper long division + LargeInteger temp = lhs; + while (temp >= rhs) { + temp = temp - rhs; + result = result + LargeInteger{1}; + } + + return result; + } + + template + constexpr LargerType operator%(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger temp = lhs; + while (temp >= rhs) { + temp = temp - rhs; + } + return temp; + } + + // Unary operators + constexpr LargeInteger operator-() const { + LargeInteger result{}; + for (size_t i = 0; i < Size; ++i) { + result[i] = ~m_data[i] + 1; // Two's complement + } + return result; + } + + constexpr LargeInteger operator~() const { + LargeInteger result{}; + for (size_t i = 0; i < Size; ++i) { + result[i] = ~m_data[i]; + } + return result; + } + + // Bit operations + template + constexpr LargerType operator&(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + return lhs.bitwise_op(rhs, std::bit_and{}); + } + + template + constexpr LargeInteger& operator&=(const LargeInteger& other) { + *this = *this & other; + return *this; + } + + template + constexpr LargerType operator|(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + return lhs.bitwise_op(rhs, std::bit_or{}); + } + + template + constexpr LargeInteger& operator|=(const LargeInteger& other) { + *this = *this | other; + return *this; + } + + template + constexpr LargerType operator^(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + return lhs.bitwise_op(rhs, std::bit_xor{}); + } + + template + constexpr LargeInteger& operator^=(const LargeInteger& other) { + *this = *this ^ other; + return *this; + } + + template + constexpr LargerType operator<<(size_t shift) const { + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result = *this; + + size_t word_shift = shift / 64; + size_t bit_shift = shift % 64; + + if (word_shift >= ResultSize) { + result.m_data.fill(0); + return result; + } + + // Shift words + for (size_t i = ResultSize - 1; i >= word_shift; --i) { + result[i] = result[i - word_shift]; + } + for (size_t i = 0; i < word_shift; ++i) { + result[i] = 0; + } + + // Shift bits + if (bit_shift > 0) { + uint64_t carry = 0; + for (size_t i = word_shift; i < ResultSize; ++i) { + uint64_t new_carry = result[i] >> (64 - bit_shift); + result[i] = (result[i] << bit_shift) | carry; + carry = new_carry; + } + } + + return result; + } + + template + constexpr LargeInteger& operator<<=(size_t shift) { + *this = *this << shift; + return *this; + } + + template + constexpr LargerType operator>>(size_t shift) const { + constexpr size_t ResultSize = std::max(Size, OtherSize); + LargeInteger result = *this; + + size_t word_shift = shift / 64; + size_t bit_shift = shift % 64; + + if (word_shift >= ResultSize) { + result.m_data.fill(0); + return result; + } + + // Shift words + for (size_t i = 0; i < ResultSize - word_shift; ++i) { + result[i] = result[i + word_shift]; + } + for (size_t i = ResultSize - word_shift; i < ResultSize; ++i) { + result[i] = 0; + } + + // Shift bits + if (bit_shift > 0) { + uint64_t carry = 0; + for (size_t i = ResultSize - word_shift - 1; i < ResultSize; --i) { + uint64_t new_carry = result[i] << (64 - bit_shift); + result[i] = (result[i] >> bit_shift) | carry; + carry = new_carry; + if (i == 0) break; + } + } + + return result; + } + + template + constexpr LargeInteger& operator>>=(size_t shift) { + *this = *this >> shift; + return *this; + } + + // Comparison operators + template + constexpr bool operator==(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + return lhs.m_data == rhs.m_data; + } + + template + constexpr bool operator!=(const LargeInteger& other) const { + return !(*this == other); + } + + template + constexpr bool operator<(const LargeInteger& other) const { + auto [lhs, rhs] = promote(other); + for (size_t i = lhs.m_data.size(); i > 0; --i) { + if (lhs.m_data[i-1] != rhs.m_data[i-1]) { + return lhs.m_data[i-1] < rhs.m_data[i-1]; + } + } + return false; + } + + template + constexpr bool operator<=(const LargeInteger& other) const { + return *this < other || *this == other; + } + + template + constexpr bool operator>(const LargeInteger& other) const { + return other < *this; + } + + template + constexpr bool operator>=(const LargeInteger& other) const { + return other <= *this; + } + + // std::bit library functions + constexpr int countl_zero() const { + for (size_t i = Size; i > 0; --i) { + if (m_data[i-1] != 0) { + return std::countl_zero(m_data[i-1]) + (Size - i) * 64; + } + } + return Size * 64; + } + + constexpr int countl_one() const { + for (size_t i = Size; i > 0; --i) { + if (m_data[i-1] != std::numeric_limits::max()) { + return std::countl_one(m_data[i-1]) + (Size - i) * 64; + } + } + return Size * 64; + } + + constexpr int countr_zero() const { + for (size_t i = 0; i < Size; ++i) { + if (m_data[i] != 0) { + return std::countr_zero(m_data[i]) + i * 64; + } + } + return Size * 64; + } + + constexpr int countr_one() const { + for (size_t i = 0; i < Size; ++i) { + if (m_data[i] != std::numeric_limits::max()) { + return std::countr_one(m_data[i]) + i * 64; + } + } + return Size * 64; + } + + constexpr int popcount() const { + int count = 0; + for (size_t i = 0; i < Size; ++i) { + count += std::popcount(m_data[i]); + } + return count; + } + + template + constexpr LargerType rotl(size_t shift) const { + shift %= (Size * 64); + return (*this << shift) | (*this >> ((Size * 64) - shift)); + } + + template + constexpr LargerType rotr(size_t shift) const { + shift %= (Size * 64); + return (*this >> shift) | (*this << ((Size * 64) - shift)); + } + + constexpr bool has_single_bit() const { + return popcount() == 1; + } + + constexpr LargeInteger bit_ceil() const { + if (*this == LargeInteger{0}) return LargeInteger{1}; + + LargeInteger result = *this; + result -= LargeInteger{1}; + result |= result >> 1; + result |= result >> 2; + result |= result >> 4; + result |= result >> 8; + result |= result >> 16; + result |= result >> 32; + + // Handle multi-word case + for (size_t i = 1; i < Size; ++i) { + if (result[i] != 0) { + // Find the highest set bit in the higher words + size_t highest_word = Size - 1; + for (size_t j = Size - 1; j > 0; --j) { + if (result[j] != 0) { + highest_word = j; + break; + } + } + // Set all lower words to 0 and the highest word to the power of 2 + for (size_t j = 0; j < highest_word; ++j) { + result[j] = 0; + } + result[highest_word] = uint64_t(1) << (63 - std::countl_zero(result[highest_word])); + break; + } + } + + result += LargeInteger{1}; + return result; + } + + constexpr LargeInteger bit_floor() const { + if (*this == LargeInteger{0}) return LargeInteger{0}; + + LargeInteger result = *this; + result |= result >> 1; + result |= result >> 2; + result |= result >> 4; + result |= result >> 8; + result |= result >> 16; + result |= result >> 32; + + // Handle multi-word case + for (size_t i = 1; i < Size; ++i) { + if (result[i] != 0) { + size_t highest_word = Size - 1; + for (size_t j = Size - 1; j > 0; --j) { + if (result[j] != 0) { + highest_word = j; + break; + } + } + for (size_t j = 0; j < highest_word; ++j) { + result[j] = 0; + } + result[highest_word] = uint64_t(1) << (63 - std::countl_zero(result[highest_word])); + return result; + } + } + + // Single word case + result = LargeInteger{uint64_t(1) << (63 - std::countl_zero(result[0]))}; + return result; + } + + constexpr int bit_width() const { + if (*this == LargeInteger{0}) return 0; + + for (size_t i = Size; i > 0; --i) { + if (m_data[i-1] != 0) { + return (i - 1) * 64 + 64 - std::countl_zero(m_data[i-1]); + } + } + return 0; + } + +private: + // Helper function for bitwise operations + template + constexpr LargeInteger bitwise_op(const LargeInteger& other, Op op) const { + LargeInteger result{}; + for (size_t i = 0; i < Size; ++i) { + result[i] = op(m_data[i], other[i]); } return result; } }; -} \ No newline at end of file +// Deduction guide for constructor from integral types +template +LargeInteger(T) -> LargeInteger<1>; + +} // namespace WFC \ No newline at end of file diff --git a/include/nd-wfc/wfc_random.hpp b/include/nd-wfc/wfc_random.hpp index 92cada7..fb713ba 100644 --- a/include/nd-wfc/wfc_random.hpp +++ b/include/nd-wfc/wfc_random.hpp @@ -21,7 +21,7 @@ public: return static_cast(rng(possibleValues.size())); } - constexpr uint32_t rng(uint32_t max) { + constexpr uint32_t rng(uint32_t max) const { m_seed = m_seed * 1103515245 + 12345; return m_seed % max; } @@ -45,7 +45,7 @@ public: return rng(possibleValues.size()); } - uint32_t rng(uint32_t max) { + uint32_t rng(uint32_t max) const { std::uniform_int_distribution dist(0, max); return dist(m_rng); } diff --git a/include/nd-wfc/wfc_utils.hpp b/include/nd-wfc/wfc_utils.hpp index 191020c..26a955d 100644 --- a/include/nd-wfc/wfc_utils.hpp +++ b/include/nd-wfc/wfc_utils.hpp @@ -24,4 +24,21 @@ inline int FindNthSetBit(size_t num, int n) { return bitCount; } +template +struct WorldValue +{ +public: + WorldValue() = default; + WorldValue(VarT value, uint16_t internalIndex) + : Value(value) + , InternalIndex(internalIndex) + {} +public: + operator VarT() const { return Value; } + +public: + VarT Value{}; + uint16_t InternalIndex{}; +}; + } \ No newline at end of file diff --git a/include/nd-wfc/wfc_variable_map.hpp b/include/nd-wfc/wfc_variable_map.hpp index ab998c0..a4d553c 100644 --- a/include/nd-wfc/wfc_variable_map.hpp +++ b/include/nd-wfc/wfc_variable_map.hpp @@ -67,6 +67,12 @@ public: } static consteval size_t size() { return ValuesRegisteredAmount; } + + template + static constexpr auto ValuesToIndices() -> std::array { + std::array indices = {GetIndex()...}; + return indices; + } }; diff --git a/include/nd-wfc/wfc_wave.hpp b/include/nd-wfc/wfc_wave.hpp index 7456244..7ace7b2 100644 --- a/include/nd-wfc/wfc_wave.hpp +++ b/include/nd-wfc/wfc_wave.hpp @@ -11,10 +11,12 @@ class Wave { public: using BitContainerT = BitContainer; using ElementT = typename BitContainerT::StorageType; + using IDMapT = VariableIDMapT; + static constexpr size_t ElementsAmount = Size; public: Wave() = default; - Wave(size_t size, size_t variableAmount, WFCStackAllocator& allocator) : m_data(size, WFCStackAllocatorAdapter(allocator)) + Wave(size_t size, size_t variableAmount, WFCStackAllocator& allocator) : m_data(size, allocator) { for (auto& wave : m_data) wave = (1 << variableAmount) - 1; }