CUDAでShared memoryを48KiB以上使うには
消えたCUDA関連の旧ブログ記事を復元するひとり Advent Calendar 2024の記事です。
何の話か
例えばGeForce RTX 3080 (Shared memory/L1 Cache: 128KB)で走らせることを想定した以下のコードがあります。
このコードは64KiB分のShared memoryのデータをGlobal memoryに書き出すだけのコードです。
// 64 KiB
constexpr unsigned shared_memory_size = 64 * 1024;
__global__ void kernel(float* const ptr) {
__shared__ float smem[shared_memory_size / sizeof(float)];
const unsigned long tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= shared_memory_size / sizeof(float)) return;
ptr[tid] = smem[tid];
}
int main() {
float *d_array;
cudaMalloc(&d_array, shared_memory_size);
kernel<<<shared_memory_size / 256, 256>>>(d_array);
cudaDeviceSynchronize();
}このコードをnvccでコンパイルすると、アクセスするShared memoryのアドレスは搭載されているShared memoryの大きさを超えていないですが、エラーとなります。
ptxas error : Entry function '_Z6kernelPf' uses too much shared data (0x10000 bytes, 0xc000 max)
0xc000 bytesは48KiBです。
ではどう書いたら48KiB以上のSharedメモリを使えるようになるかというのがこの記事です。
48KiBを超えるSharedメモリの確保の仕方
48KiB以上のShared memoryを確保するために行うことは3つです。
- カーネル関数内でのShared memoryの宣言を修正
- cudaFuncSetAttribute関数で、立ち上げるカーネル関数が必要とするShared memoryの大きさを設定
- カーネル関数の立ち上げ時に確保するShared memoryのサイズを指定
この3つを行うよう上記のコードを書き換えると以下のようになります。
// 64 KiB
constexpr unsigned shared_memory_size = 64 * 1024;
__global__ void kernel(float* const ptr) {
// 1.
extern __shared__ float smem[];
const unsigned long tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= shared_memory_size / sizeof(float)) return;
ptr[tid] = smem[tid];
}
int main() {
float *d_array;
cudaMalloc(&d_array, shared_memory_size);
// 2.
cudaFuncSetAttribute(&kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size);
// 3.
kernel<<<shared_memory_size / 256, 256, shared_memory_size>>>(d_array);
cudaDeviceSynchronize();
}終わりに
はじめのコードでは、コンパイルの時点ではどのアーキで実行されるかは判定できないため、サポートしているGPUの最小Shared memoryサイズを上限としてエラーを出しているということですかね。
ptxasでエラーが出ていることからも分かるとおり、ptxからカーネルイメージに落とす際にエラーが出るのですが、cuからptxへの変換はエラーなく行われます。
ですので、nvcc -ptx main.error.cuでptxを見てみると、64KiBのShared memoryをとろうとしていることが確認できます。
.version 7.2
.target sm_52
.address_size 64
.visible .entry _Z6kernelPf(
.param .u64 _Z6kernelPf_param_0
)
{
// iroiro
// demoted variable
.shared .align 4 .b8 _ZZ6kernelPfE4smem[65536];
// iroiro
}