From 24774006e7132391b5579c6fdadcc7ab33b31d5f Mon Sep 17 00:00:00 2001
From: Tobias Frisch <tfrisch@uni-koblenz.de>
Date: Tue, 29 Nov 2022 22:26:52 +0100
Subject: [PATCH] Allow compaction of bottom-level acceleration structures to
 improve memory usage

Signed-off-by: Tobias Frisch <tfrisch@uni-koblenz.de>
---
 include/vkcv/Core.hpp                     |   3 +-
 modules/scene/src/vkcv/scene/Mesh.cpp     |   3 +-
 src/vkcv/AccelerationStructureManager.cpp | 194 ++++++++++++++++++++--
 src/vkcv/AccelerationStructureManager.hpp |   7 +-
 src/vkcv/Core.cpp                         |   6 +-
 5 files changed, 191 insertions(+), 22 deletions(-)

diff --git a/include/vkcv/Core.hpp b/include/vkcv/Core.hpp
index 985c0fc6..00f38864 100644
--- a/include/vkcv/Core.hpp
+++ b/include/vkcv/Core.hpp
@@ -969,7 +969,8 @@ namespace vkcv {
 		 */
 		AccelerationStructureHandle createAccelerationStructure(
 				const std::vector<GeometryData> &geometryData,
-				const BufferHandle &transformBuffer = {});
+				const BufferHandle &transformBuffer = {},
+				bool compaction = false);
 		
 		/**
 		 * @brief Creates an acceleration structure handle built with a given list of
diff --git a/modules/scene/src/vkcv/scene/Mesh.cpp b/modules/scene/src/vkcv/scene/Mesh.cpp
index 03eb8a59..aa31bb0b 100644
--- a/modules/scene/src/vkcv/scene/Mesh.cpp
+++ b/modules/scene/src/vkcv/scene/Mesh.cpp
@@ -146,7 +146,8 @@ namespace vkcv::scene {
 		
 		const AccelerationStructureHandle handle = core.createAccelerationStructure(
 				geometryData,
-				transformBuffer.getHandle()
+				transformBuffer.getHandle(),
+				true
 		);
 		
 		if (handle) {
diff --git a/src/vkcv/AccelerationStructureManager.cpp b/src/vkcv/AccelerationStructureManager.cpp
index 8c51a9ad..18ab457e 100644
--- a/src/vkcv/AccelerationStructureManager.cpp
+++ b/src/vkcv/AccelerationStructureManager.cpp
@@ -28,7 +28,7 @@ namespace vkcv {
 		auto &accelerationStructure = getById(id);
 		
 		if (accelerationStructure.m_accelerationStructure) {
-			getCore().getContext().getDevice().destroyAccelerationStructureKHR(
+			getCore().getContext().getDevice().destroy(
 					accelerationStructure.m_accelerationStructure,
 					nullptr,
 					getCore().getContext().getDispatchLoaderDynamic()
@@ -37,13 +37,13 @@ namespace vkcv {
 			accelerationStructure.m_accelerationStructure = nullptr;
 		}
 		
-		if (accelerationStructure.m_storageBuffer) {
-			accelerationStructure.m_storageBuffer = BufferHandle();
-		}
-		
 		if (!accelerationStructure.m_children.empty()) {
 			accelerationStructure.m_children.clear();
 		}
+		
+		if (accelerationStructure.m_storageBuffer) {
+			accelerationStructure.m_storageBuffer = BufferHandle();
+		}
 	}
 	
 	const BufferManager &AccelerationStructureManager::getBufferManager() const {
@@ -120,9 +120,10 @@ namespace vkcv {
 			BufferManager& bufferManager,
 			std::vector<vk::AccelerationStructureBuildGeometryInfoKHR> &geometryInfos,
 			const std::vector<vk::AccelerationStructureBuildRangeInfoKHR> &rangeInfos,
+			vk::AccelerationStructureTypeKHR accelerationStructureType,
 			size_t accelerationStructureSize,
 			size_t scratchBufferSize,
-			vk::AccelerationStructureTypeKHR accelerationStructureType) {
+			const vk::QueryPool &compactionQueryPool) {
 		const auto &dynamicDispatch = core.getContext().getDispatchLoaderDynamic();
 		const vk::PhysicalDevice &physicalDevice = core.getContext().getPhysicalDevice();
 		
@@ -197,13 +198,21 @@ namespace vkcv {
 		
 		core.recordCommandsToStream(
 				cmdStream,
-				[&geometryInfos, &pRangeInfos, &dynamicDispatch](
+				[&geometryInfos, &pRangeInfos, &compactionQueryPool, &dynamicDispatch](
 						const vk::CommandBuffer &cmdBuffer) {
 					const vk::MemoryBarrier barrier (
 							vk::AccessFlagBits::eAccelerationStructureWriteKHR,
 							vk::AccessFlagBits::eAccelerationStructureReadKHR
 					);
 					
+					if (compactionQueryPool) {
+						cmdBuffer.resetQueryPool(
+								compactionQueryPool,
+								0,
+								geometryInfos.size()
+						);
+					}
+					
 					for (size_t i = 0; i < geometryInfos.size(); i++) {
 						cmdBuffer.buildAccelerationStructuresKHR(
 								1,
@@ -220,6 +229,16 @@ namespace vkcv {
 								nullptr,
 								nullptr
 						);
+						
+						if (compactionQueryPool) {
+							cmdBuffer.writeAccelerationStructuresPropertiesKHR(
+									geometryInfos[i].dstAccelerationStructure,
+									vk::QueryType::eAccelerationStructureCompactedSizeKHR,
+									compactionQueryPool,
+									i,
+									dynamicDispatch
+							);
+						}
 					}
 				},
 				nullptr
@@ -228,15 +247,98 @@ namespace vkcv {
 		core.submitCommandStream(cmdStream, false);
 		
 		return {
+			accelerationStructureType,
+			accelerationStructureSize,
 			accelerationStructure,
-			asStorageBuffer,
-			{}
+			{},
+			asStorageBuffer
 		};
 	}
 	
+	static void compactAccelerationStructure(Core& core,
+											 BufferManager& bufferManager,
+											 size_t compactSize,
+											 AccelerationStructureEntry& entry) {
+		const auto &dynamicDispatch = core.getContext().getDispatchLoaderDynamic();
+		
+		if ((compactSize <= 0) || (compactSize >= entry.m_size)) {
+			vkcv_log(LogLevel::WARNING, "Skip compaction because it will not improve memory usage");
+			return;
+		}
+		
+		const BufferHandle &compactStorageBuffer = bufferManager.createBuffer(
+				typeGuard<uint8_t>(),
+				BufferType::ACCELERATION_STRUCTURE_STORAGE,
+				BufferMemoryType::DEVICE_LOCAL,
+				compactSize,
+				false
+		);
+		
+		if (!compactStorageBuffer) {
+			return;
+		}
+		
+		const vk::AccelerationStructureCreateInfoKHR asCreateInfo (
+				vk::AccelerationStructureCreateFlagsKHR(),
+				bufferManager.getBuffer(compactStorageBuffer),
+				0,
+				compactSize,
+				entry.m_type
+		);
+		
+		vk::AccelerationStructureKHR accelerationStructure;
+		const vk::Result result = core.getContext().getDevice().createAccelerationStructureKHR(
+				&asCreateInfo,
+				nullptr,
+				&accelerationStructure,
+				dynamicDispatch
+		);
+		
+		if (result != vk::Result::eSuccess) {
+			return;
+		}
+		
+		auto cmdStream = core.createCommandStream(vkcv::QueueType::Compute);
+		
+		core.recordCommandsToStream(
+				cmdStream,
+				[&entry, &accelerationStructure, &dynamicDispatch](const vk::CommandBuffer &cmdBuffer) {
+					const vk::CopyAccelerationStructureInfoKHR copyAccelerationStructureInfo (
+							entry.m_accelerationStructure,
+							accelerationStructure,
+							vk::CopyAccelerationStructureModeKHR::eCompact
+					);
+					
+					cmdBuffer.copyAccelerationStructureKHR(
+							copyAccelerationStructureInfo,
+							dynamicDispatch
+					);
+				},
+				[&core,
+				 &entry,
+				 compactSize,
+				 accelerationStructure,
+				 &compactStorageBuffer,
+				 &dynamicDispatch]() {
+					core.getContext().getDevice().destroy(
+							entry.m_accelerationStructure,
+							nullptr,
+							dynamicDispatch
+					);
+					
+					entry.m_size = compactSize;
+					entry.m_accelerationStructure = accelerationStructure;
+					entry.m_storageBuffer = compactStorageBuffer;
+				}
+		);
+		
+		core.submitCommandStream(cmdStream, false);
+	}
+	
 	AccelerationStructureHandle AccelerationStructureManager::createAccelerationStructure(
 			const std::vector<GeometryData> &geometryData,
-			const BufferHandle &transformBuffer) {
+			const BufferHandle &transformBuffer,
+			bool compaction) {
 		std::vector<vk::AccelerationStructureGeometryKHR> geometries;
 		std::vector<vk::AccelerationStructureBuildGeometryInfoKHR> geometryInfos;
 		std::vector<vk::AccelerationStructureBuildRangeInfoKHR> rangeInfos;
@@ -318,9 +420,17 @@ namespace vkcv {
 		}
 		
 		{
+			vk::BuildAccelerationStructureFlagsKHR buildFlags (
+					vk::BuildAccelerationStructureFlagBitsKHR::ePreferFastTrace
+			);
+			
+			if (compaction) {
+				buildFlags |= vk::BuildAccelerationStructureFlagBitsKHR::eAllowCompaction;
+			}
+			
 			const vk::AccelerationStructureBuildGeometryInfoKHR asBuildGeometryInfo(
 					vk::AccelerationStructureTypeKHR::eBottomLevel,
-					vk::BuildAccelerationStructureFlagBitsKHR::ePreferFastTrace,
+					buildFlags,
 					vk::BuildAccelerationStructureModeKHR::eBuild,
 					{},
 					{},
@@ -340,23 +450,74 @@ namespace vkcv {
 				dynamicDispatch
 		);
 		
-		accelerationStructureSize += asBuildSizesInfo.accelerationStructureSize;
-		scratchBufferSize = std::max(scratchBufferSize, asBuildSizesInfo.buildScratchSize);
+		{
+			accelerationStructureSize += asBuildSizesInfo.accelerationStructureSize;
+			scratchBufferSize = std::max(scratchBufferSize, asBuildSizesInfo.buildScratchSize);
+		}
+		
+		vk::QueryPool compactionQueryPool;
+		
+		if (compaction) {
+			const vk::QueryPoolCreateInfo queryPoolCreateInfo (
+					vk::QueryPoolCreateFlags(),
+					vk::QueryType::eAccelerationStructureCompactedSizeKHR,
+					static_cast<uint32_t>(geometryInfos.size())
+			);
+			
+			compactionQueryPool = getCore().getContext().getDevice().createQueryPool(
+					queryPoolCreateInfo
+			);
+		}
 		
-		const auto entry = buildAccelerationStructure(
+		auto entry = buildAccelerationStructure(
 				getCore(),
 				bufferManager,
 				geometryInfos,
 				rangeInfos,
+				vk::AccelerationStructureTypeKHR::eBottomLevel,
 				accelerationStructureSize,
 				scratchBufferSize,
-				vk::AccelerationStructureTypeKHR::eBottomLevel
+				compactionQueryPool
 		);
 		
 		if ((!entry.m_accelerationStructure) || (!entry.m_storageBuffer)) {
+			if (compactionQueryPool) {
+				getCore().getContext().getDevice().destroy(compactionQueryPool);
+			}
+			
 			return {};
 		}
 		
+		if (compactionQueryPool) {
+			const auto compactSizes = (
+					getCore().getContext().getDevice().getQueryPoolResults<vk::DeviceSize>(
+							compactionQueryPool,
+							0,
+							geometryInfos.size(),
+							geometryInfos.size() * sizeof(vk::DeviceSize),
+							sizeof(vk::DeviceSize),
+							vk::QueryResultFlagBits::eWait
+					)
+			);
+			
+			if (compactSizes.result == vk::Result::eSuccess) {
+				accelerationStructureSize = 0;
+				
+				for (const auto& compactSize : compactSizes.value) {
+					accelerationStructureSize += compactSize;
+				}
+				
+				compactAccelerationStructure(
+						getCore(),
+						bufferManager,
+						accelerationStructureSize,
+						entry
+				);
+			}
+			
+			getCore().getContext().getDevice().destroy(compactionQueryPool);
+		}
+		
 		return add(entry);
 	}
 	
@@ -489,9 +650,10 @@ namespace vkcv {
 				bufferManager,
 				asBuildGeometryInfos,
 				{ asBuildRangeInfo },
+				vk::AccelerationStructureTypeKHR::eTopLevel,
 				asBuildSizesInfo.accelerationStructureSize,
 				asBuildSizesInfo.buildScratchSize,
-				vk::AccelerationStructureTypeKHR::eTopLevel
+				nullptr
 		);
 		
 		if ((!entry.m_accelerationStructure) || (!entry.m_storageBuffer)) {
diff --git a/src/vkcv/AccelerationStructureManager.hpp b/src/vkcv/AccelerationStructureManager.hpp
index d5afaadd..041b4acf 100644
--- a/src/vkcv/AccelerationStructureManager.hpp
+++ b/src/vkcv/AccelerationStructureManager.hpp
@@ -19,9 +19,11 @@
 namespace vkcv {
 	
 	struct AccelerationStructureEntry {
+		vk::AccelerationStructureTypeKHR m_type;
+		vk::DeviceSize m_size;
 		vk::AccelerationStructureKHR m_accelerationStructure;
-		BufferHandle m_storageBuffer;
 		std::vector<AccelerationStructureHandle> m_children;
+		BufferHandle m_storageBuffer;
 	};
 	
 	/**
@@ -78,7 +80,8 @@ namespace vkcv {
 		
 		[[nodiscard]] AccelerationStructureHandle createAccelerationStructure(
 				const std::vector<GeometryData> &geometryData,
-				const BufferHandle &transformBuffer);
+				const BufferHandle &transformBuffer,
+				bool compaction);
 		
 		[[nodiscard]] AccelerationStructureHandle createAccelerationStructure(
 				const std::vector<AccelerationStructureHandle> &accelerationStructures);
diff --git a/src/vkcv/Core.cpp b/src/vkcv/Core.cpp
index 142f4fff..d28f5669 100644
--- a/src/vkcv/Core.cpp
+++ b/src/vkcv/Core.cpp
@@ -1362,10 +1362,12 @@ namespace vkcv {
 	
 	AccelerationStructureHandle Core::createAccelerationStructure(
 			const std::vector<GeometryData> &geometryData,
-			const BufferHandle &transformBuffer) {
+			const BufferHandle &transformBuffer,
+			bool compaction) {
 		return m_AccelerationStructureManager->createAccelerationStructure(
 				geometryData,
-				transformBuffer
+				transformBuffer,
+				compaction
 		);
 	}
 	
-- 
GitLab