From c420f75a09bedfb9c72de943b7df624e9404cad0 Mon Sep 17 00:00:00 2001
From: Simeon Hermann <shermann04@uni-koblenz.de>
Date: Wed, 30 Jun 2021 19:21:49 +0200
Subject: [PATCH] [#76] replace vkcv shader flags by vk flags for first_mesh to
 run

---
 include/vkcv/DescriptorConfig.hpp |  4 ++--
 include/vkcv/ShaderProgram.hpp    | 13 ++++++-------
 include/vkcv/ShaderStage.hpp      | 11 -----------
 projects/first_mesh/src/main.cpp  |  4 ++--
 src/vkcv/DescriptorConfig.cpp     |  4 ++--
 src/vkcv/DescriptorManager.cpp    | 22 +---------------------
 src/vkcv/DescriptorManager.hpp    |  6 ------
 src/vkcv/PipelineManager.cpp      | 16 ++++++++--------
 src/vkcv/PipelineManager.hpp      |  2 +-
 src/vkcv/ShaderProgram.cpp        | 10 +++++-----
 10 files changed, 27 insertions(+), 65 deletions(-)

diff --git a/include/vkcv/DescriptorConfig.hpp b/include/vkcv/DescriptorConfig.hpp
index c6d0dfd1..29dc8124 100644
--- a/include/vkcv/DescriptorConfig.hpp
+++ b/include/vkcv/DescriptorConfig.hpp
@@ -38,12 +38,12 @@ namespace vkcv
             uint32_t bindingID,
             DescriptorType descriptorType,
             uint32_t descriptorCount,
-            ShaderStage shaderStage
+            vk::ShaderStageFlags shaderStages
         ) noexcept;
         
         uint32_t bindingID;
         DescriptorType descriptorType;
         uint32_t descriptorCount;
-        ShaderStage shaderStage;
+        vk::ShaderStageFlags shaderStages;
     };
 }
diff --git a/include/vkcv/ShaderProgram.hpp b/include/vkcv/ShaderProgram.hpp
index 78b1f021..707d72a8 100644
--- a/include/vkcv/ShaderProgram.hpp
+++ b/include/vkcv/ShaderProgram.hpp
@@ -13,7 +13,6 @@
 #include <vulkan/vulkan.hpp>
 #include <spirv_cross.hpp>
 #include "VertexLayout.hpp"
-#include "ShaderStage.hpp"
 #include "DescriptorConfig.hpp"
 
 namespace vkcv {
@@ -21,7 +20,7 @@ namespace vkcv {
     struct Shader
     {
         std::vector<char> shaderCode;
-        ShaderStage shaderStage;
+        vk::ShaderStageFlagBits shaderStage;
     };
 
 	class ShaderProgram
@@ -37,16 +36,16 @@ namespace vkcv {
         * @param[in] flag that signals the respective shaderStage (e.g. VK_SHADER_STAGE_VERTEX_BIT)
         * @param[in] relative path to the shader code (e.g. "../../../../../shaders/vert.spv")
         */
-        bool addShader(ShaderStage shaderStage, const std::filesystem::path &shaderPath);
+        bool addShader(vk::ShaderStageFlagBits shaderStage, const std::filesystem::path &shaderPath);
 
         /**
         * Returns the shader program's shader of the specified shader.
         * Needed for the transfer to the pipeline.
         * @return Shader object consisting of buffer with shader code and shader stage enum
         */
-        const Shader &getShader(ShaderStage shaderStage) const;
+        const Shader &getShader(vk::ShaderStageFlagBits shaderStage) const;
 
-        bool existsShader(ShaderStage shaderStage) const;
+        bool existsShader(vk::ShaderStageFlagBits shaderStage) const;
 
         const std::vector<VertexAttachment> &getVertexAttachments() const;
 		size_t getPushConstantSize() const;
@@ -59,9 +58,9 @@ namespace vkcv {
 	     * Fills vertex input attachments and descriptor sets (if present).
 	     * @param shaderStage the stage to reflect data from
 	     */
-        void reflectShader(ShaderStage shaderStage);
+        void reflectShader(vk::ShaderStageFlagBits shaderStage);
 
-        std::unordered_map<ShaderStage, Shader> m_Shaders;
+        std::unordered_map<vk::ShaderStageFlagBits, Shader> m_Shaders;
 
         // contains all vertex input attachments used in the vertex buffer
         std::vector<VertexAttachment> m_VertexAttachments;
diff --git a/include/vkcv/ShaderStage.hpp b/include/vkcv/ShaderStage.hpp
index a52db5d8..488fd92f 100644
--- a/include/vkcv/ShaderStage.hpp
+++ b/include/vkcv/ShaderStage.hpp
@@ -2,15 +2,4 @@
 
 namespace vkcv {
 
-	enum class ShaderStage : VkShaderStageFlags
-	{
-		VERTEX = VK_SHADER_STAGE_VERTEX_BIT,
-		TESS_CONTROL = VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
-		TESS_EVAL = VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
-		GEOMETRY = VK_SHADER_STAGE_GEOMETRY_BIT,
-		FRAGMENT = VK_SHADER_STAGE_FRAGMENT_BIT,
-		COMPUTE = VK_SHADER_STAGE_COMPUTE_BIT
-	};
-
-	using ShaderStages = vk::Flags<ShaderStage>;
 }
diff --git a/projects/first_mesh/src/main.cpp b/projects/first_mesh/src/main.cpp
index dc43c905..fabea47f 100644
--- a/projects/first_mesh/src/main.cpp
+++ b/projects/first_mesh/src/main.cpp
@@ -79,8 +79,8 @@ int main(int argc, const char** argv) {
 	}
 
 	vkcv::ShaderProgram firstMeshProgram{};
-    firstMeshProgram.addShader(vkcv::ShaderStage::VERTEX, std::filesystem::path("resources/shaders/vert.spv"));
-    firstMeshProgram.addShader(vkcv::ShaderStage::FRAGMENT, std::filesystem::path("resources/shaders/frag.spv"));
+    firstMeshProgram.addShader(vk::ShaderStageFlagBits::eVertex, std::filesystem::path("resources/shaders/vert.spv"));
+    firstMeshProgram.addShader(vk::ShaderStageFlagBits::eFragment, std::filesystem::path("resources/shaders/frag.spv"));
 	
 	auto& attributes = mesh.vertexGroups[0].vertexBuffer.attributes;
 	
diff --git a/src/vkcv/DescriptorConfig.cpp b/src/vkcv/DescriptorConfig.cpp
index 54e879ac..1fc075ce 100644
--- a/src/vkcv/DescriptorConfig.cpp
+++ b/src/vkcv/DescriptorConfig.cpp
@@ -5,11 +5,11 @@ namespace vkcv {
 		uint32_t bindingID,
 		DescriptorType descriptorType,
 		uint32_t descriptorCount,
-		ShaderStage shaderStage) noexcept
+		vk::ShaderStageFlags shaderStages) noexcept
 		:
 		bindingID(bindingID),
 		descriptorType(descriptorType),
 		descriptorCount(descriptorCount),
-		shaderStage(shaderStage) {}
+		shaderStages(shaderStages) {}
 	
 }
diff --git a/src/vkcv/DescriptorManager.cpp b/src/vkcv/DescriptorManager.cpp
index 8e565a76..676e840a 100644
--- a/src/vkcv/DescriptorManager.cpp
+++ b/src/vkcv/DescriptorManager.cpp
@@ -45,7 +45,7 @@ namespace vkcv
                 bindings[i].bindingID,
                 convertDescriptorTypeFlag(bindings[i].descriptorType),
                 bindings[i].descriptorCount,
-                convertShaderStageFlag(bindings[i].shaderStage));
+                bindings[i].shaderStages);
             setBindings.push_back(descriptorSetLayoutBinding);
         }
 
@@ -245,26 +245,6 @@ namespace vkcv
                 return vk::DescriptorType::eUniformBuffer;
         }
     }
-
-    vk::ShaderStageFlagBits DescriptorManager::convertShaderStageFlag(ShaderStage stage) {
-        switch (stage) 
-        {
-            case ShaderStage::VERTEX:
-                return vk::ShaderStageFlagBits::eVertex;
-            case ShaderStage::FRAGMENT:
-                return vk::ShaderStageFlagBits::eFragment;
-            case ShaderStage::TESS_CONTROL:
-                return vk::ShaderStageFlagBits::eTessellationControl;
-            case ShaderStage::TESS_EVAL:
-                return vk::ShaderStageFlagBits::eTessellationEvaluation;
-            case ShaderStage::GEOMETRY:
-                return vk::ShaderStageFlagBits::eGeometry;
-            case ShaderStage::COMPUTE:
-                return vk::ShaderStageFlagBits::eCompute;
-            default:
-                return vk::ShaderStageFlagBits::eAll;
-        }
-    }
     
     void DescriptorManager::destroyDescriptorSetById(uint64_t id) {
 		if (id >= m_DescriptorSets.size()) {
diff --git a/src/vkcv/DescriptorManager.hpp b/src/vkcv/DescriptorManager.hpp
index d18be64f..df58daa5 100644
--- a/src/vkcv/DescriptorManager.hpp
+++ b/src/vkcv/DescriptorManager.hpp
@@ -51,12 +51,6 @@ namespace vkcv
 		* @return vk flag of the DescriptorType
 		*/
 		static vk::DescriptorType convertDescriptorTypeFlag(DescriptorType type);
-		/**
-		* Converts the flags of the shader stages from VulkanCV (vkcv) to Vulkan (vk).
-		* @param[in] vkcv flag of the ShaderStage (see ShaderProgram.hpp)
-		* @return vk flag of the ShaderStage
-		*/
-		static vk::ShaderStageFlagBits convertShaderStageFlag(ShaderStage stage);
 		
 		/**
 		* Destroys a specific resource description
diff --git a/src/vkcv/PipelineManager.cpp b/src/vkcv/PipelineManager.cpp
index df36442e..c8f885df 100644
--- a/src/vkcv/PipelineManager.cpp
+++ b/src/vkcv/PipelineManager.cpp
@@ -55,8 +55,8 @@ namespace vkcv
     {
 		const vk::RenderPass &pass = passManager.getVkPass(config.m_PassHandle);
     	
-        const bool existsVertexShader = config.m_ShaderProgram.existsShader(ShaderStage::VERTEX);
-        const bool existsFragmentShader = config.m_ShaderProgram.existsShader(ShaderStage::FRAGMENT);
+        const bool existsVertexShader = config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eVertex);
+        const bool existsFragmentShader = config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eFragment);
         if (!(existsVertexShader && existsFragmentShader))
         {
 			vkcv_log(LogLevel::ERROR, "Requires vertex and fragment shader code");
@@ -64,7 +64,7 @@ namespace vkcv
         }
 
         // vertex shader stage
-        std::vector<char> vertexCode = config.m_ShaderProgram.getShader(ShaderStage::VERTEX).shaderCode;
+        std::vector<char> vertexCode = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eVertex).shaderCode;
         vk::ShaderModuleCreateInfo vertexModuleInfo({}, vertexCode.size(), reinterpret_cast<uint32_t*>(vertexCode.data()));
         vk::ShaderModule vertexModule{};
         if (m_Device.createShaderModule(&vertexModuleInfo, nullptr, &vertexModule) != vk::Result::eSuccess)
@@ -79,7 +79,7 @@ namespace vkcv
         );
 
         // fragment shader stage
-        std::vector<char> fragCode = config.m_ShaderProgram.getShader(ShaderStage::FRAGMENT).shaderCode;
+        std::vector<char> fragCode = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eFragment).shaderCode;
         vk::ShaderModuleCreateInfo fragmentModuleInfo({}, fragCode.size(), reinterpret_cast<uint32_t*>(fragCode.data()));
         vk::ShaderModule fragmentModule{};
         if (m_Device.createShaderModule(&fragmentModuleInfo, nullptr, &fragmentModule) != vk::Result::eSuccess)
@@ -258,8 +258,8 @@ namespace vkcv
 
 		const char *geometryShaderName = "main";	// outside of if to make sure it stays in scope
 		vk::ShaderModule geometryModule;
-		if (config.m_ShaderProgram.existsShader(ShaderStage::GEOMETRY)) {
-			const vkcv::Shader geometryShader = config.m_ShaderProgram.getShader(ShaderStage::GEOMETRY);
+		if (config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eGeometry)) {
+			const vkcv::Shader geometryShader = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eGeometry);
 			const auto& geometryCode = geometryShader.shaderCode;
 			const vk::ShaderModuleCreateInfo geometryModuleInfo({}, geometryCode.size(), reinterpret_cast<const uint32_t*>(geometryCode.data()));
 			if (m_Device.createShaderModule(&geometryModuleInfo, nullptr, &geometryModule) != vk::Result::eSuccess) {
@@ -375,7 +375,7 @@ namespace vkcv
 
         // Temporally handing over the Shader Program instead of a pipeline config
         vk::ShaderModule computeModule{};
-        if (createShaderModule(computeModule, shaderProgram, ShaderStage::COMPUTE) != vk::Result::eSuccess)
+        if (createShaderModule(computeModule, shaderProgram, vk::ShaderStageFlagBits::eCompute) != vk::Result::eSuccess)
             return PipelineHandle();
 
         vk::PipelineShaderStageCreateInfo pipelineComputeShaderStageInfo(
@@ -424,7 +424,7 @@ namespace vkcv
     // There is an issue for refactoring the Pipeline Manager.
     // While including Compute Pipeline Creation, some private helper functions where introduced:
 
-    vk::Result PipelineManager::createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, const ShaderStage stage)
+    vk::Result PipelineManager::createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, const vk::ShaderStageFlagBits stage)
     {
         std::vector<char> code = shaderProgram.getShader(stage).shaderCode;
         vk::ShaderModuleCreateInfo moduleInfo({}, code.size(), reinterpret_cast<uint32_t*>(code.data()));
diff --git a/src/vkcv/PipelineManager.hpp b/src/vkcv/PipelineManager.hpp
index b153eb46..b48e06f0 100644
--- a/src/vkcv/PipelineManager.hpp
+++ b/src/vkcv/PipelineManager.hpp
@@ -22,7 +22,7 @@ namespace vkcv
         
         void destroyPipelineById(uint64_t id);
 
-        vk::Result createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, ShaderStage stage);
+        vk::Result createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, vk::ShaderStageFlagBits stage);
 
     public:
         PipelineManager() = delete; // no default ctor
diff --git a/src/vkcv/ShaderProgram.cpp b/src/vkcv/ShaderProgram.cpp
index 971797d9..4c1508ad 100644
--- a/src/vkcv/ShaderProgram.cpp
+++ b/src/vkcv/ShaderProgram.cpp
@@ -76,7 +76,7 @@ namespace vkcv {
     m_DescriptorSets{}
 	{}
 
-	bool ShaderProgram::addShader(ShaderStage shaderStage, const std::filesystem::path &shaderPath)
+	bool ShaderProgram::addShader(vk::ShaderStageFlagBits shaderStage, const std::filesystem::path &shaderPath)
 	{
 	    if(m_Shaders.find(shaderStage) != m_Shaders.end()) {
 			vkcv_log(LogLevel::WARNING, "Overwriting existing shader stage");
@@ -94,12 +94,12 @@ namespace vkcv {
         }
 	}
 
-    const Shader &ShaderProgram::getShader(ShaderStage shaderStage) const
+    const Shader &ShaderProgram::getShader(vk::ShaderStageFlagBits shaderStage) const
     {
 	    return m_Shaders.at(shaderStage);
 	}
 
-    bool ShaderProgram::existsShader(ShaderStage shaderStage) const
+    bool ShaderProgram::existsShader(vk::ShaderStageFlagBits shaderStage) const
     {
 	    if(m_Shaders.find(shaderStage) == m_Shaders.end())
 	        return false;
@@ -107,7 +107,7 @@ namespace vkcv {
 	        return true;
     }
 
-    void ShaderProgram::reflectShader(ShaderStage shaderStage)
+    void ShaderProgram::reflectShader(vk::ShaderStageFlagBits shaderStage)
     {
         auto shaderCodeChar = m_Shaders.at(shaderStage).shaderCode;
         std::vector<uint32_t> shaderCode;
@@ -119,7 +119,7 @@ namespace vkcv {
         spirv_cross::ShaderResources resources = comp.get_shader_resources();
 
         //reflect vertex input
-		if (shaderStage == ShaderStage::VERTEX)
+		if (shaderStage == vk::ShaderStageFlagBits::eVertex)
 		{
 			// spirv-cross API (hopefully) returns the stage_inputs in order
 			for (uint32_t i = 0; i < resources.stage_inputs.size(); i++)
-- 
GitLab