diff --git a/projects/saf_r/shaders/raytracing.comp b/projects/saf_r/shaders/raytracing.comp
index 1af6808f4009452205ec6a14f523a5e9d3772eeb..e43d6ed1a0317f851323b4f562906f38f4af93fd 100644
--- a/projects/saf_r/shaders/raytracing.comp
+++ b/projects/saf_r/shaders/raytracing.comp
@@ -1,7 +1,8 @@
 #version 450 core
 #extension GL_ARB_separate_shader_objects : enable
 
-#define M_PI 3.1415926535897932384626433832795
+const float pi      = 3.1415926535897932384626433832795;
+const float hitBias = 0.0001;   // used to offset hits to avoid self intersection
 
 layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
 
@@ -26,15 +27,11 @@ layout(std430, binding = 0) coherent buffer lights{
     Light inLights[];
 };
 
-layout(std430, binding = 1) coherent buffer materials{
-    Material inMaterials[];
-};
-
-layout(std430, binding = 2) coherent buffer spheres{
+layout(std430, binding = 1) coherent buffer spheres{
     Sphere inSpheres[];
 };
 
-layout(set=0, binding=3, rgba8) uniform image2D outImage;
+layout(set=0, binding = 2, rgba8) uniform image2D outImage;
 
 layout( push_constant ) uniform constants{
     float lightCount;
@@ -42,116 +39,141 @@ layout( push_constant ) uniform constants{
     float sphereCount;
 };
 
-
-vec3 safr_reflect(const vec3 dir, const vec3 hit_center) {
-    return dir - hit_center * 2.f * (dot(dir, hit_center));
-}
-
 bool ray_intersect(const vec3 origin, const vec3 dir, out float t0, const int id){
-        vec3 L = inSpheres[id].center - origin;
-        float tca = dot(L, dir);
-        float d2 = dot(L, L) - tca * tca;
-        if (d2 > inSpheres[id].radius * inSpheres[id].radius){
-            return false;
-        }
-        float thc = float(sqrt(inSpheres[id].radius * inSpheres[id].radius - d2));
-        t0 = tca - thc;
-        float t1 = tca + thc;
-        if (t0 < 0) {
-            t0 = t1;
-        }
-        if (t0 < 0){
-            return false;
-        }
-        return true;
+    vec3 L = inSpheres[id].center - origin;
+    float tca = dot(L, dir);
+    float d2 = dot(L, L) - tca * tca;
+    if (d2 > inSpheres[id].radius * inSpheres[id].radius){
+        return false;
+    }
+    float thc = float(sqrt(inSpheres[id].radius * inSpheres[id].radius - d2));
+    t0 = tca - thc;
+    float t1 = tca + thc;
+    if (t0 < 0) {
+        t0 = t1;
+    }
+    if (t0 < 0){
+        return false;
+    }
+    return true;
 }
 
-bool sceneIntersect(const vec3 orig, const vec3 dir, out vec3 hit, out vec3 hit_center, out Material material) {
-    float spheres_dist = 1.0 / 0.0;
+struct Intersection{
+    bool hit;
+    vec3 pos;
+    vec3 N;
+    Material material;
+};
+
+Intersection sceneIntersect(const vec3 rayOrigin, const vec3 rayDirection) {
+    float   min_d    = 100000;  // lets start with something big
+    
+    Intersection intersection;
+    intersection.hit = false;
+    
     for (int i = 0; i < sphereCount; i++) {
-        float dist_i;
-        if (ray_intersect(orig, dir, dist_i, i) && dist_i < spheres_dist) {
-            spheres_dist = dist_i;
-            hit = orig + dir * dist_i;
-            hit_center = normalize(hit - inSpheres[i].center);
-            material = inSpheres[i].material;
-            break;
+        float d;
+        if (ray_intersect(rayOrigin, rayDirection, d, i)) {
+            
+            intersection.hit = true;
+            
+            if(d < min_d){
+                min_d = d;
+                intersection.pos        = rayOrigin + rayDirection * d;
+                intersection.N          = normalize(intersection.pos - inSpheres[i].center);
+                intersection.material   = inSpheres[i].material;
+            }
         }
     }
-    return spheres_dist < 1000;
+    return intersection;
 }
 
+vec3 biasHitPosition(vec3 hitPos, vec3 rayDirection, vec3 N){
+    // return hitPos + N * hitBias; // works as long as no refraction/transmission is used and camera is outside sphere
+    return hitPos + sign(dot(rayDirection, N)) * N * hitBias;
+}
 
-vec3 castRay(const vec3 orig, const vec3 dir, int max_depth) {
+vec3 computeHitLighting(Intersection intersection, vec3 V, out float outReflectionThroughput){
+    
+    float lightIntensityDiffuse  = 0;
+    float lightIntensitySpecular = 0;
+
+    for (int i = 0; i < lightCount; i++) {
+        
+        vec3   L = normalize(inLights[i].position - intersection.pos);
+        float  d = distance(inLights[i].position, intersection.pos);
+
+        vec3 shadowOrigin = biasHitPosition(intersection.pos, L, intersection.N);
+        
+        Intersection shadowIntersection = sceneIntersect(shadowOrigin, L);
+        
+        bool isShadowed = false;
+        if(shadowIntersection.hit)
+            isShadowed = distance(shadowIntersection.pos, shadowOrigin) < d;
+		
+        if(isShadowed)
+            continue;        
+        
+        lightIntensityDiffuse  += inLights[i].intensity * max(0.f, dot(L, intersection.N));
+        lightIntensitySpecular += pow(max(0.f, dot(reflect(V, intersection.N), L)), intersection.material.specular_exponent) * inLights[i].intensity;
+    }
 
-    int depth = 0;
-    vec3 point, hit_center;
-    Material material;
-    vec3 result = vec3(0.2, 0.7, 0.8);
-    bool intersect;
-    vec3 direction = dir;
-    vec3 reflect_dir = direction;
-    vec3 reflect_orig = orig;
+    outReflectionThroughput = intersection.material.albedo[2];
+    return intersection.material.diffuse_color * lightIntensityDiffuse * intersection.material.albedo[0] + lightIntensitySpecular * intersection.material.albedo[1];
+}
+
+vec3 castRay(const vec3 initialOrigin, const vec3 initialDirection, int max_depth) {
+    
+    vec3 skyColor = vec3(0.2, 0.7, 0.8);
 
+    vec3 rayOrigin    = initialOrigin;
+    vec3 rayDirection = initialDirection;
+    
+    float   reflectionThroughput    = 1;
+    vec3    color                   = vec3(0);
+    
     for(int i = 0; i < max_depth; i++){
-        depth++;
-        intersect = sceneIntersect(reflect_orig, reflect_dir, point, hit_center, material);
-        if(!intersect){
-            break;
+
+        Intersection intersection = sceneIntersect(rayOrigin, rayDirection);
+        
+        vec3 hitColor;
+        float hitReflectionThroughput;
+        
+        if(intersection.hit){
+            hitColor = computeHitLighting(intersection, rayDirection, hitReflectionThroughput);
         }
-        //compute recursive directions and origins of rays and then call the function
-        reflect_dir = normalize(safr_reflect(direction, hit_center));
-        reflect_orig = (dot(reflect_dir, hit_center) < 0) ? point - hit_center * float(1e-3) : point + hit_center * float(1e-3);// offset the original point to avoid occlusion by the object itself
-        direction = reflect_dir;
-    }
-    if (depth == 1){
-        return result;
-    }
+        else{
+            hitColor = skyColor;
+        }
+        
+        color                   += hitColor * reflectionThroughput;
+        reflectionThroughput    *= hitReflectionThroughput;
+        
+        if(!intersection.hit)
+            break;
 
-    vec3 reflect_color = result;
-    for(int i = 0; i < depth; i++){
-
-         //compute shadows and other light properties for the returned ray color
-         float diffuse_light_intensity = 0, specular_light_intensity = 0;
-
-         for (int i = 0; i < lightCount; i++) {
-             vec3 light_dir = normalize(inLights[i].position - point);
-             float light_distance = distance(inLights[i].position, point);
-
-             vec3 shadow_orig = (dot(light_dir, hit_center) < 0) ? point - hit_center * float(1e-3) :
-             point + hit_center * float(1e-3);// checking if the point lies in the shadow of the lights[i]
-             vec3 shadow_pt, shadow_hit_center;
-             Material tmpmaterial;
-             if (sceneIntersect(shadow_orig, light_dir, shadow_pt, shadow_hit_center, tmpmaterial)
-             && distance(shadow_pt, shadow_orig) < light_distance){
-                 continue;
-             }
-             diffuse_light_intensity += inLights[i].intensity * max(0.f, dot(light_dir, hit_center));
-             specular_light_intensity += pow(max(0.f, dot(safr_reflect(light_dir, hit_center), dir)), material.specular_exponent) * inLights[i].intensity;
-         }
-         return result = material.diffuse_color * diffuse_light_intensity * material.albedo[0] +
-         vec3(1., 1., 1.) * specular_light_intensity * material.albedo[1] + reflect_color * material.albedo[2];
+        rayDirection    = normalize(reflect(rayDirection, intersection.N));
+        rayOrigin       = biasHitPosition(intersection.pos, rayDirection, intersection.N);
     }
-    return result;
+
+    return color;
 }
 
 vec3 computeDirection(ivec2 coord){
 
-    ivec2 outImageRes = imageSize(outImage);
-    float fov = M_PI / 2.f;
-    //float x = (2 * (i + 0.5f) / (float)width - 1) * tan(fov / 2.f) * width / (float)height;
-    float x = (2 * (float(coord.x) + 0.5f) / float(outImageRes.x) - 1) * tan(fov / 2.f) * outImageRes.x / float(outImageRes.y);
-    //float y = -(2 * (j + 0.5f) / (float)height - 1) * tan(fov / 2.f);
-    float y = -(2 * (float(coord.y) + 0.5f) / float(outImageRes.y) - 1) * tan(fov / 2.f);
-    vec3 dir = normalize(vec3(x, y, -1));
-    return dir;
+    ivec2 outImageRes   = imageSize(outImage);
+    float fov           = pi / 2.f;
+    float x             =  (2 * (float(coord.x) + 0.5f) / float(outImageRes.x) - 1) * tan(fov / 2.f) * outImageRes.x / float(outImageRes.y);
+    float y             = -(2 * (float(coord.y) + 0.5f) / float(outImageRes.y) - 1) * tan(fov / 2.f);
+    return normalize(vec3(x, y, -1));
 }
 
 
 void main(){
-    ivec2 coord = ivec2(gl_GlobalInvocationID.xy);
-    int max_depth = 4;
-    vec3 direction = computeDirection(coord);
-    vec3 color = castRay(vec3(0,0,0), direction, max_depth);
+    ivec2 coord     = ivec2(gl_GlobalInvocationID.xy);
+    int max_depth   = 4;
+    vec3 direction  = computeDirection(coord);
+    vec3 color      = castRay(vec3(0,0,0), direction, max_depth);
+    
     imageStore(outImage, coord, vec4(color, 0.f));
 }
\ No newline at end of file
diff --git a/projects/saf_r/src/main.cpp b/projects/saf_r/src/main.cpp
index 2ac638b4c373d7b69cb85bc4ab08e81a481746f2..27dfafbd5e90c7774fb6d5d23a55f357f0160b33 100644
--- a/projects/saf_r/src/main.cpp
+++ b/projects/saf_r/src/main.cpp
@@ -133,12 +133,6 @@ int main(int argc, const char** argv) {
 	);
 	sphereBuffer.fill(spheres);
 
-	vkcv::Buffer<safrScene::Material> materialBuffer = core.createBuffer<safrScene::Material>(
-		vkcv::BufferType::STORAGE,
-		materials.size()
-	);
-	materialBuffer.fill(materials);
-
 	glm::vec3 pushData = glm::vec3((lights.size()), (materials.size()), (spheres.size()));
 	vkcv::PushConstants pushConstantsCompute(sizeof(glm::vec3));
 	pushConstantsCompute.appendDrawcall(pushData);
@@ -150,8 +144,7 @@ int main(int argc, const char** argv) {
 
 	vkcv::DescriptorWrites computeWrites;
 	computeWrites.storageBufferWrites = { vkcv::BufferDescriptorWrite(0,lightsBuffer.getHandle()),
-                                          vkcv::BufferDescriptorWrite(1,materialBuffer.getHandle()),
-                                          vkcv::BufferDescriptorWrite(2,sphereBuffer.getHandle())};
+                                          vkcv::BufferDescriptorWrite(1,sphereBuffer.getHandle())};
     core.writeDescriptorSet(computeDescriptorSet, computeWrites);
 
 	const auto& context = core.getContext();
@@ -222,7 +215,6 @@ int main(int argc, const char** argv) {
 			continue;
 		}
 
-
 		auto end = std::chrono::system_clock::now();
 		auto deltatime = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
 		start = end;
@@ -236,7 +228,7 @@ int main(int argc, const char** argv) {
 
 		auto cmdStream = core.createCommandStream(vkcv::QueueType::Graphics);
 
-        computeWrites.storageImageWrites = { vkcv::StorageImageDescriptorWrite(3, swapchainInput)};
+        computeWrites.storageImageWrites = { vkcv::StorageImageDescriptorWrite(2, swapchainInput)};
         core.writeDescriptorSet(computeDescriptorSet, computeWrites);
 
         core.prepareImageForStorage (cmdStream, swapchainInput);
@@ -252,14 +244,6 @@ int main(int argc, const char** argv) {
 
 		core.recordBufferMemoryBarrier(cmdStream, lightsBuffer.getHandle());
 
-		/*core.recordDrawcallsToCmdStream(
-			cmdStream,
-			safrPass,
-			safrPipeline,
-			pushConstants,
-			{ drawcall },
-			{ swapchainInput });*/
-
 		core.prepareSwapchainImageForPresent(cmdStream);
 		core.submitCommandStream(cmdStream);