Skip to content
Snippets Groups Projects
Verified Commit 8d08758b authored by Tobias Frisch's avatar Tobias Frisch
Browse files

Fix breaking issues in task shader

parent 9f3c4023
No related branches found
No related tags found
1 merge request!111Resolve "Cross vendor mesh shader support"
...@@ -20,23 +20,25 @@ struct Vertex { ...@@ -20,23 +20,25 @@ struct Vertex {
float padding1; float padding1;
}; };
layout(std430, binding = 0) readonly buffer vertexBuffer { layout(std430, set=0, binding = 0) readonly buffer vertexBuffer {
Vertex vertices[]; Vertex vertices[];
}; };
layout(std430, binding = 1) readonly buffer indexBuffer { layout(std430, set=0, binding = 1) readonly buffer indexBuffer {
uint localIndices[]; // breaks for 16 bit indices uint localIndices[]; // breaks for 16 bit indices
}; };
layout(std430, binding = 2) readonly buffer meshletBuffer { layout(std430, set=0, binding = 2) readonly buffer meshletBuffer {
Meshlet meshlets[]; Meshlet meshlets[];
}; };
taskPayloadSharedEXT Task IN; taskPayloadSharedEXT Task IN;
void main() { void main() {
uint meshletIndex = IN.meshletIndices[gl_WorkGroupID.x]; const uint meshletIndex = IN.meshletIndices[gl_WorkGroupID.x];
Meshlet meshlet = meshlets[meshletIndex]; Meshlet meshlet = meshlets[meshletIndex];
SetMeshOutputsEXT(meshlet.vertexCount, meshlet.indexCount / 3);
// set vertices // set vertices
for (uint i = 0; i < 2; i++) { for (uint i = 0; i < 2; i++) {
...@@ -55,16 +57,12 @@ void main() { ...@@ -55,16 +57,12 @@ void main() {
// set local indices // set local indices
for (uint i = 0; i < 12; i++) { for (uint i = 0; i < 12; i++) {
uint workIndex = gl_LocalInvocationID.x + i * 32; const uint workIndex = gl_LocalInvocationID.x + i * 32;
if (workIndex >= meshlet.indexCount) { if (workIndex >= meshlet.indexCount) {
break; break;
} }
uint indexBufferIndex = meshlet.indexOffset + workIndex; const uint indexBufferIndex = meshlet.indexOffset + workIndex;
gl_PrimitiveTriangleIndicesEXT[workIndex] = uvec3(localIndices[indexBufferIndex]); gl_PrimitiveTriangleIndicesEXT[workIndex] = uvec3(localIndices[indexBufferIndex]);
} }
if (gl_LocalInvocationID.x == 0) {
SetMeshOutputsEXT(64, meshlet.indexCount / 3);
}
} }
\ No newline at end of file
...@@ -11,12 +11,11 @@ layout(local_size_x=32, local_size_y=1, local_size_z=1) in; ...@@ -11,12 +11,11 @@ layout(local_size_x=32, local_size_y=1, local_size_z=1) in;
taskPayloadSharedEXT Task OUT; taskPayloadSharedEXT Task OUT;
layout( push_constant ) uniform constants { layout( push_constant ) uniform constants {
uint meshletCount;
uint matrixIndex; uint matrixIndex;
}; };
// TODO: reuse mesh stage binding at location 2 after required fix in framework // TODO: reuse mesh stage binding at location 2 after required fix in framework
layout(std430, binding = 5) readonly buffer meshletBuffer { layout(std430, set=0, binding = 5) readonly buffer meshletBuffer {
Meshlet meshlets[]; Meshlet meshlets[];
}; };
...@@ -31,7 +30,7 @@ layout(set=0, binding = 3, std140) uniform cameraPlaneBuffer { ...@@ -31,7 +30,7 @@ layout(set=0, binding = 3, std140) uniform cameraPlaneBuffer {
Plane cameraPlanes[6]; Plane cameraPlanes[6];
}; };
layout(std430, binding = 4) readonly buffer matrixBuffer { layout(std430, set=0, binding = 4) readonly buffer matrixBuffer {
ObjectMatrices objectMatrices[]; ObjectMatrices objectMatrices[];
}; };
...@@ -47,13 +46,13 @@ bool isSphereInsideFrustum(vec3 spherePos, float sphereRadius, Plane cameraPlane ...@@ -47,13 +46,13 @@ bool isSphereInsideFrustum(vec3 spherePos, float sphereRadius, Plane cameraPlane
} }
void main() { void main() {
if (gl_LocalInvocationID.x >= meshletCount) { const uint meshletIndex = gl_GlobalInvocationID.x;
return; Meshlet meshlet;
if (meshletIndex < meshlets.length()) {
meshlet = meshlets[meshletIndex];
} }
uint meshletIndex = gl_GlobalInvocationID.x;
Meshlet meshlet = meshlets[meshletIndex];
if (gl_LocalInvocationID.x == 0) { if (gl_LocalInvocationID.x == 0) {
taskCount = 0; taskCount = 0;
} }
...@@ -63,14 +62,12 @@ void main() { ...@@ -63,14 +62,12 @@ void main() {
// TODO: scaling support // TODO: scaling support
vec3 meshletPositionWorld = (vec4(meshlet.meanPosition, 1) * objectMatrices[matrixIndex].model).xyz; vec3 meshletPositionWorld = (vec4(meshlet.meanPosition, 1) * objectMatrices[matrixIndex].model).xyz;
if (isSphereInsideFrustum(meshletPositionWorld, meshlet.boundingSphereRadius, cameraPlanes)) { if (isSphereInsideFrustum(meshletPositionWorld, meshlet.boundingSphereRadius, cameraPlanes)) {
uint outIndex = atomicAdd(taskCount, 1); const uint outIndex = atomicAdd(taskCount, 1);
OUT.meshletIndices[outIndex] = gl_GlobalInvocationID.x; OUT.meshletIndices[outIndex] = meshletIndex;
} }
memoryBarrierShared(); memoryBarrierShared();
if (gl_LocalInvocationID.x == 0) { OUT.mvp = objectMatrices[matrixIndex].mvp;
OUT.mvp = objectMatrices[matrixIndex].mvp; EmitMeshTasksEXT(taskCount, 1, 1);
EmitMeshTasksEXT(taskCount, 1, 1);
}
} }
\ No newline at end of file
...@@ -10,18 +10,16 @@ layout(location = 1) in vec3 inNormal; ...@@ -10,18 +10,16 @@ layout(location = 1) in vec3 inNormal;
layout(location = 0) out vec3 passNormal; layout(location = 0) out vec3 passNormal;
layout(location = 1) out flat uint passTaskIndex; layout(location = 1) out flat uint passTaskIndex;
layout(std430, binding = 0) readonly buffer matrixBuffer { layout(std430, set=0, binding = 0) readonly buffer matrixBuffer {
ObjectMatrices objectMatrices[]; ObjectMatrices objectMatrices[];
}; };
layout( push_constant ) uniform constants { layout( push_constant ) uniform constants {
uint padding; // pad to same size as mesh shader constants
uint matrixIndex; uint matrixIndex;
}; };
void main() { void main() {
passNormal = inNormal;
passTaskIndex = 0;
gl_Position = objectMatrices[matrixIndex].mvp * vec4(inPosition, 1.0); gl_Position = objectMatrices[matrixIndex].mvp * vec4(inPosition, 1.0);
passNormal = inNormal;
passTaskIndex = 0;
} }
\ No newline at end of file
...@@ -94,8 +94,7 @@ int main(int argc, const char** argv) { ...@@ -94,8 +94,7 @@ int main(int argc, const char** argv) {
applicationName, applicationName,
VK_MAKE_VERSION(0, 0, 1), VK_MAKE_VERSION(0, 0, 1),
{ vk::QueueFlagBits::eTransfer,vk::QueueFlagBits::eGraphics, vk::QueueFlagBits::eCompute }, { vk::QueueFlagBits::eTransfer,vk::QueueFlagBits::eGraphics, vk::QueueFlagBits::eCompute },
features, features
{ VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME }
); );
vkcv::WindowHandle windowHandle = core.createWindow(applicationName, 1280, 720, true); vkcv::WindowHandle windowHandle = core.createWindow(applicationName, 1280, 720, true);
...@@ -165,7 +164,11 @@ int main(int argc, const char** argv) { ...@@ -165,7 +164,11 @@ int main(int argc, const char** argv) {
vkcv::meshlet::VertexCacheReorderResult tipsifyResult = vkcv::meshlet::tipsifyMesh(indexBuffer32Bit, interleavedVertices.size()); vkcv::meshlet::VertexCacheReorderResult tipsifyResult = vkcv::meshlet::tipsifyMesh(indexBuffer32Bit, interleavedVertices.size());
vkcv::meshlet::VertexCacheReorderResult forsythResult = vkcv::meshlet::forsythReorder(indexBuffer32Bit, interleavedVertices.size()); vkcv::meshlet::VertexCacheReorderResult forsythResult = vkcv::meshlet::forsythReorder(indexBuffer32Bit, interleavedVertices.size());
const auto meshShaderModelData = createMeshShaderModelData(interleavedVertices, forsythResult.indexBuffer, forsythResult.skippedIndices); const auto meshShaderModelData = createMeshShaderModelData(
interleavedVertices,
forsythResult.indexBuffer,
forsythResult.skippedIndices
);
auto meshShaderVertexBuffer = vkcv::buffer<vkcv::meshlet::Vertex>( auto meshShaderVertexBuffer = vkcv::buffer<vkcv::meshlet::Vertex>(
core, core,
...@@ -182,9 +185,7 @@ int main(int argc, const char** argv) { ...@@ -182,9 +185,7 @@ int main(int argc, const char** argv) {
auto meshletBuffer = vkcv::buffer<vkcv::meshlet::Meshlet>( auto meshletBuffer = vkcv::buffer<vkcv::meshlet::Meshlet>(
core, core,
vkcv::BufferType::STORAGE, vkcv::BufferType::STORAGE,
meshShaderModelData.meshlets.size(), meshShaderModelData.meshlets.size());
vkcv::BufferMemoryType::DEVICE_LOCAL
);
meshletBuffer.fill(meshShaderModelData.meshlets); meshletBuffer.fill(meshShaderModelData.meshlets);
vkcv::PassHandle renderPass = vkcv::passSwapchain( vkcv::PassHandle renderPass = vkcv::passSwapchain(
...@@ -227,7 +228,9 @@ int main(int argc, const char** argv) { ...@@ -227,7 +228,9 @@ int main(int argc, const char** argv) {
glm::mat4 mvp; glm::mat4 mvp;
}; };
const size_t objectCount = 1; const size_t objectCount = 1;
vkcv::Buffer<ObjectMatrices> matrixBuffer = vkcv::buffer<ObjectMatrices>(core, vkcv::BufferType::STORAGE, objectCount); vkcv::Buffer<ObjectMatrices> matrixBuffer = vkcv::buffer<ObjectMatrices>(
core, vkcv::BufferType::STORAGE, objectCount
);
vkcv::DescriptorWrites vertexShaderDescriptorWrites; vkcv::DescriptorWrites vertexShaderDescriptorWrites;
vertexShaderDescriptorWrites.writeStorageBuffer(0, matrixBuffer.getHandle()); vertexShaderDescriptorWrites.writeStorageBuffer(0, matrixBuffer.getHandle());
...@@ -286,7 +289,9 @@ int main(int argc, const char** argv) { ...@@ -286,7 +289,9 @@ int main(int argc, const char** argv) {
return EXIT_FAILURE; return EXIT_FAILURE;
} }
vkcv::Buffer<CameraPlanes> cameraPlaneBuffer = vkcv::buffer<CameraPlanes>(core, vkcv::BufferType::UNIFORM, 1); vkcv::Buffer<CameraPlanes> cameraPlaneBuffer = vkcv::buffer<CameraPlanes>(
core, vkcv::BufferType::UNIFORM, 1
);
vkcv::DescriptorWrites meshShaderWrites; vkcv::DescriptorWrites meshShaderWrites;
meshShaderWrites.writeStorageBuffer( meshShaderWrites.writeStorageBuffer(
...@@ -317,10 +322,10 @@ int main(int argc, const char** argv) { ...@@ -317,10 +322,10 @@ int main(int argc, const char** argv) {
vkcv::camera::CameraManager cameraManager(window); vkcv::camera::CameraManager cameraManager(window);
auto camHandle = cameraManager.addCamera(vkcv::camera::ControllerType::PILOT); auto camHandle = cameraManager.addCamera(vkcv::camera::ControllerType::PILOT);
cameraManager.getCamera(camHandle).setPosition(glm::vec3(0, 0, -2)); cameraManager.getCamera(camHandle).setPosition(glm::vec3(0, 2.5f, -2));
bool useMeshShader = true; bool useMeshShader = true;
bool updateFrustumPlanes = true; bool updateFrustumPlanes = true;
core.run([&](const vkcv::WindowHandle &windowHandle, double t, double dt, core.run([&](const vkcv::WindowHandle &windowHandle, double t, double dt,
uint32_t swapchainWidth, uint32_t swapchainHeight) { uint32_t swapchainWidth, uint32_t swapchainHeight) {
...@@ -341,21 +346,11 @@ int main(int argc, const char** argv) { ...@@ -341,21 +346,11 @@ int main(int argc, const char** argv) {
const vkcv::camera::Camera& camera = cameraManager.getActiveCamera(); const vkcv::camera::Camera& camera = cameraManager.getActiveCamera();
ObjectMatrices objectMatrices; ObjectMatrices objectMatrices;
objectMatrices.model = *reinterpret_cast<glm::mat4*>(&mesh.meshes.front().modelMatrix); objectMatrices.model = *reinterpret_cast<glm::mat4*>(&mesh.meshes.front().modelMatrix);
objectMatrices.mvp = camera.getMVP() * objectMatrices.model; objectMatrices.mvp = camera.getMVP() * objectMatrices.model;
matrixBuffer.fill({ objectMatrices }); matrixBuffer.fill({ objectMatrices });
struct MeshletPushConstants {
uint32_t meshletCount;
uint32_t matrixIndex;
};
MeshletPushConstants pushConstants {
static_cast<uint32_t>(meshShaderModelData.meshlets.size()),
0
};
if (updateFrustumPlanes) { if (updateFrustumPlanes) {
const CameraPlanes cameraPlanes = computeCameraPlanes(camera); const CameraPlanes cameraPlanes = computeCameraPlanes(camera);
cameraPlaneBuffer.fill({ cameraPlanes }); cameraPlaneBuffer.fill({ cameraPlanes });
...@@ -364,15 +359,14 @@ int main(int argc, const char** argv) { ...@@ -364,15 +359,14 @@ int main(int argc, const char** argv) {
const std::vector<vkcv::ImageHandle> renderTargets = { swapchainInput, depthBuffer }; const std::vector<vkcv::ImageHandle> renderTargets = { swapchainInput, depthBuffer };
auto cmdStream = core.createCommandStream(vkcv::QueueType::Graphics); auto cmdStream = core.createCommandStream(vkcv::QueueType::Graphics);
vkcv::PushConstants pushConstantData = vkcv::pushConstants<MeshletPushConstants>(); vkcv::PushConstants pushConstantData = vkcv::pushConstants<uint32_t>(0);
pushConstantData.appendDrawcall(pushConstants);
if (useMeshShader) { if (useMeshShader) {
const uint32_t taskCount = (meshShaderModelData.meshlets.size() + 31) / 32; vkcv::TaskDrawcall drawcall (vkcv::dispatchInvocations(
meshShaderModelData.meshlets.size(), 32
));
vkcv::TaskDrawcall drawcall (taskCount);
drawcall.useDescriptorSet(0, meshShaderDescriptorSet); drawcall.useDescriptorSet(0, meshShaderDescriptorSet);
core.recordMeshShaderDrawcalls( core.recordMeshShaderDrawcalls(
cmdStream, cmdStream,
meshShaderPipeline, meshShaderPipeline,
...@@ -385,14 +379,14 @@ int main(int argc, const char** argv) { ...@@ -385,14 +379,14 @@ int main(int argc, const char** argv) {
vkcv::InstanceDrawcall drawcall (vertexData); vkcv::InstanceDrawcall drawcall (vertexData);
drawcall.useDescriptorSet(0, vertexShaderDescriptorSet); drawcall.useDescriptorSet(0, vertexShaderDescriptorSet);
/*core.recordDrawcallsToCmdStream( core.recordDrawcallsToCmdStream(
cmdStream, cmdStream,
bunnyPipeline, bunnyPipeline,
pushConstantData, pushConstantData,
{ drawcall }, { drawcall },
{ renderTargets }, { renderTargets },
windowHandle windowHandle
);*/ );
} }
core.prepareSwapchainImageForPresent(cmdStream); core.prepareSwapchainImageForPresent(cmdStream);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment