From 8d08758b7a5424eff93b1a377fbb92fb48093681 Mon Sep 17 00:00:00 2001
From: Tobias Frisch <tfrisch@uni-koblenz.de>
Date: Mon, 30 Jan 2023 01:23:38 +0100
Subject: [PATCH] Fix breaking issues in task shader

Signed-off-by: Tobias Frisch <tfrisch@uni-koblenz.de>
---
 .../mesh_shader/assets/shaders/shader.mesh    | 18 +++----
 .../mesh_shader/assets/shaders/shader.task    | 25 ++++-----
 .../mesh_shader/assets/shaders/shader.vert    |  8 ++-
 projects/mesh_shader/src/main.cpp             | 54 +++++++++----------
 4 files changed, 46 insertions(+), 59 deletions(-)

diff --git a/projects/mesh_shader/assets/shaders/shader.mesh b/projects/mesh_shader/assets/shaders/shader.mesh
index 0d97389b..8934bdf5 100644
--- a/projects/mesh_shader/assets/shaders/shader.mesh
+++ b/projects/mesh_shader/assets/shaders/shader.mesh
@@ -20,23 +20,25 @@ struct Vertex {
     float padding1;
 };
 
-layout(std430, binding = 0) readonly buffer vertexBuffer {
+layout(std430, set=0, binding = 0) readonly buffer vertexBuffer {
     Vertex vertices[];
 };
 
-layout(std430, binding = 1) readonly buffer indexBuffer {
+layout(std430, set=0, binding = 1) readonly buffer indexBuffer {
     uint localIndices[]; // breaks for 16 bit indices
 };
 
-layout(std430, binding = 2) readonly buffer meshletBuffer {
+layout(std430, set=0, binding = 2) readonly buffer meshletBuffer {
     Meshlet meshlets[];
 };
 
 taskPayloadSharedEXT Task IN;
 
 void main()	{
-    uint meshletIndex = IN.meshletIndices[gl_WorkGroupID.x];
+    const uint meshletIndex = IN.meshletIndices[gl_WorkGroupID.x];
     Meshlet meshlet = meshlets[meshletIndex];
+
+    SetMeshOutputsEXT(meshlet.vertexCount, meshlet.indexCount / 3);
     
     // set vertices
     for (uint i = 0; i < 2; i++) {
@@ -55,16 +57,12 @@ void main()	{
     
     // set local indices
     for (uint i = 0; i < 12; i++) {
-        uint workIndex = gl_LocalInvocationID.x + i * 32;
+        const uint workIndex = gl_LocalInvocationID.x + i * 32;
         if (workIndex >= meshlet.indexCount) {
             break;
         }
         
-        uint indexBufferIndex = meshlet.indexOffset + workIndex;
+        const uint indexBufferIndex = meshlet.indexOffset + workIndex;
         gl_PrimitiveTriangleIndicesEXT[workIndex] = uvec3(localIndices[indexBufferIndex]);
     }
-    
-    if (gl_LocalInvocationID.x == 0) {
-        SetMeshOutputsEXT(64, meshlet.indexCount / 3);
-    }
 }
\ No newline at end of file
diff --git a/projects/mesh_shader/assets/shaders/shader.task b/projects/mesh_shader/assets/shaders/shader.task
index be09706c..516e6cd5 100644
--- a/projects/mesh_shader/assets/shaders/shader.task
+++ b/projects/mesh_shader/assets/shaders/shader.task
@@ -11,12 +11,11 @@ layout(local_size_x=32, local_size_y=1, local_size_z=1) in;
 taskPayloadSharedEXT Task OUT;
 
 layout( push_constant ) uniform constants {
-    uint meshletCount;
     uint matrixIndex;
 };
 
 // TODO: reuse mesh stage binding at location 2 after required fix in framework
-layout(std430, binding = 5) readonly buffer meshletBuffer {
+layout(std430, set=0, binding = 5) readonly buffer meshletBuffer {
     Meshlet meshlets[];
 };
 
@@ -31,7 +30,7 @@ layout(set=0, binding = 3, std140) uniform cameraPlaneBuffer {
     Plane cameraPlanes[6];
 };
 
-layout(std430, binding = 4) readonly buffer matrixBuffer {
+layout(std430, set=0, binding = 4) readonly buffer matrixBuffer {
     ObjectMatrices objectMatrices[];
 };
 
@@ -47,13 +46,13 @@ bool isSphereInsideFrustum(vec3 spherePos, float sphereRadius, Plane cameraPlane
 }
 
 void main() {
-    if (gl_LocalInvocationID.x >= meshletCount) {
-        return;
+    const uint meshletIndex = gl_GlobalInvocationID.x;
+    Meshlet meshlet;
+
+    if (meshletIndex < meshlets.length()) {
+        meshlet = meshlets[meshletIndex];
     }
     
-    uint meshletIndex   = gl_GlobalInvocationID.x;
-    Meshlet meshlet     = meshlets[meshletIndex]; 
-    
     if (gl_LocalInvocationID.x == 0) {
         taskCount = 0;
     }
@@ -63,14 +62,12 @@ void main() {
     // TODO: scaling support
     vec3 meshletPositionWorld = (vec4(meshlet.meanPosition, 1) * objectMatrices[matrixIndex].model).xyz;
     if (isSphereInsideFrustum(meshletPositionWorld, meshlet.boundingSphereRadius, cameraPlanes)) {
-        uint outIndex = atomicAdd(taskCount, 1);
-        OUT.meshletIndices[outIndex] = gl_GlobalInvocationID.x;
+        const uint outIndex = atomicAdd(taskCount, 1);
+        OUT.meshletIndices[outIndex] = meshletIndex;
     }
 
     memoryBarrierShared();
 
-    if (gl_LocalInvocationID.x == 0) {
-        OUT.mvp = objectMatrices[matrixIndex].mvp;
-        EmitMeshTasksEXT(taskCount, 1, 1);
-    }
+    OUT.mvp = objectMatrices[matrixIndex].mvp;
+    EmitMeshTasksEXT(taskCount, 1, 1);
 }
\ No newline at end of file
diff --git a/projects/mesh_shader/assets/shaders/shader.vert b/projects/mesh_shader/assets/shaders/shader.vert
index 0da4f058..0f1c42d6 100644
--- a/projects/mesh_shader/assets/shaders/shader.vert
+++ b/projects/mesh_shader/assets/shaders/shader.vert
@@ -10,18 +10,16 @@ layout(location = 1) in vec3 inNormal;
 layout(location = 0) out vec3 passNormal;
 layout(location = 1) out flat uint passTaskIndex;
 
-layout(std430, binding = 0) readonly buffer matrixBuffer {
+layout(std430, set=0, binding = 0) readonly buffer matrixBuffer {
     ObjectMatrices objectMatrices[];
 };
 
 layout( push_constant ) uniform constants {
-    uint padding; // pad to same size as mesh shader constants
     uint matrixIndex;
 };
 
 void main()	{
-	passNormal = inNormal;
-    passTaskIndex = 0;
-
     gl_Position = objectMatrices[matrixIndex].mvp * vec4(inPosition, 1.0);
+    passNormal = inNormal;
+    passTaskIndex = 0;
 }
\ No newline at end of file
diff --git a/projects/mesh_shader/src/main.cpp b/projects/mesh_shader/src/main.cpp
index 23d8b028..72b15e31 100644
--- a/projects/mesh_shader/src/main.cpp
+++ b/projects/mesh_shader/src/main.cpp
@@ -94,8 +94,7 @@ int main(int argc, const char** argv) {
 		applicationName,
 		VK_MAKE_VERSION(0, 0, 1),
 		{ vk::QueueFlagBits::eTransfer,vk::QueueFlagBits::eGraphics, vk::QueueFlagBits::eCompute },
-		features,
-		{ VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME }
+		features
 	);
 	
 	vkcv::WindowHandle windowHandle = core.createWindow(applicationName, 1280, 720, true);
@@ -165,7 +164,11 @@ int main(int argc, const char** argv) {
     vkcv::meshlet::VertexCacheReorderResult tipsifyResult = vkcv::meshlet::tipsifyMesh(indexBuffer32Bit, interleavedVertices.size());
     vkcv::meshlet::VertexCacheReorderResult forsythResult = vkcv::meshlet::forsythReorder(indexBuffer32Bit, interleavedVertices.size());
 
-    const auto meshShaderModelData = createMeshShaderModelData(interleavedVertices, forsythResult.indexBuffer, forsythResult.skippedIndices);
+    const auto meshShaderModelData = createMeshShaderModelData(
+			interleavedVertices,
+			forsythResult.indexBuffer,
+			forsythResult.skippedIndices
+	);
 
 	auto meshShaderVertexBuffer = vkcv::buffer<vkcv::meshlet::Vertex>(
 		core,
@@ -182,9 +185,7 @@ int main(int argc, const char** argv) {
 	auto meshletBuffer = vkcv::buffer<vkcv::meshlet::Meshlet>(
 		core,
 		vkcv::BufferType::STORAGE,
-		meshShaderModelData.meshlets.size(),
-		vkcv::BufferMemoryType::DEVICE_LOCAL
-		);
+		meshShaderModelData.meshlets.size());
 	meshletBuffer.fill(meshShaderModelData.meshlets);
 	
 	vkcv::PassHandle renderPass = vkcv::passSwapchain(
@@ -227,7 +228,9 @@ int main(int argc, const char** argv) {
 		glm::mat4 mvp;
 	};
 	const size_t objectCount = 1;
-	vkcv::Buffer<ObjectMatrices> matrixBuffer = vkcv::buffer<ObjectMatrices>(core, vkcv::BufferType::STORAGE, objectCount);
+	vkcv::Buffer<ObjectMatrices> matrixBuffer = vkcv::buffer<ObjectMatrices>(
+			core, vkcv::BufferType::STORAGE, objectCount
+	);
 
 	vkcv::DescriptorWrites vertexShaderDescriptorWrites;
 	vertexShaderDescriptorWrites.writeStorageBuffer(0, matrixBuffer.getHandle());
@@ -286,7 +289,9 @@ int main(int argc, const char** argv) {
 		return EXIT_FAILURE;
 	}
 
-	vkcv::Buffer<CameraPlanes> cameraPlaneBuffer = vkcv::buffer<CameraPlanes>(core, vkcv::BufferType::UNIFORM, 1);
+	vkcv::Buffer<CameraPlanes> cameraPlaneBuffer = vkcv::buffer<CameraPlanes>(
+			core, vkcv::BufferType::UNIFORM, 1
+	);
 
 	vkcv::DescriptorWrites meshShaderWrites;
 	meshShaderWrites.writeStorageBuffer(
@@ -317,10 +322,10 @@ int main(int argc, const char** argv) {
 	vkcv::camera::CameraManager cameraManager(window);
 	auto camHandle = cameraManager.addCamera(vkcv::camera::ControllerType::PILOT);
 	
-	cameraManager.getCamera(camHandle).setPosition(glm::vec3(0, 0, -2));
+	cameraManager.getCamera(camHandle).setPosition(glm::vec3(0, 2.5f, -2));
 
-	bool useMeshShader          = true;
-	bool updateFrustumPlanes    = true;
+	bool useMeshShader       = true;
+	bool updateFrustumPlanes = true;
 	
 	core.run([&](const vkcv::WindowHandle &windowHandle, double t, double dt,
 				 uint32_t swapchainWidth, uint32_t swapchainHeight) {
@@ -341,21 +346,11 @@ int main(int argc, const char** argv) {
 		const vkcv::camera::Camera& camera = cameraManager.getActiveCamera();
 
 		ObjectMatrices objectMatrices;
-		objectMatrices.model    = *reinterpret_cast<glm::mat4*>(&mesh.meshes.front().modelMatrix);
-		objectMatrices.mvp      = camera.getMVP() * objectMatrices.model;
+		objectMatrices.model = *reinterpret_cast<glm::mat4*>(&mesh.meshes.front().modelMatrix);
+		objectMatrices.mvp   = camera.getMVP() * objectMatrices.model;
 
 		matrixBuffer.fill({ objectMatrices });
 
-		struct MeshletPushConstants {
-			uint32_t meshletCount;
-			uint32_t matrixIndex;
-		};
-		
-		MeshletPushConstants pushConstants {
-			static_cast<uint32_t>(meshShaderModelData.meshlets.size()),
-			0
-		};
-
 		if (updateFrustumPlanes) {
 			const CameraPlanes cameraPlanes = computeCameraPlanes(camera);
 			cameraPlaneBuffer.fill({ cameraPlanes });
@@ -364,15 +359,14 @@ int main(int argc, const char** argv) {
 		const std::vector<vkcv::ImageHandle> renderTargets = { swapchainInput, depthBuffer };
 		auto cmdStream = core.createCommandStream(vkcv::QueueType::Graphics);
 
-		vkcv::PushConstants pushConstantData = vkcv::pushConstants<MeshletPushConstants>();
-		pushConstantData.appendDrawcall(pushConstants);
+		vkcv::PushConstants pushConstantData = vkcv::pushConstants<uint32_t>(0);
 
 		if (useMeshShader) {
-			const uint32_t taskCount = (meshShaderModelData.meshlets.size() + 31) / 32;
+			vkcv::TaskDrawcall drawcall (vkcv::dispatchInvocations(
+					meshShaderModelData.meshlets.size(), 32
+			));
 			
-			vkcv::TaskDrawcall drawcall (taskCount);
 			drawcall.useDescriptorSet(0, meshShaderDescriptorSet);
-
 			core.recordMeshShaderDrawcalls(
 				cmdStream,
 				meshShaderPipeline,
@@ -385,14 +379,14 @@ int main(int argc, const char** argv) {
 			vkcv::InstanceDrawcall drawcall (vertexData);
 			drawcall.useDescriptorSet(0, vertexShaderDescriptorSet);
 
-			/*core.recordDrawcallsToCmdStream(
+			core.recordDrawcallsToCmdStream(
 				cmdStream,
 				bunnyPipeline,
 				pushConstantData,
 				{ drawcall },
 				{ renderTargets },
 				windowHandle
-			);*/
+			);
 		}
 
 		core.prepareSwapchainImageForPresent(cmdStream);
-- 
GitLab