From df985b6d67526b13481603716d07aaccb46b9289 Mon Sep 17 00:00:00 2001
From: Tobias Frisch <tfrisch@uni-koblenz.de>
Date: Tue, 24 Jan 2023 23:37:20 +0100
Subject: [PATCH] Adjusted usage of mesh shader extension

Signed-off-by: Tobias Frisch <tfrisch@uni-koblenz.de>
---
 include/vkcv/Drawcall.hpp                     |  7 +++--
 .../include/vkcv/shader/GLSLCompiler.hpp      |  1 +
 .../src/vkcv/shader/GLSLCompiler.cpp          | 12 ++++++++
 .../mesh_shader/assets/shaders/shader.mesh    | 30 +++++++++----------
 .../mesh_shader/assets/shaders/shader.task    | 25 ++++++++--------
 projects/mesh_shader/src/main.cpp             |  4 +--
 src/vkcv/Core.cpp                             | 18 +++++++----
 src/vkcv/Drawcall.cpp                         |  8 ++---
 src/vkcv/GraphicsPipelineManager.cpp          | 24 ++++++++++-----
 src/vkcv/SwapchainManager.cpp                 |  6 ++++
 src/vkcv/WindowManager.cpp                    |  5 ++++
 11 files changed, 91 insertions(+), 49 deletions(-)

diff --git a/include/vkcv/Drawcall.hpp b/include/vkcv/Drawcall.hpp
index 08e436ae..5b229814 100644
--- a/include/vkcv/Drawcall.hpp
+++ b/include/vkcv/Drawcall.hpp
@@ -8,6 +8,7 @@
 #include <vector>
 
 #include "DescriptorSetUsage.hpp"
+#include "DispatchSize.hpp"
 #include "Handles.hpp"
 #include "VertexData.hpp"
 
@@ -78,12 +79,12 @@ namespace vkcv {
 	 */
 	class TaskDrawcall : public Drawcall {
 	private:
-		uint32_t m_taskCount;
+		DispatchSize m_taskSize;
 
 	public:
-		explicit TaskDrawcall(uint32_t taskCount = 1);
+		explicit TaskDrawcall(const DispatchSize& taskSize = DispatchSize(1));
 
-		[[nodiscard]] uint32_t getTaskCount() const;
+		[[nodiscard]] DispatchSize getTaskSize() const;
 	};
 
 } // namespace vkcv
diff --git a/modules/shader_compiler/include/vkcv/shader/GLSLCompiler.hpp b/modules/shader_compiler/include/vkcv/shader/GLSLCompiler.hpp
index cebc8bb7..e3fa7b85 100644
--- a/modules/shader_compiler/include/vkcv/shader/GLSLCompiler.hpp
+++ b/modules/shader_compiler/include/vkcv/shader/GLSLCompiler.hpp
@@ -15,6 +15,7 @@ namespace vkcv::shader {
 	enum class GLSLCompileTarget {
 		SUBGROUP_OP,
 		RAY_TRACING,
+		MESH_SHADING,
 		
 		UNKNOWN
 	};
diff --git a/modules/shader_compiler/src/vkcv/shader/GLSLCompiler.cpp b/modules/shader_compiler/src/vkcv/shader/GLSLCompiler.cpp
index da515315..e469152f 100644
--- a/modules/shader_compiler/src/vkcv/shader/GLSLCompiler.cpp
+++ b/modules/shader_compiler/src/vkcv/shader/GLSLCompiler.cpp
@@ -132,6 +132,14 @@ namespace vkcv::shader {
 		resources.maxCullDistances = 8;
 		resources.maxCombinedClipAndCullDistances = 8;
 		resources.maxSamples = 4;
+		resources.maxMeshOutputVerticesNV = 256;
+		resources.maxMeshOutputPrimitivesNV = 512;
+		resources.maxMeshWorkGroupSizeX_NV = 32;
+		resources.maxMeshWorkGroupSizeY_NV = 1;
+		resources.maxMeshWorkGroupSizeZ_NV = 1;
+		resources.maxTaskWorkGroupSizeX_NV = 32;
+		resources.maxTaskWorkGroupSizeY_NV = 1;
+		resources.maxTaskWorkGroupSizeZ_NV = 1;
 		resources.maxMeshOutputVerticesEXT = 256;
 		resources.maxMeshOutputPrimitivesEXT = 512;
 		resources.maxMeshWorkGroupSizeX_EXT = 32;
@@ -173,6 +181,10 @@ namespace vkcv::shader {
 				shader.setEnvClient(glslang::EShClientVulkan, glslang::EShTargetVulkan_1_2);
 				shader.setEnvTarget(glslang::EShTargetSpv, glslang::EShTargetSpv_1_4);
 				break;
+			case GLSLCompileTarget::MESH_SHADING:
+				shader.setEnvClient(glslang::EShClientVulkan, glslang::EShTargetVulkan_1_1);
+				shader.setEnvTarget(glslang::EShTargetSpv, glslang::EShTargetSpv_1_4);
+				break;
 			default:
 				break;
 		}
diff --git a/projects/mesh_shader/assets/shaders/shader.mesh b/projects/mesh_shader/assets/shaders/shader.mesh
index 30c98610..a6b0cb62 100644
--- a/projects/mesh_shader/assets/shaders/shader.mesh
+++ b/projects/mesh_shader/assets/shaders/shader.mesh
@@ -1,11 +1,11 @@
 #version 460
 #extension GL_ARB_separate_shader_objects   : enable
 #extension GL_GOOGLE_include_directive      : enable
-#extension GL_NV_mesh_shader                : require
+#extension GL_EXT_mesh_shader               : require
 
 #include "meshlet.inc"
 
-layout(local_size_x=32) in;
+layout(local_size_x=32, local_size_y=1, local_size_z=1) in;
 
 layout(triangles) out;
 layout(max_vertices=64, max_primitives=126) out;
@@ -34,13 +34,14 @@ layout(std430, binding = 2) readonly buffer meshletBuffer
     Meshlet meshlets[];
 };
 
-taskNV in Task {
-  uint meshletIndices[32];
-  mat4 mvp;
-} IN;
+struct Task {
+    uint meshletIndices[32];
+    mat4 mvp;
+};
+
+taskPayloadSharedEXT Task IN;
 
 void main()	{
-    
     uint meshletIndex = IN.meshletIndices[gl_WorkGroupID.x];
     Meshlet meshlet = meshlets[meshletIndex];
     
@@ -55,24 +56,23 @@ void main()	{
         uint vertexIndex    = meshlet.vertexOffset + workIndex;
         Vertex vertex       = vertices[vertexIndex];
     
-        gl_MeshVerticesNV[workIndex].gl_Position    = IN.mvp * vec4(vertex.position, 1);
+        gl_MeshVerticesEXT[workIndex].gl_Position    = IN.mvp * vec4(vertex.position, 1);
         passNormal[workIndex]                       = vertex.normal;
         passTaskIndex[workIndex]                    = meshletIndex;
     }
     
     // set local indices
-    for(uint i = 0; i < 12; i++){
-    
+    for (uint i = 0; i < 12; i++) {
         uint workIndex = gl_LocalInvocationID.x + i * 32;
         if(workIndex >= meshlet.indexCount){
             break;
-        }    
+        }
         
-        uint indexBufferIndex               = meshlet.indexOffset + workIndex;
-        gl_PrimitiveIndicesNV[workIndex]    = localIndices[indexBufferIndex];
+        uint indexBufferIndex = meshlet.indexOffset + workIndex;
+        gl_PrimitiveTriangleIndicesEXT[workIndex] = uvec3(localIndices[indexBufferIndex]);
     }
     
-    if(gl_LocalInvocationID.x == 0){
-        gl_PrimitiveCountNV = meshlet.indexCount / 3;
+    if (gl_LocalInvocationID.x == 0) {
+        SetMeshOutputsEXT(64, meshlet.indexCount / 3);
     }
 }
\ No newline at end of file
diff --git a/projects/mesh_shader/assets/shaders/shader.task b/projects/mesh_shader/assets/shaders/shader.task
index 7a692e98..a6b1ad61 100644
--- a/projects/mesh_shader/assets/shaders/shader.task
+++ b/projects/mesh_shader/assets/shaders/shader.task
@@ -1,17 +1,19 @@
 #version 460
 #extension GL_ARB_separate_shader_objects   : enable
-#extension GL_NV_mesh_shader                : require
+#extension GL_EXT_mesh_shader               : require
 #extension GL_GOOGLE_include_directive      : enable
 
 #include "meshlet.inc"
 #include "common.inc"
 
-layout(local_size_x=32) in;
+layout(local_size_x=32, local_size_y=1, local_size_z=1) in;
 
-taskNV out Task {
-  uint meshletIndices[32];
-  mat4 mvp;
-} OUT;
+struct Task {
+    uint meshletIndices[32];
+    mat4 mvp;
+};
+
+taskPayloadSharedEXT Task OUT;
 
 layout( push_constant ) uniform constants{
     uint matrixIndex;
@@ -52,27 +54,26 @@ bool isSphereInsideFrustum(vec3 spherePos, float sphereRadius, Plane cameraPlane
 }
 
 void main() {
-
-    if(gl_LocalInvocationID.x >= meshletCount){
+    if (gl_LocalInvocationID.x >= meshletCount) {
         return;
     }
     
     uint meshletIndex   = gl_GlobalInvocationID.x;
     Meshlet meshlet     = meshlets[meshletIndex]; 
     
-    if(gl_LocalInvocationID.x == 0){
+    if (gl_LocalInvocationID.x == 0) {
         taskCount = 0;
     }
     
     // TODO: scaling support
     vec3 meshletPositionWorld = (vec4(meshlet.meanPosition, 1) * objectMatrices[matrixIndex].model).xyz;
-    if(isSphereInsideFrustum(meshletPositionWorld, meshlet.boundingSphereRadius, cameraPlanes)){
+    if (isSphereInsideFrustum(meshletPositionWorld, meshlet.boundingSphereRadius, cameraPlanes)) {
         uint outIndex = atomicAdd(taskCount, 1);
         OUT.meshletIndices[outIndex] = gl_GlobalInvocationID.x;
     }
 
-    if(gl_LocalInvocationID.x == 0){
-        gl_TaskCountNV              = taskCount;
+    if (gl_LocalInvocationID.x == 0) {
         OUT.mvp = objectMatrices[matrixIndex].mvp;
+        EmitMeshTasksEXT(taskCount, 1, 1);
     }
 }
\ No newline at end of file
diff --git a/projects/mesh_shader/src/main.cpp b/projects/mesh_shader/src/main.cpp
index df2a0ec0..8afbbb26 100644
--- a/projects/mesh_shader/src/main.cpp
+++ b/projects/mesh_shader/src/main.cpp
@@ -81,7 +81,7 @@ int main(int argc, const char** argv) {
 	const std::string applicationName = "Mesh shader";
 	
 	vkcv::Features features;
-	features.requireExtension(VK_EXT_MESH_SHADER_EXTENSION_NAME);
+	features.requireExtension(VK_KHR_SWAPCHAIN_EXTENSION_NAME);
 	features.requireExtensionFeature<vk::PhysicalDeviceMeshShaderFeaturesEXT>(
 			VK_EXT_MESH_SHADER_EXTENSION_NAME,
 			[](vk::PhysicalDeviceMeshShaderFeaturesEXT& features) {
@@ -199,7 +199,7 @@ int main(int argc, const char** argv) {
 	}
 
 	vkcv::ShaderProgram bunnyShaderProgram{};
-	vkcv::shader::GLSLCompiler compiler;
+	vkcv::shader::GLSLCompiler compiler (vkcv::shader::GLSLCompileTarget::MESH_SHADING);
 	
 	compiler.compile(vkcv::ShaderStage::VERTEX, std::filesystem::path("assets/shaders/shader.vert"),
 					 [&bunnyShaderProgram](vkcv::ShaderStage shaderStage, const std::filesystem::path& path) {
diff --git a/src/vkcv/Core.cpp b/src/vkcv/Core.cpp
index d28f5669..d31a9ff3 100644
--- a/src/vkcv/Core.cpp
+++ b/src/vkcv/Core.cpp
@@ -598,9 +598,9 @@ namespace vkcv {
 										 const PushConstants &pushConstantData,
 										 size_t drawcallIndex, const TaskDrawcall &drawcall) {
 
-		static PFN_vkCmdDrawMeshTasksNV cmdDrawMeshTasks =
-			reinterpret_cast<PFN_vkCmdDrawMeshTasksNV>(
-				core.getContext().getDevice().getProcAddr("vkCmdDrawMeshTasksNV"));
+		static PFN_vkCmdDrawMeshTasksEXT cmdDrawMeshTasks =
+			reinterpret_cast<PFN_vkCmdDrawMeshTasksEXT>(
+				core.getContext().getDevice().getProcAddr("vkCmdDrawMeshTasksEXT"));
 
 		if (!cmdDrawMeshTasks) {
 			vkcv_log(LogLevel::ERROR, "Mesh shader drawcalls are not supported");
@@ -619,8 +619,14 @@ namespace vkcv {
 									pushConstantData.getSizePerDrawcall(),
 									pushConstantData.getDrawcallData(drawcallIndex));
 		}
-
-		cmdDrawMeshTasks(VkCommandBuffer(cmdBuffer), drawcall.getTaskCount(), 0);
+		
+		const auto& groupSize = drawcall.getTaskSize();
+		cmdDrawMeshTasks(
+				VkCommandBuffer(cmdBuffer),
+				groupSize.x(),
+				groupSize.y(),
+				groupSize.z()
+		);
 	}
 
 	void Core::recordMeshShaderDrawcalls(const CommandStreamHandle &cmdStreamHandle,
@@ -640,7 +646,7 @@ namespace vkcv {
 		auto recordFunction = [&](const vk::CommandBuffer &cmdBuffer) {
 			for (size_t i = 0; i < drawcalls.size(); i++) {
 				recordMeshShaderDrawcall(*this, *m_DescriptorSetManager, cmdBuffer, pipelineLayout,
-										 pushConstantData, i, drawcalls [i]);
+										 pushConstantData, i, drawcalls[i]);
 			}
 		};
 
diff --git a/src/vkcv/Drawcall.cpp b/src/vkcv/Drawcall.cpp
index b45cb6ec..73142b55 100644
--- a/src/vkcv/Drawcall.cpp
+++ b/src/vkcv/Drawcall.cpp
@@ -55,10 +55,10 @@ namespace vkcv {
 		return m_drawCount;
 	}
 
-	TaskDrawcall::TaskDrawcall(uint32_t taskCount) : Drawcall(), m_taskCount(taskCount) {}
-
-	uint32_t TaskDrawcall::getTaskCount() const {
-		return m_taskCount;
+	TaskDrawcall::TaskDrawcall(const DispatchSize& taskSize) : Drawcall(), m_taskSize(taskSize) {}
+	
+	DispatchSize TaskDrawcall::getTaskSize() const {
+		return m_taskSize;
 	}
 
 } // namespace vkcv
diff --git a/src/vkcv/GraphicsPipelineManager.cpp b/src/vkcv/GraphicsPipelineManager.cpp
index 4ac862e4..8747a5e9 100644
--- a/src/vkcv/GraphicsPipelineManager.cpp
+++ b/src/vkcv/GraphicsPipelineManager.cpp
@@ -115,9 +115,9 @@ namespace vkcv {
 		case ShaderStage::COMPUTE:
 			return vk::ShaderStageFlagBits::eCompute;
 		case ShaderStage::TASK:
-			return vk::ShaderStageFlagBits::eTaskNV;
+			return vk::ShaderStageFlagBits::eTaskEXT;
 		case ShaderStage::MESH:
-			return vk::ShaderStageFlagBits::eMeshNV;
+			return vk::ShaderStageFlagBits::eMeshEXT;
 		default:
 			vkcv_log(LogLevel::ERROR, "Unknown shader stage");
 			return vk::ShaderStageFlagBits::eAll;
@@ -641,11 +641,21 @@ namespace vkcv {
 		// Get all setting structs together and create the Pipeline
 		const vk::GraphicsPipelineCreateInfo graphicsPipelineCreateInfo(
 			{}, static_cast<uint32_t>(shaderStages.size()), shaderStages.data(),
-			&pipelineVertexInputStateCreateInfo, &pipelineInputAssemblyStateCreateInfo,
-			&pipelineTessellationStateCreateInfo, &pipelineViewportStateCreateInfo,
-			&pipelineRasterizationStateCreateInfo, &pipelineMultisampleStateCreateInfo,
-			p_depthStencilCreateInfo, &pipelineColorBlendStateCreateInfo, &dynamicStateCreateInfo,
-			vkPipelineLayout, pass, 0, {}, 0);
+			&pipelineVertexInputStateCreateInfo,
+			&pipelineInputAssemblyStateCreateInfo,
+			&pipelineTessellationStateCreateInfo,
+			&pipelineViewportStateCreateInfo,
+			&pipelineRasterizationStateCreateInfo,
+			&pipelineMultisampleStateCreateInfo,
+			p_depthStencilCreateInfo,
+			&pipelineColorBlendStateCreateInfo,
+			&dynamicStateCreateInfo,
+			vkPipelineLayout,
+			pass,
+			0,
+			{},
+			0
+		);
 
 		vk::Pipeline vkPipeline {};
 		if (getCore().getContext().getDevice().createGraphicsPipelines(
diff --git a/src/vkcv/SwapchainManager.cpp b/src/vkcv/SwapchainManager.cpp
index 27efa503..7586f9ad 100644
--- a/src/vkcv/SwapchainManager.cpp
+++ b/src/vkcv/SwapchainManager.cpp
@@ -175,6 +175,12 @@ namespace vkcv {
 
 	static bool createVulkanSwapchain(const Context &context, const Window &window,
 									  SwapchainEntry &entry) {
+		if (!context.getFeatureManager().isExtensionActive(VK_KHR_SWAPCHAIN_EXTENSION_NAME)) {
+			vkcv_log(LogLevel::WARNING, "Extension required to create a swapchain: '%s'",
+					 VK_KHR_SWAPCHAIN_EXTENSION_NAME);
+			return false;
+		}
+		
 		const vk::PhysicalDevice &physicalDevice = context.getPhysicalDevice();
 		const vk::Device &device = context.getDevice();
 
diff --git a/src/vkcv/WindowManager.cpp b/src/vkcv/WindowManager.cpp
index 058ee219..077b9ec3 100644
--- a/src/vkcv/WindowManager.cpp
+++ b/src/vkcv/WindowManager.cpp
@@ -33,6 +33,11 @@ namespace vkcv {
 								 static_cast<int>(windowHeight), resizeable);
 
 		SwapchainHandle swapchainHandle = swapchainManager.createSwapchain(*window);
+		
+		if (!swapchainHandle) {
+			delete window;
+			return {};
+		}
 
 		if (resizeable) {
 			const event_handle<int, int> &resizeHandle =
-- 
GitLab