Mandrill 2025.6.0
Loading...
Searching...
No Matches
RayTracingPipeline.h
1#pragma once
2
3#include "Common.h"
4
5#include "Buffer.h"
6#include "Descriptor.h"
7#include "Device.h"
8#include "Layout.h"
9#include "Pipeline.h"
10#include "Shader.h"
11
12
13namespace Mandrill
14{
15 struct RayTracingPipelineDesc {
16 uint32_t maxRecursionDepth;
17 uint32_t missGroupCount;
18 uint32_t hitGroupCount;
19 std::vector<VkRayTracingShaderGroupCreateInfoKHR> shaderGroups;
20
21 MANDRILL_API RayTracingPipelineDesc(uint32_t missGroupCount = 1, uint32_t hitGroupCount = 1,
22 uint32_t maxRecursionDepth = 1)
23 : maxRecursionDepth(maxRecursionDepth), missGroupCount(missGroupCount), hitGroupCount(hitGroupCount)
24 {
25 shaderGroups.resize(1 + missGroupCount + hitGroupCount);
26 }
27
28 MANDRILL_API void setRayGen(uint32_t stage)
29 {
30 VkRayTracingShaderGroupCreateInfoKHR ci = {
31 .sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
32 .type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
33 .generalShader = stage,
34 .closestHitShader = VK_SHADER_UNUSED_KHR,
35 .anyHitShader = VK_SHADER_UNUSED_KHR,
36 .intersectionShader = VK_SHADER_UNUSED_KHR,
37 };
38 shaderGroups[0] = ci;
39 }
40
41 MANDRILL_API void setMissGroup(uint32_t missGroup, uint32_t stage)
42 {
43 if (missGroup >= missGroupCount) {
44 Log::Error("Miss group {} exceeds hitGroupCount {}", missGroup, missGroupCount);
45 return;
46 }
47
48 VkRayTracingShaderGroupCreateInfoKHR ci = {
49 .sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
50 .type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
51 .generalShader = stage,
52 .closestHitShader = VK_SHADER_UNUSED_KHR,
53 .anyHitShader = VK_SHADER_UNUSED_KHR,
54 .intersectionShader = VK_SHADER_UNUSED_KHR,
55 };
56 shaderGroups[1 + missGroup] = ci;
57 }
58
59 MANDRILL_API void
60 setHitGroup(uint32_t hitGroup, uint32_t closestHitStage, uint32_t anyHitStage = VK_SHADER_UNUSED_KHR,
61 uint32_t intersectionStage = VK_SHADER_UNUSED_KHR,
62 VkRayTracingShaderGroupTypeKHR type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR)
63 {
64 if (hitGroup >= hitGroupCount) {
65 Log::Error("Hit group {} exceeds hitGroupCount {}", hitGroup, hitGroupCount);
66 return;
67 }
68
69 VkRayTracingShaderGroupCreateInfoKHR ci = {
70 .sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
71 .type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR,
72 .generalShader = VK_SHADER_UNUSED_KHR,
73 .closestHitShader = closestHitStage,
74 .anyHitShader = anyHitStage,
75 .intersectionShader = intersectionStage,
76 };
77 shaderGroups[1 + missGroupCount + hitGroup] = ci;
78 }
79 };
80
85 {
86 public:
87 MANDRILL_NON_COPYABLE(RayTracingPipeline)
88
89
96 MANDRILL_API RayTracingPipeline(ptr<Device> pDevice, ptr<Layout> pLayout, ptr<Shader> pShader,
97 const RayTracingPipelineDesc& desc = RayTracingPipelineDesc());
98
104 MANDRILL_API void bind(VkCommandBuffer cmd);
105
112 MANDRILL_API void write(VkCommandBuffer cmd, VkImage image);
113
120 MANDRILL_API void read(VkCommandBuffer cmd, VkImage image);
121
126 MANDRILL_API void recreate();
127
132 MANDRILL_API VkStridedDeviceAddressRegionKHR getRayGenSBT() const
133 {
134 VkDeviceAddress address = mpShaderBindingTableBuffer->getDeviceAddress();
135 VkStridedDeviceAddressRegionKHR region = {
136 .deviceAddress = address,
137 .stride = mGroupSizeAligned,
138 .size = mGroupSizeAligned,
139 };
140 return region;
141 }
142
147 MANDRILL_API VkStridedDeviceAddressRegionKHR getMissSBT() const
148 {
149 VkDeviceAddress address = mpShaderBindingTableBuffer->getDeviceAddress();
150 VkStridedDeviceAddressRegionKHR region = {
151 .deviceAddress = address + mGroupSizeAligned,
152 .stride = mGroupSizeAligned,
153 .size = mGroupSizeAligned,
154 };
155 return region;
156 }
157
162 MANDRILL_API VkStridedDeviceAddressRegionKHR getHitSBT() const
163 {
164 VkDeviceAddress address = mpShaderBindingTableBuffer->getDeviceAddress();
165 VkStridedDeviceAddressRegionKHR region = {
166 .deviceAddress = address + mGroupSizeAligned * (1 + mMissGroupCount),
167 .stride = mGroupSizeAligned,
168 .size = mGroupSizeAligned,
169 };
170 return region;
171 }
172
177 MANDRILL_API VkStridedDeviceAddressRegionKHR getCallSBT() const
178 {
179 // Not implemented
180 VkStridedDeviceAddressRegionKHR region = {0};
181 return region;
182 }
183
184 private:
185 void createPipeline();
186 void createShaderBindingTable();
187
188 std::vector<VkRayTracingShaderGroupCreateInfoKHR> mShaderGroups;
189
190 ptr<Buffer> mpShaderBindingTableBuffer;
191 uint32_t mGroupSizeAligned;
192
193 uint32_t mMaxRecursionDepth;
194 uint32_t mMissGroupCount;
195 uint32_t mHitGroupCount;
196
197 ptr<Descriptor> mpStorageImageDescriptor;
198 };
199} // namespace Mandrill
Pipeline class for managing Vulkan graphics pipelines.
Definition Pipeline.h:76
Ray tracing pipeline class that manages the creation and usage of a ray tracing pipeline in Vulkan.
Definition RayTracingPipeline.h:85
MANDRILL_API VkStridedDeviceAddressRegionKHR getHitSBT() const
Get the hit group shader binding table record.
Definition RayTracingPipeline.h:162
MANDRILL_API void bind(VkCommandBuffer cmd)
Bind pipeline for rendering.
Definition RayTracingPipeline.cpp:24
MANDRILL_API void read(VkCommandBuffer cmd, VkImage image)
Transition and image for reading from.
Definition RayTracingPipeline.cpp:36
MANDRILL_API VkStridedDeviceAddressRegionKHR getMissSBT() const
Get the miss group shader binding table record.
Definition RayTracingPipeline.h:147
MANDRILL_API VkStridedDeviceAddressRegionKHR getRayGenSBT() const
Get the raygen group shader binding table record.
Definition RayTracingPipeline.h:132
MANDRILL_API void write(VkCommandBuffer cmd, VkImage image)
Transition an image for writing to.
Definition RayTracingPipeline.cpp:29
MANDRILL_API VkStridedDeviceAddressRegionKHR getCallSBT() const
[NOT IMPLEMENTED] Get the call group shader binding table record.
Definition RayTracingPipeline.h:177
MANDRILL_API void recreate()
Recreate a pipeline. Call this if shader source code has changed and should be reloaded.
Definition RayTracingPipeline.cpp:43