WGSL Shaders

The heart of our simulation lies in the shaders. We have two shader files:

  1. compute_shader.wgsl: Updates particle physics (position, velocity).
  2. draw_shader.wgsl: Draws the particles to the screen.

1. Compute Shader

Create compute_shader.wgsl in your project root.

📝 Note

This shader runs once for every particle in the simulation. It calculates forces, gravity, and updates the position.

struct Dimensions {
    width: u32,
    height: u32,
    generation_offset: u32,
    num_of_particles: u32,
    frame_time: f32,
    is_gravity_on: u32,
    time_to_die: f32,
    num_of_particles_to_generate_per_second: u32,
    target_pos: vec4<f32>,
    proj_view: mat4x4<f32>,
    init_type: u32,
}

struct particle {
    pos: vec4<f32>,
    speed: vec4<f32>,
    accel: vec4<f32>,
}

struct number_of_alive_particles {
    count: atomic<u32>,
}

var<workgroup> shared_reduce: array<u32, 256>;

@group(0) @binding(1)
var<storage, read_write> particles: array<particle>;

@group(0) @binding(2)
var<uniform> dimensions: Dimensions;

@group(0) @binding(3)
var<storage, read_write> alive_particles: number_of_alive_particles;

fn hash(value: u32) -> u32 {
    var state = value;
    state = state ^ 2747636419u;
    state = state * 2654435769u;
    state = state ^ (state >> 16u);
    state = state * 2654435769u;
    state = state ^ (state >> 16u);
    state = state * 2654435769u;
    return state;
}

const INT_max_u32 = 1.0 / 4294967295.0;

fn randomFloat(value: u32) -> f32 {
    return f32(hash(value)) * INT_max_u32;
}

const PI: f32 = 3.14159265359;
const TWO_PI: f32 = 6.28318530718;
const ONE_DIV_3: f32 = 1.0 / 3.0;

fn get_cube_pos(id: u32) -> vec4<f32> {
     return vec4<f32>(fma(randomFloat(id), 50.0, - 25.0), fma(randomFloat(id * 2u), 50.0, - 25.0), fma(randomFloat(id * 3u), - 50.0, 25.0), 0.0);
}

fn get_sphere_pos(id: u32) -> vec4<f32> {
	// 1. Generate 3 independent uniform random numbers in [0, 1]
	let u1 = randomFloat(id);
	let u2 = randomFloat(id * 2u);
	let u3 = randomFloat(id * 3u);

	let R = 25.0; // Sphere radius
	var xyz: vec3<f32>;

	// 2. Map to spherical coordinates
    let theta = 2.0 * PI * u1;
    let phi = (PI * u2 );
    let radius =  R * pow(u3, ONE_DIV_3);

    // 3. Convert spherical to Cartesian
    let sin_phi = sin(phi);
    xyz = vec3<f32>(
        radius * sin_phi * cos(theta),
        radius * cos(phi),
        radius * sin_phi * sin(theta),
    );
	return vec4<f32>(xyz, 0.0);
}

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>) {

    // Parallel Reduction to count alive particles
    var local_val = 0u;
    if (global_id.x < dimensions.num_of_particles) {
        particles[global_id.x].pos.w = 0.0;
        local_val = select(0u, 1u, particles[global_id.x].accel.w >= 1.0);
    }
    shared_reduce[local_id.x] = local_val;
    workgroupBarrier();

    for (var s = 128u; s > 0u; s >>= 1u) {
        if (local_id.x < s) {
            shared_reduce[local_id.x] += shared_reduce[local_id.x + s];
        }
        workgroupBarrier();
    }

    if (local_id.x == 0u) {
        atomicAdd(&alive_particles.count, shared_reduce[0]);
    }

    // Physics Update
    var p = particles[global_id.x];

    let start_gen = dimensions.generation_offset * dimensions.num_of_particles_to_generate_per_second;
    let end_gen = start_gen + dimensions.num_of_particles_to_generate_per_second;
    let should_revive = (p.accel.w <= 0.0) && (global_id.x >= start_gen) && (global_id.x < end_gen);

    if (should_revive) {
        p.accel = vec4<f32>(0.0, 0.0, 0.0, 1.0);
        p.speed = vec4<f32>(0.0, 0.0, 0.0, 0.0);
        if (dimensions.init_type == 0u) {
             p.pos = get_cube_pos(global_id.x);
        } else {
             p.pos = get_sphere_pos(global_id.x);
        }
    }
    
    let dt = vec3<f32>(dimensions.frame_time);
    p.speed.w = p.speed.w + dt.x;
    
    // Kill old particles
    p.accel.w = select(p.accel.w, 0.0, (p.speed.w >= dimensions.time_to_die) && (dimensions.time_to_die > 0.0));

    // Update Speed and Position using FMA (Fused Multiply-Add) for performance
    p.speed = vec4<f32>(fma(p.accel.xyz, dt, p.speed.xyz), p.speed.w);
    p.pos = vec4<f32>(fma(p.speed.xyz, dt, p.pos.xyz), p.pos.w);

    // Apply Damping
    p.speed = vec4<f32>(fma(p.speed.xyz, -dt * 0.7, p.speed.xyz), p.speed.w);

    if( dimensions.is_gravity_on == 0u) {
        p.accel =  vec4<f32>(0.0, 0.0, 0.0, p.accel.w);
        particles[global_id.x] = p;
        return;
    }

    // Gravity / Attraction Logic
    let target_pos = dimensions.target_pos.xyz;
    let diff = target_pos - p.pos.xyz;
    let dist_sq = length(diff);
    let r_inv = 1.0/(max(dist_sq, 0.001));
    let force_mag = 30000.0 / (dist_sq*dist_sq + 100.0);

    p.accel = vec4<f32>(diff * (r_inv * force_mag), p.accel.w);
    particles[global_id.x] = p;
}

// Separate entry points for re-initialization
@compute @workgroup_size(256)
fn init_cube(@builtin(global_invocation_id) global_id: vec3<u32>) {
    if (global_id.x >= dimensions.num_of_particles) { return; }
    
    var dead_or_alive = select(0.0, 1.0, dimensions.time_to_die <= 0.0);
    dead_or_alive = select(dead_or_alive, 1.0, particles[global_id.x].accel.w >= 1.0);

    particles[global_id.x] = particle(
        get_cube_pos(global_id.x),
        vec4<f32>(0.0, 0.0, 0.0, particles[global_id.x].speed.w),
        vec4<f32>(0.0, 0.0, 0.0,  dead_or_alive)
    );
}

@compute @workgroup_size(256)
fn init_sphere(@builtin(global_invocation_id) global_id: vec3<u32>) {
    if (global_id.x >= dimensions.num_of_particles) { return; }

    var dead_or_alive = select(0.0, 1.0, dimensions.time_to_die <= 0.0);
    dead_or_alive = select(dead_or_alive, 1.0, particles[global_id.x].accel.w >= 1.0);
    
    particles[global_id.x] = particle(
        get_sphere_pos(global_id.x),
        vec4<f32>(0.0, 0.0, 0.0, particles[global_id.x].speed.w),
        vec4<f32>(0.0, 0.0, 0.0, dead_or_alive )
    );
}

Key Concepts Explanation

  • @group(0) @binding(X): Maps to the data bindings we will set up in Rust. Binding 1 is our particle storage buffer.
  • vec4<f32>: We use vec4 for position/velocity because GPUs are optimized for 4-component vectors, and it ensures 16-byte alignment.
  • fma(a, b, c): Performs a * b + c in a single hardware step. It's faster and more precise than doing it separately.
  • workgroupBarrier(): Ensures all threads in a group have finished writing to shared memory before any read from it. Essential for the reduction step.
  • atomicAdd: Safely increments a shared counter from multiple threads simultaneously without race conditions.

2. Draw Shader

Create draw_shader.wgsl. This handles rendering.

struct Dimensions {
    width: u32,
    height: u32,
    generation_offset: u32,
    num_of_particles: u32,
    frame_time: f32,
    is_gravity_on: u32,
    time_to_die: f32,
    num_of_particles_to_generate_per_second: u32,
    target_pos: vec4<f32>,
    proj_view: mat4x4<f32>,
}

struct particle {
    pos: vec4<f32>,
    speed: vec4<f32>,
    accel: vec4<f32>,
}

@group(0) @binding(1)
var<storage, read> particles: array<particle>;

@group(0) @binding(2)
var<uniform> dimensions: Dimensions;

struct VertexOutput {
    @builtin(position) clip_position: vec4<f32>,
    @location(0) color: vec4<f32>,
}

@vertex
fn vs_main(@builtin(vertex_index) in_vertex_index: u32) -> VertexOutput {
    var out: VertexOutput;

    let particle = particles[in_vertex_index];
    
    // Dead particles are collapsed to zero and made transparent
    if(particle.accel.w <= 0.0f) {
        out.clip_position = vec4<f32>(0.0, 0.0, 0.0, 0.0);
        out.color = vec4<f32>(1.0, 1.0, 1.0, 0.0);
        return out;
    }

    // Transform position: View/Projection Matrix * World Position
    out.clip_position = dimensions.proj_view * vec4<f32>(particle.pos.xyz, 1.0);
    
    // Coloring based on distance from target
    let particle_pos_len = (length((dimensions.target_pos.xyz) - (particle.pos.xyz))) / 15.0;
    
    let intensity_r = 1.0 - ((particle_pos_len * 0.9));
    let intensity_g = 1.0 - ((particle_pos_len * 0.2));
    let intensity_b = 1.0 - ((particle_pos_len * 0.01));

    out.color = vec4<f32>(intensity_r, intensity_g, intensity_b, 1.0);

    return out;
}

@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
    return in.color;
}

Draw Shader Logic

  • @vertex: The entry point for the vertex shader. It runs for each vertex. Since we draw POINTS, it runs once per particle.
  • @builtin(vertex_index): The ID of the current vertex (particle). We use this to look up the particle's position in the storage buffer.
  • @builtin(position): The final screen-space coordinate of the vertex. WGPU needs this to know where to draw.
  • @fragment: Runs for every pixel the point covers. It just outputs the calculated color.