Skip to content
Snippets Groups Projects
Commit e6dee03f authored by Alex Laptop's avatar Alex Laptop
Browse files

Refactor shader

parent bdfb811c
No related branches found
No related tags found
1 merge request!101Path tracing
...@@ -59,88 +59,98 @@ layout( push_constant ) uniform constants{ ...@@ -59,88 +59,98 @@ layout( push_constant ) uniform constants{
int planeCount; int planeCount;
}; };
bool raySphereIntersect(const vec3 origin, const vec3 dir, out float t0, const int id){ struct Ray{
vec3 L = inSpheres[id].center - origin; vec3 origin;
float tca = dot(L, dir); vec3 direction;
float d2 = dot(L, L) - tca * tca; };
if (d2 > inSpheres[id].radius * inSpheres[id].radius){
return false; struct Intersection{
bool hit;
float distance;
vec3 pos;
vec3 N;
Material material;
};
Intersection raySphereIntersect(Ray ray, Sphere sphere){
Intersection intersection;
intersection.hit = false;
vec3 L = sphere.center - ray.origin;
float tca = dot(L, ray.direction);
float d2 = dot(L, L) - tca * tca;
if (d2 > sphere.radius * sphere.radius){
return intersection;
} }
float thc = float(sqrt(inSpheres[id].radius * inSpheres[id].radius - d2)); float thc = float(sqrt(sphere.radius * sphere.radius - d2));
t0 = tca - thc; float t0 = tca - thc;
float t1 = tca + thc; float t1 = tca + thc;
if (t0 < 0) { if (t0 < 0) {
t0 = t1; t0 = t1;
} }
if (t0 < 0){ if (t0 < 0){
return false; return intersection;
} }
return true;
intersection.hit = true;
intersection.distance = t0;
intersection.pos = ray.origin + ray.direction * intersection.distance;
intersection.N = normalize(intersection.pos - sphere.center);
intersection.material = inMaterials[sphere.materialIndex];
return intersection;
} }
// see: https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-plane-and-ray-disk-intersection // see: https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-plane-and-ray-disk-intersection
bool rayPlaneIntersect(const vec3 origin, vec3 direction, out float intersectionDistance, const int id){ Intersection rayPlaneIntersect(Ray ray, Plane plane){
Plane plane = inPlanes[id];
vec3 toPlane = plane.center - origin; Intersection intersection;
float denom = dot(direction, plane.N); intersection.hit = false;
vec3 toPlane = plane.center - ray.origin;
float denom = dot(ray.direction, plane.N);
if(abs(denom) < 0.001) if(abs(denom) < 0.001)
return false; return intersection;
intersectionDistance = dot(toPlane, plane.N) / denom; intersection.distance = dot(toPlane, plane.N) / denom;
if(intersectionDistance < 0) if(intersection.distance < 0)
return false; return intersection;
vec3 intersection = origin + direction * intersectionDistance; intersection.pos = ray.origin + ray.direction * intersection.distance;
vec3 right = abs(plane.N.x) < 0.99 ? vec3(1, 0, 0) : vec3(0, 0, 1); vec3 right = abs(plane.N.x) < 0.99 ? vec3(1, 0, 0) : vec3(0, 0, 1);
vec3 up = normalize(cross(plane.N, right)); vec3 up = normalize(cross(plane.N, right));
right = cross(plane.N, up); right = cross(plane.N, up);
vec3 centerToIntersection = intersection - plane.center; vec3 centerToIntersection = intersection.pos - plane.center;
float projectedRight = dot(centerToIntersection, right); float projectedRight = dot(centerToIntersection, right);
float projectedUp = dot(centerToIntersection, up); float projectedUp = dot(centerToIntersection, up);
return abs(projectedRight) <= plane.extent.x && abs(projectedUp) <= plane.extent.y; intersection.hit = abs(projectedRight) <= plane.extent.x && abs(projectedUp) <= plane.extent.y;
intersection.N = plane.N;
intersection.material = inMaterials[plane.materialIndex];
return intersection;
} }
struct Intersection{ Intersection sceneIntersect(Ray ray) {
bool hit; float minDistance = 100000; // lets start with something big
vec3 pos;
vec3 N;
Material material;
};
Intersection sceneIntersect(const vec3 rayOrigin, const vec3 rayDirection) {
float min_d = 100000; // lets start with something big
Intersection intersection; Intersection intersection;
intersection.hit = false; intersection.hit = false;
for (int i = 0; i < sphereCount; i++) { for (int i = 0; i < sphereCount; i++) {
float d; Intersection sphereIntersection = raySphereIntersect(ray, inSpheres[i]);
if (raySphereIntersect(rayOrigin, rayDirection, d, i)) { if (sphereIntersection.hit && sphereIntersection.distance < minDistance) {
intersection = sphereIntersection;
intersection.hit = true; minDistance = intersection.distance;
if(d < min_d){
min_d = d;
intersection.pos = rayOrigin + rayDirection * d;
intersection.N = normalize(intersection.pos - inSpheres[i].center);
intersection.material = inMaterials[inSpheres[i].materialIndex];
}
} }
} }
for (int i = 0; i < planeCount; i++){ for (int i = 0; i < planeCount; i++){
float d; Intersection planeIntersection = rayPlaneIntersect(ray, inPlanes[i]);
if (rayPlaneIntersect(rayOrigin, rayDirection, d, i)) { if (planeIntersection.hit && planeIntersection.distance < minDistance) {
intersection = planeIntersection;
intersection.hit = true; minDistance = intersection.distance;
if(d < min_d){
min_d = d;
intersection.pos = rayOrigin + rayDirection * d;
intersection.N = inPlanes[i].N;
intersection.material = inMaterials[inPlanes[i].materialIndex];
}
} }
} }
return intersection; return intersection;
...@@ -160,14 +170,16 @@ vec3 computeHitLighting(Intersection intersection, vec3 V, out float outReflecti ...@@ -160,14 +170,16 @@ vec3 computeHitLighting(Intersection intersection, vec3 V, out float outReflecti
vec3 L = normalize(inLights[i].position - intersection.pos); vec3 L = normalize(inLights[i].position - intersection.pos);
float d = distance(inLights[i].position, intersection.pos); float d = distance(inLights[i].position, intersection.pos);
vec3 shadowOrigin = biasHitPosition(intersection.pos, L, intersection.N); Ray shadowRay;
shadowRay.origin = biasHitPosition(intersection.pos, L, intersection.N);
shadowRay.direction = L;
Intersection shadowIntersection = sceneIntersect(shadowOrigin, L); Intersection shadowIntersection = sceneIntersect(shadowRay);
bool isShadowed = false; bool isShadowed = false;
if(shadowIntersection.hit) if(shadowIntersection.hit)
isShadowed = distance(shadowIntersection.pos, shadowOrigin) < d; isShadowed = distance(shadowIntersection.pos, shadowRay.origin) < d;
if(isShadowed) if(isShadowed)
continue; continue;
...@@ -180,24 +192,22 @@ vec3 computeHitLighting(Intersection intersection, vec3 V, out float outReflecti ...@@ -180,24 +192,22 @@ vec3 computeHitLighting(Intersection intersection, vec3 V, out float outReflecti
return intersection.material.diffuse_color * lightIntensityDiffuse * intersection.material.albedo[0] + lightIntensitySpecular * intersection.material.albedo[1]; return intersection.material.diffuse_color * lightIntensityDiffuse * intersection.material.albedo[0] + lightIntensitySpecular * intersection.material.albedo[1];
} }
vec3 castRay(const vec3 initialOrigin, const vec3 initialDirection, int max_depth) { vec3 castRay(Ray ray, int max_depth) {
vec3 skyColor = vec3(0.2, 0.7, 0.8); vec3 skyColor = vec3(0.2, 0.7, 0.8);
vec3 rayOrigin = initialOrigin;
vec3 rayDirection = initialDirection;
float reflectionThroughput = 1; float reflectionThroughput = 1;
vec3 color = vec3(0); vec3 color = vec3(0);
for(int i = 0; i < max_depth; i++){ for(int i = 0; i < max_depth; i++){
Intersection intersection = sceneIntersect(rayOrigin, rayDirection); Intersection intersection = sceneIntersect(ray);
vec3 hitColor; vec3 hitColor;
float hitReflectionThroughput; float hitReflectionThroughput;
if(intersection.hit){ if(intersection.hit){
hitColor = computeHitLighting(intersection, rayDirection, hitReflectionThroughput); hitColor = computeHitLighting(intersection, ray.direction, hitReflectionThroughput);
} }
else{ else{
hitColor = skyColor; hitColor = skyColor;
...@@ -209,8 +219,8 @@ vec3 castRay(const vec3 initialOrigin, const vec3 initialDirection, int max_dept ...@@ -209,8 +219,8 @@ vec3 castRay(const vec3 initialOrigin, const vec3 initialDirection, int max_dept
if(!intersection.hit) if(!intersection.hit)
break; break;
rayDirection = normalize(reflect(rayDirection, intersection.N)); ray.direction = normalize(reflect(ray.direction, intersection.N));
rayOrigin = biasHitPosition(intersection.pos, rayDirection, intersection.N); ray.origin = biasHitPosition(intersection.pos, ray.direction, intersection.N);
} }
return color; return color;
...@@ -240,9 +250,10 @@ vec3 computeDirection(ivec2 coord){ ...@@ -240,9 +250,10 @@ vec3 computeDirection(ivec2 coord){
void main(){ void main(){
ivec2 coord = ivec2(gl_GlobalInvocationID.xy); ivec2 coord = ivec2(gl_GlobalInvocationID.xy);
int max_depth = 4; int max_depth = 4;
vec3 direction = computeDirection(coord); Ray cameraRay;
vec3 cameraPos = viewToWorld[3].xyz; cameraRay.direction = computeDirection(coord);
vec3 color = castRay(cameraPos, direction, max_depth); cameraRay.origin = viewToWorld[3].xyz;
vec3 color = castRay(cameraRay, max_depth);
imageStore(outImage, coord, vec4(color, 0.f)); imageStore(outImage, coord, vec4(color, 0.f));
} }
\ No newline at end of file
...@@ -85,7 +85,7 @@ int main(int argc, const char** argv) { ...@@ -85,7 +85,7 @@ int main(int argc, const char** argv) {
const uint32_t greenIndex = 2; const uint32_t greenIndex = 2;
std::vector<temp::Sphere> spheres; std::vector<temp::Sphere> spheres;
spheres.emplace_back(temp::Sphere(glm::vec3(0, -1, 0), 1, whiteIndex)); spheres.emplace_back(temp::Sphere(glm::vec3(0, -1.5, 0), 0.5, whiteIndex));
std::vector<temp::Plane> planes; std::vector<temp::Plane> planes;
planes.emplace_back(temp::Plane(glm::vec3( 0, -2, 0), glm::vec3( 0, 1, 0), glm::vec2(2), whiteIndex)); planes.emplace_back(temp::Plane(glm::vec3( 0, -2, 0), glm::vec3( 0, 1, 0), glm::vec2(2), whiteIndex));
...@@ -179,7 +179,7 @@ int main(int argc, const char** argv) { ...@@ -179,7 +179,7 @@ int main(int argc, const char** argv) {
raytracingPushData.viewToWorld = glm::inverse(cameraManager.getActiveCamera().getView()); raytracingPushData.viewToWorld = glm::inverse(cameraManager.getActiveCamera().getView());
raytracingPushData.lightCount = lights.size(); raytracingPushData.lightCount = lights.size();
raytracingPushData.sphereCount = spheres.size(); raytracingPushData.sphereCount = spheres.size();
raytracingPushData.planeCount = planes.size(); raytracingPushData.planeCount = planes.size();
vkcv::PushConstants pushConstantsCompute(sizeof(RaytracingPushConstantData)); vkcv::PushConstants pushConstantsCompute(sizeof(RaytracingPushConstantData));
pushConstantsCompute.appendDrawcall(raytracingPushData); pushConstantsCompute.appendDrawcall(raytracingPushData);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment