// Description: Compute shader for performing forward Discrete Cosine Transform (DCT) on an image block.
#define SHARED_MEMORY_STRIDE (GROUP_SIZE_X + 1)
// Trigonometric constants for DCT calculation
static const float COSINE_A = 1.387039845322148f;
static const float COSINE_B = 1.306562964876377f;
static const float COSINE_C = 1.175875602419359f;
static const float COSINE_D = 0.785694958387102f;
static const float COSINE_E = 0.541196100146197f;
static const float COSINE_F = 0.275899379282943f;
static const float NORMALIZATION_FACTOR = 0.3535533905932737f;
// Texture and buffer declarations
Texture2D<float> TextureSource : register(t0);
RWStructuredBuffer<float> OutputBuffer : register(u0);
// Function to perform in-place DCT on a vector of 8 elements
void InPlaceDCT(groupshared float sharedBlock[], uint index)
// Load elements from the shared block
float elem0 = sharedBlock[index];
float elem1 = sharedBlock[index + 1 * SHARED_MEMORY_STRIDE];
float elem2 = sharedBlock[index + 2 * SHARED_MEMORY_STRIDE];
float elem3 = sharedBlock[index + 3 * SHARED_MEMORY_STRIDE];
float elem4 = sharedBlock[index + 4 * SHARED_MEMORY_STRIDE];
float elem5 = sharedBlock[index + 5 * SHARED_MEMORY_STRIDE];
float elem6 = sharedBlock[index + 6 * SHARED_MEMORY_STRIDE];
float elem7 = sharedBlock[index + 7 * SHARED_MEMORY_STRIDE];
float temp0 = elem0 + elem7;
float temp1 = elem1 + elem6;
float temp2 = elem2 + elem5;
float temp3 = elem3 + elem4;
float temp4 = elem3 - elem4;
float temp5 = elem2 - elem5;
float temp6 = elem1 - elem6;
float temp7 = elem0 - elem7;
float temp10 = temp0 + temp3;
float temp11 = temp0 - temp3;
float temp12 = temp1 + temp2;
float temp13 = temp1 - temp2;
// Final DCT coefficients
sharedBlock[index] = NORMALIZATION_FACTOR * (temp10 + temp12);
sharedBlock[index + 2 * SHARED_MEMORY_STRIDE] = NORMALIZATION_FACTOR * (COSINE_B * temp11 + COSINE_E * temp13);
sharedBlock[index + 4 * SHARED_MEMORY_STRIDE] = NORMALIZATION_FACTOR * (temp10 - temp12);
sharedBlock[index + 6 * SHARED_MEMORY_STRIDE] = NORMALIZATION_FACTOR * (COSINE_E * temp11 - COSINE_B * temp13);
sharedBlock[index + 1 * SHARED_MEMORY_STRIDE] = NORMALIZATION_FACTOR * (COSINE_A * temp7 - COSINE_C * temp6 + COSINE_D * temp5 - COSINE_F * temp4);
sharedBlock[index + 3 * SHARED_MEMORY_STRIDE] = NORMALIZATION_FACTOR * (COSINE_C * temp7 + COSINE_F * temp6 - COSINE_A * temp5 + COSINE_D * temp4);
sharedBlock[index + 5 * SHARED_MEMORY_STRIDE] = NORMALIZATION_FACTOR * (COSINE_D * temp7 + COSINE_A * temp6 + COSINE_F * temp5 - COSINE_C * temp4);
sharedBlock[index + 7 * SHARED_MEMORY_STRIDE] = NORMALIZATION_FACTOR * (COSINE_F * temp7 + COSINE_D * temp6 + COSINE_C * temp5 + COSINE_A * temp4);
// The compute shader main function
[numthreads(GROUP_SIZE_X, GROUP_SIZE_Y, 1)]
void PerformForwardDCT(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID, uint3 dispatchThreadID : SV_DispatchThreadID)
// Shared memory for efficient data access
groupshared float sharedBlock[GROUP_SIZE_Y * SHARED_MEMORY_STRIDE];
// Calculate thread offsets within the block
uint threadRowOffset = groupThreadID.y * BLOCK_SIZE + groupThreadID.x;
uint threadColOffset = groupThreadID.z * BLOCK_SIZE;
// Load data into shared memory from the texture
uint textureCoordX = groupID.x * GROUP_SIZE_X + threadRowOffset;
uint textureCoordY = groupID.y * GROUP_SIZE_Y + threadColOffset;
for (uint i = 0; i < BLOCK_SIZE; i++)
sharedBlock[threadColOffset * SHARED_MEMORY_STRIDE + threadRowOffset] = TextureSource.Load(int3(textureCoordX, textureCoordY + i, 0));
GroupMemoryBarrierWithGroupSync(); // Synchronize threads in the group
// Perform in-place DCT on rows
if (groupThreadID.x < BLOCK_SIZE)
InPlaceDCT(sharedBlock, threadColOffset * SHARED_MEMORY_STRIDE + groupThreadID.x);
GroupMemoryBarrierWithGroupSync(); // Synchronize again before processing columns
// Perform in-place DCT on columns
if (groupThreadID.y < BLOCK_SIZE)
InPlaceDCT(sharedBlock, threadRowOffset * SHARED_MEMORY_STRIDE + groupThreadID.y);
// Write the results back to the output buffer
uint outputIndex = dispatchThreadID.y * GROUP_SIZE_X * dispatchThreadID.z + dispatchThreadID.x;
for (uint i = 0; i < BLOCK_SIZE; i++)
OutputBuffer[outputIndex + i * GROUP_SIZE_X] = sharedBlock[threadColOffset * SHARED_MEMORY_STRIDE + threadRowOffset];