From 4206b372f2d90b2138c8030a2b5e68840917d86f Mon Sep 17 00:00:00 2001
From: Alexander Gauggel <agauggel@uni-koblenz.de>
Date: Sun, 16 May 2021 11:57:58 +0200
Subject: [PATCH] [18]adjust ShaderProgram and Pipeline so the latter can use
 the former

---
 include/vkcv/Pipeline.hpp      |  9 ++++-----
 include/vkcv/ShaderProgram.hpp | 10 +++++-----
 src/vkcv/Core.cpp              | 27 +++++++++++++++++++++++++--
 src/vkcv/Pipeline.cpp          |  4 ++--
 src/vkcv/ShaderProgram.cpp     | 10 +++++-----
 5 files changed, 41 insertions(+), 19 deletions(-)

diff --git a/include/vkcv/Pipeline.hpp b/include/vkcv/Pipeline.hpp
index 574bd526..83dd6e4f 100644
--- a/include/vkcv/Pipeline.hpp
+++ b/include/vkcv/Pipeline.hpp
@@ -10,6 +10,7 @@
 #include <vector>
 #include <cstdint>
 #include "vkcv/Handles.hpp"
+#include "ShaderProgram.hpp"
 
 namespace vkcv {
 
@@ -25,16 +26,14 @@ namespace vkcv {
          *  Constructor for the pipeline. Creates a pipeline using @p vertexCode, @p fragmentCode as well as the
          *  dimensions of the application window @p width and @p height. A handle for the Render Pass is also needed, @p passHandle.
          *
-         * @param vertexCode Spir-V of Vertex Shader
-         * @param fragCode Spir-V of Fragment Shader
+         * @param shaderProgram shaders of the pipeline
          * @param height height of the application window
          * @param width width of the application window
          * @param passHandle handle for Render Pass
          */
-        Pipeline(const std::vector<uint32_t> &vertexCode, const std::vector<uint32_t> &fragCode, uint32_t height, uint32_t width, RenderpassHandle &passHandle);
+        Pipeline(const ShaderProgram& shaderProgram, uint32_t width, uint32_t height, RenderpassHandle &passHandle);
 
-        std::vector<uint32_t> m_vertexCode;
-        std::vector<uint32_t> m_fragCode;
+		ShaderProgram m_shaderProgram;
         uint32_t m_height;
         uint32_t m_width;
         RenderpassHandle m_passHandle;
diff --git a/include/vkcv/ShaderProgram.hpp b/include/vkcv/ShaderProgram.hpp
index d9f60f17..ea0cb0f1 100644
--- a/include/vkcv/ShaderProgram.hpp
+++ b/include/vkcv/ShaderProgram.hpp
@@ -49,7 +49,7 @@ namespace vkcv {
         * @param[in] flag that signals the respective shader stage (e.g. VK_SHADER_STAGE_VERTEX_BIT)
         * @return boolean that is true if the shader program contains the shader stage
         */
-        bool containsShaderStage(ShaderProgram::ShaderStage shaderStage);
+        bool containsShaderStage(ShaderProgram::ShaderStage shaderStage) const;
 
         /**
         * Deletes the given shader stage in the shader program.
@@ -63,21 +63,21 @@ namespace vkcv {
         * Needed for the transfer to the pipeline.
         * @return vector list with all shader stage info structs
         */
-        std::vector<vk::ShaderStageFlagBits> getShaderStages();
+        std::vector<vk::ShaderStageFlagBits> getShaderStages() const;
 
         /**
         * Returns a list with all the shader code in the shader program.
         * Needed for the transfer to the pipeline.
         * @return vector list with all shader code char vecs
         */
-        std::vector<std::vector<char>> getShaderCode();
+        std::vector<std::vector<char>> getShaderCode() const;
 
         /**
         * Returns the number of shader stages in the shader program.
         * Needed for the transfer to the pipeline.
         * @return integer with the number of stages
         */
-        int getShaderStagesCount();
+        int getShaderStagesCount() const;
 
 
 
@@ -109,7 +109,7 @@ namespace vkcv {
 		* @param[in] ShaderStage enum
 		* @return vk::ShaderStageFlagBits
 		*/
-        vk::ShaderStageFlagBits convertToShaderStageFlagBits(ShaderProgram::ShaderStage shaderStage);
+        vk::ShaderStageFlagBits convertToShaderStageFlagBits(ShaderProgram::ShaderStage shaderStage) const;
 
 		/**
 		* Creates a shader module that encapsulates the read shader code. 
diff --git a/src/vkcv/Core.cpp b/src/vkcv/Core.cpp
index 4dd1da74..9635f2a4 100644
--- a/src/vkcv/Core.cpp
+++ b/src/vkcv/Core.cpp
@@ -421,8 +421,31 @@ namespace vkcv
 
 	bool Core::createGraphicsPipeline(const Pipeline& pipeline, PipelineHandle& handle) {
 
+		// TODO: this search could be avoided if ShaderProgram could be queried for a specific stage
+		const auto shaderStageFlags = pipeline.m_shaderProgram.getShaderStages();
+		const auto shaderCode = pipeline.m_shaderProgram.getShaderCode();
+		std::vector<char> vertexCode;
+		std::vector<char> fragCode;
+		assert(shaderStageFlags.size() == shaderCode.size());
+		for (int i = 0; i < shaderStageFlags.size(); i++) {
+			switch (shaderStageFlags[i]) {
+				case vk::ShaderStageFlagBits::eVertex: vertexCode = shaderCode[i]; break;
+				case vk::ShaderStageFlagBits::eFragment: fragCode = shaderCode[i]; break;
+				default: std::cout << "Core::createGraphicsPipeline encountered unknown shader stage" << std::endl; return false;
+			}
+		}
+
+		const bool foundVertexCode = vertexCode.size() > 0;
+		const bool foundFragCode = fragCode.size() > 0;
+		const bool foundRequiredShaderCode = foundVertexCode && foundFragCode;
+		if (!foundRequiredShaderCode) {
+			std::cout << "Core::createGraphicsPipeline requires vertex and fragment shader code" << std::endl; 
+			return false;
+		}
+
 		// vertex shader stage
-		vk::ShaderModuleCreateInfo vertexModuleInfo({}, pipeline.m_vertexCode.size(), pipeline.m_vertexCode.data());
+		// TODO: store shader code as uint32_t in ShaderProgram to avoid pointer cast
+		vk::ShaderModuleCreateInfo vertexModuleInfo({}, vertexCode.size(), reinterpret_cast<uint32_t*>(vertexCode.data()));
 		vk::ShaderModule vertexModule{};
 		if (m_Context.m_Device.createShaderModule(&vertexModuleInfo, nullptr, &vertexModule) != vk::Result::eSuccess)
 			return false;
@@ -436,7 +459,7 @@ namespace vkcv
 		);
 
 		// fragment shader stage
-		vk::ShaderModuleCreateInfo fragmentModuleInfo({}, pipeline.m_fragCode.size(), pipeline.m_fragCode.data());
+		vk::ShaderModuleCreateInfo fragmentModuleInfo({}, fragCode.size(), reinterpret_cast<uint32_t*>(fragCode.data()));
 		vk::ShaderModule fragmentModule{};
 		if (m_Context.m_Device.createShaderModule(&fragmentModuleInfo, nullptr, &fragmentModule) != vk::Result::eSuccess)
 			return false;
diff --git a/src/vkcv/Pipeline.cpp b/src/vkcv/Pipeline.cpp
index 6fee4d2b..42a6b963 100644
--- a/src/vkcv/Pipeline.cpp
+++ b/src/vkcv/Pipeline.cpp
@@ -8,6 +8,6 @@
 
 namespace vkcv {
 
-    Pipeline::Pipeline(const std::vector<uint32_t> &vertexCode, const std::vector<uint32_t> &fragCode, uint32_t height, uint32_t width, RenderpassHandle &passHandle):
-        m_vertexCode(vertexCode), m_fragCode(fragCode), m_height(height), m_width(width), m_passHandle(passHandle) {}
+    Pipeline::Pipeline(const ShaderProgram& shaderProgram, uint32_t width, uint32_t height, RenderpassHandle &passHandle):
+		m_shaderProgram(shaderProgram), m_height(height), m_width(width), m_passHandle(passHandle) {}
 }
diff --git a/src/vkcv/ShaderProgram.cpp b/src/vkcv/ShaderProgram.cpp
index 7659b1a7..b995dde0 100644
--- a/src/vkcv/ShaderProgram.cpp
+++ b/src/vkcv/ShaderProgram.cpp
@@ -31,7 +31,7 @@ namespace vkcv {
 		return buffer;
 	}
 
-    vk::ShaderStageFlagBits ShaderProgram::convertToShaderStageFlagBits(ShaderProgram::ShaderStage shaderStage){
+    vk::ShaderStageFlagBits ShaderProgram::convertToShaderStageFlagBits(ShaderProgram::ShaderStage shaderStage) const{
         switch (shaderStage) {
             case ShaderStage::VERTEX:
                 return vk::ShaderStageFlagBits::eVertex;
@@ -83,7 +83,7 @@ namespace vkcv {
 		}
 	}
 
-	bool ShaderProgram::containsShaderStage(ShaderProgram::ShaderStage shaderStage) {
+	bool ShaderProgram::containsShaderStage(ShaderProgram::ShaderStage shaderStage) const{
         vk::ShaderStageFlagBits convertedShaderStage = convertToShaderStageFlagBits(shaderStage);
 		for (int i = 0; i < m_shaderStages.shaderStageFlag.size(); i++) {
 			if (m_shaderStages.shaderStageFlag[i] == convertedShaderStage) {
@@ -105,15 +105,15 @@ namespace vkcv {
 		return false;
 	}
 
-	std::vector<vk::ShaderStageFlagBits> ShaderProgram::getShaderStages() {
+	std::vector<vk::ShaderStageFlagBits> ShaderProgram::getShaderStages() const{
 		return m_shaderStages.shaderStageFlag;
 	}
 
-    std::vector<std::vector<char>> ShaderProgram::getShaderCode() {
+    std::vector<std::vector<char>> ShaderProgram::getShaderCode() const {
 	    return m_shaderStages.shaderCode;
 	}
 
-	int ShaderProgram::getShaderStagesCount() {
+	int ShaderProgram::getShaderStagesCount() const {
 		return m_shaderStages.shaderStageFlag.size();
 	}
 }
-- 
GitLab