From f005b1a90ff83b1f1d0732a8b58ae76d5060405a Mon Sep 17 00:00:00 2001
From: Jacki <jacki@thejackimonster.de>
Date: Mon, 16 Sep 2024 20:32:02 +0200
Subject: [PATCH] Implement slang compiler

Signed-off-by: Jacki <jacki@thejackimonster.de>
---
 .../include/vkcv/shader/Compiler.hpp          |   6 +-
 .../include/vkcv/shader/GlslangCompiler.hpp   |  19 ---
 .../include/vkcv/shader/ShadyCompiler.hpp     |  17 --
 .../include/vkcv/shader/SlangCompiler.hpp     |  44 +++--
 .../src/vkcv/shader/Compiler.cpp              |  27 ++++
 .../src/vkcv/shader/GlslangCompiler.cpp       |  27 ----
 .../src/vkcv/shader/ShadyCompiler.cpp         |  27 ----
 .../src/vkcv/shader/SlangCompiler.cpp         | 151 ++++++++++++++----
 .../src/vkcv/shader/SlimCompiler.cpp          |  38 ++---
 9 files changed, 197 insertions(+), 159 deletions(-)

diff --git a/modules/shader_compiler/include/vkcv/shader/Compiler.hpp b/modules/shader_compiler/include/vkcv/shader/Compiler.hpp
index 873c7ba0..67d63a28 100644
--- a/modules/shader_compiler/include/vkcv/shader/Compiler.hpp
+++ b/modules/shader_compiler/include/vkcv/shader/Compiler.hpp
@@ -84,11 +84,11 @@ namespace vkcv::shader {
          * @param[in] includePath Include path for shaders
          * @param[in] update Flag to update shaders during runtime
          */
-		virtual void compile(ShaderStage shaderStage,
+	     void compile(ShaderStage shaderStage,
 							 const std::filesystem::path& shaderPath,
 							 const ShaderCompiledFunction& compiled,
-							 const std::filesystem::path& includePath,
-							 bool update) = 0;
+							 const std::filesystem::path& includePath = "",
+							 bool update = false);
 		
 		/**
          * Compile a shader program from a specific map of given file paths for
diff --git a/modules/shader_compiler/include/vkcv/shader/GlslangCompiler.hpp b/modules/shader_compiler/include/vkcv/shader/GlslangCompiler.hpp
index 161c796e..443290de 100644
--- a/modules/shader_compiler/include/vkcv/shader/GlslangCompiler.hpp
+++ b/modules/shader_compiler/include/vkcv/shader/GlslangCompiler.hpp
@@ -1,7 +1,5 @@
 #pragma once
 
-#include <filesystem>
-
 #include <vkcv/ShaderStage.hpp>
 #include "Compiler.hpp"
 
@@ -57,23 +55,6 @@ namespace vkcv::shader {
 		 */
 		GlslangCompiler& operator=(GlslangCompiler&& other) = default;
 		
-		/**
-         * Compile a shader from a specific file path for a target stage with
-         * a custom shader include path and an event function called if the
-         * compilation completes.
-         *
-         * @param[in] shaderStage Shader pipeline stage
-         * @param[in] shaderPath Filepath of shader
-         * @param[in] compiled Shader compilation event
-         * @param[in] includePath Include path for shaders
-         * @param[in] update Flag to update shaders during runtime
-         */
-		void compile(ShaderStage shaderStage,
-					 const std::filesystem::path& shaderPath,
-					 const ShaderCompiledFunction& compiled,
-					 const std::filesystem::path& includePath = "",
-					 bool update = false) override;
-		
 	};
 	
 	/** @} */
diff --git a/modules/shader_compiler/include/vkcv/shader/ShadyCompiler.hpp b/modules/shader_compiler/include/vkcv/shader/ShadyCompiler.hpp
index 65e732f0..f9278329 100644
--- a/modules/shader_compiler/include/vkcv/shader/ShadyCompiler.hpp
+++ b/modules/shader_compiler/include/vkcv/shader/ShadyCompiler.hpp
@@ -57,23 +57,6 @@ namespace vkcv::shader {
 		 */
         ShadyCompiler& operator=(ShadyCompiler&& other) = default;
 
-        /**
-         * Compile a shader from a specific file path for a target stage with
-         * a custom shader include path and an event function called if the
-         * compilation completes.
-         *
-         * @param[in] shaderStage Shader pipeline stage
-         * @param[in] shaderPath Filepath of shader
-         * @param[in] compiled Shader compilation event
-         * @param[in] includePath Include path for shaders
-         * @param[in] update Flag to update shaders during runtime
-         */
-        void compile(ShaderStage shaderStage,
-					 const std::filesystem::path& shaderPath,
-					 const ShaderCompiledFunction& compiled,
-					 const std::filesystem::path& includePath = "",
-					 bool update = false) override;
-
     };
 
     /** @} */
diff --git a/modules/shader_compiler/include/vkcv/shader/SlangCompiler.hpp b/modules/shader_compiler/include/vkcv/shader/SlangCompiler.hpp
index 3c13bbd1..556627e9 100644
--- a/modules/shader_compiler/include/vkcv/shader/SlangCompiler.hpp
+++ b/modules/shader_compiler/include/vkcv/shader/SlangCompiler.hpp
@@ -12,15 +12,27 @@ namespace vkcv::shader {
     * @{
     */
 	
+	enum class SlangCompileProfile {
+		GLSL,
+		HLSL,
+		SPIRV,
+		UNKNOWN
+	};
+	
 	/**
 	 * An abstract class to handle Slang runtime shader compilation.
 	 */
 	class SlangCompiler : public Compiler {
+	private:
+		SlangCompileProfile m_profile;
+	
 	public:
 		/**
 		 * The constructor of a runtime Slang shader compiler instance.
+		 *
+		 * @param[in] profile Compile profile (optional)
 		 */
-		SlangCompiler();
+		SlangCompiler(SlangCompileProfile profile = SlangCompileProfile::UNKNOWN);
 		
 		/**
 		 * The copy-constructor of a runtime Slang shader compiler instance.
@@ -56,23 +68,21 @@ namespace vkcv::shader {
 		 * @return Reference to this instance
 		 */
 		SlangCompiler& operator=(SlangCompiler&& other) = default;
-		
+
 		/**
-         * Compile a shader from a specific file path for a target stage with
-         * a custom shader include path and an event function called if the
-         * compilation completes.
-         *
-         * @param[in] shaderStage Shader pipeline stage
-         * @param[in] shaderPath Filepath of shader
-         * @param[in] compiled Shader compilation event
-         * @param[in] includePath Include path for shaders
-         * @param[in] update Flag to update shaders during runtime
-         */
-		void compile(ShaderStage shaderStage,
-					 const std::filesystem::path& shaderPath,
-					 const ShaderCompiledFunction& compiled,
-					 const std::filesystem::path& includePath = "",
-					 bool update = false) override;
+		 * Compile a shader from source for a target stage with a custom shader
+		 * include path and an event function called if the compilation completes.
+		 *
+		 * @param[in] shaderStage Shader pipeline stage
+		 * @param[in] shaderSource Source of shader
+		 * @param[in] compiled Shader compilation event
+		 * @param[in] includePath Include path for shaders
+		 * @return Result if the compilation succeeds
+		 */
+		bool compileSource(ShaderStage shaderStage,
+											 const std::string& shaderSource,
+											 const ShaderCompiledFunction& compiled,
+											 const std::filesystem::path& includePath) override;
 		
 	};
 	
diff --git a/modules/shader_compiler/src/vkcv/shader/Compiler.cpp b/modules/shader_compiler/src/vkcv/shader/Compiler.cpp
index 194faee9..842ad6ab 100644
--- a/modules/shader_compiler/src/vkcv/shader/Compiler.cpp
+++ b/modules/shader_compiler/src/vkcv/shader/Compiler.cpp
@@ -41,6 +41,33 @@ namespace vkcv::shader {
 				}, directory
 		);
 	}
+
+	void Compiler::compile(ShaderStage shaderStage,
+												 const std::filesystem::path &shaderPath,
+												 const ShaderCompiledFunction &compiled,
+												 const std::filesystem::path &includePath,
+												 bool update) {
+		std::string shaderCode;
+		bool result = readTextFromFile(shaderPath, shaderCode);
+		
+		if (!result) {
+			vkcv_log(LogLevel::ERROR, "Loading shader failed: (%s)", shaderPath.string().c_str());
+		}
+		
+		if (!includePath.empty()) {
+			result = compileSource(shaderStage, shaderCode, compiled, includePath);
+		} else {
+			result = compileSource(shaderStage, shaderCode, compiled, shaderPath.parent_path());
+		}
+		
+		if (!result) {
+			vkcv_log(LogLevel::ERROR, "Shader compilation failed: (%s)", shaderPath.string().c_str());
+		}
+		
+		if (update) {
+			// TODO: Shader hot compilation during runtime
+		}
+	}
 	
 	void Compiler::compileProgram(ShaderProgram& program,
 								  const Dictionary<ShaderStage, const std::filesystem::path>& stages,
diff --git a/modules/shader_compiler/src/vkcv/shader/GlslangCompiler.cpp b/modules/shader_compiler/src/vkcv/shader/GlslangCompiler.cpp
index 2a2847ad..b749f950 100644
--- a/modules/shader_compiler/src/vkcv/shader/GlslangCompiler.cpp
+++ b/modules/shader_compiler/src/vkcv/shader/GlslangCompiler.cpp
@@ -36,31 +36,4 @@ namespace vkcv::shader {
 		return *this;
 	}
 	
-	void GlslangCompiler::compile(ShaderStage shaderStage,
-								  const std::filesystem::path &shaderPath,
-								  const ShaderCompiledFunction &compiled,
-								  const std::filesystem::path &includePath,
-								  bool update) {
-		std::string shaderCode;
-		bool result = readTextFromFile(shaderPath, shaderCode);
-		
-		if (!result) {
-			vkcv_log(LogLevel::ERROR, "Loading shader failed: (%s)", shaderPath.string().c_str());
-		}
-		
-		if (!includePath.empty()) {
-			result = compileSource(shaderStage, shaderCode, compiled, includePath);
-		} else {
-			result = compileSource(shaderStage, shaderCode, compiled, shaderPath.parent_path());
-		}
-		
-		if (!result) {
-			vkcv_log(LogLevel::ERROR, "Shader compilation failed: (%s)", shaderPath.string().c_str());
-		}
-		
-		if (update) {
-			// TODO: Shader hot compilation during runtime
-		}
-	}
-	
 }
diff --git a/modules/shader_compiler/src/vkcv/shader/ShadyCompiler.cpp b/modules/shader_compiler/src/vkcv/shader/ShadyCompiler.cpp
index 6168fcfa..80b44a36 100644
--- a/modules/shader_compiler/src/vkcv/shader/ShadyCompiler.cpp
+++ b/modules/shader_compiler/src/vkcv/shader/ShadyCompiler.cpp
@@ -9,31 +9,4 @@ namespace vkcv::shader {
 	ShadyCompiler::ShadyCompiler()
 	: Compiler() {}
 	
-	void ShadyCompiler::compile(ShaderStage shaderStage,
-								const std::filesystem::path &shaderPath,
-								const ShaderCompiledFunction &compiled,
-								const std::filesystem::path &includePath,
-								bool update) {
-		std::string shaderCode;
-		bool result = readTextFromFile(shaderPath, shaderCode);
-		
-		if (!result) {
-			vkcv_log(LogLevel::ERROR, "Loading shader failed: (%s)", shaderPath.string().c_str());
-		}
-		
-		if (!includePath.empty()) {
-			result = compileSource(shaderStage, shaderCode, compiled, includePath);
-		} else {
-			result = compileSource(shaderStage, shaderCode, compiled, shaderPath.parent_path());
-		}
-		
-		if (!result) {
-			vkcv_log(LogLevel::ERROR, "Shader compilation failed: (%s)", shaderPath.string().c_str());
-		}
-		
-		if (update) {
-			// TODO: Shader hot compilation during runtime
-		}
-	}
-	
 }
diff --git a/modules/shader_compiler/src/vkcv/shader/SlangCompiler.cpp b/modules/shader_compiler/src/vkcv/shader/SlangCompiler.cpp
index ffc295b7..a54d17bd 100644
--- a/modules/shader_compiler/src/vkcv/shader/SlangCompiler.cpp
+++ b/modules/shader_compiler/src/vkcv/shader/SlangCompiler.cpp
@@ -1,67 +1,158 @@
 
 #include "vkcv/shader/SlangCompiler.hpp"
 
+#include <cstdint>
 #include <vkcv/File.hpp>
 #include <vkcv/Logger.hpp>
 
 #include <slang.h>
+#include <slang-com-ptr.h>
+#include <slang-com-helper.h>
+#include <vkcv/ShaderStage.hpp>
 
 namespace vkcv::shader {
 	
 	static uint32_t s_CompilerCount = 0;
-  static slang::IGlobalSession* s_GlobalSession = nullptr;
+  static Slang::ComPtr<slang::IGlobalSession> s_GlobalSession;
 	
-	SlangCompiler::SlangCompiler() : Compiler() {
+	SlangCompiler::SlangCompiler(SlangCompileProfile profile)
+	: Compiler(), m_profile(profile) {
 		if (s_CompilerCount == 0) {
-      slang::createGlobalSession(&s_GlobalSession);
+      slang::createGlobalSession(s_GlobalSession.writeRef());
 		}
 		
 		s_CompilerCount++;
 	}
 	
-	SlangCompiler::SlangCompiler(const SlangCompiler &other) : Compiler(other) {
+	SlangCompiler::SlangCompiler(const SlangCompiler &other)
+	: Compiler(other), m_profile(other.m_profile) {
 		s_CompilerCount++;
 	}
 	
 	SlangCompiler::~SlangCompiler() {
 		s_CompilerCount--;
-
-    if ((s_CompilerCount == 0) && (s_GlobalSession != nullptr)) {
-      spDestroySession(s_GlobalSession);
-      s_GlobalSession = nullptr;
-    }
 	}
 	
 	SlangCompiler &SlangCompiler::operator=(const SlangCompiler &other) {
+		m_profile = other.m_profile;
 		s_CompilerCount++;
 		return *this;
 	}
-	
-	void SlangCompiler::compile(ShaderStage shaderStage,
-								  const std::filesystem::path &shaderPath,
-								  const ShaderCompiledFunction &compiled,
-								  const std::filesystem::path &includePath,
-								  bool update) {
-		std::string shaderCode;
-		bool result = readTextFromFile(shaderPath, shaderCode);
-		
-		if (!result) {
-			vkcv_log(LogLevel::ERROR, "Loading shader failed: (%s)", shaderPath.string().c_str());
+
+	constexpr SlangStage findShaderLanguage(ShaderStage shaderStage) {
+		switch (shaderStage) {
+			case ShaderStage::VERTEX:
+				return SlangStage::SLANG_STAGE_VERTEX;
+			case ShaderStage::TESS_CONTROL:
+				return SlangStage::SLANG_STAGE_HULL;
+			case ShaderStage::TESS_EVAL:
+				return SlangStage::SLANG_STAGE_DOMAIN;
+			case ShaderStage::GEOMETRY:
+				return SlangStage::SLANG_STAGE_GEOMETRY;
+			case ShaderStage::FRAGMENT:
+				return SlangStage::SLANG_STAGE_FRAGMENT;
+			case ShaderStage::COMPUTE:
+				return SlangStage::SLANG_STAGE_COMPUTE;
+			case ShaderStage::TASK:
+				return SlangStage::SLANG_STAGE_AMPLIFICATION;
+			case ShaderStage::MESH:
+				return SlangStage::SLANG_STAGE_MESH;
+			case ShaderStage::RAY_GEN:
+			    return SlangStage::SLANG_STAGE_RAY_GENERATION;
+			case ShaderStage::RAY_CLOSEST_HIT:
+			    return SlangStage::SLANG_STAGE_CLOSEST_HIT;
+			case ShaderStage::RAY_MISS:
+			    return SlangStage::SLANG_STAGE_MISS;
+			case ShaderStage::RAY_INTERSECTION:
+				return SlangStage::SLANG_STAGE_INTERSECTION;
+			case ShaderStage::RAY_ANY_HIT:
+				return SlangStage::SLANG_STAGE_ANY_HIT;
+			case ShaderStage::RAY_CALLABLE:
+				return SlangStage::SLANG_STAGE_CALLABLE;
+			default:
+				return SlangStage::SLANG_STAGE_NONE;
 		}
-		
-		if (!includePath.empty()) {
-			result = compileSource(shaderStage, shaderCode, compiled, includePath);
-		} else {
-			result = compileSource(shaderStage, shaderCode, compiled, shaderPath.parent_path());
+	}
+
+	bool SlangCompiler::compileSource(ShaderStage shaderStage,
+																		const std::string& shaderSource,
+																		const ShaderCompiledFunction& compiled,
+																		const std::filesystem::path& includePath) {
+		slang::SessionDesc sessionDesc = {};
+    slang::TargetDesc targetDesc = {};
+
+		targetDesc.format = SLANG_SPIRV;
+
+		switch (m_profile) {
+			case SlangCompileProfile::GLSL:
+				targetDesc.profile = s_GlobalSession->findProfile("glsl_460");
+				sessionDesc.defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
+				break;
+			case SlangCompileProfile::HLSL:
+				targetDesc.profile = s_GlobalSession->findProfile("sm_5_0");
+				sessionDesc.defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+				break;
+			case SlangCompileProfile::SPIRV:
+				targetDesc.profile = s_GlobalSession->findProfile("spirv_1_5");
+				targetDesc.flags = SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY;
+				break;
+			default:
+				break;
 		}
-		
-		if (!result) {
-			vkcv_log(LogLevel::ERROR, "Shader compilation failed: (%s)", shaderPath.string().c_str());
+
+		sessionDesc.targets = &targetDesc;
+		sessionDesc.targetCount = 1;
+
+		const char *searchPath = includePath.c_str();
+		sessionDesc.searchPaths = &searchPath;
+		sessionDesc.searchPathCount = 1;
+
+		Slang::ComPtr<slang::ISession> session;
+		if (SLANG_FAILED(s_GlobalSession->createSession(sessionDesc, session.writeRef()))) {
+			vkcv_log(LogLevel::ERROR, "Compiler session could not be created");
+			return false;
+		}
+
+		Slang::ComPtr<slang::ICompileRequest> request;
+		if (SLANG_FAILED(session->createCompileRequest(request.writeRef()))) {
+			vkcv_log(LogLevel::ERROR, "Compilation request could not be created");
+			return false;
+		}
+
+		const int entryPoint = request->addEntryPoint(
+			0, "main", findShaderLanguage(shaderStage)
+		);
+
+		if (SLANG_FAILED(request->compile())) {
+			vkcv_log(LogLevel::ERROR, "Compilation process failed");
+			return false;
+		}
+
+		size_t size;
+		const void *code = request->getEntryPointCode(entryPoint, &size);
+
+		if ((size <= 0) || (!code)) {
+			vkcv_log(LogLevel::ERROR, "Entry point could not be found");
+			return false;
+		}
+
+		std::vector<uint32_t> spirv;
+		spirv.resize(size / sizeof(uint32_t));
+		memcpy(spirv.data(), code, spirv.size() * sizeof(uint32_t));
+
+		const std::filesystem::path tmp_path = generateTemporaryFilePath();
+
+		if (!writeBinaryToFile(tmp_path, spirv)) {
+			vkcv_log(LogLevel::ERROR, "Spir-V could not be written to disk");
+			return false;
 		}
 		
-		if (update) {
-			// TODO: Shader hot compilation during runtime
+		if (compiled) {
+			compiled(shaderStage, tmp_path);
 		}
+		
+		std::filesystem::remove(tmp_path);
+		return true;
 	}
 	
 }
diff --git a/modules/shader_compiler/src/vkcv/shader/SlimCompiler.cpp b/modules/shader_compiler/src/vkcv/shader/SlimCompiler.cpp
index 028d0618..6270bb8e 100644
--- a/modules/shader_compiler/src/vkcv/shader/SlimCompiler.cpp
+++ b/modules/shader_compiler/src/vkcv/shader/SlimCompiler.cpp
@@ -14,15 +14,15 @@ namespace vkcv::shader {
     : ShadyCompiler(), m_target(target) {}
 
     static bool shadyCompileModule(Module* module,
-								   ShaderStage shaderStage,
-								   const std::string& shaderSource,
-								   const ShaderCompiledFunction &compiled,
-								   const std::filesystem::path &includePath) {
+																	 ShaderStage shaderStage,
+																	 const std::string& shaderSource,
+																	 const ShaderCompiledFunction &compiled,
+																	 const std::filesystem::path &includePath) {
 		ShadyErrorCodes codes = driver_load_source_file(
-            SrcSlim,
-            shaderSource.length(),
+			SrcSlim,
+			shaderSource.length(),
 			shaderSource.c_str(),
-            module
+			module
 		);
 
 		switch (codes) {
@@ -46,7 +46,7 @@ namespace vkcv::shader {
 
 		DriverConfig config = default_driver_config();
 
-        config.target = TgtSPV;
+		config.target = TgtSPV;
 		config.output_filename = tmp_path.string().c_str();
 
 		codes = driver_compile(&config, module);
@@ -78,10 +78,10 @@ namespace vkcv::shader {
 	}
 
     static bool shadyCompileArena(IrArena* arena,
-								  ShaderStage shaderStage,
-								  const std::string& shaderSource,
-								  const ShaderCompiledFunction &compiled,
-								  const std::filesystem::path &includePath) {
+																	ShaderStage shaderStage,
+																	const std::string& shaderSource,
+																	const ShaderCompiledFunction &compiled,
+																	const std::filesystem::path &includePath) {
 		Module* module = new_module(arena, "slim_module");
 
 		if (nullptr == module) {
@@ -92,11 +92,11 @@ namespace vkcv::shader {
 		return shadyCompileModule(module, shaderStage, shaderSource, compiled, includePath);
 	}
 
-    bool SlimCompiler::compileSource(ShaderStage shaderStage,
-                                     const std::string& shaderSource,
-						             const ShaderCompiledFunction& compiled,
-						             const std::filesystem::path& includePath) {
-        if (ShaderStage::COMPUTE != shaderStage) {
+	bool SlimCompiler::compileSource(ShaderStage shaderStage,
+																	 const std::string& shaderSource,
+																	 const ShaderCompiledFunction& compiled,
+																	 const std::filesystem::path& includePath) {
+		if (ShaderStage::COMPUTE != shaderStage) {
 			vkcv_log(LogLevel::ERROR, "Shader stage not supported");
 			return false;
 		}
@@ -112,7 +112,7 @@ namespace vkcv::shader {
 		bool result = shadyCompileArena(arena, shaderStage, shaderSource, compiled, includePath);
 
 		destroy_ir_arena(arena);
-        return result;
-    }
+		return result;
+	}
 
 }
-- 
GitLab