diff --git a/projects/mesh_shader/resources/shaders/shader.frag b/projects/mesh_shader/resources/shaders/shader.frag
index 34f78313f547d99d352b9e4ca35bea9f905386a8..f4f6982f2089e6c8e102027f3b8763bb38f8e59c 100644
--- a/projects/mesh_shader/resources/shaders/shader.frag
+++ b/projects/mesh_shader/resources/shaders/shader.frag
@@ -27,6 +27,6 @@ vec3 colorFromIndex(uint i){
 }
 
 void main() {
-	//outColor = normalize(passNormal) * 0.5 + 0.5;
+	outColor = normalize(passNormal) * 0.5 + 0.5;
     outColor = colorFromIndex(passTaskIndex);
 }
\ No newline at end of file
diff --git a/projects/mesh_shader/resources/shaders/shader.mesh b/projects/mesh_shader/resources/shaders/shader.mesh
index e92874393bf075d270e57c5eb4db080e4d79fe25..628d98416a1b0ea57788019c3510107d98ed731f 100644
--- a/projects/mesh_shader/resources/shaders/shader.mesh
+++ b/projects/mesh_shader/resources/shaders/shader.mesh
@@ -43,12 +43,13 @@ layout(std430, binding = 2) readonly buffer meshletBuffer
 };
 
 taskNV in Task {
-  uint meshletIndex;
+  uint meshletIndices[32];
 } IN;
 
 void main()	{
     
-    Meshlet meshlet = meshlets[IN.meshletIndex];
+    uint meshletIndex = IN.meshletIndices[gl_WorkGroupID.x];
+    Meshlet meshlet = meshlets[meshletIndex];
     
     // set vertices
     for(uint i = 0; i < 2; i++){
@@ -63,7 +64,7 @@ void main()	{
     
         gl_MeshVerticesNV[workIndex].gl_Position    = mvp * vec4(vertex.position, 1);
         passNormal[workIndex]                       = vertex.normal;
-        passTaskIndex[workIndex]                    = IN.meshletIndex;
+        passTaskIndex[workIndex]                    = meshletIndex;
     }
     
     // set local indices
diff --git a/projects/mesh_shader/resources/shaders/shader.task b/projects/mesh_shader/resources/shaders/shader.task
index aedeba1505e21c63cfa89b04fb73e97955cd9b8c..0ac1169eca1507eb49231a7174530631fa2e8ecd 100644
--- a/projects/mesh_shader/resources/shaders/shader.task
+++ b/projects/mesh_shader/resources/shaders/shader.task
@@ -2,13 +2,23 @@
 #extension GL_ARB_separate_shader_objects   : enable
 #extension GL_NV_mesh_shader                : require
 
-layout(local_size_x=1) in;
+layout(local_size_x=32) in;
 
 taskNV out Task {
-  uint meshletIndex;
+  uint meshletIndices[32];
 } OUT;
 
+layout( push_constant ) uniform constants{
+    mat4 mvp;
+    uint meshletCount;
+};
+
 void main() {
-    gl_TaskCountNV      = 1;
-    OUT.meshletIndex    = gl_GlobalInvocationID.x;
+    if(gl_LocalInvocationID.x == 0){
+        int taskCount              = int(gl_WorkGroupID.x * 32);
+        // use signed ints to avoid underflow
+        int superflousTaskCount    = max(taskCount - int(meshletCount), 0);
+        gl_TaskCountNV              = 32 - superflousTaskCount;
+    }
+    OUT.meshletIndices[gl_LocalInvocationID.x] = gl_GlobalInvocationID.x;
 }
\ No newline at end of file
diff --git a/projects/mesh_shader/resources/shaders/shader.vert b/projects/mesh_shader/resources/shaders/shader.vert
index 6b5a0ce47a8fb2c5dbd7841985b31013ac307e8f..636545262a70a490a6aabfd5809a61f226c84a30 100644
--- a/projects/mesh_shader/resources/shaders/shader.vert
+++ b/projects/mesh_shader/resources/shaders/shader.vert
@@ -9,10 +9,12 @@ layout(location = 1) out uint dummyOutput;
 
 layout( push_constant ) uniform constants{
     mat4 mvp;
+    uint padding; // pad to same size as mesh shader constants
 };
 
 void main()	{
 	gl_Position = mvp * vec4(inPosition, 1.0);
-	dummyOutput = 0;
 	passNormal  = inNormal;
+    
+    dummyOutput = padding * 0;  // padding must be used, else compiler shrinks constant size
 }
\ No newline at end of file
diff --git a/projects/mesh_shader/src/main.cpp b/projects/mesh_shader/src/main.cpp
index 5d90672396eb6413a3296c64b422db197a11ee3e..37855e9b10d5c19a1c91ac66e60196826de0ed23 100644
--- a/projects/mesh_shader/src/main.cpp
+++ b/projects/mesh_shader/src/main.cpp
@@ -226,33 +226,40 @@ int main(int argc, const char** argv) {
 			continue;
 		}
 		
-        auto end = std::chrono::system_clock::now();
-        auto deltatime = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
-        start = end;
+		auto end = std::chrono::system_clock::now();
+		auto deltatime = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
+		start = end;
 		
 		cameraManager.update(0.000001 * static_cast<double>(deltatime.count()));
 
 		glm::mat4 modelMatrix = *reinterpret_cast<glm::mat4*>(&mesh.meshes.front().modelMatrix);
-        glm::mat4 mvp = cameraManager.getActiveCamera().getMVP() * modelMatrix;
+		glm::mat4 mvp = cameraManager.getActiveCamera().getMVP() * modelMatrix;
 
-        const std::vector<vkcv::ImageHandle> renderTargets = { swapchainInput, depthBuffer };
+		struct PushConstants {
+			glm::mat4 mvp;
+			uint32_t meshletCount;
+		};
+		PushConstants pushConstants{ mvp, meshShaderModelData.meshlets.size() };
+
+		const std::vector<vkcv::ImageHandle> renderTargets = { swapchainInput, depthBuffer };
 		auto cmdStream = core.createCommandStream(vkcv::QueueType::Graphics);
 
 		const bool useMeshShader = true;
 
-		vkcv::PushConstants pushConstantData(sizeof(glm::mat4));
-		pushConstantData.appendDrawcall(mvp);
+		vkcv::PushConstants pushConstantData(sizeof(pushConstants));
+		pushConstantData.appendDrawcall(pushConstants);
 
 		if (useMeshShader) {
 
 			vkcv::DescriptorSetUsage descriptorUsage(0, core.getDescriptorSet(meshShaderDescriptorSet).vulkanHandle);
+			const uint32_t taskCount = (meshShaderModelData.meshlets.size() + 31) / 32;
 
 			core.recordMeshShaderDrawcalls(
 				cmdStream,
 				renderPass,
 				meshShaderPipeline,
 				pushConstantData,
-				{ vkcv::MeshShaderDrawcall({descriptorUsage}, meshShaderModelData.meshlets.size())},
+				{ vkcv::MeshShaderDrawcall({descriptorUsage}, taskCount)},
 				{ renderTargets });
 		}
 		else {