renderling/
convolution.rs

1//! Convolution shaders.
2//!
3//! These shaders convolve various functions to produce cached maps.
4
5pub mod shader {
6    //! Shader side of convolution.
7    use crabslab::{Id, Slab, SlabItem};
8    use glam::{Vec2, Vec3, Vec4, Vec4Swizzles};
9    use spirv_std::{
10        image::{Cubemap, Image2d},
11        num_traits::Zero,
12        spirv, Sampler,
13    };
14
15    #[allow(unused_imports)]
16    use spirv_std::num_traits::Float;
17
18    use crate::{camera::shader::CameraDescriptor, math::IsVector};
19
20    // Allow manual bit rotation because this code is `no_std`.
21    #[allow(clippy::manual_rotate)]
22    fn radical_inverse_vdc(mut bits: u32) -> f32 {
23        bits = (bits << 16u32) | (bits >> 16u32);
24        bits = ((bits & 0x55555555u32) << 1u32) | ((bits & 0xAAAAAAAAu32) >> 1u32);
25        bits = ((bits & 0x33333333u32) << 2u32) | ((bits & 0xCCCCCCCCu32) >> 2u32);
26        bits = ((bits & 0x0F0F0F0Fu32) << 4u32) | ((bits & 0xF0F0F0F0u32) >> 4u32);
27        bits = ((bits & 0x00FF00FFu32) << 8u32) | ((bits & 0xFF00FF00u32) >> 8u32);
28        (bits as f32) * 2.328_306_4e-10 // / 0x100000000
29    }
30
31    fn hammersley(i: u32, n: u32) -> Vec2 {
32        Vec2::new(i as f32 / n as f32, radical_inverse_vdc(i))
33    }
34
35    fn importance_sample_ggx(xi: Vec2, n: Vec3, roughness: f32) -> Vec3 {
36        let a = roughness * roughness;
37
38        let phi = 2.0 * core::f32::consts::PI * xi.x;
39        let cos_theta = f32::sqrt((1.0 - xi.y) / (1.0 + (a * a - 1.0) * xi.y));
40        let sin_theta = f32::sqrt(1.0 - cos_theta * cos_theta);
41
42        // Convert spherical to cartesian coordinates
43        let h = Vec3::new(phi.cos() * sin_theta, phi.sin() * sin_theta, cos_theta);
44
45        // Convert tangent-space vector to world-space vector
46        let up = if n.z.abs() < 0.999 {
47            Vec3::new(0.0, 0.0, 1.0)
48        } else {
49            Vec3::new(1.0, 0.0, 0.0)
50        };
51        let tangent = up.cross(n).alt_norm_or_zero();
52        let bitangent = n.cross(tangent);
53
54        let result = tangent * h.x + bitangent * h.y + n * h.z;
55        result.alt_norm_or_zero()
56    }
57
58    fn geometry_schlick_ggx(n_dot_v: f32, roughness: f32) -> f32 {
59        let r = roughness;
60        let k = (r * r) / 2.0;
61
62        let nom = n_dot_v;
63        let denom = n_dot_v * (1.0 - k) + k;
64
65        if denom.is_zero() {
66            0.0
67        } else {
68            nom / denom
69        }
70    }
71
72    fn geometry_smith(normal: Vec3, view_dir: Vec3, light_dir: Vec3, roughness: f32) -> f32 {
73        let n_dot_v = normal.dot(view_dir).max(0.0);
74        let n_dot_l = normal.dot(light_dir).max(0.0);
75        let ggx1 = geometry_schlick_ggx(n_dot_v, roughness);
76        let ggx2 = geometry_schlick_ggx(n_dot_l, roughness);
77
78        ggx1 * ggx2
79    }
80
81    const SAMPLE_COUNT: u32 = 1024;
82
83    pub fn integrate_brdf(mut n_dot_v: f32, roughness: f32) -> Vec2 {
84        n_dot_v = n_dot_v.max(f32::EPSILON);
85        let v = Vec3::new(f32::sqrt(1.0 - n_dot_v * n_dot_v), 0.0, n_dot_v);
86
87        let mut a = 0.0f32;
88        let mut b = 0.0f32;
89
90        let n = Vec3::Z;
91
92        for i in 1..SAMPLE_COUNT {
93            let xi = hammersley(i, SAMPLE_COUNT);
94            let h = importance_sample_ggx(xi, n, roughness);
95            let l = (2.0 * v.dot(h) * h - v).alt_norm_or_zero();
96
97            let n_dot_l = l.z.max(0.0);
98            let n_dot_h = h.z.max(0.0);
99            let v_dot_h = v.dot(h).max(0.0);
100
101            if n_dot_l > 0.0 {
102                let g = geometry_smith(n, v, l, roughness);
103                let denom = n_dot_h * n_dot_v;
104                let g_vis = (g * v_dot_h) / denom;
105                let f_c = (1.0 - v_dot_h).powf(5.0);
106
107                a += (1.0 - f_c) * g_vis;
108                b += f_c * g_vis;
109            }
110        }
111
112        a /= SAMPLE_COUNT as f32;
113        b /= SAMPLE_COUNT as f32;
114
115        Vec2::new(a, b)
116    }
117
118    /// This function doesn't work on rust-gpu, presumably because of the loop.
119    pub fn integrate_brdf_doesnt_work(mut n_dot_v: f32, roughness: f32) -> Vec2 {
120        n_dot_v = n_dot_v.max(f32::EPSILON);
121        let v = Vec3::new(f32::sqrt(1.0 - n_dot_v * n_dot_v), 0.0, n_dot_v);
122
123        let mut a = 0.0f32;
124        let mut b = 0.0f32;
125
126        let n = Vec3::Z;
127
128        let mut i = 0u32;
129        while i < SAMPLE_COUNT {
130            i += 1;
131
132            let xi = hammersley(i, SAMPLE_COUNT);
133            let h = importance_sample_ggx(xi, n, roughness);
134            let l = (2.0 * v.dot(h) * h - v).alt_norm_or_zero();
135
136            let n_dot_l = l.z.max(0.0);
137            let n_dot_h = h.z.max(0.0);
138            let v_dot_h = v.dot(h).max(0.0);
139
140            if n_dot_l > 0.0 {
141                let g = geometry_smith(n, v, l, roughness);
142                let denom = n_dot_h * n_dot_v;
143                let g_vis = (g * v_dot_h) / denom;
144                let f_c = (1.0 - v_dot_h).powf(5.0);
145
146                a += (1.0 - f_c) * g_vis;
147                b += f_c * g_vis;
148            }
149        }
150
151        a /= SAMPLE_COUNT as f32;
152        b /= SAMPLE_COUNT as f32;
153
154        Vec2::new(a, b)
155    }
156
157    /// Used by [`prefilter_environment_cubemap_vertex`] to read the camera and
158    /// roughness values from the slab.
159    #[derive(Clone, Copy, Default, SlabItem)]
160    pub struct VertexPrefilterEnvironmentCubemapIds {
161        pub camera: Id<CameraDescriptor>,
162        // TODO: does this have to be an Id? Pretty sure it can be inline
163        pub roughness: Id<f32>,
164    }
165
166    /// Vertex shader for rendering a "prefilter environment" cubemap.
167    #[spirv(vertex)]
168    pub fn prefilter_environment_cubemap_vertex(
169        #[spirv(instance_index)] prefilter_id: Id<VertexPrefilterEnvironmentCubemapIds>,
170        #[spirv(vertex_index)] vertex_id: u32,
171        #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] slab: &[u32],
172        out_pos: &mut Vec3,
173        out_roughness: &mut f32,
174        #[spirv(position)] gl_pos: &mut Vec4,
175    ) {
176        let in_pos = crate::math::CUBE[vertex_id as usize];
177        let VertexPrefilterEnvironmentCubemapIds { camera, roughness } = slab.read(prefilter_id);
178        let camera = slab.read(camera);
179        *out_roughness = slab.read(roughness);
180        *out_pos = in_pos;
181        *gl_pos = camera.view_projection() * in_pos.extend(1.0);
182    }
183
184    /// Fragment shader for rendering a "prefilter environment" cubemap.
185    ///
186    /// Lambertian prefilter.
187    #[spirv(fragment)]
188    pub fn prefilter_environment_cubemap_fragment(
189        #[spirv(descriptor_set = 0, binding = 1)] environment_cubemap: &Cubemap,
190        #[spirv(descriptor_set = 0, binding = 2)] sampler: &Sampler,
191        in_pos: Vec3,
192        in_roughness: f32,
193        frag_color: &mut Vec4,
194    ) {
195        let mut n = in_pos.alt_norm_or_zero();
196        // `wgpu` and vulkan's y coords are flipped from opengl
197        n.y *= -1.0;
198        let r = n;
199        let v = r;
200
201        let mut total_weight = 0.0f32;
202        let mut prefiltered_color = Vec3::ZERO;
203
204        for i in 0..SAMPLE_COUNT {
205            let xi = hammersley(i, SAMPLE_COUNT);
206            let h = importance_sample_ggx(xi, n, in_roughness);
207            let l = (2.0 * v.dot(h) * h - v).alt_norm_or_zero();
208
209            let n_dot_l = n.dot(l).max(0.0);
210            if n_dot_l > 0.0 {
211                let mip_level = if in_roughness == 0.0 {
212                    0.0
213                } else {
214                    calc_lod(n_dot_l)
215                };
216                prefiltered_color += environment_cubemap
217                    .sample_by_lod(*sampler, l, mip_level)
218                    .xyz()
219                    * n_dot_l;
220                total_weight += n_dot_l;
221            }
222        }
223
224        prefiltered_color /= total_weight;
225        *frag_color = prefiltered_color.extend(1.0);
226    }
227
228    pub fn calc_lod_old(n: Vec3, v: Vec3, h: Vec3, roughness: f32) -> f32 {
229        // sample from the environment's mip level based on roughness/pdf
230        let d = crate::pbr::shader::normal_distribution_ggx(n, h, roughness);
231        let n_dot_h = n.dot(h).max(0.0);
232        let h_dot_v = h.dot(v).max(0.0);
233        let pdf = (d * n_dot_h / (4.0 * h_dot_v)).max(f32::EPSILON);
234
235        let resolution = 512.0; // resolution of source cubemap (per face)
236        let sa_texel = 4.0 * core::f32::consts::PI / (6.0 * resolution * resolution);
237        let sa_sample = 1.0 / (SAMPLE_COUNT as f32 * pdf + f32::EPSILON);
238
239        0.5 * (sa_sample / sa_texel).log2()
240    }
241
242    pub fn calc_lod(n_dot_l: f32) -> f32 {
243        let cube_width = 512.0;
244        let pdf = (n_dot_l * core::f32::consts::FRAC_1_PI).max(0.0);
245        0.5 * (6.0 * cube_width * cube_width / (SAMPLE_COUNT as f32 * pdf).max(f32::EPSILON)).log2()
246    }
247
248    #[spirv(vertex)]
249    /// Vertex shader for generating texture mips.
250    pub fn generate_mipmap_vertex(
251        #[spirv(vertex_index)] vertex_id: u32,
252        out_uv: &mut Vec2,
253        #[spirv(position)] gl_pos: &mut Vec4,
254    ) {
255        let i = vertex_id as usize;
256        *out_uv = crate::math::UV_COORD_QUAD_CCW[i];
257        *gl_pos = crate::math::CLIP_SPACE_COORD_QUAD_CCW[i];
258    }
259
260    #[spirv(fragment)]
261    /// Fragment shader for generating texture mips.
262    pub fn generate_mipmap_fragment(
263        #[spirv(descriptor_set = 0, binding = 0)] texture: &Image2d,
264        #[spirv(descriptor_set = 0, binding = 1)] sampler: &Sampler,
265        in_uv: Vec2,
266        frag_color: &mut Vec4,
267    ) {
268        *frag_color = texture.sample(*sampler, in_uv);
269    }
270
271    #[repr(C)]
272    #[derive(Clone, Copy)]
273    struct Vert {
274        pos: [f32; 3],
275        uv: [f32; 2],
276    }
277
278    /// A screen-space quad.
279    const BRDF_VERTS: [Vert; 6] = {
280        let bl = Vert {
281            pos: [-1.0, -1.0, 0.0],
282            uv: [0.0, 1.0],
283        };
284        let br = Vert {
285            pos: [1.0, -1.0, 0.0],
286            uv: [1.0, 1.0],
287        };
288        let tl = Vert {
289            pos: [-1.0, 1.0, 0.0],
290            uv: [0.0, 0.0],
291        };
292        let tr = Vert {
293            pos: [1.0, 1.0, 0.0],
294            uv: [1.0, 0.0],
295        };
296
297        [bl, br, tr, bl, tr, tl]
298    };
299
300    #[spirv(vertex)]
301    /// Vertex shader for creating a BRDF LUT.
302    pub fn brdf_lut_convolution_vertex(
303        #[spirv(vertex_index)] vertex_id: u32,
304        out_uv: &mut glam::Vec2,
305        #[spirv(position)] gl_pos: &mut glam::Vec4,
306    ) {
307        let Vert { pos, uv } = BRDF_VERTS[vertex_id as usize];
308        *out_uv = Vec2::from(uv);
309        *gl_pos = Vec3::from(pos).extend(1.0);
310    }
311
312    #[spirv(fragment)]
313    /// Fragment shader for creating a BRDF LUT.
314    pub fn brdf_lut_convolution_fragment(in_uv: glam::Vec2, out_color: &mut glam::Vec2) {
315        *out_color = integrate_brdf(in_uv.x, in_uv.y);
316    }
317}
318
319#[cfg(test)]
320mod test {
321    use super::*;
322
323    #[test]
324    fn integrate_brdf_sanity() {
325        let points = [(0.0, 0.0), (1.0, 0.0), (0.0, 1.0), (1.0, 1.0)];
326        for (x, y) in points.into_iter() {
327            assert!(
328                !shader::integrate_brdf(x, y).is_nan(),
329                "brdf is NaN at {x},{y}"
330            );
331        }
332        let size = 32;
333        let mut img = image::RgbaImage::new(size, size);
334        for (x, y, image::Rgba([r, g, _, a])) in img.enumerate_pixels_mut() {
335            let u = x as f32 / size as f32;
336            let v = y as f32 / size as f32;
337            let brdf = shader::integrate_brdf(u, v);
338            *r = (brdf.x * 255.0) as u8;
339            *g = (brdf.y * 255.0) as u8;
340            *a = 255;
341        }
342        img_diff::assert_img_eq("skybox/brdf_cpu.png", img);
343    }
344}