From 8d08758b7a5424eff93b1a377fbb92fb48093681 Mon Sep 17 00:00:00 2001 From: Tobias Frisch <tfrisch@uni-koblenz.de> Date: Mon, 30 Jan 2023 01:23:38 +0100 Subject: [PATCH] Fix breaking issues in task shader Signed-off-by: Tobias Frisch <tfrisch@uni-koblenz.de> --- .../mesh_shader/assets/shaders/shader.mesh | 18 +++---- .../mesh_shader/assets/shaders/shader.task | 25 ++++----- .../mesh_shader/assets/shaders/shader.vert | 8 ++- projects/mesh_shader/src/main.cpp | 54 +++++++++---------- 4 files changed, 46 insertions(+), 59 deletions(-) diff --git a/projects/mesh_shader/assets/shaders/shader.mesh b/projects/mesh_shader/assets/shaders/shader.mesh index 0d97389b..8934bdf5 100644 --- a/projects/mesh_shader/assets/shaders/shader.mesh +++ b/projects/mesh_shader/assets/shaders/shader.mesh @@ -20,23 +20,25 @@ struct Vertex { float padding1; }; -layout(std430, binding = 0) readonly buffer vertexBuffer { +layout(std430, set=0, binding = 0) readonly buffer vertexBuffer { Vertex vertices[]; }; -layout(std430, binding = 1) readonly buffer indexBuffer { +layout(std430, set=0, binding = 1) readonly buffer indexBuffer { uint localIndices[]; // breaks for 16 bit indices }; -layout(std430, binding = 2) readonly buffer meshletBuffer { +layout(std430, set=0, binding = 2) readonly buffer meshletBuffer { Meshlet meshlets[]; }; taskPayloadSharedEXT Task IN; void main() { - uint meshletIndex = IN.meshletIndices[gl_WorkGroupID.x]; + const uint meshletIndex = IN.meshletIndices[gl_WorkGroupID.x]; Meshlet meshlet = meshlets[meshletIndex]; + + SetMeshOutputsEXT(meshlet.vertexCount, meshlet.indexCount / 3); // set vertices for (uint i = 0; i < 2; i++) { @@ -55,16 +57,12 @@ void main() { // set local indices for (uint i = 0; i < 12; i++) { - uint workIndex = gl_LocalInvocationID.x + i * 32; + const uint workIndex = gl_LocalInvocationID.x + i * 32; if (workIndex >= meshlet.indexCount) { break; } - uint indexBufferIndex = meshlet.indexOffset + workIndex; + const uint indexBufferIndex = meshlet.indexOffset + workIndex; gl_PrimitiveTriangleIndicesEXT[workIndex] = uvec3(localIndices[indexBufferIndex]); } - - if (gl_LocalInvocationID.x == 0) { - SetMeshOutputsEXT(64, meshlet.indexCount / 3); - } } \ No newline at end of file diff --git a/projects/mesh_shader/assets/shaders/shader.task b/projects/mesh_shader/assets/shaders/shader.task index be09706c..516e6cd5 100644 --- a/projects/mesh_shader/assets/shaders/shader.task +++ b/projects/mesh_shader/assets/shaders/shader.task @@ -11,12 +11,11 @@ layout(local_size_x=32, local_size_y=1, local_size_z=1) in; taskPayloadSharedEXT Task OUT; layout( push_constant ) uniform constants { - uint meshletCount; uint matrixIndex; }; // TODO: reuse mesh stage binding at location 2 after required fix in framework -layout(std430, binding = 5) readonly buffer meshletBuffer { +layout(std430, set=0, binding = 5) readonly buffer meshletBuffer { Meshlet meshlets[]; }; @@ -31,7 +30,7 @@ layout(set=0, binding = 3, std140) uniform cameraPlaneBuffer { Plane cameraPlanes[6]; }; -layout(std430, binding = 4) readonly buffer matrixBuffer { +layout(std430, set=0, binding = 4) readonly buffer matrixBuffer { ObjectMatrices objectMatrices[]; }; @@ -47,13 +46,13 @@ bool isSphereInsideFrustum(vec3 spherePos, float sphereRadius, Plane cameraPlane } void main() { - if (gl_LocalInvocationID.x >= meshletCount) { - return; + const uint meshletIndex = gl_GlobalInvocationID.x; + Meshlet meshlet; + + if (meshletIndex < meshlets.length()) { + meshlet = meshlets[meshletIndex]; } - uint meshletIndex = gl_GlobalInvocationID.x; - Meshlet meshlet = meshlets[meshletIndex]; - if (gl_LocalInvocationID.x == 0) { taskCount = 0; } @@ -63,14 +62,12 @@ void main() { // TODO: scaling support vec3 meshletPositionWorld = (vec4(meshlet.meanPosition, 1) * objectMatrices[matrixIndex].model).xyz; if (isSphereInsideFrustum(meshletPositionWorld, meshlet.boundingSphereRadius, cameraPlanes)) { - uint outIndex = atomicAdd(taskCount, 1); - OUT.meshletIndices[outIndex] = gl_GlobalInvocationID.x; + const uint outIndex = atomicAdd(taskCount, 1); + OUT.meshletIndices[outIndex] = meshletIndex; } memoryBarrierShared(); - if (gl_LocalInvocationID.x == 0) { - OUT.mvp = objectMatrices[matrixIndex].mvp; - EmitMeshTasksEXT(taskCount, 1, 1); - } + OUT.mvp = objectMatrices[matrixIndex].mvp; + EmitMeshTasksEXT(taskCount, 1, 1); } \ No newline at end of file diff --git a/projects/mesh_shader/assets/shaders/shader.vert b/projects/mesh_shader/assets/shaders/shader.vert index 0da4f058..0f1c42d6 100644 --- a/projects/mesh_shader/assets/shaders/shader.vert +++ b/projects/mesh_shader/assets/shaders/shader.vert @@ -10,18 +10,16 @@ layout(location = 1) in vec3 inNormal; layout(location = 0) out vec3 passNormal; layout(location = 1) out flat uint passTaskIndex; -layout(std430, binding = 0) readonly buffer matrixBuffer { +layout(std430, set=0, binding = 0) readonly buffer matrixBuffer { ObjectMatrices objectMatrices[]; }; layout( push_constant ) uniform constants { - uint padding; // pad to same size as mesh shader constants uint matrixIndex; }; void main() { - passNormal = inNormal; - passTaskIndex = 0; - gl_Position = objectMatrices[matrixIndex].mvp * vec4(inPosition, 1.0); + passNormal = inNormal; + passTaskIndex = 0; } \ No newline at end of file diff --git a/projects/mesh_shader/src/main.cpp b/projects/mesh_shader/src/main.cpp index 23d8b028..72b15e31 100644 --- a/projects/mesh_shader/src/main.cpp +++ b/projects/mesh_shader/src/main.cpp @@ -94,8 +94,7 @@ int main(int argc, const char** argv) { applicationName, VK_MAKE_VERSION(0, 0, 1), { vk::QueueFlagBits::eTransfer,vk::QueueFlagBits::eGraphics, vk::QueueFlagBits::eCompute }, - features, - { VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME } + features ); vkcv::WindowHandle windowHandle = core.createWindow(applicationName, 1280, 720, true); @@ -165,7 +164,11 @@ int main(int argc, const char** argv) { vkcv::meshlet::VertexCacheReorderResult tipsifyResult = vkcv::meshlet::tipsifyMesh(indexBuffer32Bit, interleavedVertices.size()); vkcv::meshlet::VertexCacheReorderResult forsythResult = vkcv::meshlet::forsythReorder(indexBuffer32Bit, interleavedVertices.size()); - const auto meshShaderModelData = createMeshShaderModelData(interleavedVertices, forsythResult.indexBuffer, forsythResult.skippedIndices); + const auto meshShaderModelData = createMeshShaderModelData( + interleavedVertices, + forsythResult.indexBuffer, + forsythResult.skippedIndices + ); auto meshShaderVertexBuffer = vkcv::buffer<vkcv::meshlet::Vertex>( core, @@ -182,9 +185,7 @@ int main(int argc, const char** argv) { auto meshletBuffer = vkcv::buffer<vkcv::meshlet::Meshlet>( core, vkcv::BufferType::STORAGE, - meshShaderModelData.meshlets.size(), - vkcv::BufferMemoryType::DEVICE_LOCAL - ); + meshShaderModelData.meshlets.size()); meshletBuffer.fill(meshShaderModelData.meshlets); vkcv::PassHandle renderPass = vkcv::passSwapchain( @@ -227,7 +228,9 @@ int main(int argc, const char** argv) { glm::mat4 mvp; }; const size_t objectCount = 1; - vkcv::Buffer<ObjectMatrices> matrixBuffer = vkcv::buffer<ObjectMatrices>(core, vkcv::BufferType::STORAGE, objectCount); + vkcv::Buffer<ObjectMatrices> matrixBuffer = vkcv::buffer<ObjectMatrices>( + core, vkcv::BufferType::STORAGE, objectCount + ); vkcv::DescriptorWrites vertexShaderDescriptorWrites; vertexShaderDescriptorWrites.writeStorageBuffer(0, matrixBuffer.getHandle()); @@ -286,7 +289,9 @@ int main(int argc, const char** argv) { return EXIT_FAILURE; } - vkcv::Buffer<CameraPlanes> cameraPlaneBuffer = vkcv::buffer<CameraPlanes>(core, vkcv::BufferType::UNIFORM, 1); + vkcv::Buffer<CameraPlanes> cameraPlaneBuffer = vkcv::buffer<CameraPlanes>( + core, vkcv::BufferType::UNIFORM, 1 + ); vkcv::DescriptorWrites meshShaderWrites; meshShaderWrites.writeStorageBuffer( @@ -317,10 +322,10 @@ int main(int argc, const char** argv) { vkcv::camera::CameraManager cameraManager(window); auto camHandle = cameraManager.addCamera(vkcv::camera::ControllerType::PILOT); - cameraManager.getCamera(camHandle).setPosition(glm::vec3(0, 0, -2)); + cameraManager.getCamera(camHandle).setPosition(glm::vec3(0, 2.5f, -2)); - bool useMeshShader = true; - bool updateFrustumPlanes = true; + bool useMeshShader = true; + bool updateFrustumPlanes = true; core.run([&](const vkcv::WindowHandle &windowHandle, double t, double dt, uint32_t swapchainWidth, uint32_t swapchainHeight) { @@ -341,21 +346,11 @@ int main(int argc, const char** argv) { const vkcv::camera::Camera& camera = cameraManager.getActiveCamera(); ObjectMatrices objectMatrices; - objectMatrices.model = *reinterpret_cast<glm::mat4*>(&mesh.meshes.front().modelMatrix); - objectMatrices.mvp = camera.getMVP() * objectMatrices.model; + objectMatrices.model = *reinterpret_cast<glm::mat4*>(&mesh.meshes.front().modelMatrix); + objectMatrices.mvp = camera.getMVP() * objectMatrices.model; matrixBuffer.fill({ objectMatrices }); - struct MeshletPushConstants { - uint32_t meshletCount; - uint32_t matrixIndex; - }; - - MeshletPushConstants pushConstants { - static_cast<uint32_t>(meshShaderModelData.meshlets.size()), - 0 - }; - if (updateFrustumPlanes) { const CameraPlanes cameraPlanes = computeCameraPlanes(camera); cameraPlaneBuffer.fill({ cameraPlanes }); @@ -364,15 +359,14 @@ int main(int argc, const char** argv) { const std::vector<vkcv::ImageHandle> renderTargets = { swapchainInput, depthBuffer }; auto cmdStream = core.createCommandStream(vkcv::QueueType::Graphics); - vkcv::PushConstants pushConstantData = vkcv::pushConstants<MeshletPushConstants>(); - pushConstantData.appendDrawcall(pushConstants); + vkcv::PushConstants pushConstantData = vkcv::pushConstants<uint32_t>(0); if (useMeshShader) { - const uint32_t taskCount = (meshShaderModelData.meshlets.size() + 31) / 32; + vkcv::TaskDrawcall drawcall (vkcv::dispatchInvocations( + meshShaderModelData.meshlets.size(), 32 + )); - vkcv::TaskDrawcall drawcall (taskCount); drawcall.useDescriptorSet(0, meshShaderDescriptorSet); - core.recordMeshShaderDrawcalls( cmdStream, meshShaderPipeline, @@ -385,14 +379,14 @@ int main(int argc, const char** argv) { vkcv::InstanceDrawcall drawcall (vertexData); drawcall.useDescriptorSet(0, vertexShaderDescriptorSet); - /*core.recordDrawcallsToCmdStream( + core.recordDrawcallsToCmdStream( cmdStream, bunnyPipeline, pushConstantData, { drawcall }, { renderTargets }, windowHandle - );*/ + ); } core.prepareSwapchainImageForPresent(cmdStream); -- GitLab