Exemple #1
0
    def _check_grad_angle_combined(self, grads, grads_true):
        """Verify that the reconstructed gradients has correct direction.

    Due to numerical imprecision, the magnitude may be slightly different.
    Yet according to the paper, the angle should be roughly the same.

    Args:
      grads: list of gradients from reconstruction
      grads_true: list of true gradients
    """
        def _combine(gs):
            return [tf.reshape(g, [-1]) for g in gs]

        g1_all = tf.concat(_combine(grads), axis=0)
        g2_all = tf.concat(_combine(grads_true), axis=0)

        self.assertEqual(len(g1_all.shape), 1)
        self.assertEqual(len(g2_all.shape), 1)

        degree = blocks_test.compute_degree(g1_all, g2_all)
        self.assertLessEqual(degree, 1e0)
  def _check_grad_angle_combined(self, grads, grads_true):
    """Verify that the reconstructed gradients has correct direction.

    Due to numerical imprecision, the magnitude may be slightly different.
    Yet according to the paper, the angle should be roughly the same.

    Args:
      grads: list of gradients from reconstruction
      grads_true: list of true gradients
    """

    def _combine(gs):
      return [tf.reshape(g, [-1]) for g in gs]

    g1_all = tf.concat(_combine(grads), axis=0)
    g2_all = tf.concat(_combine(grads_true), axis=0)

    self.assertEqual(len(g1_all.shape), 1)
    self.assertEqual(len(g2_all.shape), 1)

    degree = blocks_test.compute_degree(g1_all, g2_all)
    self.assertLessEqual(degree, 1e0)