diff --git a/include/vkcv/Core.hpp b/include/vkcv/Core.hpp index 6f5053e97963a8ebcdaf2ee97c4dd27ba24bf4b0..01d1455c20bc99d1956bd73a8a65eaef3350523c 100644 --- a/include/vkcv/Core.hpp +++ b/include/vkcv/Core.hpp @@ -342,13 +342,15 @@ namespace vkcv /** - * Prepares the @p rtxPipeline for ray generation by recording the @p shaderBindingTable to the @p cmdStreamHandle. + * Records the rtx ray generation to the @p cmdStreamHandle. * Currently only supports @p closestHit, @p rayGen and @c miss shaderstages @c. * @param cmdStreamHandle The command stream handle which receives relevant commands for drawing. * @param rtxPipeline The raytracing pipeline from the RTXModule. * @param rtxPipelineLayout The raytracing pipeline layout from the RTXModule. - * @param shaderBindingTable The shader binding table from the RTXModule. - * @param shaderGroupBaseAlignment The shader group base alignment from the RTXModule. + * @param rgenRegion The shader binding table region for ray generation shaders. + * @param rmissRegion The shader binding table region for ray miss shaders. + * @param rchitRegion The shader binding table region for ray closest hit shaders. + * @param rcallRegion The shader binding table region for callable shaders. * @param descriptorSetUsages The descriptor set usages. * @param pushConstants The push constants. * @param windowHandle The window handle defining in which window to render. @@ -357,8 +359,10 @@ namespace vkcv CommandStreamHandle cmdStreamHandle, vk::Pipeline rtxPipeline, vk::PipelineLayout rtxPipelineLayout, - vk::Buffer shaderBindingTable, - vk::DeviceSize shaderGroupBaseAlignment, + vk::StridedDeviceAddressRegionKHR rgenRegion, + vk::StridedDeviceAddressRegionKHR rmissRegion, + vk::StridedDeviceAddressRegionKHR rchitRegion, + vk::StridedDeviceAddressRegionKHR rcallRegion, const std::vector<DescriptorSetUsage>& descriptorSetUsages, const PushConstants& pushConstants, const WindowHandle windowHandle); diff --git a/projects/rtx_ambient_occlusion/src/RTX/RTX.cpp b/projects/rtx_ambient_occlusion/src/RTX/RTX.cpp index 6e380b79dcb0ac4cca634c6586c1e0aa727d4065..59f3da21409993bf6a460000a49c2ed67cb6d178 100644 --- a/projects/rtx_ambient_occlusion/src/RTX/RTX.cpp +++ b/projects/rtx_ambient_occlusion/src/RTX/RTX.cpp @@ -59,6 +59,39 @@ namespace vkcv::rtx { free(shaderHandleStorage); } + ShaderBindingTableRegions RTXModule::createRegions() { + // Define offsets for the RTX shaders. RayGen is the first allocated shader. Each following shader is + // shifted by shaderGroupBaseAlignment. + // Offset Calculation: offset = count of previous shaders * m_shaderGroupBaseAlignment + // Regions are hard coded + vk::DeviceSize rayGenOffset = 0; //First Shader group -> offset 0 * m_shaderGroupBaseAlignment =0 + vk::DeviceSize missOffset = m_shaderGroupBaseAlignment;//Second group, offset = 1 * m_shaderGroupBaseAlignment + vk::DeviceSize closestHitOffset = 2 * m_shaderGroupBaseAlignment; //Third group, offset = 2 * m_shaderGroupBaseAlignment + vk::DeviceSize shaderBindingTableSize = m_shaderGroupBaseAlignment * 3; // 3 hardcoded to rtx-shader count + + auto m_rtxDispatcher = vk::DispatchLoaderDynamic((PFN_vkGetInstanceProcAddr)m_core->getContext().getInstance().getProcAddr("vkGetInstanceProcAddr")); + m_rtxDispatcher.init(m_core->getContext().getInstance()); + + + // Create regions for the shader binding table buffer which are used for vk::CommandBuffer::traceRaysKHR + vk::StridedDeviceAddressRegionKHR rgenRegion; + vk::BufferDeviceAddressInfoKHR shaderBindingTableAddressInfo(m_shaderBindingTableBuffer.vulkanHandle); + rgenRegion.deviceAddress = m_core->getContext().getDevice().getBufferAddressKHR(shaderBindingTableAddressInfo, m_rtxDispatcher) + rayGenOffset; + rgenRegion.setStride(shaderBindingTableSize); + rgenRegion.setSize(shaderBindingTableSize); + vk::StridedDeviceAddressRegionKHR rmissRegion; + rmissRegion.deviceAddress = m_core->getContext().getDevice().getBufferAddressKHR(shaderBindingTableAddressInfo, m_rtxDispatcher) + missOffset; + rmissRegion.setStride(shaderBindingTableSize); + rmissRegion.setSize(shaderBindingTableSize); + vk::StridedDeviceAddressRegionKHR rchitRegion; + rchitRegion.deviceAddress = m_core->getContext().getDevice().getBufferAddressKHR(shaderBindingTableAddressInfo, m_rtxDispatcher) + closestHitOffset; + rchitRegion.setStride(shaderBindingTableSize); + rchitRegion.setSize(shaderBindingTableSize); + vk::StridedDeviceAddressRegionKHR rcallRegion = {}; + + return ShaderBindingTableRegions{ rgenRegion, rmissRegion, rchitRegion, rcallRegion }; + } + void RTXModule::RTXDescriptors(std::vector<vkcv::DescriptorSetHandle>& descriptorSetHandles) { diff --git a/projects/rtx_ambient_occlusion/src/RTX/RTX.hpp b/projects/rtx_ambient_occlusion/src/RTX/RTX.hpp index 77dd437ec776acf0b9aabb290ff4f0411748bf18..ece4ac8e6707039ab3fe166750140a8040aed924 100644 --- a/projects/rtx_ambient_occlusion/src/RTX/RTX.hpp +++ b/projects/rtx_ambient_occlusion/src/RTX/RTX.hpp @@ -7,6 +7,14 @@ namespace vkcv::rtx { + //struct that holds all shader binding table regions + struct ShaderBindingTableRegions { + vk::StridedDeviceAddressRegionKHR rgenRegion; + vk::StridedDeviceAddressRegionKHR rmissRegion; + vk::StridedDeviceAddressRegionKHR rchitRegion; + vk::StridedDeviceAddressRegionKHR rcallRegion; + }; + class RTXModule { private: @@ -66,6 +74,13 @@ namespace vkcv::rtx { */ void createShaderBindingTable(uint32_t shaderCount); + /** + * @brief Divides the shader binding table into regions for each shader type + * (ray generation, ray miss, ray closest hit, callable) and returns them as a struct. + * @return The struct holding all four regions of type vk::StridedDeviceAddressRegionKHR. + */ + ShaderBindingTableRegions createRegions(); + /** * @brief Creates Descriptor-Writes for RTX * @param descriptorSetHandles The descriptorSetHandles for RTX. diff --git a/projects/rtx_ambient_occlusion/src/main.cpp b/projects/rtx_ambient_occlusion/src/main.cpp index 50786b61f985fc126a7429272352d9f8972d0b98..dcabe95a2977afe4cdc6fb24cc8499c4b459785c 100644 --- a/projects/rtx_ambient_occlusion/src/main.cpp +++ b/projects/rtx_ambient_occlusion/src/main.cpp @@ -89,6 +89,8 @@ int main(int argc, const char** argv) { vk::Pipeline rtxPipeline = rtxModule.getPipeline(); vk::PipelineLayout rtxPipelineLayout = rtxModule.getPipelineLayout(); + vkcv::rtx::ShaderBindingTableRegions rtxRegions = rtxModule.createRegions(); + vkcv::ImageHandle depthBuffer = core.createImage(vk::Format::eD32Sfloat, windowWidth, windowHeight).getHandle(); const vkcv::ImageHandle swapchainInput = vkcv::ImageHandle::createSwapchainImageHandle(); @@ -142,8 +144,10 @@ int main(int argc, const char** argv) { cmdStream, rtxPipeline, rtxPipelineLayout, - rtxModule.getShaderBindingTableBuffer(), - rtxModule.getShaderGroupBaseAlignment(), + rtxRegions.rgenRegion, + rtxRegions.rmissRegion, + rtxRegions.rchitRegion, + rtxRegions.rcallRegion, { vkcv::DescriptorSetUsage(0, core.getDescriptorSet(rtxShaderDescriptorSet).vulkanHandle)}, pushConstantsRTX, windowHandle); diff --git a/src/vkcv/Core.cpp b/src/vkcv/Core.cpp index 7b097c018862f3d8669b3104149208b7721b15e2..de01b590c13e4941bd99619de05f2da144d6a2ec 100644 --- a/src/vkcv/Core.cpp +++ b/src/vkcv/Core.cpp @@ -407,11 +407,13 @@ namespace vkcv void Core::recordRayGenerationToCmdStream( - CommandStreamHandle cmdStreamHandle, - vk::Pipeline rtxPipeline, - vk::PipelineLayout rtxPipelineLayout, - vk::Buffer shaderBindingTable, - vk::DeviceSize shaderGroupBaseAlignment, + CommandStreamHandle cmdStreamHandle, + vk::Pipeline rtxPipeline, + vk::PipelineLayout rtxPipelineLayout, + vk::StridedDeviceAddressRegionKHR rgenRegion, + vk::StridedDeviceAddressRegionKHR rmissRegion, + vk::StridedDeviceAddressRegionKHR rchitRegion, + vk::StridedDeviceAddressRegionKHR rcallRegion, const std::vector<DescriptorSetUsage>& descriptorSetUsages, const PushConstants& pushConstants, const WindowHandle windowHandle) { @@ -431,37 +433,15 @@ namespace vkcv if (pushConstants.getSizePerDrawcall() > 0) { cmdBuffer.pushConstants( rtxPipelineLayout, - (vk::ShaderStageFlagBits::eClosestHitKHR | vk::ShaderStageFlagBits::eMissKHR | vk::ShaderStageFlagBits::eRaygenKHR), + (vk::ShaderStageFlagBits::eClosestHitKHR | vk::ShaderStageFlagBits::eMissKHR | vk::ShaderStageFlagBits::eRaygenKHR), // TODO: add Support for eAnyHitKHR, eCallableKHR, eIntersectionKHR 0, pushConstants.getSizePerDrawcall(), pushConstants.getData()); } - // Define offsets for the RTX shaders. RayGen is the first allocated shader. Each following shader is - // shifted by shaderGroupBaseAlignment. - vk::DeviceSize rayGenOffset = 0; - vk::DeviceSize missOffset = shaderGroupBaseAlignment; - vk::DeviceSize closestHitOffset = 2 * shaderGroupBaseAlignment; - vk::DeviceSize shaderBindingTableSize = shaderGroupBaseAlignment * 3; // 3 hardcoded to rtx-shader count - + auto m_rtxDispatcher = vk::DispatchLoaderDynamic((PFN_vkGetInstanceProcAddr)m_Context.getInstance().getProcAddr("vkGetInstanceProcAddr")); m_rtxDispatcher.init(m_Context.getInstance()); - // Create regions for the shader binding table buffer which are used for vk::CommandBuffer::traceRaysKHR - vk::StridedDeviceAddressRegionKHR rgenRegion; - vk::BufferDeviceAddressInfoKHR shaderBindingTableAddressInfo(shaderBindingTable); - rgenRegion.deviceAddress = m_Context.getDevice().getBufferAddressKHR(shaderBindingTableAddressInfo, m_rtxDispatcher) + rayGenOffset; - rgenRegion.setStride(shaderBindingTableSize); - rgenRegion.setSize(shaderBindingTableSize); - vk::StridedDeviceAddressRegionKHR rmissRegion; - rmissRegion.deviceAddress = m_Context.getDevice().getBufferAddressKHR(shaderBindingTableAddressInfo, m_rtxDispatcher) + missOffset; - rmissRegion.setStride(shaderBindingTableSize); - rmissRegion.setSize(shaderBindingTableSize); - vk::StridedDeviceAddressRegionKHR rchitRegion; - rchitRegion.deviceAddress = m_Context.getDevice().getBufferAddressKHR(shaderBindingTableAddressInfo, m_rtxDispatcher) + closestHitOffset; - rchitRegion.setStride(shaderBindingTableSize); - rchitRegion.setSize(shaderBindingTableSize); - vk::StridedDeviceAddressRegionKHR rcallRegion = {}; - cmdBuffer.traceRaysKHR(&rgenRegion,&rmissRegion,&rchitRegion,&rcallRegion, getWindow(windowHandle).getWidth(), getWindow(windowHandle).getHeight(),1, m_rtxDispatcher);