renderling/primitive/
shader.rs

1//! Shader support for rendering primitives.
2use crabslab::{Array, Id, Slab, SlabItem};
3use glam::{Mat4, Vec2, Vec3, Vec4, Vec4Swizzles};
4use spirv_std::{
5    image::{Cubemap, Image2d, Image2dArray},
6    spirv, Image, Sampler,
7};
8
9// use glam::Mat4;
10// #[cfg(not(target_arch = "spirv"))]
11// use glam::UVec2;
12
13// #[allow(unused_imports)]
14// use spirv_std::num_traits::Float;
15
16use crate::{
17    bvol::BoundingSphere,
18    geometry::{
19        shader::{GeometryDescriptor, SkinDescriptor},
20        MorphTarget, Vertex,
21    },
22    material::shader::MaterialDescriptor,
23    math::IsVector,
24    transform::shader::TransformDescriptor,
25};
26
27#[allow(unused_imports)]
28use spirv_std::num_traits::Float;
29
30/// Returned by [`PrimitiveDescriptor::get_vertex_info`].
31pub struct VertexInfo {
32    pub vertex: Vertex,
33    pub transform: TransformDescriptor,
34    pub model_matrix: Mat4,
35    pub world_pos: Vec3,
36}
37
38/// A draw call used to render some geometry.
39#[derive(Clone, Copy, PartialEq, SlabItem, Debug)]
40#[offsets]
41pub struct PrimitiveDescriptor {
42    pub visible: bool,
43    pub vertices_array: Array<Vertex>,
44    /// Bounding sphere of the entire primitive, in local space.
45    pub bounds: BoundingSphere,
46    pub indices_array: Array<u32>,
47    pub transform_id: Id<TransformDescriptor>,
48    pub material_id: Id<MaterialDescriptor>,
49    pub skin_id: Id<SkinDescriptor>,
50    pub morph_targets: Array<Array<MorphTarget>>,
51    pub morph_weights: Array<f32>,
52    pub geometry_descriptor_id: Id<GeometryDescriptor>,
53}
54
55impl Default for PrimitiveDescriptor {
56    fn default() -> Self {
57        PrimitiveDescriptor {
58            visible: true,
59            vertices_array: Array::default(),
60            bounds: BoundingSphere::default(),
61            indices_array: Array::default(),
62            transform_id: Id::NONE,
63            material_id: Id::NONE,
64            skin_id: Id::NONE,
65            morph_targets: Array::default(),
66            morph_weights: Array::default(),
67            geometry_descriptor_id: Id::new(0),
68        }
69    }
70}
71
72impl PrimitiveDescriptor {
73    /// Returns the vertex at the given index and its related values.
74    ///
75    /// These values are often used in shaders, so they are grouped together.
76    pub fn get_vertex_info(&self, vertex_index: u32, geometry_slab: &[u32]) -> VertexInfo {
77        let vertex = self.get_vertex(vertex_index, geometry_slab);
78        let transform = self.get_transform(vertex, geometry_slab);
79        let model_matrix = Mat4::from(transform);
80        let world_pos = model_matrix.transform_point3(vertex.position);
81        VertexInfo {
82            vertex,
83            transform,
84            model_matrix,
85            world_pos,
86        }
87    }
88    /// Retrieve the transform of this `primitive`.
89    ///
90    /// This takes into consideration all skinning matrices.
91    pub fn get_transform(&self, vertex: Vertex, slab: &[u32]) -> TransformDescriptor {
92        let config = slab.read_unchecked(self.geometry_descriptor_id);
93        if config.has_skinning && self.skin_id.is_some() {
94            let skin = slab.read(self.skin_id);
95            TransformDescriptor::from(skin.get_skinning_matrix(vertex, slab))
96        } else {
97            slab.read(self.transform_id)
98        }
99    }
100
101    /// Retrieve the vertex from the slab, calculating any displacement due to
102    /// morph targets.
103    pub fn get_vertex(&self, vertex_index: u32, slab: &[u32]) -> Vertex {
104        let index = if self.indices_array.is_null() {
105            vertex_index as usize
106        } else {
107            slab.read(self.indices_array.at(vertex_index as usize)) as usize
108        };
109        let vertex_id = self.vertices_array.at(index);
110        let mut vertex = slab.read_unchecked(vertex_id);
111        for i in 0..self.morph_targets.len() {
112            let morph_target_array = slab.read(self.morph_targets.at(i));
113            let morph_target = slab.read(morph_target_array.at(index));
114            let weight = slab.read(self.morph_weights.at(i));
115            vertex.position += weight * morph_target.position;
116            vertex.normal += weight * morph_target.normal;
117            vertex.tangent += weight * morph_target.tangent.extend(0.0);
118        }
119        vertex
120    }
121
122    pub fn get_vertex_count(&self) -> u32 {
123        if self.indices_array.is_null() {
124            self.vertices_array.len() as u32
125        } else {
126            self.indices_array.len() as u32
127        }
128    }
129}
130
131#[cfg(test)]
132/// A helper struct that contains all outputs of the primitive's PBR vertex shader.
133#[derive(Default, Debug, Clone, Copy, PartialEq)]
134pub struct PrimitivePbrVertexInfo {
135    pub primitive: PrimitiveDescriptor,
136    pub primitive_id: Id<PrimitiveDescriptor>,
137    pub vertex_index: u32,
138    pub vertex: Vertex,
139    pub transform: TransformDescriptor,
140    pub model_matrix: Mat4,
141    pub view_projection: Mat4,
142    pub out_color: Vec4,
143    pub out_uv0: Vec2,
144    pub out_uv1: Vec2,
145    pub out_norm: Vec3,
146    pub out_tangent: Vec3,
147    pub out_bitangent: Vec3,
148    pub out_pos: Vec3,
149    pub out_clip_pos: Vec4,
150}
151
152/// primitive vertex shader.
153#[spirv(vertex)]
154#[allow(clippy::too_many_arguments)]
155pub fn primitive_vertex(
156    // Points at a `primitive`
157    #[spirv(instance_index)] primitive_id: Id<PrimitiveDescriptor>,
158    // Which vertex within the primitive are we rendering
159    #[spirv(vertex_index)] vertex_index: u32,
160    #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] geometry_slab: &[u32],
161
162    #[spirv(flat)] out_primitive: &mut Id<PrimitiveDescriptor>,
163    // TODO: Think about placing all these out values in a G-Buffer
164    // But do we have enough buffers + enough space on web?
165    // ...and can we write to buffers from vertex shaders on web?
166    out_color: &mut Vec4,
167    out_uv0: &mut Vec2,
168    out_uv1: &mut Vec2,
169    out_norm: &mut Vec3,
170    out_tangent: &mut Vec3,
171    out_bitangent: &mut Vec3,
172    out_world_pos: &mut Vec3,
173    #[spirv(position)] out_clip_pos: &mut Vec4,
174    // test-only info struct
175    #[cfg(test)] out_info: &mut PrimitivePbrVertexInfo,
176) {
177    let primitive = geometry_slab.read_unchecked(primitive_id);
178    if !primitive.visible {
179        // put it outside the clipping frustum
180        *out_clip_pos = Vec4::new(10.0, 10.0, 10.0, 1.0);
181        return;
182    }
183
184    *out_primitive = primitive_id;
185
186    let VertexInfo {
187        vertex,
188        transform,
189        model_matrix,
190        world_pos,
191    } = primitive.get_vertex_info(vertex_index, geometry_slab);
192    *out_color = vertex.color;
193    *out_uv0 = vertex.uv0;
194    *out_uv1 = vertex.uv1;
195    *out_world_pos = world_pos;
196
197    let scale2 = transform.scale * transform.scale;
198    let normal = vertex.normal.alt_norm_or_zero();
199    let tangent = vertex.tangent.xyz().alt_norm_or_zero();
200    let normal_w: Vec3 = (model_matrix * (normal / scale2).extend(0.0))
201        .xyz()
202        .alt_norm_or_zero();
203    *out_norm = normal_w;
204
205    let tangent_w: Vec3 = (model_matrix * tangent.extend(0.0))
206        .xyz()
207        .alt_norm_or_zero();
208    *out_tangent = tangent_w;
209
210    let bitangent_w = normal_w.cross(tangent_w) * if vertex.tangent.w >= 0.0 { 1.0 } else { -1.0 };
211    *out_bitangent = bitangent_w;
212
213    let camera_id = geometry_slab
214        .read_unchecked(primitive.geometry_descriptor_id + GeometryDescriptor::OFFSET_OF_CAMERA_ID);
215    let camera = geometry_slab.read(camera_id);
216    let clip_pos = camera.view_projection() * world_pos.extend(1.0);
217    *out_clip_pos = clip_pos;
218    #[cfg(test)]
219    {
220        *out_info = PrimitivePbrVertexInfo {
221            primitive_id,
222            vertex_index,
223            vertex,
224            transform,
225            model_matrix,
226            view_projection: camera.view_projection(),
227            out_clip_pos: clip_pos,
228            primitive,
229            out_color: *out_color,
230            out_uv0: *out_uv0,
231            out_uv1: *out_uv1,
232            out_norm: *out_norm,
233            out_tangent: *out_tangent,
234            out_bitangent: *out_bitangent,
235            out_pos: *out_world_pos,
236        };
237    }
238}
239
240/// primitive fragment shader
241#[allow(clippy::too_many_arguments, dead_code)]
242#[spirv(fragment)]
243pub fn primitive_fragment(
244    #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] geometry_slab: &[u32],
245    #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] material_slab: &[u32],
246    #[spirv(descriptor_set = 0, binding = 2)] atlas: &Image2dArray,
247    #[spirv(descriptor_set = 0, binding = 3)] atlas_sampler: &Sampler,
248    #[spirv(descriptor_set = 0, binding = 4)] irradiance: &Cubemap,
249    #[spirv(descriptor_set = 0, binding = 5)] irradiance_sampler: &Sampler,
250    #[spirv(descriptor_set = 0, binding = 6)] prefiltered: &Cubemap,
251    #[spirv(descriptor_set = 0, binding = 7)] prefiltered_sampler: &Sampler,
252    #[spirv(descriptor_set = 0, binding = 8)] brdf: &Image2d,
253    #[spirv(descriptor_set = 0, binding = 9)] brdf_sampler: &Sampler,
254    #[spirv(storage_buffer, descriptor_set = 0, binding = 10)] light_slab: &[u32],
255    #[spirv(descriptor_set = 0, binding = 11)] shadow_map: &Image!(2D, type=f32, sampled, arrayed),
256    #[spirv(descriptor_set = 0, binding = 12)] shadow_map_sampler: &Sampler,
257    #[cfg(feature = "debug-slab")]
258    #[spirv(storage_buffer, descriptor_set = 0, binding = 13)]
259    debug_slab: &mut [u32],
260
261    #[spirv(flat)] primitive_id: Id<PrimitiveDescriptor>,
262    #[spirv(frag_coord)] frag_coord: Vec4,
263    in_color: Vec4,
264    in_uv0: Vec2,
265    in_uv1: Vec2,
266    in_norm: Vec3,
267    in_tangent: Vec3,
268    in_bitangent: Vec3,
269    world_pos: Vec3,
270    output: &mut Vec4,
271) {
272    // proxy to a separate impl that allows us to test on CPU
273    crate::pbr::shader::fragment_impl(
274        atlas,
275        atlas_sampler,
276        irradiance,
277        irradiance_sampler,
278        prefiltered,
279        prefiltered_sampler,
280        brdf,
281        brdf_sampler,
282        shadow_map,
283        shadow_map_sampler,
284        geometry_slab,
285        material_slab,
286        light_slab,
287        primitive_id,
288        frag_coord,
289        in_color,
290        in_uv0,
291        in_uv1,
292        in_norm,
293        in_tangent,
294        in_bitangent,
295        world_pos,
296        output,
297    );
298}
299
300#[cfg(feature = "test_i8_16_extraction")]
301#[spirv(compute(threads(32)))]
302/// A shader to ensure that we can extract i8 and i16 values from a storage
303/// buffer.
304pub fn test_i8_i16_extraction(
305    #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] slab: &mut [u32],
306    #[spirv(global_invocation_id)] global_id: UVec3,
307) {
308    let index = global_id.x as usize;
309    let (value, _, _) = crate::bits::extract_i8(index, 2, slab);
310    if value > 0 {
311        slab[index] = value as u32;
312    }
313    let (value, _, _) = crate::bits::extract_i16(index, 2, slab);
314    if value > 0 {
315        slab[index] = value as u32;
316    }
317}