renderling/cull/
cpu.rs

1//! CPU side of compute culling.
2
3use craballoc::{
4    prelude::{GpuArray, Hybrid, SlabAllocator, SlabAllocatorError},
5    runtime::WgpuRuntime,
6    slab::SlabBuffer,
7};
8use crabslab::{Array, Slab};
9use glam::UVec2;
10use snafu::{OptionExt, Snafu};
11
12use crate::{bindgroup::ManagedBindGroup, texture::Texture};
13
14use super::shader::DepthPyramidDescriptor;
15
16#[derive(Debug, Snafu)]
17pub enum CullingError {
18    #[snafu(display(
19        "Texture is not a depth texture, expected '{:?}' but saw '{seen:?}'",
20        Texture::DEPTH_FORMAT
21    ))]
22    NotADepthTexture { seen: wgpu::TextureFormat },
23
24    #[snafu(display("Missing depth pyramid mip {index}"))]
25    MissingMip { index: usize },
26
27    #[snafu(display("{source}"))]
28    SlabError { source: SlabAllocatorError },
29
30    #[snafu(display("Could not read mip {index}"))]
31    ReadMip { index: usize },
32}
33
34impl From<SlabAllocatorError> for CullingError {
35    fn from(source: SlabAllocatorError) -> Self {
36        CullingError::SlabError { source }
37    }
38}
39
40/// Computes frustum and occlusion culling on the GPU.
41pub struct ComputeCulling {
42    pipeline: wgpu::ComputePipeline,
43
44    pyramid_slab_buffer: SlabBuffer<wgpu::Buffer>,
45    stage_slab_buffer: SlabBuffer<wgpu::Buffer>,
46    indirect_slab_buffer: SlabBuffer<wgpu::Buffer>,
47
48    bindgroup_layout: wgpu::BindGroupLayout,
49    bindgroup: ManagedBindGroup,
50
51    pub(crate) compute_depth_pyramid: ComputeDepthPyramid,
52}
53
54impl ComputeCulling {
55    const LABEL: Option<&'static str> = Some("compute-culling");
56
57    fn new_bindgroup(
58        stage_slab_buffer: &wgpu::Buffer,
59        hzb_slab_buffer: &wgpu::Buffer,
60        indirect_buffer: &wgpu::Buffer,
61        layout: &wgpu::BindGroupLayout,
62        device: &wgpu::Device,
63    ) -> wgpu::BindGroup {
64        device.create_bind_group(&wgpu::BindGroupDescriptor {
65            label: Self::LABEL,
66            layout,
67            entries: &[
68                wgpu::BindGroupEntry {
69                    binding: 0,
70                    resource: wgpu::BindingResource::Buffer(
71                        stage_slab_buffer.as_entire_buffer_binding(),
72                    ),
73                },
74                wgpu::BindGroupEntry {
75                    binding: 1,
76                    resource: wgpu::BindingResource::Buffer(
77                        hzb_slab_buffer.as_entire_buffer_binding(),
78                    ),
79                },
80                wgpu::BindGroupEntry {
81                    binding: 2,
82                    resource: wgpu::BindingResource::Buffer(
83                        indirect_buffer.as_entire_buffer_binding(),
84                    ),
85                },
86            ],
87        })
88    }
89
90    pub fn new(
91        runtime: impl AsRef<WgpuRuntime>,
92        stage_slab_buffer: &SlabBuffer<wgpu::Buffer>,
93        indirect_slab_buffer: &SlabBuffer<wgpu::Buffer>,
94        depth_texture: &Texture,
95    ) -> Self {
96        let runtime = runtime.as_ref();
97        let device = &runtime.device;
98        let bindgroup_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
99            label: Self::LABEL,
100            entries: &[
101                wgpu::BindGroupLayoutEntry {
102                    binding: 0,
103                    visibility: wgpu::ShaderStages::COMPUTE,
104                    ty: wgpu::BindingType::Buffer {
105                        ty: wgpu::BufferBindingType::Storage { read_only: true },
106                        has_dynamic_offset: false,
107                        min_binding_size: None,
108                    },
109                    count: None,
110                },
111                wgpu::BindGroupLayoutEntry {
112                    binding: 1,
113                    visibility: wgpu::ShaderStages::COMPUTE,
114                    ty: wgpu::BindingType::Buffer {
115                        ty: wgpu::BufferBindingType::Storage { read_only: true },
116                        has_dynamic_offset: false,
117                        min_binding_size: None,
118                    },
119                    count: None,
120                },
121                wgpu::BindGroupLayoutEntry {
122                    binding: 2,
123                    visibility: wgpu::ShaderStages::COMPUTE,
124                    ty: wgpu::BindingType::Buffer {
125                        ty: wgpu::BufferBindingType::Storage { read_only: false },
126                        has_dynamic_offset: false,
127                        min_binding_size: None,
128                    },
129                    count: None,
130                },
131            ],
132        });
133        let linkage = crate::linkage::compute_culling::linkage(device);
134        let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
135            label: Self::LABEL,
136            layout: Some(
137                &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
138                    label: Self::LABEL,
139                    bind_group_layouts: &[&bindgroup_layout],
140                    push_constant_ranges: &[],
141                }),
142            ),
143            module: &linkage.module,
144            entry_point: Some(linkage.entry_point),
145            compilation_options: wgpu::PipelineCompilationOptions::default(),
146            cache: None,
147        });
148        let compute_depth_pyramid = ComputeDepthPyramid::new(runtime, depth_texture);
149        let pyramid_slab_buffer = compute_depth_pyramid
150            .compute_copy_depth
151            .pyramid_slab_buffer
152            .clone();
153        let bindgroup = Self::new_bindgroup(
154            stage_slab_buffer,
155            &pyramid_slab_buffer,
156            indirect_slab_buffer,
157            &bindgroup_layout,
158            device,
159        );
160        Self {
161            pipeline,
162            bindgroup_layout,
163            bindgroup: ManagedBindGroup::from(bindgroup),
164            compute_depth_pyramid,
165            pyramid_slab_buffer,
166            stage_slab_buffer: stage_slab_buffer.clone(),
167            indirect_slab_buffer: indirect_slab_buffer.clone(),
168        }
169    }
170
171    fn runtime(&self) -> &WgpuRuntime {
172        self.compute_depth_pyramid.depth_pyramid.slab.runtime()
173    }
174
175    pub fn run(&mut self, indirect_draw_count: u32, depth_texture: &Texture) {
176        log::trace!(
177            "indirect_draw_count: {indirect_draw_count}, sample_count: {}",
178            depth_texture.texture.sample_count()
179        );
180        // Compute the depth pyramid from last frame's depth buffer
181        self.compute_depth_pyramid.run(depth_texture);
182
183        let stage_slab_invalid = self.stage_slab_buffer.update_if_invalid();
184        let indirect_slab_invalid = self.indirect_slab_buffer.update_if_invalid();
185        let pyramid_slab_invalid = self.pyramid_slab_buffer.update_if_invalid();
186        let should_recreate_bindgroup =
187            stage_slab_invalid || indirect_slab_invalid || pyramid_slab_invalid;
188        log::trace!("stage_slab_invalid: {stage_slab_invalid}");
189        log::trace!("indirect_slab_invalid: {indirect_slab_invalid}");
190        log::trace!("pyramid_slab_invalid: {pyramid_slab_invalid}");
191        let bindgroup = self.bindgroup.get(should_recreate_bindgroup, || {
192            log::debug!("recreating compute-culling bindgroup");
193            Self::new_bindgroup(
194                &self.stage_slab_buffer,
195                &self.pyramid_slab_buffer,
196                &self.indirect_slab_buffer,
197                &self.bindgroup_layout,
198                self.compute_depth_pyramid.depth_pyramid.slab.device(),
199            )
200        });
201        let runtime = self.runtime();
202        let mut encoder = runtime
203            .device
204            .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Self::LABEL });
205        {
206            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
207                label: Self::LABEL,
208                timestamp_writes: None,
209            });
210            compute_pass.set_pipeline(&self.pipeline);
211            compute_pass.set_bind_group(0, Some(bindgroup.as_ref()), &[]);
212            compute_pass.dispatch_workgroups(indirect_draw_count / 16 + 1, 1, 1);
213        }
214        runtime.queue.submit(Some(encoder.finish()));
215    }
216}
217
218pub struct DepthPyramid {
219    slab: SlabAllocator<WgpuRuntime>,
220    desc: Hybrid<DepthPyramidDescriptor>,
221    mip: GpuArray<Array<f32>>,
222    mip_data: Vec<GpuArray<f32>>,
223}
224
225impl DepthPyramid {
226    const LABEL: &str = "depth-pyramid";
227
228    fn allocate(
229        size: UVec2,
230        desc: &Hybrid<DepthPyramidDescriptor>,
231        slab: &SlabAllocator<WgpuRuntime>,
232    ) -> (Vec<GpuArray<f32>>, GpuArray<Array<f32>>) {
233        let mip_levels = size.min_element().ilog2();
234        let mip_data = (0..mip_levels)
235            .map(|i| {
236                let width = size.x >> i;
237                let height = size.y >> i;
238                slab.new_array(vec![0f32; (width * height) as usize])
239                    .into_gpu_only()
240            })
241            .collect::<Vec<_>>();
242        let mip = slab.new_array(mip_data.iter().map(|m| m.array()));
243        desc.set(DepthPyramidDescriptor {
244            size,
245            mip_level: 0,
246            mip: mip.array(),
247        });
248        (mip_data, mip.into_gpu_only())
249    }
250
251    pub fn new(runtime: impl AsRef<WgpuRuntime>, size: UVec2) -> Self {
252        let slab = SlabAllocator::new(runtime, Self::LABEL, wgpu::BufferUsages::empty());
253        let desc = slab.new_value(DepthPyramidDescriptor::default());
254        let (mip_data, mip) = Self::allocate(size, &desc, &slab);
255
256        Self {
257            slab,
258            desc,
259            mip_data,
260            mip,
261        }
262    }
263
264    pub fn resize(&mut self, size: UVec2) {
265        log::trace!("resizing depth pyramid to {size}");
266        // drop the buffers
267        let mip = self.slab.new_array(vec![]);
268        self.mip_data = vec![];
269        self.desc.modify(|desc| desc.mip = mip.array());
270        self.mip = mip.into_gpu_only();
271
272        // Reclaim the dropped buffer slots
273        self.slab.commit();
274
275        // Reallocate
276        let (mip_data, mip) = Self::allocate(size, &self.desc, &self.slab);
277        self.mip_data = mip_data;
278        self.mip = mip;
279
280        // Run upkeep one more time to sync the resize
281        self.slab.commit();
282    }
283
284    pub fn size(&self) -> UVec2 {
285        self.desc.get().size
286    }
287
288    pub async fn read_images(&self) -> Result<Vec<image::GrayImage>, CullingError> {
289        let size = self.size();
290        let slab_data = self.slab.read(0..).await?;
291        let mut images = vec![];
292        let mut min = f32::MAX;
293        let mut max = f32::MIN;
294        for (i, mip) in self.mip_data.iter().enumerate() {
295            let depth_data: Vec<u8> = slab_data
296                .read_vec(mip.array())
297                .into_iter()
298                .map(|depth: f32| {
299                    if i == 0 {
300                        min = min.min(depth);
301                        max = max.max(depth);
302                    }
303                    crate::color::f32_to_u8(depth)
304                })
305                .collect();
306            log::trace!("min: {min}");
307            log::trace!("max: {max}");
308            let width = size.x >> i;
309            let height = size.y >> i;
310            let image = image::GrayImage::from_raw(width, height, depth_data)
311                .context(ReadMipSnafu { index: i })?;
312            images.push(image);
313        }
314        Ok(images)
315    }
316}
317
318/// Copies the depth texture to the top of the depth pyramid.
319struct ComputeCopyDepth {
320    pipeline: wgpu::ComputePipeline,
321    bindgroup_layout: wgpu::BindGroupLayout,
322    sample_count: u32,
323    pyramid_slab_buffer: SlabBuffer<wgpu::Buffer>,
324    bindgroup: ManagedBindGroup,
325}
326
327impl ComputeCopyDepth {
328    const LABEL: Option<&'static str> = Some("compute-occlusion-copy-depth-to-pyramid");
329
330    fn create_bindgroup_layout(device: &wgpu::Device, sample_count: u32) -> wgpu::BindGroupLayout {
331        if sample_count > 1 {
332            log::trace!(
333                "creating bindgroup layout with {sample_count} multisampled depth for {}",
334                Self::LABEL.unwrap()
335            );
336        } else {
337            log::trace!(
338                "creating bindgroup layout without multisampling for {}",
339                Self::LABEL.unwrap()
340            );
341        }
342        device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
343            label: Self::LABEL,
344            entries: &[
345                // slab
346                wgpu::BindGroupLayoutEntry {
347                    binding: 0,
348                    visibility: wgpu::ShaderStages::COMPUTE,
349                    ty: wgpu::BindingType::Buffer {
350                        ty: wgpu::BufferBindingType::Storage { read_only: false },
351                        has_dynamic_offset: false,
352                        min_binding_size: None,
353                    },
354                    count: None,
355                },
356                // previous_mip: DepthPyramidImage
357                wgpu::BindGroupLayoutEntry {
358                    binding: 1,
359                    visibility: wgpu::ShaderStages::COMPUTE,
360                    ty: wgpu::BindingType::Texture {
361                        sample_type: wgpu::TextureSampleType::Depth,
362                        view_dimension: wgpu::TextureViewDimension::D2,
363                        multisampled: sample_count > 1,
364                    },
365                    count: None,
366                },
367            ],
368        })
369    }
370
371    fn create_pipeline(
372        device: &wgpu::Device,
373        bindgroup_layout: &wgpu::BindGroupLayout,
374        multisampled: bool,
375    ) -> wgpu::ComputePipeline {
376        let linkage = if multisampled {
377            log::trace!("creating multisampled shader for {}", Self::LABEL.unwrap());
378            crate::linkage::compute_copy_depth_to_pyramid_multisampled::linkage(device)
379        } else {
380            log::trace!(
381                "creating shader without multisampling for {}",
382                Self::LABEL.unwrap()
383            );
384            crate::linkage::compute_copy_depth_to_pyramid::linkage(device)
385        };
386        device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
387            label: Self::LABEL,
388            layout: Some(
389                &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
390                    label: Self::LABEL,
391                    bind_group_layouts: &[bindgroup_layout],
392                    push_constant_ranges: &[],
393                }),
394            ),
395            module: &linkage.module,
396            entry_point: Some(linkage.entry_point),
397            compilation_options: wgpu::PipelineCompilationOptions::default(),
398            cache: None,
399        })
400    }
401
402    fn create_bindgroup(
403        device: &wgpu::Device,
404        layout: &wgpu::BindGroupLayout,
405        pyramid_desc_buffer: &wgpu::Buffer,
406        depth_texture_view: &wgpu::TextureView,
407    ) -> wgpu::BindGroup {
408        device.create_bind_group(&wgpu::BindGroupDescriptor {
409            label: Self::LABEL,
410            layout,
411            entries: &[
412                wgpu::BindGroupEntry {
413                    binding: 0,
414                    resource: wgpu::BindingResource::Buffer(
415                        pyramid_desc_buffer.as_entire_buffer_binding(),
416                    ),
417                },
418                wgpu::BindGroupEntry {
419                    binding: 1,
420                    resource: wgpu::BindingResource::TextureView(depth_texture_view),
421                },
422            ],
423        })
424    }
425
426    pub fn new(depth_pyramid: &DepthPyramid, depth_texture: &Texture) -> Self {
427        let device = depth_pyramid.slab.device();
428        let sample_count = depth_texture.texture.sample_count();
429        let bindgroup_layout = Self::create_bindgroup_layout(device, sample_count);
430        let pipeline = Self::create_pipeline(device, &bindgroup_layout, sample_count > 1);
431        let pyramid_slab_buffer = depth_pyramid.slab.commit();
432        let buffer = Self::create_bindgroup(
433            device,
434            &bindgroup_layout,
435            &pyramid_slab_buffer,
436            &depth_texture.view,
437        );
438        Self {
439            pipeline,
440            bindgroup: ManagedBindGroup::from(buffer),
441            bindgroup_layout,
442            pyramid_slab_buffer,
443            sample_count,
444        }
445    }
446
447    pub fn run(&mut self, pyramid: &mut DepthPyramid, depth_texture: &Texture) {
448        let _ = pyramid.desc.modify(|desc| {
449            desc.mip_level = 0;
450            desc.size
451        });
452
453        let runtime = pyramid.slab.runtime().clone();
454        let sample_count = depth_texture.texture.sample_count();
455        let sample_count_mismatch = sample_count != self.sample_count;
456        if sample_count_mismatch {
457            log::debug!(
458                "sample count changed from {} to {}, updating {} bindgroup layout and pipeline",
459                self.sample_count,
460                sample_count,
461                Self::LABEL.unwrap()
462            );
463            self.sample_count = sample_count;
464            self.bindgroup_layout = Self::create_bindgroup_layout(&runtime.device, sample_count);
465            self.pipeline =
466                Self::create_pipeline(&runtime.device, &self.bindgroup_layout, sample_count > 1);
467        }
468
469        let extent = depth_texture.texture.size();
470        let size = UVec2::new(extent.width, extent.height);
471        let size_changed = size != pyramid.size();
472        if size_changed {
473            pyramid.resize(size);
474        }
475
476        // TODO: check if we need to upkeep the depth pyramid slab here.
477        let _ = pyramid.slab.commit();
478        let should_recreate_bindgroup =
479            self.pyramid_slab_buffer.update_if_invalid() || sample_count_mismatch || size_changed;
480        let bindgroup = self.bindgroup.get(should_recreate_bindgroup, || {
481            Self::create_bindgroup(
482                &runtime.device,
483                &self.bindgroup_layout,
484                &self.pyramid_slab_buffer,
485                &depth_texture.view,
486            )
487        });
488
489        let mut encoder = runtime
490            .device
491            .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Self::LABEL });
492        {
493            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
494                label: Self::LABEL,
495                ..Default::default()
496            });
497            compute_pass.set_pipeline(&self.pipeline);
498            compute_pass.set_bind_group(0, Some(bindgroup.as_ref()), &[]);
499            let x = size.x / 16 + 1;
500            let y = size.y / 16 + 1;
501            let z = 1;
502            compute_pass.dispatch_workgroups(x, y, z);
503        }
504        pyramid.slab.queue().submit(Some(encoder.finish()));
505    }
506}
507
508/// Downsamples the depth texture from one mip to the next.
509struct ComputeDownsampleDepth {
510    pipeline: wgpu::ComputePipeline,
511    pyramid_slab_buffer: SlabBuffer<wgpu::Buffer>,
512    bindgroup: wgpu::BindGroup,
513    bindgroup_layout: wgpu::BindGroupLayout,
514}
515
516impl ComputeDownsampleDepth {
517    const LABEL: Option<&'static str> = Some("compute-occlusion-downsample-depth");
518
519    fn create_bindgroup_layout(device: &wgpu::Device) -> wgpu::BindGroupLayout {
520        device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
521            label: Self::LABEL,
522            entries: &[
523                // slab
524                wgpu::BindGroupLayoutEntry {
525                    binding: 0,
526                    visibility: wgpu::ShaderStages::COMPUTE,
527                    ty: wgpu::BindingType::Buffer {
528                        ty: wgpu::BufferBindingType::Storage { read_only: false },
529                        has_dynamic_offset: false,
530                        min_binding_size: None,
531                    },
532                    count: None,
533                },
534            ],
535        })
536    }
537
538    fn create_pipeline(
539        device: &wgpu::Device,
540        bindgroup_layout: &wgpu::BindGroupLayout,
541    ) -> wgpu::ComputePipeline {
542        let linkage = crate::linkage::compute_downsample_depth_pyramid::linkage(device);
543        device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
544            label: Self::LABEL,
545            layout: Some(
546                &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
547                    label: Self::LABEL,
548                    bind_group_layouts: &[bindgroup_layout],
549                    push_constant_ranges: &[],
550                }),
551            ),
552            module: &linkage.module,
553            entry_point: Some(linkage.entry_point),
554            compilation_options: wgpu::PipelineCompilationOptions::default(),
555            cache: None,
556        })
557    }
558
559    fn create_bindgroup(
560        device: &wgpu::Device,
561        layout: &wgpu::BindGroupLayout,
562        pyramid_desc_buffer: &wgpu::Buffer,
563    ) -> wgpu::BindGroup {
564        device.create_bind_group(&wgpu::BindGroupDescriptor {
565            label: Self::LABEL,
566            layout,
567            entries: &[wgpu::BindGroupEntry {
568                binding: 0,
569                resource: wgpu::BindingResource::Buffer(
570                    pyramid_desc_buffer.as_entire_buffer_binding(),
571                ),
572            }],
573        })
574    }
575
576    pub fn new(pyramid: &DepthPyramid) -> Self {
577        let device = pyramid.slab.device();
578        let bindgroup_layout = Self::create_bindgroup_layout(device);
579        let pipeline = Self::create_pipeline(device, &bindgroup_layout);
580        let pyramid_slab_buffer = pyramid.slab.commit();
581        let bindgroup = Self::create_bindgroup(device, &bindgroup_layout, &pyramid_slab_buffer);
582        Self {
583            pipeline,
584            bindgroup,
585            bindgroup_layout,
586            pyramid_slab_buffer,
587        }
588    }
589
590    pub fn run(&mut self, pyramid: &DepthPyramid) {
591        let device = pyramid.slab.device();
592
593        if self.pyramid_slab_buffer.update_if_invalid() {
594            self.bindgroup =
595                Self::create_bindgroup(device, &self.bindgroup_layout, &self.pyramid_slab_buffer);
596        }
597
598        for i in 1..pyramid.mip_data.len() {
599            log::trace!("downsampling to mip {i}..{}", pyramid.mip_data.len());
600            // Update the mip_level we're operating on.
601            let size = pyramid.desc.modify(|desc| {
602                desc.mip_level = i as u32;
603                desc.size
604            });
605            // Sync the change.
606            pyramid.slab.commit();
607            debug_assert!(
608                self.pyramid_slab_buffer.is_valid(),
609                "pyramid slab should never resize here"
610            );
611
612            let mut encoder = device
613                .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Self::LABEL });
614            {
615                let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
616                    label: Self::LABEL,
617                    ..Default::default()
618                });
619                compute_pass.set_pipeline(&self.pipeline);
620                compute_pass.set_bind_group(0, Some(&self.bindgroup), &[]);
621                let w = size.x >> i;
622                let h = size.y >> i;
623                let x = w / 16 + 1;
624                let y = h / 16 + 1;
625                let z = 1;
626                compute_pass.dispatch_workgroups(x, y, z);
627            }
628            pyramid.slab.queue().submit(Some(encoder.finish()));
629        }
630    }
631}
632
633/// Computes occlusion culling on the GPU.
634pub struct ComputeDepthPyramid {
635    pub(crate) depth_pyramid: DepthPyramid,
636    compute_copy_depth: ComputeCopyDepth,
637    compute_downsample_depth: ComputeDownsampleDepth,
638}
639
640impl ComputeDepthPyramid {
641    const _LABEL: Option<&'static str> = Some("compute-depth-pyramid");
642
643    pub fn new(runtime: impl AsRef<WgpuRuntime>, depth_texture: &Texture) -> Self {
644        let runtime = runtime.as_ref();
645        let depth_pyramid = DepthPyramid::new(runtime, depth_texture.size());
646        let compute_copy_depth = ComputeCopyDepth::new(&depth_pyramid, depth_texture);
647        let compute_downsample_depth = ComputeDownsampleDepth::new(&depth_pyramid);
648        Self {
649            depth_pyramid,
650            compute_copy_depth,
651            compute_downsample_depth,
652        }
653    }
654
655    /// Run depth pyramid copy and downsampling, then return the updated HZB buffer.
656    pub fn run(&mut self, depth_texture: &Texture) {
657        let extent = depth_texture.texture.size();
658        let size = UVec2::new(extent.width, extent.height);
659        if size != self.depth_pyramid.size() {
660            log::debug!("depth texture size changed");
661            self.depth_pyramid.resize(size);
662        }
663
664        self.compute_copy_depth
665            .run(&mut self.depth_pyramid, depth_texture);
666
667        self.compute_downsample_depth.run(&self.depth_pyramid);
668    }
669}
670
671#[cfg(test)]
672mod test {
673    use std::collections::HashMap;
674
675    use crate::{
676        bvol::BoundingSphere,
677        context::Context,
678        cull::shader::DepthPyramidDescriptor,
679        draw::DrawIndirectArgs,
680        geometry::{Geometry, Vertex},
681        math::hex_to_vec4,
682        primitive::shader::PrimitiveDescriptor,
683        test::BlockOnFuture,
684    };
685    use crabslab::{Array, GrowableSlab, Id, Slab};
686    use glam::{Mat4, Quat, UVec2, UVec3, Vec2, Vec3, Vec4};
687
688    #[test]
689    fn occlusion_culling_sanity() {
690        let ctx = Context::headless(100, 100).block();
691        let stage = ctx.new_stage().with_background_color(Vec4::splat(1.0));
692        let camera_position = Vec3::new(0.0, 9.0, 9.0);
693        let _camera = stage.new_camera().with_projection_and_view(
694            Mat4::perspective_rh(std::f32::consts::PI / 4.0, 1.0, 1.0, 24.0),
695            Mat4::look_at_rh(camera_position, Vec3::ZERO, Vec3::Y),
696        );
697        let _rez = stage
698            .new_primitive()
699            .with_vertices(stage.new_vertices(crate::test::gpu_cube_vertices()))
700            .with_transform(
701                stage
702                    .new_transform()
703                    .with_scale(Vec3::new(6.0, 6.0, 6.0))
704                    .with_rotation(Quat::from_axis_angle(Vec3::Y, -std::f32::consts::FRAC_PI_4)),
705            );
706
707        let frame = ctx.get_next_frame().unwrap();
708        stage.render(&frame.view());
709        frame.present();
710
711        let frame = ctx.get_next_frame().unwrap();
712        stage.render(&frame.view());
713        let img = frame.read_image().block().unwrap();
714        img_diff::save("cull/pyramid/frame.png", img);
715        frame.present();
716
717        let depth_texture = stage.get_depth_texture();
718        let depth_img = depth_texture.read_image().block().unwrap().unwrap();
719        img_diff::save("cull/pyramid/depth.png", depth_img);
720
721        let pyramid_images = futures_lite::future::block_on(
722            stage
723                .draw_calls
724                .read()
725                .unwrap()
726                .drawing_strategy
727                .as_indirect()
728                .unwrap()
729                .read_hzb_images(),
730        )
731        .unwrap();
732        for (i, img) in pyramid_images.into_iter().enumerate() {
733            img_diff::save(format!("cull/pyramid/mip_{i}.png"), img);
734        }
735    }
736
737    #[test]
738    fn depth_pyramid_descriptor_sanity() {
739        let mut slab = vec![];
740        let size = UVec2::new(64, 32);
741        let mip_levels = size.min_element().ilog2();
742        let desc_id = slab.allocate::<DepthPyramidDescriptor>();
743        let mips_array = slab.allocate_array::<Array<f32>>(mip_levels as usize);
744        let mip_data_arrays = (0..mip_levels)
745            .map(|i| {
746                let w = size.x >> i;
747                let h = size.y >> i;
748                let len = (w * h) as usize;
749                log::info!("allocating {len} f32s for mip {i} of size {w}x{h}");
750                let array = slab.allocate_array::<f32>(len);
751                let mut data: Vec<f32> = vec![];
752                for _y in 0..h {
753                    for x in 0..w {
754                        data.push(x as f32);
755                    }
756                }
757                slab.write_array(array, &data);
758                array
759            })
760            .collect::<Vec<_>>();
761        slab.write_array(mips_array, &mip_data_arrays);
762        let desc = DepthPyramidDescriptor {
763            size: UVec2::new(64, 32),
764            mip_level: 0,
765            mip: mips_array,
766        };
767        slab.write(desc_id, &desc);
768
769        // Test that `id_of_depth` returns the correct value.
770        for mip_level in 0..mip_levels {
771            let size = desc.size_at(mip_level);
772            log::info!("mip {mip_level} is size {size}");
773            for y in 0..size.y {
774                for x in 0..size.x {
775                    let id = desc.id_of_depth(mip_level, UVec2::new(x, y), &slab);
776                    let depth = slab.read(id);
777                    assert_eq!(x as f32, depth, "depth should be x value");
778                }
779            }
780        }
781    }
782
783    #[test]
784    fn occlusion_culling_debugging() {
785        let ctx = Context::headless(128, 128).block();
786        let stage = ctx
787            .new_stage()
788            .with_lighting(false)
789            .with_bloom(false)
790            .with_background_color(Vec4::splat(1.0));
791        let _camera = {
792            let fovy = std::f32::consts::FRAC_PI_4;
793            let aspect = 1.0;
794            let znear = 0.1;
795            let zfar = 100.0;
796            let projection = Mat4::perspective_rh(fovy, aspect, znear, zfar);
797            // Camera is looking straight down Z, towards the origin with Y up
798            let view = Mat4::look_at_rh(Vec3::new(0.0, 0.0, 10.0), Vec3::ZERO, Vec3::Y);
799            stage
800                .new_camera()
801                .with_projection_and_view(projection, view)
802        };
803
804        let save_render = |s: &str| {
805            let frame = ctx.get_next_frame().unwrap();
806            stage.render(&frame.view());
807            let img = frame.read_image().block().unwrap();
808            img_diff::save(format!("cull/debugging_{s}.png"), img);
809            frame.present();
810        };
811
812        // A hashmap to hold renderlet ids to their names.
813        let mut names = HashMap::<Id<PrimitiveDescriptor>, String>::default();
814
815        // Add four yellow cubes in each corner
816        let _ycubes = [
817            (Vec2::new(-1.0, 1.0), "top_left"),
818            (Vec2::new(1.0, 1.0), "top_right"),
819            (Vec2::new(1.0, -1.0), "bottom_right"),
820            (Vec2::new(-1.0, -1.0), "bottom_left"),
821        ]
822        .map(|(offset, suffix)| {
823            let yellow = hex_to_vec4(0xFFE6A5FF);
824            let renderlet = stage
825                .new_primitive()
826                .with_transform(
827                    stage
828                        .new_transform()
829                        // move it back behind the purple cube
830                        .with_translation((offset * 10.0).extend(-20.0))
831                        // scale it up since it's a unit cube
832                        .with_scale(Vec3::splat(2.0)),
833                )
834                .with_vertices(stage.new_vertices(crate::math::unit_cube().into_iter().map(
835                    |(p, n)| {
836                        Vertex::default()
837                            .with_position(p)
838                            .with_normal(n)
839                            .with_color(yellow)
840                    },
841                )))
842                .with_bounds(BoundingSphere::new(Vec3::ZERO, Vec3::splat(0.5).length()));
843            names.insert(renderlet.id(), format!("yellow_cube_{suffix}"));
844            renderlet
845        });
846
847        save_render("0_yellow_cubes");
848
849        // We'll add a golden floor
850        let _floor = {
851            let golden = hex_to_vec4(0xFFBF61FF);
852            let renderlet = stage
853                .new_primitive()
854                .with_transform(
855                    stage
856                        .new_transform()
857                        // flip it so it's facing up, like a floor should be
858                        .with_rotation(Quat::from_rotation_x(std::f32::consts::FRAC_PI_2))
859                        // move it down and back a bit
860                        .with_translation(Vec3::new(0.0, -5.0, -10.0))
861                        // scale it up, since it's a unit quad
862                        .with_scale(Vec3::new(100.0, 100.0, 1.0)),
863                )
864                .with_vertices(
865                    stage.new_vertices(
866                        crate::math::UNIT_QUAD_CCW
867                            .map(|p| Vertex::default().with_position(p).with_color(golden)),
868                    ),
869                )
870                .with_bounds(BoundingSphere::new(Vec3::ZERO, Vec2::splat(0.5).length()));
871            names.insert(renderlet.id(), "floor".into());
872            renderlet
873        };
874
875        save_render("1_floor");
876
877        // Add a green cube
878        let _gcube = {
879            let green = hex_to_vec4(0x8ABFA3FF);
880            let renderlet = stage
881                .new_primitive()
882                .with_transform(
883                    stage
884                        .new_transform()
885                        // move it back behind the purple cube
886                        .with_translation(Vec3::new(0.0, 0.0, -10.0))
887                        // scale it up since it's a unit cube
888                        .with_scale(Vec3::splat(5.0)),
889                )
890                .with_vertices(stage.new_vertices(crate::math::unit_cube().into_iter().map(
891                    |(p, n)| {
892                        Vertex::default()
893                            .with_position(p)
894                            .with_normal(n)
895                            .with_color(green)
896                    },
897                )))
898                .with_bounds(BoundingSphere::new(Vec3::ZERO, Vec3::splat(0.5).length()));
899            stage.add_primitive(&renderlet);
900            names.insert(renderlet.id(), "green_cube".into());
901            renderlet
902        };
903
904        save_render("2_green_cube");
905
906        // And a purple cube
907        let _pcube = {
908            let purple = hex_to_vec4(0x605678FF);
909            let renderlet = stage
910                .new_primitive()
911                .with_transform(
912                    stage
913                        .new_transform()
914                        // move it back a bit
915                        .with_translation(Vec3::new(0.0, 0.0, -3.0))
916                        // scale it up since it's a unit cube
917                        .with_scale(Vec3::splat(5.0)),
918                )
919                .with_vertices(stage.new_vertices(crate::math::unit_cube().into_iter().map(
920                    |(p, n)| {
921                        Vertex::default()
922                            .with_position(p)
923                            .with_normal(n)
924                            .with_color(purple)
925                    },
926                )))
927                .with_bounds(BoundingSphere::new(Vec3::ZERO, Vec3::splat(0.5).length()));
928            names.insert(renderlet.id(), "purple_cube".into());
929            renderlet
930        };
931
932        // Do two renders, because depth pyramid operates on depth data one frame
933        // behind
934        save_render("3_purple_cube");
935        save_render("3_purple_cube");
936
937        // save the normalized depth image
938        let mut depth_img = stage
939            .get_depth_texture()
940            .read_image()
941            .block()
942            .unwrap()
943            .unwrap();
944        img_diff::normalize_gray_img(&mut depth_img);
945        img_diff::save("cull/debugging_4_depth.png", depth_img);
946
947        // save the normalized pyramid images
948        let pyramid_images = futures_lite::future::block_on(
949            stage
950                .draw_calls
951                .read()
952                .unwrap()
953                .drawing_strategy
954                .as_indirect()
955                .unwrap()
956                .read_hzb_images(),
957        )
958        .unwrap();
959        for (i, mut img) in pyramid_images.into_iter().enumerate() {
960            img_diff::normalize_gray_img(&mut img);
961            img_diff::save(format!("cull/debugging_pyramid_mip_{i}.png"), img);
962        }
963
964        // The stage's slab, which contains the `Renderlet`s and their `BoundingSphere`s
965        let stage_slab = futures_lite::future::block_on({
966            let geometry: &Geometry = stage.as_ref();
967            geometry.slab_allocator().read(..)
968        })
969        .unwrap();
970        let draw_calls = stage.draw_calls.read().unwrap();
971        let indirect_draws = draw_calls.drawing_strategy.as_indirect().unwrap();
972        // The HZB slab, which contains a `DepthPyramidDescriptor` at index 0, and all the
973        // pyramid's mips
974        let depth_pyramid_slab = futures_lite::future::block_on(
975            indirect_draws
976                .compute_culling
977                .compute_depth_pyramid
978                .depth_pyramid
979                .slab
980                .read(..),
981        )
982        .unwrap();
983        // The indirect draw buffer
984        let mut args_slab = futures_lite::future::block_on(indirect_draws.slab.read(..)).unwrap();
985        let args: &mut [DrawIndirectArgs] = bytemuck::cast_slice_mut(&mut args_slab);
986        // Number of `DrawIndirectArgs` in the `args` buffer.
987        let num_draw_calls = draw_calls.draw_count();
988
989        // Print our names so we know what we're working with
990        let mut pnames = names.iter().collect::<Vec<_>>();
991        pnames.sort();
992        for (id, name) in pnames.into_iter() {
993            log::info!("id: {id:?}, name: {name}");
994        }
995
996        for i in 0..num_draw_calls as u32 {
997            let renderlet_id = Id::<PrimitiveDescriptor>::new(args[i as usize].first_instance);
998            let name = names.get(&renderlet_id).unwrap();
999            if name != "green_cube" {
1000                continue;
1001            }
1002            log::info!("");
1003            log::info!("name: {name}");
1004            crate::cull::shader::compute_culling(
1005                &stage_slab,
1006                &depth_pyramid_slab,
1007                args,
1008                UVec3::new(i, 0, 0),
1009            );
1010        }
1011    }
1012}