亚欧色一区w666天堂,色情一区二区三区免费看,少妇特黄A片一区二区三区,亚洲人成网站999久久久综合,国产av熟女一区二区三区

  • 發布文章
  • 消息中心
點贊
收藏
評論
分享
原創

flash attention代碼解讀

2024-06-26 09:44:34
84
0
代碼地址:

github.com/tspeterkim/flash-attention-minimal

該代碼符號定義基本(ben)與(yu)論文一致。但是有Br = Bc的(de)(de)隱(yin)含假設,不適合實際復雜的(de)(de)情況。

 

一、  輸入(ru)輸出定義

1.輸入

B: batch size

nh: number of heads

N: sequence length

d: head embedding dims

Q 尺(chi)寸 B x nh x N x d

K 尺寸 B x nh x N x d

V 尺寸(cun) B x nh x N x d

Br: 每(mei)個切分的Q的大(da)小

Bc: 每個切分的(de)K,V的(de)大小

Tr = ceil(N / Br)

Tc = ceil(N/ Bc)

2.  中間變量(liang)

l尺寸(cun)B x nh x N  ,初始化為0

m 尺寸B x nh x N ,初始化(hua)為負無窮(qiong)

l和(he)m開辟在Global memory上(shang),使(shi)用時加載到寄存器上(shang)

3.  輸(shu)出

O 尺寸 B x nh x N x d

O開辟在Global memory上,使用(yong)時加載到shared memory上

二、  kernel函(han)數調用(yong)

grid_dim尺寸 B x nh,即第一個(ge)維(wei)度對應(ying)batch, 第二個(ge)維(wei)度對應(ying)head數(shu)(shu)目。則每個(ge)block處理的(de)是一個(ge)head的(de)運算。QKV的(de)數(shu)(shu)據尺寸(cun)都(dou)為N x d

 

dim3 grid_dim(B, nh); // batch_size x num_heads

 

block中的thread數(shu)目是(shi)(shi)Bc(實際應該是(shi)(shi)Br).

 

dim3 block_dim(Bc); // Bc threads per block

 

三、  核(he)函數(shu)實現(xian)

這里代碼假設(she)了Br = Bc。實(shi)際(ji)可(ke)能不是,有問題(ti)

1.  共(gong)享內存(cun)開辟

共(gong)享(xiang)內(nei)存(cun)大小為Bc x d x 3 + Br x Bc,分別(bie)存儲Qi, Kj, Vj和中(zhong)間變量Sij

2.  核(he)函(han)數結構

● 每個(ge)(ge)block處理一(yi)個(ge)(ge)head。

● 每個(ge)block單次(ci)計(ji)算(suan)處理的(de)是一個(ge)分(fen)塊(kuai)Qi, Kj, Vj的(de)計(ji)算(suan)。故block內部(bu)有Tr*Tc次(ci)循(xun)(xun)環(huan)(huan)(huan)。循(xun)(xun)環(huan)(huan)(huan)結構與論文定義相同。即外循(xun)(xun)環(huan)(huan)(huan)為(wei)Kj,Vj的(de)循(xun)(xun)環(huan)(huan)(huan),內循(xun)(xun)環(huan)(huan)(huan)為(wei)Qi的(de)循(xun)(xun)環(huan)(huan)(huan)。

● 一(yi)個(ge)thread單次處理(li)的(de)(de)(de)是Qi中(zhong)(zhong)一(yi)行的(de)(de)(de)數據。由于(yu)Qi中(zhong)(zhong)一(yi)行會和Kj中(zhong)(zhong)的(de)(de)(de)每一(yi)行計(ji)算(suan)。故(gu)在計(ji)算(suan)時,通過一(yi)個(ge)長度為(wei)Bc的(de)(de)(de)循環實現對Kj每行的(de)(de)(de)遍(bian)歷。對Vj的(de)(de)(de)遍(bian)歷同理(li)。

● 由(you)于(yu)一個thread對應的是Qi的一行(xing),所以該(gai)行(xing)m和l的狀態更新(xin)只需要(yao)寄存器保存。完成Q的遍歷后(hou)再(zai)寫入(ru)HBM即可。

3.  核函(han)數流程

  1. 定義外層(ceng)循 環,從HBM將Kj和Vj加(jia)載到(dao)shared memory
     
    // 定義(yi)外(wai)層循(xun)環(huan),從HBM將Kj和Vj加載(zai)到(dao)shared memory
    for (int j = 0; j < Tc; j++) {

    // Load Kj, Vj to SRAM
    // 將Kj和Vj加載到shared memory
    for (int x = 0; x < d; x++) {
    Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
    Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
    }
    __syncthreads();

 

2.  定義(yi)內層循環,從HBM將(jiang)Qi加(jia)載(zai)(zai)到shared_memory。從HBM將(jiang)mi和li加(jia)載(zai)(zai)到寄存器。

// 定義(yi)內(nei)層(ceng)循環(huan),從(cong)HBM將Qi加載到shared_memory。從(cong)HBM將mi和(he)li加載到寄存器。
for (int i = 0; i < Tr; i++) {

// Load Qi to SRAM, l and m to registers
// 將Qi加載(zai)到(dao)shared memory,l和m加載(zai)到(dao)寄存器
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];

// S = QK^T, row_m = rowmax(S)

3 . 執行QK^T計算。由于一個線程代表(biao)Qi里面的(de)一行,對Kj的(de)遍歷通(tong)過一個長度為(wei)Bc的(de)循環實現。row_m代表(biao)文章里的(de)mij

// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;

if (sum > row_m)
row_m = sum;
}

4.  按(an)行累加(jia)求和,計(ji)算lij

// P = exp(S - row_m), row_l = rowsum(P)
// 計(ji)算(suan)l_ij
float row_l = 0;
for (int y = 0; y < Bc; y++) {
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}

 

5.  更新m_new和l_new。此時寄存器內(nei)有從HBM中加載(zai)的(de)上一時刻(ke)的(de)m_prev和l_prev。將當前線(xian)程計(ji)算的(de)m和l代入公(gong)式。

// Compute new m and l
// 更新m_i
float row_m_new = max(row_m_prev, row_m);
// 更新l_i
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);

 

6.  將O,l,m寫入HBM。

for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
// 獲取pv相乘的(de)結果,這里p
for (int y = 0; y < Bc; y++) {
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
+ (__expf(row_m - row_m_new) * pv));
}
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;

 

7.  在每(mei)個內循環(huan)結(jie)束后,即Q遍歷結(jie)束后,需要進(jin)行一個同步。因為接下來(lai)外循環(huan)會更新Kj和(he)Vj,不同步會使得Qi使用錯誤的Kj和(he)Vj值。

__syncthreads();  // otherwise, thread can use the wrong Kj, Vj in inner loop

 

0條評論
0 / 1000
王****鵬
4文章數
0粉絲數
王****鵬
4 文章(zhang) | 0 粉(fen)絲
王****鵬
4文章數
0粉絲(si)數
王****鵬
4 文章 | 0 粉(fen)絲
原創

flash attention代碼解讀

2024-06-26 09:44:34
84
0
代碼地址:

github.com/tspeterkim/flash-attention-minimal

該代碼符(fu)號定義(yi)基(ji)本(ben)與論文一(yi)致。但是有Br = Bc的(de)隱含(han)假設,不適合(he)實際復雜(za)的(de)情況。

 

一、  輸(shu)(shu)入輸(shu)(shu)出定義

1.輸入(ru)

B: batch size

nh: number of heads

N: sequence length

d: head embedding dims

Q 尺寸(cun) B x nh x N x d

K 尺寸 B x nh x N x d

V 尺寸 B x nh x N x d

Br: 每個切分的Q的大小

Bc: 每個切分的K,V的大小

Tr = ceil(N / Br)

Tc = ceil(N/ Bc)

2.  中間變量

l尺寸B x nh x N  ,初始化為0

m 尺寸(cun)B x nh x N ,初始(shi)化為負無窮

l和(he)m開辟(pi)在Global memory上,使用時加載到(dao)寄(ji)存器上

3.  輸出(chu)

O 尺寸 B x nh x N x d

O開辟在(zai)Global memory上(shang),使用(yong)時(shi)加載到(dao)shared memory上(shang)

二、  kernel函數(shu)調用(yong)

grid_dim尺寸 B x nh,即第(di)一(yi)個維(wei)度對應batch, 第(di)二個維(wei)度對應head數目。則每個block處(chu)理的(de)是一(yi)個head的(de)運算。QKV的(de)數據(ju)尺寸都為N x d

 

dim3 grid_dim(B, nh); // batch_size x num_heads

 

block中(zhong)的thread數目是Bc(實際應該是Br).

 

dim3 block_dim(Bc); // Bc threads per block

 

三、  核函數(shu)實現

這里代碼假設了Br = Bc。實際(ji)可能不是,有問題(ti)

1.  共享(xiang)內存開(kai)辟

共享內存大小(xiao)為Bc x d x 3 + Br x Bc,分別(bie)存儲Qi, Kj, Vj和中間變(bian)量Sij

2.  核函數結構

● 每個block處理一個head。

● 每個block單次(ci)計算處理的是一個分塊Qi, Kj, Vj的計算。故block內部(bu)有Tr*Tc次(ci)循(xun)(xun)(xun)環(huan)。循(xun)(xun)(xun)環(huan)結構與(yu)論文定義相同(tong)。即外循(xun)(xun)(xun)環(huan)為(wei)Kj,Vj的循(xun)(xun)(xun)環(huan),內循(xun)(xun)(xun)環(huan)為(wei)Qi的循(xun)(xun)(xun)環(huan)。

● 一個(ge)thread單次處理的(de)(de)(de)是Qi中一行(xing)(xing)的(de)(de)(de)數據。由(you)于Qi中一行(xing)(xing)會和Kj中的(de)(de)(de)每(mei)一行(xing)(xing)計(ji)算(suan)。故在計(ji)算(suan)時,通(tong)過一個(ge)長度為(wei)Bc的(de)(de)(de)循(xun)環實現對(dui)(dui)Kj每(mei)行(xing)(xing)的(de)(de)(de)遍歷。對(dui)(dui)Vj的(de)(de)(de)遍歷同理。

● 由于一個thread對應的是Qi的一行,所以該行m和l的狀(zhuang)態更新只需要寄(ji)存器保存。完成Q的遍歷(li)后(hou)再寫入HBM即可。

3.  核函數(shu)流程

  1. 定義外層(ceng)循 環,從HBM將Kj和(he)Vj加載到(dao)shared memory
     
    // 定義(yi)外層循環,從HBM將Kj和Vj加載到shared memory
    for (int j = 0; j < Tc; j++) {

    // Load Kj, Vj to SRAM
    // 將Kj和Vj加載到shared memory
    for (int x = 0; x < d; x++) {
    Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
    Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
    }
    __syncthreads();

 

2.  定義內層循環,從HBM將Qi加載到shared_memory。從HBM將mi和(he)li加載到寄存器(qi)。

// 定義內層循環,從HBM將Qi加(jia)載(zai)到shared_memory。從HBM將mi和li加(jia)載(zai)到寄存器。
for (int i = 0; i < Tr; i++) {

// Load Qi to SRAM, l and m to registers
// 將(jiang)Qi加載到(dao)shared memory,l和m加載到(dao)寄存器(qi)
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];

// S = QK^T, row_m = rowmax(S)

3 . 執行(xing)QK^T計算。由于一(yi)個(ge)線程代表(biao)Qi里面(mian)的(de)一(yi)行(xing),對Kj的(de)遍歷通(tong)過一(yi)個(ge)長度為Bc的(de)循環(huan)實(shi)現。row_m代表(biao)文(wen)章里的(de)mij

// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;

if (sum > row_m)
row_m = sum;
}

4.  按行(xing)累加求和,計算lij

// P = exp(S - row_m), row_l = rowsum(P)
// 計(ji)算l_ij
float row_l = 0;
for (int y = 0; y < Bc; y++) {
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}

 

5.  更(geng)新m_new和l_new。此時寄存器內有從HBM中加(jia)載的上(shang)一時刻的m_prev和l_prev。將當前線(xian)程(cheng)計(ji)算(suan)的m和l代入公式。

// Compute new m and l
// 更(geng)新m_i
float row_m_new = max(row_m_prev, row_m);
// 更新(xin)l_i
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);

 

6.  將O,l,m寫入HBM。

for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
// 獲取(qu)pv相乘的(de)結果,這里p
for (int y = 0; y < Bc; y++) {
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
+ (__expf(row_m - row_m_new) * pv));
}
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;

 

7.  在(zai)每個(ge)內循(xun)環結束(shu)后(hou),即(ji)Q遍(bian)歷結束(shu)后(hou),需要(yao)進行一個(ge)同步。因為接下(xia)來外循(xun)環會更新(xin)Kj和Vj,不同步會使(shi)得Qi使(shi)用(yong)錯誤的Kj和Vj值。

__syncthreads();  // otherwise, thread can use the wrong Kj, Vj in inner loop

 

文章來自個人專欄
文章 | 訂閱
0條評論
0 / 1000
請輸入你的評論
0
0