#pragma kernel Preprocess
#pragma kernel Postprocess
#include "./ShaderConditionals.hlsl"
#include "./ShaderUtils.hlsl"
// Dimensions of the 3D texture (non-cubic supported)
uint3 dispatchDims; // (Nx, Ny, Nz)
// JFA jump offset (scalar, CPU halves it each pass)
// World-space mapping for anisotropic voxels
float3 minCornerWorld; // world-space origin of the voxel volume (same space as your mesh bounds)
float3 voxelWorldSize; // world units per voxel step along x/y/z (can be different)
float postProcessThickness; // in *normalized* distance units (see below)
float transitionSharpness;
RWTexture3D<float4> Voxels;
// Convert voxel index -> world position (centered)
float3 VoxelIndexToWorld(uint3 id)
// +0.5 to sample at voxel centers (recommended for SDF stability)
return minCornerWorld + ((float3)id + 0.5f) * voxelWorldSize;
// Max world extent for normalization (avoids cube assumption)
float3 volumeWorldSize = (float3)dispatchDims * voxelWorldSize;
return max(volumeWorldSize.x, max(volumeWorldSize.y, volumeWorldSize.z));
void Zero(uint3 id : SV_DispatchThreadID)
if (any(id >= dispatchDims)) return;
Voxels[id] = float4(0, 0, 0, 0);
void JFAIter(uint offset, uint3 id)
if (any(id >= dispatchDims)) return;
// World position of the current voxel center
float3 idWorld = VoxelIndexToWorld(id);
float4 closest = Voxels[id];
float closestDist = 1e30;
int3 boundsI = (int3)dispatchDims;
// Search 3x3x3 neighborhood at jump distance "offset"
[unroll] for (int i = -1; i <= 1; i++)
[unroll] for (int j = -1; j <= 1; j++)
[unroll] for (int k = -1; k <= 1; k++)
int3 atI = idI + int3(i, j, k) * (int)offset;
// Correct bounds test (must be signed)
if (any(atI < 0) || any(atI >= boundsI))
float4 voxel = Voxels[at];
// Not a seed / hasn't been assigned a seed yet
// voxel.xyz stores seed WORLD position (set in Preprocess)
float d = distance(idWorld, voxel.xyz);
void Preprocess(uint3 id : SV_DispatchThreadID)
if (any(id >= dispatchDims)) return;
// You currently use Voxels[id].x as your seed marker.
// Keep that behavior: if x > 0 => seed (or whatever convention you use).
float isSeed = Voxels[id].x;
float3 seedWorld = VoxelIndexToWorld(id);
// Store seed world position in xyz, and seed-flag in w.
// Non-seeds keep w==0 so they get skipped by JFA until assigned.
Voxels[id] = float4(seedWorld, isSeed);
void JFA(uint3 id : SV_DispatchThreadID)
JFAIter(samplingOffset, id);
void Postprocess(uint3 id : SV_DispatchThreadID)
if (any(id >= dispatchDims)) return;
// If a voxel never got a seed, you may want to define behavior:
// Here we just output 0.
Voxels[id] = float4(0, 0, 0, 0);
float3 seedWorld = v.xyz;
float3 idWorld = VoxelIndexToWorld(id);
float maxExtent = MaxWorldExtent();
// World distance normalized by max world extent
float dist = (distance(seedWorld, idWorld) / maxExtent) - postProcessThickness;
float ts = transitionSharpness;
float reps = frac(abs(dist) * fracRepetitions);
reps = 2 * abs(reps - 0.5);
// Normalize SDF (remap() comes from ShaderUtils.hlsl)
float normalized = saturate(remap(0.0, cutOff, 1.0, 0.0, 2.0 * dist));
float mask = smoothstep(0.5 - ts, 0.5 + ts, normalized);
// Apply cutoff / repetitions
float repDist = reps * mask;
// repDist, mask, dist, 1
Voxels[id] = float4(repDist, mask, dist, 1.0);
Voxels[id] = float4(dist, dist, dist, dist);