#version 450
#extension GL_GOOGLE_include_directive : enable
#extension GL_EXT_control_flow_attributes : enable

layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;

#include "particle.inc"

layout(set=0, binding=0, std430) restrict buffer particleBuffer {
    Particle particles [];
};

layout(set=0, binding=1) uniform texture3D gridImage;
layout(set=0, binding=2) uniform texture3D gridOldImage;
layout(set=0, binding=3) uniform sampler gridSampler;

layout( push_constant ) uniform constants {
    float lame1;
    float lame2;
    float alpha;
    float beta;
    float t;
    float dt;
};

void main()	{
    memoryBarrierBuffer();
    memoryBarrierImage();

    if (gl_GlobalInvocationID.x < particles.length()) {
        Particle particle = particles[gl_GlobalInvocationID.x];

        vec3 position = particle.minimal.position;
        float size = particle.minimal.size;
        float volume = sphere_volume(size);
        float mass = particle.minimal.mass;

        ivec3 gridResolution = textureSize(sampler3D(gridImage, gridSampler), 0);
        ivec3 gridWindow = ivec3(size * 2.0f * gridResolution);

        mat3 affine_D = mat3(0.0f);
        mat3 affine_B = mat3(0.0f);

        vec3 velocity_pic = vec3(0.0f);
        vec3 velocity_flip = vec3(particle.minimal.velocity);

        int i, j, k;

        for (i = -gridWindow.x; i <= gridWindow.x; i++) {
            for (j = -gridWindow.y; j <= gridWindow.y; j++) {
                for (k = -gridWindow.z; k <= gridWindow.z; k++) {
                    vec3 offset = vec3(i, j, k) / gridResolution;
                    vec3 voxel = position + offset;

                    vec4 gridSample = texture(sampler3D(gridImage, gridSampler), voxel);
                    vec4 gridOldSample = texture(sampler3D(gridOldImage, gridSampler), voxel);

                    float weight = voxel_particle_weight(voxel, particle.minimal);
                    vec3 velocity = gridSample.xyz * weight;

                    affine_D += outerProduct(weight * offset, offset);
                    affine_B += outerProduct(velocity, offset);

                    velocity_pic += velocity;
                    velocity_flip += (gridSample.xyz - gridOldSample.xyz) * weight;
                }
            }
        }

        barrier();
        memoryBarrierBuffer();

        mat3 mls_Q = mat3(0.0f);
        mat3 affine_C = mat3(0.0f);

        mat3 F = mat3(particle.deformation);

        if (abs(determinant(affine_D)) > 0.0f) {
            mat3 D_inv = inverse(affine_D);
            float J = determinant(F);

            if ((J > 0.0f) && (mass > 0.0f)) {
                mat3 F_T = transpose(F);

                mat3 delta = lame2 * (F * F_T - mat3(1.0f)) + lame1 * log(J);

                mls_Q += beta * dt * volume * delta * D_inv;
            }

            affine_C = affine_B * D_inv;
            mls_Q += beta * affine_C * mass;
        }

        F = (mat3(1.0f) + dt * affine_C) * F;

        vec3 velocity = mix(velocity_pic, velocity_flip, alpha);

        position = position + velocity * dt;

        for (uint i = 0; i < 3; i++) {
            if (position[i] - size < 0.0f) {
                position[i] = -position[i] + 2.0f * size;
                velocity[i] *= -1.0f;
            } else
            if (position[i] + size > 1.0f) {
                position[i] = 2.0f * (1.0f - size) - position[i];
                velocity[i] *= -1.0f;
            }
        }

        barrier();
        memoryBarrierBuffer();

        particles[gl_GlobalInvocationID.x].minimal.position = position;
        particles[gl_GlobalInvocationID.x].minimal.velocity = velocity;
        particles[gl_GlobalInvocationID.x].deformation = mat4(F);
        particles[gl_GlobalInvocationID.x].mls = mat4(mls_Q);
    }

    barrier();
    memoryBarrierBuffer();
}