筆者在學習 BERT 架構技術時,看到書中提到了 BERT 沒有采用原始 Transformer 中的正弦-余弦位置編碼,但是沒講原因。
于是筆者到網上查了一番資料進行了學習。
在機器學習和深度學習的領域中,BERT 是一種強大的預訓練語言模型。它采用了許多優化策略,其中一個關鍵設計差異是其位置編碼方法。與原始 Transformer 中的正弦-余弦位置編碼方法不同,BERT 使用了基于可學習參數的嵌入方式來表示位置。
正弦-余弦位置編碼方法回顧
原始 Transformer 論文中提出的正弦-余弦位置編碼方法是一種固定的數學方法。它通過以下公式生成位置編碼:
import numpy as np
import torch
def get_sinusoidal_positional_encoding(seq_len, d_model):
position = np.arange(seq_len)[:, np.newaxis]
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pe = np.zeros((seq_len, d_model))
pe[:, 0::2] = np.sin(position * div_term)
pe[:, 1::2] = np.cos(position * div_term)
return torch.tensor(pe, dtype=torch.float32)
positional_encoding = get_sinusoidal_positional_encoding(seq_len=10, d_model=16)
print(positional_encoding)
上述方法利用正弦和余弦函數的周期性,使模型能夠感知相對位置。
優點
- 固定性:位置編碼是確定的,不會隨訓練過程變化,具有解析性。
- 相對性:編碼的相對位置信息通過周期性自然體現。
局限性
- 表達能力有限:正弦和余弦函數的固定模式對復雜語言結構的捕捉能力不足。
- 靈活性不足:無法根據任務或數據分布自適應優化。
- 難以擴展:在超長序列情況下,固定編碼可能無法很好地適應。
BERT 的位置編碼方法
BERT 選擇了一種基于可學習參數的嵌入方式,用于位置編碼。這種方法將每個位置作為索引輸入一個可訓練的嵌入矩陣,從而得到對應的向量表示。代碼如下:
import torch
import torch.nn as nn
class LearnedPositionalEncoding(nn.Module):
def __init__(self, seq_len, d_model):
super(LearnedPositionalEncoding, self).__init__()
self.position_embeddings = nn.Embedding(seq_len, d_model)
def forward(self, input_tensor):
seq_length = input_tensor.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_tensor.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_tensor[:, :, 0])
return self.position_embeddings(position_ids)
input_tensor = torch.randn(2, 10, 16) # Batch size 2, sequence length 10, hidden size 16
learned_pe = LearnedPositionalEncoding(seq_len=10, d_model=16)
output = learned_pe(input_tensor)
print(output)
通過這種方式,位置編碼可以在訓練過程中與其他模型參數一起更新,以適應具體任務的需求。
優勢分析
靈活性
與正弦-余弦位置編碼相比,可學習的嵌入能夠根據任務數據分布自動調整編碼模式。例如,在涉及句法或語義分析的任務中,不同的語言結構對位置信息的需求可能有顯著差異。固定編碼可能無法有效捕捉這些微妙的關系。
表達能力
從數學角度來看,正弦-余弦方法的維度分解固定,受限于頻率函數的形態。而 BERT 的可學習嵌入可以在高維空間中自由調整,能夠更好地擬合復雜的語言分布。
實驗驗證
研究顯示,BERT 在許多下游任務中的表現優于基于正弦-余弦位置編碼的模型。這表明可學習位置編碼在實際場景中具有更強的適應能力。
舉例說明
假設我們有一段文本 "機器學習改變了世界",其位置編碼的作用在于幫助模型理解 "機器學習" 和 "改變了世界" 的相對關系。
- 如果使用正弦-余弦編碼,這種關系通過函數周期性體現。然而,當文本長度增加時,相對位置關系的周期性可能變得模糊。
- 如果使用可學習位置嵌入,模型可以動態調整每個位置的表示。例如,它可能在訓練過程中學會對謂語動詞和賓語的位置關系賦予更高權重,從而增強理解能力。
以下代碼通過簡單示例模擬這一過程:
from transformers import BertModel, BertTokenizer
text = "機器學習改變了世界"
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
inputs = tokenizer(text, return_tensors="pt")
model = BertModel.from_pretrained("bert-base-chinese")
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)
在 BERT 中,位置編碼已融入模型的嵌入層中。通過分析輸出隱藏層狀態,可以發現不同位置上的表征逐步捕捉了句法和語義信息。
為什么選擇動態優化
真實案例
在工業應用中,例如機器翻譯,文本長度往往不可控。如果采用固定位置編碼,長文本的效果可能顯著下降。反之,BERT 的可學習位置編碼能夠在訓練過程中優化長短文本的表現,適應不同的上下文需求。
例如,Google 的搜索引擎使用 BERT 優化搜索結果。可學習的位置編碼使模型更好地理解查詢中重要詞匯的位置關系,從而提高相關性排序。
實驗比較
通過比較使用正弦-余弦和可學習位置編碼的 Transformer 模型,可以觀察到以下差異:
| 位置編碼方式 | 訓練靈活性 | 長文本性能 | 短文本性能 |
|---|---|---|---|
| 正弦-余弦編碼 | 低 | 一般 | 良好 |
| 可學習嵌入 | 高 | 優秀 | 優秀 |
實驗表明,可學習嵌入在多種任務中表現更加穩定,尤其在涉及長文本或復雜上下文的情況下優勢顯著。
小結
BERT 不采用正弦-余弦位置編碼的主要原因在于其靈活性和表達能力的局限。通過引入可學習的位置嵌入,BERT 能夠更好地適應不同任務的需求,從而在多種自然語言處理任務中實現更高的性能。這一設計選擇為語言模型的發展奠定了新的基準,也為后續模型優化提供了重要的啟發。