一、前言
Bert源码解读完了,具体怎么用于自己的项目呢?在Bert系列(四)——源码解读之Fine-tune中,我说只要修改两个地方。
重要的是明白根据不同任务调整输入格式和对loss的构建,这两个知识点学会之后,基本上也可以依葫芦画瓢做一些自己的任务了。
那么这一次,我们就来依葫芦画瓢,作一个中文分词项目。
二、数据准备
数据是我经词性标注@人民日报199801.txt 加工过后形成的2个样本文件,分别用于训练和验证,2份数据的格式一摸一样。
test.txt
train.txt
数据格式如下所示:
黄土地上的蒲公英(外一首) bmessbmesssss
同胞们、朋友们、女士们、先生们: bessbessbessbess
文字和标签用 \t 分隔开。相信做过分词的朋友都知道后面的标签是什么意思,不知道的我这里解释一下,分词是一个典型的基于字的标注任务,标签有四个分别是:b代表begin,m代表middle,e代表end,s代表single。
例如:黄土地上的蒲公英(外一首)按照自然切分呈现的是 黄土地 上 的 蒲公英 ( 外 一 首 )。黄土地的标签是bme 说明黄是一个词的开头begin, 土是一个词的中间middle,地是一个词的结尾end,上、的是s说明这是一个单字词single,于是任意的字都可以有一个标签去对应。
三、代码
1.转换输入格式
def get_labels(self):
return ["s", "b", "m", "e", "X", "[CLS]", "[SEP]"]
把Processor里的get_labels函数换成以下形式,其中"X"代表其他标签(虽然原则上每一个字符都能对应上面四种标签的任一个,但是因为WordpieceTokenizer的操作会分解个别字段,被分解的字符用"X"标记)
def _create_example(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(line[1])
assert len(text) == len(label)
examples.append(InputExample(guid=guid, text=text, label=label))
return examples
构造样本,没什么好说的
def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer,mode):
label_map = {}
for (i, label) in enumerate(label_list,1):
label_map[label] = i
label2idpath = './output/label2id.pkl'
if not os.path.exists(label2idpath):
with open(label2idpath,'wb') as w:
pickle.dump(label_map, w)
textlist = list(example.text)
labellist = list(example.label)
tokens = []
labels = []
unknow_index = []
for i, word in enumerate(textlist):
token = tokenizer.tokenize(word)
tokens.extend(token)
label_1 = labellist[i]
for m in range(len(token)):
if m == 0:
labels.append(label_1)
else:
labels.append("X")
if token[m] == "[UNK]":
unknow_index.append(i)
assert len(tokens) == len(labels)
if len(tokens) >= max_seq_length - 1:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
ntokens = []
segment_ids = []
label_ids = []
ntokens.append("[CLS]")
segment_ids.append(0)
label_ids.append(label_map["[CLS]"])
for i, token in enumerate(tokens):
ntokens.append(token)
segment_ids.append(0)
label_ids.append(label_map[labels[i]])
ntokens.append("[SEP]")
segment_ids.append(0)
label_ids.append(label_map["[SEP]"])
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(0)
ntokens.append("**NULL**")
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(label_ids) == max_seq_length
if ex_index < 5:
tf.logging.info("*** Example ***")
tf.logging.info("guid: %s" % (example.guid))
tf.logging.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens]))
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
tf.logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids]))
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_ids=label_ids,
)
output_tokens = []
for i, each in enumerate(ntokens):
if each != "[UNK]":
output_tokens.append(each)
else:
index = unknow_index[0]
output_tokens.append(textlist[index])
unknow_index = unknow_index[1:]
write_tokens(output_tokens, mode)
return feature
和run_squad.py以及run_classifier.py一样,将example转换成feature,每一个ntokens的格式:[CLS]句子[SEP],超出指定长度的truncate,不足的用0补齐。
2.构造loss
def create_model(bert_config, is_training, input_ids, input_mask,
segment_ids, labels, num_labels, use_one_hot_embeddings):
model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings
)
output_layer = model.get_sequence_output()
hidden_size = output_layer.shape[-1].value
output_weight = tf.get_variable(
"output_weights", [num_labels, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02)
)
output_bias = tf.get_variable(
"output_bias", [num_labels], initializer=tf.zeros_initializer()
)
with tf.variable_scope("loss"):
if is_training:
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
output_layer = tf.reshape(output_layer, [-1, hidden_size])
logits = tf.matmul(output_layer, output_weight, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
logits = tf.reshape(logits, [-1, FLAGS.max_seq_length, num_labels])
log_probs = tf.nn.log_softmax(logits, axis=-1)
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_sum(per_example_loss)
probabilities = tf.nn.softmax(logits, axis=-1)
predict = tf.argmax(probabilities, axis=-1)
return (loss, per_example_loss, logits, predict)
取模型最后一层输出的sequence_output(shape [batch_size, seq_length, hidden_size]),然后线性投影到标签上(shape [batch_size, seq_length, num_labels]) ,这里num_labels=8,7个标签+补齐用的0;
再计算交叉熵损失loss 和 预测的标签predict
四、运行&&评估
训练集有39404个,GPU运行3个epochs,耗时大概在28分钟
python3 run_cut.py --task_name="people" --do_train=True --do_predict=True --data_dir=$PEOPLEcut --vocab_file=$BERT_CHINESE_DIR/vocab.txt --bert_config_file=$BERT_CHINESE_DIR/bert_config.json --init_checkpoint=$BERT_CHINESE_DIR/bert_model.ckpt --max_seq_length=128 --train_batch_size=32 --learning_rate=2e-5 --num_train_epochs=3.0 --output_dir=./output/result_cut/
评估方法见中文分词器分词效果的评测方法,以下是采用测试集的评估结果:
INFO:tensorflow:***** Eval results *****
INFO:tensorflow: count = 9925
INFO:tensorflow: precision_avg = 0.9794
INFO:tensorflow: recall_avg = 0.9780
INFO:tensorflow: f1_avg = 0.9783
INFO:tensorflow: error_avg = 0.0213
测试数据用于测试效果,其中第一行为原文本,第二行为标签分词、第三行为预测分词:
国民经济保持了“高增长、低通胀”的良好发展态势
国民经济 保持 了 “ 高 增长 、 低 通胀 ” 的 良好 发展 态势
国民经济 保持 了 “ 高 增长 、 低 通胀 ” 的 良好 发展 态势农业生产再次获得好的收成,企业改革继续深化,人民生活进一步改善
农业 生产 再次 获得 好 的 收成 , 企业 改革 继续 深化 , 人民 生活 进一步 改善
农业 生产 再次 获得 好 的 收成 , 企业 改革 继续 深化 , 人民 生活 进一步 改善这些外交活动,符合和平与发展的时代主题,顺应世界走向多极化的趋势,对于促进国际社会的友好合作和共同发展作出了积极的贡献
这些 外交 活动 , 符合 和平 与 发展 的 时代 主题 , 顺应 世界 走向 多极化 的 趋势 , 对于 促进 国际 社会 的 友好 合作 和 共同 发展 作出 了 积极 的 贡献
这些 外交 活动 , 符合 和平 与 发展 的 时代 主题 , 顺应 世界 走向 多极化 的 趋势 , 对于 促进 国际 社会 的 友好 合作 和 共同 发展 作出 了 积极 的 贡献1998年,中国人民将满怀信心地开创新的业绩
1998年 , 中国 人民 将 满怀信心 地 开创 新 的 业绩
1998年 , 中国 人民 将 满怀信心 地 开创 新 的 业绩
五、总结
1.我们只需要修改两个地方就可以实现对bert的利用,输入格式和构造loss,其他任务也类似;
2.bert的中文输入格式必须是基于字的,所以在进行词性标注或者NER等任务时,也务必要以字而不是词作为基本输入单元,如果要基于词也不是不可以,那就需要自己重新预训练中文模型了;
想尝试的同学可以照这个思路修改下,我也会尽快开源代码和数据
更新: 我已经上传了代码 Bert_ChinesewordSegment
本文系列
Bert系列(一)——demo运行
Bert系列(二)——模型主体源码解读
Bert系列(三)——源码解读之Pre-train
Bert系列(四)——源码解读之Fine-tune
Reference
1.https://github.com/google-research/bert