Commit b728b1f8 authored by Artur Wasmut's avatar Artur Wasmut
Browse files

refactor ShaderProgram class.

parent 1c84af84
......@@ -33,10 +33,10 @@ namespace vkcv {
*/
PipelineConfig(const ShaderProgram& shaderProgram, uint32_t width, uint32_t height, PassHandle &passHandle);
ShaderProgram m_shaderProgram;
uint32_t m_height;
uint32_t m_width;
PassHandle m_passHandle;
ShaderProgram m_ShaderProgram;
uint32_t m_Height;
uint32_t m_Width;
PassHandle m_PassHandle;
};
}
......
......@@ -5,35 +5,35 @@
* @brief ShaderProgram class to handle and prepare the shader stages for a graphics pipeline
*/
#define GLFW_INCLUDE_VULKAN
#include <vector>
#include <unordered_map>
#include <fstream>
#include <iostream>
#include <filesystem>
#include <vulkan/vulkan.hpp>
namespace vkcv {
class ShaderProgram final {
enum class ShaderStage
{
VERTEX,
TESS_CONTROL,
TESS_EVAL,
GEOMETRY,
FRAGMENT,
COMPUTE
};
struct Shader
{
std::vector<char> shaderCode;
ShaderStage shaderStage;
};
class ShaderProgram
{
public:
enum class ShaderStage {
VERTEX,
FRAGMENT,
COMPUTE
};
/**
* destructor of ShaderProgram, does nothing so far
*/
~ShaderProgram();
/**
* Creates a shader program.
* So far it only calls the constructor.
* @param[in] context of the app
*/
static ShaderProgram create();
ShaderProgram() noexcept; // ctor
~ShaderProgram() = default; // dtor
/**
* Adds a shader into the shader program.
......@@ -42,92 +42,19 @@ namespace vkcv {
* @param[in] flag that signals the respective shaderStage (e.g. VK_SHADER_STAGE_VERTEX_BIT)
* @param[in] relative path to the shader code (e.g. "../../../../../shaders/vert.spv")
*/
void addShader(ShaderProgram::ShaderStage shaderStage, const std::string& filepath);
/**
* Tests if the shader program contains a certain shader stage.
* @param[in] flag that signals the respective shader stage (e.g. VK_SHADER_STAGE_VERTEX_BIT)
* @return boolean that is true if the shader program contains the shader stage
*/
bool containsShaderStage(ShaderProgram::ShaderStage shaderStage) const;
/**
* Deletes the given shader stage in the shader program.
* @param[in] flag that signals the respective shader stage (e.g. VK_SHADER_STAGE_VERTEX_BIT)
* @return boolean that is false if the shader stage was not found in the shader program
*/
bool deleteShaderStage(ShaderProgram::ShaderStage shaderStage);
/**
* Returns a list with all the shader stages in the shader program.
* Needed for the transfer to the pipeline.
* @return vector list with all shader stage info structs
*/
std::vector<vk::ShaderStageFlagBits> getShaderStages() const;
/**
* Returns a list with all the shader code in the shader program.
* Needed for the transfer to the pipeline.
* @return vector list with all shader code char vecs
*/
std::vector<std::vector<char>> getShaderCode() const;
bool addShader(ShaderStage shaderStage, const std::filesystem::path &shaderPath);
/**
* Returns the number of shader stages in the shader program.
* Returns the shader program's shader of the specified shader.
* Needed for the transfer to the pipeline.
* @return integer with the number of stages
* @return Shader object consisting of buffer with shader code and shader stage enum
*/
int getShaderStagesCount() const;
const Shader &getShader(ShaderStage shaderStage) const;
bool existsShader(ShaderStage shaderStage) const;
private:
struct ShaderStages {
std::vector<std::vector<char>> shaderCode;
std::vector<vk::ShaderStageFlagBits> shaderStageFlag;
};
ShaderStages m_shaderStages;
/**
* Constructor of ShaderProgram requires a context for the logical device.
* @param context of the app
*/
ShaderProgram();
/**
* Reads the file of a given shader code.
* Only used within the class.
* @param[in] relative path to the shader code
* @return vector of chars as a buffer for the code
*/
std::vector<char> readFile(const std::string& filepath);
/**
* Converts ShaderStage Enum into vk::ShaderStageFlagBits
* @param[in] ShaderStage enum
* @return vk::ShaderStageFlagBits
*/
vk::ShaderStageFlagBits convertToShaderStageFlagBits(ShaderProgram::ShaderStage shaderStage) const;
/**
* Creates a shader module that encapsulates the read shader code.
* Only used within the class.
* Shader modules are destroyed after respective shader stages are created.
* @param[in] a vector of chars as a buffer for the code
* @return shader module
*/
//vk::ShaderModule createShaderModule(const std::vector<char>& shaderCode); -> Core
/**
* Creates a shader stage (info struct) for the to be added shader.
* Only used within the class.
* @param[in] Shader module that encapsulates the shader code
* @param[in] flag that signals the respective shaderStage
* @return pipeline shader stage info struct
*/
//vk::PipelineShaderStageCreateInfo createShaderStage(vk::ShaderModule& shaderModule, vk::ShaderStageFlagBits shaderStage); -> Core
std::unordered_map<ShaderStage, Shader> m_Shaders;
};
}
......@@ -57,9 +57,9 @@ int main(int argc, const char** argv) {
return EXIT_FAILURE;
}
vkcv::ShaderProgram triangleShaderProgram = vkcv::ShaderProgram::create();
triangleShaderProgram.addShader(vkcv::ShaderProgram::ShaderStage::VERTEX, "shaders/vert.spv");
triangleShaderProgram.addShader(vkcv::ShaderProgram::ShaderStage::FRAGMENT, "shaders/frag.spv");
vkcv::ShaderProgram triangleShaderProgram{};
triangleShaderProgram.addShader(vkcv::ShaderStage::VERTEX, std::filesystem::path("shaders/vert.spv"));
triangleShaderProgram.addShader(vkcv::ShaderStage::FRAGMENT, std::filesystem::path("shaders/frag.spv"));
const vkcv::PipelineConfig trianglePipelineDefinition(triangleShaderProgram, windowWidth, windowHeight, trianglePass);
vkcv::PipelineHandle trianglePipeline = core.createGraphicsPipeline(trianglePipelineDefinition);
......
......@@ -497,7 +497,7 @@ namespace vkcv
PipelineHandle Core::createGraphicsPipeline(const PipelineConfig &config)
{
const vk::RenderPass &pass = m_PassManager->getVkPass(config.m_passHandle);
const vk::RenderPass &pass = m_PassManager->getVkPass(config.m_PassHandle);
return m_PipelineManager->createPipeline(config, pass);
}
......
......@@ -9,5 +9,9 @@
namespace vkcv {
PipelineConfig::PipelineConfig(const ShaderProgram& shaderProgram, uint32_t width, uint32_t height, PassHandle &passHandle):
m_shaderProgram(shaderProgram), m_height(height), m_width(width), m_passHandle(passHandle) {}
m_ShaderProgram(shaderProgram),
m_Height(height),
m_Width(width),
m_PassHandle(passHandle)
{}
}
......@@ -25,32 +25,16 @@ namespace vkcv
PipelineHandle PipelineManager::createPipeline(const PipelineConfig &config, const vk::RenderPass &pass)
{
// TODO: this search could be avoided if ShaderProgram could be queried for a specific stage
const auto shaderStageFlags = config.m_shaderProgram.getShaderStages();
const auto shaderCode = config.m_shaderProgram.getShaderCode();
std::vector<char> vertexCode;
std::vector<char> fragCode;
assert(shaderStageFlags.size() == shaderCode.size());
for (int i = 0; i < shaderStageFlags.size(); i++) {
switch (shaderStageFlags[i]) {
case vk::ShaderStageFlagBits::eVertex: vertexCode = shaderCode[i]; break;
case vk::ShaderStageFlagBits::eFragment: fragCode = shaderCode[i]; break;
default: std::cout << "Core::createGraphicsPipeline encountered unknown shader stage" << std::endl;
return PipelineHandle{0};
}
}
const bool foundVertexCode = !vertexCode.empty();
const bool foundFragCode = !fragCode.empty();
const bool foundRequiredShaderCode = foundVertexCode && foundFragCode;
if (!foundRequiredShaderCode) {
const bool existsVertexShader = config.m_ShaderProgram.existsShader(ShaderStage::VERTEX);
const bool existsFragmentShader = config.m_ShaderProgram.existsShader(ShaderStage::FRAGMENT);
if (!(existsVertexShader && existsFragmentShader))
{
std::cout << "Core::createGraphicsPipeline requires vertex and fragment shader code" << std::endl;
return PipelineHandle{0};
}
// vertex shader stage
// TODO: store shader code as uint32_t in ShaderProgram to avoid pointer cast
std::vector<char> vertexCode = config.m_ShaderProgram.getShader(ShaderStage::VERTEX).shaderCode;
vk::ShaderModuleCreateInfo vertexModuleInfo({}, vertexCode.size(), reinterpret_cast<uint32_t*>(vertexCode.data()));
vk::ShaderModule vertexModule{};
if (m_Device.createShaderModule(&vertexModuleInfo, nullptr, &vertexModule) != vk::Result::eSuccess)
......@@ -65,6 +49,7 @@ namespace vkcv
);
// fragment shader stage
std::vector<char> fragCode = config.m_ShaderProgram.getShader(ShaderStage::FRAGMENT).shaderCode;
vk::ShaderModuleCreateInfo fragmentModuleInfo({}, fragCode.size(), reinterpret_cast<uint32_t*>(fragCode.data()));
vk::ShaderModule fragmentModule{};
if (m_Device.createShaderModule(&fragmentModuleInfo, nullptr, &fragmentModule) != vk::Result::eSuccess)
......@@ -101,8 +86,8 @@ namespace vkcv
);
// viewport state
vk::Viewport viewport(0.f, 0.f, static_cast<float>(config.m_width), static_cast<float>(config.m_height), 0.f, 1.f);
vk::Rect2D scissor({ 0,0 }, { config.m_width, config.m_height });
vk::Viewport viewport(0.f, 0.f, static_cast<float>(config.m_Width), static_cast<float>(config.m_Height), 0.f, 1.f);
vk::Rect2D scissor({ 0,0 }, { config.m_Width, config.m_Height });
vk::PipelineViewportStateCreateInfo pipelineViewportStateCreateInfo({}, 1, &viewport, 1, &scissor);
// rasterization state
......@@ -211,11 +196,11 @@ namespace vkcv
vk::Pipeline PipelineManager::getVkPipeline(const PipelineHandle &handle) const
{
return m_Pipelines[handle.id -1];
return m_Pipelines.at(handle.id -1);
}
vk::PipelineLayout PipelineManager::getVkPipelineLayout(const PipelineHandle &handle) const
{
return m_PipelineLayouts[handle.id - 1];
return m_PipelineLayouts.at(handle.id - 1);
}
}
\ No newline at end of file
......@@ -6,114 +6,57 @@
#include "vkcv/ShaderProgram.hpp"
std::vector<const char*> validationLayers = {
"VK_LAYER_KHRONOS_validation"
};
namespace vkcv {
ShaderProgram::ShaderProgram(){
ShaderStages m_shaderStages{};
m_shaderStages.shaderCode = std::vector<std::vector<char>> ();
m_shaderStages.shaderStageFlag = std::vector<vk::ShaderStageFlagBits> ();
}
std::vector<char> ShaderProgram::readFile(const std::string& filepath) {
std::ifstream file(filepath, std::ios::ate | std::ios::binary);
/**
* Reads the file of a given shader code.
* Only used within the class.
* @param[in] relative path to the shader code
* @return vector of chars as a buffer for the code
*/
std::vector<char> readShaderCode(const std::filesystem::path &shaderPath)
{
std::ifstream file(shaderPath.string(), std::ios::ate | std::ios::binary);
if (!file.is_open()) {
throw std::runtime_error("The file could not be opened.");
std::cout << "The file could not be opened." << std::endl;
return std::vector<char>{};
}
size_t fileSize = (size_t)file.tellg();
std::vector<char> buffer(fileSize);
file.seekg(0);
file.read(buffer.data(), fileSize);
return buffer;
}
vk::ShaderStageFlagBits ShaderProgram::convertToShaderStageFlagBits(ShaderProgram::ShaderStage shaderStage) const{
switch (shaderStage) {
case ShaderStage::VERTEX:
return vk::ShaderStageFlagBits::eVertex;
case ShaderStage::FRAGMENT:
return vk::ShaderStageFlagBits::eFragment;
case ShaderStage::COMPUTE:
return vk::ShaderStageFlagBits::eCompute;
return buffer;
}
ShaderProgram::ShaderProgram() noexcept :
m_Shaders{}
{}
bool ShaderProgram::addShader(ShaderStage shaderStage, const std::filesystem::path &shaderPath)
{
if(m_Shaders.find(shaderStage) != m_Shaders.end())
std::cout << "Found existing shader stage. Overwriting." << std::endl;
const std::vector<char> shaderCode = readShaderCode(shaderPath);
if (shaderCode.empty())
return false;
else
{
Shader shader{shaderCode, shaderStage};
m_Shaders.insert(std::make_pair(shaderStage, shader));
return true;
}
throw std::runtime_error("Shader Type not yet implemented.");
}
/*vk::ShaderModule ShaderProgram::createShaderModule(const std::vector<char>& shaderCode) {
vk::ShaderModuleCreateInfo createInfo({}, shaderCode.size(), reinterpret_cast<const uint32_t*>(shaderCode.data()));
vk::ShaderModule shaderModule;
if ((m_context.getDevice().createShaderModule(&createInfo, nullptr, &shaderModule)) != vk::Result::eSuccess) {
throw std::runtime_error("Failed to create shader module!");
}
return shaderModule;
}*/
/*vk::PipelineShaderStageCreateInfo ShaderProgram::createShaderStage(vk::ShaderModule& shaderModule, vk::ShaderStageFlagBits shaderStage) {
vk::PipelineShaderStageCreateInfo shaderStageInfo({}, shaderStage, shaderModule, "main", {});
shaderStageInfo.stage = shaderStage;
shaderStageInfo.module = shaderModule;
shaderStageInfo.pName = "main";
return shaderStageInfo;
}*/
ShaderProgram::~ShaderProgram() {
}
ShaderProgram ShaderProgram::create() {
return ShaderProgram();
}
void ShaderProgram::addShader(ShaderProgram::ShaderStage shaderStage, const std::string& filepath) {
if (containsShaderStage(shaderStage)) {
throw std::runtime_error("Shader program already contains this particular shader stage.");
}
else {
auto shaderCode = readFile(filepath);
vk::ShaderStageFlagBits convertedShaderStage = convertToShaderStageFlagBits(shaderStage);
//vk::ShaderModule shaderModule = createShaderModule(shaderCode);
//vk::PipelineShaderStageCreateInfo shaderInfo = createShaderStage(shaderModule, shaderStage);
//m_shaderStagesList.push_back(shaderInfo);
//m_context.getDevice().destroyShaderModule(shaderModule, nullptr);
m_shaderStages.shaderCode.push_back(shaderCode);
m_shaderStages.shaderStageFlag.push_back(convertedShaderStage);
}
const Shader &ShaderProgram::getShader(ShaderStage shaderStage) const
{
return m_Shaders.at(shaderStage);
}
bool ShaderProgram::containsShaderStage(ShaderProgram::ShaderStage shaderStage) const{
vk::ShaderStageFlagBits convertedShaderStage = convertToShaderStageFlagBits(shaderStage);
for (int i = 0; i < m_shaderStages.shaderStageFlag.size(); i++) {
if (m_shaderStages.shaderStageFlag[i] == convertedShaderStage) {
return true;
}
}
return false;
}
bool ShaderProgram::deleteShaderStage(ShaderProgram::ShaderStage shaderStage) {
vk::ShaderStageFlagBits convertedShaderStage = convertToShaderStageFlagBits(shaderStage);
for (int i = 0; i < m_shaderStages.shaderStageFlag.size() - 1; i++) {
if (m_shaderStages.shaderStageFlag[i] == convertedShaderStage) {
m_shaderStages.shaderStageFlag.erase(m_shaderStages.shaderStageFlag.begin() + i);
m_shaderStages.shaderCode.erase(m_shaderStages.shaderCode.begin() + i);
return true;
}
}
return false;
}
std::vector<vk::ShaderStageFlagBits> ShaderProgram::getShaderStages() const{
return m_shaderStages.shaderStageFlag;
}
std::vector<std::vector<char>> ShaderProgram::getShaderCode() const {
return m_shaderStages.shaderCode;
}
int ShaderProgram::getShaderStagesCount() const {
return m_shaderStages.shaderStageFlag.size();
}
bool ShaderProgram::existsShader(ShaderStage shaderStage) const
{
if(m_Shaders.find(shaderStage) == m_Shaders.end())
return false;
else
return true;
}
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment