diff --git a/src/vkcv/DrawcallRecording.cpp b/src/vkcv/DrawcallRecording.cpp
index d2112a5a330d5373249ae4999ea8b2d84b3730b4..bfea043d4da11293fc7ca445efd3565ee8b385e1 100644
--- a/src/vkcv/DrawcallRecording.cpp
+++ b/src/vkcv/DrawcallRecording.cpp
@@ -54,9 +54,9 @@ namespace vkcv {
 
     void InitMeshShaderDrawFunctions(vk::Device device)
     {
-        MeshShaderFunctions.cmdDrawMeshTasks = reinterpret_cast<PFN_vkCmdDrawMeshTasksNV>(vkGetDeviceProcAddr(VkDevice(device), "vkCmdDrawMeshTasksNV"));
-        MeshShaderFunctions.cmdDrawMeshTasksIndirect = reinterpret_cast<PFN_vkCmdDrawMeshTasksIndirectNV>(vkGetDeviceProcAddr(VkDevice(device), "vkCmdDrawMeshTasksIndirectNV"));
-        MeshShaderFunctions.cmdDrawMeshTasksIndirectCount = reinterpret_cast<PFN_vkCmdDrawMeshTasksIndirectCountNV>(vkGetDeviceProcAddr(VkDevice(device), "vkCmdDrawMeshTasksIndirectCountNV"));
+        MeshShaderFunctions.cmdDrawMeshTasks = PFN_vkCmdDrawMeshTasksNV(device.getProcAddr("vkCmdDrawMeshTasksNV"));
+        MeshShaderFunctions.cmdDrawMeshTasksIndirect = PFN_vkCmdDrawMeshTasksIndirectNV(device.getProcAddr("vkCmdDrawMeshTasksIndirectNV"));
+        MeshShaderFunctions.cmdDrawMeshTasksIndirectCount = PFN_vkCmdDrawMeshTasksIndirectCountNV (device.getProcAddr( "vkCmdDrawMeshTasksIndirectCountNV"));
     }
 
     void recordMeshShaderDrawcall(