15 struct RayTracingPipelineDesc {
16 uint32_t maxRecursionDepth;
17 uint32_t missGroupCount;
18 uint32_t hitGroupCount;
19 std::vector<VkRayTracingShaderGroupCreateInfoKHR> shaderGroups;
21 MANDRILL_API RayTracingPipelineDesc(uint32_t missGroupCount = 1, uint32_t hitGroupCount = 1,
22 uint32_t maxRecursionDepth = 1)
23 : maxRecursionDepth(maxRecursionDepth), missGroupCount(missGroupCount), hitGroupCount(hitGroupCount)
25 shaderGroups.resize(1 + missGroupCount + hitGroupCount);
28 MANDRILL_API
void setRayGen(uint32_t stage)
30 VkRayTracingShaderGroupCreateInfoKHR ci = {
31 .sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
32 .type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
33 .generalShader = stage,
34 .closestHitShader = VK_SHADER_UNUSED_KHR,
35 .anyHitShader = VK_SHADER_UNUSED_KHR,
36 .intersectionShader = VK_SHADER_UNUSED_KHR,
41 MANDRILL_API
void setMissGroup(uint32_t missGroup, uint32_t stage)
43 if (missGroup >= missGroupCount) {
44 Log::Error(
"Miss group {} exceeds hitGroupCount {}", missGroup, missGroupCount);
48 VkRayTracingShaderGroupCreateInfoKHR ci = {
49 .sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
50 .type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
51 .generalShader = stage,
52 .closestHitShader = VK_SHADER_UNUSED_KHR,
53 .anyHitShader = VK_SHADER_UNUSED_KHR,
54 .intersectionShader = VK_SHADER_UNUSED_KHR,
56 shaderGroups[1 + missGroup] = ci;
60 setHitGroup(uint32_t hitGroup, uint32_t closestHitStage, uint32_t anyHitStage = VK_SHADER_UNUSED_KHR,
61 uint32_t intersectionStage = VK_SHADER_UNUSED_KHR,
62 VkRayTracingShaderGroupTypeKHR type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR)
64 if (hitGroup >= hitGroupCount) {
65 Log::Error(
"Hit group {} exceeds hitGroupCount {}", hitGroup, hitGroupCount);
69 VkRayTracingShaderGroupCreateInfoKHR ci = {
70 .sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
71 .type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR,
72 .generalShader = VK_SHADER_UNUSED_KHR,
73 .closestHitShader = closestHitStage,
74 .anyHitShader = anyHitStage,
75 .intersectionShader = intersectionStage,
77 shaderGroups[1 + missGroupCount + hitGroup] = ci;
96 MANDRILL_API
RayTracingPipeline(ptr<Device> pDevice, ptr<Layout> pLayout, ptr<Shader> pShader,
97 const RayTracingPipelineDesc& desc = RayTracingPipelineDesc());
104 MANDRILL_API
void bind(VkCommandBuffer cmd);
112 MANDRILL_API
void write(VkCommandBuffer cmd, VkImage image);
120 MANDRILL_API
void read(VkCommandBuffer cmd, VkImage image);
134 VkDeviceAddress address = mpShaderBindingTableBuffer->getDeviceAddress();
135 VkStridedDeviceAddressRegionKHR region = {
136 .deviceAddress = address,
137 .stride = mGroupSizeAligned,
138 .size = mGroupSizeAligned,
147 MANDRILL_API VkStridedDeviceAddressRegionKHR
getMissSBT()
const
149 VkDeviceAddress address = mpShaderBindingTableBuffer->getDeviceAddress();
150 VkStridedDeviceAddressRegionKHR region = {
151 .deviceAddress = address + mGroupSizeAligned,
152 .stride = mGroupSizeAligned,
153 .size = mGroupSizeAligned,
162 MANDRILL_API VkStridedDeviceAddressRegionKHR
getHitSBT()
const
164 VkDeviceAddress address = mpShaderBindingTableBuffer->getDeviceAddress();
165 VkStridedDeviceAddressRegionKHR region = {
166 .deviceAddress = address + mGroupSizeAligned * (1 + mMissGroupCount),
167 .stride = mGroupSizeAligned,
168 .size = mGroupSizeAligned,
177 MANDRILL_API VkStridedDeviceAddressRegionKHR
getCallSBT()
const
180 VkStridedDeviceAddressRegionKHR region = {0};
185 void createPipeline();
186 void createShaderBindingTable();
188 std::vector<VkRayTracingShaderGroupCreateInfoKHR> mShaderGroups;
190 ptr<Buffer> mpShaderBindingTableBuffer;
191 uint32_t mGroupSizeAligned;
193 uint32_t mMaxRecursionDepth;
194 uint32_t mMissGroupCount;
195 uint32_t mHitGroupCount;
197 ptr<Descriptor> mpStorageImageDescriptor;
Pipeline class for managing Vulkan graphics pipelines.
Definition Pipeline.h:76
Ray tracing pipeline class that manages the creation and usage of a ray tracing pipeline in Vulkan.
Definition RayTracingPipeline.h:85
MANDRILL_API VkStridedDeviceAddressRegionKHR getHitSBT() const
Get the hit group shader binding table record.
Definition RayTracingPipeline.h:162
MANDRILL_API void bind(VkCommandBuffer cmd)
Bind pipeline for rendering.
Definition RayTracingPipeline.cpp:24
MANDRILL_API void read(VkCommandBuffer cmd, VkImage image)
Transition and image for reading from.
Definition RayTracingPipeline.cpp:36
MANDRILL_API VkStridedDeviceAddressRegionKHR getMissSBT() const
Get the miss group shader binding table record.
Definition RayTracingPipeline.h:147
MANDRILL_API VkStridedDeviceAddressRegionKHR getRayGenSBT() const
Get the raygen group shader binding table record.
Definition RayTracingPipeline.h:132
MANDRILL_API void write(VkCommandBuffer cmd, VkImage image)
Transition an image for writing to.
Definition RayTracingPipeline.cpp:29
MANDRILL_API VkStridedDeviceAddressRegionKHR getCallSBT() const
[NOT IMPLEMENTED] Get the call group shader binding table record.
Definition RayTracingPipeline.h:177
MANDRILL_API void recreate()
Recreate a pipeline. Call this if shader source code has changed and should be reloaded.
Definition RayTracingPipeline.cpp:43