renderling/cull/
shader.rs

1use crabslab::{Array, Id, Slab, SlabItem};
2use glam::{UVec2, UVec3, Vec2, Vec3Swizzles};
3#[allow(unused_imports)]
4use spirv_std::num_traits::Float;
5use spirv_std::{
6    arch::IndexUnchecked,
7    image::{sample_with, Image, ImageWithMethods},
8    spirv,
9};
10
11use crate::{
12    draw::DrawIndirectArgs, geometry::shader::GeometryDescriptor,
13    primitive::shader::PrimitiveDescriptor,
14};
15
16#[spirv(compute(threads(16)))]
17pub fn compute_culling(
18    #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] stage_slab: &[u32],
19    #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] depth_pyramid_slab: &[u32],
20    #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] args: &mut [DrawIndirectArgs],
21    #[spirv(global_invocation_id)] global_id: UVec3,
22) {
23    let gid = global_id.x as usize;
24    if gid >= args.len() {
25        return;
26    }
27
28    crate::println!("gid: {gid}");
29    // Get the draw arg
30    let arg = unsafe { args.index_unchecked_mut(gid) };
31    // Get the primitive using the draw arg's primitive id
32    let primitive_id = Id::<PrimitiveDescriptor>::new(arg.first_instance);
33    let primitive = stage_slab.read_unchecked(primitive_id);
34    crate::println!("primitive: {primitive_id:?}");
35
36    arg.vertex_count = primitive.get_vertex_count();
37    arg.instance_count = if primitive.visible { 1 } else { 0 };
38
39    if primitive.bounds.radius == 0.0 {
40        crate::println!("primitive bounding radius is zero, cannot cull");
41        return;
42    }
43
44    let config: GeometryDescriptor = stage_slab.read(Id::new(0));
45    if !config.perform_frustum_culling {
46        return;
47    }
48
49    let camera = stage_slab.read(config.camera_id);
50    let model = stage_slab.read(primitive.transform_id);
51    // Compute frustum culling, and then occlusion culling, if need be
52    let (primitive_is_inside_frustum, sphere_in_world_coords) =
53        primitive.bounds.is_inside_camera_view(&camera, model);
54
55    if primitive_is_inside_frustum {
56        arg.instance_count = 1;
57        crate::println!("primitive is inside frustum");
58        crate::println!("znear: {}", camera.frustum().planes[0]);
59        crate::println!(" zfar: {}", camera.frustum().planes[5]);
60        if !config.perform_occlusion_culling {
61            return;
62        }
63
64        // Compute occlusion culling using the hierachical z-buffer.
65        let hzb_desc = depth_pyramid_slab.read_unchecked::<DepthPyramidDescriptor>(0u32.into());
66        let viewport_size = Vec2::new(hzb_desc.size.x as f32, hzb_desc.size.y as f32);
67        let sphere_aabb = sphere_in_world_coords.project_onto_viewport(&camera, viewport_size);
68        crate::println!("sphere_aabb: {sphere_aabb:#?}");
69
70        let size_in_pixels = sphere_aabb.max.xy() - sphere_aabb.min.xy();
71        let size_in_pixels = if size_in_pixels.x > size_in_pixels.y {
72            size_in_pixels.x
73        } else {
74            size_in_pixels.y
75        };
76        crate::println!("primitive size in pixels: {size_in_pixels}");
77
78        let mip_level = size_in_pixels.log2().floor() as u32;
79        let max_mip_level = hzb_desc.mip.len() as u32 - 1;
80        let mip_level = if mip_level > max_mip_level {
81            crate::println!("mip_level maxed out at {mip_level}, setting to {max_mip_level}");
82            max_mip_level
83        } else {
84            mip_level
85        };
86        crate::println!(
87            "selected mip level: {mip_level} {}x{}",
88            viewport_size.x as u32 >> mip_level,
89            viewport_size.y as u32 >> mip_level
90        );
91
92        let center = sphere_aabb.center().xy();
93        crate::println!("center: {center}");
94
95        let x = center.x.round() as u32 >> mip_level;
96        let y = center.y.round() as u32 >> mip_level;
97        crate::println!("mip (x, y): ({x}, {y})");
98
99        let depth_id = hzb_desc.id_of_depth(mip_level, UVec2::new(x, y), depth_pyramid_slab);
100        let depth_in_hzb = depth_pyramid_slab.read_unchecked(depth_id);
101        crate::println!("depth_in_hzb: {depth_in_hzb}");
102
103        let depth_of_sphere = sphere_aabb.min.z;
104        crate::println!("depth_of_sphere: {depth_of_sphere}");
105
106        let primitive_is_behind_something = depth_of_sphere > depth_in_hzb;
107        let primitive_surrounds_camera = depth_of_sphere > 1.0;
108
109        if primitive_is_behind_something || primitive_surrounds_camera {
110            crate::println!("CULLED");
111            arg.instance_count = 0;
112        }
113    } else {
114        arg.instance_count = 0;
115    }
116}
117
118/// A hierarchichal depth buffer.
119///
120/// AKA HZB
121#[derive(Clone, Copy, Default, SlabItem)]
122pub struct DepthPyramidDescriptor {
123    /// Size of the top layer mip.
124    pub size: UVec2,
125    /// Current mip level.
126    ///
127    /// This will be updated for each run of the downsample compute shader.
128    pub mip_level: u32,
129    /// Pointer to the mip data.
130    ///
131    /// This points to the depth data at each mip level.
132    ///
133    /// The depth data itself is somewhere else in the slab.
134    pub mip: Array<Array<f32>>,
135}
136
137impl DepthPyramidDescriptor {
138    fn should_skip_invocation(&self, global_invocation: UVec3) -> bool {
139        let current_size = self.size >> self.mip_level;
140        !(global_invocation.x < current_size.x && global_invocation.y < current_size.y)
141    }
142
143    #[cfg(test)]
144    pub fn size_at(&self, mip_level: u32) -> UVec2 {
145        UVec2::new(self.size.x >> mip_level, self.size.y >> mip_level)
146    }
147
148    /// Return the [`Id`] of the depth at the given `mip_level` and coordinate.
149    pub fn id_of_depth(&self, mip_level: u32, coord: UVec2, slab: &[u32]) -> Id<f32> {
150        let mip_array = slab.read(self.mip.at(mip_level as usize));
151        let width_at_mip = self.size.x >> mip_level;
152        let index = coord.y * width_at_mip + coord.x;
153        mip_array.at(index as usize)
154    }
155}
156
157pub type DepthImage2d = Image!(2D, type=f32, sampled, depth);
158pub type DepthImage2dMultisampled = Image!(2D, type=f32, sampled, depth, multisampled);
159
160/// Copies a depth texture to the top mip of a pyramid of mips.
161///
162/// It is assumed that a [`DepthPyramidDescriptor`] is stored at index `0` in
163/// the given slab.
164#[spirv(compute(threads(16, 16, 1)))]
165pub fn compute_copy_depth_to_pyramid(
166    #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] slab: &mut [u32],
167    #[spirv(descriptor_set = 0, binding = 1)] depth_texture: &DepthImage2d,
168    #[spirv(global_invocation_id)] global_id: UVec3,
169) {
170    let desc = slab.read_unchecked::<DepthPyramidDescriptor>(0u32.into());
171    if desc.should_skip_invocation(global_id) {
172        return;
173    }
174
175    let depth = depth_texture
176        .fetch_with(global_id.xy(), sample_with::lod(0))
177        .x;
178    let dest_id = desc.id_of_depth(0, global_id.xy(), slab);
179    slab.write(dest_id, &depth);
180}
181
182/// Copies a depth texture to the top mip of a pyramid of mips.
183///
184/// It is assumed that a [`DepthPyramidDescriptor`] is stored at index `0` in
185/// the given slab.
186#[spirv(compute(threads(16, 16, 1)))]
187pub fn compute_copy_depth_to_pyramid_multisampled(
188    #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] slab: &mut [u32],
189    #[spirv(descriptor_set = 0, binding = 1)] depth_texture: &DepthImage2dMultisampled,
190    #[spirv(global_invocation_id)] global_id: UVec3,
191) {
192    let desc = slab.read_unchecked::<DepthPyramidDescriptor>(0u32.into());
193    if desc.should_skip_invocation(global_id) {
194        return;
195    }
196
197    let depth = depth_texture
198        .fetch_with(global_id.xy(), sample_with::sample_index(0))
199        .x;
200    let dest_id = desc.id_of_depth(0, global_id.xy(), slab);
201    slab.write(dest_id, &depth);
202}
203
204/// Downsample from `DepthPyramidDescriptor::mip_level-1` into
205/// `DepthPyramidDescriptor::mip_level`.
206///
207/// It is assumed that a [`DepthPyramidDescriptor`] is stored at index `0` in
208/// the given slab.
209///
210/// The `DepthPyramidDescriptor`'s `mip_level` field will point to that of the
211/// mip level being downsampled to (the mip level being written into).
212///
213/// This shader should be called in a loop from from `1..mip_count`.
214#[spirv(compute(threads(16, 16, 1)))]
215pub fn compute_downsample_depth_pyramid(
216    #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] slab: &mut [u32],
217    #[spirv(global_invocation_id)] global_id: UVec3,
218) {
219    let desc = slab.read_unchecked::<DepthPyramidDescriptor>(0u32.into());
220    if desc.should_skip_invocation(global_id) {
221        return;
222    }
223    // Sample the texel.
224    //
225    // The texel will look like this:
226    //
227    // a b
228    // c d
229    let a_coord = global_id.xy() * 2;
230    let a = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord, slab));
231    let b = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord + UVec2::new(1, 0), slab));
232    let c = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord + UVec2::new(0, 1), slab));
233    let d = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord + UVec2::new(1, 1), slab));
234    // Take the maximum depth of the region (max depth means furthest away)
235    let depth_value = a.max(b).max(c).max(d);
236    // Write the texel in the next mip
237    let depth_id = desc.id_of_depth(desc.mip_level, global_id.xy(), slab);
238    slab.write(depth_id, &depth_value);
239}