コード例 #1
0
  def testFoldBatchNorms(self):
    with self.test_session() as sess:
      inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
      input_op = tf.constant(np.array(inputs), shape=[1, 1, 6, 2],
                             dtype=tf.float32)
      weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
      weights_op = tf.constant(np.array(weights), shape=[1, 2, 2, 2],
                               dtype=tf.float32)
      conv_op = tf.nn.conv2d(input_op, weights_op, [1, 1, 1, 1],
                             padding="SAME", name="conv_op")
      mean_op = tf.constant(np.array([10, 20]), shape=[2], dtype=tf.float32)
      variance_op = tf.constant(np.array([0.25, 0.5]), shape=[2],
                                dtype=tf.float32)
      beta_op = tf.constant(np.array([0.1, 0.6]), shape=[2],
                            dtype=tf.float32)
      gamma_op = tf.constant(np.array([1.0, 2.0]), shape=[2],
                             dtype=tf.float32)
      tf.get_default_graph().graph_def_versions.producer = 8
      gen_nn_ops._batch_norm_with_global_normalization(
          conv_op, mean_op, variance_op, beta_op, gamma_op, 0.00001, False,
          name="output")
      original_graph_def = sess.graph_def
      original_result = sess.run(["output:0"])
    optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
        original_graph_def)

    with self.test_session() as sess:
      _ = tf.import_graph_def(optimized_graph_def, input_map={},
                              name="optimized")
      optimized_result = sess.run(["optimized/output:0"])

    self.assertAllClose(original_result, optimized_result)

    for node in optimized_graph_def.node:
      self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
コード例 #2
0
  def testFoldFusedBatchNorms(self):
    for data_format, use_gpu, conv2d_func in [
        ("NHWC", False, nn_ops.conv2d), ("NCHW", True, nn_ops.conv2d),
        ("NHWC", False, nn_ops.depthwise_conv2d_native),
        ("NCHW", True, nn_ops.depthwise_conv2d_native)
    ]:
      with self.cached_session(use_gpu=use_gpu) as sess:
        inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
        input_op = constant_op.constant(
            np.array(inputs),
            shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
            dtype=dtypes.float32)
        if conv2d_func == nn_ops.conv2d:
          weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
          weights_op = constant_op.constant(
              np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
        else:
          weights = [1, 2, 0.3, 0.4]
          weights_op = constant_op.constant(
              np.array(weights), shape=[1, 2, 2, 1], dtype=dtypes.float32)
        conv_op = conv2d_func(
            input_op,
            weights_op, [1, 1, 1, 1],
            padding="SAME",
            data_format=data_format,
            name="conv_op")
        mean_op = constant_op.constant(
            np.array([10, 20]), shape=[2], dtype=dtypes.float32)
        variance_op = constant_op.constant(
            np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
        beta_op = constant_op.constant(
            np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
        gamma_op = constant_op.constant(
            np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
        ops.get_default_graph().graph_def_versions.producer = 9
        gen_nn_ops._fused_batch_norm(
            conv_op,
            gamma_op,
            beta_op,
            mean_op,
            variance_op,
            0.00001,
            is_training=False,
            data_format=data_format,
            name="output")
        original_graph_def = sess.graph_def
        original_result = sess.run(["output:0"])
      optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
          original_graph_def)

      _ = importer.import_graph_def(
          optimized_graph_def, input_map={}, name="optimized")
      optimized_result = sess.run(["optimized/output:0"])

      self.assertAllClose(
          original_result, optimized_result, rtol=1e-04, atol=1e-06)

      for node in optimized_graph_def.node:
        self.assertNotEqual("FusedBatchNorm", node.op)
コード例 #3
0
  def testFoldFusedBatchNorms(self):
    for data_format, use_gpu in [("NHWC", False), ("NCHW", True)]:
      with self.test_session(use_gpu=use_gpu) as sess:
        inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
        input_op = constant_op.constant(
            np.array(inputs),
            shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
            dtype=dtypes.float32)
        weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
        weights_op = constant_op.constant(
            np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
        conv_op = nn_ops.conv2d(
            input_op,
            weights_op, [1, 1, 1, 1],
            padding="SAME",
            data_format=data_format,
            name="conv_op")
        mean_op = constant_op.constant(
            np.array([10, 20]), shape=[2], dtype=dtypes.float32)
        variance_op = constant_op.constant(
            np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
        beta_op = constant_op.constant(
            np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
        gamma_op = constant_op.constant(
            np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
        ops.get_default_graph().graph_def_versions.producer = 9
        gen_nn_ops._fused_batch_norm(
            conv_op,
            gamma_op,
            beta_op,
            mean_op,
            variance_op,
            0.00001,
            is_training=False,
            data_format=data_format,
            name="output")
        original_graph_def = sess.graph_def
        original_result = sess.run(["output:0"])
      optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
          original_graph_def)

      with self.test_session(use_gpu=use_gpu) as sess:
        _ = importer.import_graph_def(
            optimized_graph_def, input_map={}, name="optimized")
        optimized_result = sess.run(["optimized/output:0"])

      self.assertAllClose(
          original_result, optimized_result, rtol=1e-04, atol=1e-06)

      for node in optimized_graph_def.node:
        self.assertNotEqual("FusedBatchNorm", node.op)