“冻结”一些variables/范围在tensorflow:stop_gradient与传递variables最小化
我试图实施敌对neural network ,它需要在交替训练minibatches期间“冻结”图中的一个或另一个部分。 即有两个子networking:G和D.
G( Z ) -> Xz D( X ) -> Y
G
损失函数依赖于D[G(Z)], D[X]
。
首先,我需要在所有G参数固定的情况下对D中的参数进行训练,然后使用D中的参数固定G中的参数。 第一种情况下的损失函数将是第二种情况下的负损失函数,并且更新将必须应用于第一或第二子networking的参数。
我看到tensorflow有tf.stop_gradient
函数。 为了训练D(下游)子networking,我可以使用这个function来阻止梯度stream向
Z -> [ G ] -> tf.stop_gradient(Xz) -> [ D ] -> Y
tf.stop_gradient
非常简洁,没有内联示例(例如seq2seq.py
太长而且不容易阅读),但看起来像在图创build期间必须调用它。 这是否意味着,如果我想要交替批量阻止/取消阻止梯度stream,我需要重新创build并重新初始化图模型?
此外,似乎tf.stop_gradient
阻止stream经G(上游)networking的tf.stop_gradient
,对吧?
作为替代scheme,我看到可以将优化器调用的variables列表作为opt_op = opt.minimize(cost, <list of variables>)
如果可以获得每个variables的范围中的所有variables子网。 一个人可以得到一个<list of variables>
为一个tf.scope?
在你的问题中提到的最简单的方法就是创build两个优化器操作,分别调用opt.minimize(cost, ...)
。 默认情况下,优化器将使用tf.trainable_variables()
所有variables。 如果要将variables过滤到特定范围,可以使用可选的scope
参数tf.get_collection()
,如下所示:
optimizer = tf.train.AdagradOptimzer(0.01) first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "scope/prefix/for/first/vars") first_train_op = optimizer.minimize(cost, var_list=first_train_vars) second_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "scope/prefix/for/second/vars") second_train_op = optimizer.minimize(cost, var_list=second_train_vars)
你可能想要考虑的另一个select是你可以在一个variables上设置trainable = False。 这意味着它不会被训练修改。
tf.Variable(my_weights, trainable=False)
@ mrry的回答是完全正确的,也许比我想提出的更一般化。 但我认为一个简单的方法来完成它只是将python引用直接传递给var_list
:
W = tf.Variable(...) C = tf.Variable(...) Y_est = tf.matmul(W,C) loss = tf.reduce_sum((data-Y_est)**2) optimizer = tf.train.AdamOptimizer(0.001) # You can pass the python object directly train_W = optimizer.minimize(loss, var_list=[W]) train_C = optimizer.minimize(loss, var_list=[C])
我在这里有一个独立的例子: https : //gist.github.com/ahwillia/8cedc710352eb919b684d8848bc2df3a
我不知道我的方法是否有缺陷,但是我用这个结构解决了这个问题:
do_gradient = <Tensor that evaluates to 0 or 1> no_gradient = 1 - do_gradient wrapped_op = do_gradient * original + no_gradient * tf.stop_gradient(original)
所以,如果do_gradient = 1
,那么值和渐变将会stream过,但是如果do_gradient = 0
,那么值将只stream过stop_gradient op,这将停止梯度回stream。
对于我的场景,将do_gradient挂钩到random_shuffle张量的索引让我随机训练不同的networking片段。