從零實現一個17M參數的GPT預訓練模型

大家好,我是寫代碼的中年人!
今天我們使用開源的的中文數據進行模型的預訓練,下面跟著我的步驟,從零實現你的預訓練模型。
本文所有代碼和數據資源位置:
https://github.com/ColinAIAPP/MoiraiLM
01、預訓練模型的概念
預訓練模型(Pretrained Model)就是一個已經在海量數據上訓練過的模型,它學會了語言的基本規律、結構和語義,然后可以拿來做各種下游任務,比如寫作、翻譯、問答、分類、生成代碼等。
那“預訓練”到底在學什么?以語言模型(LLM)為例:預訓練階段的任務通常是預測下一個詞(token)。
接下來我們就一步一步實現一個17M參數的預訓練模型。
02、數據準備
構建語言模型的第一要義是高質量的數據源。對于中文任務,選擇維基百科開源中文數據集是一個理想起點。這個數據集包含數百萬條中文百科條目,涵蓋歷史、文化、科技等領域,總量約數GB的純文本數據。它開源且免費,可通過維基百科的官方轉儲頁面下載最新版本的XML格式文件。
要解壓處理這個文件我們要使用wikiextractor工具進行數據解壓。
安裝解壓命令:
pip install wikiextractor解壓命令:
python -m wikiextractor.WikiExtractor -b 1G -o extracted_wiki_zh zhwiki-20250920-pages-articles-multistream.xml.bz2 --json
zhwiki-20250920-pages-articles-multistream.xml.bz2:為文件名INFO: Preprocessing 'zhwiki-20250920-pages-articles-multistream.xml.bz2' to collect template definitions: this may take some time.
INFO: Preprocessed 100000 pages
INFO: Preprocessed 200000 pages
INFO: Preprocessed 300000 pages
INFO: Preprocessed 400000 pages
INFO: Preprocessed 500000 pages
INFO: Preprocessed 600000 pages
INFO: Preprocessed 700000 pages
INFO: Preprocessed 800000 pages
INFO: Preprocessed 900000 pages
INFO: Preprocessed 1000000 pages
INFO: Preprocessed 1100000 pages
INFO: Preprocessed 1200000 pages
INFO: Preprocessed 1300000 pages
INFO: Preprocessed 1400000 pages
INFO: Preprocessed 1500000 pages
INFO: Preprocessed 1600000 pages
INFO: Preprocessed 1700000 pages
INFO: Preprocessed 1800000 pages
INFO: Preprocessed 1900000 pages
INFO: Preprocessed 2000000 pages
INFO: Preprocessed 2100000 pages
INFO: Preprocessed 2200000 pages
INFO: Preprocessed 2300000 pages
INFO: Preprocessed 2400000 pages
INFO: Preprocessed 2500000 pages
INFO: Preprocessed 2600000 pages
INFO: Preprocessed 2700000 pages
INFO: Preprocessed 2800000 pages
INFO: Preprocessed 2900000 pages
INFO: Preprocessed 3000000 pages
INFO: Preprocessed 3100000 pages
INFO: Preprocessed 3200000 pages
INFO: Preprocessed 3300000 pages
INFO: Preprocessed 3400000 pages
INFO: Preprocessed 3500000 pages
INFO: Preprocessed 3600000 pages
INFO: Preprocessed 3700000 pages
INFO: Preprocessed 3800000 pages
INFO: Preprocessed 3900000 pages
INFO: Preprocessed 4000000 pages
INFO: Preprocessed 4100000 pages
INFO: Preprocessed 4200000 pages
INFO: Preprocessed 4300000 pages
INFO: Preprocessed 4400000 pages
INFO: Preprocessed 4500000 pages
INFO: Preprocessed 4600000 pages
INFO: Preprocessed 4700000 pages
INFO: Loaded 1036734 templates in 704.2s
INFO: Starting page extraction from zhwiki-20250920-pages-articles-multistream.xml.bz2.
INFO: Using 127 extract processes.
INFO: Extracted 100000 articles (1209.6 art/s)
INFO: Extracted 200000 articles (1947.8 art/s)
INFO: Extracted 300000 articles (2325.1 art/s)
INFO: Extracted 400000 articles (3471.3 art/s)
INFO: Extracted 500000 articles (2551.1 art/s)
INFO: Extracted 600000 articles (2239.4 art/s)
INFO: Extracted 700000 articles (2299.3 art/s)
INFO: Extracted 800000 articles (1525.2 art/s)
INFO: Extracted 900000 articles (3256.1 art/s)
INFO: Extracted 1000000 articles (3485.9 art/s)
INFO: Extracted 1100000 articles (3495.0 art/s)
INFO: Extracted 1200000 articles (3330.4 art/s)
INFO: Extracted 1300000 articles (3555.6 art/s)
INFO: Extracted 1400000 articles (3456.3 art/s)
INFO: Extracted 1500000 articles (2476.1 art/s)
INFO: Extracted 1600000 articles (2268.6 art/s)
INFO: Extracted 1700000 articles (2473.5 art/s)
INFO: Extracted 1800000 articles (2305.9 art/s)
INFO: Extracted 1900000 articles (2263.9 art/s)
INFO: Extracted 2000000 articles (2136.4 art/s)
INFO: Extracted 2100000 articles (2363.0 art/s)
INFO: Extracted 2200000 articles (2601.9 art/s)
INFO: Extracted 2300000 articles (3709.0 art/s)
INFO: Extracted 2400000 articles (2723.9 art/s)
INFO: Extracted 2500000 articles (2487.1 art/s)
INFO: Extracted 2600000 articles (2621.3 art/s)
INFO: Extracted 2700000 articles (2525.4 art/s)
INFO: Extracted 2800000 articles (2666.4 art/s)
INFO: Finished 127-process extraction of 2893023 articles in 1156.5s (2501.5 art/s)03、清洗數據
我們解壓后的數據如下圖,下面我們要把數據清洗出來。
注:
我們本步驟生成的文件為 data/cleaned_wiki_full.txt
import os
import json
import logging
import argparse
import re
from tqdm import tqdm
# 配置日志記錄
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
# python scripts/clean_wiki_text.py data/extracted_wiki_zh data/cleaned_wiki_full.txt --min_line_length 20 --min_article_length 300
def clean_text(text: str) -> str:
"""
對文本進行深度清洗。
移除維基百科特有的格式標記、參考文獻、HTML標簽、日期和數字等。
"""
# 移除維基鏈接 [[link|display]] 或 [[link]]
text = re.sub(r'\[\[([^\]|]+\|)?([^\]]+)\]\]', r'\2', text)
# 移除參考文獻標記 [1], [2], [ref], 等
text = re.sub(r'\[\d+\]|\[ref\]|\[/ref\]|\[citation needed\]', '', text)
# 移除HTML標簽
text = re.sub(r'<[^>]+>', '', text)
# 移除日期格式 (yyyy-mm-dd, yyyy/mm/dd, mm/dd/yyyy 等)
text = re.sub(r'\d{1,4}[-/]\d{1,2}[-/]\d{1,4}', '', text)
# 移除年份 (1000-2999)
text = re.sub(r'\b[12]\d{3}\b', '', text)
# 移除純數字(包括小數)
text = re.sub(r'\b\d+\.?\d*\b', '', text)
# 移除重復的空白字符(但保留單個空格)
text = re.sub(r' +', ' ', text)
# 移除行首尾空白
text = text.strip()
return text
def process_extracted_wiki(extracted_dir: str,
output_file: str,
min_line_length: int = 20,
min_article_length: int = 200):
"""
讀取WikiExtractor輸出的JSON文件,提取、清洗文本并保存到單個文件中。
參數:
extracted_dir: WikiExtractor輸出的目錄路徑
output_file: 最終合并的純文本文件路徑
min_line_length: 單行文本最小長度,用于過濾噪音(默認: 20)
min_article_length: 文章最小長度,用于過濾短文章(默認: 200)
"""
if not os.path.isdir(extracted_dir):
logging.error(f"輸入的目錄不存在: {extracted_dir}")
return
total_articles = 0
skipped_articles = 0
# 第一次遍歷:獲取所有需要處理的文件列表
file_list = []
for root, dirs, files in os.walk(extracted_dir):
for file_name in files:
# 僅處理 WikiExtractor 生成的以 'wiki_' 開頭的文件
if file_name.startswith('wiki_'):
file_list.append(os.path.join(root, file_name))
total_files = len(file_list)
logging.info(f"找到 {total_files} 個文件等待處理。")
if total_files == 0:
logging.warning(f"目錄 {extracted_dir} 中未找到任何 'wiki_' 文件。請檢查路徑。")
return
# 第二次遍歷:處理文件并寫入輸出
with open(output_file, 'w', encoding='utf-8') as f_out:
# 使用 tqdm 包裝文件列表,顯示處理進度
for file_path in tqdm(file_list, desc="?? 正在提取維基文本"):
try:
with open(file_path, 'r', encoding='utf-8') as f_in:
for line_num, line in enumerate(f_in, 1):
try:
article = json.loads(line)
text_content = article.get('text', '').strip()
# --- 文本清洗和過濾 ---
# 1. 過濾掉過短的文章,它們通常是噪音或重定向頁
if len(text_content) < min_article_length:
skipped_articles += 1
continue
# 2. 按行處理文本,過濾短行和額外的空白
# 保留行結構,而不是將所有行連接成一個長句子
cleaned_lines = []
for text_line in text_content.split('\n'):
text_line = clean_text(text_line)
# 只保留足夠長的行
if len(text_line) >= min_line_length:
cleaned_lines.append(text_line)
# 使用換行符連接各行,保留段落結構
final_text = '\n'.join(cleaned_lines)
# 最終檢查:確保清洗后的文本仍然足夠長
if final_text and len(final_text) >= min_article_length:
# 文章之間用兩個換行符分隔
f_out.write(final_text + '\n\n')
total_articles += 1
else:
skipped_articles += 1
except json.JSONDecodeError:
logging.warning(f"無法解析 JSON,文件: {file_path},行號: {line_num}")
except Exception as e:
logging.error(f"處理文件 {file_path} 第 {line_num} 行時出錯: {e}")
except Exception as e:
logging.error(f"打開文件 {file_path} 時出錯: {e}")
logging.info(f" 所有維基百科文本已成功提取并清洗。")
logging.info(f" 總文章數: {total_articles}")
logging.info(f" 跳過文章數: {skipped_articles}")
logging.info(f" 文件已保存到: {output_file}")
def main():
parser = argparse.ArgumentParser(
descriptinotallow="從 WikiExtractor 輸出的 JSON 文件中提取并清洗純文本。",
formatter_class=argparse.RawTextHelpFormatter
)
# 位置參數 1: 輸入目錄
parser.add_argument(
"extracted_directory",
type=str,
help="WikiExtractor 輸出的目錄路徑 (e.g., extracted_wiki_zh)"
)
# 位置參數 2: 輸出文件
parser.add_argument(
"output_filename",
type=str,
help="最終合并的純文本文件路徑 (e.g., cleaned_wiki.txt)"
)
# 可選參數: 最小行長
parser.add_argument(
"--min_line_length",
type=int,
default=20,
help="文章中單行文本必須達到的最小長度,用于過濾噪音。默認值: 20"
)
# 可選參數: 最小文章長度
parser.add_argument(
"--min_article_length",
type=int,
default=200,
help="文章最小長度,用于過濾短文章和重定向頁。默認值: 200"
)
args = parser.parse_args()
process_extracted_wiki(
args.extracted_directory,
args.output_filename,
args.min_line_length,
args.min_article_length
)
if __name__ == "__main__":
main()2025-10-01 11:10:58,772 - INFO - 找到 5 個文件等待處理。
正在提取維基文本: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00, 6.78s/it]
2025-10-01 11:11:32,681 - INFO - 所有維基百科文本已成功提取。總文章數: 628093。文件已保存到 data/cleaned_wiki_full.txt04、訓練分詞器
我們使用SentencePiece訓練分詞器,本次我們訓練的分詞庫大小為16k,你也可以訓練32k的分詞庫。相關代碼及過程如下:
注:
我們本步驟生成的文件為
workdir/spm_wiki_16k.model
workdir/spm_wiki_16k.vocab
import sys
import sentencepiece as spm
import argparse
import os
from tqdm import tqdm
# python scripts/train_tokenizer.py data/cleaned_wiki_full.txt workdir/spm_wiki 32000
def get_corpus_size(input_file: str) -> int:
"""計算語料的總行數和文件大小"""
try:
file_size_bytes = os.path.getsize(input_file)
file_size_mb = file_size_bytes / (1024 * 1024)
print(f"語料文件大小: {file_size_mb:.2f} MB")
# 計算行數和總字符數
line_count = 0
total_chars = 0
with open(input_file, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="統計語料信息"):
line_count += 1
total_chars += len(line)
print(f"語料總行數 (文章數): {line_count}")
print(f"總字符數: {total_chars:,}")
print(f"平均每行字符數: {total_chars / line_count:.1f}")
return file_size_bytes
except Exception as e:
print(f"警告:無法計算文件大小或行數:{e}")
return 0
def train_spm_model(input_file: str,
model_prefix: str,
vocab_size: int,
model_type: str = 'bpe',
character_coverage: float = 0.9995):
"""
訓練一個SentencePiece分詞器模型。
參數:
input_file: 訓練語料文件路徑
model_prefix: 輸出模型文件的前綴
vocab_size: 詞匯表大小
model_type: 分詞算法類型 ('bpe', 'unigram', 'char', 'word')
character_coverage: 字符覆蓋率 (0-1,通常 0.995-0.9995)
"""
if not os.path.exists(input_file):
print(f"錯誤:輸入語料文件未找到:{input_file}")
sys.exit(1)
# 確保輸出目錄存在
output_dir = os.path.dirname(model_prefix)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
print(f"已創建輸出目錄: {output_dir}")
# 打印語料規模信息
print("\n=== 語料分析 ===")
get_corpus_size(input_file)
# 構建訓練參數
# 對于 1.5GB 語料,建議啟用 train_extremely_large_corpus=True 加速
train_params = {
'input': input_file,
'model_prefix': model_prefix,
'vocab_size': vocab_size,
'model_type': model_type,
'character_coverage': character_coverage,
'num_threads': 32, # 增加到32(最大化CPU利用)
'bos_id': 0,
'eos_id': 1,
'unk_id': 2,
'pad_id': -1,
'normalization_rule_name': 'identity',
'input_sentence_size': 2000000, # 5000000, # 增加到500萬句子采樣
'train_extremely_large_corpus': True, # 必須啟用
'shuffle_input_sentence': True,
'seed_sentencepiece_size': 2000000, # 添加種子句子大小
'hard_vocab_limit': False, # 允許超過目標詞匯量以獲得更好質量
}
print("\n=== SentencePiece 訓練參數 ===")
for key, value in train_params.items():
print(f" {key}: {value}")
print("=" * 35)
print("\n正在訓練 SentencePiece 模型...")
print(" (請稍候,進度由 SentencePiece 輸出)\n")
try:
# 執行訓練
spm.SentencePieceTrainer.train(**train_params)
print("\n分詞器模型訓練完成!")
print(f" 模型文件: {model_prefix}.model")
print(f" 詞匯表文件: {model_prefix}.vocab")
# 驗證模型是否成功創建
if os.path.exists(f"{model_prefix}.model") and os.path.exists(f"{model_prefix}.vocab"):
model_size_kb = os.path.getsize(f"{model_prefix}.model") / 1024
print(f"\n模型文件大小: {model_size_kb:.2f} KB")
# 加載模型進行快速測試
print("\n進行快速測試...")
sp = spm.SentencePieceProcessor(model_file=f"{model_prefix}.model")
test_text = "這是一個分詞測試句子。"
tokens = sp.encode(test_text, out_type=str)
ids = sp.encode(test_text, out_type=int)
print(f" 測試文本: {test_text}")
print(f" 分詞結果: {tokens}")
print(f" Token IDs: {ids}")
else:
print("\n警告:模型文件生成失敗,請檢查輸入數據或參數")
except Exception as e:
print(f"\n訓練過程出錯: {e}")
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
descriptinotallow="使用 SentencePiece 訓練分詞器模型。",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"input_file",
type=str,
help="訓練語料的路徑 (e.g., data/cleaned_wiki_full.txt)"
)
parser.add_argument(
"model_prefix",
type=str,
help="訓練模型文件的輸出前綴 (e.g., workdir/spm_wiki)"
)
parser.add_argument(
"vocab_size",
type=int,
help="詞匯表大小 (e.g., 32000)"
)
parser.add_argument(
"--model_type",
type=str,
default='bpe',
choices=['bpe', 'unigram', 'char', 'word'],
help="分詞算法類型 (默認: bpe)"
)
parser.add_argument(
"--character_coverage",
type=float,
default=0.9995,
help="字符覆蓋率,范圍 [0-1]。對于小詞表(8K),建議用0.99或更小"
)
args = parser.parse_args()
print("\n" + "="*50)
print("SentencePiece 分詞器訓練程序")
print("="*50)
print(f"輸入語料: {args.input_file}")
print(f"輸出模型前綴: {args.model_prefix}")
print(f"詞匯表大小: {args.vocab_size}")
print(f"分詞算法: {args.model_type}")
print(f"字符覆蓋率: {args.character_coverage}")
print("="*50 + "\n")
train_spm_model(
args.input_file,
args.model_prefix,
args.vocab_size,
args.model_type,
args.character_coverage
)
if __name__ == "__main__":
main()開始訓練SentencePiece分詞器...
輸入語料: data/cleaned_wiki_full.txt
輸出模型前綴: workdir/spm_wiki_16k
詞匯表大小: 16000
語料文件大小: 1697.54 MB
Counting lines: 1256186it [00:05, 230354.42it/s]
語料總行數 (文章數): 1256186
--- SentencePiece 訓練參數 ---
--input=data/cleaned_wiki_full.txt
--model_prefix=workdir/spm_wiki_16k
--vocab_size=16000
--model_type=bpe
--character_coverage=0.9995
--num_threads=16
--bos_id=0
--eos_id=1
--unk_id=2
--pad_id=-1
------------------------------
? 正在啟動訓練... 請注意觀察 SentencePiece 自身的進度輸出。
sentencepiece_trainer.cc(178) LOG(INFO) Running command: --input=data/cleaned_wiki_full.txt --model_prefix=workdir/spm_colinai_16000 --vocab_size=16000 --model_type=bpe --character_coverage=0.9995 --num_threads=16 --bos_id=0 --eos_id=1 --unk_id=2 --pad_id=-1
sentencepiece_trainer.cc(78) LOG(INFO) Starts training with :
trainer_spec {
input: data/cleaned_wiki_full.txt
input_format:
model_prefix: workdir/spm_colinai_16000
model_type: BPE
vocab_size: 16000
self_test_sample_size: 0
character_coverage: 0.9995
input_sentence_size: 0
shuffle_input_sentence: 1
seed_sentencepiece_size: 1000000
shrinking_factor: 0.75
max_sentence_length: 4192
num_threads: 16
num_sub_iterations: 2
max_sentencepiece_length: 16
split_by_unicode_script: 1
split_by_number: 1
split_by_whitespace: 1
split_digits: 0
pretokenization_delimiter:
treat_whitespace_as_suffix: 0
allow_whitespace_only_pieces: 0
required_chars:
byte_fallback: 0
vocabulary_output_piece_score: 1
train_extremely_large_corpus: 0
seed_sentencepieces_file:
hard_vocab_limit: 1
use_all_vocab: 0
unk_id: 2
bos_id: 0
eos_id: 1
pad_id: -1
unk_piece: <unk>
bos_piece: <s>
eos_piece: </s>
pad_piece: <pad>
unk_surface: ?
enable_differential_privacy: 0
differential_privacy_noise_level: 0
differential_privacy_clipping_threshold: 0
}
normalizer_spec {
name: nmt_nfkc
add_dummy_prefix: 1
remove_extra_whitespaces: 1
escape_whitespaces: 1
normalization_rule_tsv:
}
denormalizer_spec {}
trainer_interface.cc(355) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.
trainer_interface.cc(186) LOG(INFO) Loading corpus: data/cleaned_wiki_full.txt
trainer_interface.cc(382) LOG(WARNING) Found too long line (18615 > 4192).
trainer_interface.cc(384) LOG(WARNING) Too long lines are skipped in the training.
trainer_interface.cc(385) LOG(WARNING) The maximum length can be changed with --max_sentence_length=<size> flag.
trainer_interface.cc(411) LOG(INFO) Loaded all 528882 sentences
trainer_interface.cc(418) LOG(INFO) Skipped 99211 too long sentences.
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <s>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: </s>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <unk>
trainer_interface.cc(432) LOG(INFO) Normalizing sentences...
trainer_interface.cc(541) LOG(INFO) all chars count=281809036
trainer_interface.cc(552) LOG(INFO) Done: 99.95% characters are covered.
trainer_interface.cc(562) LOG(INFO) Alphabet size=8686
trainer_interface.cc(563) LOG(INFO) Final character coverage=0.9995
trainer_interface.cc(594) LOG(INFO) Done! preprocessed 528882 sentences.
trainer_interface.cc(600) LOG(INFO) Tokenizing input sentences with whitespace: 528882
trainer_interface.cc(611) LOG(INFO) Done! 3885388
.....05、原始文本轉為Token ID 序列
在訓練大型語言模型的準備階段,將海量文本語料轉化為模型可處理的數字格式至關重要。本次將原始文本語料編碼為整數 Token ID 序列。為了克服單次加載大文件的內存限制,腳本采用了分塊讀取機制,支持以自定義大小逐塊處理語料。所有 Token ID 最終被匯總并轉化為高效率的 torch.int32 PyTorch 張量,直接存儲為 .pt 文件。這不僅優化了數據格式,方便后續 PyTorch DataLoader 快速讀取,同時也提供了關鍵的統計信息和完整性驗證,是構建 LLM 數據集的穩定且高性能的預處理方案。
import sys
import torch
import sentencepiece as spm
import argparse
from tqdm import tqdm
import os
import numpy as np
# python scripts/preprocess_data.py workdir/spm_wiki.model data/cleaned_wiki_full.txt workdir/wiki_tokens.pt
def preprocess(sp_model_path: str,
corpus_path: str,
output_path: str,
chunk_size_mb: int = 50):
"""
分塊讀取語料,編碼為 Token ID,并保存為 PyTorch 文件。
參數:
sp_model_path: SentencePiece 模型文件路徑
corpus_path: 輸入語料文件路徑
output_path: 輸出 .pt 文件路徑
chunk_size_mb: 每次處理的文本大小(MB),默認 50MB
"""
# 驗證文件存在
if not os.path.exists(sp_model_path):
print(f"錯誤:分詞器模型文件未找到: {sp_model_path}")
sys.exit(1)
if not os.path.exists(corpus_path):
print(f"錯誤:語料文件未找到: {corpus_path}")
sys.exit(1)
# 加載分詞器
try:
sp = spm.SentencePieceProcessor(model_file=sp_model_path)
vocab_size = sp.get_piece_size()
print(f" 分詞器加載成功")
print(f" 詞匯表大小: {vocab_size}")
print(f" 特殊 Token: BOS={sp.bos_id()}, EOS={sp.eos_id()}, UNK={sp.unk_id()}, PAD={sp.pad_id()}")
except Exception as e:
print(f"加載分詞器失敗: {e}")
sys.exit(1)
# 確保輸出目錄存在
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
print(f"\n 開始處理語料...")
print(f" 輸入文件: {corpus_path}")
print(f" 輸出文件: {output_path}")
print(f" 塊大小: {chunk_size_mb} MB\n")
# 計算總大小用于進度條
total_bytes = os.path.getsize(corpus_path)
chunk_size_bytes = chunk_size_mb * 1024 * 1024
token_ids = []
tokens_processed = 0
chunks_processed = 0
try:
with open(corpus_path, 'r', encoding='utf-8') as f:
with tqdm(total=total_bytes, unit='B', unit_scale=True, desc="? 編碼語料") as pbar:
while True:
chunk = f.read(chunk_size_bytes)
if not chunk:
break
# 直接編碼(cleaned_wiki_full.txt 已經過清洗)
ids = sp.encode(chunk, out_type=int)
token_ids.extend(ids)
# 更新進度條
bytes_read = len(chunk.encode('utf-8'))
pbar.update(bytes_read)
tokens_processed += len(ids)
chunks_processed += 1
# 定期顯示進度信息
if chunks_processed % 10 == 0:
pbar.set_postfix({
'chunks': chunks_processed,
'tokens': f'{tokens_processed:,}'
})
print(f"\n 編碼完成")
print(f" 處理塊數: {chunks_processed}")
print(f" 總 Token 數: {tokens_processed:,}")
# 轉換為 PyTorch 張量
print(f"\n轉換為張量并保存...")
final_tensor = torch.tensor(token_ids, dtype=torch.int32)
print(f" 張量形狀: {final_tensor.shape}")
print(f" 張量大小: {final_tensor.numel():,}")
print(f" 數據類型: {final_tensor.dtype}")
print(f" 占用內存: {final_tensor.numel() * 4 / (1024**3):.2f} GB")
# 驗證 Token ID 范圍
min_id = final_tensor.min().item()
max_id = final_tensor.max().item()
print(f" Token ID 范圍: [{min_id}, {max_id}]")
if max_id >= vocab_size or min_id < 0:
print(f" 警告: 檢測到越界 Token ID!")
print(f" 詞匯表大小: {vocab_size}")
print(f" 最大 ID: {max_id}")
# 保存張量
torch.save(final_tensor, output_path)
file_size_mb = os.path.getsize(output_path) / (1024 ** 2)
print(f"\nToken ID 已保存到 {output_path}")
print(f" 文件大小: {file_size_mb:.2f} MB")
# 驗證保存的文件
print(f"\n驗證保存的文件...")
loaded_tensor = torch.load(output_path)
print(f" 加載成功,形狀: {loaded_tensor.shape}")
print(f" 是否相同: {torch.equal(final_tensor, loaded_tensor)}")
print(f"\n? 預處理完成!")
except Exception as e:
print(f"\n處理過程中出錯: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
descriptinotallow="將清洗后的文本語料轉換為 Token ID 二進制文件。",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"model_path",
type=str,
help="SentencePiece 模型文件路徑 (e.g., workdir/spm_wiki.model)"
)
parser.add_argument(
"corpus_path",
type=str,
help="輸入語料文件路徑 (e.g., data/cleaned_wiki_full.txt)"
)
parser.add_argument(
"output_path",
type=str,
help="輸出 Token ID 文件路徑 (e.g., workdir/wiki_tokens.pt)"
)
parser.add_argument(
"--chunk_size",
type=int,
default=50,
help="每次處理的文本大小(MB),默認 50MB。更大的塊更快,但占用更多內存。"
)
args = parser.parse_args()
print("\n" + "="*60)
print("數據預處理程序 - 文本到 Token ID")
print("="*60)
print(f"SentencePiece 模型: {args.model_path}")
print(f"輸入語料: {args.corpus_path}")
print(f"輸出文件: {args.output_path}")
print(f"塊大小: {args.chunk_size} MB")
print("="*60 + "\n")
preprocess(
args.model_path,
args.corpus_path,
args.output_path,
args.chunk_size
)
if __name__ == "__main__":
main()06、進行模型預訓練
"""
GPT 高性能訓練腳本
"""
from __future__ import annotations
import sys
import os
import math
import json
from datetime import datetime
from typing import Optional
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import sentencepiece as spm
from tqdm import tqdm
# ==================== 配置參數 ====================
class Config:
BLOCK_SIZE = 512 #256
BATCH_SIZE = 32 #64
GRAD_ACCUM_STEPS = 4 #1
MODEL_DIM = 384 #256
N_LAYERS = 5 #2
NUM_HEADS = 6 #4
HEAD_DIM = MODEL_DIM // NUM_HEADS
FFN_DIM = MODEL_DIM * 4
VOCAB_SIZE = None
EPOCHS = 1
MAX_STEPS = 10000 # 此處根據自己的硬件和時間定義步數
WARMUP_STEPS = 500
LR = 1e-4
MIN_LR = 1e-5
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
DROPOUT = 0.1
CHECKPOINT_EVERY = 5000
LOG_EVERY = 100
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_DIR = "./checkpoints"
LATEST_CHECKPOINT = "latest_checkpoint.pth"
NUM_WORKERS = 8
SEED = 42
# 啟用 bfloat16 (推薦用于現代 GPU)
DTYPE = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16
CFG = Config()
if CFG.DEVICE == 'cuda':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()
# 檢查是否使用了 bfloat16
if CFG.DTYPE == torch.bfloat16:
print("使用 bfloat16 混合精度 (推薦)")
else:
print("使用 float16 混合精度")
# ==================== 工具函數 ====================
def print_gpu_memory():
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / (1024**3)
reserved = torch.cuda.memory_reserved() / (1024**3)
print(f"GPU顯存: {allocated:.2f}GB / {reserved:.2f}GB")
def set_seed(seed: int):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(CFG.SEED)
# ==================== 數據集 ====================
class TextDataset(Dataset):
def __init__(self, token_ids: torch.Tensor, block_size: int):
self.ids = token_ids.long()
self.block_size = block_size
def __len__(self):
return max(0, self.ids.size(0) - self.block_size)
def __getitem__(self, idx):
x = self.ids[idx: idx + self.block_size]
y = self.ids[idx + 1: idx + 1 + self.block_size]
return x, y
# ==================== RoPE 位置編碼 ====================
class RotaryPositionalEmbedding(nn.Module):
"""RoPE 實現"""
def __init__(self, head_dim: int, max_seq_len: int = 2048):
super().__init__()
self.head_dim = head_dim
assert head_dim % 2 == 0, "head_dim must be even"
# 基頻:theta_i = 10000^(-2i/d)
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
self._seq_len_cached = max_seq_len
self._cos_cached = None
self._sin_cached = None
self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)
def _update_cos_sin_cache(self, seq_len: int, device: torch.device):
if seq_len == self._seq_len_cached and self._cos_cached is not None:
return
# m: (seq_len,), theta_i: (head_dim//2,)
m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", m, self.inv_freq) # (seq_len, head_dim//2)
# 構造完整的旋轉矩陣(每個復數對重復)
emb = torch.cat([freqs, freqs], dim=-1) # (seq_len, head_dim)
cos = emb.cos()[None, None, :, :] # (1, 1, seq_len, head_dim)
sin = emb.sin()[None, None, :, :] # (1, 1, seq_len, head_dim)
self._cos_cached = cos
self._sin_cached = sin
self._seq_len_cached = seq_len
def forward(self, seq_len: int, device: Optional[torch.device] = None):
if device is None:
device = self.inv_freq.device
self._update_cos_sin_cache(seq_len, device=device)
return self._cos_cached.to(device), self._sin_cached.to(device)
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""應用RoPE旋轉"""
# x: (B, H, T, D), cos/sin: (1, 1, T, D)
# 使用(x, y) -> (x*cos-y*sin, x*sin+y*cos)
return (x * cos) + (_rotate_half(x) * sin)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""將向量旋轉90度"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# ==================== Flash Attention ====================
class FlashAttention(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.attn_dropout = nn.Dropout(attn_dropout)
self.rope = RotaryPositionalEmbedding(self.head_dim)
def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, T, C = x.shape
assert T <= self.rope.max_seq_len, f"Seq len {T} exceeds max {self.rope.max_seq_len}"
qkv = self.qkv(x)
qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q = q.permute(0, 2, 1, 3) # (B, H, T, D)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
# 應用RoPE
cos, sin = self.rope(T, device=x.device)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
# 注意力計算
# 注意:這里如果使用 torch.nn.functional.scaled_dot_product_attention 配合 torch.compile 會更快
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if causal_mask is not None:
scores = scores.masked_fill(causal_mask == 0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
attn = self.attn_dropout(attn)
out = torch.matmul(attn, v)
out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)
return self.out_proj(out)
# ==================== 前饋網絡 ====================
class GLU(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim * 2)
def forward(self, x):
x, gates = self.linear(x).chunk(2, dim=-1)
return x * torch.nn.functional.silu(gates)
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
GLU(dim, hidden_dim),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
# ==================== Transformer Block ====================
class TransformerBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):
super().__init__()
self.ln1 = nn.LayerNorm(dim)
self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)
self.ln2 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, ffn_dim, dropout)
def forward(self, x, causal_mask=None):
x = x + self.attn(self.ln1(x), causal_mask)
x = x + self.ff(self.ln2(x))
return x
# ==================== GPT 模型(已移除 pos_emb) ====================
class GPTModel(nn.Module):
def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,
num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,
ffn_dim: int = CFG.FFN_DIM, dropout: float = CFG.DROPOUT,
tie_weights: bool = True):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, dim)
# self.pos_emb = nn.Embedding(block_size, dim) # 移除:與 RoPE 沖突
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)
])
self.ln_final = nn.LayerNorm(dim)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
if tie_weights:
self.lm_head.weight = self.token_emb.weight
self.block_size = block_size
self.apply(self._init_weights)
n_params = sum(p.numel() for p in self.parameters())
print(f"模型參數: {n_params/1e6:.2f}M")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, idx):
B, T = idx.shape
assert T <= self.block_size, f"Seq len {T} exceeds block_size {self.block_size}"
token_emb = self.token_emb(idx)
x = self.dropout(token_emb) # token embedding
causal_mask = torch.tril(torch.ones(T, T, device=idx.device, dtype=torch.bool))[None, None, :, :]
for block in self.blocks:
x = block(x, causal_mask)
x = self.ln_final(x)
logits = self.lm_head(x)
return logits
# ==================== 檢查點管理 ====================
def save_checkpoint(model, optimizer, scaler, lr_scheduler, step: int, loss: float, config_dict: dict):
os.makedirs(CFG.CHECKPOINT_DIR, exist_ok=True)
checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)
state = {
'step': step,
'loss': loss,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'config': config_dict,
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
}
if scaler is not None and hasattr(scaler, "state_dict"):
state['scaler_state_dict'] = scaler.state_dict()
if lr_scheduler is not None:
state['lr_scheduler_state_dict'] = {
'current_step': lr_scheduler.current_step,
'warmup_steps': lr_scheduler.warmup_steps,
'total_steps': lr_scheduler.total_steps,
'base_lr': lr_scheduler.base_lr,
'min_lr': lr_scheduler.min_lr,
}
torch.save(state, checkpoint_path)
try:
with open(os.path.join(CFG.CHECKPOINT_DIR, "config.json"), "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2)
except Exception:
pass
print(f" 檢查點已保存: {checkpoint_path} (step {step}, loss {loss:.4f})")
def load_checkpoint(checkpoint_path: str, model, optimizer, scaler, lr_scheduler):
if not os.path.exists(checkpoint_path):
return None
checkpoint = torch.load(checkpoint_path, map_locatinotallow=CFG.DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if checkpoint.get('scaler_state_dict') is not None and scaler is not None:
try:
scaler.load_state_dict(checkpoint['scaler_state_dict'])
except Exception as e:
print(f"無法恢復scaler: {e}")
if checkpoint.get('lr_scheduler_state_dict') is not None and lr_scheduler is not None:
try:
sched_state = checkpoint['lr_scheduler_state_dict']
lr_scheduler.current_step = sched_state['current_step']
lr_scheduler.warmup_steps = sched_state['warmup_steps']
lr_scheduler.total_steps = sched_state['total_steps']
lr_scheduler.base_lr = sched_state['base_lr']
lr_scheduler.min_lr = sched_state['min_lr']
except Exception as e:
print(f"無法恢復lr_scheduler: {e}")
torch.set_rng_state(checkpoint['torch_rng_state'])
if torch.cuda.is_available() and checkpoint.get('cuda_rng_state') is not None:
torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
print(f"檢查點已加載: {checkpoint_path}")
print(f" Step: {checkpoint['step']}, Loss: {checkpoint['loss']:.4f}")
return checkpoint['step']
# ==================== 學習率調度器 ====================
class WarmupCosineScheduler:
def __init__(self, optimizer, warmup_steps: int, total_steps: int, base_lr: float, min_lr: float):
self.optimizer = optimizer
self.warmup_steps = max(0, int(warmup_steps))
self.total_steps = max(1, int(total_steps))
self.base_lr = base_lr
self.min_lr = min_lr
self.current_step = 0
def get_lr(self, step: int = None) -> float:
"""計算給定step的學習率(不修改optimizer)"""
if step is None:
step = self.current_step
if step < self.warmup_steps and self.warmup_steps > 0:
return self.base_lr * (step / float(self.warmup_steps))
else:
denom = max(1, (self.total_steps - self.warmup_steps))
progress = (step - self.warmup_steps) / denom
progress = min(1.0, max(0.0, progress))
return self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress))
def step(self):
"""執行一次步長更新"""
lr = self.get_lr(self.current_step)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.current_step += 1
return lr
# ==================== 訓練循環 ====================
def train(model: nn.Module, train_loader: DataLoader, epochs: int = CFG.EPOCHS, resume: bool = False):
# 檢測fused優化器支持
fused = False
try:
fused = torch.cuda.is_available() and ("fused" in torch.optim.AdamW.__init__.__code__.co_varnames)
except Exception:
fused = False
optimizer = torch.optim.AdamW(
model.parameters(),
lr=CFG.LR,
betas=(0.9, 0.95),
weight_decay=CFG.WEIGHT_DECAY,
fused=fused
)
# 使用配置中的 DTYPE
scaler = torch.cuda.amp.GradScaler(enabled=(CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16))
loss_fn = nn.CrossEntropyLoss()
total_steps = CFG.MAX_STEPS if CFG.MAX_STEPS else len(train_loader) * epochs
lr_scheduler = WarmupCosineScheduler(optimizer, CFG.WARMUP_STEPS, total_steps, CFG.LR, CFG.MIN_LR)
model.train()
start_step = 0
best_loss = float('inf')
checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)
if resume and os.path.exists(checkpoint_path):
loaded_step = load_checkpoint(checkpoint_path, model, optimizer, scaler, lr_scheduler)
if loaded_step is not None:
start_step = loaded_step
global_step = start_step
grad_accum_counter = 0
accumulated_loss = 0.0
print("\n" + "="*60)
print("開始訓練...")
print("="*60)
print_gpu_memory()
print()
# 自動選擇是否需要 scaler.scale()
use_scaler = (CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16)
for epoch in range(epochs):
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", initial=global_step % len(train_loader) if epoch == 0 else 0)
num_batches = 0
last_lr = None
for batch_idx, (xb, yb) in enumerate(pbar):
# 跳過已訓練的批次 (如果從中間恢復)
if global_step > start_step and batch_idx < (start_step % len(train_loader)):
continue
xb = xb.to(CFG.DEVICE, non_blocking=True)
yb = yb.to(CFG.DEVICE, non_blocking=True)
with torch.cuda.amp.autocast(enabled=(CFG.DEVICE == "cuda"), dtype=CFG.DTYPE):
logits = model(xb)
loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))
loss_item = loss.item()
loss = loss / CFG.GRAD_ACCUM_STEPS
if use_scaler:
scaler.scale(loss).backward()
else:
loss.backward()
grad_accum_counter += 1
accumulated_loss += loss_item
num_batches += 1
# 這里的 global_step 計數是基于數據批次的,而不是優化器步數,用于日志和檢查點
# 真正的優化器步數會在下面更新
# 梯度累積:達到閾值時執行優化步驟
if grad_accum_counter >= CFG.GRAD_ACCUM_STEPS:
# 優化器步進 (這是真正的 global_step 增長點)
lr_scheduler.step() # 先更新 LR
global_step += 1 # 只有進行了一次優化器步進,才算一個 global_step
if use_scaler:
scaler.unscale_(optimizer)
# 梯度裁剪 (在 unscale 后或非 AMP 模式下)
torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)
if use_scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
grad_accum_counter = 0
last_lr = lr_scheduler.get_lr(global_step) # 獲取當前步的LR
# 日志輸出
if global_step % CFG.LOG_EVERY == 0 or (global_step == 1):
# accumulated_loss 是累積的原始損失, num_batches 是累積的批次數
avg_loss = accumulated_loss / num_batches if num_batches > 0 else 0.0
pbar.set_postfix({
'step': global_step,
'loss': f'{avg_loss:.4f}',
'lr': f'{last_lr:.2e}' if last_lr is not None else 'N/A'
})
# 重置累積值以便計算下一個 LOG_EVERY 間隔的平均損失
accumulated_loss = 0.0
num_batches = 0
# 保存檢查點
if global_step > start_step and global_step % CFG.CHECKPOINT_EVERY == 0:
# 使用上一個日志點計算的 avg_loss
current_avg_loss = accumulated_loss / num_batches if num_batches > 0 else loss_item
config_dict = {
'vocab_size': CFG.VOCAB_SIZE,
'block_size': CFG.BLOCK_SIZE,
'model_dim': CFG.MODEL_DIM,
'n_layers': CFG.N_LAYERS,
'num_heads': CFG.NUM_HEADS,
'created_at': datetime.now().isoformat()
}
save_checkpoint(model, optimizer, scaler, lr_scheduler, global_step, current_avg_loss, config_dict)
torch.cuda.empty_cache()
if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:
break
# 處理 epoch 結束時剩余的梯度 (如果 grad_accum_counter > 0)
if grad_accum_counter > 0:
if use_scaler:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)
if use_scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
global_step += 1
grad_accum_counter = 0
# 此時 pbar.total_loss 已累積
if num_batches > 0:
final_avg_loss = accumulated_loss / num_batches
else:
final_avg_loss = float('inf')
if final_avg_loss < best_loss:
best_loss = final_avg_loss
best_path = os.path.join(CFG.CHECKPOINT_DIR, "best_model.pth")
torch.save(model.state_dict(), best_path)
print(f"最佳模型已保存 (loss: {best_loss:.4f})")
print(f"\n[Epoch {epoch+1}] Avg Loss: {final_avg_loss:.4f}")
if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:
break
print("\n訓練完成!")
# ==================== 主函數 ====================
def main():
if len(sys.argv) < 4:
print("用法: python train_20251012_v1.py workdir/spm_wiki_16k.model workdir/wiki_tokens_16k.pt models/gpt_wiki.pth [--resume]")
sys.exit(1)
sp_model_path, token_file_path, out_path = sys.argv[1:4]
resume = "--resume" in sys.argv
if not os.path.exists(token_file_path):
print(f" Token文件不存在: {token_file_path}")
sys.exit(1)
# 檢查 CFG.DTYPE 是否為 bfloat16 但環境不支持
if CFG.DTYPE == torch.bfloat16 and not torch.cuda.is_bf16_supported():
print("警告: bfloat16 不受當前 CUDA 設備支持,自動回退到 float16。")
CFG.DTYPE = torch.float16
sp = spm.SentencePieceProcessor(model_file=sp_model_path)
CFG.VOCAB_SIZE = sp.get_piece_size()
print("="*60)
print("GPT 語言模型訓練")
print("="*60)
print(f"分詞器: {sp_model_path}")
print(f"Token文件: {token_file_path}")
print(f"輸出模型: {out_path}")
print(f"設備: {CFG.DEVICE}")
print(f"\n模型配置:")
print(f" - VOCAB_SIZE: {CFG.VOCAB_SIZE}")
print(f" - BLOCK_SIZE: {CFG.BLOCK_SIZE}")
print(f" - MODEL_DIM: {CFG.MODEL_DIM}")
print(f" - N_LAYERS: {CFG.N_LAYERS}")
print(f" - NUM_HEADS: {CFG.NUM_HEADS}")
print(f"\n訓練配置:")
print(f" - BATCH_SIZE: {CFG.BATCH_SIZE}")
print(f" - GRAD_ACCUM_STEPS: {CFG.GRAD_ACCUM_STEPS}")
print(f" - 有效BATCH_SIZE: {CFG.BATCH_SIZE * CFG.GRAD_ACCUM_STEPS}")
print(f" - LR: {CFG.LR}, WARMUP_STEPS: {CFG.WARMUP_STEPS}")
print("="*60)
print(f"\n加載Token文件: {token_file_path}")
ids = torch.load(token_file_path)
print(f"已加載 {ids.numel():,} tokens ({ids.numel() * ids.element_size() / (1024**3):.2f} GB)")
dataset = TextDataset(ids, CFG.BLOCK_SIZE)
del ids
torch.cuda.empty_cache()
# 改進:啟用 shuffle=True 進行預訓練
num_workers = CFG.NUM_WORKERS
try:
train_loader = DataLoader(
dataset,
batch_size=CFG.BATCH_SIZE,
shuffle=True, # 啟用 Shuffle
pin_memory=(CFG.DEVICE == "cuda"),
num_workers=num_workers,
persistent_workers=True if num_workers > 0 else False
)
except Exception as e:
print(f"DataLoader錯誤: {e}, 改用num_workers=0")
train_loader = DataLoader(
dataset,
batch_size=CFG.BATCH_SIZE,
shuffle=True,
pin_memory=(CFG.DEVICE == "cuda"),
num_workers=0
)
model = GPTModel(
CFG.VOCAB_SIZE,
CFG.BLOCK_SIZE,
dim=CFG.MODEL_DIM,
num_layers=CFG.N_LAYERS,
num_heads=CFG.NUM_HEADS,
ffn_dim=CFG.FFN_DIM,
dropout=CFG.DROPOUT
).to(CFG.DEVICE)
# 嘗試編譯(容錯)
try:
model = torch.compile(model, mode='reduce-overhead')
print("已啟用 torch.compile() 加速")
except Exception as e:
print(f"跳過 torch.compile(): {e}")
train(model, train_loader, epochs=CFG.EPOCHS, resume=resume)
torch.save(model.state_dict(), out_path)
print(f"\n最終模型已保存到 {out_path}")
print_gpu_memory()
if __name__ == "__main__":
main()07、進行模型推理測試
import torch
from torch import nn
import sentencepiece as spm
from typing import Optional
# ==================== 配置參數 (必須與訓練時一致) ====================
# 使用與訓練腳本中完全相同的配置
class Config:
BLOCK_SIZE = 512
# 模型尺寸參數 (必須與訓練時一致)
MODEL_DIM = 384
N_LAYERS = 5
NUM_HEADS = 6
HEAD_DIM = MODEL_DIM // NUM_HEADS
FFN_DIM = MODEL_DIM * 4
VOCAB_SIZE = None
# 推理設置
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 推理通常使用 float32 獲得最佳兼容性和精度
DTYPE = torch.float32
CFG = Config()
# ==================== RoPE 位置編碼 (與訓練腳本保持一致) ====================
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, head_dim: int, max_seq_len: int = 2048):
super().__init__()
self.head_dim = head_dim
assert head_dim % 2 == 0, "head_dim must be even"
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
self._seq_len_cached = max_seq_len
self._cos_cached = None
self._sin_cached = None
self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)
def _update_cos_sin_cache(self, seq_len: int, device: torch.device):
if seq_len == self._seq_len_cached and self._cos_cached is not None:
return
m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", m, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos()[None, None, :, :]
sin = emb.sin()[None, None, :, :]
self._cos_cached = cos
self._sin_cached = sin
self._seq_len_cached = seq_len
def forward(self, seq_len: int, device: Optional[torch.device] = None):
if device is None:
device = self.inv_freq.device
self._update_cos_sin_cache(seq_len, device=device)
return self._cos_cached.to(device), self._sin_cached.to(device)
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
return (x * cos) + (_rotate_half(x) * sin)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# ==================== Attention, FFN, Block, Model (與訓練腳本保持一致) ====================
class FlashAttention(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
# 推理時通常不使用 Dropout,但模型結構需要保持一致
self.attn_dropout = nn.Dropout(attn_dropout)
self.rope = RotaryPositionalEmbedding(self.head_dim)
def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, T, C = x.shape
qkv = self.qkv(x)
qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q = q.permute(0, 2, 1, 3) # (B, H, T, D)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
cos, sin = self.rope(T, device=x.device)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# 注意:在推理時,通常使用 KV-Cache,這里簡化為完整計算
if T > 1: # 僅在序列長度大于 1 時應用 mask
causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))[None, None, :, :]
scores = scores.masked_fill(causal_mask == 0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
# 推理時禁用 dropout
# attn = self.attn_dropout(attn)
out = torch.matmul(attn, v)
out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)
return self.out_proj(out)
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
super().__init__()
# 必須保持與訓練腳本中完全相同的 nn.Sequential 結構
self.net = nn.Sequential(
GLU(dim, hidden_dim),
nn.Dropout(dropout), # net.1: Dropout (必須保留,占位)
nn.Linear(hidden_dim, dim), # net.2: Linear (與訓練時一致)
nn.Dropout(dropout), # net.3: Dropout (必須保留,占位)
)
def forward(self, x):
# 在推理時, model.eval() 會自動禁用所有 nn.Dropout 層,但結構不變
return self.net(x)
# 確保 GLU 的定義如下(與訓練時一致):
class GLU(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
# GLU 內部只有一個 nn.Linear
self.linear = nn.Linear(in_dim, out_dim * 2)
def forward(self, x):
x, gates = self.linear(x).chunk(2, dim=-1)
return x * torch.nn.functional.silu(gates)
class TransformerBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):
super().__init__()
self.ln1 = nn.LayerNorm(dim)
self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)
self.ln2 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, ffn_dim, dropout)
def forward(self, x, causal_mask=None):
x = x + self.attn(self.ln1(x), causal_mask)
x = x + self.ff(self.ln2(x))
return x
class GPTModel(nn.Module):
def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,
num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,
ffn_dim: int = CFG.FFN_DIM, dropout: float = 0.0, # 推理時 dropout=0
tie_weights: bool = True):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, dim)
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)
])
self.ln_final = nn.LayerNorm(dim)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
if tie_weights:
self.lm_head.weight = self.token_emb.weight
self.block_size = block_size
def forward(self, idx):
B, T = idx.shape
token_emb = self.token_emb(idx)
x = token_emb # 推理時不使用 dropout
causal_mask = None # Attention 模塊內部處理 Causal Mask
for block in self.blocks:
x = block(x, causal_mask)
x = self.ln_final(x)
logits = self.lm_head(x)
return logits
# ==================== 推理和生成函數 ====================
@torch.no_grad()
def generate_text(model: GPTModel, sp: spm.SentencePieceProcessor,
prompt: str, max_new_tokens: int, temperature: float = 0.8,
top_k: int = 50):
model.eval()
device = CFG.DEVICE
# 1. 編碼輸入
input_ids = sp.encode_as_ids(prompt)
if not input_ids:
return "無法編碼輸入。"
# 將輸入轉換為模型期望的格式 (B, T)
x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
# 2. 循環生成
for _ in range(max_new_tokens):
# 裁剪輸入以適應模型的 BLOCK_SIZE
# 在實際部署中,這里應該使用 KV Cache,但此處簡化為完整前向傳播
idx_cond = x if x.size(1) <= CFG.BLOCK_SIZE else x[:, -CFG.BLOCK_SIZE:]
# 獲取 logits
logits = model(idx_cond)
# 只取最后一個時間步的 logits
logits = logits[:, -1, :]
# 應用溫度縮放
logits = logits / temperature
# 3. Top-K 采樣
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# 計算概率并采樣
probs = torch.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# 4. 停止條件
# 檢查是否生成了 EOS token (假設 </s> 是 ID 3, 請根據您的分詞器調整)
# 默認使用 SentencePiece 的 <eos> ID
if idx_next.item() == sp.eos_id():
break
# 將新生成的 token 添加到序列中
x = torch.cat((x, idx_next), dim=1)
# 檢查是否達到最大序列長度 (防止溢出)
if x.size(1) >= CFG.BLOCK_SIZE + max_new_tokens:
break
# 5. 解碼輸出
output_ids = x[0].tolist()
# 查找輸入 prompt 的長度,只解碼新生成的 token
start_index = len(input_ids)
return sp.decode_ids(output_ids[start_index:])
# ==================== 主執行函數 ====================
def main_infer(sp_model_path: str, model_weights_path: str):
print("="*50)
print(f"GPT 模型推理模式")
print(f"設備: {CFG.DEVICE}, DTYPE: {CFG.DTYPE}")
print("="*50)
# 1. 加載分詞器
try:
sp = spm.SentencePieceProcessor(model_file=sp_model_path)
CFG.VOCAB_SIZE = sp.get_piece_size()
print(f"加載分詞器成功,VOCAB_SIZE: {CFG.VOCAB_SIZE}")
except Exception as e:
print(f"無法加載分詞器模型 {sp_model_path}: {e}")
return
# 2. 實例化模型
model = GPTModel(
vocab_size=CFG.VOCAB_SIZE,
block_size=CFG.BLOCK_SIZE,
dim=CFG.MODEL_DIM,
num_layers=CFG.N_LAYERS,
num_heads=CFG.NUM_HEADS,
ffn_dim=CFG.FFN_DIM,
dropout=0.0 # 推理時設置 dropout 為 0
).to(CFG.DEVICE).to(CFG.DTYPE)
# 3. 加載權重
try:
# 檢查是否是 torch.compile 后的狀態字典
weights = torch.load(model_weights_path, map_locatinotallow=CFG.DEVICE)
# 如果權重是 DDP 或 torch.compile 包裝后的,需要解包
if any(k.startswith('_orig_mod.') for k in weights.keys()):
weights = {k.replace('_orig_mod.', ''): v for k, v in weights.items()}
model.load_state_dict(weights, strict=True)
print(f"成功加載模型權重: {model_weights_path}")
except Exception as e:
print(f"無法加載或匹配模型權重: {e}")
# 如果加載失敗,打印預期鍵和實際鍵,方便調試
# print("\n--- 預期模型鍵 (部分) ---")
# print(list(model.state_dict().keys())[:5])
# print("\n--- 載入權重鍵 (部分) ---")
# print(list(weights.keys())[:5])
return
# 4. 進入交互循環
print("\n--- 進入交互模式 ---")
print(f"輸入 'exit' 或 'quit' 退出。")
print(f"輸入 'config' 查看當前生成參數。")
print("----------------------")
max_tokens = 100
temperature = 0.8
top_k = 50
while True:
try:
prompt = input(">>> 輸入提示詞: ")
if prompt.lower() in ['exit', 'quit']:
break
if prompt.lower() == 'config':
print(f" Max Tokens: {max_tokens}, Temp: {temperature}, Top K: {top_k}")
new_max = input(" 設置 Max Tokens (回車跳過): ")
new_temp = input(" 設置 Temperature (回車跳過): ")
new_k = input(" 設置 Top K (回車跳過): ")
if new_max: max_tokens = int(new_max)
if new_temp: temperature = float(new_temp)
if new_k: top_k = int(new_k)
continue
if not prompt.strip():
continue
print("生成中...")
# 執行生成
output = generate_text(model, sp, prompt, max_tokens, temperature, top_k)
print(f"--- 模型回復 ---\n{output.strip()}")
print("----------------")
except KeyboardInterrupt:
print("\n退出生成...")
break
except Exception as e:
print(f"發生錯誤: {e}")
if __name__ == "__main__":
import sys
if len(sys.argv) != 3:
print("用法: python infer.py <spm模型路徑> <模型權重文件路徑>")
# 示例用法 (請根據您的實際文件路徑修改):
# python infer.py tokenizer.model final_model.pth
sys.exit(1)
sp_model_path = sys.argv[1]
model_weights_path = sys.argv[2]
main_infer(sp_model_path, model_weights_path)我們看到模型大概可以預測我們輸入的下一個詞,因我們訓練的參數和步數很低,模型輸出的亂七八糟!
本次總結
本次我們做了數據準備、數據清洗、分詞器訓練、模型訓練、推理等,請根據步驟進行執行代碼,你便可以得到一個17M參數的小模型。后面我們再加大參數進行訓練,再進行監督微調。

























