揭秘大模型的魔法:實(shí)現(xiàn)帶可訓(xùn)練權(quán)重的自注意力機(jī)制

大家好,我是寫代碼的中年人。
上一篇我們實(shí)現(xiàn)了一個(gè)“無可訓(xùn)練參數(shù)”的注意力機(jī)制,讓每個(gè)詞都能“看看別人”,計(jì)算出自己的上下文理解。
雖然實(shí)現(xiàn)起來不難,但它只是個(gè)“玩具級(jí)”的注意力,離真正的大模型還差了幾個(gè)“億”個(gè)參數(shù)。今天,我們來實(shí)現(xiàn)一個(gè)可訓(xùn)練版本的自注意力機(jī)制,這可是 Transformer 的核心!
01、什么叫“可訓(xùn)練”的注意力?
在大模型里,注意力機(jī)制不是寫死的,而是學(xué)出來的。
為了讓每個(gè)詞都能“智能提問、精準(zhǔn)關(guān)注”,我們需要三個(gè)可訓(xùn)練的權(quán)重矩陣:

每個(gè)詞自己造問題,然后去問別的詞,看看誰最“對(duì)味”,然后決定聽誰的意見。
為什么自注意力機(jī)制(Self-Attention)中需要三個(gè)可訓(xùn)練的權(quán)重矩陣,也就是常說的:
Wq:Query 權(quán)重矩陣
Wk:Key 權(quán)重矩陣
Wv:Value 權(quán)重矩陣
這個(gè)設(shè)計(jì)最早出現(xiàn)在 2017 年 Google 的論文《Attention is All You Need》中,也就是Transformer架構(gòu)的原始論文。這三個(gè)矩陣的引入不是隨便“拍腦袋”的設(shè)計(jì),而是有明確動(dòng)機(jī)的:

# ONE
這段論文奠定了 Transformer 的注意力計(jì)算基礎(chǔ)。Transformer 后續(xù)所有的 Multi-Head Attention、Encoder-Decoder Attention,都是基于這個(gè) Scaled Dot-Product Attention 構(gòu)建的。
02、我是誰?我在哪?我要關(guān)注誰?
其實(shí)自注意力就是一種帶可訓(xùn)練權(quán)重的加權(quán)平均機(jī)制,它做了三件事:
把每個(gè)詞向量分別變成三個(gè)形態(tài):Query(查詢)、Key(鍵)、Value(值);
計(jì)算 Query 和所有 Key 的相似度(注意力權(quán)重);
用這個(gè)權(quán)重加權(quán) Value 向量,得出最終的輸出向量。每個(gè)詞都在用“自己的 Query”去看“別人的 Key”,然后決定“我到底該關(guān)注誰”。
如果我們想理解這些內(nèi)容,最好以代碼的形式來逐步理解:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# ------- 定義可訓(xùn)練的自注意力模塊 -------
class SelfAttention(nn.Module):
def __init__(self, embed_dim, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.dropout = dropout
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.last_attn_weights = None
def forward(self, x):
B, T, C = x.size()
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_dim ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
self.last_attn_weights = attn_weights.detach()
out = torch.matmul(attn_weights, V)
out = self.out_proj(out)
return out
# ------- Create a Simulated Dataset -------
# Simulate a small vocabulary and word embeddings
vocab = {"寫": 0, "代碼": 1, "的": 2, "中年人": 3, "天天": 4, "<PAD>": 5}
embed_dim = 16
vocab_size = len(vocab)
embedding = nn.Embedding(vocab_size, embed_dim) # Randomly initialized word embeddings
# Sentence data
sentences = [
["寫", "代碼", "的", "中年人"],
["天天", "寫", "代碼", "<PAD>"] # Pad the second sentence to match length
]
batch_size = len(sentences)
seq_len = len(sentences[0]) # Sentences have the same length (4)
# Convert sentences to indices
input_ids = torch.tensor([[vocab[word] for word in sent] for sent in sentences]) # (batch_size, seq_len)
# ------- Parameter Settings -------
epochs = 200
dropout = 0.1
model = SelfAttention(embed_dim, dropout)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# ------- Train the Model -------
for epoch in range(epochs):
model.train()
# Compute input inside the loop to create a fresh computation graph
x = embedding(input_ids) # (batch_size, seq_len, embed_dim)
target = x.clone() # Target is the same as input for this task
out = model(x)
loss = criterion(out, target)
optimizer.zero_grad()
loss.backward() # Compute gradients
optimizer.step() # Update model parameters
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1:3d}, Loss: {loss.item():.6f}")
# ------- Visualize Attention Weights -------
# Visualize attention matrix for the first sentence
attention = model.last_attn_weights[0].numpy() # (seq_len, seq_len)
sentence = sentences[0] # ["寫", "代碼", "的", "中年人"]
plt.figure(figsize=(8, 6))
plt.imshow(attention, cmap='viridis')
plt.title(f"Attention Matrix for Sentence: {' '.join(sentence)}")
plt.xticks(ticks=np.arange(seq_len), labels=sentence)
plt.yticks(ticks=np.arange(seq_len), labels=sentence)
plt.xlabel("Key (Word)")
plt.ylabel("Query (Word)")
plt.colorbar(label="Attention Strength")
for i in range(seq_len):
for j in range(seq_len):
plt.text(j, i, f"{attention[i,j]:.2f}", ha="center", va="center", color="white")
plt.tight_layout()
plt.savefig("attention_matrix_sentence1.png")
plt.show()
# Visualize attention matrix for the second sentence
attention = model.last_attn_weights[1].numpy()
sentence = sentences[1] # ["天天", "寫", "代碼", "<PAD>"]
plt.figure(figsize=(8, 6))
plt.imshow(attention, cmap='viridis')
plt.title(f"Attention Matrix for Sentence: {' '.join(sentence)}")
plt.xticks(ticks=np.arange(seq_len), labels=sentence)
plt.yticks(ticks=np.arange(seq_len), labels=sentence)
plt.xlabel("Key (Word)")
plt.ylabel("Query (Word)")
plt.colorbar(label="Attention Strength")
for i in range(seq_len):
for j in range(seq_len):
plt.text(j, i, f"{attention[i,j]:.2f}", ha="center", va="center", color="white")
plt.tight_layout()
plt.savefig("attention_matrix_sentence2.png")
plt.show()上面的代碼執(zhí)行后輸出:


代碼詳解:
這段代碼實(shí)現(xiàn)了一個(gè)簡(jiǎn)單的自注意力(Self-Attention)模型,并通過一個(gè)模擬的中文數(shù)據(jù)集進(jìn)行訓(xùn)練,展示自注意力機(jī)制如何捕捉句子中詞語之間的關(guān)系。以下是代碼的詳細(xì)解釋,以及對(duì)自注意力機(jī)制的深入分析。
這段代碼的核心目標(biāo)是:實(shí)現(xiàn)自注意力模塊:通過定義一個(gè)SelfAttention類,實(shí)現(xiàn)自注意力機(jī)制,模擬Transformer模型中的核心組件。訓(xùn)練模型:使用一個(gè)簡(jiǎn)單的中文詞匯數(shù)據(jù)集,訓(xùn)練自注意力模型,使其學(xué)習(xí)詞語之間的注意力分布。可視化注意力權(quán)重:通過繪制注意力矩陣,直觀展示模型如何關(guān)注句子中不同詞語之間的關(guān)系。
代碼主要分為以下幾個(gè)部分:數(shù)據(jù)集構(gòu)建:構(gòu)造一個(gè)小型中文詞匯表和兩個(gè)短句,模擬自然語言處理任務(wù)。模型定義:實(shí)現(xiàn)自注意力模塊,包含查詢(Query)、鍵(Key)、值(Value)的線性變換和注意力計(jì)算。訓(xùn)練過程:通過優(yōu)化模型,使其輸出盡可能接近輸入(一種簡(jiǎn)單的自監(jiān)督學(xué)習(xí)任務(wù))。可視化:繪制注意力矩陣,展示模型對(duì)不同詞語的關(guān)注程度。
03、自注意力機(jī)制詳解
自注意力機(jī)制的核心思想
自注意力是Transformer模型的核心組件,用于捕捉序列中元素(詞、字符等)之間的關(guān)系。
其核心思想是:
每個(gè)輸入元素(如詞)同時(shí)扮演查詢(Query)、鍵(Key)和值(Value)三個(gè)角色。通過計(jì)算查詢與鍵的相似度,生成注意力權(quán)重,決定每個(gè)元素對(duì)其他元素的關(guān)注程度。使用注意力權(quán)重對(duì)值進(jìn)行加權(quán)求和,生成上下文感知的表示。
數(shù)學(xué)公式:

# ONE
訓(xùn)練權(quán)重的作用:
在訓(xùn)練過程中,自注意力機(jī)制的權(quán)重(W_q, W_k, W_v, W_out)通過優(yōu)化器更新,目標(biāo)是使模型輸出盡可能接近輸入(MSE損失)。
具體作用:
學(xué)習(xí)語義關(guān)系:通過調(diào)整W_q和W_k,模型學(xué)習(xí)詞之間的語義關(guān)聯(lián)。例如,“寫”和“代碼”可能有較高的注意力權(quán)重,因?yàn)樗鼈冊(cè)谡Z義上相關(guān)。
增強(qiáng)表示:通過W_v和W_out,模型生成更豐富的上下文表示,捕捉句子中詞語的相互影響。
動(dòng)態(tài)關(guān)注:注意力權(quán)重是動(dòng)態(tài)計(jì)算的,允許模型根據(jù)輸入內(nèi)容靈活調(diào)整關(guān)注重點(diǎn)。
通過深入剖析自注意力機(jī)制及其可訓(xùn)練權(quán)重的核心作用,我們揭開了大模型處理復(fù)雜任務(wù)時(shí)那份“魔力”的關(guān)鍵一角。自注意力以其獨(dú)特的方式,讓模型能靈活聚焦于輸入序列中的重要信息,大幅提升了上下文理解的能力。但這只是開端。在下一章,我們將進(jìn)一步探討多頭注意力機(jī)制,看它如何通過并行處理多組注意力,為模型帶來更強(qiáng)的表達(dá)力和靈活性。


































