From ba27891a52c005e145c01c920cb46df9d248959c Mon Sep 17 00:00:00 2001
From: Alexander Gauggel <agauggel@uni-koblenz.de>
Date: Sat, 3 Jul 2021 12:37:18 +0200
Subject: [PATCH] [#87] Create empty mesh shader pipeline

---
 include/vkcv/Context.hpp                      |  11 +-
 include/vkcv/ShaderStage.hpp                  |   6 +-
 .../src/vkcv/shader/GLSLCompiler.cpp          |   4 +
 .../mesh_shader/resources/shaders/shader.mesh |  19 ++
 .../mesh_shader/resources/shaders/shader.task |  15 ++
 projects/mesh_shader/src/main.cpp             |  45 +++-
 src/vkcv/Context.cpp                          |  39 +++-
 src/vkcv/PipelineManager.cpp                  | 217 ++++++++++++------
 8 files changed, 271 insertions(+), 85 deletions(-)
 create mode 100644 projects/mesh_shader/resources/shaders/shader.mesh
 create mode 100644 projects/mesh_shader/resources/shaders/shader.task

diff --git a/include/vkcv/Context.hpp b/include/vkcv/Context.hpp
index 1c01a613..be693399 100644
--- a/include/vkcv/Context.hpp
+++ b/include/vkcv/Context.hpp
@@ -33,11 +33,12 @@ namespace vkcv
         [[nodiscard]]
         const QueueManager& getQueueManager() const;
         
-        static Context create(const char *applicationName,
-							  uint32_t applicationVersion,
-							  std::vector<vk::QueueFlagBits> queueFlags,
-							  std::vector<const char *> instanceExtensions,
-							  std::vector<const char *> deviceExtensions);
+        static Context create(
+			const char*                     applicationName,
+			uint32_t                        applicationVersion,
+			std::vector<vk::QueueFlagBits>  queueFlags,
+			std::vector<const char *>       instanceExtensions,
+			std::vector<const char *>       deviceExtensions);
 
     private:
         /**
diff --git a/include/vkcv/ShaderStage.hpp b/include/vkcv/ShaderStage.hpp
index dca395bd..3893bdf5 100644
--- a/include/vkcv/ShaderStage.hpp
+++ b/include/vkcv/ShaderStage.hpp
@@ -9,7 +9,11 @@ namespace vkcv {
 		TESS_EVAL,
 		GEOMETRY,
 		FRAGMENT,
-		COMPUTE
+		COMPUTE,
+		TASK,
+		MESH
 	};
 
+
+
 }
diff --git a/modules/shader_compiler/src/vkcv/shader/GLSLCompiler.cpp b/modules/shader_compiler/src/vkcv/shader/GLSLCompiler.cpp
index ec358188..4a351f41 100644
--- a/modules/shader_compiler/src/vkcv/shader/GLSLCompiler.cpp
+++ b/modules/shader_compiler/src/vkcv/shader/GLSLCompiler.cpp
@@ -50,6 +50,10 @@ namespace vkcv::shader {
 				return EShLangFragment;
 			case ShaderStage::COMPUTE:
 				return EShLangCompute;
+			case ShaderStage::TASK:
+				return EShLangTaskNV;
+			case ShaderStage::MESH:
+				return EShLangMeshNV;
 			default:
 				return EShLangCount;
 		}
diff --git a/projects/mesh_shader/resources/shaders/shader.mesh b/projects/mesh_shader/resources/shaders/shader.mesh
new file mode 100644
index 00000000..f4bddd60
--- /dev/null
+++ b/projects/mesh_shader/resources/shaders/shader.mesh
@@ -0,0 +1,19 @@
+#version 460
+#extension GL_ARB_separate_shader_objects   : enable
+#extension GL_NV_mesh_shader                : require
+
+layout(local_size_x=32) in;
+
+layout(triangles) out;
+layout(max_vertices=64, max_primitives=126) out;
+
+// out uint gl_PrimitiveCountNV;
+// out uint gl_PrimitiveIndicesNV[];
+
+out gl_MeshPerVertexNV {
+    vec4 gl_Position;
+} gl_MeshVerticesNV[];
+
+void main()	{
+    
+}
\ No newline at end of file
diff --git a/projects/mesh_shader/resources/shaders/shader.task b/projects/mesh_shader/resources/shaders/shader.task
new file mode 100644
index 00000000..3308320f
--- /dev/null
+++ b/projects/mesh_shader/resources/shaders/shader.task
@@ -0,0 +1,15 @@
+#version 460
+#extension GL_ARB_separate_shader_objects   : enable
+#extension GL_NV_mesh_shader                : require
+
+layout(local_size_x=32) in;
+
+taskNV out Task {
+  uint baseID;
+  uint subIDs[32];
+} OUT;
+
+
+void main() {
+
+}
\ No newline at end of file
diff --git a/projects/mesh_shader/src/main.cpp b/projects/mesh_shader/src/main.cpp
index 925a6308..e9f48dcd 100644
--- a/projects/mesh_shader/src/main.cpp
+++ b/projects/mesh_shader/src/main.cpp
@@ -25,7 +25,7 @@ int main(int argc, const char** argv) {
 		VK_MAKE_VERSION(0, 0, 1),
 		{ vk::QueueFlagBits::eTransfer,vk::QueueFlagBits::eGraphics, vk::QueueFlagBits::eCompute },
 		{},
-		{ "VK_KHR_swapchain" }
+		{ "VK_KHR_swapchain", VK_NV_MESH_SHADER_EXTENSION_NAME }
 	);
 	
 	vkcv::gui::GUI gui (core, window);
@@ -41,9 +41,9 @@ int main(int argc, const char** argv) {
 		core.getSwapchain().getFormat());
 
 	vkcv::PassConfig trianglePassDefinition({ present_color_attachment });
-	vkcv::PassHandle trianglePass = core.createPass(trianglePassDefinition);
+	vkcv::PassHandle renderPass = core.createPass(trianglePassDefinition);
 
-	if (!trianglePass)
+	if (!renderPass)
 	{
 		std::cout << "Error. Could not create renderpass. Exiting." << std::endl;
 		return EXIT_FAILURE;
@@ -66,7 +66,7 @@ int main(int argc, const char** argv) {
 		triangleShaderProgram,
 		(uint32_t)windowWidth,
 		(uint32_t)windowHeight,
-		trianglePass,
+		renderPass,
 		{},
 		{},
 		false
@@ -80,6 +80,41 @@ int main(int argc, const char** argv) {
 		return EXIT_FAILURE;
 	}
 
+	// mesh shader
+	vkcv::ShaderProgram meshShaderProgram;
+	compiler.compile(vkcv::ShaderStage::TASK, std::filesystem::path("resources/shaders/shader.task"),
+		[&meshShaderProgram](vkcv::ShaderStage shaderStage, const std::filesystem::path& path) {
+		meshShaderProgram.addShader(shaderStage, path);
+	});
+
+	compiler.compile(vkcv::ShaderStage::MESH, std::filesystem::path("resources/shaders/shader.mesh"),
+		[&meshShaderProgram](vkcv::ShaderStage shaderStage, const std::filesystem::path& path) {
+		meshShaderProgram.addShader(shaderStage, path);
+	});
+
+	compiler.compile(vkcv::ShaderStage::FRAGMENT, std::filesystem::path("resources/shaders/shader.frag"),
+		[&meshShaderProgram](vkcv::ShaderStage shaderStage, const std::filesystem::path& path) {
+		meshShaderProgram.addShader(shaderStage, path);
+	});
+
+	const vkcv::PipelineConfig meshShaderPipelineDefinition{
+		meshShaderProgram,
+		(uint32_t)windowWidth,
+		(uint32_t)windowHeight,
+		renderPass,
+		{},
+		{},
+		false
+	};
+
+	vkcv::PipelineHandle meshShaderPipeline = core.createGraphicsPipeline(meshShaderPipelineDefinition);
+
+	if (!meshShaderPipeline)
+	{
+		std::cout << "Error. Could not create mesh shader pipeline. Exiting." << std::endl;
+		return EXIT_FAILURE;
+	}
+
 	auto start = std::chrono::system_clock::now();
 
 	vkcv::ImageHandle swapchainImageHandle = vkcv::ImageHandle::createSwapchainImageHandle();
@@ -115,7 +150,7 @@ int main(int argc, const char** argv) {
 
 		core.recordDrawcallsToCmdStream(
 			cmdStream,
-			trianglePass,
+			renderPass,
 			trianglePipeline,
 			pushConstantData,
 			{ drawcall },
diff --git a/src/vkcv/Context.cpp b/src/vkcv/Context.cpp
index ac133d1a..49aea650 100644
--- a/src/vkcv/Context.cpp
+++ b/src/vkcv/Context.cpp
@@ -168,12 +168,22 @@ namespace vkcv
 		
 		return extensions;
 	}
-	
-	Context Context::create(const char *applicationName,
-							uint32_t applicationVersion,
-							std::vector<vk::QueueFlagBits> queueFlags,
-							std::vector<const char *> instanceExtensions,
-							std::vector<const char *> deviceExtensions) {
+
+	bool isPresentInCharPtrVector(const std::vector<const char*>& v, const char* term){
+		for (const auto& entry : v) {
+			if (strcmp(entry, term) != 0) {
+				return true;
+			}
+		}
+		return false;
+	}
+
+	Context Context::create(
+        const char* applicationName,
+        uint32_t                        applicationVersion,
+        std::vector<vk::QueueFlagBits>  queueFlags,
+        std::vector<const char*>        instanceExtensions,
+        std::vector<const char*>        deviceExtensions) {
 		// check for layer support
 		
 		const std::vector<vk::LayerProperties>& layerProperties = vk::enumerateInstanceLayerProperties();
@@ -277,10 +287,19 @@ namespace vkcv
 #endif
 
 		// FIXME: check if device feature is supported
-		vk::PhysicalDeviceFeatures deviceFeatures;
-		deviceFeatures.fragmentStoresAndAtomics = true;
-		deviceFeatures.geometryShader = true;
-		deviceCreateInfo.pEnabledFeatures = &deviceFeatures;
+		vk::PhysicalDeviceFeatures2 deviceFeatures;
+		deviceFeatures.features.fragmentStoresAndAtomics    = true;
+		deviceFeatures.features.geometryShader              = true;
+
+		const bool usingMeshShaders = isPresentInCharPtrVector(deviceExtensions, VK_NV_MESH_SHADER_EXTENSION_NAME);
+		vk::PhysicalDeviceMeshShaderFeaturesNV meshShading;
+		if (usingMeshShaders) {
+			meshShading.taskShader = true;
+			meshShading.meshShader = true;
+			deviceFeatures.pNext = &meshShading;
+		}
+
+		deviceCreateInfo.pNext = &deviceFeatures;
 
 		// Ablauf
 		// qCreateInfos erstellen --> braucht das Device
diff --git a/src/vkcv/PipelineManager.cpp b/src/vkcv/PipelineManager.cpp
index df36442e..fc77e37d 100644
--- a/src/vkcv/PipelineManager.cpp
+++ b/src/vkcv/PipelineManager.cpp
@@ -51,51 +51,143 @@ namespace vkcv
         }
     }
 
+	vk::ShaderStageFlagBits shaderStageToVkShaderStage(vkcv::ShaderStage stage) {
+		switch (stage) {
+		case vkcv::ShaderStage::VERTEX:         return vk::ShaderStageFlagBits::eVertex;
+		case vkcv::ShaderStage::FRAGMENT:       return vk::ShaderStageFlagBits::eFragment;
+		case vkcv::ShaderStage::GEOMETRY:       return vk::ShaderStageFlagBits::eGeometry;
+		case vkcv::ShaderStage::TESS_CONTROL:   return vk::ShaderStageFlagBits::eTessellationControl;
+		case vkcv::ShaderStage::TESS_EVAL:      return vk::ShaderStageFlagBits::eTessellationEvaluation;
+		case vkcv::ShaderStage::COMPUTE:        return vk::ShaderStageFlagBits::eCompute;
+		case vkcv::ShaderStage::TASK:           return vk::ShaderStageFlagBits::eTaskNV;
+		case vkcv::ShaderStage::MESH:           return vk::ShaderStageFlagBits::eMeshNV;
+		default: vkcv_log(vkcv::LogLevel::ERROR, "Unknown shader stage"); return vk::ShaderStageFlagBits::eAll;
+		}
+	}
+
+    bool createPipelineShaderStageCreateInfo(
+        const vkcv::ShaderProgram&          shaderProgram, 
+        ShaderStage                         stage,
+        vk::Device                          device,
+        vk::PipelineShaderStageCreateInfo*  outCreateInfo) {
+
+        assert(outCreateInfo);
+        std::vector<char>           code = shaderProgram.getShader(stage).shaderCode;
+        vk::ShaderModuleCreateInfo  vertexModuleInfo({}, code.size(), reinterpret_cast<uint32_t*>(code.data()));
+        vk::ShaderModule            shaderModule;
+        if (device.createShaderModule(&vertexModuleInfo, nullptr, &shaderModule) != vk::Result::eSuccess)
+            return false;
+
+        const static auto entryName = "main";
+
+        *outCreateInfo = vk::PipelineShaderStageCreateInfo(
+            {},
+            shaderStageToVkShaderStage(stage),
+            shaderModule,
+            entryName,
+            nullptr);
+        return true;
+    }
+
     PipelineHandle PipelineManager::createPipeline(const PipelineConfig &config, PassManager& passManager)
     {
 		const vk::RenderPass &pass = passManager.getVkPass(config.m_PassHandle);
     	
+		const bool existsTaskShader = config.m_ShaderProgram.existsShader(ShaderStage::TASK);
+		const bool existsMeshShader = config.m_ShaderProgram.existsShader(ShaderStage::MESH);
+
         const bool existsVertexShader = config.m_ShaderProgram.existsShader(ShaderStage::VERTEX);
+
+        const bool validGeometryStages = existsVertexShader || (existsTaskShader && existsMeshShader);
+
         const bool existsFragmentShader = config.m_ShaderProgram.existsShader(ShaderStage::FRAGMENT);
-        if (!(existsVertexShader && existsFragmentShader))
+        if (!validGeometryStages)
         {
-			vkcv_log(LogLevel::ERROR, "Requires vertex and fragment shader code");
+			vkcv_log(LogLevel::ERROR, "Requires vertex or task and mesh shader");
             return PipelineHandle();
         }
-
-        // vertex shader stage
-        std::vector<char> vertexCode = config.m_ShaderProgram.getShader(ShaderStage::VERTEX).shaderCode;
-        vk::ShaderModuleCreateInfo vertexModuleInfo({}, vertexCode.size(), reinterpret_cast<uint32_t*>(vertexCode.data()));
-        vk::ShaderModule vertexModule{};
-        if (m_Device.createShaderModule(&vertexModuleInfo, nullptr, &vertexModule) != vk::Result::eSuccess)
+        if (!existsFragmentShader) {
+            vkcv_log(LogLevel::ERROR, "Requires fragment shader code");
             return PipelineHandle();
+        }
 
-        vk::PipelineShaderStageCreateInfo pipelineVertexShaderStageInfo(
-                {},
-                vk::ShaderStageFlagBits::eVertex,
-                vertexModule,
-                "main",
-                nullptr
-        );
+        std::vector<vk::PipelineShaderStageCreateInfo> shaderStages;
+        auto destroyShaderModules = [&shaderStages, this] {
+            for (auto stage : shaderStages) {
+                m_Device.destroyShaderModule(stage.module);
+            }
+            shaderStages.clear();
+        };
+
+        if (existsVertexShader) {
+            vk::PipelineShaderStageCreateInfo createInfo;
+            const bool success = createPipelineShaderStageCreateInfo(
+                config.m_ShaderProgram, 
+                vkcv::ShaderStage::VERTEX, 
+                m_Device, 
+                &createInfo);
+
+            if (success) {
+                shaderStages.push_back(createInfo);
+            }
+            else {
+                destroyShaderModules();
+                return PipelineHandle();
+            }
+        }
+
+        if (existsTaskShader) {
+            vk::PipelineShaderStageCreateInfo createInfo;
+            const bool success = createPipelineShaderStageCreateInfo(
+                config.m_ShaderProgram,
+                vkcv::ShaderStage::TASK,
+                m_Device,
+                &createInfo);
+
+            if (success) {
+                shaderStages.push_back(createInfo);
+            }
+            else {
+                destroyShaderModules();
+                return PipelineHandle();
+            }
+        }
+
+        if (existsMeshShader) {
+            vk::PipelineShaderStageCreateInfo createInfo;
+            const bool success = createPipelineShaderStageCreateInfo(
+                config.m_ShaderProgram,
+                vkcv::ShaderStage::MESH,
+                m_Device,
+                &createInfo);
+
+            if (success) {
+                shaderStages.push_back(createInfo);
+            }
+            else {
+                destroyShaderModules();
+                return PipelineHandle();
+            }
+        }
 
         // fragment shader stage
-        std::vector<char> fragCode = config.m_ShaderProgram.getShader(ShaderStage::FRAGMENT).shaderCode;
-        vk::ShaderModuleCreateInfo fragmentModuleInfo({}, fragCode.size(), reinterpret_cast<uint32_t*>(fragCode.data()));
-        vk::ShaderModule fragmentModule{};
-        if (m_Device.createShaderModule(&fragmentModuleInfo, nullptr, &fragmentModule) != vk::Result::eSuccess)
         {
-            m_Device.destroy(vertexModule);
-            return PipelineHandle();
+            vk::PipelineShaderStageCreateInfo createInfo;
+            const bool success = createPipelineShaderStageCreateInfo(
+                config.m_ShaderProgram,
+                vkcv::ShaderStage::FRAGMENT,
+                m_Device,
+                &createInfo);
+
+            if (success) {
+                shaderStages.push_back(createInfo);
+            }
+            else {
+                destroyShaderModules();
+                return PipelineHandle();
+            }
         }
 
-        vk::PipelineShaderStageCreateInfo pipelineFragmentShaderStageInfo(
-                {},
-                vk::ShaderStageFlagBits::eFragment,
-                fragmentModule,
-                "main",
-                nullptr
-        );
-
         // vertex input state
 
         // Fill up VertexInputBindingDescription and VertexInputAttributeDescription Containers
@@ -201,20 +293,23 @@ namespace vkcv
                 { 1.f,1.f,1.f,1.f }
         );
 
-		const size_t matrixPushConstantSize = config.m_ShaderProgram.getPushConstantSize();
-		const vk::PushConstantRange pushConstantRange(vk::ShaderStageFlagBits::eAll, 0, matrixPushConstantSize);
+		const size_t pushConstantSize = config.m_ShaderProgram.getPushConstantSize();
+		const vk::PushConstantRange pushConstantRange(vk::ShaderStageFlagBits::eAll, 0, pushConstantSize);
 
         // pipeline layout
         vk::PipelineLayoutCreateInfo pipelineLayoutCreateInfo(
 			{},
 			(config.m_DescriptorLayouts),
-			(pushConstantRange));
+			pushConstantRange);
+
+		if (pushConstantSize <= 0) {
+			pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
+		}
 
         vk::PipelineLayout vkPipelineLayout{};
         if (m_Device.createPipelineLayout(&pipelineLayoutCreateInfo, nullptr, &vkPipelineLayout) != vk::Result::eSuccess)
         {
-            m_Device.destroy(vertexModule);
-            m_Device.destroy(fragmentModule);
+            destroyShaderModules();
             return PipelineHandle();
         }
 	
@@ -249,25 +344,28 @@ namespace vkcv
 		    dynamicStates.push_back(vk::DynamicState::eScissor);
         }
 
-        vk::PipelineDynamicStateCreateInfo dynamicStateCreateInfo({},
-                                                            static_cast<uint32_t>(dynamicStates.size()),
-                                                            dynamicStates.data());
-
-        // graphics pipeline create
-        std::vector<vk::PipelineShaderStageCreateInfo> shaderStages = { pipelineVertexShaderStageInfo, pipelineFragmentShaderStageInfo };
-
-		const char *geometryShaderName = "main";	// outside of if to make sure it stays in scope
-		vk::ShaderModule geometryModule;
-		if (config.m_ShaderProgram.existsShader(ShaderStage::GEOMETRY)) {
-			const vkcv::Shader geometryShader = config.m_ShaderProgram.getShader(ShaderStage::GEOMETRY);
-			const auto& geometryCode = geometryShader.shaderCode;
-			const vk::ShaderModuleCreateInfo geometryModuleInfo({}, geometryCode.size(), reinterpret_cast<const uint32_t*>(geometryCode.data()));
-			if (m_Device.createShaderModule(&geometryModuleInfo, nullptr, &geometryModule) != vk::Result::eSuccess) {
-				return PipelineHandle();
-			}
-			vk::PipelineShaderStageCreateInfo geometryStage({}, vk::ShaderStageFlagBits::eGeometry, geometryModule, geometryShaderName);
-			shaderStages.push_back(geometryStage);
-		}
+        vk::PipelineDynamicStateCreateInfo dynamicStateCreateInfo(
+            {},
+            static_cast<uint32_t>(dynamicStates.size()),
+            dynamicStates.data());
+
+        const bool existsGeometryShader = config.m_ShaderProgram.existsShader(vkcv::ShaderStage::GEOMETRY);
+        if (existsGeometryShader) {
+            vk::PipelineShaderStageCreateInfo createInfo;
+            const bool success = createPipelineShaderStageCreateInfo(
+                config.m_ShaderProgram,
+                vkcv::ShaderStage::GEOMETRY,
+                m_Device,
+                &createInfo);
+
+            if (success) {
+                shaderStages.push_back(createInfo);
+            }
+            else {
+                destroyShaderModules();
+                return PipelineHandle();
+            }
+        }
 
         const vk::GraphicsPipelineCreateInfo graphicsPipelineCreateInfo(
                 {},
@@ -292,20 +390,11 @@ namespace vkcv
         vk::Pipeline vkPipeline{};
         if (m_Device.createGraphicsPipelines(nullptr, 1, &graphicsPipelineCreateInfo, nullptr, &vkPipeline) != vk::Result::eSuccess)
         {
-            m_Device.destroy(vertexModule);
-            m_Device.destroy(fragmentModule);
-            if (geometryModule) {
-                m_Device.destroy(geometryModule);
-            }
-            m_Device.destroy();
+            destroyShaderModules();
             return PipelineHandle();
         }
 
-        m_Device.destroy(vertexModule);
-        m_Device.destroy(fragmentModule);
-        if (geometryModule) {
-            m_Device.destroy(geometryModule);
-        }
+        destroyShaderModules();
         
         const uint64_t id = m_Pipelines.size();
         m_Pipelines.push_back({ vkPipeline, vkPipelineLayout, config });
-- 
GitLab