Skip to content
Snippets Groups Projects
SlangCompiler.cpp 4.89 KiB

#include "vkcv/shader/SlangCompiler.hpp"

#include <cstdint>
#include <vkcv/File.hpp>
#include <vkcv/Logger.hpp>

#include <slang.h>
#include <slang-com-ptr.h>
#include <slang-com-helper.h>
#include <vkcv/ShaderStage.hpp>

namespace vkcv::shader {
	
	static uint32_t s_CompilerCount = 0;
  static Slang::ComPtr<slang::IGlobalSession> s_GlobalSession;
	
	SlangCompiler::SlangCompiler(SlangCompileProfile profile)
	: Compiler(), m_profile(profile) {
		if (s_CompilerCount == 0) {
      slang::createGlobalSession(s_GlobalSession.writeRef());
		}
		
		s_CompilerCount++;
	}
	
	SlangCompiler::SlangCompiler(const SlangCompiler &other)
	: Compiler(other), m_profile(other.m_profile) {
		s_CompilerCount++;
	}
	
	SlangCompiler::~SlangCompiler() {
		s_CompilerCount--;
	}
	
	SlangCompiler &SlangCompiler::operator=(const SlangCompiler &other) {
		m_profile = other.m_profile;
		s_CompilerCount++;
		return *this;
	}

	constexpr SlangStage findShaderLanguage(ShaderStage shaderStage) {
		switch (shaderStage) {
			case ShaderStage::VERTEX:
				return SlangStage::SLANG_STAGE_VERTEX;
			case ShaderStage::TESS_CONTROL:
				return SlangStage::SLANG_STAGE_HULL;
			case ShaderStage::TESS_EVAL:
				return SlangStage::SLANG_STAGE_DOMAIN;
			case ShaderStage::GEOMETRY:
				return SlangStage::SLANG_STAGE_GEOMETRY;
			case ShaderStage::FRAGMENT:
				return SlangStage::SLANG_STAGE_FRAGMENT;
			case ShaderStage::COMPUTE:
				return SlangStage::SLANG_STAGE_COMPUTE;
			case ShaderStage::TASK:
				return SlangStage::SLANG_STAGE_AMPLIFICATION;
			case ShaderStage::MESH:
				return SlangStage::SLANG_STAGE_MESH;
			case ShaderStage::RAY_GEN:
			    return SlangStage::SLANG_STAGE_RAY_GENERATION;
			case ShaderStage::RAY_CLOSEST_HIT:
			    return SlangStage::SLANG_STAGE_CLOSEST_HIT;
			case ShaderStage::RAY_MISS:
			    return SlangStage::SLANG_STAGE_MISS;
			case ShaderStage::RAY_INTERSECTION:
				return SlangStage::SLANG_STAGE_INTERSECTION;
			case ShaderStage::RAY_ANY_HIT:
				return SlangStage::SLANG_STAGE_ANY_HIT;
			case ShaderStage::RAY_CALLABLE:
				return SlangStage::SLANG_STAGE_CALLABLE;
			default:
				return SlangStage::SLANG_STAGE_NONE;
		}
	}

	bool SlangCompiler::compileSource(ShaderStage shaderStage,
																		const std::string& shaderSource,
																		const ShaderCompiledFunction& compiled,
																		const std::filesystem::path& includePath) {
		slang::SessionDesc sessionDesc = {};
    slang::TargetDesc targetDesc = {};

		targetDesc.format = SLANG_SPIRV;

		switch (m_profile) {
			case SlangCompileProfile::GLSL:
				targetDesc.profile = s_GlobalSession->findProfile("glsl_460");
				sessionDesc.defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
				break;
			case SlangCompileProfile::HLSL:
				targetDesc.profile = s_GlobalSession->findProfile("sm_5_0");
				sessionDesc.defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
				break;
			case SlangCompileProfile::SPIRV:
				targetDesc.profile = s_GlobalSession->findProfile("spirv_1_5");
				targetDesc.flags = SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY;
				break;
			default:
				break;
		}

		sessionDesc.targets = &targetDesc;
		sessionDesc.targetCount = 1;

		const char *searchPath = includePath.c_str();
		sessionDesc.searchPaths = &searchPath;
		sessionDesc.searchPathCount = 1;

		std::vector<slang::PreprocessorMacroDesc> macros;
		macros.reserve(m_defines.size());

		for (const auto& define : m_defines) {
			const slang::PreprocessorMacroDesc macro = {
				define.first.c_str(),
				define.second.c_str()
			};

			macros.push_back(macro);
		}

		sessionDesc.preprocessorMacros = macros.data();
		sessionDesc.preprocessorMacroCount = macros.size();

		Slang::ComPtr<slang::ISession> session;
		if (SLANG_FAILED(s_GlobalSession->createSession(sessionDesc, session.writeRef()))) {
			vkcv_log(LogLevel::ERROR, "Compiler session could not be created");
			return false;
		}

		Slang::ComPtr<slang::ICompileRequest> request;
		if (SLANG_FAILED(session->createCompileRequest(request.writeRef()))) {
			vkcv_log(LogLevel::ERROR, "Compilation request could not be created");
			return false;
		}

		const int entryPoint = request->addEntryPoint(
			0, "main", findShaderLanguage(shaderStage)
		);
		if (SLANG_FAILED(request->compile())) {
			vkcv_log(LogLevel::ERROR, "Compilation process failed");
			return false;
		}

		size_t size;
		const void *code = request->getEntryPointCode(entryPoint, &size);

		if ((size <= 0) || (!code)) {
			vkcv_log(LogLevel::ERROR, "Entry point could not be found");
			return false;
		}

		std::vector<uint32_t> spirv;
		spirv.resize(size / sizeof(uint32_t));
		memcpy(spirv.data(), code, spirv.size() * sizeof(uint32_t));

		const std::filesystem::path tmp_path = generateTemporaryFilePath();

		if (!writeBinaryToFile(tmp_path, spirv)) {
			vkcv_log(LogLevel::ERROR, "Spir-V could not be written to disk");
			return false;
		}
		
		if (compiled) {
			compiled(shaderStage, tmp_path);
		}
		
		std::filesystem::remove(tmp_path);
		return true;
	}
	
}