From 5736b55bb3e36ca8bf51a38ccdf64f08f4ad511a Mon Sep 17 00:00:00 2001
From: Lars Hoerttrich <larshoerttrich@uni-koblenz.de>
Date: Sat, 11 Sep 2021 16:44:33 +0200
Subject: [PATCH] [#92] descriptorWrites for Acceleration structure and index
 and vertexbuffer

---
 modules/rtx/include/vkcv/rtx/ASManager.hpp    |  8 +++
 modules/rtx/include/vkcv/rtx/RTX.hpp          | 13 +++-
 modules/rtx/src/vkcv/rtx/ASManager.cpp        | 10 +++
 modules/rtx/src/vkcv/rtx/RTX.cpp              | 72 +++++++++++++++++--
 modules/scene/src/vkcv/scene/Scene.cpp        |  2 +-
 projects/rtx/resources/shaders/raytrace.rchit |  9 ++-
 projects/rtx/resources/shaders/raytrace.rgen  |  9 +--
 projects/rtx/src/main.cpp                     | 39 +++++-----
 src/vkcv/ShaderProgram.cpp                    | 22 ++++++
 9 files changed, 150 insertions(+), 34 deletions(-)

diff --git a/modules/rtx/include/vkcv/rtx/ASManager.hpp b/modules/rtx/include/vkcv/rtx/ASManager.hpp
index 688f40a2..a8b34576 100644
--- a/modules/rtx/include/vkcv/rtx/ASManager.hpp
+++ b/modules/rtx/include/vkcv/rtx/ASManager.hpp
@@ -91,5 +91,13 @@ namespace vkcv::rtx {
          * #ASManager::m_accelerationStructures objects.
          */
         void buildTLAS();
+
+        /**
+        * @brief Returns the top level acceleration structure buffer
+        * @return A #TopLevelAccelerationStructure object holding the tlas.
+        */
+        TopLevelAccelerationStructure getTLAS();
+
+        BottomLevelAccelerationStructure getBLAS(uint32_t id);
     };
 }
\ No newline at end of file
diff --git a/modules/rtx/include/vkcv/rtx/RTX.hpp b/modules/rtx/include/vkcv/rtx/RTX.hpp
index 9c2d0026..07182efa 100644
--- a/modules/rtx/include/vkcv/rtx/RTX.hpp
+++ b/modules/rtx/include/vkcv/rtx/RTX.hpp
@@ -3,6 +3,8 @@
 #include <vector>
 #include "vulkan/vulkan.hpp"
 #include "vkcv/Core.hpp"
+#include "RTXDescriptorWrites.hpp"
+#include "ASManager.hpp"
 
 namespace vkcv::rtx {
 
@@ -12,6 +14,7 @@ namespace vkcv::rtx {
         std::vector<const char*> m_instanceExtensions;  // the instance extensions needed for using RTX
         std::vector<const char*> m_deviceExtensions;    // the device extensions needed for using RTX
         vkcv::Features m_features;                      // the features needed to be enabled for using RTX
+        ASManager* m_asManager;
 
     public:
 
@@ -48,8 +51,16 @@ namespace vkcv::rtx {
          * @param core The reference to the #Core.
          * @param vertices The scene vertex data of type uint8_t.
          * @param indices The scene index data of type uint8_t.
+         * @param descriptorSetHandles The descriptorSetHandles for RTX
          */
-        void init(Core* core, std::vector<uint8_t> &vertices, std::vector<uint8_t> &indices);
+        void init(Core* core, std::vector<uint8_t> &vertices, std::vector<uint8_t> &indices, std::vector<vkcv::DescriptorSetHandle> &descriptorSetHandles);
+
+        /**
+         * @brief Creates Descriptor-Writes for RTX
+         * @param asManager The ASManager of RTX
+         * @param descriptorSetHandles The descriptorSetHandles for RTX
+         */
+        void RTXDescriptors(ASManager* asManager, Core* core, std::vector<vkcv::DescriptorSetHandle>& descriptorSetHandles);
     };
 
 }
diff --git a/modules/rtx/src/vkcv/rtx/ASManager.cpp b/modules/rtx/src/vkcv/rtx/ASManager.cpp
index 72c58a9d..29b5b0e7 100644
--- a/modules/rtx/src/vkcv/rtx/ASManager.cpp
+++ b/modules/rtx/src/vkcv/rtx/ASManager.cpp
@@ -220,6 +220,16 @@ namespace vkcv::rtx {
         };
     }
 
+    TopLevelAccelerationStructure ASManager::getTLAS()
+    {
+        return m_topLevelAccelerationStructure;
+    }
+
+    BottomLevelAccelerationStructure ASManager::getBLAS(uint32_t id)
+    {
+        return m_bottomLevelAccelerationStructures[id];
+    }
+
 
     void ASManager::buildBLAS(std::vector<uint8_t> &vertices, std::vector<uint8_t> &indices) {
         uint32_t vertexCount = vertices.size();
diff --git a/modules/rtx/src/vkcv/rtx/RTX.cpp b/modules/rtx/src/vkcv/rtx/RTX.cpp
index a9e73a87..38144688 100644
--- a/modules/rtx/src/vkcv/rtx/RTX.cpp
+++ b/modules/rtx/src/vkcv/rtx/RTX.cpp
@@ -1,9 +1,8 @@
 #include "vkcv/rtx/RTX.hpp"
-#include "vkcv/rtx/ASManager.hpp"
 
 namespace vkcv::rtx {
 
-    RTXModule::RTXModule() {
+    RTXModule::RTXModule(){
 
         // prepare needed raytracing extensions
         m_instanceExtensions = {
@@ -112,11 +111,72 @@ namespace vkcv::rtx {
                 });
     }
 
-    void RTXModule::init(Core* core, std::vector<uint8_t> &vertices, std::vector<uint8_t> &indices) {
+    void RTXModule::init(Core* core, std::vector<uint8_t>& vertices,
+        std::vector<uint8_t>& indices, std::vector<vkcv::DescriptorSetHandle>& descriptorSetHandles)
+    {
         // build acceleration structures BLAS then TLAS --> see ASManager
-        ASManager asManager(core);
-        asManager.buildBLAS(vertices, indices);
-        asManager.buildTLAS();
+        //asManager(core);
+        ASManager temp(core);
+        m_asManager = &temp;
+        m_asManager->buildBLAS(vertices, indices);
+        m_asManager->buildTLAS();
+        RTXDescriptors(m_asManager, core, descriptorSetHandles);
+        
+    }
+
+    void RTXModule::RTXDescriptors(ASManager* asManager,Core* core, std::vector<vkcv::DescriptorSetHandle>& descriptorSetHandles)
+    {
+        //TLAS-Descriptor-Write
+        TopLevelAccelerationStructure tlas = asManager->getTLAS();
+        RTXBuffer tlasBuffer = tlas.tlasBuffer;
+        vk::WriteDescriptorSetAccelerationStructureKHR AccelerationDescriptor = {};
+        AccelerationDescriptor.accelerationStructureCount = 1;
+        const TopLevelAccelerationStructure constTLAS = tlas;
+        AccelerationDescriptor.pAccelerationStructures = &constTLAS.vulkanHandle;
+
+        vk::WriteDescriptorSet tlasWrite;
+        tlasWrite.setPNext(&AccelerationDescriptor);
+        tlasWrite.setDstSet(core->getDescriptorSet(descriptorSetHandles[0]).vulkanHandle);
+        tlasWrite.setDstBinding(1);
+        tlasWrite.setDstArrayElement(0);
+        tlasWrite.setDescriptorCount(1);
+        tlasWrite.setDescriptorType(vk::DescriptorType::eAccelerationStructureKHR);
+        core->getContext().getDevice().updateDescriptorSets(tlasWrite, nullptr);
+
+        //INDEX & VERTEX BUFFER
+        BottomLevelAccelerationStructure blas = asManager->getBLAS(0);//HARD CODED
+
+        //VERTEX BUFFER
+
+        vk::DescriptorBufferInfo vertexInfo = {};
+        vertexInfo.setBuffer(blas.vertexBuffer.vulkanHandle);
+        vertexInfo.setOffset(0);
+        vertexInfo.setRange(blas.vertexBuffer.deviceSize); //maybe check if size is correct
+
+        vk::WriteDescriptorSet vertexWrite;
+        vertexWrite.setDstSet(core->getDescriptorSet(descriptorSetHandles[1]).vulkanHandle);
+        vertexWrite.setDstBinding(3);
+        vertexWrite.setDescriptorCount(1);
+        vertexWrite.setDescriptorType(vk::DescriptorType::eStorageBuffer);
+        vertexWrite.setPBufferInfo(&vertexInfo);
+        core->getContext().getDevice().updateDescriptorSets(vertexWrite, nullptr);
+
+        //INDEXBUFFER
+        vk::DescriptorBufferInfo indexInfo = {};
+        indexInfo.setBuffer(blas.indexBuffer.vulkanHandle);
+        indexInfo.setOffset(0);
+        indexInfo.setRange(blas.indexBuffer.deviceSize); //maybe check if size is correct
+
+        vk::WriteDescriptorSet indexWrite;
+        indexWrite.setDstSet(core->getDescriptorSet(descriptorSetHandles[1]).vulkanHandle);
+        indexWrite.setDstBinding(4);
+        indexWrite.setDescriptorCount(1);
+        indexWrite.setDescriptorType(vk::DescriptorType::eStorageBuffer);
+        indexWrite.setPBufferInfo(&indexInfo);
+        core->getContext().getDevice().updateDescriptorSets(indexWrite, nullptr);
+
+          
+
     }
 
     std::vector<const char*> RTXModule::getInstanceExtensions() {
diff --git a/modules/scene/src/vkcv/scene/Scene.cpp b/modules/scene/src/vkcv/scene/Scene.cpp
index c0065af5..8aeccd8d 100644
--- a/modules/scene/src/vkcv/scene/Scene.cpp
+++ b/modules/scene/src/vkcv/scene/Scene.cpp
@@ -131,7 +131,7 @@ namespace vkcv::scene {
 			node.recordDrawcalls(viewProjection, pushConstants, drawcalls, record);
 		}
 		
-		vkcv_log(LogLevel::RAW_INFO, "Frustum culling: %lu / %lu", drawcalls.size(), count);
+		//vkcv_log(LogLevel::RAW_INFO, "Frustum culling: %lu / %lu", drawcalls.size(), count);
 		
 		m_core->recordDrawcallsToCmdStream(
 				cmdStream,
diff --git a/projects/rtx/resources/shaders/raytrace.rchit b/projects/rtx/resources/shaders/raytrace.rchit
index 836edc70..26c21636 100644
--- a/projects/rtx/resources/shaders/raytrace.rchit
+++ b/projects/rtx/resources/shaders/raytrace.rchit
@@ -1,18 +1,17 @@
 #version 460
 #extension GL_EXT_ray_tracing : require
 
-/*
-layout(binding = 3, set 0) buffer vertices
+
+layout(binding = 3, set = 0) buffer rtxVertices
 {
     vec3 vertices[];
 };
 
-layout(binding = 4, set 0) buffer indices
+layout(binding = 4, set = 0) buffer rtxIndices
 {
     uint indices[];
 };
-*/
 
 void main() {
-
+    int b = 42;
 }
diff --git a/projects/rtx/resources/shaders/raytrace.rgen b/projects/rtx/resources/shaders/raytrace.rgen
index 58b0f0ab..ac5a8355 100644
--- a/projects/rtx/resources/shaders/raytrace.rgen
+++ b/projects/rtx/resources/shaders/raytrace.rgen
@@ -1,13 +1,13 @@
 #version 460
 #extension GL_EXT_ray_tracing : require
 
-/*
+
 // A location for a ray payload (we can have multiple of these)
-layout(location = 0) rayPayloadEXT RayPayload pay;
+//layout(location = 0) rayPayloadEXT RayPayload pay;
 
-layout(binding = 0, set = 0, rgba32f) uniform image2D outImg;           // the output image -> maybe use 16 bit values?
+//layout(binding = 0, set = 0, rgba32f) uniform image2D outImg;           // the output image -> maybe use 16 bit values?
 layout(binding = 1, set = 0) uniform accelerationStructureEXT tlas;     // top level acceleration structure (for the noobs here (you!))
-
+/*
 layout( push_constant ) uniform constants {     // TODO: add push_constants in main.cpp!
     vec4 camera_position;   // as origin for ray generation
     vec4 camera_right;      // for computing ray direction
@@ -20,4 +20,5 @@ layout( push_constant ) uniform constants {     // TODO: add push_constants in m
 
 void main() {
     // TODO
+    int a = 42;
 }
\ No newline at end of file
diff --git a/projects/rtx/src/main.cpp b/projects/rtx/src/main.cpp
index b518d26f..510a8685 100644
--- a/projects/rtx/src/main.cpp
+++ b/projects/rtx/src/main.cpp
@@ -49,7 +49,7 @@ int main(int argc, const char** argv) {
 	);
 
 	vkcv::scene::Scene scene = vkcv::scene::Scene::load(core, std::filesystem::path(
-			argc > 1 ? argv[1] : "resources/Sponza/Sponza.gltf"
+			argc > 1 ? argv[1] : "resources/Cube/cube.gltf"
 	));
 
     // TODO: replace by bigger scene
@@ -77,9 +77,6 @@ int main(int argc, const char** argv) {
 	    indices.push_back(mesh.vertexGroups[0].indexBuffer.data[i]);
 	}
 
-	// init RTXModule
-    rtxModule.init(&core, vertices, indices);
-
 	const vkcv::AttachmentDescription present_color_attachment(
 		vkcv::AttachmentOperation::STORE,
 		vkcv::AttachmentOperation::CLEAR,
@@ -129,29 +126,37 @@ int main(int argc, const char** argv) {
 		[&rayGenShaderProgram](vkcv::ShaderStage shaderStage, const std::filesystem::path& path) {
 			rayGenShaderProgram.addShader(shaderStage, path);
 		});
-
+	
 	vkcv::ShaderProgram rayClosestHitShaderProgram;
 	compiler.compile(vkcv::ShaderStage::RAY_CLOSEST_HIT, std::filesystem::path("resources/shaders/raytrace.rchit"),
 		[&rayClosestHitShaderProgram](vkcv::ShaderStage shaderStage, const std::filesystem::path& path) {
 			rayClosestHitShaderProgram.addShader(shaderStage, path);
 		});
-
+	/*
 	vkcv::ShaderProgram rayMissShaderProgram;
 	compiler.compile(vkcv::ShaderStage::RAY_MISS, std::filesystem::path("resources/shaders/raytrace.rmiss"),
 		[&rayMissShaderProgram](vkcv::ShaderStage shaderStage, const std::filesystem::path& path) {
 			rayMissShaderProgram.addShader(shaderStage, path);
 		});
-
-
+	*/
+	std::vector<vkcv::DescriptorSetHandle> descriptorSetHandles;
 	//TODO
-//	vkcv::DescriptorSetLayoutHandle rayGenShaderDescriptorSetLayout = core.createDescriptorSetLayout(rayGenShaderProgram.getReflectedDescriptors().at(0));
-//	vkcv::DescriptorSetHandle rayGenShaderDescriptorSet = core.createDescriptorSet(rayGenShaderDescriptorSetLayout);
+	vkcv::DescriptorSetLayoutHandle rayGenShaderDescriptorSetLayout = core.createDescriptorSetLayout(rayGenShaderProgram.getReflectedDescriptors().at(0));
+	vkcv::DescriptorSetHandle rayGenShaderDescriptorSet = core.createDescriptorSet(rayGenShaderDescriptorSetLayout);//
+	descriptorSetHandles.push_back(rayGenShaderDescriptorSet);
 
-//	vkcv::DescriptorSetLayoutHandle rayClosestHitShaderDescriptorSetLayout = core.createDescriptorSetLayout(rayClosestHitShaderProgram.getReflectedDescriptors().at(0));
-//	vkcv::DescriptorSetHandle rayGenShaderDescriptorSet = core.createDescriptorSet(rayClosestHitShaderDescriptorSetLayout);
+	
+	vkcv::DescriptorSetLayoutHandle rayClosestHitShaderDescriptorSetLayout = core.createDescriptorSetLayout(rayClosestHitShaderProgram.getReflectedDescriptors().at(0));
+	vkcv::DescriptorSetHandle rayCHITShaderDescriptorSet = core.createDescriptorSet(rayClosestHitShaderDescriptorSetLayout);
+	descriptorSetHandles.push_back(rayCHITShaderDescriptorSet);
+	/*
+	vkcv::DescriptorSetLayoutHandle rayMissShaderDescriptorSetLayout = core.createDescriptorSetLayout(rayMissShaderProgram.getReflectedDescriptors().at(0));
+	vkcv::DescriptorSetHandle rayMissShaderDescriptorSet = core.createDescriptorSet(rayMissShaderDescriptorSetLayout);
+	descriptorSetHandles.push_back(rayMissShaderDescriptorSet);
+	*/
 
-//	vkcv::DescriptorSetLayoutHandle rayMissShaderDescriptorSetLayout = core.createDescriptorSetLayout(rayMissShaderProgram.getReflectedDescriptors().at(0));
-//	vkcv::DescriptorSetHandle rayGenShaderDescriptorSet = core.createDescriptorSet(rayMissShaderDescriptorSetLayout);
+	// init RTXModule
+	rtxModule.init(&core, vertices, indices,descriptorSetHandles);
 
 	const vkcv::PipelineConfig scenePipelineDefinition{
 		sceneShaderProgram,
@@ -164,9 +169,9 @@ int main(int argc, const char** argv) {
 	vkcv::PipelineHandle scenePipeline = core.createGraphicsPipeline(scenePipelineDefinition);
 
 	// TODO
-//	vkcv::DescriptorWrites vertexShaderDescriptorWrites;
-//	vertexShaderDescriptorWrites.storageBufferWrites = { vkcv::BufferDescriptorWrite(0, matrixBuffer.getHandle()) };
-//	core.writeDescriptorSet(vertexShaderDescriptorSet, vertexShaderDescriptorWrites);
+	//vkcv::RTXDescriptorWrites vertexShaderDescriptorWrites;
+	//vertexShaderDescriptorWrites.storageBufferWrites = { vkcv::BufferDescriptorWrite(0, matrixBuffer.getHandle()) };
+	//core.writeDescriptorSet(vertexShaderDescriptorSet, vertexShaderDescriptorWrites);
 
 	if (!scenePipeline) {
 		std::cout << "Error. Could not create graphics pipeline. Exiting." << std::endl;
diff --git a/src/vkcv/ShaderProgram.cpp b/src/vkcv/ShaderProgram.cpp
index a1634c23..65e4cab5 100644
--- a/src/vkcv/ShaderProgram.cpp
+++ b/src/vkcv/ShaderProgram.cpp
@@ -254,6 +254,28 @@ namespace vkcv {
             }
         }
 
+        for (uint32_t i = 0; i < resources.acceleration_structures.size(); i++) {
+            auto& u = resources.acceleration_structures[i];
+            const spirv_cross::SPIRType& base_type = comp.get_type(u.base_type_id);
+
+            uint32_t setID = comp.get_decoration(u.id, spv::DecorationDescriptorSet);
+            uint32_t bindingID = comp.get_decoration(u.id, spv::DecorationBinding);
+            auto binding = DescriptorBinding(
+                bindingID,
+                DescriptorType::ACCELERATION_STRUCTURE_KHR,
+                base_type.vecsize,
+                shaderStage);
+
+            auto insertionResult = m_DescriptorSets[setID].insert(std::make_pair(bindingID, binding));
+            if (!insertionResult.second)
+            {
+                vkcv_log(LogLevel::WARNING,
+                    "Attempting to overwrite already existing binding %u at set ID %u.",
+                    bindingID,
+                    setID);
+            }
+        }
+
         //reflect push constants
 		for (const auto &pushConstantBuffer : resources.push_constant_buffers)
 		{
-- 
GitLab