WGSL Shaders
The heart of our simulation lies in the shaders. We have two shader files:
compute_shader.wgsl: Updates particle physics (position, velocity).draw_shader.wgsl: Draws the particles to the screen.
1. Compute Shader
Create compute_shader.wgsl in your project root.
📝 Note
This shader runs once for every particle in the simulation. It calculates forces, gravity, and updates the position.
struct Dimensions {
width: u32,
height: u32,
generation_offset: u32,
num_of_particles: u32,
frame_time: f32,
is_gravity_on: u32,
time_to_die: f32,
num_of_particles_to_generate_per_second: u32,
target_pos: vec4<f32>,
proj_view: mat4x4<f32>,
init_type: u32,
}
struct particle {
pos: vec4<f32>,
speed: vec4<f32>,
accel: vec4<f32>,
}
struct number_of_alive_particles {
count: atomic<u32>,
}
var<workgroup> shared_reduce: array<u32, 256>;
@group(0) @binding(1)
var<storage, read_write> particles: array<particle>;
@group(0) @binding(2)
var<uniform> dimensions: Dimensions;
@group(0) @binding(3)
var<storage, read_write> alive_particles: number_of_alive_particles;
fn hash(value: u32) -> u32 {
var state = value;
state = state ^ 2747636419u;
state = state * 2654435769u;
state = state ^ (state >> 16u);
state = state * 2654435769u;
state = state ^ (state >> 16u);
state = state * 2654435769u;
return state;
}
const INT_max_u32 = 1.0 / 4294967295.0;
fn randomFloat(value: u32) -> f32 {
return f32(hash(value)) * INT_max_u32;
}
const PI: f32 = 3.14159265359;
const TWO_PI: f32 = 6.28318530718;
const ONE_DIV_3: f32 = 1.0 / 3.0;
fn get_cube_pos(id: u32) -> vec4<f32> {
return vec4<f32>(fma(randomFloat(id), 50.0, - 25.0), fma(randomFloat(id * 2u), 50.0, - 25.0), fma(randomFloat(id * 3u), - 50.0, 25.0), 0.0);
}
fn get_sphere_pos(id: u32) -> vec4<f32> {
// 1. Generate 3 independent uniform random numbers in [0, 1]
let u1 = randomFloat(id);
let u2 = randomFloat(id * 2u);
let u3 = randomFloat(id * 3u);
let R = 25.0; // Sphere radius
var xyz: vec3<f32>;
// 2. Map to spherical coordinates
let theta = 2.0 * PI * u1;
let phi = (PI * u2 );
let radius = R * pow(u3, ONE_DIV_3);
// 3. Convert spherical to Cartesian
let sin_phi = sin(phi);
xyz = vec3<f32>(
radius * sin_phi * cos(theta),
radius * cos(phi),
radius * sin_phi * sin(theta),
);
return vec4<f32>(xyz, 0.0);
}
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>) {
// Parallel Reduction to count alive particles
var local_val = 0u;
if (global_id.x < dimensions.num_of_particles) {
particles[global_id.x].pos.w = 0.0;
local_val = select(0u, 1u, particles[global_id.x].accel.w >= 1.0);
}
shared_reduce[local_id.x] = local_val;
workgroupBarrier();
for (var s = 128u; s > 0u; s >>= 1u) {
if (local_id.x < s) {
shared_reduce[local_id.x] += shared_reduce[local_id.x + s];
}
workgroupBarrier();
}
if (local_id.x == 0u) {
atomicAdd(&alive_particles.count, shared_reduce[0]);
}
// Physics Update
var p = particles[global_id.x];
let start_gen = dimensions.generation_offset * dimensions.num_of_particles_to_generate_per_second;
let end_gen = start_gen + dimensions.num_of_particles_to_generate_per_second;
let should_revive = (p.accel.w <= 0.0) && (global_id.x >= start_gen) && (global_id.x < end_gen);
if (should_revive) {
p.accel = vec4<f32>(0.0, 0.0, 0.0, 1.0);
p.speed = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (dimensions.init_type == 0u) {
p.pos = get_cube_pos(global_id.x);
} else {
p.pos = get_sphere_pos(global_id.x);
}
}
let dt = vec3<f32>(dimensions.frame_time);
p.speed.w = p.speed.w + dt.x;
// Kill old particles
p.accel.w = select(p.accel.w, 0.0, (p.speed.w >= dimensions.time_to_die) && (dimensions.time_to_die > 0.0));
// Update Speed and Position using FMA (Fused Multiply-Add) for performance
p.speed = vec4<f32>(fma(p.accel.xyz, dt, p.speed.xyz), p.speed.w);
p.pos = vec4<f32>(fma(p.speed.xyz, dt, p.pos.xyz), p.pos.w);
// Apply Damping
p.speed = vec4<f32>(fma(p.speed.xyz, -dt * 0.7, p.speed.xyz), p.speed.w);
if( dimensions.is_gravity_on == 0u) {
p.accel = vec4<f32>(0.0, 0.0, 0.0, p.accel.w);
particles[global_id.x] = p;
return;
}
// Gravity / Attraction Logic
let target_pos = dimensions.target_pos.xyz;
let diff = target_pos - p.pos.xyz;
let dist_sq = length(diff);
let r_inv = 1.0/(max(dist_sq, 0.001));
let force_mag = 30000.0 / (dist_sq*dist_sq + 100.0);
p.accel = vec4<f32>(diff * (r_inv * force_mag), p.accel.w);
particles[global_id.x] = p;
}
// Separate entry points for re-initialization
@compute @workgroup_size(256)
fn init_cube(@builtin(global_invocation_id) global_id: vec3<u32>) {
if (global_id.x >= dimensions.num_of_particles) { return; }
var dead_or_alive = select(0.0, 1.0, dimensions.time_to_die <= 0.0);
dead_or_alive = select(dead_or_alive, 1.0, particles[global_id.x].accel.w >= 1.0);
particles[global_id.x] = particle(
get_cube_pos(global_id.x),
vec4<f32>(0.0, 0.0, 0.0, particles[global_id.x].speed.w),
vec4<f32>(0.0, 0.0, 0.0, dead_or_alive)
);
}
@compute @workgroup_size(256)
fn init_sphere(@builtin(global_invocation_id) global_id: vec3<u32>) {
if (global_id.x >= dimensions.num_of_particles) { return; }
var dead_or_alive = select(0.0, 1.0, dimensions.time_to_die <= 0.0);
dead_or_alive = select(dead_or_alive, 1.0, particles[global_id.x].accel.w >= 1.0);
particles[global_id.x] = particle(
get_sphere_pos(global_id.x),
vec4<f32>(0.0, 0.0, 0.0, particles[global_id.x].speed.w),
vec4<f32>(0.0, 0.0, 0.0, dead_or_alive )
);
}
Key Concepts Explanation
@group(0) @binding(X): Maps to the data bindings we will set up in Rust. Binding 1 is our particle storage buffer.vec4<f32>: We usevec4for position/velocity because GPUs are optimized for 4-component vectors, and it ensures 16-byte alignment.fma(a, b, c): Performsa * b + cin a single hardware step. It's faster and more precise than doing it separately.workgroupBarrier(): Ensures all threads in a group have finished writing to shared memory before any read from it. Essential for the reduction step.atomicAdd: Safely increments a shared counter from multiple threads simultaneously without race conditions.
2. Draw Shader
Create draw_shader.wgsl. This handles rendering.
struct Dimensions {
width: u32,
height: u32,
generation_offset: u32,
num_of_particles: u32,
frame_time: f32,
is_gravity_on: u32,
time_to_die: f32,
num_of_particles_to_generate_per_second: u32,
target_pos: vec4<f32>,
proj_view: mat4x4<f32>,
}
struct particle {
pos: vec4<f32>,
speed: vec4<f32>,
accel: vec4<f32>,
}
@group(0) @binding(1)
var<storage, read> particles: array<particle>;
@group(0) @binding(2)
var<uniform> dimensions: Dimensions;
struct VertexOutput {
@builtin(position) clip_position: vec4<f32>,
@location(0) color: vec4<f32>,
}
@vertex
fn vs_main(@builtin(vertex_index) in_vertex_index: u32) -> VertexOutput {
var out: VertexOutput;
let particle = particles[in_vertex_index];
// Dead particles are collapsed to zero and made transparent
if(particle.accel.w <= 0.0f) {
out.clip_position = vec4<f32>(0.0, 0.0, 0.0, 0.0);
out.color = vec4<f32>(1.0, 1.0, 1.0, 0.0);
return out;
}
// Transform position: View/Projection Matrix * World Position
out.clip_position = dimensions.proj_view * vec4<f32>(particle.pos.xyz, 1.0);
// Coloring based on distance from target
let particle_pos_len = (length((dimensions.target_pos.xyz) - (particle.pos.xyz))) / 15.0;
let intensity_r = 1.0 - ((particle_pos_len * 0.9));
let intensity_g = 1.0 - ((particle_pos_len * 0.2));
let intensity_b = 1.0 - ((particle_pos_len * 0.01));
out.color = vec4<f32>(intensity_r, intensity_g, intensity_b, 1.0);
return out;
}
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
return in.color;
}
Draw Shader Logic
@vertex: The entry point for the vertex shader. It runs for each vertex. Since we draw POINTS, it runs once per particle.@builtin(vertex_index): The ID of the current vertex (particle). We use this to look up the particle's position in the storage buffer.@builtin(position): The final screen-space coordinate of the vertex. WGPU needs this to know where to draw.@fragment: Runs for every pixel the point covers. It just outputs the calculated color.