diff --git a/projects/mesh_shader/src/main.cpp b/projects/mesh_shader/src/main.cpp index 72b15e3176a9a5e3614a9238fefe618a19070ba0..998162be24bc148aff3651a31d65157d94d6c351 100644 --- a/projects/mesh_shader/src/main.cpp +++ b/projects/mesh_shader/src/main.cpp @@ -90,6 +90,14 @@ int main(int argc, const char** argv) { } ); + features.requireExtensionFeature<vk::PhysicalDeviceDescriptorIndexingFeatures>( + VK_EXT_DESCRIPTOR_INDEXING_EXTENSION_NAME, + [](vk::PhysicalDeviceDescriptorIndexingFeatures& features) { + features.setDescriptorBindingPartiallyBound(true); + features.setDescriptorBindingVariableDescriptorCount(true); + } + ); + vkcv::Core core = vkcv::Core::create( applicationName, VK_MAKE_VERSION(0, 0, 1), diff --git a/src/vkcv/ShaderProgram.cpp b/src/vkcv/ShaderProgram.cpp index 5ad185d950de7e2e75ccb4cb34804251b0d97c9b..4a62dec4fdf685d872dd8db6aadf4da7414f515b 100644 --- a/src/vkcv/ShaderProgram.cpp +++ b/src/vkcv/ShaderProgram.cpp @@ -8,6 +8,9 @@ #include "vkcv/File.hpp" #include "vkcv/Logger.hpp" +#include <cstddef> +#include <cstdint> +#include <limits> namespace vkcv { @@ -79,73 +82,48 @@ namespace vkcv { return true; } - void ShaderProgram::reflectShader(ShaderStage shaderStage) { - auto shaderCode = m_Shaders.at(shaderStage); - - spirv_cross::Compiler comp(shaderCode); - spirv_cross::ShaderResources resources = comp.get_shader_resources(); - - // reflect vertex input - if (shaderStage == ShaderStage::VERTEX) { - // spirv-cross API (hopefully) returns the stage_inputs in order - for (uint32_t i = 0; i < resources.stage_inputs.size(); i++) { - // spirv-cross specific objects - auto &stage_input = resources.stage_inputs [i]; - const spirv_cross::SPIRType &base_type = comp.get_type(stage_input.base_type_id); - - // vertex input location - const uint32_t attachment_loc = - comp.get_decoration(stage_input.id, spv::DecorationLocation); - // vertex input name - const std::string attachment_name = stage_input.name; - // vertex input format (implies its size) - const VertexAttachmentFormat attachment_format = - convertFormat(base_type.basetype, base_type.vecsize); - - m_VertexAttachments.push_back( - { attachment_loc, attachment_name, attachment_format, 0 }); - } + static void reflectShaderDescriptorSets(Dictionary<uint32_t, DescriptorBindings> &descriptorSets, + ShaderStage shaderStage, + DescriptorType descriptorType, + const spirv_cross::Compiler &comp, + const spirv_cross::ShaderResources &resources) { + const spirv_cross::SmallVector<spirv_cross::Resource> *res = nullptr; + + switch (descriptorType) { + case DescriptorType::UNIFORM_BUFFER: + res = &(resources.uniform_buffers); + break; + case DescriptorType::STORAGE_BUFFER: + res = &(resources.storage_buffers); + break; + case DescriptorType::SAMPLER: + res = &(resources.separate_samplers); + break; + case DescriptorType::IMAGE_SAMPLED: + res = &(resources.separate_images); + break; + case DescriptorType::IMAGE_STORAGE: + res = &(resources.storage_images); + break; + case DescriptorType::UNIFORM_BUFFER_DYNAMIC: + res = &(resources.uniform_buffers); + break; + case DescriptorType::STORAGE_BUFFER_DYNAMIC: + res = &(resources.storage_buffers); + break; + case DescriptorType::ACCELERATION_STRUCTURE_KHR: + res = &(resources.acceleration_structures); + break; + default: + break; } - // reflect descriptor sets (uniform buffer, storage buffer, sampler, sampled image, storage - // image) - Vector<std::pair<uint32_t, DescriptorBinding>> bindings; - - for (uint32_t i = 0; i < resources.uniform_buffers.size(); i++) { - auto &u = resources.uniform_buffers [i]; - const spirv_cross::SPIRType &base_type = comp.get_type(u.base_type_id); - const spirv_cross::SPIRType &type = comp.get_type(u.type_id); - - uint32_t setID = comp.get_decoration(u.id, spv::DecorationDescriptorSet); - uint32_t bindingID = comp.get_decoration(u.id, spv::DecorationBinding); - - uint32_t descriptorCount = base_type.vecsize; - bool variableCount = false; - // query whether reflected resources are qualified as one-dimensional array - if (type.array_size_literal [0]) { - if (type.array [0] == 0) - variableCount = true; - } - - DescriptorBinding binding { - bindingID, DescriptorType::UNIFORM_BUFFER, descriptorCount, shaderStage, - variableCount, - variableCount // partialBinding == variableCount - }; - - auto insertionResult = - m_DescriptorSets [setID].insert(std::make_pair(bindingID, binding)); - if (!insertionResult.second) { - insertionResult.first->second.shaderStages |= shaderStage; - - vkcv_log(LogLevel::WARNING, - "Attempting to overwrite already existing binding %u at set ID %u.", - bindingID, setID); - } + if (nullptr == res) { + return; } - for (uint32_t i = 0; i < resources.storage_buffers.size(); i++) { - auto &u = resources.storage_buffers [i]; + for (uint32_t i = 0; i < res->size(); i++) { + const spirv_cross::Resource &u = (*res)[i]; const spirv_cross::SPIRType &base_type = comp.get_type(u.base_type_id); const spirv_cross::SPIRType &type = comp.get_type(u.type_id); @@ -153,86 +131,23 @@ namespace vkcv { uint32_t bindingID = comp.get_decoration(u.id, spv::DecorationBinding); uint32_t descriptorCount = base_type.vecsize; - bool variableCount = false; - // query whether reflected resources are qualified as one-dimensional array - if (type.array_size_literal [0]) { - if (type.array [0] == 0) - variableCount = true; - } - DescriptorBinding binding { - bindingID, DescriptorType::STORAGE_BUFFER, descriptorCount, shaderStage, - variableCount, - variableCount // partialBinding == variableCount - }; - - auto insertionResult = - m_DescriptorSets [setID].insert(std::make_pair(bindingID, binding)); - if (!insertionResult.second) { - insertionResult.first->second.shaderStages |= shaderStage; - - vkcv_log(LogLevel::WARNING, - "Attempting to overwrite already existing binding %u at set ID %u.", - bindingID, setID); - } - } - - for (uint32_t i = 0; i < resources.separate_samplers.size(); i++) { - auto &u = resources.separate_samplers [i]; - const spirv_cross::SPIRType &base_type = comp.get_type(u.base_type_id); - const spirv_cross::SPIRType &type = comp.get_type(u.type_id); - - uint32_t setID = comp.get_decoration(u.id, spv::DecorationDescriptorSet); - uint32_t bindingID = comp.get_decoration(u.id, spv::DecorationBinding); - - uint32_t descriptorCount = base_type.vecsize; bool variableCount = false; // query whether reflected resources are qualified as one-dimensional array - if (type.array_size_literal [0]) { - if (type.array [0] == 0) - variableCount = true; + if (descriptorCount == 0) { + variableCount = true; } DescriptorBinding binding { - bindingID, DescriptorType::SAMPLER, descriptorCount, shaderStage, variableCount, - variableCount // partialBinding == variableCount - }; - - auto insertionResult = - m_DescriptorSets [setID].insert(std::make_pair(bindingID, binding)); - if (!insertionResult.second) { - insertionResult.first->second.shaderStages |= shaderStage; - - vkcv_log(LogLevel::WARNING, - "Attempting to overwrite already existing binding %u at set ID %u.", - bindingID, setID); - } - } - - for (uint32_t i = 0; i < resources.separate_images.size(); i++) { - auto &u = resources.separate_images [i]; - const spirv_cross::SPIRType &base_type = comp.get_type(u.base_type_id); - const spirv_cross::SPIRType &type = comp.get_type(u.type_id); - - uint32_t setID = comp.get_decoration(u.id, spv::DecorationDescriptorSet); - uint32_t bindingID = comp.get_decoration(u.id, spv::DecorationBinding); - - uint32_t descriptorCount = base_type.vecsize; - bool variableCount = false; - // query whether reflected resources are qualified as one-dimensional array - if (type.array_size_literal [0]) { - if (type.array [0] == 0) - variableCount = true; - } - - DescriptorBinding binding { - bindingID, DescriptorType::IMAGE_SAMPLED, descriptorCount, shaderStage, + bindingID, + descriptorType, + descriptorCount, + shaderStage, variableCount, variableCount // partialBinding == variableCount }; - auto insertionResult = - m_DescriptorSets [setID].insert(std::make_pair(bindingID, binding)); + auto insertionResult = descriptorSets[setID].insert(std::make_pair(bindingID, binding)); if (!insertionResult.second) { insertionResult.first->second.shaderStages |= shaderStage; @@ -241,62 +156,96 @@ namespace vkcv { bindingID, setID); } } + } - for (uint32_t i = 0; i < resources.storage_images.size(); i++) { - auto &u = resources.storage_images [i]; - const spirv_cross::SPIRType &base_type = comp.get_type(u.base_type_id); - const spirv_cross::SPIRType &type = comp.get_type(u.type_id); + void ShaderProgram::reflectShader(ShaderStage shaderStage) { + auto shaderCode = m_Shaders.at(shaderStage); - uint32_t setID = comp.get_decoration(u.id, spv::DecorationDescriptorSet); - uint32_t bindingID = comp.get_decoration(u.id, spv::DecorationBinding); + spirv_cross::Compiler comp(shaderCode); + spirv_cross::ShaderResources resources = comp.get_shader_resources(); - uint32_t descriptorCount = base_type.vecsize; - bool variableCount = false; - // query whether reflected resources are qualified as one-dimensional array - if (type.array_size_literal [0]) { - if (type.array [0] == 0) - variableCount = true; - } + // reflect vertex input + if (shaderStage == ShaderStage::VERTEX) { + // spirv-cross API (hopefully) returns the stage_inputs in order + for (uint32_t i = 0; i < resources.stage_inputs.size(); i++) { + // spirv-cross specific objects + auto &stage_input = resources.stage_inputs [i]; + const spirv_cross::SPIRType &base_type = comp.get_type(stage_input.base_type_id); - DescriptorBinding binding { - bindingID, DescriptorType::IMAGE_STORAGE, descriptorCount, shaderStage, - variableCount, - variableCount // partialBinding == variableCount - }; + // vertex input location + const uint32_t attachment_loc = + comp.get_decoration(stage_input.id, spv::DecorationLocation); + // vertex input name + const std::string attachment_name = stage_input.name; + // vertex input format (implies its size) + const VertexAttachmentFormat attachment_format = + convertFormat(base_type.basetype, base_type.vecsize); - auto insertionResult = - m_DescriptorSets [setID].insert(std::make_pair(bindingID, binding)); - if (!insertionResult.second) { - insertionResult.first->second.shaderStages |= shaderStage; - - vkcv_log(LogLevel::WARNING, - "Attempting to overwrite already existing binding %u at set ID %u.", - bindingID, setID); + m_VertexAttachments.push_back( + { attachment_loc, attachment_name, attachment_format, 0 }); } } - // Used to reflect acceleration structure bindings for RTX. - for (uint32_t i = 0; i < resources.acceleration_structures.size(); i++) { - auto &u = resources.acceleration_structures [i]; - const spirv_cross::SPIRType &base_type = comp.get_type(u.base_type_id); + reflectShaderDescriptorSets( + m_DescriptorSets, + shaderStage, + DescriptorType::UNIFORM_BUFFER, + comp, + resources + ); + + reflectShaderDescriptorSets( + m_DescriptorSets, + shaderStage, + DescriptorType::STORAGE_BUFFER, + comp, + resources + ); + + reflectShaderDescriptorSets( + m_DescriptorSets, + shaderStage, + DescriptorType::SAMPLER, + comp, + resources + ); + + reflectShaderDescriptorSets( + m_DescriptorSets, + shaderStage, + DescriptorType::IMAGE_SAMPLED, + comp, + resources + ); + + reflectShaderDescriptorSets( + m_DescriptorSets, + shaderStage, + DescriptorType::IMAGE_STORAGE, + comp, + resources + ); + + reflectShaderDescriptorSets( + m_DescriptorSets, + shaderStage, + DescriptorType::ACCELERATION_STRUCTURE_KHR, + comp, + resources + ); + + for (auto &descriptorSet : m_DescriptorSets) { + uint32_t maxVariableBindingID = 0; + + for (const auto &binding : descriptorSet.second) { + maxVariableBindingID = std::max(maxVariableBindingID, binding.first); + } - uint32_t setID = comp.get_decoration(u.id, spv::DecorationDescriptorSet); - uint32_t bindingID = comp.get_decoration(u.id, spv::DecorationBinding); - auto binding = DescriptorBinding { bindingID, - DescriptorType::ACCELERATION_STRUCTURE_KHR, - base_type.vecsize, - shaderStage, - false, - false }; - - auto insertionResult = - m_DescriptorSets [setID].insert(std::make_pair(bindingID, binding)); - if (!insertionResult.second) { - insertionResult.first->second.shaderStages |= shaderStage; - - vkcv_log(LogLevel::WARNING, - "Attempting to overwrite already existing binding %u at set ID %u.", - bindingID, setID); + for (auto &binding : descriptorSet.second) { + if (binding.first < maxVariableBindingID) { + binding.second.variableCount &= false; + binding.second.partialBinding &= false; + } } }