From 682504ad934fd323b7f5497cdbf81e355e0c7607 Mon Sep 17 00:00:00 2001
From: Alexander Gauggel <agauggel@uni-koblenz.de>
Date: Sat, 3 Jul 2021 13:28:36 +0200
Subject: [PATCH] [#87] WIP: add mesh shader drawcall, still need to obtain
 extension functions

---
 include/vkcv/Core.hpp              |  10 +-
 include/vkcv/DrawcallRecording.hpp |  15 ++
 src/vkcv/Core.cpp                  | 306 ++++++++++++++++++++---------
 src/vkcv/DrawcallRecording.cpp     |  31 +++
 4 files changed, 264 insertions(+), 98 deletions(-)

diff --git a/include/vkcv/Core.hpp b/include/vkcv/Core.hpp
index c7512346..1a1a8f55 100644
--- a/include/vkcv/Core.hpp
+++ b/include/vkcv/Core.hpp
@@ -239,13 +239,21 @@ namespace vkcv
 		bool beginFrame(uint32_t& width, uint32_t& height);
 
 		void recordDrawcallsToCmdStream(
-            const CommandStreamHandle       cmdStreamHandle,
+			const CommandStreamHandle       cmdStreamHandle,
 			const PassHandle                renderpassHandle, 
 			const PipelineHandle            pipelineHandle,
 			const PushConstantData          &pushConstantData,
 			const std::vector<DrawcallInfo> &drawcalls,
 			const std::vector<ImageHandle>  &renderTargets);
 
+		void recordMeshShaderDrawcalls(
+			const CommandStreamHandle               cmdStreamHandle,
+			const PassHandle                        renderpassHandle,
+			const PipelineHandle                    pipelineHandle,
+			const PushConstantData&                 pushConstantData,
+            const std::vector<MeshShaderDrawcall>&  drawcalls,
+			const std::vector<ImageHandle>&         renderTargets);
+
 		void recordComputeDispatchToCmdStream(
 			CommandStreamHandle cmdStream,
 			PipelineHandle computePipeline,
diff --git a/include/vkcv/DrawcallRecording.hpp b/include/vkcv/DrawcallRecording.hpp
index 0929ad03..e1526689 100644
--- a/include/vkcv/DrawcallRecording.hpp
+++ b/include/vkcv/DrawcallRecording.hpp
@@ -51,4 +51,19 @@ namespace vkcv {
         const PushConstantData  &pushConstantData,
         const size_t            drawcallIndex);
 
+    struct MeshShaderDrawcall {
+        inline MeshShaderDrawcall(const std::vector<DescriptorSetUsage> descriptorSets, uint32_t taskCout)
+            : descriptorSets(descriptorSets), taskCount(taskCount) {}
+
+        std::vector<DescriptorSetUsage> descriptorSets;
+        uint32_t                        taskCount;
+    };
+
+    void recordMeshShaderDrawcall(
+        vk::CommandBuffer                       cmdBuffer,
+        vk::PipelineLayout                      pipelineLayout,
+        const PushConstantData&                 pushConstantData,
+        const uint32_t                          pushConstantOffset,
+        const MeshShaderDrawcall&               drawcall,
+        const uint32_t                          firstTask);
 }
\ No newline at end of file
diff --git a/src/vkcv/Core.cpp b/src/vkcv/Core.cpp
index 1492b1af..194d68ce 100644
--- a/src/vkcv/Core.cpp
+++ b/src/vkcv/Core.cpp
@@ -177,133 +177,245 @@ namespace vkcv
 		return (m_currentSwapchainImageIndex != std::numeric_limits<uint32_t>::max());
 	}
 
-	void Core::recordDrawcallsToCmdStream(
-		const CommandStreamHandle       cmdStreamHandle,
-		const PassHandle                renderpassHandle, 
-		const PipelineHandle            pipelineHandle, 
-        const PushConstantData          &pushConstantData,
-        const std::vector<DrawcallInfo> &drawcalls,
-		const std::vector<ImageHandle>  &renderTargets) {
+	std::array<uint32_t, 2> getWidthHeightFromRenderTargets(
+		const std::vector<ImageHandle>& renderTargets,
+		const Swapchain& swapchain,
+		const ImageManager& imageManager) {
 
-		if (m_currentSwapchainImageIndex == std::numeric_limits<uint32_t>::max()) {
-			return;
-		}
+		std::array<uint32_t, 2> widthHeight;
 
-		uint32_t width;
-		uint32_t height;
 		if (renderTargets.size() > 0) {
 			const vkcv::ImageHandle firstImage = renderTargets[0];
 			if (firstImage.isSwapchainImage()) {
-				const auto& swapchainExtent = m_swapchain.getExtent();
-				width = swapchainExtent.width;
-				height = swapchainExtent.height;
+				const auto& swapchainExtent = swapchain.getExtent();
+				widthHeight[0] = swapchainExtent.width;
+				widthHeight[1] = swapchainExtent.height;
 			}
 			else {
-				width = m_ImageManager->getImageWidth(firstImage);
-				height = m_ImageManager->getImageHeight(firstImage);
+				widthHeight[0] = imageManager.getImageWidth(firstImage);
+				widthHeight[1] = imageManager.getImageHeight(firstImage);
 			}
 		}
 		else {
-			width = 1;
-			height = 1;
+			widthHeight[0] = 1;
+			widthHeight[1] = 1;
 		}
 		// TODO: validate that width/height match for all attachments
+		return widthHeight;
+	}
 
-		const vk::RenderPass renderpass = m_PassManager->getVkPass(renderpassHandle);
-		const PassConfig passConfig = m_PassManager->getPassConfig(renderpassHandle);
-
-		const vk::Pipeline pipeline		= m_PipelineManager->getVkPipeline(pipelineHandle);
-		const vk::PipelineLayout pipelineLayout = m_PipelineManager->getVkPipelineLayout(pipelineHandle);
-		const vk::Rect2D renderArea(vk::Offset2D(0, 0), vk::Extent2D(width, height));
+	vk::Framebuffer createFramebuffer(
+		const std::vector<ImageHandle>& renderTargets,
+		const ImageManager&             imageManager,
+		const Swapchain&                swapchain,
+		vk::RenderPass                  renderpass,
+		vk::Device                      device) {
 
 		std::vector<vk::ImageView> attachmentsViews;
 		for (const ImageHandle handle : renderTargets) {
-			vk::ImageView targetHandle;
-			const auto cmdBuffer = m_CommandStreamManager->getStreamCommandBuffer(cmdStreamHandle);
+			vk::ImageView targetHandle = imageManager.getVulkanImageView(handle);
+			attachmentsViews.push_back(targetHandle);
+		}
+
+		const std::array<uint32_t, 2> widthHeight = getWidthHeightFromRenderTargets(renderTargets, swapchain, imageManager);
+
+		const vk::FramebufferCreateInfo createInfo(
+			{},
+			renderpass,
+			static_cast<uint32_t>(attachmentsViews.size()),
+			attachmentsViews.data(),
+			widthHeight[0],
+			widthHeight[1],
+			1);
+
+		return device.createFramebuffer(createInfo);
+	}
+
+	void transitionRendertargetsToAttachmentLayout(
+		const std::vector<ImageHandle>& renderTargets,
+		ImageManager&                   imageManager,
+		const vk::CommandBuffer         cmdBuffer) {
 
-			targetHandle = m_ImageManager->getVulkanImageView(handle);
-			const bool isDepthImage = isDepthFormat(m_ImageManager->getImageFormat(handle));
-			const vk::ImageLayout targetLayout = 
+		for (const ImageHandle handle : renderTargets) {
+			vk::ImageView targetHandle = imageManager.getVulkanImageView(handle);
+			const bool isDepthImage = isDepthFormat(imageManager.getImageFormat(handle));
+			const vk::ImageLayout targetLayout =
 				isDepthImage ? vk::ImageLayout::eDepthStencilAttachmentOptimal : vk::ImageLayout::eColorAttachmentOptimal;
-			m_ImageManager->recordImageLayoutTransition(handle, targetLayout, cmdBuffer);
-			attachmentsViews.push_back(targetHandle);
+			imageManager.recordImageLayoutTransition(handle, targetLayout, cmdBuffer);
 		}
-		
-        const vk::FramebufferCreateInfo createInfo(
-            {},
-            renderpass,
-            static_cast<uint32_t>(attachmentsViews.size()),
-            attachmentsViews.data(),
-            width,
-            height,
-            1
+	}
+
+	std::vector<vk::ClearValue> createAttachmentClearValues(const std::vector<AttachmentDescription>& attachments) {
+		std::vector<vk::ClearValue> clearValues;
+		for (const auto& attachment : attachments) {
+			if (attachment.load_operation == AttachmentOperation::CLEAR) {
+				float clear = 0.0f;
+
+				if (isDepthFormat(attachment.format)) {
+					clear = 1.0f;
+				}
+
+				clearValues.emplace_back(std::array<float, 4>{
+					clear,
+						clear,
+						clear,
+						1.f
+				});
+			}
+		}
+		return clearValues;
+	}
+
+	void recordDynamicViewport(vk::CommandBuffer cmdBuffer, uint32_t width, uint32_t height) {
+		vk::Viewport dynamicViewport(
+			0.0f, 0.0f,
+			static_cast<float>(width), static_cast<float>(height),
+			0.0f, 1.0f
 		);
-		
-		vk::Framebuffer framebuffer = m_Context.m_Device.createFramebuffer(createInfo);
-        
-        if (!framebuffer) {
+
+		vk::Rect2D dynamicScissor({ 0, 0 }, { width, height });
+
+		cmdBuffer.setViewport(0, 1, &dynamicViewport);
+		cmdBuffer.setScissor(0, 1, &dynamicScissor);
+	}
+
+	void Core::recordDrawcallsToCmdStream(
+		const CommandStreamHandle       cmdStreamHandle,
+		const PassHandle                renderpassHandle, 
+		const PipelineHandle            pipelineHandle, 
+        const PushConstantData          &pushConstantData,
+        const std::vector<DrawcallInfo> &drawcalls,
+		const std::vector<ImageHandle>  &renderTargets) {
+
+		if (m_currentSwapchainImageIndex == std::numeric_limits<uint32_t>::max()) {
+			return;
+		}
+
+		const std::array<uint32_t, 2> widthHeight = getWidthHeightFromRenderTargets(renderTargets, m_swapchain, *m_ImageManager);
+		const auto width  = widthHeight[0];
+		const auto height = widthHeight[1];
+
+		const vk::RenderPass        renderpass      = m_PassManager->getVkPass(renderpassHandle);
+		const PassConfig            passConfig      = m_PassManager->getPassConfig(renderpassHandle);
+
+		const vk::Pipeline          pipeline        = m_PipelineManager->getVkPipeline(pipelineHandle);
+		const vk::PipelineLayout    pipelineLayout  = m_PipelineManager->getVkPipelineLayout(pipelineHandle);
+		const vk::Rect2D            renderArea(vk::Offset2D(0, 0), vk::Extent2D(width, height));
+
+		vk::CommandBuffer cmdBuffer = m_CommandStreamManager->getStreamCommandBuffer(cmdStreamHandle);
+		transitionRendertargetsToAttachmentLayout(renderTargets, *m_ImageManager, cmdBuffer);
+
+		const vk::Framebuffer framebuffer = createFramebuffer(renderTargets, *m_ImageManager, m_swapchain, renderpass, m_Context.m_Device);
+
+		if (!framebuffer) {
 			vkcv_log(LogLevel::ERROR, "Failed to create temporary framebuffer");
-            return;
-        }
+			return;
+		}
 
-        vk::Viewport dynamicViewport(
-        		0.0f, 0.0f,
-            	static_cast<float>(width), static_cast<float>(height),
-            0.0f, 1.0f
-		);
+		SubmitInfo submitInfo;
+		submitInfo.queueType = QueueType::Graphics;
+		submitInfo.signalSemaphores = { m_SyncResources.renderFinished };
+
+		auto submitFunction = [&](const vk::CommandBuffer& cmdBuffer) {
+
+			const std::vector<vk::ClearValue> clearValues = createAttachmentClearValues(passConfig.attachments);
 
-        vk::Rect2D dynamicScissor({0, 0}, {width, height});
+			const vk::RenderPassBeginInfo beginInfo(renderpass, framebuffer, renderArea, clearValues.size(), clearValues.data());
+			cmdBuffer.beginRenderPass(beginInfo, {}, {});
 
-		auto &bufferManager = m_BufferManager;
+			cmdBuffer.bindPipeline(vk::PipelineBindPoint::eGraphics, pipeline, {});
+
+			const PipelineConfig &pipeConfig = m_PipelineManager->getPipelineConfig(pipelineHandle);
+			if(pipeConfig.m_UseDynamicViewport)
+			{
+				recordDynamicViewport(cmdBuffer, width, height);
+			}
+
+			for (int i = 0; i < drawcalls.size(); i++) {
+				recordDrawcall(drawcalls[i], cmdBuffer, pipelineLayout, pushConstantData, i);
+			}
+
+			cmdBuffer.endRenderPass();
+		};
+
+		auto finishFunction = [framebuffer, this]()
+		{
+			m_Context.m_Device.destroy(framebuffer);
+		};
+
+		recordCommandsToStream(cmdStreamHandle, submitFunction, finishFunction);
+	}
+
+	void Core::recordMeshShaderDrawcalls(
+		const CommandStreamHandle                           cmdStreamHandle,
+		const PassHandle                                    renderpassHandle,
+		const PipelineHandle                                pipelineHandle,
+		const PushConstantData&                             pushConstantData,
+		const std::vector<MeshShaderDrawcall>&              drawcalls,
+		const std::vector<ImageHandle>&                     renderTargets) {
+
+		if (m_currentSwapchainImageIndex == std::numeric_limits<uint32_t>::max()) {
+			return;
+		}
+
+		const std::array<uint32_t, 2> widthHeight = getWidthHeightFromRenderTargets(renderTargets, m_swapchain, *m_ImageManager);
+		const auto width  = widthHeight[0];
+		const auto height = widthHeight[1];
+
+		const vk::RenderPass        renderpass = m_PassManager->getVkPass(renderpassHandle);
+		const PassConfig            passConfig = m_PassManager->getPassConfig(renderpassHandle);
+
+		const vk::Pipeline          pipeline = m_PipelineManager->getVkPipeline(pipelineHandle);
+		const vk::PipelineLayout    pipelineLayout = m_PipelineManager->getVkPipelineLayout(pipelineHandle);
+		const vk::Rect2D            renderArea(vk::Offset2D(0, 0), vk::Extent2D(width, height));
+
+		vk::CommandBuffer cmdBuffer = m_CommandStreamManager->getStreamCommandBuffer(cmdStreamHandle);
+		transitionRendertargetsToAttachmentLayout(renderTargets, *m_ImageManager, cmdBuffer);
+
+		const vk::Framebuffer framebuffer = createFramebuffer(renderTargets, *m_ImageManager, m_swapchain, renderpass, m_Context.m_Device);
+
+		if (!framebuffer) {
+			vkcv_log(LogLevel::ERROR, "Failed to create temporary framebuffer");
+			return;
+		}
 
 		SubmitInfo submitInfo;
 		submitInfo.queueType = QueueType::Graphics;
 		submitInfo.signalSemaphores = { m_SyncResources.renderFinished };
 
 		auto submitFunction = [&](const vk::CommandBuffer& cmdBuffer) {
-            std::vector<vk::ClearValue> clearValues;
-
-            for (const auto& attachment : passConfig.attachments) {
-                if (attachment.load_operation == AttachmentOperation::CLEAR) {
-                    float clear = 0.0f;
-
-                    if (isDepthFormat(attachment.format)) {
-                        clear = 1.0f;
-                    }
-
-                    clearValues.emplace_back(std::array<float, 4>{
-                            clear,
-                            clear,
-                            clear,
-                            1.f
-                    });
-                }
-            }
-
-            const vk::RenderPassBeginInfo beginInfo(renderpass, framebuffer, renderArea, clearValues.size(), clearValues.data());
-            const vk::SubpassContents subpassContents = {};
-            cmdBuffer.beginRenderPass(beginInfo, subpassContents, {});
-
-            cmdBuffer.bindPipeline(vk::PipelineBindPoint::eGraphics, pipeline, {});
-
-            const PipelineConfig &pipeConfig = m_PipelineManager->getPipelineConfig(pipelineHandle);
-            if(pipeConfig.m_UseDynamicViewport)
-            {
-                cmdBuffer.setViewport(0, 1, &dynamicViewport);
-                cmdBuffer.setScissor(0, 1, &dynamicScissor);
-            }
-
-            for (int i = 0; i < drawcalls.size(); i++) {
-                recordDrawcall(drawcalls[i], cmdBuffer, pipelineLayout, pushConstantData, i);
-            }
-
-            cmdBuffer.endRenderPass();
-        };
-
-        auto finishFunction = [framebuffer, this]()
-        {
-            m_Context.m_Device.destroy(framebuffer);
-        };
+
+			const std::vector<vk::ClearValue> clearValues = createAttachmentClearValues(passConfig.attachments);
+
+			const vk::RenderPassBeginInfo beginInfo(renderpass, framebuffer, renderArea, clearValues.size(), clearValues.data());
+			cmdBuffer.beginRenderPass(beginInfo, {}, {});
+
+			cmdBuffer.bindPipeline(vk::PipelineBindPoint::eGraphics, pipeline, {});
+
+			const PipelineConfig& pipeConfig = m_PipelineManager->getPipelineConfig(pipelineHandle);
+			if (pipeConfig.m_UseDynamicViewport)
+			{
+				recordDynamicViewport(cmdBuffer, width, height);
+			}
+
+			for (int i = 0; i < drawcalls.size(); i++) {
+                const uint32_t pushConstantOffset = i * pushConstantData.sizePerDrawcall;
+                recordMeshShaderDrawcall(
+                    cmdBuffer,
+                    pipelineLayout,
+                    pushConstantData,
+                    pushConstantOffset,
+                    drawcalls[i],
+                    0);
+			}
+
+			cmdBuffer.endRenderPass();
+		};
+
+		auto finishFunction = [framebuffer, this]()
+		{
+			m_Context.m_Device.destroy(framebuffer);
+		};
 
 		recordCommandsToStream(cmdStreamHandle, submitFunction, finishFunction);
 	}
diff --git a/src/vkcv/DrawcallRecording.cpp b/src/vkcv/DrawcallRecording.cpp
index df7b7bbc..18eb896b 100644
--- a/src/vkcv/DrawcallRecording.cpp
+++ b/src/vkcv/DrawcallRecording.cpp
@@ -42,4 +42,35 @@ namespace vkcv {
             cmdBuffer.draw(drawcall.mesh.indexCount, 1, 0, 0, {});
         }
     }
+
+    void recordMeshShaderDrawcall(
+        vk::CommandBuffer                       cmdBuffer,
+        vk::PipelineLayout                      pipelineLayout,
+        const PushConstantData&                 pushConstantData,
+        const uint32_t                          pushConstantOffset,
+        const MeshShaderDrawcall&               drawcall,
+        const uint32_t                          firstTask) {
+
+        for (const auto& descriptorUsage : drawcall.descriptorSets) {
+            cmdBuffer.bindDescriptorSets(
+                vk::PipelineBindPoint::eGraphics,
+                pipelineLayout,
+                descriptorUsage.setLocation,
+                descriptorUsage.vulkanHandle,
+                nullptr);
+        }
+
+        const size_t drawcallPushConstantOffset = pushConstantOffset;
+        // char* cast because void* does not support pointer arithmetic
+        const void* drawcallPushConstantData = drawcallPushConstantOffset + (char*)pushConstantData.data;
+
+        cmdBuffer.pushConstants(
+            pipelineLayout,
+            vk::ShaderStageFlagBits::eAll,
+            0,
+            pushConstantData.sizePerDrawcall,
+            drawcallPushConstantData);
+
+        cmdBuffer.drawMeshTasksNV(drawcall.taskCount, firstTask);
+    }
 }
\ No newline at end of file
-- 
GitLab