diff --git a/projects/fire_works/shaders/generation.comp b/projects/fire_works/shaders/generation.comp
index 83b0167779bc6bc1836c52642359bc8419be5152..864a9e44e5806394225cec7a769049eaea4eda35 100644
--- a/projects/fire_works/shaders/generation.comp
+++ b/projects/fire_works/shaders/generation.comp
@@ -148,7 +148,7 @@ void main() {
 
     {
         const uint tid = atomicAdd(trailIndex, 1) % trails.length();
-        const uint trailLen = 16; // 64 + int(randomData[(tid + id) % randomData.length()] * 32);
+        const uint trailLen = 96 + int(randomData[(tid + id) % randomData.length()] * 32);
 
         const uint startIndex = atomicAdd(pointIndex, trailLen) % points.length();
 
diff --git a/projects/fire_works/shaders/smoke.frag b/projects/fire_works/shaders/smoke.frag
index 7339c6ee64d321267f0cafb2f0169c9189d15c42..ac5740d8d774ce4d3d6320c35e208b83a0060694 100644
--- a/projects/fire_works/shaders/smoke.frag
+++ b/projects/fire_works/shaders/smoke.frag
@@ -5,7 +5,7 @@
 #include "physics.inc"
 
 layout(location = 0) in vec3 passPos;
-layout(location = 1) in vec3 passView;
+layout(location = 1) in vec3 passDir;
 layout(location = 2) in vec3 passColor;
 layout(location = 3) in float passDensity;
 layout(location = 4) in flat int passSmokeIndex;
@@ -16,11 +16,6 @@ layout(set=1, binding=0, std430) readonly buffer randomBuffer {
     float randomData [];
 };
 
-layout( push_constant ) uniform constants{
-    mat4 view;
-    mat4 projection;
-};
-
 #define NUM_SMOKE_SAMPLES 16
 
 void main()	{
@@ -28,10 +23,8 @@ void main()	{
         discard;
     }
 
-    vec3 dir = -normalize((inverse(view) * vec4(passView, 0)).xyz);
-
     vec3 start = passPos;
-    vec3 end = start + dir * 3.5f;
+    vec3 end = start + normalize(passDir) * 3.5f;
 
     vec4 result = vec4(0);
 
diff --git a/projects/fire_works/shaders/smoke.vert b/projects/fire_works/shaders/smoke.vert
index aa17b3d6f2c3bd0387f139d741c6479c5a6abd1c..0392deca4bbf2a174c8a9bd80fad86ee1fd9fcee 100644
--- a/projects/fire_works/shaders/smoke.vert
+++ b/projects/fire_works/shaders/smoke.vert
@@ -10,14 +10,14 @@ layout(set=0, binding=0, std430) readonly buffer smokeBuffer {
 layout(location = 0) in vec3 vertexPos;
 
 layout(location = 0) out vec3 passPos;
-layout(location = 1) out vec3 passView;
+layout(location = 1) out vec3 passDir;
 layout(location = 2) out vec3 passColor;
 layout(location = 3) out float passDensity;
 layout(location = 4) out flat int passSmokeIndex;
 
 layout( push_constant ) uniform constants{
-    mat4 view;
-    mat4 projection;
+    mat4 mvp;
+    vec3 camera;
 };
 
 void main()	{
@@ -25,10 +25,10 @@ void main()	{
     float size = smokes[gl_InstanceIndex].size;
     vec3 color = smokes[gl_InstanceIndex].color;
 
-    vec4 viewPos = view * vec4(position + vertexPos * size, 1);
+    vec3 pos = position + vertexPos * size;
 
     passPos = vertexPos;
-    passView = viewPos.xyz;
+    passDir = pos - camera;
     passColor = color;
 
     if (size > 0.0f) {
@@ -40,5 +40,5 @@ void main()	{
     passSmokeIndex = gl_InstanceIndex;
 
     // transform position into projected view space
-    gl_Position = projection * viewPos;
+    gl_Position = mvp * vec4(pos, 1);
 }
\ No newline at end of file
diff --git a/projects/fire_works/shaders/trail.geom b/projects/fire_works/shaders/trail.geom
index e6872cda6c76b6626b8ea3083dd50781ba72b537..027943473d05cddc9db6019c7ee36771fcb91d2e 100644
--- a/projects/fire_works/shaders/trail.geom
+++ b/projects/fire_works/shaders/trail.geom
@@ -22,14 +22,14 @@ layout(location = 3) in uint geomStartIndex [1];
 layout(location = 4) in uint geomUseCount [1];
 
 layout(location = 0) out vec3 passPos;
-layout(location = 1) out vec3 passView;
+layout(location = 1) out vec3 passDir;
 layout(location = 2) out vec3 passColor;
 layout(location = 3) out float passDensity;
 layout(location = 4) out flat int passSmokeIndex;
 
 layout( push_constant ) uniform constants{
-    mat4 view;
-    mat4 projection;
+    mat4 mvp;
+    vec3 camera;
 };
 
 void main() {
@@ -41,58 +41,54 @@ void main() {
     const uint startIndex = geomStartIndex[0];
     const uint useCount = geomUseCount[0];
 
-    if (useCount <= 1) {
-        return;
-    }
-
-    vec4 viewPositions [2];
+    const uint indexOffset = (gl_InvocationID * (INSTANCE_LEN - 1));
+    const uint instanceIndex = startIndex + indexOffset;
 
-    for (uint i = 0; i < 2; i++) {
-        const vec3 position = points[startIndex + i].position;
+    uint count = min(INSTANCE_LEN, useCount);
 
-        viewPositions[i] = view * vec4(position, 1);
+    if ((indexOffset >= useCount) && (indexOffset + INSTANCE_LEN > useCount)) {
+        count = indexOffset - useCount;
     }
 
-    vec3 pos = viewPositions[0].xyz;
-    vec3 dir = normalize(cross(viewPositions[1].xyz - pos, viewPositions[0].xyz));
+    if (count <= 1) {
+        return;
+    }
 
     const float trailFactor = mediumDensity / friction;
 
-    for (uint i = 0; i < useCount; i++) {
-        const float u = float(i + 1) / float(useCount);
+    for (uint i = 0; i < count; i++) {
+        const float u = float(indexOffset + i + 1) / float(useCount);
 
-        const vec3 position = points[startIndex + i].position;
-        const float size = points[startIndex + i].size;
+        const uint index = (instanceIndex + i) % points.length();
 
-        vec4 viewPos = view * vec4(position, 1);
+        const vec3 position = points[index].position;
+        const float size = points[index].size;
+        const vec3 velocity = points[index].velocity;
 
-        if (i > 0) {
-            dir = normalize(cross(viewPos.xyz - pos, viewPos.xyz));
-            pos = viewPos.xyz;
-        }
+        const vec3 dir = normalize(cross(abs(velocity), position - camera));
 
         vec3 offset = dir * size;
         float density = trailFactor * (1.0f - u * u) / size;
 
-        vec4 v0 = viewPos - vec4(offset, 0);
-        vec4 v1 = viewPos + vec4(offset, 0);
+        const vec3 p0 = position - offset;
+        const vec3 p1 = position + offset;
 
         passPos = vec3(u, -1.0f, -1.0f);
-        passView = v0.xyz;
+        passDir = vec3(-0.1f * u, +0.2f, 2.0f);
         passColor = mix(color, trailColor, u);
         passDensity = density;
         passSmokeIndex = int(id);
 
-        gl_Position = projection * v0;
+        gl_Position = mvp * vec4(p0, 1);
         EmitVertex();
 
         passPos = vec3(u, +1.0f, -1.0f);
-        passView = v1.xyz;
+        passDir = vec3(-0.1f * u, -0.2f, 2.0f);
         passColor = mix(color, trailColor, u);
         passDensity = density;
         passSmokeIndex = int(id);
 
-        gl_Position = projection * v1;
+        gl_Position = mvp * vec4(p1, 1);
         EmitVertex();
     }
 
diff --git a/projects/fire_works/src/main.cpp b/projects/fire_works/src/main.cpp
index 3d894904d81fada1d3a93c7c76a160459bab9fab..c4de5a7ac88c89db3fa2ca419a6503414926cbfe 100644
--- a/projects/fire_works/src/main.cpp
+++ b/projects/fire_works/src/main.cpp
@@ -68,6 +68,11 @@ struct draw_particles_t {
 	uint32_t height;
 };
 
+struct draw_smoke_t {
+	glm::mat4 mvp;
+	glm::vec3 camera;
+};
+
 #define PARTICLE_COUNT (1024)
 #define SMOKE_COUNT (512)
 #define TRAIL_COUNT (2048)
@@ -838,13 +843,14 @@ int main(int argc, const char **argv) {
 		
 		core.recordBufferMemoryBarrier(cmdStream, smokeBuffer.getHandle());
 		
-		glm::mat4 smokeMatrices [2];
-		smokeMatrices[0] = camera.getView();
-		smokeMatrices[1] = camera.getProjection();
+		draw_smoke_t draw_smoke {
+			camera.getMVP(),
+			camera.getPosition()
+		};
 		
 		core.recordBeginDebugLabel(cmdStream, "Draw smoke", { 1.0f, 0.5f, 1.0f, 1.0f });
-		vkcv::PushConstants pushConstantsDraw1 (sizeof(glm::mat4) * 2);
-		pushConstantsDraw1.appendDrawcall(smokeMatrices);
+		vkcv::PushConstants pushConstantsDraw1 (sizeof(draw_smoke_t));
+		pushConstantsDraw1.appendDrawcall(draw_smoke);
 		
 		core.recordDrawcallsToCmdStream(
 			cmdStream,