字节对编码(Byte Pair Encoding)

子词嵌入

在自然语言处理中,如果将整个单词作为词表中的一个词
存在词表过大的问题
同时一个新的词语如果不是常见词,也会出现OOV问题
同时不同词性的词,如果将它们认为成不同的词,会失去这部分的语义信息
因此便有了字节对编码(BPE)技术

字节的二面问到了这一算法
要求讲出具体过程并用代码实现,当时没能写出来
现对这一部分的内容进行补漏

BPE代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def get_max_freq(token_freqs): # 获得最大频率的bigram
pairs = collections.defaultdict(int)
for k, v in token_freqs.items():
symbols = k.split()
for i in range(len(symbols) - 1):
pairs[symbols[i], symbols[i + 1]] += v
return max(pairs, key=pairs.get)

def merge_vocab(max_pair, token_freqs, vocab): # 构建新词表
vocab.append(''.join(max_pair))
v_out = {}
bigram = re.escape(' '.join(max_pair)) # 原pair进行转义
p = re.compile('(?<!\S)' + bigram + '(?!\S)') # 前后都不能有空格
for word in token_freqs:
new_token = p.sub(''.join(max_pair), word) # 进行替换
v_out[new_token] = token_freqs[word]
return v_out

def segement_BPE(tokens, vocab): # 计算
outputs = []
for token in tokens:
start, end = 0, len(token)
cur_output = []
while start < len(token) and start < end:
if token[start: end] in vocab:
cur_output.append(token[start: end])
start = end
end = len(token)
else:
end -= 1
if start < len(token):
cur_output.append('[UNK]')
outputs.append(' '.join(cur_output))
return outputs

子词嵌入迭代过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
raw_token_freqs = {'low_': 5, 'lower_': 2, 'newest_': 6, 'widest_': 3}
token_freqs = {}
for k, v in raw_token_freqs.items():
token_freqs[' '.join(k)] = v
print(token_freqs)
vocab = [chr(ord('a') + i) for i in range(26)]
vocab.append('_')
num_merge = 10
for i in range(num_merge):
max_pair = get_max_freq(token_freqs)
token_freqs = merge_vocab(max_pair, token_freqs, vocab)
print(max_pair)
print(vocab)
print(token_freqs)

输出结果

1
2
3
4
5
6
7
8
9
10
11
12
('e', 's')
('es', 't')
('est', '_')
('l', 'o')
('lo', 'w')
('n', 'e')
('ne', 'w')
('new', 'est_')
('low', '_')
('w', 'i')
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', 'es', 'est', 'est_', 'lo', 'low', 'ne', 'new', 'newest_', 'low_', 'wi']
{'low_': 5, 'low e r _': 2, 'newest_': 6, 'wi d est_': 3}

切分数据集单词

我们还可以使用从一个数据集学习的子词来切分另一个数据集的单词

1
2
3
tokens = 'old world is lower, new world is the best'.split()
tokens = [i + '_' for i in tokens]
print(segement_BPE(tokens, vocab))

输出结果

1
['o l d _', 'w o r l d _', 'i s _', 'low e r [UNK]', 'new _', 'w o r l d _', 'i s _', 't h e _', 'b est_']