diff --git a/projects/mesh_shader/assets/shaders/shader.mesh b/projects/mesh_shader/assets/shaders/shader.mesh
index 8934bdf5b4e550439685ca1fa57a25c4e922a018..672f962ee9788e85a67e038c9c64da30fd7994e6 100644
--- a/projects/mesh_shader/assets/shaders/shader.mesh
+++ b/projects/mesh_shader/assets/shaders/shader.mesh
@@ -11,7 +11,7 @@ layout(triangles) out;
 layout(max_vertices=64, max_primitives=126) out;
 
 layout(location = 0) out vec3 passNormal[];
-layout(location = 1) out uint passTaskIndex[];
+layout(location = 1) out flat uint passTaskIndex[];
 
 struct Vertex {
     vec3  position;
@@ -34,35 +34,50 @@ layout(std430, set=0, binding = 2) readonly buffer meshletBuffer {
 
 taskPayloadSharedEXT Task IN;
 
+void pass_vertex(uint meshletIndex, uint workIndex) {
+    const Meshlet meshlet = meshlets[meshletIndex];
+
+    if (workIndex >= meshlet.vertexCount) {
+        return;
+    }
+
+    const uint vertexIndex = meshlet.vertexOffset + workIndex;
+    const Vertex vertex = vertices[vertexIndex];
+
+    gl_MeshVerticesEXT[workIndex].gl_Position = IN.mvp * vec4(vertex.position, 1);
+    passNormal[workIndex]                     = vertex.normal;
+    passTaskIndex[workIndex]                  = meshletIndex;
+}
+
+void pass_index(uint meshletIndex, uint workIndex) {
+    const Meshlet meshlet = meshlets[meshletIndex];
+
+    if (workIndex * 3 + 2 >= meshlet.indexCount) {
+        return;
+    }
+
+    const uint indexBufferIndex = meshlet.indexOffset + workIndex * 3;
+
+    gl_PrimitiveTriangleIndicesEXT[workIndex] = uvec3(
+        localIndices[indexBufferIndex],
+        localIndices[indexBufferIndex + 1],
+        localIndices[indexBufferIndex + 2]
+    );
+}
+
 void main()	{
     const uint meshletIndex = IN.meshletIndices[gl_WorkGroupID.x];
     Meshlet meshlet = meshlets[meshletIndex];
 
     SetMeshOutputsEXT(meshlet.vertexCount, meshlet.indexCount / 3);
-    
+
     // set vertices
     for (uint i = 0; i < 2; i++) {
-        uint workIndex = gl_LocalInvocationID.x + 32 * i;
-        if (workIndex >= meshlet.vertexCount) {
-            break;
-        }
-    
-        uint vertexIndex    = meshlet.vertexOffset + workIndex;
-        Vertex vertex       = vertices[vertexIndex];
-    
-        gl_MeshVerticesEXT[workIndex].gl_Position   = IN.mvp * vec4(vertex.position, 1);
-        passNormal[workIndex]                       = vertex.normal;
-        passTaskIndex[workIndex]                    = meshletIndex;
+        pass_vertex(meshletIndex, gl_LocalInvocationIndex + i * 32);
     }
-    
+
     // set local indices
-    for (uint i = 0; i < 12; i++) {
-        const uint workIndex = gl_LocalInvocationID.x + i * 32;
-        if (workIndex >= meshlet.indexCount) {
-            break;
-        }
-        
-        const uint indexBufferIndex = meshlet.indexOffset + workIndex;
-        gl_PrimitiveTriangleIndicesEXT[workIndex] = uvec3(localIndices[indexBufferIndex]);
+    for (uint i = 0; i < 4; i++) {
+        pass_index(meshletIndex, gl_LocalInvocationIndex + i * 32);
     }
 }
\ No newline at end of file