Exemple #1
0
  def testIndexedSlicesGradientInCondInWhileLoop(self):
    with ops.Graph().as_default():
      embedding_matrix = tf.get_variable(
          "embedding_matrix", [5, 5],
          initializer=tf.random_normal_initializer())

      def Cond(it, _):
        return it < 5
      def Body(it, cost):
        embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
        cost = tf.cond(tf.equal(it, 3),
                       lambda: tf.square(cost),
                       lambda: cost + tf.reduce_sum(embedding))
        return it + 1, cost
      _, cost = control_flow_ops.While(
          Cond, Body, [tf.constant(0), tf.constant(0.0)])

      dynamic_grads = tf.gradients(cost, [embedding_matrix])[0]
      dynamic_grads = tf.segment_sum(dynamic_grads.values,
                                     dynamic_grads.indices)

      embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
      static = tf.square(
          tf.reduce_sum(embedding) +
          tf.reduce_sum(embedding) +
          tf.reduce_sum(embedding)) + tf.reduce_sum(embedding)
      static_grads = tf.gradients(static, [embedding_matrix])[0]
      static_grads = tf.segment_sum(static_grads.values, static_grads.indices)

      with self.test_session() as sess:
        sess.run(tf.initialize_all_variables())
        self.assertAllEqual(*sess.run([static_grads, dynamic_grads]))
Exemple #2
0
 def testIndexedSlicesGradient(self):
   with ops.Graph().as_default():
     embedding_matrix = tf.get_variable(
         "embedding_matrix", [5, 5],
         initializer=tf.random_normal_initializer())
     def Cond(it, _):
       return it < 5
     def Body(it, cost):
       embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0])
       cost += tf.reduce_sum(embedding)
       return it + 1, cost
     _, cost = control_flow_ops.While(
         Cond, Body, [tf.constant(0), tf.constant(0.0)])
     optimizer = momentum.MomentumOptimizer(0.1, 0.9)
     train_op = optimizer.minimize(cost)
     with self.test_session() as sess:
       sess.run(tf.initialize_all_variables())
       for _ in range(10):
         sess.run([train_op])