CUDAでLambda関数
消えたCUDA関連の旧ブログ記事を復元するひとり Advent Calendar 2024の記事です。
CUDAではCUDA 8からカーネル関数内でLambda関数が使えます. で,PTXコードがどうなるのか気になったので調べてみました.
検証コード
このコードには計3つのカーネル関数があります.
- kernel_0
カーネル関数内に直にLambda関数を書くパターン - kernel_1
カーネル関数の引数にLambda関数を書くパターン - kernel_1
カーネル関数の引数にホストのデータをキャプチャするLambda関数を書くパターン
PTXコード
nvcc lambda.cu --ptx --expt-extended-lambda -arch=sm_75
でPTXコードを出力させます.
.version 6.4 .target sm_75 .address_size 64 // .globl _Z8kernel_0v .extern .func (.param .b32 func_retval0) vprintf ( .param .b64 vprintf_param_0, .param .b64 vprintf_param_1 ) ; .global .align 1 .b8 $str[7] = {112, 111, 105, 32, 48, 10, 0}; .global .align 1 .b8 $str1[7] = {112, 111, 105, 32, 49, 10, 0}; .global .align 1 .b8 $str2[10] = {112, 111, 105, 32, 50, 32, 37, 100, 10, 0}; .visible .entry _Z8kernel_0v( ) { .reg .b32 %r<2>; .reg .b64 %rd<4>; mov.u64 %rd1, $str; cvta.global.u64 %rd2, %rd1; mov.u64 %rd3, 0; // Callseq Start 0 { .reg .b32 temp_param_reg; .param .b64 param0; st.param.b64 [param0+0], %rd2; .param .b64 param1; st.param.b64 [param1+0], %rd3; .param .b32 retval0; call.uni (retval0), vprintf, ( param0, param1 ); ld.param.b32 %r1, [retval0+0]; }// Callseq End 0 ret; } // .globl _Z8kernel_1IZ4mainEUlvE_EvT_ .visible .entry _Z8kernel_1IZ4mainEUlvE_EvT_( .param .align 1 .b8 _Z8kernel_1IZ4mainEUlvE_EvT__param_0[1] ) { .reg .b32 %r<2>; .reg .b64 %rd<4>; mov.u64 %rd1, $str1; cvta.global.u64 %rd2, %rd1; mov.u64 %rd3, 0; // Callseq Start 1 { .reg .b32 temp_param_reg; .param .b64 param0; st.param.b64 [param0+0], %rd2; .param .b64 param1; st.param.b64 [param1+0], %rd3; .param .b32 retval0; call.uni (retval0), vprintf, ( param0, param1 ); ld.param.b32 %r1, [retval0+0]; }// Callseq End 1 ret; } // .globl _Z8kernel_1IZ4mainEUlvE0_EvT_ .visible .entry _Z8kernel_1IZ4mainEUlvE0_EvT_( .param .align 4 .b8 _Z8kernel_1IZ4mainEUlvE0_EvT__param_0[4] ) { .local .align 8 .b8 __local_depot2[8]; .reg .b64 %SP; .reg .b64 %SPL; .reg .b32 %r<3>; .reg .b64 %rd<5>; mov.u64 %SPL, __local_depot2; cvta.local.u64 %SP, %SPL; ld.param.u32 %r1, [_Z8kernel_1IZ4mainEUlvE0_EvT__param_0]; add.u64 %rd1, %SP, 0; add.u64 %rd2, %SPL, 0; st.local.u32 [%rd2], %r1; mov.u64 %rd3, $str2; cvta.global.u64 %rd4, %rd3; // Callseq Start 2 { .reg .b32 temp_param_reg; .param .b64 param0; st.param.b64 [param0+0], %rd4; .param .b64 param1; st.param.b64 [param1+0], %rd1; .param .b32 retval0; call.uni (retval0), vprintf, ( param0, param1 ); ld.param.b32 %r2, [retval0+0]; }// Callseq End 2 ret; }
マングリングされた
- _Z8kernel_0v
- _Z8kernel_1IZ4mainEUlvE_EvT_
- _Z8kernel_1IZ4mainEUlvE0_EvT_
の計3つのカーネル関数が見えます.
すべてinline展開されています.
このコードでのLambda関数の中身はコンパイル時には決定しているものなので普通の__device__関数と扱いは同じなようです.
特筆すべきは3つめのホストのデータをキャプチャするLambda関数でしょうか.
.visible .entry _Z8kernel_1IZ4mainEUlvE0_EvT_( .param .align 4 .b8 _Z8kernel_1IZ4mainEUlvE0_EvT__param_0[4] ) ... ld.param.u32 %r1, [_Z8kernel_1IZ4mainEUlvE0_EvT__param_0]; ... st.local.u32 [%rd2], %r1; ...
とあるので,カーネル関数の引数で配列としてとしてキャプチャするデータが渡されているようです.
おまけ 0
カーネル関数内のprintfですが,出力する文字列は
.global .align 1 .b8 $str[7] = {112, 111, 105, 32, 48, 10, 0}; .global .align 1 .b8 $str1[7] = {112, 111, 105, 32, 49, 10, 0}; .global .align 1 .b8 $str2[10] = {112, 111, 105, 32, 50, 32, 37, 100, 10, 0};
のようにglobalな空間に定義しておいて,%dなどで変数を表示する場合は
st.local.u32 [%rd2], %r1;
のようにローカルメモリの__local_depot2に配置してvprintfを呼ぶんですね.
おまけ 1
C++のstd::functionのCUDA版,nvstd::functionがあるそうです.
詳しくはNew Compiler Features in CUDA 8をどうぞ.
余談ですが,NVIDIAはWMMA APIではnvcuda名前空間を使いこちらではnvstd名前空間をつかっているのでちょっぴり統一してほしいです.