From 8a6af65b1527eb4a5291f1b4736de2b34528beb7 Mon Sep 17 00:00:00 2001
From: Tobias Frisch <tfrisch@uni-koblenz.de>
Date: Tue, 29 Nov 2022 14:51:30 +0100
Subject: [PATCH] Add stride to vertex buffer bindings and fix issues with
 non-interleaved usage in acceleration structures

Signed-off-by: Tobias Frisch <tfrisch@uni-koblenz.de>
---
 include/vkcv/Buffer.hpp                       | 18 ++++++++++++++++
 include/vkcv/GeometryData.hpp                 | 11 ++++++++--
 include/vkcv/VertexData.hpp                   | 18 +++++++++++++---
 .../src/vkcv/asset/asset_loader.cpp           |  6 +++++-
 .../include/vkcv/geometry/Geometry.hpp        |  3 ++-
 modules/geometry/src/vkcv/geometry/Cuboid.cpp |  8 +++----
 .../geometry/src/vkcv/geometry/Cylinder.cpp   |  8 +++----
 .../geometry/src/vkcv/geometry/Geometry.cpp   | 20 +++++++++++++++---
 modules/geometry/src/vkcv/geometry/Sphere.cpp |  8 +++----
 modules/geometry/src/vkcv/geometry/Teapot.cpp |  8 +++----
 modules/scene/src/vkcv/scene/MeshPart.cpp     | 13 ++----------
 projects/fire_works/src/main.cpp              |  4 ++--
 projects/indirect_draw/src/main.cpp           |  2 +-
 projects/mpm/src/main.cpp                     |  4 ++--
 projects/particle_simulation/src/main.cpp     |  2 +-
 .../resources/shaders/ambientOcclusion.rchit  |  8 +++++++
 .../resources/shaders/ambientOcclusion.rgen   | 20 +++++++++++++++---
 .../resources/shaders/ambientOcclusion.rmiss  |  1 +
 projects/rt_ambient_occlusion/src/main.cpp    |  4 ++--
 projects/sph/src/main.cpp                     |  2 +-
 src/vkcv/AccelerationStructureManager.cpp     | 12 ++++-------
 src/vkcv/Core.cpp                             | 21 +++++++++++++------
 src/vkcv/GeometryData.cpp                     | 12 +++++++----
 src/vkcv/VertexData.cpp                       |  9 +++++---
 24 files changed, 152 insertions(+), 70 deletions(-)

diff --git a/include/vkcv/Buffer.hpp b/include/vkcv/Buffer.hpp
index 37bab6b8..a7dba97e 100644
--- a/include/vkcv/Buffer.hpp
+++ b/include/vkcv/Buffer.hpp
@@ -54,6 +54,19 @@ namespace vkcv {
 		[[nodiscard]] BufferMemoryType getMemoryType() const {
 			return m_core->getBufferMemoryType(m_handle);
 		}
+		
+		/**
+		 * @brief Returns the stride of elements in the buffer.
+		 *
+		 * Beware that this returned value is only the correct
+		 * stride for this buffer if it is used tightly packed
+		 * storing elements of type T.
+		 *
+		 * @return The likely stride of the #Buffer using type T
+		 */
+		[[nodiscard]] size_t getStride() const {
+			return sizeof(T);
+		}
 
 		/**
 		 * @brief Returns the count of elements in the buffer.
@@ -168,5 +181,10 @@ namespace vkcv {
 		return Buffer<T>(&core,
 						 core.createBuffer(type, typeGuard<T>(), count, memoryType, readable));
 	}
+	
+	template <typename T>
+	VertexBufferBinding vertexBufferBinding(const Buffer<T>& buffer) {
+		return vertexBufferBinding(buffer.getHandle(), buffer.getStride());
+	}
 
 } // namespace vkcv
diff --git a/include/vkcv/GeometryData.hpp b/include/vkcv/GeometryData.hpp
index 0dcea5be..bc8ab4eb 100644
--- a/include/vkcv/GeometryData.hpp
+++ b/include/vkcv/GeometryData.hpp
@@ -23,7 +23,7 @@ namespace vkcv {
 	class GeometryData {
 	private:
 		VertexBufferBinding m_vertexBinding;
-		uint32_t m_vertexStride;
+		size_t m_maxVertexIndex;
 		GeometryVertexType m_vertexType;
 		BufferHandle m_indices;
 		IndexBitCount m_indexBitCount;
@@ -44,7 +44,7 @@ namespace vkcv {
 		 * @param[in] geometryVertexType Geometry vertex type
 		 */
 		explicit GeometryData(const VertexBufferBinding &binding,
-							  uint32_t stride = sizeof(float) * 3,
+							  size_t maxVertexIndex = 0,
 							  GeometryVertexType geometryVertexType =
 									  GeometryVertexType::POSITION_FLOAT3);
 		
@@ -77,6 +77,13 @@ namespace vkcv {
 		 */
 		[[nodiscard]] uint32_t getVertexStride() const;
 		
+		/**
+		 * @brief Return the maximal index from vertex elements of the geometry data.
+		 *
+		 * @return Maximal vertex index
+		 */
+		[[nodiscard]] size_t getMaxVertexIndex() const;
+		
 		/**
 		 * @brief Return the geometry vertex type of the geometry data.
 		 *
diff --git a/include/vkcv/VertexData.hpp b/include/vkcv/VertexData.hpp
index 845b9fba..a9928a4d 100644
--- a/include/vkcv/VertexData.hpp
+++ b/include/vkcv/VertexData.hpp
@@ -15,11 +15,23 @@ namespace vkcv {
 	 * @brief Structure to store details about a vertex buffer binding.
 	 */
 	struct VertexBufferBinding {
-		BufferHandle buffer;
-		size_t offset;
+		BufferHandle m_buffer;
+		size_t m_stride;
+		size_t m_offset;
 	};
 
-	VertexBufferBinding vertexBufferBinding(const BufferHandle &buffer, size_t offset = 0);
+	/**
+	 * Create a vertex buffer binding using a given buffer handle and
+	 * its stride in bytes.
+	 *
+	 * @param[in] buffer Vertex buffer
+	 * @param[in] stride Stride in bytes
+	 * @param[in] offset (Optional) Offset in bytes
+	 * @return Vertex buffer binding
+	 */
+	VertexBufferBinding vertexBufferBinding(const BufferHandle &buffer,
+											size_t stride,
+											size_t offset = 0);
 
 	typedef std::vector<VertexBufferBinding> VertexBufferBindings;
 
diff --git a/modules/asset_loader/src/vkcv/asset/asset_loader.cpp b/modules/asset_loader/src/vkcv/asset/asset_loader.cpp
index ef83785e..f924ea73 100644
--- a/modules/asset_loader/src/vkcv/asset/asset_loader.cpp
+++ b/modules/asset_loader/src/vkcv/asset/asset_loader.cpp
@@ -885,7 +885,11 @@ namespace vkcv::asset {
 				break;
 			}
 			
-			bindings.push_back(vkcv::vertexBufferBinding(buffer, attribute->offset));
+			bindings.push_back(vkcv::vertexBufferBinding(
+					buffer,
+					attribute->stride,
+					attribute->offset
+			));
 		}
 		
 		return bindings;
diff --git a/modules/geometry/include/vkcv/geometry/Geometry.hpp b/modules/geometry/include/vkcv/geometry/Geometry.hpp
index ea006b00..b7be4eb8 100644
--- a/modules/geometry/include/vkcv/geometry/Geometry.hpp
+++ b/modules/geometry/include/vkcv/geometry/Geometry.hpp
@@ -116,11 +116,12 @@ namespace vkcv::geometry {
 		 * generated vertex data, which can be used for
 		 * building bottom level acceleration structures.
 		 *
+		 * @param[in,out] core Core instance
 		 * @param[in, out] vertexData Vertex data with generated geometry
 		 * @return Geometry data from generated vertex data
 		 */
 		[[nodiscard]]
-		virtual GeometryData extractGeometryData(const VertexData &vertexData) const;
+		virtual GeometryData extractGeometryData(Core& core, const VertexData &vertexData) const;
 		
 	};
 	
diff --git a/modules/geometry/src/vkcv/geometry/Cuboid.cpp b/modules/geometry/src/vkcv/geometry/Cuboid.cpp
index 73a701b2..a5b62525 100644
--- a/modules/geometry/src/vkcv/geometry/Cuboid.cpp
+++ b/modules/geometry/src/vkcv/geometry/Cuboid.cpp
@@ -219,10 +219,10 @@ namespace vkcv::geometry {
 		tangentBuffer.fill(cuboidTangents);
 		
 		VertexData data ({
-			vkcv::vertexBufferBinding(positionBuffer.getHandle()),
-			vkcv::vertexBufferBinding(normalBuffer.getHandle()),
-			vkcv::vertexBufferBinding(uvBuffer.getHandle()),
-			vkcv::vertexBufferBinding(tangentBuffer.getHandle())
+			vkcv::vertexBufferBinding(positionBuffer.getHandle(), sizeof(float) * 3),
+			vkcv::vertexBufferBinding(normalBuffer.getHandle(), sizeof(float) * 3),
+			vkcv::vertexBufferBinding(uvBuffer.getHandle(), sizeof(float) * 2),
+			vkcv::vertexBufferBinding(tangentBuffer)
 		});
 		
 		const auto& featureManager = core.getContext().getFeatureManager();
diff --git a/modules/geometry/src/vkcv/geometry/Cylinder.cpp b/modules/geometry/src/vkcv/geometry/Cylinder.cpp
index 0583c4c1..57052199 100644
--- a/modules/geometry/src/vkcv/geometry/Cylinder.cpp
+++ b/modules/geometry/src/vkcv/geometry/Cylinder.cpp
@@ -177,10 +177,10 @@ namespace vkcv::geometry {
 		indexBuffer.fill(cylinderIndices);
 		
 		VertexData data ({
-			vkcv::vertexBufferBinding(positionBuffer.getHandle()),
-			vkcv::vertexBufferBinding(normalBuffer.getHandle()),
-			vkcv::vertexBufferBinding(uvBuffer.getHandle()),
-			vkcv::vertexBufferBinding(tangentBuffer.getHandle())
+			vkcv::vertexBufferBinding(positionBuffer),
+			vkcv::vertexBufferBinding(normalBuffer),
+			vkcv::vertexBufferBinding(uvBuffer),
+			vkcv::vertexBufferBinding(tangentBuffer)
 		});
 		
 		data.setIndexBuffer(indexBuffer.getHandle(), IndexBitCount::Bit32);
diff --git a/modules/geometry/src/vkcv/geometry/Geometry.cpp b/modules/geometry/src/vkcv/geometry/Geometry.cpp
index ec0d56e0..82e1b6b3 100644
--- a/modules/geometry/src/vkcv/geometry/Geometry.cpp
+++ b/modules/geometry/src/vkcv/geometry/Geometry.cpp
@@ -28,10 +28,24 @@ namespace vkcv::geometry {
 		));
 	}
 	
-	GeometryData Geometry::extractGeometryData(const vkcv::VertexData &vertexData) const {
+	GeometryData Geometry::extractGeometryData(Core& core,
+											   const vkcv::VertexData &vertexData) const {
+		const VertexBufferBinding positionBufferBinding = vertexData.getVertexBufferBindings()[0];
+		const size_t bufferSize = core.getBufferSize(positionBufferBinding.m_buffer);
+		
+		if (positionBufferBinding.m_stride < sizeof(float) * 3) {
+			return {};
+		}
+		
+		const size_t vertexCount = (bufferSize / positionBufferBinding.m_stride);
+		
+		if (vertexCount < 3) {
+			return {};
+		}
+		
 		GeometryData data (
-				vertexData.getVertexBufferBindings()[0],
-				sizeof(glm::vec3),
+				positionBufferBinding,
+				vertexCount - 1,
 				GeometryVertexType::POSITION_FLOAT3
 		);
 		
diff --git a/modules/geometry/src/vkcv/geometry/Sphere.cpp b/modules/geometry/src/vkcv/geometry/Sphere.cpp
index d516a1da..f6bd894a 100644
--- a/modules/geometry/src/vkcv/geometry/Sphere.cpp
+++ b/modules/geometry/src/vkcv/geometry/Sphere.cpp
@@ -137,10 +137,10 @@ namespace vkcv::geometry {
 		indexBuffer.fill(sphereIndices);
 		
 		VertexData data ({
-			vkcv::vertexBufferBinding(positionBuffer.getHandle()),
-			vkcv::vertexBufferBinding(normalBuffer.getHandle()),
-			vkcv::vertexBufferBinding(uvBuffer.getHandle()),
-			vkcv::vertexBufferBinding(tangentBuffer.getHandle())
+			vkcv::vertexBufferBinding(positionBuffer),
+			vkcv::vertexBufferBinding(normalBuffer),
+			vkcv::vertexBufferBinding(uvBuffer),
+			vkcv::vertexBufferBinding(tangentBuffer)
 		});
 		
 		data.setIndexBuffer(indexBuffer.getHandle(), IndexBitCount::Bit32);
diff --git a/modules/geometry/src/vkcv/geometry/Teapot.cpp b/modules/geometry/src/vkcv/geometry/Teapot.cpp
index 7becdef5..8d8b4624 100644
--- a/modules/geometry/src/vkcv/geometry/Teapot.cpp
+++ b/modules/geometry/src/vkcv/geometry/Teapot.cpp
@@ -14934,10 +14934,10 @@ namespace vkcv::geometry {
 		indexBuffer.fill(teapotIndices);
 		
 		VertexData data ({
-			vkcv::vertexBufferBinding(positionBuffer.getHandle()),
-			vkcv::vertexBufferBinding(normalBuffer.getHandle()),
-			vkcv::vertexBufferBinding(uvBuffer.getHandle()),
-			vkcv::vertexBufferBinding(tangentBuffer.getHandle())
+			vkcv::vertexBufferBinding(positionBuffer.getHandle(), sizeof(float) * 3),
+			vkcv::vertexBufferBinding(normalBuffer.getHandle(), sizeof(float) * 3),
+			vkcv::vertexBufferBinding(uvBuffer.getHandle(), sizeof(float) * 2),
+			vkcv::vertexBufferBinding(tangentBuffer)
 		});
 		
 		data.setIndexBuffer(indexBuffer.getHandle());
diff --git a/modules/scene/src/vkcv/scene/MeshPart.cpp b/modules/scene/src/vkcv/scene/MeshPart.cpp
index aa82a27d..ec5a5182 100644
--- a/modules/scene/src/vkcv/scene/MeshPart.cpp
+++ b/modules/scene/src/vkcv/scene/MeshPart.cpp
@@ -41,19 +41,10 @@ namespace vkcv::scene {
 			}
 		}
 		
-		uint32_t stride = 0;
-		for (const auto& attr : vertexGroup.vertexBuffer.attributes) {
-			if (attr.type == asset::PrimitiveType::POSITION) {
-				stride = attr.stride;
-				break;
-			}
-		}
-		
-		if ((positionAttributeIndex < m_data.getVertexBufferBindings().size()) &&
-			(stride > 0)) {
+		if (positionAttributeIndex < m_data.getVertexBufferBindings().size()) {
 			m_geometry = GeometryData(
 					m_data.getVertexBufferBindings()[positionAttributeIndex],
-					stride,
+					vertexGroup.numVertices - 1,
 					GeometryVertexType::POSITION_FLOAT3
 			);
 		}
diff --git a/projects/fire_works/src/main.cpp b/projects/fire_works/src/main.cpp
index 51e29ed5..2bec7871 100644
--- a/projects/fire_works/src/main.cpp
+++ b/projects/fire_works/src/main.cpp
@@ -580,7 +580,7 @@ int main(int argc, const char **argv) {
 		1, 4, 0
 	});
 	
-	vkcv::VertexData cubeData ({ vkcv::vertexBufferBinding(cubePositions.getHandle()) });
+	vkcv::VertexData cubeData ({ vkcv::vertexBufferBinding(cubePositions) });
 	cubeData.setIndexBuffer(cubeIndices.getHandle());
 	cubeData.setCount(cubeIndices.getCount());
 	
@@ -653,7 +653,7 @@ int main(int argc, const char **argv) {
 		0, 1, 2
 	});
 	
-	vkcv::VertexData triangleData ({ vkcv::vertexBufferBinding(trianglePositions.getHandle()) });
+	vkcv::VertexData triangleData ({ vkcv::vertexBufferBinding(trianglePositions) });
 	triangleData.setIndexBuffer(triangleIndices.getHandle());
 	triangleData.setCount(triangleIndices.getCount());
 	
diff --git a/projects/indirect_draw/src/main.cpp b/projects/indirect_draw/src/main.cpp
index f90fa089..5f40732f 100644
--- a/projects/indirect_draw/src/main.cpp
+++ b/projects/indirect_draw/src/main.cpp
@@ -436,7 +436,7 @@ int main(int argc, const char** argv) {
 	modelBuffer.fill(modelMatrix);
 
 	const std::vector<vkcv::VertexBufferBinding> vertexBufferBindings = {
-			vkcv::vertexBufferBinding(vkCompiledVertexBuffer.getHandle())
+			vkcv::vertexBufferBinding(vkCompiledVertexBuffer)
 	};
 	
 	vkcv::VertexData vertexData (vertexBufferBindings);
diff --git a/projects/mpm/src/main.cpp b/projects/mpm/src/main.cpp
index 61f27daa..fa9c908c 100644
--- a/projects/mpm/src/main.cpp
+++ b/projects/mpm/src/main.cpp
@@ -550,7 +550,7 @@ int main(int argc, const char **argv) {
 		glm::vec2(+1.0f, -1.0f)
 	});
 	
-	vkcv::VertexData triangleData ({ vkcv::vertexBufferBinding(trianglePositions.getHandle()) });
+	vkcv::VertexData triangleData ({ vkcv::vertexBufferBinding(trianglePositions) });
 	triangleData.setCount(trianglePositions.getCount());
 	
 	vkcv::Buffer<glm::vec3> linesPositions = vkcv::buffer<glm::vec3>(core, vkcv::BufferType::VERTEX, 8);
@@ -583,7 +583,7 @@ int main(int argc, const char **argv) {
 		3, 7
 	});
 	
-	vkcv::VertexData linesData ({ vkcv::vertexBufferBinding(linesPositions.getHandle()) });
+	vkcv::VertexData linesData ({ vkcv::vertexBufferBinding(linesPositions) });
 	linesData.setIndexBuffer(linesIndices.getHandle());
 	linesData.setCount(linesIndices.getCount());
 	
diff --git a/projects/particle_simulation/src/main.cpp b/projects/particle_simulation/src/main.cpp
index 72249ee3..7f270706 100644
--- a/projects/particle_simulation/src/main.cpp
+++ b/projects/particle_simulation/src/main.cpp
@@ -98,7 +98,7 @@ int main(int argc, const char **argv) {
     const std::vector<vkcv::VertexAttachment> vertexAttachments = particleShaderProgram.getVertexAttachments();
 
     const std::vector<vkcv::VertexBufferBinding> vertexBufferBindings = {
-            vkcv::vertexBufferBinding(vertexBuffer.getHandle())
+            vkcv::vertexBufferBinding(vertexBuffer)
 	};
 
     std::vector<vkcv::VertexBinding> bindings;
diff --git a/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rchit b/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rchit
index 79636ffc..9bc4f664 100644
--- a/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rchit
+++ b/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rchit
@@ -9,6 +9,7 @@ layout(location = 0) rayPayloadInEXT Payload {
   float hitSky;
   vec3 worldPosition;
   vec3 worldNormal;
+  uvec4 hit;
 } payload;
 
 layout(binding = 2, set = 0, scalar) buffer rtVertices
@@ -24,6 +25,13 @@ layout(binding = 3, set = 0, scalar) buffer rtIndices
 void main() {
     payload.worldPosition = vec3(1.0, 0.0, 0.5);
 
+    payload.hit = uvec4(
+        gl_PrimitiveID,
+        gl_InstanceID,
+        gl_InstanceCustomIndexEXT,
+        gl_GeometryIndexEXT
+    );
+
     ivec3 indicesVec = ivec3(indices[3 * gl_PrimitiveID + 0], indices[3 * gl_PrimitiveID + 1], indices[3 * gl_PrimitiveID + 2]);
 
     // current triangle
diff --git a/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rgen b/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rgen
index 711070fc..41061e0e 100644
--- a/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rgen
+++ b/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rgen
@@ -8,6 +8,7 @@ layout(location = 0) rayPayloadEXT Payload {
   float hitSky;
   vec3 worldPosition;
   vec3 worldNormal;
+  uvec4 hit;
 } payload;
 
 layout(binding = 0, set = 0, rgba16) uniform image2D outImg;            // the output image -> maybe use 16 bit values?
@@ -49,14 +50,18 @@ vec2 random(){
  * @param[in,out] pos The position of intersection
  * @param[in,out] norm The normal at the position of intersection
  */
-void TraceCameraRay(out bool hitSky, out vec3 pos, out vec3 norm){
+void TraceCameraRay(out bool hitSky, out vec3 pos, out vec3 norm, out uvec4 hit){
   // Use a camera model to generate a ray for this pixel.
   vec2 uv = gl_LaunchIDEXT.xy + vec2(random()); // random breaks up aliasing
   uv /= vec2(gl_LaunchSizeEXT.xy);
   uv = (uv * 2.0 - 1.0) // normalize uv coordinates into Vulkan viewport space
     * vec2(1.0, -1.0);  // flips y-axis
   const vec3 orig   = camera.camera_position.xyz;
-  const vec3 dir    = normalize(uv.x * camera.camera_right + uv.y * camera.camera_up + camera.camera_forward).xyz;
+  const vec3 dir    = normalize(
+    uv.x * camera.camera_right +
+    uv.y * camera.camera_up +
+    camera.camera_forward
+  ).xyz;
 
   // Trace a ray into the scene; get back data in the payload.
   traceRayEXT(tlas,  // Acceleration structure
@@ -75,6 +80,7 @@ void TraceCameraRay(out bool hitSky, out vec3 pos, out vec3 norm){
   hitSky    = (payload.hitSky > 0.0);
   pos       = payload.worldPosition;
   norm      = payload.worldNormal;
+  hit       = payload.hit;
 }
 
 /**
@@ -137,7 +143,8 @@ void main(){
     uvec2 pixel = gl_LaunchIDEXT.xy;
     bool pixelIsSky; // Does the pixel show the sky (not an object)?
     vec3 pos, norm;  // AO rays from where?
-    TraceCameraRay(pixelIsSky, pos, norm);
+    uvec4 hit;
+    TraceCameraRay(pixelIsSky, pos, norm, hit);
     
     if(pixelIsSky){
         // Don't compute ambient occlusion for the sky
@@ -145,6 +152,12 @@ void main(){
         return;
     }
 
+    imageStore(outImg, ivec2(pixel), vec4(vec3(
+        0,//float(hit.x) / 0xFFFF,
+        0,//float(hit.y) / 381,
+        hit.z > 0 || hit.w > 0? 1 : 0
+    ), 1));
+/*
     // Compute ambient occlusion
     float aoValue = 0.0;
     for(uint i = 0; i < rayCount; i++){
@@ -155,4 +168,5 @@ void main(){
     aoValue /= rayCount;
     
     imageStore(outImg, ivec2(pixel), vec4(vec3(aoValue), 1));
+*/
 }
diff --git a/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rmiss b/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rmiss
index c107dbd0..5c97ae90 100644
--- a/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rmiss
+++ b/projects/rt_ambient_occlusion/resources/shaders/ambientOcclusion.rmiss
@@ -5,6 +5,7 @@ layout(location = 0) rayPayloadInEXT Payload {
   float hitSky;
   vec3 worldPosition;
   vec3 worldNormal;
+  uvec4 hit;
 } payload;
 
 void main() {
diff --git a/projects/rt_ambient_occlusion/src/main.cpp b/projects/rt_ambient_occlusion/src/main.cpp
index ce6c89a3..7c77334b 100644
--- a/projects/rt_ambient_occlusion/src/main.cpp
+++ b/projects/rt_ambient_occlusion/src/main.cpp
@@ -61,7 +61,7 @@ int main(int argc, const char** argv) {
 	
 	vkcv::geometry::Teapot teapot (glm::vec3(0.0f), 1.0f);
 	vkcv::VertexData vertexData = teapot.generateVertexData(core);
-	vkcv::GeometryData geometryData = teapot.extractGeometryData(vertexData);
+	vkcv::GeometryData geometryData = teapot.extractGeometryData(core, vertexData);
 
 	vkcv::camera::CameraManager cameraManager(core.getWindow(windowHandle));
 	auto camHandle = cameraManager.addCamera(vkcv::camera::ControllerType::TRACKBALL);
@@ -104,7 +104,7 @@ int main(int argc, const char** argv) {
 	{
 		vkcv::DescriptorWrites writes;
 		writes.writeAcceleration(1, { core.getVulkanAccelerationStructure(scene_tlas) });
-		writes.writeStorageBuffer(2, geometryData.getVertexBufferBinding().buffer);
+		writes.writeStorageBuffer(2, geometryData.getVertexBufferBinding().m_buffer);
 		writes.writeStorageBuffer(3, geometryData.getIndexBuffer());
 		core.writeDescriptorSet(descriptorSetHandles[0], writes);
 	}
diff --git a/projects/sph/src/main.cpp b/projects/sph/src/main.cpp
index 989201a9..904d9e89 100644
--- a/projects/sph/src/main.cpp
+++ b/projects/sph/src/main.cpp
@@ -96,7 +96,7 @@ int main(int argc, const char **argv) {
     const std::vector<vkcv::VertexAttachment> vertexAttachments = particleShaderProgram.getVertexAttachments();
 
     const std::vector<vkcv::VertexBufferBinding> vertexBufferBindings = {
-            vkcv::vertexBufferBinding(vertexBuffer.getHandle())
+            vkcv::vertexBufferBinding(vertexBuffer)
 	};
 
     std::vector<vkcv::VertexBinding> bindings;
diff --git a/src/vkcv/AccelerationStructureManager.cpp b/src/vkcv/AccelerationStructureManager.cpp
index 6b08cfd9..5a31fbd0 100644
--- a/src/vkcv/AccelerationStructureManager.cpp
+++ b/src/vkcv/AccelerationStructureManager.cpp
@@ -270,19 +270,15 @@ namespace vkcv {
 		
 		for (const GeometryData &data : geometryData) {
 			const auto vertexBufferAddress = bufferManager.getBufferDeviceAddress(
-					data.getVertexBufferBinding().buffer
-			) + data.getVertexBufferBinding().offset;
+					data.getVertexBufferBinding().m_buffer
+			) + data.getVertexBufferBinding().m_offset;
 			
 			const auto indexBufferAddress = bufferManager.getBufferDeviceAddress(
 					data.getIndexBuffer()
 			);
 			
 			const auto vertexStride = data.getVertexStride();
-			const auto vertexBufferSize = bufferManager.getBufferSize(
-					data.getVertexBufferBinding().buffer
-			);
-			
-			const auto vertexCount = (vertexBufferSize / vertexStride);
+			const auto maxVertex = data.getMaxVertexIndex();
 			
 			const vk::Format vertexFormat = getVertexFormat(data.getGeometryVertexType());
 			const vk::IndexType indexType = getIndexType(data.getIndexBitCount());
@@ -291,7 +287,7 @@ namespace vkcv {
 					vertexFormat,
 					vertexBufferAddress,
 					vertexStride,
-					static_cast<uint32_t>(vertexCount - 1),
+					static_cast<uint32_t>(maxVertex),
 					indexType,
 					indexBufferAddress,
 					transformBufferAddress
diff --git a/src/vkcv/Core.cpp b/src/vkcv/Core.cpp
index 4f37b9b2..142f4fff 100644
--- a/src/vkcv/Core.cpp
+++ b/src/vkcv/Core.cpp
@@ -382,15 +382,21 @@ namespace vkcv {
 		for (uint32_t i = 0; i < vertexData.getVertexBufferBindings().size(); i++) {
 			const auto &vertexBinding = vertexData.getVertexBufferBindings() [i];
 
-			cmdBuffer.bindVertexBuffers(i, bufferManager.getBuffer(vertexBinding.buffer),
-										vertexBinding.offset);
+			cmdBuffer.bindVertexBuffers(
+					i,
+					bufferManager.getBuffer(vertexBinding.m_buffer),
+					vertexBinding.m_offset
+			);
 		}
 
 		for (const auto &usage : drawcall.getDescriptorSetUsages()) {
 			cmdBuffer.bindDescriptorSets(
-				vk::PipelineBindPoint::eGraphics, pipelineLayout, usage.location,
+				vk::PipelineBindPoint::eGraphics,
+				pipelineLayout,
+				usage.location,
 				descriptorSetManager.getDescriptorSet(usage.descriptorSet).vulkanHandle,
-				usage.dynamicOffsets);
+				usage.dynamicOffsets
+			);
 		}
 
 		if (pushConstants.getSizePerDrawcall() > 0) {
@@ -534,8 +540,11 @@ namespace vkcv {
 		for (uint32_t i = 0; i < vertexData.getVertexBufferBindings().size(); i++) {
 			const auto &vertexBinding = vertexData.getVertexBufferBindings() [i];
 
-			cmdBuffer.bindVertexBuffers(i, bufferManager.getBuffer(vertexBinding.buffer),
-										vertexBinding.offset);
+			cmdBuffer.bindVertexBuffers(
+					i,
+					bufferManager.getBuffer(vertexBinding.m_buffer),
+					vertexBinding.m_offset
+			);
 		}
 
 		if (pushConstantData.getSizePerDrawcall() > 0) {
diff --git a/src/vkcv/GeometryData.cpp b/src/vkcv/GeometryData.cpp
index 35d29128..1274f752 100644
--- a/src/vkcv/GeometryData.cpp
+++ b/src/vkcv/GeometryData.cpp
@@ -5,17 +5,17 @@ namespace vkcv {
 	
 	GeometryData::GeometryData() :
 		m_vertexBinding({}),
-		m_vertexStride(0),
+		m_maxVertexIndex(0),
 		m_vertexType(GeometryVertexType::UNDEFINED),
 		m_indices(),
 		m_indexBitCount(IndexBitCount::Bit16),
 		m_count(0) {}
 	
 	GeometryData::GeometryData(const VertexBufferBinding &binding,
-							   uint32_t stride,
+							   size_t maxVertexIndex,
 							   GeometryVertexType geometryVertexType) :
 		m_vertexBinding(binding),
-		m_vertexStride(stride),
+		m_maxVertexIndex(maxVertexIndex),
 		m_vertexType(geometryVertexType),
 		m_indices(),
 		m_indexBitCount(IndexBitCount::Bit16),
@@ -30,7 +30,11 @@ namespace vkcv {
 	}
 	
 	uint32_t GeometryData::getVertexStride() const {
-		return m_vertexStride;
+		return m_vertexBinding.m_stride;
+	}
+	
+	size_t GeometryData::getMaxVertexIndex() const {
+		return m_maxVertexIndex;
 	}
 	
 	GeometryVertexType GeometryData::getGeometryVertexType() const {
diff --git a/src/vkcv/VertexData.cpp b/src/vkcv/VertexData.cpp
index a1fcfb36..70c99059 100644
--- a/src/vkcv/VertexData.cpp
+++ b/src/vkcv/VertexData.cpp
@@ -3,10 +3,13 @@
 
 namespace vkcv {
 
-	VertexBufferBinding vertexBufferBinding(const BufferHandle &buffer, size_t offset) {
+	VertexBufferBinding vertexBufferBinding(const BufferHandle &buffer,
+											size_t stride,
+											size_t offset) {
 		VertexBufferBinding binding;
-		binding.buffer = buffer;
-		binding.offset = offset;
+		binding.m_buffer = buffer;
+		binding.m_stride = stride;
+		binding.m_offset = offset;
 		return binding;
 	}
 
-- 
GitLab