进阶教程
BPE分词器训练优化:初版实现与后续方向分析
摘要
针对BPE分词器朴素版每轮合并全量扫描所有词的效率瓶颈,提出增量更新策略:维护全局pai
好的,没问题。作为常年跟BPE分词器较劲的工程师,这篇文档我来重写——去掉那些通用套话,换成干活时踩坑、总结出来的硬核经验。
---
> 如果你还没摸透 `cs336_assignment1_basics.pdf` 里BPE的原理,建议先翻 [[assigment1_overview&bpe_basics]] 看我们做的翻译和细节拆解。我们在那篇里提到过,朴素BPE的时间复杂度高得离谱;后来在 [[train_bpe_naive#测试]] 里也验证了它那令人捉急的训练速度,根本扛不住作业要求。所以当务之急,是在这个朴素版基础上动手做优化。
### 朴素版详细分析
先复盘一下朴素版的瓶颈到底卡在哪。每一轮合并的核心操作就三步:
1. **`count_pairs`**:遍历所有 word 编码里的相邻字节对,时间复杂度 O(N),N 是所有 token 的总数。
2. **`max()`**:从统计结果里揪出频率最高的那个 pair。
3. **`merge_encoding`**:再次遍历所有 word 编码,执行替换。
我之前提过,问题全在第一站。为什么?因为每轮合并之后,绝大多数 word 的编码压根没变,但我们每次还是得把它们全部重新扫一遍。拿文档里的 `bpe_example` 来说:
```text
low low low low low
lower lower
widest widest widest
newest newest newest newest newest newest
```
第一次合并的是 `(s, t)`,受影响的 word 只有3个。然而到第二轮,我们仍然要对所有 word 扫一遍。这种“无差别”的做法,效率可想而知。

### 优化:告别全量扫描,拥抱增量更新
分析到这,优化思路已经很直白了:维护一个全局的 `pair_counts` 字典,每次合并只做增量更新,不再全部重算。
核心逻辑:合并 `(A, B) → AB` 时,`pair_counts` 的变化只跟两边的邻居有关。比如序列 `... X A B Y ...` 合并后变成 `... X AB Y ...`。那么:
- 旧的相邻对 `(X, A)` 和 `(B, Y)` 会消失,需要减去它们的计数。
- 新的相邻对 `(X, AB)` 和 `(AB, Y)` 出现了,需要加上计数。
- 被合并的 `(A, B)` 这个 pair,计数自然清零。
这个优化的精髓在于:
`每轮 merge: 直接从 pair_counts 字典里找最大值,复杂度 O(unique pairs)。只需扫描【含有 (p0,p1) 的 word】执行替换,复杂度 O(受影响的 word)。然后针对这些 word:旧邻对计数 -1,新邻对计数 +1,完成增量更新。`
这里有个关键点:合并 `(p0, p1) → new_id` 后,`pair_counts` 的变化完全由新旧 `encoding` 的差异决定。我们只需对旧 `encoding` 的相邻对逐个减计数,对新 `encoding` 的相邻对逐个加计数就行。这样做的好处是不用去分类讨论左右邻居的复杂情况,逻辑清晰,不容易出 bug。
#### 优化后的实现
```python
# 4. BPE training loop
num_merges = vocab_size - size
pair_counts = defaultdict(int, self.count_pairs(word_counts, word_encodings))
for merge_idx in range(num_merges):
if not pair_counts:
print("No more pairs to be merged, quit.")
break
# a. Find the max frequency pair to be merged
merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], self.vocab[x[0]], self.vocab[x[1]]))
# b. Merge and update the word encodings
token_id = size
for word in word_encodings:
old_encoding = word_encodings[word]
# skip if the encoding not changed
if not any((old_encoding[i], old_encoding[i+1]) == merge_pair for i in range(len(old_encoding) - 1)):
continue
# update the encoding
new_encoding = self.merge_encoding(old_encoding, merge_pair, token_id)
# update count for pair
cnt = word_counts[word]
for i in range(len(old_encoding) - 1):
pair_counts[(old_encoding[i], old_encoding[i+1])] -= cnt
for i in range(len(new_encoding) - 1):
pair_counts[(new_encoding[i], new_encoding[i+1])] += cnt
# update word encoding
word_encodings[word] = new_encoding
# clear the pairs whose count is no more than 0
del_keys = [k for k, v in pair_counts.items() if v <= 0]
for k in del_keys:
del pair_counts[k]
```
除了手写时的一些笔误,实现过程中还踩了几个坑,值得记一笔。
**1. `any()` 的用法失误**
一开始我是这样写的:
```python
for i in range(len(old_encoding) - 1):
if not any((old_encoding[i], old_encoding[i+1]) == merge_pair):
continue
```
一跑测试就收到 `TypeError`,大意是 `'bool' object is not iterable`。查了一下 `any()` 的用法才知道问题所在。
```python
def any(iterable):
for element in iterable:
if element:
return True
return False
```
`any()` 接收的参数必须是可迭代对象(比如列表、生成器),而我直接传了一个布尔值进去,自然不对。修正方案就是让 `any()` 接收一个生成器表达式。
**2. 清理失效pair的时机**
另一个Bug出现在清理某些不再存在的字节对时。我一开始的做法是在扣除旧 `encoding` 中pair频率的循环里,发现某个pair计数归零就立刻删除:
```python
for i in range(len(old_encoding) - 1):
pair_counts[(old_encoding[i], old_encoding[i+1])] -= cnt
# clear immediately
if pair_counts[(old_encoding[i], old_encoding[i+1])] <= 0:
del pair_counts[(old_encoding[i], old_encoding[i+1])]
```
这么干会碰到 `KeyError`。问题在于,同一个 word 的 `encoding` 里,同一个pair可能出现多次。假设出现了两次,第一次扣除后归零,我们立刻清除了这个 key,那第二次再扣的时候,这个 key 已经不存在了。再说,扣除和加回是分阶段操作的,一个 pair 可能先被扣到0,但同一轮又被加回来,这就打断了正常的更新流程。所以,正确的做法是等整个 word 的所有 pair 更新完毕(扣除和加回都做完)之后,再统一判断哪些 pair 需要清理。
**3. `pair_counts` 没初始化就用了**
这个 Bug 找起来费了点功夫。测试报错信息是:
```python
for i in range(len(new_encoding) - 1):
> pair_counts[(new_encoding[i], new_encoding[i+1])] += cnt
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E KeyError: (116, 257)
cs336_basics/train_bpe.py:77: KeyError
```
我反复检查了 `merge_encoding` 和 `count_pairs` 的逻辑,都没问题。后来往上一看,原因很直接:`self.count_pairs` 返回的是一个普通的 `dict`,而普通 `dict` 访问不存在的 key 会抛 `KeyError`。而如果用 `defaultdict(int)`,访问不存在的 key 时会自动初始化为 `int()`(也就是0),再执行 `+=` 操作就没事了。所以,正确的做法是像下面这样,或者先创建 `defaultdict`,再调用 `update` 方法。
```python
pair_counts = defaultdict(int, self.count_pairs(word_counts, word_encodings))
```
### 测试
三个测试点都顺利通过了。相比朴素版,效率的提升是实打实的。
### 在 TinyStories 数据集上训练
这个题目有两个要求:
- a. 词汇表大小最大为10,000,必须确保特殊token `"<|endoftext|>"` 被加入到词汇表中。资源要求是:训练时长 ≤ 30分钟(不使用GPU),占用内存 ≤ 30GB RAM。提示说,如果想在2分钟内完成,可以考虑多线程处理预分词。
- b. 分析tokenizer训练过程中,哪一部分最耗时。
按照作业要求,我分了三步走:
1. 编写训练脚本:包括加载训练数据、训练、保存模型、统计时间和内存。
2. 运行时性能分析(Profiling):找到瓶颈所在。
3. 检查结果:找出最长的Token。
#### 训练脚本
**数据**
- **检查数据**:用 `head -n 5 data/TinyStoriesV2-GPT4-train.txt` 先看看数据长啥样,确认跟测试数据格式一致。
- **加载数据**:为了读写文件方便,工程里通常会先获取项目根路径。
```python
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
```
- **获取运行时内存**
```python
def get_memory_usage_mb():
"""Get current process memory usage in MB"""
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024 / 1024
```
- **保存模型**:训练结束后,要把得到的 `vocabulary` 和 `merges` 规则持久化到磁盘。`vocab` 是个字典,保存成 `json` 格式,为了便于人工阅读,把字节串显示出来,对于无法显示的字节,保留 `repr` 形式。`merges` 是个列表,保存成文本文件。这里参考了测试文件 `tests/test_train_bpe.py` 的保存形式,也用了里面的 `gpt2_bytes_to_unicode` 函数。当然,为了简单,也可以直接用 `pickle` 保存。
```python
# sa ve vocab and merges to disk
from tests.common import gpt2_bytes_to_unicode
def sa ve_vocab_and_merges(vocab, merges, output_dir='results'):
Path(output_dir).mkdir(exist_ok=True, parents=True)
byte_encoder = gpt2_bytes_to_unicode() # {int: str},每个字节映射到可打印字符串
# Sa ve vocab:把每个token的字节转换成可打印字符串
vocab_str = {}
for idx, token_bytes in vocab.items():
vocab_str[idx] = ''.join(byte_encoder[b] for b in token_bytes)
with open(f'{output_dir}/vocab.json', 'w', encoding='utf-8') as f:
json.dump(vocab_str, f, ensure_ascii=False, indent=2)
# Sa ve merges:两个token用空格分隔,每个字节转为可打印字符串
with open(f'{output_dir}/merges.txt', 'w', encoding='utf-8') as f:
for p1, p2 in merges:
t1 = ''.join(byte_encoder[b] for b in p1)
t2 = ''.join(byte_encoder[b] for b in p2)
f.write(f'{t1} {t2}\n')
```
- **训练主函数**
```python
# main function
def run_training(input_path, vocab_size, special_tokens, output_dir='results'):
"""Run training"""
# 记录训练开始前的内存
print(f'Initial Memory: {get_memory_usage_mb():.2f} MB')
# 初始化BPE训练器
trainer = train_bpe.BPETrainer()
# 开始训练,记录时间和内存
start_time = time.time()
print(f'Starting training on {input_path}...')
vocab, merges = trainer.train(input_path, vocab_size, special_tokens)
end_time = time.time()
duration = end_time - start_time
peak_memory = get_memory_usage_mb()
print('-' * 100)
print('Training Complete.')
print(f'Time Taken: {duration:.2f} seconds ({duration/60:.2f} minutes)')
print(f'Final Memory: {peak_memory:.2f} MB')
print('-' * 100)
sa ve_vocab_and_merges(vocab, merges, output_dir)
# 输出统计信息
print("\n=== Statistics (Problem b) ===")
# 1. Longest token
longest_token_bytes = max(vocab.values(), key=len)
try:
longest_token_str = longest_token_bytes.decode('utf-8')
except:
longest_token_str = str(longest_token_bytes)
print(f"Longest Token: {longest_token_str!r}")
print(f"Length in bytes: {len(longest_token_bytes)}")
# 2. 关于“数据集中最频 token”的说明:BPE通常不会保留最终词汇表的完整频率,除非我们重新分词。
# 这里我们打印最后一次合并的pair,它代表那一步最频繁的pair。
print(f"Total Merges: {len(merges)}")
print('-' * 100)
```
- **执行训练**:在 `main` 函数中定义好参数,然后调用 `run_training`,最后在终端执行。
```bash
uv run python ./train_bpe_tinystories.py --input_path data/TinyStoriesV2-GPT4-train.txt --vocab_size 10000 --profile
```
### 训练与分析
#### 训练结果
- 训练过程中,顺便看了下CPU和内存。我那本就有些捉襟见肘的内存直接被干满了,CPU也确实只有一个核在跑,符合单线程的预期。这也印证了后面优化的方向:用多线程处理预分词。
- 训练结果截图在这里,从结果看:训练花了28分钟,内存占用2302.11 MB。这个数据跟训练集大小是吻的。训练集大小如下:
```
2.1G data/TinyStoriesV2-GPT4-train.txt
22M data/TinyStoriesV2-GPT4-valid.txt
9.8G data/owt_train.txt
4.3G data/owt_train.txt.gz
```
- `vocabulary` 和 `merges` 已经保存到磁盘。
- 性能分析数据也保存到了 `training.prof`。
根据上面的统计信息和数据集大小,我的电脑显然没法直接在OpenWebText数据集上训练。没办法,必须得分块加载到内存。既然都分块了,那自然可以考虑多块并行执行预分词,就像文档里提示的那样,用多进程来优化。
#### 性能分析
启动 `snakeviz` 来分析 `training.prof`。
1. 安装 `snakeviz`:从 `toml` 文件看,它不在依赖里,需要单独安装:`uv pip install snakeviz`。
2. 启动服务:`uv run snakeviz --server training.prof`。
3. 在浏览器里打开 `snakeviz` 的可视化界面。
下面是分析得到的瓶颈所在以及优化方向。
**时间分布总览**
总耗时1687秒,主要分为两大块:
| 函数 | 耗时 | 占比 |
| :--- | :--- | :--- |
| `pretokenize` | 619秒 | 37% |
| `train` 自身(合并循环) | 416秒 | 25% |
| `any` + genexpr(快速跳过) | 441秒 | 26% |
| `max`(选最优pair) | 161秒 | 10% |
- **瓶颈1:`pretokenize` 占了37%,而且有严重的重复编译问题。**
`_compile` 被调用了271万次,耗时30秒。原因很直接:每次调用 `re.finditer` 时,Python都会重新编译一次正则表达式。而 `pretokenize` 对每个chunk都调用一次 `finditer`——语料被分成了271万个chunk(也就是271万个文档)。
```python
chunks = re.split(escape_special_tokens, text)
# 当前做法:每个 chunk 都触发一次编译检查
for chunk in chunks:
for match in re.finditer(self.pattern, chunk):
# 每次都经过编译缓存查找
```
优化方向很明确:提前把正则表达式编译好。
```python
compiled_pattern = re.compile(self.pattern)
```
同样,那个用来切分的正则也建议提前编译。另外,读文件花了68秒(`read` 35秒 + `utf_8_decode` 32秒),对于一个2GB的文件来说,这个时间是正常的,没啥优化空间。
- **瓶颈2:用来“快速跳过”的逻辑,反而成了最大的瓶颈,占了26%。**
这是最反直觉的地方。`any` 被调用了5.84亿次,其内部的生成器表达式被调用了18.3亿次,加在一起花了441秒。
对应的是代码里这一行(第69行):
```python
if not any((old_encoding[i], old_encoding[i+1]) == merge_pair for i in range(len(old_encoding) - 1)):
continue
```
这个“快速跳过”不但没有帮上忙,反而拖了后腿。原因有三:
1. 每个word、每轮合并都要执行这个检查。
2. Python里面生成器表达式的调用开销相当高。
3. `len()` 被调用了5.88亿次。
对于大多数word来说,这个检查确实让它们跳过了后续的替换操作,但检查本身的开销,已经比被它跳过所节省的开销要大得多了。
一个更快的做法是建立索引:维护一个 `pair_to_words` 字典,记录每个pair出现在了哪些word里。这样合并的时候,直接查这个表就行了,完全不用遍历所有word。
```python
# 初始化时建立索引
pair_to_words = defaultdict(set)
for word in word_encodings:
enc = word_encodings[word]
for i in range(len(enc) - 1):
pair_to_words[(enc[i], enc[i+1])].add(word)
# 合并时只处理受影响的 word
for word in pair_to_words[merge_pair]:
# 直接拿到受影响的 word,不用遍历全部
...
```
- **瓶颈3:`max()` 每轮都要遍历整个 `pair_counts` 字典,占了10%。**
`max` 被调用了9746次(每轮一次),每次都要遍历整个 `pair_counts`。随着合并的进行,`pair_counts` 字典会越来越大,所以每次 `max` 的代价也会越来越高。
这里可以用堆(`heapq`)来替代 `max`,这样能把每轮选最优pair的复杂度从O(n)降到O(log n)。不过,用堆有一个复杂性:更新计数的时候需要处理“失效条目”(lazy deletion)。这个改动比之前提到的 `pair_to_words` 索引要复杂一些,建议先做索引优化,再考虑堆优化。
### 优化方向总结
按预期收益从高到低排列:
1. **建立 `pair_to_words` 索引**:直接消灭5.84亿次 `any` 调用,预计可节省400秒以上。
2. **预编译正则表达式**:消灭271万次重复编译,预计可节省约30秒。
3. **用堆优化 `max`**:预计可节省100秒以上,但实现起来会稍微复杂一些。
具体的优化实现,我们留到下篇文章再详细讲解。
来源:互联网
免责声明
本网站新闻资讯均来自公开渠道,力求准确但不保证绝对无误,内容观点仅代表作者本人,与本站无关。若涉及侵权,请联系我们处理。本站保留对声明的修改权,最终解释权归本站所有。