1pub mod shader {
6 use crabslab::{Id, Slab, SlabItem};
8 use glam::{Vec2, Vec3, Vec4, Vec4Swizzles};
9 use spirv_std::{
10 image::{Cubemap, Image2d},
11 num_traits::Zero,
12 spirv, Sampler,
13 };
14
15 #[allow(unused_imports)]
16 use spirv_std::num_traits::Float;
17
18 use crate::{camera::shader::CameraDescriptor, math::IsVector};
19
20 #[allow(clippy::manual_rotate)]
22 fn radical_inverse_vdc(mut bits: u32) -> f32 {
23 bits = (bits << 16u32) | (bits >> 16u32);
24 bits = ((bits & 0x55555555u32) << 1u32) | ((bits & 0xAAAAAAAAu32) >> 1u32);
25 bits = ((bits & 0x33333333u32) << 2u32) | ((bits & 0xCCCCCCCCu32) >> 2u32);
26 bits = ((bits & 0x0F0F0F0Fu32) << 4u32) | ((bits & 0xF0F0F0F0u32) >> 4u32);
27 bits = ((bits & 0x00FF00FFu32) << 8u32) | ((bits & 0xFF00FF00u32) >> 8u32);
28 (bits as f32) * 2.328_306_4e-10 }
30
31 fn hammersley(i: u32, n: u32) -> Vec2 {
32 Vec2::new(i as f32 / n as f32, radical_inverse_vdc(i))
33 }
34
35 fn importance_sample_ggx(xi: Vec2, n: Vec3, roughness: f32) -> Vec3 {
36 let a = roughness * roughness;
37
38 let phi = 2.0 * core::f32::consts::PI * xi.x;
39 let cos_theta = f32::sqrt((1.0 - xi.y) / (1.0 + (a * a - 1.0) * xi.y));
40 let sin_theta = f32::sqrt(1.0 - cos_theta * cos_theta);
41
42 let h = Vec3::new(phi.cos() * sin_theta, phi.sin() * sin_theta, cos_theta);
44
45 let up = if n.z.abs() < 0.999 {
47 Vec3::new(0.0, 0.0, 1.0)
48 } else {
49 Vec3::new(1.0, 0.0, 0.0)
50 };
51 let tangent = up.cross(n).alt_norm_or_zero();
52 let bitangent = n.cross(tangent);
53
54 let result = tangent * h.x + bitangent * h.y + n * h.z;
55 result.alt_norm_or_zero()
56 }
57
58 fn geometry_schlick_ggx(n_dot_v: f32, roughness: f32) -> f32 {
59 let r = roughness;
60 let k = (r * r) / 2.0;
61
62 let nom = n_dot_v;
63 let denom = n_dot_v * (1.0 - k) + k;
64
65 if denom.is_zero() {
66 0.0
67 } else {
68 nom / denom
69 }
70 }
71
72 fn geometry_smith(normal: Vec3, view_dir: Vec3, light_dir: Vec3, roughness: f32) -> f32 {
73 let n_dot_v = normal.dot(view_dir).max(0.0);
74 let n_dot_l = normal.dot(light_dir).max(0.0);
75 let ggx1 = geometry_schlick_ggx(n_dot_v, roughness);
76 let ggx2 = geometry_schlick_ggx(n_dot_l, roughness);
77
78 ggx1 * ggx2
79 }
80
81 const SAMPLE_COUNT: u32 = 1024;
82
83 pub fn integrate_brdf(mut n_dot_v: f32, roughness: f32) -> Vec2 {
84 n_dot_v = n_dot_v.max(f32::EPSILON);
85 let v = Vec3::new(f32::sqrt(1.0 - n_dot_v * n_dot_v), 0.0, n_dot_v);
86
87 let mut a = 0.0f32;
88 let mut b = 0.0f32;
89
90 let n = Vec3::Z;
91
92 for i in 1..SAMPLE_COUNT {
93 let xi = hammersley(i, SAMPLE_COUNT);
94 let h = importance_sample_ggx(xi, n, roughness);
95 let l = (2.0 * v.dot(h) * h - v).alt_norm_or_zero();
96
97 let n_dot_l = l.z.max(0.0);
98 let n_dot_h = h.z.max(0.0);
99 let v_dot_h = v.dot(h).max(0.0);
100
101 if n_dot_l > 0.0 {
102 let g = geometry_smith(n, v, l, roughness);
103 let denom = n_dot_h * n_dot_v;
104 let g_vis = (g * v_dot_h) / denom;
105 let f_c = (1.0 - v_dot_h).powf(5.0);
106
107 a += (1.0 - f_c) * g_vis;
108 b += f_c * g_vis;
109 }
110 }
111
112 a /= SAMPLE_COUNT as f32;
113 b /= SAMPLE_COUNT as f32;
114
115 Vec2::new(a, b)
116 }
117
118 pub fn integrate_brdf_doesnt_work(mut n_dot_v: f32, roughness: f32) -> Vec2 {
120 n_dot_v = n_dot_v.max(f32::EPSILON);
121 let v = Vec3::new(f32::sqrt(1.0 - n_dot_v * n_dot_v), 0.0, n_dot_v);
122
123 let mut a = 0.0f32;
124 let mut b = 0.0f32;
125
126 let n = Vec3::Z;
127
128 let mut i = 0u32;
129 while i < SAMPLE_COUNT {
130 i += 1;
131
132 let xi = hammersley(i, SAMPLE_COUNT);
133 let h = importance_sample_ggx(xi, n, roughness);
134 let l = (2.0 * v.dot(h) * h - v).alt_norm_or_zero();
135
136 let n_dot_l = l.z.max(0.0);
137 let n_dot_h = h.z.max(0.0);
138 let v_dot_h = v.dot(h).max(0.0);
139
140 if n_dot_l > 0.0 {
141 let g = geometry_smith(n, v, l, roughness);
142 let denom = n_dot_h * n_dot_v;
143 let g_vis = (g * v_dot_h) / denom;
144 let f_c = (1.0 - v_dot_h).powf(5.0);
145
146 a += (1.0 - f_c) * g_vis;
147 b += f_c * g_vis;
148 }
149 }
150
151 a /= SAMPLE_COUNT as f32;
152 b /= SAMPLE_COUNT as f32;
153
154 Vec2::new(a, b)
155 }
156
157 #[derive(Clone, Copy, Default, SlabItem)]
160 pub struct VertexPrefilterEnvironmentCubemapIds {
161 pub camera: Id<CameraDescriptor>,
162 pub roughness: Id<f32>,
164 }
165
166 #[spirv(vertex)]
168 pub fn prefilter_environment_cubemap_vertex(
169 #[spirv(instance_index)] prefilter_id: Id<VertexPrefilterEnvironmentCubemapIds>,
170 #[spirv(vertex_index)] vertex_id: u32,
171 #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] slab: &[u32],
172 out_pos: &mut Vec3,
173 out_roughness: &mut f32,
174 #[spirv(position)] gl_pos: &mut Vec4,
175 ) {
176 let in_pos = crate::math::CUBE[vertex_id as usize];
177 let VertexPrefilterEnvironmentCubemapIds { camera, roughness } = slab.read(prefilter_id);
178 let camera = slab.read(camera);
179 *out_roughness = slab.read(roughness);
180 *out_pos = in_pos;
181 *gl_pos = camera.view_projection() * in_pos.extend(1.0);
182 }
183
184 #[spirv(fragment)]
188 pub fn prefilter_environment_cubemap_fragment(
189 #[spirv(descriptor_set = 0, binding = 1)] environment_cubemap: &Cubemap,
190 #[spirv(descriptor_set = 0, binding = 2)] sampler: &Sampler,
191 in_pos: Vec3,
192 in_roughness: f32,
193 frag_color: &mut Vec4,
194 ) {
195 let mut n = in_pos.alt_norm_or_zero();
196 n.y *= -1.0;
198 let r = n;
199 let v = r;
200
201 let mut total_weight = 0.0f32;
202 let mut prefiltered_color = Vec3::ZERO;
203
204 for i in 0..SAMPLE_COUNT {
205 let xi = hammersley(i, SAMPLE_COUNT);
206 let h = importance_sample_ggx(xi, n, in_roughness);
207 let l = (2.0 * v.dot(h) * h - v).alt_norm_or_zero();
208
209 let n_dot_l = n.dot(l).max(0.0);
210 if n_dot_l > 0.0 {
211 let mip_level = if in_roughness == 0.0 {
212 0.0
213 } else {
214 calc_lod(n_dot_l)
215 };
216 prefiltered_color += environment_cubemap
217 .sample_by_lod(*sampler, l, mip_level)
218 .xyz()
219 * n_dot_l;
220 total_weight += n_dot_l;
221 }
222 }
223
224 prefiltered_color /= total_weight;
225 *frag_color = prefiltered_color.extend(1.0);
226 }
227
228 pub fn calc_lod_old(n: Vec3, v: Vec3, h: Vec3, roughness: f32) -> f32 {
229 let d = crate::pbr::shader::normal_distribution_ggx(n, h, roughness);
231 let n_dot_h = n.dot(h).max(0.0);
232 let h_dot_v = h.dot(v).max(0.0);
233 let pdf = (d * n_dot_h / (4.0 * h_dot_v)).max(f32::EPSILON);
234
235 let resolution = 512.0; let sa_texel = 4.0 * core::f32::consts::PI / (6.0 * resolution * resolution);
237 let sa_sample = 1.0 / (SAMPLE_COUNT as f32 * pdf + f32::EPSILON);
238
239 0.5 * (sa_sample / sa_texel).log2()
240 }
241
242 pub fn calc_lod(n_dot_l: f32) -> f32 {
243 let cube_width = 512.0;
244 let pdf = (n_dot_l * core::f32::consts::FRAC_1_PI).max(0.0);
245 0.5 * (6.0 * cube_width * cube_width / (SAMPLE_COUNT as f32 * pdf).max(f32::EPSILON)).log2()
246 }
247
248 #[spirv(vertex)]
249 pub fn generate_mipmap_vertex(
251 #[spirv(vertex_index)] vertex_id: u32,
252 out_uv: &mut Vec2,
253 #[spirv(position)] gl_pos: &mut Vec4,
254 ) {
255 let i = vertex_id as usize;
256 *out_uv = crate::math::UV_COORD_QUAD_CCW[i];
257 *gl_pos = crate::math::CLIP_SPACE_COORD_QUAD_CCW[i];
258 }
259
260 #[spirv(fragment)]
261 pub fn generate_mipmap_fragment(
263 #[spirv(descriptor_set = 0, binding = 0)] texture: &Image2d,
264 #[spirv(descriptor_set = 0, binding = 1)] sampler: &Sampler,
265 in_uv: Vec2,
266 frag_color: &mut Vec4,
267 ) {
268 *frag_color = texture.sample(*sampler, in_uv);
269 }
270
271 #[repr(C)]
272 #[derive(Clone, Copy)]
273 struct Vert {
274 pos: [f32; 3],
275 uv: [f32; 2],
276 }
277
278 const BRDF_VERTS: [Vert; 6] = {
280 let bl = Vert {
281 pos: [-1.0, -1.0, 0.0],
282 uv: [0.0, 1.0],
283 };
284 let br = Vert {
285 pos: [1.0, -1.0, 0.0],
286 uv: [1.0, 1.0],
287 };
288 let tl = Vert {
289 pos: [-1.0, 1.0, 0.0],
290 uv: [0.0, 0.0],
291 };
292 let tr = Vert {
293 pos: [1.0, 1.0, 0.0],
294 uv: [1.0, 0.0],
295 };
296
297 [bl, br, tr, bl, tr, tl]
298 };
299
300 #[spirv(vertex)]
301 pub fn brdf_lut_convolution_vertex(
303 #[spirv(vertex_index)] vertex_id: u32,
304 out_uv: &mut glam::Vec2,
305 #[spirv(position)] gl_pos: &mut glam::Vec4,
306 ) {
307 let Vert { pos, uv } = BRDF_VERTS[vertex_id as usize];
308 *out_uv = Vec2::from(uv);
309 *gl_pos = Vec3::from(pos).extend(1.0);
310 }
311
312 #[spirv(fragment)]
313 pub fn brdf_lut_convolution_fragment(in_uv: glam::Vec2, out_color: &mut glam::Vec2) {
315 *out_color = integrate_brdf(in_uv.x, in_uv.y);
316 }
317}
318
319#[cfg(test)]
320mod test {
321 use super::*;
322
323 #[test]
324 fn integrate_brdf_sanity() {
325 let points = [(0.0, 0.0), (1.0, 0.0), (0.0, 1.0), (1.0, 1.0)];
326 for (x, y) in points.into_iter() {
327 assert!(
328 !shader::integrate_brdf(x, y).is_nan(),
329 "brdf is NaN at {x},{y}"
330 );
331 }
332 let size = 32;
333 let mut img = image::RgbaImage::new(size, size);
334 for (x, y, image::Rgba([r, g, _, a])) in img.enumerate_pixels_mut() {
335 let u = x as f32 / size as f32;
336 let v = y as f32 / size as f32;
337 let brdf = shader::integrate_brdf(u, v);
338 *r = (brdf.x * 255.0) as u8;
339 *g = (brdf.y * 255.0) as u8;
340 *a = 255;
341 }
342 img_diff::assert_img_eq("skybox/brdf_cpu.png", img);
343 }
344}