diff --git a/projects/rtx_ambient_occlusion/resources/shaders/ambientOcclusion.rgen b/projects/rtx_ambient_occlusion/resources/shaders/ambientOcclusion.rgen
index ca74919da8e4b8f49d34f1861c0923c8af03f47f..711070fcf1eec18253f331cbd133330791fa6be6 100644
--- a/projects/rtx_ambient_occlusion/resources/shaders/ambientOcclusion.rgen
+++ b/projects/rtx_ambient_occlusion/resources/shaders/ambientOcclusion.rgen
@@ -3,8 +3,6 @@
 
 #define M_PI 3.1415926535897932384626433832795
 
-// TODO: credits!!!
-
 // A location for a ray payload (we can have multiple of these)
 layout(location = 0) rayPayloadEXT Payload {
   float hitSky;
@@ -22,44 +20,61 @@ layout( push_constant ) uniform constants {
     vec4 camera_forward;    // for computing ray direction
 } camera;
 
-uint rngState = gl_LaunchIDEXT.x * 2000 + gl_LaunchIDEXT.y;     // each shader call has its own rngState
+// random() and helpers from: https://www.shadertoy.com/view/XlycWh
+float g_seed = 0;
+
+uint base_hash(uvec2 p) {
+    p = 1103515245U*((p >> 1U)^(p.yx));
+    uint h32 = 1103515245U*((p.x)^(p.y>>3U));
+    return h32^(h32 >> 16);
+}
+
+vec2 hash2(inout float seed) {
+    uint n = base_hash(floatBitsToUint(vec2(seed+=.1,seed+=.1)));
+    uvec2 rz = uvec2(n, n*48271U);
+    return vec2(rz.xy & uvec2(0x7fffffffU))/float(0x7fffffff);
+}
+
+void initRandom(uvec2 coord){
+	g_seed = float(base_hash(coord)/float(0xffffffffU));
+}
 
-float random(vec2 uv, float seed) {
-  return fract(sin(mod(dot(uv, vec2(12.9898, 78.233)) + 1113.1 * seed, M_PI)) * 43758.5453);;
+vec2 random(){
+	return hash2(g_seed);
 }
 
 /**
- * Retrieves pixel information.
+ * Traces the ray from the camera and provides the intersection information.
  * @param[in,out] hitSky Defines if the ray has hit the sky
  * @param[in,out] pos The position of intersection
  * @param[in,out] norm The normal at the position of intersection
  */
-void GetPixelInfo(out bool hitSky, out vec3 pos, out vec3 norm){
+void TraceCameraRay(out bool hitSky, out vec3 pos, out vec3 norm){
   // Use a camera model to generate a ray for this pixel.
-  vec2 uv = gl_LaunchIDEXT.xy + vec2(random(gl_LaunchIDEXT.xy, 0), random(gl_LaunchIDEXT.xy, 1));
+  vec2 uv = gl_LaunchIDEXT.xy + vec2(random()); // random breaks up aliasing
   uv /= vec2(gl_LaunchSizeEXT.xy);
   uv = (uv * 2.0 - 1.0) // normalize uv coordinates into Vulkan viewport space
     * vec2(1.0, -1.0);  // flips y-axis
-  const vec3 orig              = camera.camera_position.xyz;
-  const vec3 dir               = normalize(uv.x * camera.camera_right + uv.y * camera.camera_up + camera.camera_forward).xyz;
+  const vec3 orig   = camera.camera_position.xyz;
+  const vec3 dir    = normalize(uv.x * camera.camera_right + uv.y * camera.camera_up + camera.camera_forward).xyz;
 
   // Trace a ray into the scene; get back data in the payload.
   traceRayEXT(tlas,  // Acceleration structure
-              gl_RayFlagsOpaqueEXT,   // Ray flags, here saying "ignore intersection shaders"
-              0xFF,   // 8-bit instance mask, here saying "trace against all instances"
-              0,      // SBT record offset
-              0,      // SBT record stride for offset
-              0,      // Miss index
-              orig,   // Ray origin
-              0.0,    // Minimum t-value
-              dir,    // Ray direction
-              1000.0, // Maximum t-value
-              0);     // Location of payload
+              gl_RayFlagsOpaqueEXT, // Ray flags, here saying "ignore intersection shaders"
+              0xFF,                 // 8-bit instance mask, here saying "trace against all instances"
+              0,                    // SBT record offset
+              0,                    // SBT record stride for offset
+              0,                    // Miss index
+              orig,                 // Ray origin
+              0.0,                  // Minimum t-value
+              dir,                  // Ray direction
+              1000.0,               // Maximum t-value
+              0);                   // Location of payload
 
   // Read the values from the payload:
-  hitSky = (payload.hitSky > 0.0);
-  pos  = payload.worldPosition;
-  norm = payload.worldNormal;
+  hitSky    = (payload.hitSky > 0.0);
+  pos       = payload.worldPosition;
+  norm      = payload.worldNormal;
 }
 
 /**
@@ -67,7 +82,7 @@ void GetPixelInfo(out bool hitSky, out vec3 pos, out vec3 norm){
  * @param[in] orig The point of origin of the shadow ray.
  * @param[in] dir The direction of the shadow ray.
  */
-float ShadowRay(vec3 orig, vec3 dir){
+float CastShadowRay(vec3 orig, vec3 dir){
   payload.hitSky = 0.0f;  // Assume ray is occluded
   traceRayEXT(tlas,   // Acceleration structure
               gl_RayFlagsOpaqueEXT | gl_RayFlagsSkipClosestHitShaderEXT | gl_RayFlagsTerminateOnFirstHitEXT, // Ray flags, here saying "ignore any hit shaders and closest hit shaders, and terminate the ray on the first found intersection"
@@ -76,85 +91,68 @@ float ShadowRay(vec3 orig, vec3 dir){
               0,       // SBT record stride for offset
               0,       // Miss index
               orig,    // Ray origin
-              0.0,     // Minimum t-value
+              0.0001,  // Minimum t-value - avoid self intersection
               dir,     // Ray direction
               1000.0,  // Maximum t-value
               0);      // Location of payload
   return payload.hitSky;
 }
 
-/**
- * @brief Computes the offset position at @p worldPosition and its @p normal to avoid self-intersection.
- * @param[in] worldPosition The point of intersection.
- * @param[in] normal The normal at the point of intersection.
- */
-vec3 OffsetPositionAlongNormal(vec3 worldPosition, vec3 normal){
-  // Convert the normal to an integer offset.
-  const float int_scale = 256.0f;
-  const ivec3 of_i      = ivec3(int_scale * normal);
-
-  // Offset each component of worldPosition using its binary representation.
-  // Handle the sign bits correctly.
-  const vec3 p_i = vec3(  //
-      intBitsToFloat(floatBitsToInt(worldPosition.x) + ((worldPosition.x < 0) ? -of_i.x : of_i.x)),
-      intBitsToFloat(floatBitsToInt(worldPosition.y) + ((worldPosition.y < 0) ? -of_i.y : of_i.y)),
-      intBitsToFloat(floatBitsToInt(worldPosition.z) + ((worldPosition.z < 0) ? -of_i.z : of_i.z)));
-
-  // Use a floating-point offset instead for points near (0,0,0), the origin.
-  const float origin     = 1.0f / 32.0f;
-  const float floatScale = 1.0f / 65536.0f;
-  return vec3(  //
-      abs(worldPosition.x) < origin ? worldPosition.x + floatScale * normal.x : p_i.x,
-      abs(worldPosition.y) < origin ? worldPosition.y + floatScale * normal.y : p_i.y,
-      abs(worldPosition.z) < origin ? worldPosition.z + floatScale * normal.z : p_i.z);
+vec3 sampleCosineDistribution(vec2 xi){
+	float phi = 2 * M_PI * xi.y;
+	return vec3(
+		sqrt(xi.x) * cos(phi),
+		sqrt(1 - xi.x),
+		sqrt(xi.x) * sin(phi));
 }
 
-/**
- * @brief Used for creating random float numbers.
- */
-float StepAndOutputRNGFloat(){
-  // Condensed version of pcg_output_rxs_m_xs_32_32, with simple conversion to floating-point [0,1].
-  rngState  = rngState * 747796405 + 1;
-  uint word = ((rngState >> ((rngState >> 28) + 4)) ^ rngState) * 277803737;
-  word      = (word >> 22) ^ word;
-  return float(word) / 4294967295.0f;
+struct Basis{
+	vec3 right;
+	vec3 up;
+	vec3 forward;
+};
+
+Basis buildBasisAroundNormal(vec3 N){
+	Basis 	basis;
+	basis.up 		= N;
+	basis.right 	= abs(basis.up.x) < 0.99 ?  vec3(1, 0, 0) : vec3(0, 0, 1);
+	basis.forward 	= normalize(cross(basis.up, basis.right));
+	basis.right 	= cross(basis.up, basis.forward);
+	return basis;
 }
 
-
-/**
- * @brief Gets a randomly chosen cosine-weighted direction within the unit hemisphere defined by @p norm.
- * @param[in] norm The surface normal.
- */
-vec3 GetRandCosDir(vec3 norm){
-  // To generate a cosine-weighted normal, generate a random point on a sphere:
-  float theta      = 6.2831853 * StepAndOutputRNGFloat();  // Random in [0, 2pi]
-  float z          = 2 * StepAndOutputRNGFloat() - 1.0;    // Random in [-1, 1]
-  float r          = sqrt(1.0 - z * z);
-  vec3  ptOnSphere = vec3(r * cos(theta), r * sin(theta), z);
-  // Then add the normal to it and normalize to make it cosine-weighted on a hemisphere:
-  return normalize(ptOnSphere + norm);
+vec3 sampleTangentToWorldSpace(vec3 tangentSpaceSample, vec3 N){
+	Basis tangentBasis = buildBasisAroundNormal(N);
+	return
+		tangentBasis.right		* tangentSpaceSample.x +
+		tangentBasis.up			* tangentSpaceSample.y +
+		tangentBasis.forward 	* tangentSpaceSample.z;
 }
 
-
 void main(){
-     uint rays = 64;    // the amount of rays to be casted
+    uint rayCount = 64;    // the amount of rays to be casted
+
+    initRandom(gl_LaunchIDEXT.xy);
 
-     uvec2 pixel = gl_LaunchIDEXT.xy;
-     bool pixelIsSky; // Does the pixel show the sky (not an object)?
-     vec3 pos, norm;  // AO rays from where?
-     GetPixelInfo(pixelIsSky, pos, norm);
-     if(pixelIsSky){
+    uvec2 pixel = gl_LaunchIDEXT.xy;
+    bool pixelIsSky; // Does the pixel show the sky (not an object)?
+    vec3 pos, norm;  // AO rays from where?
+    TraceCameraRay(pixelIsSky, pos, norm);
+    
+    if(pixelIsSky){
         // Don't compute ambient occlusion for the sky
         imageStore(outImg, ivec2(pixel), vec4(0.8,0.8,0.8,1.0));
         return;
-     }
-
-      // Compute ambient occlusion
-     pos = OffsetPositionAlongNormal(pos, norm); // Avoid self-intersection
-     float aoColor = 0.0;
-     for(uint i = 0; i < rays; i++){
-        aoColor += ShadowRay(pos, GetRandCosDir(norm)) / rays;
-     }
-     vec4 aoColorVec = vec4(aoColor);
-     imageStore(outImg, ivec2(pixel), aoColorVec);
+    }
+
+    // Compute ambient occlusion
+    float aoValue = 0.0;
+    for(uint i = 0; i < rayCount; i++){
+        vec3 sampleTangentSpace = sampleCosineDistribution(random());
+        vec3 sampleWorldSpace   = sampleTangentToWorldSpace(sampleTangentSpace, norm);
+        aoValue                 += CastShadowRay(pos, sampleWorldSpace);
+    }
+    aoValue /= rayCount;
+    
+    imageStore(outImg, ivec2(pixel), vec4(vec3(aoValue), 1));
 }