diff --git a/include/vkcv/Pipeline.hpp b/include/vkcv/Pipeline.hpp index 574bd5268ddfd5a0efcd4aa4c2388bd366ed03e0..83dd6e4f70314ade3207e6c5fdb98d4b2bab2b9a 100644 --- a/include/vkcv/Pipeline.hpp +++ b/include/vkcv/Pipeline.hpp @@ -10,6 +10,7 @@ #include <vector> #include <cstdint> #include "vkcv/Handles.hpp" +#include "ShaderProgram.hpp" namespace vkcv { @@ -25,16 +26,14 @@ namespace vkcv { * Constructor for the pipeline. Creates a pipeline using @p vertexCode, @p fragmentCode as well as the * dimensions of the application window @p width and @p height. A handle for the Render Pass is also needed, @p passHandle. * - * @param vertexCode Spir-V of Vertex Shader - * @param fragCode Spir-V of Fragment Shader + * @param shaderProgram shaders of the pipeline * @param height height of the application window * @param width width of the application window * @param passHandle handle for Render Pass */ - Pipeline(const std::vector<uint32_t> &vertexCode, const std::vector<uint32_t> &fragCode, uint32_t height, uint32_t width, RenderpassHandle &passHandle); + Pipeline(const ShaderProgram& shaderProgram, uint32_t width, uint32_t height, RenderpassHandle &passHandle); - std::vector<uint32_t> m_vertexCode; - std::vector<uint32_t> m_fragCode; + ShaderProgram m_shaderProgram; uint32_t m_height; uint32_t m_width; RenderpassHandle m_passHandle; diff --git a/include/vkcv/ShaderProgram.hpp b/include/vkcv/ShaderProgram.hpp index d9f60f17c60a6945ab70b0554e3bda7b81881cd8..ea0cb0f17c1ac5a292c51365fe3d0007c2e7ebfc 100644 --- a/include/vkcv/ShaderProgram.hpp +++ b/include/vkcv/ShaderProgram.hpp @@ -49,7 +49,7 @@ namespace vkcv { * @param[in] flag that signals the respective shader stage (e.g. VK_SHADER_STAGE_VERTEX_BIT) * @return boolean that is true if the shader program contains the shader stage */ - bool containsShaderStage(ShaderProgram::ShaderStage shaderStage); + bool containsShaderStage(ShaderProgram::ShaderStage shaderStage) const; /** * Deletes the given shader stage in the shader program. @@ -63,21 +63,21 @@ namespace vkcv { * Needed for the transfer to the pipeline. * @return vector list with all shader stage info structs */ - std::vector<vk::ShaderStageFlagBits> getShaderStages(); + std::vector<vk::ShaderStageFlagBits> getShaderStages() const; /** * Returns a list with all the shader code in the shader program. * Needed for the transfer to the pipeline. * @return vector list with all shader code char vecs */ - std::vector<std::vector<char>> getShaderCode(); + std::vector<std::vector<char>> getShaderCode() const; /** * Returns the number of shader stages in the shader program. * Needed for the transfer to the pipeline. * @return integer with the number of stages */ - int getShaderStagesCount(); + int getShaderStagesCount() const; @@ -109,7 +109,7 @@ namespace vkcv { * @param[in] ShaderStage enum * @return vk::ShaderStageFlagBits */ - vk::ShaderStageFlagBits convertToShaderStageFlagBits(ShaderProgram::ShaderStage shaderStage); + vk::ShaderStageFlagBits convertToShaderStageFlagBits(ShaderProgram::ShaderStage shaderStage) const; /** * Creates a shader module that encapsulates the read shader code. diff --git a/src/vkcv/Core.cpp b/src/vkcv/Core.cpp index 4dd1da746226c124707929483a8f0ea4c688a7b1..9635f2a4e1ae272f25ff4b642a43d955ae8cdadc 100644 --- a/src/vkcv/Core.cpp +++ b/src/vkcv/Core.cpp @@ -421,8 +421,31 @@ namespace vkcv bool Core::createGraphicsPipeline(const Pipeline& pipeline, PipelineHandle& handle) { + // TODO: this search could be avoided if ShaderProgram could be queried for a specific stage + const auto shaderStageFlags = pipeline.m_shaderProgram.getShaderStages(); + const auto shaderCode = pipeline.m_shaderProgram.getShaderCode(); + std::vector<char> vertexCode; + std::vector<char> fragCode; + assert(shaderStageFlags.size() == shaderCode.size()); + for (int i = 0; i < shaderStageFlags.size(); i++) { + switch (shaderStageFlags[i]) { + case vk::ShaderStageFlagBits::eVertex: vertexCode = shaderCode[i]; break; + case vk::ShaderStageFlagBits::eFragment: fragCode = shaderCode[i]; break; + default: std::cout << "Core::createGraphicsPipeline encountered unknown shader stage" << std::endl; return false; + } + } + + const bool foundVertexCode = vertexCode.size() > 0; + const bool foundFragCode = fragCode.size() > 0; + const bool foundRequiredShaderCode = foundVertexCode && foundFragCode; + if (!foundRequiredShaderCode) { + std::cout << "Core::createGraphicsPipeline requires vertex and fragment shader code" << std::endl; + return false; + } + // vertex shader stage - vk::ShaderModuleCreateInfo vertexModuleInfo({}, pipeline.m_vertexCode.size(), pipeline.m_vertexCode.data()); + // TODO: store shader code as uint32_t in ShaderProgram to avoid pointer cast + vk::ShaderModuleCreateInfo vertexModuleInfo({}, vertexCode.size(), reinterpret_cast<uint32_t*>(vertexCode.data())); vk::ShaderModule vertexModule{}; if (m_Context.m_Device.createShaderModule(&vertexModuleInfo, nullptr, &vertexModule) != vk::Result::eSuccess) return false; @@ -436,7 +459,7 @@ namespace vkcv ); // fragment shader stage - vk::ShaderModuleCreateInfo fragmentModuleInfo({}, pipeline.m_fragCode.size(), pipeline.m_fragCode.data()); + vk::ShaderModuleCreateInfo fragmentModuleInfo({}, fragCode.size(), reinterpret_cast<uint32_t*>(fragCode.data())); vk::ShaderModule fragmentModule{}; if (m_Context.m_Device.createShaderModule(&fragmentModuleInfo, nullptr, &fragmentModule) != vk::Result::eSuccess) return false; diff --git a/src/vkcv/Pipeline.cpp b/src/vkcv/Pipeline.cpp index 6fee4d2b774d4b1016894f15134d5b32634236c1..42a6b963e31286bfaedb1d27e4fa679be1fc71e6 100644 --- a/src/vkcv/Pipeline.cpp +++ b/src/vkcv/Pipeline.cpp @@ -8,6 +8,6 @@ namespace vkcv { - Pipeline::Pipeline(const std::vector<uint32_t> &vertexCode, const std::vector<uint32_t> &fragCode, uint32_t height, uint32_t width, RenderpassHandle &passHandle): - m_vertexCode(vertexCode), m_fragCode(fragCode), m_height(height), m_width(width), m_passHandle(passHandle) {} + Pipeline::Pipeline(const ShaderProgram& shaderProgram, uint32_t width, uint32_t height, RenderpassHandle &passHandle): + m_shaderProgram(shaderProgram), m_height(height), m_width(width), m_passHandle(passHandle) {} } diff --git a/src/vkcv/ShaderProgram.cpp b/src/vkcv/ShaderProgram.cpp index 7659b1a7d94a3968c9be424e8a46945f7621cdb6..b995dde0916a13689a025f234aae1d27a1249204 100644 --- a/src/vkcv/ShaderProgram.cpp +++ b/src/vkcv/ShaderProgram.cpp @@ -31,7 +31,7 @@ namespace vkcv { return buffer; } - vk::ShaderStageFlagBits ShaderProgram::convertToShaderStageFlagBits(ShaderProgram::ShaderStage shaderStage){ + vk::ShaderStageFlagBits ShaderProgram::convertToShaderStageFlagBits(ShaderProgram::ShaderStage shaderStage) const{ switch (shaderStage) { case ShaderStage::VERTEX: return vk::ShaderStageFlagBits::eVertex; @@ -83,7 +83,7 @@ namespace vkcv { } } - bool ShaderProgram::containsShaderStage(ShaderProgram::ShaderStage shaderStage) { + bool ShaderProgram::containsShaderStage(ShaderProgram::ShaderStage shaderStage) const{ vk::ShaderStageFlagBits convertedShaderStage = convertToShaderStageFlagBits(shaderStage); for (int i = 0; i < m_shaderStages.shaderStageFlag.size(); i++) { if (m_shaderStages.shaderStageFlag[i] == convertedShaderStage) { @@ -105,15 +105,15 @@ namespace vkcv { return false; } - std::vector<vk::ShaderStageFlagBits> ShaderProgram::getShaderStages() { + std::vector<vk::ShaderStageFlagBits> ShaderProgram::getShaderStages() const{ return m_shaderStages.shaderStageFlag; } - std::vector<std::vector<char>> ShaderProgram::getShaderCode() { + std::vector<std::vector<char>> ShaderProgram::getShaderCode() const { return m_shaderStages.shaderCode; } - int ShaderProgram::getShaderStagesCount() { + int ShaderProgram::getShaderStagesCount() const { return m_shaderStages.shaderStageFlag.size(); } }