diff --git a/include/vkcv/DescriptorConfig.hpp b/include/vkcv/DescriptorConfig.hpp index 29dc81240367d33e0bb96ef6146e3540d80ac50f..78e111f40144a825bb790149f02a91039eb517b9 100644 --- a/include/vkcv/DescriptorConfig.hpp +++ b/include/vkcv/DescriptorConfig.hpp @@ -38,12 +38,12 @@ namespace vkcv uint32_t bindingID, DescriptorType descriptorType, uint32_t descriptorCount, - vk::ShaderStageFlags shaderStages + ShaderStages shaderStages ) noexcept; uint32_t bindingID; DescriptorType descriptorType; uint32_t descriptorCount; - vk::ShaderStageFlags shaderStages; + ShaderStages shaderStages; }; } diff --git a/include/vkcv/ShaderProgram.hpp b/include/vkcv/ShaderProgram.hpp index 707d72a864cfd95936a476a11675818e48cc6a9d..c7d67b19148b3c9ec19ce1b539f9661797d1b38f 100644 --- a/include/vkcv/ShaderProgram.hpp +++ b/include/vkcv/ShaderProgram.hpp @@ -14,13 +14,14 @@ #include <spirv_cross.hpp> #include "VertexLayout.hpp" #include "DescriptorConfig.hpp" +#include "ShaderStage.hpp" namespace vkcv { struct Shader { std::vector<char> shaderCode; - vk::ShaderStageFlagBits shaderStage; + ShaderStage shaderStage; }; class ShaderProgram @@ -36,16 +37,16 @@ namespace vkcv { * @param[in] flag that signals the respective shaderStage (e.g. VK_SHADER_STAGE_VERTEX_BIT) * @param[in] relative path to the shader code (e.g. "../../../../../shaders/vert.spv") */ - bool addShader(vk::ShaderStageFlagBits shaderStage, const std::filesystem::path &shaderPath); + bool addShader(ShaderStage shaderStage, const std::filesystem::path &shaderPath); /** * Returns the shader program's shader of the specified shader. * Needed for the transfer to the pipeline. * @return Shader object consisting of buffer with shader code and shader stage enum */ - const Shader &getShader(vk::ShaderStageFlagBits shaderStage) const; + const Shader &getShader(ShaderStage shaderStage) const; - bool existsShader(vk::ShaderStageFlagBits shaderStage) const; + bool existsShader(ShaderStage shaderStage) const; const std::vector<VertexAttachment> &getVertexAttachments() const; size_t getPushConstantSize() const; @@ -58,9 +59,9 @@ namespace vkcv { * Fills vertex input attachments and descriptor sets (if present). * @param shaderStage the stage to reflect data from */ - void reflectShader(vk::ShaderStageFlagBits shaderStage); + void reflectShader(ShaderStage shaderStage); - std::unordered_map<vk::ShaderStageFlagBits, Shader> m_Shaders; + std::unordered_map<ShaderStage, Shader> m_Shaders; // contains all vertex input attachments used in the vertex buffer std::vector<VertexAttachment> m_VertexAttachments; diff --git a/include/vkcv/ShaderStage.hpp b/include/vkcv/ShaderStage.hpp index 488fd92fb081a4671055daa5bde4c0bf8116b71a..d671b87b55ac7a5a8926e479c77fa991dd90c665 100644 --- a/include/vkcv/ShaderStage.hpp +++ b/include/vkcv/ShaderStage.hpp @@ -1,5 +1,38 @@ #pragma once -namespace vkcv { +#include <vulkan/vulkan.hpp> +namespace vkcv { + + enum class ShaderStage : VkShaderStageFlags { + VERTEX = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eVertex), + TESS_CONTROL = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eTessellationControl), + TESS_EVAL = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eTessellationEvaluation), + GEOMETRY = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eGeometry), + FRAGMENT = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eFragment), + COMPUTE = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eCompute) + }; + + using ShaderStages = vk::Flags<ShaderStage>; + + constexpr vk::ShaderStageFlags getShaderStageFlags(ShaderStages shaderStages) noexcept { + return vk::ShaderStageFlags(static_cast<VkShaderStageFlags>(shaderStages)); + } + + constexpr ShaderStages operator|(ShaderStage stage0, ShaderStage stage1) noexcept { + return ShaderStages(stage0) | stage1; + } + + constexpr ShaderStages operator&(ShaderStage stage0, ShaderStage stage1) noexcept { + return ShaderStages(stage0) & stage1; + } + + constexpr ShaderStages operator^(ShaderStage stage0, ShaderStage stage1) noexcept { + return ShaderStages(stage0) ^ stage1; + } + + constexpr ShaderStages operator~(ShaderStage stage) noexcept { + return ~(ShaderStages(stage)); + } + } diff --git a/projects/first_mesh/src/main.cpp b/projects/first_mesh/src/main.cpp index fabea47fc06fc15838bd390525effd32d62de8bc..dc43c905784525a34732bc0e66343fbdcc17a639 100644 --- a/projects/first_mesh/src/main.cpp +++ b/projects/first_mesh/src/main.cpp @@ -79,8 +79,8 @@ int main(int argc, const char** argv) { } vkcv::ShaderProgram firstMeshProgram{}; - firstMeshProgram.addShader(vk::ShaderStageFlagBits::eVertex, std::filesystem::path("resources/shaders/vert.spv")); - firstMeshProgram.addShader(vk::ShaderStageFlagBits::eFragment, std::filesystem::path("resources/shaders/frag.spv")); + firstMeshProgram.addShader(vkcv::ShaderStage::VERTEX, std::filesystem::path("resources/shaders/vert.spv")); + firstMeshProgram.addShader(vkcv::ShaderStage::FRAGMENT, std::filesystem::path("resources/shaders/frag.spv")); auto& attributes = mesh.vertexGroups[0].vertexBuffer.attributes; diff --git a/src/vkcv/DescriptorConfig.cpp b/src/vkcv/DescriptorConfig.cpp index 1fc075ced91aebe4a5d3596c635208368db08822..a9a127fe608472682c1cbc8d32ca466fba860c72 100644 --- a/src/vkcv/DescriptorConfig.cpp +++ b/src/vkcv/DescriptorConfig.cpp @@ -5,7 +5,7 @@ namespace vkcv { uint32_t bindingID, DescriptorType descriptorType, uint32_t descriptorCount, - vk::ShaderStageFlags shaderStages) noexcept + ShaderStages shaderStages) noexcept : bindingID(bindingID), descriptorType(descriptorType), diff --git a/src/vkcv/DescriptorManager.cpp b/src/vkcv/DescriptorManager.cpp index 676e840a1b7d28548e83cbb53ccbd62992a00e2c..d28dd9d137240ba923b55c9be9da9059d3a9ab31 100644 --- a/src/vkcv/DescriptorManager.cpp +++ b/src/vkcv/DescriptorManager.cpp @@ -45,7 +45,7 @@ namespace vkcv bindings[i].bindingID, convertDescriptorTypeFlag(bindings[i].descriptorType), bindings[i].descriptorCount, - bindings[i].shaderStages); + getShaderStageFlags(bindings[i].shaderStages)); setBindings.push_back(descriptorSetLayoutBinding); } diff --git a/src/vkcv/PipelineManager.cpp b/src/vkcv/PipelineManager.cpp index c8f885dffe5593c8f47f702bfdb99128a9d5a12d..df36442efc2992bf16b6e82245ef9753dad95e5d 100644 --- a/src/vkcv/PipelineManager.cpp +++ b/src/vkcv/PipelineManager.cpp @@ -55,8 +55,8 @@ namespace vkcv { const vk::RenderPass &pass = passManager.getVkPass(config.m_PassHandle); - const bool existsVertexShader = config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eVertex); - const bool existsFragmentShader = config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eFragment); + const bool existsVertexShader = config.m_ShaderProgram.existsShader(ShaderStage::VERTEX); + const bool existsFragmentShader = config.m_ShaderProgram.existsShader(ShaderStage::FRAGMENT); if (!(existsVertexShader && existsFragmentShader)) { vkcv_log(LogLevel::ERROR, "Requires vertex and fragment shader code"); @@ -64,7 +64,7 @@ namespace vkcv } // vertex shader stage - std::vector<char> vertexCode = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eVertex).shaderCode; + std::vector<char> vertexCode = config.m_ShaderProgram.getShader(ShaderStage::VERTEX).shaderCode; vk::ShaderModuleCreateInfo vertexModuleInfo({}, vertexCode.size(), reinterpret_cast<uint32_t*>(vertexCode.data())); vk::ShaderModule vertexModule{}; if (m_Device.createShaderModule(&vertexModuleInfo, nullptr, &vertexModule) != vk::Result::eSuccess) @@ -79,7 +79,7 @@ namespace vkcv ); // fragment shader stage - std::vector<char> fragCode = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eFragment).shaderCode; + std::vector<char> fragCode = config.m_ShaderProgram.getShader(ShaderStage::FRAGMENT).shaderCode; vk::ShaderModuleCreateInfo fragmentModuleInfo({}, fragCode.size(), reinterpret_cast<uint32_t*>(fragCode.data())); vk::ShaderModule fragmentModule{}; if (m_Device.createShaderModule(&fragmentModuleInfo, nullptr, &fragmentModule) != vk::Result::eSuccess) @@ -258,8 +258,8 @@ namespace vkcv const char *geometryShaderName = "main"; // outside of if to make sure it stays in scope vk::ShaderModule geometryModule; - if (config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eGeometry)) { - const vkcv::Shader geometryShader = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eGeometry); + if (config.m_ShaderProgram.existsShader(ShaderStage::GEOMETRY)) { + const vkcv::Shader geometryShader = config.m_ShaderProgram.getShader(ShaderStage::GEOMETRY); const auto& geometryCode = geometryShader.shaderCode; const vk::ShaderModuleCreateInfo geometryModuleInfo({}, geometryCode.size(), reinterpret_cast<const uint32_t*>(geometryCode.data())); if (m_Device.createShaderModule(&geometryModuleInfo, nullptr, &geometryModule) != vk::Result::eSuccess) { @@ -375,7 +375,7 @@ namespace vkcv // Temporally handing over the Shader Program instead of a pipeline config vk::ShaderModule computeModule{}; - if (createShaderModule(computeModule, shaderProgram, vk::ShaderStageFlagBits::eCompute) != vk::Result::eSuccess) + if (createShaderModule(computeModule, shaderProgram, ShaderStage::COMPUTE) != vk::Result::eSuccess) return PipelineHandle(); vk::PipelineShaderStageCreateInfo pipelineComputeShaderStageInfo( @@ -424,7 +424,7 @@ namespace vkcv // There is an issue for refactoring the Pipeline Manager. // While including Compute Pipeline Creation, some private helper functions where introduced: - vk::Result PipelineManager::createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, const vk::ShaderStageFlagBits stage) + vk::Result PipelineManager::createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, const ShaderStage stage) { std::vector<char> code = shaderProgram.getShader(stage).shaderCode; vk::ShaderModuleCreateInfo moduleInfo({}, code.size(), reinterpret_cast<uint32_t*>(code.data())); diff --git a/src/vkcv/PipelineManager.hpp b/src/vkcv/PipelineManager.hpp index b48e06f07de6b863ae79a810482a355f5e4e280a..b153eb4632b844e84b92953fe8abf6666a13e0c9 100644 --- a/src/vkcv/PipelineManager.hpp +++ b/src/vkcv/PipelineManager.hpp @@ -22,7 +22,7 @@ namespace vkcv void destroyPipelineById(uint64_t id); - vk::Result createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, vk::ShaderStageFlagBits stage); + vk::Result createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, ShaderStage stage); public: PipelineManager() = delete; // no default ctor diff --git a/src/vkcv/ShaderProgram.cpp b/src/vkcv/ShaderProgram.cpp index 4c1508ad14ded8bd33a1444910defcc361599f79..59bc70ea76c1aebcb1f41a5583ec6938860bc918 100644 --- a/src/vkcv/ShaderProgram.cpp +++ b/src/vkcv/ShaderProgram.cpp @@ -76,7 +76,7 @@ namespace vkcv { m_DescriptorSets{} {} - bool ShaderProgram::addShader(vk::ShaderStageFlagBits shaderStage, const std::filesystem::path &shaderPath) + bool ShaderProgram::addShader(ShaderStage shaderStage, const std::filesystem::path &shaderPath) { if(m_Shaders.find(shaderStage) != m_Shaders.end()) { vkcv_log(LogLevel::WARNING, "Overwriting existing shader stage"); @@ -94,12 +94,12 @@ namespace vkcv { } } - const Shader &ShaderProgram::getShader(vk::ShaderStageFlagBits shaderStage) const + const Shader &ShaderProgram::getShader(ShaderStage shaderStage) const { return m_Shaders.at(shaderStage); } - bool ShaderProgram::existsShader(vk::ShaderStageFlagBits shaderStage) const + bool ShaderProgram::existsShader(ShaderStage shaderStage) const { if(m_Shaders.find(shaderStage) == m_Shaders.end()) return false; @@ -107,7 +107,7 @@ namespace vkcv { return true; } - void ShaderProgram::reflectShader(vk::ShaderStageFlagBits shaderStage) + void ShaderProgram::reflectShader(ShaderStage shaderStage) { auto shaderCodeChar = m_Shaders.at(shaderStage).shaderCode; std::vector<uint32_t> shaderCode; @@ -119,7 +119,7 @@ namespace vkcv { spirv_cross::ShaderResources resources = comp.get_shader_resources(); //reflect vertex input - if (shaderStage == vk::ShaderStageFlagBits::eVertex) + if (shaderStage == ShaderStage::VERTEX) { // spirv-cross API (hopefully) returns the stage_inputs in order for (uint32_t i = 0; i < resources.stage_inputs.size(); i++) @@ -138,6 +138,12 @@ namespace vkcv { m_VertexAttachments.emplace_back(attachment_loc, attachment_name, attachment_format); } } + + ShaderStages stages; + stages |= ShaderStage::VERTEX; + stages |= ShaderStage::FRAGMENT; + + vk::ShaderStageFlags flags = vk::ShaderStageFlagBits::eVertex | vk::ShaderStageFlagBits::eFragment; //reflect descriptor sets (uniform buffer, storage buffer, sampler, sampled image, storage image) std::vector<std::pair<uint32_t, DescriptorBinding>> bindings;