在TensorFlow中使用预先训练的单词embedded(word2vec或Glove)

我最近回顾了卷积文本分类的一个有趣的实现。 然而,我所检查的所有TensorFlow代码都使用了如下的随机(未经过预先训练的)embedded向量:

with tf.device('/cpu:0'), tf.name_scope("embedding"): W = tf.Variable( tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0), name="W") self.embedded_chars = tf.nn.embedding_lookup(W, self.input_x) self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1) 

有人知道如何使用Word2vec的结果或GloVe预先训练的词embedded,而不是随机的吗?

有几种方法可以在TensorFlow中使用预先训练的embedded。 假设您已经embedded了名为embedding的NumPy数组,并且使用了vocab_size rows和embedding_dim列,并且您想要创build一个可用于调用tf.nn.embedding_lookup()的张量W

  1. 简单地创buildW作为tf.constant() ,它将embedding作为其值:

     W = tf.constant(embedding, name="W") 

    这是最简单的方法,但它不是有效的内存,因为tf.constant()的值在内存中存储多次。 由于embedding可以是非常大的,你应该只使用这种方法玩具的例子。

  2. 创buildW作为tf.Variable并通过tf.placeholder()从NumPy数组中初始化它:

     W = tf.Variable(tf.constant(0.0, shape=[vocab_size, embedding_dim]), trainable=False, name="W") embedding_placeholder = tf.placeholder(tf.float32, [vocab_size, embedding_dim]) embedding_init = W.assign(embedding_placeholder) # ... sess = tf.Session() sess.run(embedding_init, feed_dict={embedding_placeholder: embedding}) 

    这避免了在graphics中存储embedding的副本,但是它确实需要足够的存储器来一次在存储器中保存两个matrix的副本(一个用于NumPy数组,一个用于tf.Variable )。 请注意,我假设你想在训练过程中保持embeddedmatrix不变,所以W被创build为trainable=False

  3. 如果将embedded作为另一个TensorFlow模型的一部分进行训练,则可以使用tf.train.Saver从另一个模型的检查点文件加载值。 这意味着embeddedmatrix可以完全绕过Python。 按照选项2创buildW ,然后执行以下操作:

     W = tf.Variable(...) embedding_saver = tf.train.Saver({"name_of_variable_in_other_model": W}) # ... sess = tf.Session() embedding_saver.restore(sess, "checkpoint_filename.ckpt") 

我使用这种方法来加载和共享embedded。

 W = tf.get_variable(name="W", shape=embedding.shape, initializer=tf.constant_initializer(embedding), trainable=False) 

@mrry的答案是不正确的,因为它挑战了每一个networking运行的embedded权重的覆盖,所以如果你遵循一个minibatch的方法来训练你的networking,你将覆盖embedded的权重。 所以,从我的观点来看,预先训练embedded的正确方法是:

 embeddings = tf.get_variable("embeddings", shape=[dim1, dim2], initializer=tf.constant_initializer(np.array(embeddings_matrix))