Skip to content
Snippets Groups Projects
Verified Commit e4e646f3 authored by Tobias Frisch's avatar Tobias Frisch
Browse files

[#76] Fixed problems with custom shader stages

parent c420f75a
No related branches found
No related tags found
1 merge request!71Resolve "Descriptor in multiple shader stages"
Pipeline #26155 failed
This commit is part of merge request !71. Comments created here will be created in the context of that merge request.
......@@ -38,12 +38,12 @@ namespace vkcv
uint32_t bindingID,
DescriptorType descriptorType,
uint32_t descriptorCount,
vk::ShaderStageFlags shaderStages
ShaderStages shaderStages
) noexcept;
uint32_t bindingID;
DescriptorType descriptorType;
uint32_t descriptorCount;
vk::ShaderStageFlags shaderStages;
ShaderStages shaderStages;
};
}
......@@ -14,13 +14,14 @@
#include <spirv_cross.hpp>
#include "VertexLayout.hpp"
#include "DescriptorConfig.hpp"
#include "ShaderStage.hpp"
namespace vkcv {
struct Shader
{
std::vector<char> shaderCode;
vk::ShaderStageFlagBits shaderStage;
ShaderStage shaderStage;
};
class ShaderProgram
......@@ -36,16 +37,16 @@ namespace vkcv {
* @param[in] flag that signals the respective shaderStage (e.g. VK_SHADER_STAGE_VERTEX_BIT)
* @param[in] relative path to the shader code (e.g. "../../../../../shaders/vert.spv")
*/
bool addShader(vk::ShaderStageFlagBits shaderStage, const std::filesystem::path &shaderPath);
bool addShader(ShaderStage shaderStage, const std::filesystem::path &shaderPath);
/**
* Returns the shader program's shader of the specified shader.
* Needed for the transfer to the pipeline.
* @return Shader object consisting of buffer with shader code and shader stage enum
*/
const Shader &getShader(vk::ShaderStageFlagBits shaderStage) const;
const Shader &getShader(ShaderStage shaderStage) const;
bool existsShader(vk::ShaderStageFlagBits shaderStage) const;
bool existsShader(ShaderStage shaderStage) const;
const std::vector<VertexAttachment> &getVertexAttachments() const;
size_t getPushConstantSize() const;
......@@ -58,9 +59,9 @@ namespace vkcv {
* Fills vertex input attachments and descriptor sets (if present).
* @param shaderStage the stage to reflect data from
*/
void reflectShader(vk::ShaderStageFlagBits shaderStage);
void reflectShader(ShaderStage shaderStage);
std::unordered_map<vk::ShaderStageFlagBits, Shader> m_Shaders;
std::unordered_map<ShaderStage, Shader> m_Shaders;
// contains all vertex input attachments used in the vertex buffer
std::vector<VertexAttachment> m_VertexAttachments;
......
#pragma once
namespace vkcv {
#include <vulkan/vulkan.hpp>
namespace vkcv {
enum class ShaderStage : VkShaderStageFlags {
VERTEX = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eVertex),
TESS_CONTROL = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eTessellationControl),
TESS_EVAL = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eTessellationEvaluation),
GEOMETRY = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eGeometry),
FRAGMENT = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eFragment),
COMPUTE = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eCompute)
};
using ShaderStages = vk::Flags<ShaderStage>;
constexpr vk::ShaderStageFlags getShaderStageFlags(ShaderStages shaderStages) noexcept {
return vk::ShaderStageFlags(static_cast<VkShaderStageFlags>(shaderStages));
}
constexpr ShaderStages operator|(ShaderStage stage0, ShaderStage stage1) noexcept {
return ShaderStages(stage0) | stage1;
}
constexpr ShaderStages operator&(ShaderStage stage0, ShaderStage stage1) noexcept {
return ShaderStages(stage0) & stage1;
}
constexpr ShaderStages operator^(ShaderStage stage0, ShaderStage stage1) noexcept {
return ShaderStages(stage0) ^ stage1;
}
constexpr ShaderStages operator~(ShaderStage stage) noexcept {
return ~(ShaderStages(stage));
}
}
......@@ -79,8 +79,8 @@ int main(int argc, const char** argv) {
}
vkcv::ShaderProgram firstMeshProgram{};
firstMeshProgram.addShader(vk::ShaderStageFlagBits::eVertex, std::filesystem::path("resources/shaders/vert.spv"));
firstMeshProgram.addShader(vk::ShaderStageFlagBits::eFragment, std::filesystem::path("resources/shaders/frag.spv"));
firstMeshProgram.addShader(vkcv::ShaderStage::VERTEX, std::filesystem::path("resources/shaders/vert.spv"));
firstMeshProgram.addShader(vkcv::ShaderStage::FRAGMENT, std::filesystem::path("resources/shaders/frag.spv"));
auto& attributes = mesh.vertexGroups[0].vertexBuffer.attributes;
......
......@@ -5,7 +5,7 @@ namespace vkcv {
uint32_t bindingID,
DescriptorType descriptorType,
uint32_t descriptorCount,
vk::ShaderStageFlags shaderStages) noexcept
ShaderStages shaderStages) noexcept
:
bindingID(bindingID),
descriptorType(descriptorType),
......
......@@ -45,7 +45,7 @@ namespace vkcv
bindings[i].bindingID,
convertDescriptorTypeFlag(bindings[i].descriptorType),
bindings[i].descriptorCount,
bindings[i].shaderStages);
getShaderStageFlags(bindings[i].shaderStages));
setBindings.push_back(descriptorSetLayoutBinding);
}
......
......@@ -55,8 +55,8 @@ namespace vkcv
{
const vk::RenderPass &pass = passManager.getVkPass(config.m_PassHandle);
const bool existsVertexShader = config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eVertex);
const bool existsFragmentShader = config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eFragment);
const bool existsVertexShader = config.m_ShaderProgram.existsShader(ShaderStage::VERTEX);
const bool existsFragmentShader = config.m_ShaderProgram.existsShader(ShaderStage::FRAGMENT);
if (!(existsVertexShader && existsFragmentShader))
{
vkcv_log(LogLevel::ERROR, "Requires vertex and fragment shader code");
......@@ -64,7 +64,7 @@ namespace vkcv
}
// vertex shader stage
std::vector<char> vertexCode = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eVertex).shaderCode;
std::vector<char> vertexCode = config.m_ShaderProgram.getShader(ShaderStage::VERTEX).shaderCode;
vk::ShaderModuleCreateInfo vertexModuleInfo({}, vertexCode.size(), reinterpret_cast<uint32_t*>(vertexCode.data()));
vk::ShaderModule vertexModule{};
if (m_Device.createShaderModule(&vertexModuleInfo, nullptr, &vertexModule) != vk::Result::eSuccess)
......@@ -79,7 +79,7 @@ namespace vkcv
);
// fragment shader stage
std::vector<char> fragCode = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eFragment).shaderCode;
std::vector<char> fragCode = config.m_ShaderProgram.getShader(ShaderStage::FRAGMENT).shaderCode;
vk::ShaderModuleCreateInfo fragmentModuleInfo({}, fragCode.size(), reinterpret_cast<uint32_t*>(fragCode.data()));
vk::ShaderModule fragmentModule{};
if (m_Device.createShaderModule(&fragmentModuleInfo, nullptr, &fragmentModule) != vk::Result::eSuccess)
......@@ -258,8 +258,8 @@ namespace vkcv
const char *geometryShaderName = "main"; // outside of if to make sure it stays in scope
vk::ShaderModule geometryModule;
if (config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eGeometry)) {
const vkcv::Shader geometryShader = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eGeometry);
if (config.m_ShaderProgram.existsShader(ShaderStage::GEOMETRY)) {
const vkcv::Shader geometryShader = config.m_ShaderProgram.getShader(ShaderStage::GEOMETRY);
const auto& geometryCode = geometryShader.shaderCode;
const vk::ShaderModuleCreateInfo geometryModuleInfo({}, geometryCode.size(), reinterpret_cast<const uint32_t*>(geometryCode.data()));
if (m_Device.createShaderModule(&geometryModuleInfo, nullptr, &geometryModule) != vk::Result::eSuccess) {
......@@ -375,7 +375,7 @@ namespace vkcv
// Temporally handing over the Shader Program instead of a pipeline config
vk::ShaderModule computeModule{};
if (createShaderModule(computeModule, shaderProgram, vk::ShaderStageFlagBits::eCompute) != vk::Result::eSuccess)
if (createShaderModule(computeModule, shaderProgram, ShaderStage::COMPUTE) != vk::Result::eSuccess)
return PipelineHandle();
vk::PipelineShaderStageCreateInfo pipelineComputeShaderStageInfo(
......@@ -424,7 +424,7 @@ namespace vkcv
// There is an issue for refactoring the Pipeline Manager.
// While including Compute Pipeline Creation, some private helper functions where introduced:
vk::Result PipelineManager::createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, const vk::ShaderStageFlagBits stage)
vk::Result PipelineManager::createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, const ShaderStage stage)
{
std::vector<char> code = shaderProgram.getShader(stage).shaderCode;
vk::ShaderModuleCreateInfo moduleInfo({}, code.size(), reinterpret_cast<uint32_t*>(code.data()));
......
......@@ -22,7 +22,7 @@ namespace vkcv
void destroyPipelineById(uint64_t id);
vk::Result createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, vk::ShaderStageFlagBits stage);
vk::Result createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, ShaderStage stage);
public:
PipelineManager() = delete; // no default ctor
......
......@@ -76,7 +76,7 @@ namespace vkcv {
m_DescriptorSets{}
{}
bool ShaderProgram::addShader(vk::ShaderStageFlagBits shaderStage, const std::filesystem::path &shaderPath)
bool ShaderProgram::addShader(ShaderStage shaderStage, const std::filesystem::path &shaderPath)
{
if(m_Shaders.find(shaderStage) != m_Shaders.end()) {
vkcv_log(LogLevel::WARNING, "Overwriting existing shader stage");
......@@ -94,12 +94,12 @@ namespace vkcv {
}
}
const Shader &ShaderProgram::getShader(vk::ShaderStageFlagBits shaderStage) const
const Shader &ShaderProgram::getShader(ShaderStage shaderStage) const
{
return m_Shaders.at(shaderStage);
}
bool ShaderProgram::existsShader(vk::ShaderStageFlagBits shaderStage) const
bool ShaderProgram::existsShader(ShaderStage shaderStage) const
{
if(m_Shaders.find(shaderStage) == m_Shaders.end())
return false;
......@@ -107,7 +107,7 @@ namespace vkcv {
return true;
}
void ShaderProgram::reflectShader(vk::ShaderStageFlagBits shaderStage)
void ShaderProgram::reflectShader(ShaderStage shaderStage)
{
auto shaderCodeChar = m_Shaders.at(shaderStage).shaderCode;
std::vector<uint32_t> shaderCode;
......@@ -119,7 +119,7 @@ namespace vkcv {
spirv_cross::ShaderResources resources = comp.get_shader_resources();
//reflect vertex input
if (shaderStage == vk::ShaderStageFlagBits::eVertex)
if (shaderStage == ShaderStage::VERTEX)
{
// spirv-cross API (hopefully) returns the stage_inputs in order
for (uint32_t i = 0; i < resources.stage_inputs.size(); i++)
......@@ -138,6 +138,12 @@ namespace vkcv {
m_VertexAttachments.emplace_back(attachment_loc, attachment_name, attachment_format);
}
}
ShaderStages stages;
stages |= ShaderStage::VERTEX;
stages |= ShaderStage::FRAGMENT;
vk::ShaderStageFlags flags = vk::ShaderStageFlagBits::eVertex | vk::ShaderStageFlagBits::eFragment;
//reflect descriptor sets (uniform buffer, storage buffer, sampler, sampled image, storage image)
std::vector<std::pair<uint32_t, DescriptorBinding>> bindings;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment