diff --git a/include/vkcv/DescriptorConfig.hpp b/include/vkcv/DescriptorConfig.hpp
index 29dc81240367d33e0bb96ef6146e3540d80ac50f..78e111f40144a825bb790149f02a91039eb517b9 100644
--- a/include/vkcv/DescriptorConfig.hpp
+++ b/include/vkcv/DescriptorConfig.hpp
@@ -38,12 +38,12 @@ namespace vkcv
             uint32_t bindingID,
             DescriptorType descriptorType,
             uint32_t descriptorCount,
-            vk::ShaderStageFlags shaderStages
+            ShaderStages shaderStages
         ) noexcept;
         
         uint32_t bindingID;
         DescriptorType descriptorType;
         uint32_t descriptorCount;
-        vk::ShaderStageFlags shaderStages;
+        ShaderStages shaderStages;
     };
 }
diff --git a/include/vkcv/ShaderProgram.hpp b/include/vkcv/ShaderProgram.hpp
index 707d72a864cfd95936a476a11675818e48cc6a9d..c7d67b19148b3c9ec19ce1b539f9661797d1b38f 100644
--- a/include/vkcv/ShaderProgram.hpp
+++ b/include/vkcv/ShaderProgram.hpp
@@ -14,13 +14,14 @@
 #include <spirv_cross.hpp>
 #include "VertexLayout.hpp"
 #include "DescriptorConfig.hpp"
+#include "ShaderStage.hpp"
 
 namespace vkcv {
 
     struct Shader
     {
         std::vector<char> shaderCode;
-        vk::ShaderStageFlagBits shaderStage;
+        ShaderStage shaderStage;
     };
 
 	class ShaderProgram
@@ -36,16 +37,16 @@ namespace vkcv {
         * @param[in] flag that signals the respective shaderStage (e.g. VK_SHADER_STAGE_VERTEX_BIT)
         * @param[in] relative path to the shader code (e.g. "../../../../../shaders/vert.spv")
         */
-        bool addShader(vk::ShaderStageFlagBits shaderStage, const std::filesystem::path &shaderPath);
+        bool addShader(ShaderStage shaderStage, const std::filesystem::path &shaderPath);
 
         /**
         * Returns the shader program's shader of the specified shader.
         * Needed for the transfer to the pipeline.
         * @return Shader object consisting of buffer with shader code and shader stage enum
         */
-        const Shader &getShader(vk::ShaderStageFlagBits shaderStage) const;
+        const Shader &getShader(ShaderStage shaderStage) const;
 
-        bool existsShader(vk::ShaderStageFlagBits shaderStage) const;
+        bool existsShader(ShaderStage shaderStage) const;
 
         const std::vector<VertexAttachment> &getVertexAttachments() const;
 		size_t getPushConstantSize() const;
@@ -58,9 +59,9 @@ namespace vkcv {
 	     * Fills vertex input attachments and descriptor sets (if present).
 	     * @param shaderStage the stage to reflect data from
 	     */
-        void reflectShader(vk::ShaderStageFlagBits shaderStage);
+        void reflectShader(ShaderStage shaderStage);
 
-        std::unordered_map<vk::ShaderStageFlagBits, Shader> m_Shaders;
+        std::unordered_map<ShaderStage, Shader> m_Shaders;
 
         // contains all vertex input attachments used in the vertex buffer
         std::vector<VertexAttachment> m_VertexAttachments;
diff --git a/include/vkcv/ShaderStage.hpp b/include/vkcv/ShaderStage.hpp
index 488fd92fb081a4671055daa5bde4c0bf8116b71a..d671b87b55ac7a5a8926e479c77fa991dd90c665 100644
--- a/include/vkcv/ShaderStage.hpp
+++ b/include/vkcv/ShaderStage.hpp
@@ -1,5 +1,38 @@
 #pragma once
 
-namespace vkcv {
+#include <vulkan/vulkan.hpp>
 
+namespace vkcv {
+	
+	enum class ShaderStage : VkShaderStageFlags {
+		VERTEX = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eVertex),
+		TESS_CONTROL = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eTessellationControl),
+		TESS_EVAL = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eTessellationEvaluation),
+		GEOMETRY = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eGeometry),
+		FRAGMENT = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eFragment),
+		COMPUTE = static_cast<VkShaderStageFlags>(vk::ShaderStageFlagBits::eCompute)
+	};
+	
+	using ShaderStages = vk::Flags<ShaderStage>;
+	
+	constexpr vk::ShaderStageFlags getShaderStageFlags(ShaderStages shaderStages) noexcept {
+		return vk::ShaderStageFlags(static_cast<VkShaderStageFlags>(shaderStages));
+	}
+	
+	constexpr ShaderStages operator|(ShaderStage stage0, ShaderStage stage1) noexcept {
+		return ShaderStages(stage0) | stage1;
+	}
+	
+	constexpr ShaderStages operator&(ShaderStage stage0, ShaderStage stage1) noexcept {
+		return ShaderStages(stage0) & stage1;
+	}
+	
+	constexpr ShaderStages operator^(ShaderStage stage0, ShaderStage stage1) noexcept {
+		return ShaderStages(stage0) ^ stage1;
+	}
+	
+	constexpr ShaderStages operator~(ShaderStage stage) noexcept {
+		return ~(ShaderStages(stage));
+	}
+	
 }
diff --git a/projects/first_mesh/src/main.cpp b/projects/first_mesh/src/main.cpp
index fabea47fc06fc15838bd390525effd32d62de8bc..dc43c905784525a34732bc0e66343fbdcc17a639 100644
--- a/projects/first_mesh/src/main.cpp
+++ b/projects/first_mesh/src/main.cpp
@@ -79,8 +79,8 @@ int main(int argc, const char** argv) {
 	}
 
 	vkcv::ShaderProgram firstMeshProgram{};
-    firstMeshProgram.addShader(vk::ShaderStageFlagBits::eVertex, std::filesystem::path("resources/shaders/vert.spv"));
-    firstMeshProgram.addShader(vk::ShaderStageFlagBits::eFragment, std::filesystem::path("resources/shaders/frag.spv"));
+    firstMeshProgram.addShader(vkcv::ShaderStage::VERTEX, std::filesystem::path("resources/shaders/vert.spv"));
+    firstMeshProgram.addShader(vkcv::ShaderStage::FRAGMENT, std::filesystem::path("resources/shaders/frag.spv"));
 	
 	auto& attributes = mesh.vertexGroups[0].vertexBuffer.attributes;
 	
diff --git a/src/vkcv/DescriptorConfig.cpp b/src/vkcv/DescriptorConfig.cpp
index 1fc075ced91aebe4a5d3596c635208368db08822..a9a127fe608472682c1cbc8d32ca466fba860c72 100644
--- a/src/vkcv/DescriptorConfig.cpp
+++ b/src/vkcv/DescriptorConfig.cpp
@@ -5,7 +5,7 @@ namespace vkcv {
 		uint32_t bindingID,
 		DescriptorType descriptorType,
 		uint32_t descriptorCount,
-		vk::ShaderStageFlags shaderStages) noexcept
+		ShaderStages shaderStages) noexcept
 		:
 		bindingID(bindingID),
 		descriptorType(descriptorType),
diff --git a/src/vkcv/DescriptorManager.cpp b/src/vkcv/DescriptorManager.cpp
index 676e840a1b7d28548e83cbb53ccbd62992a00e2c..d28dd9d137240ba923b55c9be9da9059d3a9ab31 100644
--- a/src/vkcv/DescriptorManager.cpp
+++ b/src/vkcv/DescriptorManager.cpp
@@ -45,7 +45,7 @@ namespace vkcv
                 bindings[i].bindingID,
                 convertDescriptorTypeFlag(bindings[i].descriptorType),
                 bindings[i].descriptorCount,
-                bindings[i].shaderStages);
+				getShaderStageFlags(bindings[i].shaderStages));
             setBindings.push_back(descriptorSetLayoutBinding);
         }
 
diff --git a/src/vkcv/PipelineManager.cpp b/src/vkcv/PipelineManager.cpp
index c8f885dffe5593c8f47f702bfdb99128a9d5a12d..df36442efc2992bf16b6e82245ef9753dad95e5d 100644
--- a/src/vkcv/PipelineManager.cpp
+++ b/src/vkcv/PipelineManager.cpp
@@ -55,8 +55,8 @@ namespace vkcv
     {
 		const vk::RenderPass &pass = passManager.getVkPass(config.m_PassHandle);
     	
-        const bool existsVertexShader = config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eVertex);
-        const bool existsFragmentShader = config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eFragment);
+        const bool existsVertexShader = config.m_ShaderProgram.existsShader(ShaderStage::VERTEX);
+        const bool existsFragmentShader = config.m_ShaderProgram.existsShader(ShaderStage::FRAGMENT);
         if (!(existsVertexShader && existsFragmentShader))
         {
 			vkcv_log(LogLevel::ERROR, "Requires vertex and fragment shader code");
@@ -64,7 +64,7 @@ namespace vkcv
         }
 
         // vertex shader stage
-        std::vector<char> vertexCode = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eVertex).shaderCode;
+        std::vector<char> vertexCode = config.m_ShaderProgram.getShader(ShaderStage::VERTEX).shaderCode;
         vk::ShaderModuleCreateInfo vertexModuleInfo({}, vertexCode.size(), reinterpret_cast<uint32_t*>(vertexCode.data()));
         vk::ShaderModule vertexModule{};
         if (m_Device.createShaderModule(&vertexModuleInfo, nullptr, &vertexModule) != vk::Result::eSuccess)
@@ -79,7 +79,7 @@ namespace vkcv
         );
 
         // fragment shader stage
-        std::vector<char> fragCode = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eFragment).shaderCode;
+        std::vector<char> fragCode = config.m_ShaderProgram.getShader(ShaderStage::FRAGMENT).shaderCode;
         vk::ShaderModuleCreateInfo fragmentModuleInfo({}, fragCode.size(), reinterpret_cast<uint32_t*>(fragCode.data()));
         vk::ShaderModule fragmentModule{};
         if (m_Device.createShaderModule(&fragmentModuleInfo, nullptr, &fragmentModule) != vk::Result::eSuccess)
@@ -258,8 +258,8 @@ namespace vkcv
 
 		const char *geometryShaderName = "main";	// outside of if to make sure it stays in scope
 		vk::ShaderModule geometryModule;
-		if (config.m_ShaderProgram.existsShader(vk::ShaderStageFlagBits::eGeometry)) {
-			const vkcv::Shader geometryShader = config.m_ShaderProgram.getShader(vk::ShaderStageFlagBits::eGeometry);
+		if (config.m_ShaderProgram.existsShader(ShaderStage::GEOMETRY)) {
+			const vkcv::Shader geometryShader = config.m_ShaderProgram.getShader(ShaderStage::GEOMETRY);
 			const auto& geometryCode = geometryShader.shaderCode;
 			const vk::ShaderModuleCreateInfo geometryModuleInfo({}, geometryCode.size(), reinterpret_cast<const uint32_t*>(geometryCode.data()));
 			if (m_Device.createShaderModule(&geometryModuleInfo, nullptr, &geometryModule) != vk::Result::eSuccess) {
@@ -375,7 +375,7 @@ namespace vkcv
 
         // Temporally handing over the Shader Program instead of a pipeline config
         vk::ShaderModule computeModule{};
-        if (createShaderModule(computeModule, shaderProgram, vk::ShaderStageFlagBits::eCompute) != vk::Result::eSuccess)
+        if (createShaderModule(computeModule, shaderProgram, ShaderStage::COMPUTE) != vk::Result::eSuccess)
             return PipelineHandle();
 
         vk::PipelineShaderStageCreateInfo pipelineComputeShaderStageInfo(
@@ -424,7 +424,7 @@ namespace vkcv
     // There is an issue for refactoring the Pipeline Manager.
     // While including Compute Pipeline Creation, some private helper functions where introduced:
 
-    vk::Result PipelineManager::createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, const vk::ShaderStageFlagBits stage)
+    vk::Result PipelineManager::createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, const ShaderStage stage)
     {
         std::vector<char> code = shaderProgram.getShader(stage).shaderCode;
         vk::ShaderModuleCreateInfo moduleInfo({}, code.size(), reinterpret_cast<uint32_t*>(code.data()));
diff --git a/src/vkcv/PipelineManager.hpp b/src/vkcv/PipelineManager.hpp
index b48e06f07de6b863ae79a810482a355f5e4e280a..b153eb4632b844e84b92953fe8abf6666a13e0c9 100644
--- a/src/vkcv/PipelineManager.hpp
+++ b/src/vkcv/PipelineManager.hpp
@@ -22,7 +22,7 @@ namespace vkcv
         
         void destroyPipelineById(uint64_t id);
 
-        vk::Result createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, vk::ShaderStageFlagBits stage);
+        vk::Result createShaderModule(vk::ShaderModule &module, const ShaderProgram &shaderProgram, ShaderStage stage);
 
     public:
         PipelineManager() = delete; // no default ctor
diff --git a/src/vkcv/ShaderProgram.cpp b/src/vkcv/ShaderProgram.cpp
index 4c1508ad14ded8bd33a1444910defcc361599f79..59bc70ea76c1aebcb1f41a5583ec6938860bc918 100644
--- a/src/vkcv/ShaderProgram.cpp
+++ b/src/vkcv/ShaderProgram.cpp
@@ -76,7 +76,7 @@ namespace vkcv {
     m_DescriptorSets{}
 	{}
 
-	bool ShaderProgram::addShader(vk::ShaderStageFlagBits shaderStage, const std::filesystem::path &shaderPath)
+	bool ShaderProgram::addShader(ShaderStage shaderStage, const std::filesystem::path &shaderPath)
 	{
 	    if(m_Shaders.find(shaderStage) != m_Shaders.end()) {
 			vkcv_log(LogLevel::WARNING, "Overwriting existing shader stage");
@@ -94,12 +94,12 @@ namespace vkcv {
         }
 	}
 
-    const Shader &ShaderProgram::getShader(vk::ShaderStageFlagBits shaderStage) const
+    const Shader &ShaderProgram::getShader(ShaderStage shaderStage) const
     {
 	    return m_Shaders.at(shaderStage);
 	}
 
-    bool ShaderProgram::existsShader(vk::ShaderStageFlagBits shaderStage) const
+    bool ShaderProgram::existsShader(ShaderStage shaderStage) const
     {
 	    if(m_Shaders.find(shaderStage) == m_Shaders.end())
 	        return false;
@@ -107,7 +107,7 @@ namespace vkcv {
 	        return true;
     }
 
-    void ShaderProgram::reflectShader(vk::ShaderStageFlagBits shaderStage)
+    void ShaderProgram::reflectShader(ShaderStage shaderStage)
     {
         auto shaderCodeChar = m_Shaders.at(shaderStage).shaderCode;
         std::vector<uint32_t> shaderCode;
@@ -119,7 +119,7 @@ namespace vkcv {
         spirv_cross::ShaderResources resources = comp.get_shader_resources();
 
         //reflect vertex input
-		if (shaderStage == vk::ShaderStageFlagBits::eVertex)
+		if (shaderStage == ShaderStage::VERTEX)
 		{
 			// spirv-cross API (hopefully) returns the stage_inputs in order
 			for (uint32_t i = 0; i < resources.stage_inputs.size(); i++)
@@ -138,6 +138,12 @@ namespace vkcv {
                 m_VertexAttachments.emplace_back(attachment_loc, attachment_name, attachment_format);
             }
 		}
+		
+		ShaderStages stages;
+		stages |= ShaderStage::VERTEX;
+		stages |= ShaderStage::FRAGMENT;
+		
+		vk::ShaderStageFlags flags = vk::ShaderStageFlagBits::eVertex | vk::ShaderStageFlagBits::eFragment;
 
 		//reflect descriptor sets (uniform buffer, storage buffer, sampler, sampled image, storage image)
         std::vector<std::pair<uint32_t, DescriptorBinding>> bindings;