Skip to content

Instantly share code, notes, and snippets.

@VictorTaelin
Created April 9, 2024 21:22
Show Gist options
  • Save VictorTaelin/3095032b157cbe79ec368347690fd893 to your computer and use it in GitHub Desktop.
Save VictorTaelin/3095032b157cbe79ec368347690fd893 to your computer and use it in GitHub Desktop.
Fast CUDA block-local prefix sum (scamsun) using warp sync primitives (__shfl_up_sync)
// Fast block-local prefix-sum on CUDA, using warp-syncs.
// The input is an array of u32. It is mutated in place. Example:
// arr = [1,1,1,1,...]
// Becomes:
// arr = [1,2,3,4,...]
// The number of elements must be equal to threads per block (TPB).
#include <stdio.h>
#include <cuda_runtime.h>
typedef unsigned int u32;
// Threads Per Block
#define TPB_L2 8
#define TPB (1 << TPB_L2)
// Amount of times to repeat, for benchmark
#define TIMES (32 * 256 * 256)
// OLD SCANSUM ("work-efficient" algorith) - exclusive
__device__ u32 scansum_0(u32* arr) {
u32 tid = threadIdx.x;
// upsweep
for (u32 d = 0; d < TPB_L2; ++d) {
u32 a = 1 << (d + 0);
u32 b = 1 << (d + 1);
if (tid % b == 0) {
arr[tid+b-1] += arr[tid+a-1];
}
__syncthreads();
}
// gets sum
u32 sum = arr[TPB - 1];
__syncthreads();
// clears last
if (tid == 0) {
arr[TPB - 1] = 0;
}
__syncthreads();
// downsweep
for (u32 d = TPB_L2 - 1; d <= TPB_L2 - 1; --d) {
u32 a = 1 << (d + 0);
u32 b = 1 << (d + 1);
if (tid % b == 0) {
u32 tmp = arr[tid+a-1];
arr[tid+a-1] = arr[tid+b-1];
arr[tid+b-1] += tmp;
}
__syncthreads();
}
return sum;
}
// NEW SCANSUM (using warp syncs) - inclusive
__device__ u32 scansum_1(u32* arr) {
__shared__ u32 wsum[TPB];
u32 tid = threadIdx.x; // thread id
u32 wid = tid / 32; // warp id
u32 lid = tid % 32; // local id
u32 ini = wid * 32; // array index
// Performs warp scansum
u32 sum, num;
sum = arr[ini+lid];
for (u32 k = 1; k < 32; k *= 2) {
num = __shfl_up_sync(__activemask(), sum, k);
sum = lid >= k ? sum + num : sum;
}
arr[ini+lid] = sum;
// Saves total warp sum
if (lid == 31) {
//printf("[%04x] %d <- %d\n", tid, TPB+wid, sum);
wsum[wid] = sum;
}
__syncthreads();
// First warp perform a "scansum of warp sums"
u32 ssum, snum;
if (wid == 0 && lid < TPB / 32) {
ssum = wsum[lid];
for (u32 k = 1; k < TPB / 32; k *= 2) {
snum = __shfl_up_sync(__activemask(), ssum, k);
ssum = lid >= k ? ssum + snum : ssum;
}
wsum[lid] = ssum;
}
__syncthreads();
// Adds sum of warps before this one
if (wid > 0) {
arr[ini+lid] += wsum[wid-1];
}
return sum;
}
__global__ void scansum_kernel(u32* arr) {
__shared__ u32 smem[2*TPB];
u32 tid = threadIdx.x;
for (u32 i = 0; i < TIMES; ++i) {
smem[tid] = tid;
__syncthreads();
scansum_1(smem);
__syncthreads();
}
arr[tid] = smem[tid];
arr[tid+TPB] = smem[tid+TPB];
}
int main() {
u32 h_arr[TPB];
memset(h_arr, 0, TPB * sizeof(u32));
u32 *d_arr;
cudaMalloc(&d_arr, TPB * sizeof(u32));
cudaMemcpy(d_arr, h_arr, TPB * sizeof(u32), cudaMemcpyHostToDevice);
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaEventRecord(start);
scansum_kernel<<<1, TPB>>>(d_arr);
cudaEventRecord(stop);
cudaMemcpy(h_arr, d_arr, TPB*sizeof(u32), cudaMemcpyDeviceToHost);
cudaEventSynchronize(stop);
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
printf("Scansum time: %f us\n", milliseconds * 1000.0 / (float)TIMES);
for (int i = 0; i < TPB; ++i) {
printf("%u ", h_arr[i]);
}
printf("\n");
cudaFree(d_arr);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment