github.com/tspeterkim/flash-attention-minimal
該代碼符號定義基本(ben)與(yu)論文一致。但是有Br = Bc的(de)(de)隱(yin)含假設,不適合實際復雜的(de)(de)情況。
一、 輸入(ru)輸出定義
1.輸入
B: batch size
nh: number of heads
N: sequence length
d: head embedding dims
Q 尺(chi)寸 B x nh x N x d
K 尺寸 B x nh x N x d
V 尺寸(cun) B x nh x N x d
Br: 每(mei)個切分的Q的大(da)小
Bc: 每個切分的(de)K,V的(de)大小
Tr = ceil(N / Br)
Tc = ceil(N/ Bc)
2. 中間變量(liang)
l尺寸(cun)B x nh x N ,初始化為0
m 尺寸B x nh x N ,初始化(hua)為負無窮(qiong)
l和(he)m開辟在Global memory上(shang),使(shi)用時加載到寄存器上(shang)
3. 輸(shu)出
O 尺寸 B x nh x N x d
O開辟在Global memory上,使用(yong)時加載到shared memory上
二、 kernel函(han)數調用(yong)
grid_dim尺寸 B x nh,即第一個(ge)維(wei)度對應(ying)batch, 第二個(ge)維(wei)度對應(ying)head數(shu)(shu)目。則每個(ge)block處理的(de)是一個(ge)head的(de)運算。QKV的(de)數(shu)(shu)據尺寸(cun)都(dou)為N x d
dim3 grid_dim(B, nh); // batch_size x num_heads
block中的thread數(shu)目是(shi)(shi)Bc(實際應該是(shi)(shi)Br).
dim3 block_dim(Bc); // Bc threads per block
三、 核(he)函數(shu)實現(xian)
這里代碼假設(she)了Br = Bc。實(shi)際(ji)可(ke)能不是,有問題(ti)
1. 共(gong)享內存(cun)開辟
共(gong)享(xiang)內(nei)存(cun)大小為Bc x d x 3 + Br x Bc,分別(bie)存儲Qi, Kj, Vj和中(zhong)間變量Sij
2. 核(he)函(han)數結構
● 每個(ge)(ge)block處理一(yi)個(ge)(ge)head。
● 每個(ge)block單次(ci)計(ji)算(suan)處理的(de)是一個(ge)分(fen)塊(kuai)Qi, Kj, Vj的(de)計(ji)算(suan)。故block內部(bu)有Tr*Tc次(ci)循(xun)(xun)環(huan)(huan)(huan)。循(xun)(xun)環(huan)(huan)(huan)結構與論文定義相同。即外循(xun)(xun)環(huan)(huan)(huan)為(wei)Kj,Vj的(de)循(xun)(xun)環(huan)(huan)(huan),內循(xun)(xun)環(huan)(huan)(huan)為(wei)Qi的(de)循(xun)(xun)環(huan)(huan)(huan)。
● 一(yi)個(ge)thread單次處理(li)的(de)(de)(de)是Qi中(zhong)(zhong)一(yi)行的(de)(de)(de)數據。由于(yu)Qi中(zhong)(zhong)一(yi)行會和Kj中(zhong)(zhong)的(de)(de)(de)每一(yi)行計(ji)算(suan)。故(gu)在計(ji)算(suan)時,通過一(yi)個(ge)長度為(wei)Bc的(de)(de)(de)循環實現對Kj每行的(de)(de)(de)遍(bian)歷。對Vj的(de)(de)(de)遍(bian)歷同理(li)。
● 由(you)于(yu)一個thread對應的是Qi的一行(xing),所以該(gai)行(xing)m和l的狀態更新(xin)只需要(yao)寄存器保存。完成Q的遍歷后(hou)再(zai)寫入(ru)HBM即可。
3. 核函(han)數流程
- 定義外層(ceng)循 環,從HBM將Kj和Vj加(jia)載到(dao)shared memory
// 定義(yi)外(wai)層循(xun)環(huan),從HBM將Kj和Vj加載(zai)到(dao)shared memory
for (int j = 0; j < Tc; j++) {
// Load Kj, Vj to SRAM
// 將Kj和Vj加載到shared memory
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads();
2. 定義(yi)內層循環,從HBM將(jiang)Qi加(jia)載(zai)(zai)到shared_memory。從HBM將(jiang)mi和li加(jia)載(zai)(zai)到寄存器。
// 定義(yi)內(nei)層(ceng)循環(huan),從(cong)HBM將Qi加載到shared_memory。從(cong)HBM將mi和(he)li加載到寄存器。
for (int i = 0; i < Tr; i++) {
// Load Qi to SRAM, l and m to registers
// 將Qi加載(zai)到(dao)shared memory,l和m加載(zai)到(dao)寄存器
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];
// S = QK^T, row_m = rowmax(S)
3 . 執行QK^T計算。由于一個線程代表(biao)Qi里面的(de)一行,對Kj的(de)遍歷通(tong)過一個長度為(wei)Bc的(de)循環實現。row_m代表(biao)文章里的(de)mij。
// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;
if (sum > row_m)
row_m = sum;
}
4. 按(an)行累加(jia)求和,計(ji)算lij。
// P = exp(S - row_m), row_l = rowsum(P)
// 計(ji)算(suan)l_ij
float row_l = 0;
for (int y = 0; y < Bc; y++) {
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}
5. 更新m_new和l_new。此時寄存器內(nei)有從HBM中加載(zai)的(de)上一時刻(ke)的(de)m_prev和l_prev。將當前線(xian)程計(ji)算的(de)m和l代入公(gong)式。
// Compute new m and l
// 更新m_i
float row_m_new = max(row_m_prev, row_m);
// 更新l_i
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);
6. 將O,l,m寫入HBM。
for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
// 獲取pv相乘的(de)結果,這里p
for (int y = 0; y < Bc; y++) {
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
+ (__expf(row_m - row_m_new) * pv));
}
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;
7. 在每(mei)個內循環(huan)結(jie)束后,即Q遍歷結(jie)束后,需要進(jin)行一個同步。因為接下來(lai)外循環(huan)會更新Kj和(he)Vj,不同步會使得Qi使用錯誤的Kj和(he)Vj值。
__syncthreads(); // otherwise, thread can use the wrong Kj, Vj in inner loop