TensorFlow,为什么保存模型后有3个文件?
读过文档后 ,我在TensorFlow
保存了一个模型,这里是我的演示代码:
# Create some variables. v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") ... # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. .. # Save the variables to disk. save_path = saver.save(sess, "/tmp/model.ckpt") print("Model saved in file: %s" % save_path)
但之后,我发现有3个文件
model.ckpt.data-00000-of-00001 model.ckpt.index model.ckpt.meta
而且我不能通过恢复model.ckpt
文件来恢复模型,因为没有这样的文件。 这是我的代码
with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "/tmp/model.ckpt")
那么,为什么有3个文件?
尝试这个:
with tf.Session() as sess: saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta') saver.restore(sess, "/tmp/model.ckpt")
TensorFlow保存方法保存三种文件,因为它将图结构与variables值分开存储。 .meta
文件描述了保存的graphics结构,因此在恢复检查点之前需要导入它(否则它不知道保存的检查点值对应的variables)。
或者,你可以这样做:
# Recreate the EXACT SAME variables v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") ... # Now load the checkpoint variable values with tf.Session() as sess: saver = tf.train.Saver() saver.restore(sess, "/tmp/model.ckpt")
即使没有名为model.ckpt
文件,在恢复时仍然会通过该名称引用保存的检查点。 来自saver.py
源代码 :“用户只需要与用户指定的前缀进行交互…而不是任何物理path名称”。
-
元文件 :描述保存的graphics结构,包括GraphDef,SaverDef等等; 然后应用
tf.train.import_meta_graph('/tmp/model.ckpt.meta')
,将恢复Saver
和Graph
。 -
索引文件 :它是一个string不可变表(tensorflow :: table :: Table)。 每个键是一个张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:哪个“数据”文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等
-
数据文件 :是TensorBundle集合,保存所有variables的值。
我正在从Word2Vec tensorflow教程恢复受过训练的单词embedded。
如果您创build了多个检查点:
例如创build的文件看起来像这样
model.ckpt-55695.data 00000-的-00001
model.ckpt-55695.index
model.ckpt-55695.meta
尝试这个
def restore_session(self, session): saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta') saver.restore(session, './tmp/model.ckpt-55695')
当调用restore_session()时:
def test_word2vec(): opts = Options() with tf.Graph().as_default(), tf.Session() as session: with tf.device("/cpu:0"): model = Word2Vec(opts, session) model.restore_session(session) model.get_embedding("assistance")