#version 450 core
#extension GL_GOOGLE_include_directive : enable
#extension GL_ARB_separate_shader_objects : enable

layout(local_size_x = 256) in;

#include "physics.inc"
#include "particle.inc"

layout(set=0, binding=0, std430) buffer particleBuffer {
    particle_t particles [];
};

layout(set=0, binding=1, std430) readonly buffer particleBufferCopy {
    particle_t particlesCopy [];
};

layout(set=1, binding=0, std430) readonly buffer randomBuffer {
    float randomData [];
};

#include "event.inc"

layout(set=1, binding=1, std430) buffer eventBuffer {
    event_t events [];
};

layout(set=1, binding=2, std430) buffer startIndexBuffer {
    uint startIndex [];
};

#include "smoke.inc"

layout(set=2, binding=0, std430) writeonly buffer smokeBuffer {
    smoke_t smokes [];
};

layout(set=2, binding=1, std430) buffer smokeIndexBuffer {
    uint smokeIndex;
    uint trailIndex;
    uint pointIndex;
};

#include "trail.inc"

layout(set=3, binding=0, std430) writeonly buffer trailBuffer {
    trail_t trails [];
};

#include "point.inc"

layout(set=3, binding=1, std430) readonly buffer pointBuffer {
    point_t points [];
};

layout( push_constant ) uniform constants{
    float t;
    float dt;
};

void main() {
    uint id = gl_GlobalInvocationID.x;

    if (id >= particles.length()) {
        return;
    }

    float lifetime = particles[id].lifetime;

    if (lifetime > 0.0f) {
        return;
    }

    uint event_id = events.length();
    uint index = 0;

    for (uint i = 0; i < events.length(); i++) {
        const float start = events[i].startTime;

        if ((events[i].continuous < 1) && (t < start)) {
            continue;
        }

        index = atomicAdd(events[i].index, 1);

        if (events[i].continuous < 1) {
            if (events[i].count > index) {
                event_id = i;
                break;
            } else {
                atomicAdd(events[i].index, -1);
            }
        } else {
            if (events[i].continuous > index){
                event_id = i;
                break;
            } else {
                if (events[i].contCount > 0) {
                    atomicAdd(events[i].contCount, -1);
                    events[i].index = 0;
                }

                atomicAdd(events[i].index, -1);
            }
        }
    }

    if (event_id >= events.length()) {
        return;
    }

    lifetime = events[event_id].lifetime;

    vec3 direction;
    if (dot(events[event_id].direction, events[event_id].direction) <= 0.0f) {
        direction = vec3(
            randomData[(id * 3 + 0) % randomData.length()],
            randomData[(id * 3 + 1) % randomData.length()],
            randomData[(id * 3 + 2) % randomData.length()]
        );
    } else {
        direction = events[event_id].direction;
    }

    vec3 color = normalize(events[event_id].color);
    const float v = events[event_id].velocity;

    vec3 velocity = vec3(0.0f);
    float size = events[event_id].size;

    const uint pid = events[event_id].parent;

    if (pid < events.length()) {
        const uint spawnCount = events[pid].count;
        const uint spawnId = startIndex[pid] + (id % spawnCount);

        if (spawnId < particlesCopy.length()) {
            particles[id].position = particlesCopy[spawnId].position;
            velocity += particlesCopy[spawnId].velocity;
            size = particlesCopy[spawnId].size;
        }
    }

    if ((0 == index) && (events[event_id].continuous < 1)) {
        const uint sid = atomicAdd(smokeIndex, 1) % smokes.length();

        smokes[sid].position = particles[id].position;
        smokes[sid].size = size * (1.0f + friction);
        smokes[sid].velocity = velocity;
        smokes[sid].scaling = v;
        smokes[sid].color = mix(color, vec3(1.0f), 0.5f);
        smokes[sid].eventID = event_id;
    }

    velocity += normalize(direction) * v;

    const float split = pow(1.0f / events[event_id].count, 1.0f / 3.0f);

    particles[id].lifetime = lifetime;
    particles[id].velocity = velocity;
    particles[id].size = size * split;
    particles[id].color = color;
    particles[id].mass = events[event_id].mass / events[event_id].count;
    particles[id].eventId = event_id;

    {
        const uint tid = atomicAdd(trailIndex, 1) % trails.length();
        const uint trailLen = 96 + int(randomData[(tid + id) % randomData.length()] * 32);

        const uint startIndex = atomicAdd(pointIndex, trailLen) % points.length();

        trails[tid].particleIndex = id;
        trails[tid].startIndex = startIndex;
        trails[tid].endIndex = (startIndex + trailLen - 1) % points.length();
        trails[tid].useCount = 0;
        trails[tid].color = mix(color, vec3(1.0f), 0.75f);
        trails[tid].lifetime = lifetime + (dt * trailLen) * 0.5f;
    }
}