1use crabslab::{Array, Id, Slab, SlabItem};
2use glam::{UVec2, UVec3, Vec2, Vec3Swizzles};
3#[allow(unused_imports)]
4use spirv_std::num_traits::Float;
5use spirv_std::{
6 arch::IndexUnchecked,
7 image::{sample_with, Image, ImageWithMethods},
8 spirv,
9};
10
11use crate::{
12 draw::DrawIndirectArgs, geometry::shader::GeometryDescriptor,
13 primitive::shader::PrimitiveDescriptor,
14};
15
16#[spirv(compute(threads(16)))]
17pub fn compute_culling(
18 #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] stage_slab: &[u32],
19 #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] depth_pyramid_slab: &[u32],
20 #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] args: &mut [DrawIndirectArgs],
21 #[spirv(global_invocation_id)] global_id: UVec3,
22) {
23 let gid = global_id.x as usize;
24 if gid >= args.len() {
25 return;
26 }
27
28 crate::println!("gid: {gid}");
29 let arg = unsafe { args.index_unchecked_mut(gid) };
31 let primitive_id = Id::<PrimitiveDescriptor>::new(arg.first_instance);
33 let primitive = stage_slab.read_unchecked(primitive_id);
34 crate::println!("primitive: {primitive_id:?}");
35
36 arg.vertex_count = primitive.get_vertex_count();
37 arg.instance_count = if primitive.visible { 1 } else { 0 };
38
39 if primitive.bounds.radius == 0.0 {
40 crate::println!("primitive bounding radius is zero, cannot cull");
41 return;
42 }
43
44 let config: GeometryDescriptor = stage_slab.read(Id::new(0));
45 if !config.perform_frustum_culling {
46 return;
47 }
48
49 let camera = stage_slab.read(config.camera_id);
50 let model = stage_slab.read(primitive.transform_id);
51 let (primitive_is_inside_frustum, sphere_in_world_coords) =
53 primitive.bounds.is_inside_camera_view(&camera, model);
54
55 if primitive_is_inside_frustum {
56 arg.instance_count = 1;
57 crate::println!("primitive is inside frustum");
58 crate::println!("znear: {}", camera.frustum().planes[0]);
59 crate::println!(" zfar: {}", camera.frustum().planes[5]);
60 if !config.perform_occlusion_culling {
61 return;
62 }
63
64 let hzb_desc = depth_pyramid_slab.read_unchecked::<DepthPyramidDescriptor>(0u32.into());
66 let viewport_size = Vec2::new(hzb_desc.size.x as f32, hzb_desc.size.y as f32);
67 let sphere_aabb = sphere_in_world_coords.project_onto_viewport(&camera, viewport_size);
68 crate::println!("sphere_aabb: {sphere_aabb:#?}");
69
70 let size_in_pixels = sphere_aabb.max.xy() - sphere_aabb.min.xy();
71 let size_in_pixels = if size_in_pixels.x > size_in_pixels.y {
72 size_in_pixels.x
73 } else {
74 size_in_pixels.y
75 };
76 crate::println!("primitive size in pixels: {size_in_pixels}");
77
78 let mip_level = size_in_pixels.log2().floor() as u32;
79 let max_mip_level = hzb_desc.mip.len() as u32 - 1;
80 let mip_level = if mip_level > max_mip_level {
81 crate::println!("mip_level maxed out at {mip_level}, setting to {max_mip_level}");
82 max_mip_level
83 } else {
84 mip_level
85 };
86 crate::println!(
87 "selected mip level: {mip_level} {}x{}",
88 viewport_size.x as u32 >> mip_level,
89 viewport_size.y as u32 >> mip_level
90 );
91
92 let center = sphere_aabb.center().xy();
93 crate::println!("center: {center}");
94
95 let x = center.x.round() as u32 >> mip_level;
96 let y = center.y.round() as u32 >> mip_level;
97 crate::println!("mip (x, y): ({x}, {y})");
98
99 let depth_id = hzb_desc.id_of_depth(mip_level, UVec2::new(x, y), depth_pyramid_slab);
100 let depth_in_hzb = depth_pyramid_slab.read_unchecked(depth_id);
101 crate::println!("depth_in_hzb: {depth_in_hzb}");
102
103 let depth_of_sphere = sphere_aabb.min.z;
104 crate::println!("depth_of_sphere: {depth_of_sphere}");
105
106 let primitive_is_behind_something = depth_of_sphere > depth_in_hzb;
107 let primitive_surrounds_camera = depth_of_sphere > 1.0;
108
109 if primitive_is_behind_something || primitive_surrounds_camera {
110 crate::println!("CULLED");
111 arg.instance_count = 0;
112 }
113 } else {
114 arg.instance_count = 0;
115 }
116}
117
118#[derive(Clone, Copy, Default, SlabItem)]
122pub struct DepthPyramidDescriptor {
123 pub size: UVec2,
125 pub mip_level: u32,
129 pub mip: Array<Array<f32>>,
135}
136
137impl DepthPyramidDescriptor {
138 fn should_skip_invocation(&self, global_invocation: UVec3) -> bool {
139 let current_size = self.size >> self.mip_level;
140 !(global_invocation.x < current_size.x && global_invocation.y < current_size.y)
141 }
142
143 #[cfg(test)]
144 pub fn size_at(&self, mip_level: u32) -> UVec2 {
145 UVec2::new(self.size.x >> mip_level, self.size.y >> mip_level)
146 }
147
148 pub fn id_of_depth(&self, mip_level: u32, coord: UVec2, slab: &[u32]) -> Id<f32> {
150 let mip_array = slab.read(self.mip.at(mip_level as usize));
151 let width_at_mip = self.size.x >> mip_level;
152 let index = coord.y * width_at_mip + coord.x;
153 mip_array.at(index as usize)
154 }
155}
156
157pub type DepthImage2d = Image!(2D, type=f32, sampled, depth);
158pub type DepthImage2dMultisampled = Image!(2D, type=f32, sampled, depth, multisampled);
159
160#[spirv(compute(threads(16, 16, 1)))]
165pub fn compute_copy_depth_to_pyramid(
166 #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] slab: &mut [u32],
167 #[spirv(descriptor_set = 0, binding = 1)] depth_texture: &DepthImage2d,
168 #[spirv(global_invocation_id)] global_id: UVec3,
169) {
170 let desc = slab.read_unchecked::<DepthPyramidDescriptor>(0u32.into());
171 if desc.should_skip_invocation(global_id) {
172 return;
173 }
174
175 let depth = depth_texture
176 .fetch_with(global_id.xy(), sample_with::lod(0))
177 .x;
178 let dest_id = desc.id_of_depth(0, global_id.xy(), slab);
179 slab.write(dest_id, &depth);
180}
181
182#[spirv(compute(threads(16, 16, 1)))]
187pub fn compute_copy_depth_to_pyramid_multisampled(
188 #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] slab: &mut [u32],
189 #[spirv(descriptor_set = 0, binding = 1)] depth_texture: &DepthImage2dMultisampled,
190 #[spirv(global_invocation_id)] global_id: UVec3,
191) {
192 let desc = slab.read_unchecked::<DepthPyramidDescriptor>(0u32.into());
193 if desc.should_skip_invocation(global_id) {
194 return;
195 }
196
197 let depth = depth_texture
198 .fetch_with(global_id.xy(), sample_with::sample_index(0))
199 .x;
200 let dest_id = desc.id_of_depth(0, global_id.xy(), slab);
201 slab.write(dest_id, &depth);
202}
203
204#[spirv(compute(threads(16, 16, 1)))]
215pub fn compute_downsample_depth_pyramid(
216 #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] slab: &mut [u32],
217 #[spirv(global_invocation_id)] global_id: UVec3,
218) {
219 let desc = slab.read_unchecked::<DepthPyramidDescriptor>(0u32.into());
220 if desc.should_skip_invocation(global_id) {
221 return;
222 }
223 let a_coord = global_id.xy() * 2;
230 let a = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord, slab));
231 let b = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord + UVec2::new(1, 0), slab));
232 let c = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord + UVec2::new(0, 1), slab));
233 let d = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord + UVec2::new(1, 1), slab));
234 let depth_value = a.max(b).max(c).max(d);
236 let depth_id = desc.id_of_depth(desc.mip_level, global_id.xy(), slab);
238 slab.write(depth_id, &depth_value);
239}