我又一次开始了“看不懂你掐死我系列”。标题名称是仿照知乎的一篇介绍傅里叶变换的文章起的。当时看完了觉得还真看懂了。可是关上网页再自己想的时候,就有想掐死博主的冲动~~ 为了致敬,这里贴出原文章,大家共勉。
抄袭标题:看不懂傅里叶变换就掐死他
这段时间做训练的时候需要分步训练不同的网络结构,最后把所有训练好的graph合并成一个大graph,前后接起来并且重新定义输入和输出再继续训练,这样分步先训练小网络再合成大网络的话效果会好一点,收敛的也会快一些。那么有个问题,怎么把训练好的好几个graph恢复训练参数再合并到一起呢?Tensorflow到底能不能这么做?如果能,那应该怎么做?在读了这篇,这篇知乎,和搜了无数个stackoverflow上的例子之后,终于有了答案。
要知道我们需要把每个pretrain的网络的结构和参数全都读进去,再把它们合并在一起。先不说合并的事,读取参数和结构就是个问题。比如下边这几个stackoverflow的帖子。1,2,3,4,5。他们都用了不同的读取方法。但是到底读取的是什么?有没有达到我们预期的目的却不清楚。所以我意识到先要把tensorflow的内部结构搞清楚,看看存有什么东西,再看看存储和读取的方式。先来看结构。
Tensorflow的内部结构:
我们都知道tensorflow里有graph,graph的节点就是运算operation。这个用tensorboard可视化可以看到。比如下面这就是个简单的graph。
这个graph在tensorflow里实际的存储方式是被序列化以后,以Protocol Buffer的形式存储的。这里有中文的对protobuf的介绍,是google开发的。
graph序列化的protobuf叫做graphDef,就是define graph的意思,一个graph的定义。这个graphDef可以用tf.train.write_graph()/tf.Import_graph_def()来写入和导出。上面stackoverflow里就有人用这个方法。然而graphDef里面其实是没有存储变量的,但是可以存常量,就是constant??梢杂靡恢纸衒reeze_graph的工具把变量替换成常量,这里有官方的介绍。一般来说没有必要这么做,因为既然存了网络,肯定有变量的信息,虽然不在graphDef里面,但是肯定在别的地方。其实它存在collectionDef里?;褂幸恍┢渌腄ef,所以干脆归纳一下:
MetaGraph - MetaInfoDef 这个是存metadata的,像版本信息啊,用户信息啥的
? ? ? ? ? ? ? ? ? ? - GraphDef 上面说的就是这个GraphDef
? ? ? ? ? ? ? ? ? ? - SaverDef 这个就是tf.train.Saver的saver
? ? ? ? ? ? ? ? ? ? - CollectionDef
这些Def的数据都存在一个叫MetaGraph的文件里。这个MetaGraph有官方介绍。
最后面的collectionDef就是各种集合。每个集合里都是1对多的key/value pairs。你也可以把你想要的变量存进某个即合理,用tf.add_to_collection(collection_name,变量)就行。然后再用tf.get_collection()取出来。比如我有loss和train_op,就可以:
tf.add_to_collection("training_collection",loss)
tf.add_to_collection("training_collection",train_op)
然后再用
Train_collect = tf.get_collection(“training_collection”)? #得到一个python list
list里面就是你之前存的东西。所以collection我的理解就是为了方便管理变量用的。
metagraph可以用export_meta_graph/Import_meta_graph来导入导出。
这里注意了,如果你用tf.Import_graph_def()导入graphDef的话,导入的东西一般是不能训练的。但是用Import_meta_graph来导入metagraph之后,就是导入了一个完整的结构,这时候是可以训练的。
虽然能训练,metagraph里也有变量,但是都是起始值。也就是说我们之前训练的参数是没有导入的。这里训练等于是从头训练。实际的训练参数没有存在metagraph里,而是在data文件里。这个下面会提到。
说完了tensorflow的结构,再说说存储的方式??赐暾饨?,你应该完全知道什么api是用来读什么的了。
存储与读取:
上面那篇中文知乎恰好总结了这些。一般存读有3个API:
tf.train.Saver()/saver.restore()
export_meta_graph/Import_meta_graph
tf.train.write_graph()/tf.Import_graph_def()
后两个上一节都见过了。现在说说第一个。
我平时常用的只有第一个tf.train.Saver()和saver.restore()。我也看到很多代码里这么写。但是有一点很坑爹的是tf.train.saver.save() 什么都保存。但是在恢复图时,tf.train.saver.restore() 只恢复 Variable,如果要从MetaGraph恢复图,需要使用 import_meta_graph。看明白了吗?saver.save()和saver.restore()保存和读取的东西不!一!样!也就是说如果我想重组graph,要么用Import_meta_graph来导入graph,之后再saver.restore();要么就从新建立graph,把tensor传入结构的过程再写一遍,然后再saver.restore()。不然连变量名都找不到肯定会报错。
说道存储,我们必须得看看存储文件的格式。如果你用saver.save()保存的话(好像也只有这一种方法),打开你的保存文件夹,你会看到4种后缀名的文件(events开头的不算,那是tf.summary生成给tensorboard用的),分别是:
checkpoint?- 就是一个账本文件,可以使用高级帮助程序来加载不同的时间保存的chkp文件。没什么用
.meta?- 保存压缩后的Metagraph的protobufs,其实就是Metagraph。
.index -?包含一个不可变的键值表,用于链接序列化的张量名称以及在chkp.data文件中查找其数据的位置,也没存什么实际东西
.data - 这个里面才是存了训练后的参数。通常比.meta要大。有的时候有多个data文件用于共享或创建多个训练的时间戳。
其中.data文件的名字一般都是这种格式的:
<prefix>-<global_step>.data-<shard_index>-of-<number_of_shards>.
比如:
所以saver.restore()的时候其实是restore的.data文件。当然在restore之前可以用tf.train.latest_checkpoint()来得到最后一次存储点?;褂幸坏闶窃趕aver.save()和restore的时候,那个文件对象是xxx.ckpt。但实际上在存储文件夹里你找不到xxx.ckpt文件。这个也是正常的。官方文档有说.ckpt文件其实是隐性的的。所以除非你文件名字输入错了,不然不用担心读错文件。
下面结合我的实例再看看怎么合并graph。
实例:
先稍微介绍一下网络的结构。我有四个网络结构。其中3个网络是平行的,这里就叫p1,p2和p3吧。最后一个网络是微调用的,就叫m吧。这个m会得到3个网络的输出,合并在一起作为m的输入,输入到m,最后得到最终结果。为了方便理解我画了个图。
如果直接训练这么大的网络,收敛起来一定很费劲,有可能某一个网络落到一个local minimum就出不去了。所以我们把p1,p2,p3拿出来单独训练,每次只训练一个。
我分别用数据训练这3个网络。这个训练阶段算是pretrain。待到三个网络都稳定的时候,我把它们的输出结果加在一起,输入到第四个网络里训练整个网络。
官方文件称feed_dicts是效率最低的方法,所以我们改用的tfrecord和dataset api来读取文件。如果你不清楚这是啥,可以参看我们办公室博导的简书,这家伙可厉害了~
现在有两个问题,1是用Import_meta_graph导入metagraph的方法没法合并graph,因为我写的数据导入之后拿不出来(或者说我不知道怎么拿出来,可能有api可以取出来)。p1,p2,p3的输出数据是要手动连接的。import_graph_def()也可以设置input,output mapping,但是我这里没有tf.placeholder。我必须拿到一个从p1,p2,p3合成出来的tensor,再塞到m里去。所以我选择了用重建graph的方法。用
traindata, label = data_iterator(tfrecord_path).get_next()?
得到数据,再把traindata分别放入p1,p2,p3的架构中:
out_p1 = networkp1(trandata_p1)
网络结构有了,再restore参数:
full_path = tf.train.latest_checkpoint(model_ckp)
saver.restore(sess, full_path)
p2和p3也这么做。
三个全恢复了会得到三个output,再合并
m_data = out_p1 + out_p2 + out_p3
再输入m中:
output_m = networkm(m_data)
之后再做loss,bp,summary啥的,就可以训练了。
需要注意的是,别恢复错了graph。不要建3个session下分别用3个graph恢复,因为那样到
m_data = out_p1 + out_p2 + out_p3 #如果三个out是不同的graph,这里会报错
这一步会报错。说不同的graph出来的结果是不能相互运算的。大家必须是在同一个graph里才行。所以要建一个session,在这个session下挨个恢复:
with tf.session as sess: # 下面每个restore里不要单建 with tf.graph():...?
? ? # restore p1
? ? # restore p2
? ? # retore p3
? ? # ....
等于是把大家依次放进default graph里。再填上最后的m就ok了。
references:
https://blog.metaflow.fr/tensorflow-saving-restoring-and-mixing-multiple-models-c4c94d5d7125
https://zhuanlan.zhihu.com/p/31308381
https://www.tensorflow.org/api_guides/python/meta_graph#What_s_in_a_MetaGraph
http://08643.cn/p/0f9f2bb962f4
stackoverflow:
https://stackoverflow.com/questions/41990014/load-multiple-models-in-tensorflow
https://stackoverflow.com/questions/45093688/how-to-understand-sess-as-default-and-sess-graph-as-default
https://stackoverflow.com/questions/49864234/tensorflow-restoring-variables-from-two-checkpoints-after-combining-two-graphs
https://stackoverflow.com/questions/49490262/combining-graphs-is-there-a-tensorflow-import-graph-def-equivalent-for-c
https://stackoverflow.com/questions/41607144/loading-two-models-from-saver-in-the-same-tensorflow-session