Example #1
0
  def check(self, op, args, error, correct=None):
    # Within Google, the switch to scalar strict occurred at version 6.
    lenient = []
    strict = [5, 6]

    # Use placeholders to bypass shape inference, since only the C++
    # GraphDef level is ever scalar lenient.
    def placeholders(args, feed):
      if isinstance(args, tuple):
        return [placeholders(x, feed) for x in args]
      else:
        x = ops.convert_to_tensor(args).eval()
        fake = array_ops.placeholder(np.asarray(x).dtype)
        feed[fake] = x
        return fake

    # Test various GraphDef versions
    for version in strict + lenient:
      with ops.Graph().as_default() as g:
        test_util.set_producer_version(g, version)
        with self.test_session(graph=g) as sess:
          feed = {}
          xs = placeholders(args, feed)
          x = op(*xs)
          if version in strict:
            with self.assertRaisesOpError(error):
              sess.run(x, feed_dict=feed)
          else:
            r = sess.run(x, feed_dict=feed)
            if correct is not None:
              self.assertAllEqual(r, correct)
Example #2
0
    def check(self, op, args, error, correct=None):
        # Within Google, the switch to scalar strict occurred at version 6.
        lenient = []
        strict = [5, 6]

        # Use placeholders to bypass shape inference, since only the C++
        # GraphDef level is ever scalar lenient.
        def placeholders(args, feed):
            if isinstance(args, tuple):
                return [placeholders(x, feed) for x in args]
            else:
                x = ops.convert_to_tensor(args).eval()
                fake = array_ops.placeholder(np.asarray(x).dtype)
                feed[fake] = x
                return fake

        # Test various GraphDef versions
        for version in strict + lenient:
            with ops.Graph().as_default() as g:
                test_util.set_producer_version(g, version)
                with self.test_session(graph=g) as sess:
                    feed = {}
                    xs = placeholders(args, feed)
                    x = op(*xs)
                    if version in strict:
                        with self.assertRaisesOpError(error):
                            sess.run(x, feed_dict=feed)
                    else:
                        r = sess.run(x, feed_dict=feed)
                        if correct is not None:
                            self.assertAllEqual(r, correct)
Example #3
0
def batch_norm_op(tensor, mean, variance, beta, gamma, scale):
    """Fused kernel for batch normalization."""
    # _batch_norm_with_global_normalization is deprecated in v9
    test_util.set_producer_version(ops.get_default_graph(), 8)
    # pylint: disable=protected-access
    return gen_nn_ops._batch_norm_with_global_normalization(
        tensor, mean, variance, beta, gamma, 0.001, scale)
def batch_norm_op(tensor, mean, variance, beta, gamma, scale):
  """Fused kernel for batch normalization."""
  # _batch_norm_with_global_normalization is deprecated in v9
  test_util.set_producer_version(ops.get_default_graph(), 8)
  # pylint: disable=protected-access
  return gen_nn_ops._batch_norm_with_global_normalization(
      tensor, mean, variance, beta, gamma, 0.001, scale)
Example #5
0
    def testFoldBatchNorms(self):
        with self.test_session() as sess:
            inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
            input_op = constant_op.constant(np.array(inputs),
                                            shape=[1, 1, 6, 2],
                                            dtype=dtypes.float32)
            weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
            weights_op = constant_op.constant(np.array(weights),
                                              shape=[1, 2, 2, 2],
                                              dtype=dtypes.float32)
            conv_op = nn_ops.conv2d(input_op,
                                    weights_op, [1, 1, 1, 1],
                                    padding="SAME",
                                    name="conv_op")
            mean_op = constant_op.constant(np.array([10, 20]),
                                           shape=[2],
                                           dtype=dtypes.float32)
            variance_op = constant_op.constant(np.array([0.25, 0.5]),
                                               shape=[2],
                                               dtype=dtypes.float32)
            beta_op = constant_op.constant(np.array([0.1, 0.6]),
                                           shape=[2],
                                           dtype=dtypes.float32)
            gamma_op = constant_op.constant(np.array([1.0, 2.0]),
                                            shape=[2],
                                            dtype=dtypes.float32)
            test_util.set_producer_version(ops.get_default_graph(), 8)
            gen_nn_ops._batch_norm_with_global_normalization(conv_op,
                                                             mean_op,
                                                             variance_op,
                                                             beta_op,
                                                             gamma_op,
                                                             0.00001,
                                                             False,
                                                             name="output")
            original_graph_def = sess.graph_def
            original_result = sess.run(["output:0"])
        optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
            original_graph_def)

        with self.test_session() as sess:
            _ = importer.import_graph_def(optimized_graph_def,
                                          input_map={},
                                          name="optimized")
            optimized_result = sess.run(["optimized/output:0"])

        self.assertAllClose(original_result, optimized_result)

        for node in optimized_graph_def.node:
            self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
    def testBatchNormGradImpl(self):
        x_shape = [7, 5, 4, 6]
        param_shape = [6]
        np.random.seed(1)  # Make it reproducible.
        x_val = np.random.random_sample(x_shape).astype(np.float32)
        m_val = np.random.random_sample(param_shape).astype(np.float32)
        v_val = np.random.random_sample(param_shape).astype(np.float32)
        beta_val = np.random.random_sample(param_shape).astype(np.float32)
        gamma_val = np.random.random_sample(param_shape).astype(np.float32)
        backprop_val = np.random.random_sample(x_shape).astype(np.float32)
        for use_gpu in [False, True]:
            with self.cached_session(use_gpu=use_gpu) as sess:
                x = constant_op.constant(x_val, name="x")
                m = constant_op.constant(m_val, name="m")
                v = constant_op.constant(v_val, name="v")
                beta = constant_op.constant(beta_val, name="beta")
                gamma = constant_op.constant(gamma_val, name="gamma")
                backprop = constant_op.constant(backprop_val, name="backprop")
                epsilon = 0.001
                for scale_after_normalization in [True, False]:
                    # _batch_norm_with_global_normalization_grad is deprecated in v9
                    test_util.set_producer_version(ops.get_default_graph(), 8)
                    grad = gen_nn_ops.batch_norm_with_global_normalization_grad(
                        x, m, v, gamma, backprop, epsilon,
                        scale_after_normalization)
                    dx, dm, dv, db, dg = grad
                    self.assertEqual(grad.dx, dx)
                    self.assertEqual(grad.dm, dm)
                    self.assertEqual(grad.dv, dv)
                    self.assertEqual(grad.db, db)
                    self.assertEqual(grad.dg, dg)

                    on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon,
                                            scale_after_normalization, True)
                    odx, odm, odv, odb, odg = gradients_impl.gradients(
                        [on], [x, m, v, beta, gamma], [backprop])
                    if scale_after_normalization:
                        all_grads = self.evaluate(
                            [dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
                        to_check = ["dx", "dm", "dv", "db", "dg"]
                    else:
                        all_grads = self.evaluate(
                            [dx, dm, dv, db, odx, odm, odv, odb])
                        to_check = ["dx", "dm", "dv", "db"]
                    for i, _ in enumerate(to_check):
                        self.assertAllClose(all_grads[i + len(to_check)],
                                            all_grads[i],
                                            atol=0.000001)
  def testBatchNormGradImpl(self):
    x_shape = [7, 5, 4, 6]
    param_shape = [6]
    np.random.seed(1)  # Make it reproducible.
    x_val = np.random.random_sample(x_shape).astype(np.float32)
    m_val = np.random.random_sample(param_shape).astype(np.float32)
    v_val = np.random.random_sample(param_shape).astype(np.float32)
    beta_val = np.random.random_sample(param_shape).astype(np.float32)
    gamma_val = np.random.random_sample(param_shape).astype(np.float32)
    backprop_val = np.random.random_sample(x_shape).astype(np.float32)
    for use_gpu in [False, True]:
      with self.cached_session(use_gpu=use_gpu) as sess:
        x = constant_op.constant(x_val, name="x")
        m = constant_op.constant(m_val, name="m")
        v = constant_op.constant(v_val, name="v")
        beta = constant_op.constant(beta_val, name="beta")
        gamma = constant_op.constant(gamma_val, name="gamma")
        backprop = constant_op.constant(backprop_val, name="backprop")
        epsilon = 0.001
        for scale_after_normalization in [True, False]:
          # _batch_norm_with_global_normalization_grad is deprecated in v9
          test_util.set_producer_version(ops.get_default_graph(), 8)
          grad = gen_nn_ops.batch_norm_with_global_normalization_grad(
              x, m, v, gamma, backprop, epsilon, scale_after_normalization)
          dx, dm, dv, db, dg = grad
          self.assertEqual(grad.dx, dx)
          self.assertEqual(grad.dm, dm)
          self.assertEqual(grad.dv, dv)
          self.assertEqual(grad.db, db)
          self.assertEqual(grad.dg, dg)

          on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon,
                                  scale_after_normalization, True)
          odx, odm, odv, odb, odg = gradients_impl.gradients(
              [on], [x, m, v, beta, gamma], [backprop])
          if scale_after_normalization:
            all_grads = self.evaluate(
                [dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
            to_check = ["dx", "dm", "dv", "db", "dg"]
          else:
            all_grads = self.evaluate([dx, dm, dv, db, odx, odm, odv, odb])
            to_check = ["dx", "dm", "dv", "db"]
          for i, _ in enumerate(to_check):
            self.assertAllClose(
                all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
  def testFoldBatchNorms(self):
    with self.cached_session() as sess:
      inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
      input_op = constant_op.constant(
          np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32)
      weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
      weights_op = constant_op.constant(
          np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
      conv_op = nn_ops.conv2d(
          input_op, weights_op, [1, 1, 1, 1], padding="SAME", name="conv_op")
      mean_op = constant_op.constant(
          np.array([10, 20]), shape=[2], dtype=dtypes.float32)
      variance_op = constant_op.constant(
          np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
      beta_op = constant_op.constant(
          np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
      gamma_op = constant_op.constant(
          np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
      test_util.set_producer_version(ops.get_default_graph(), 8)
      gen_nn_ops._batch_norm_with_global_normalization(
          conv_op,
          mean_op,
          variance_op,
          beta_op,
          gamma_op,
          0.00001,
          False,
          name="output")
      original_graph_def = sess.graph_def
      original_result = sess.run(["output:0"])
    optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
        original_graph_def)

    with self.cached_session() as sess:
      _ = importer.import_graph_def(
          optimized_graph_def, input_map={}, name="optimized")
      optimized_result = sess.run(["optimized/output:0"])

    self.assertAllClose(original_result, optimized_result)

    for node in optimized_graph_def.node:
      self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
 def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
                    scale_after_normalization):
     """Original implementation."""
     test_util.set_producer_version(ops.get_default_graph(), 8)
     return gen_nn_ops._batch_norm_with_global_normalization(
         x, m, v, beta, gamma, epsilon, scale_after_normalization)
 def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
                    scale_after_normalization):
   """Original implementation."""
   test_util.set_producer_version(ops.get_default_graph(), 8)
   return gen_nn_ops._batch_norm_with_global_normalization(
       x, m, v, beta, gamma, epsilon, scale_after_normalization)