消えたCUDA関連の旧ブログ記事を復元するひとり Advent Calendar 2024の記事です。

何の話か

例えばGeForce RTX 3080 (Shared memory/L1 Cache: 128KB)で走らせることを想定した以下のコードがあります。
このコードは64KiB分のShared memoryのデータをGlobal memoryに書き出すだけのコードです。

// 64 KiB
constexpr unsigned shared_memory_size = 64 * 1024;

__global__ void kernel(float* const ptr) {
  __shared__ float smem[shared_memory_size / sizeof(float)];

  const unsigned long tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid >= shared_memory_size / sizeof(float)) return;

  ptr[tid] = smem[tid];
}

int main() {
  float *d_array;
  cudaMalloc(&d_array, shared_memory_size);

  kernel<<<shared_memory_size / 256, 256>>>(d_array);
  cudaDeviceSynchronize();
}

このコードをnvccでコンパイルすると、アクセスするShared memoryのアドレスは搭載されているShared memoryの大きさを超えていないですが、エラーとなります。

ptxas error   : Entry function '_Z6kernelPf' uses too much shared data (0x10000 bytes, 0xc000 max)

0xc000 bytesは48KiBです。
ではどう書いたら48KiB以上のSharedメモリを使えるようになるかというのがこの記事です。


48KiBを超えるSharedメモリの確保の仕方

48KiB以上のShared memoryを確保するために行うことは3つです。

  1. カーネル関数内でのShared memoryの宣言を修正
  2. cudaFuncSetAttribute関数で、立ち上げるカーネル関数が必要とするShared memoryの大きさを設定
  3. カーネル関数の立ち上げ時に確保するShared memoryのサイズを指定

この3つを行うよう上記のコードを書き換えると以下のようになります。

// 64 KiB
constexpr unsigned shared_memory_size = 64 * 1024;

__global__ void kernel(float* const ptr) {
  // 1.
  extern __shared__ float smem[];

  const unsigned long tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid >= shared_memory_size / sizeof(float)) return;

  ptr[tid] = smem[tid];
}

int main() {
  float *d_array;
  cudaMalloc(&d_array, shared_memory_size);

  // 2.
  cudaFuncSetAttribute(&kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size);
  // 3.
  kernel<<<shared_memory_size / 256, 256, shared_memory_size>>>(d_array);
  cudaDeviceSynchronize();
}

終わりに

はじめのコードでは、コンパイルの時点ではどのアーキで実行されるかは判定できないため、サポートしているGPUの最小Shared memoryサイズを上限としてエラーを出しているということですかね。
ptxasでエラーが出ていることからも分かるとおり、ptxからカーネルイメージに落とす際にエラーが出るのですが、cuからptxへの変換はエラーなく行われます。
ですので、nvcc -ptx main.error.cuでptxを見てみると、64KiBのShared memoryをとろうとしていることが確認できます。

.version 7.2
.target sm_52
.address_size 64

.visible .entry _Z6kernelPf(
  .param .u64 _Z6kernelPf_param_0
)
{
  // iroiro

  // demoted variable
  .shared .align 4 .b8 _ZZ6kernelPfE4smem[65536];

  // iroiro

}