示例#1
0
  def _test_gradient(self,
                     x_shape,
                     x_dtype,
                     scale_shape,
                     scale_dtype,
                     use_gpu=True,
                     data_format='NHWC',
                     is_training=True):
    np.random.seed(1)
    x_val = np.random.random_sample(x_shape).astype(x_dtype)
    scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)

    with self.cached_session(use_gpu=use_gpu):
      x = constant_op.constant(x_val, name='x')
      scale = constant_op.constant(scale_val, name='scale')
      offset = constant_op.constant(offset_val, name='offset')
      if is_training:
        pop_mean = None
        pop_var = None
      else:
        pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
        pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
      y, _, _ = nn_impl.fused_batch_norm(
          x,
          scale,
          offset,
          mean=pop_mean,
          variance=pop_var,
          data_format=data_format,
          is_training=is_training)
      if x_dtype != np.float16:
        err_x = gradient_checker.compute_gradient_error(x, x_shape, y, x_shape)
        err_scale = gradient_checker.compute_gradient_error(
            scale, scale_shape, y, x_shape)
        err_offset = gradient_checker.compute_gradient_error(
            offset, scale_shape, y, x_shape)
      else:
        x32 = constant_op.constant(x_val, name='x32', dtype=dtypes.float32)
        y32, _, _ = nn_impl.fused_batch_norm(
            x32,
            scale,
            offset,
            mean=pop_mean,
            variance=pop_var,
            data_format=data_format,
            is_training=is_training)
        err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32,
                                                     x_shape)
        err_scale = self._compute_gradient_error_float16(
            scale, scale, scale_shape, y, y32, x_shape)
        err_offset = self._compute_gradient_error_float16(
            offset, offset, scale_shape, y, y32, x_shape)

    x_err_tolerance = 2e-3 if x_dtype == np.float16 else 1e-3
    scale_err_tolerance = 1e-3
    self.assertLess(err_x, x_err_tolerance)
    self.assertLess(err_scale, scale_err_tolerance)
    self.assertLess(err_offset, scale_err_tolerance)
  def _test_gradient(self,
                     x_shape,
                     x_dtype,
                     scale_shape,
                     scale_dtype,
                     use_gpu=True,
                     data_format='NHWC',
                     is_training=True):
    np.random.seed(1)
    x_val = np.random.random_sample(x_shape).astype(x_dtype)
    scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)

    with self.test_session(use_gpu=use_gpu):
      x = constant_op.constant(x_val, name='x')
      scale = constant_op.constant(scale_val, name='scale')
      offset = constant_op.constant(offset_val, name='offset')
      if is_training:
        pop_mean = None
        pop_var = None
      else:
        pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
        pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
      y, _, _ = nn_impl.fused_batch_norm(
          x,
          scale,
          offset,
          mean=pop_mean,
          variance=pop_var,
          data_format=data_format,
          is_training=is_training)
      if x_dtype != np.float16:
        err_x = gradient_checker.compute_gradient_error(x, x_shape, y, x_shape)
        err_scale = gradient_checker.compute_gradient_error(
            scale, scale_shape, y, x_shape)
        err_offset = gradient_checker.compute_gradient_error(
            offset, scale_shape, y, x_shape)
      else:
        x32 = constant_op.constant(x_val, name='x32', dtype=dtypes.float32)
        y32, _, _ = nn_impl.fused_batch_norm(
            x32,
            scale,
            offset,
            mean=pop_mean,
            variance=pop_var,
            data_format=data_format,
            is_training=is_training)
        err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32,
                                                     x_shape)
        err_scale = self._compute_gradient_error_float16(
            scale, scale, scale_shape, y, y32, x_shape)
        err_offset = self._compute_gradient_error_float16(
            offset, offset, scale_shape, y, y32, x_shape)

    x_err_tolerance = 2e-3 if x_dtype == np.float16 else 1e-3
    scale_err_tolerance = 1e-3
    self.assertLess(err_x, x_err_tolerance)
    self.assertLess(err_scale, scale_err_tolerance)
    self.assertLess(err_offset, scale_err_tolerance)
 def call_train_deterministic(self, x_in, enable_ema_updates):
     if x_in.shape.rank == 4:
         # Fused batch norm is way faster than our own implementation
         x_out, batch_mean, batch_var = fused_batch_norm(
                 x_in,
                 self.gamma,
                 self.beta,
                 epsilon=self.epsilon,
                 data_format='NHWC',
                 is_training=True)
         batch_inv_std = tf.math.rsqrt(batch_var + self.epsilon)
     else:
         reduce_dims = list(range(x_in.shape.rank - 1)) # reduce all but the last dimension
         n_reduce_inv = 1.0 / (tf.cast(tf.reduce_prod(x_in.shape[:-1]), tf.float32) - 1.0)
         
         batch_mean = tf.reduce_mean(x_in, axis=reduce_dims)
         x_centered = x_in - batch_mean
         batch_var = tf.reduce_sum(tf.square(x_centered), axis=reduce_dims) * n_reduce_inv
         batch_inv_std = tf.math.rsqrt(batch_var + self.epsilon)
         scaling_factor = batch_inv_std * self.gamma
         x_scaled = x_centered * scaling_factor
         x_out = x_scaled + self.beta
     
     if enable_ema_updates:
         self.add_update(assign_sub(self.ema_batch_mean, self.decay * (self.ema_batch_mean - batch_mean)))
         self.add_update(assign_sub(self.ema_batch_inv_std, self.decay * (self.ema_batch_inv_std - batch_inv_std)))
         # The following code is equivalent
         # self.add_update(assign(self.ema_batch_mean, self.momentum * self.ema_batch_mean + (1.0 - self.momentum) * batch_mean))
         # self.add_update(assign(self.ema_batch_inv_std, self.momentum * self.ema_batch_inv_std + (1.0 - self.momentum) * batch_inv_std))
             
     return x_out
示例#4
0
    def _test_inference(self,
                        x_shape,
                        scale_shape,
                        use_gpu=True,
                        data_format='NHWC'):
        np.random.seed(1)
        x_val = np.random.random_sample(x_shape).astype(np.float32)
        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
        offset_val = np.random.random_sample(scale_shape).astype(np.float32)
        mean_val = np.random.random_sample(scale_shape).astype(np.float32)
        var_val = np.random.random_sample(scale_shape).astype(np.float32)

        with self.test_session(use_gpu=use_gpu) as sess:
            x = constant_op.constant(x_val, name='x')
            scale = constant_op.constant(scale_val, name='scale')
            offset = constant_op.constant(offset_val, name='offset')
            mean = constant_op.constant(mean_val, name='mean')
            var = constant_op.constant(var_val, name='variance')
            epsilon = 0.001
            y, _, _ = nn_impl.fused_batch_norm(x,
                                               scale,
                                               offset,
                                               mean=mean,
                                               variance=var,
                                               epsilon=epsilon,
                                               data_format=data_format,
                                               is_training=False)
            y_val = sess.run(y)
            y_ref = self._inference_ref(x, scale, offset, mean, var, epsilon,
                                        data_format)
        self.assertAllClose(y_ref, y_val, atol=1e-3)
    def test5dBatchNormFollowedByRelu(self):
        # The remapper grappler pass previously did not properly handle a 5D
        # inference FusedBatchNorm followed by Relu. This asserts that this case is
        # correctly handled.
        np.random.seed(1)
        x = np.random.random_sample((2, 3, 2, 2, 3)).astype(np.float32)
        scale = np.random.random_sample((3, )).astype(np.float32)
        offset = np.random.random_sample((3, )).astype(np.float32)
        mean = np.random.random_sample((3, )).astype(np.float32)
        var = np.random.random_sample((3, )).astype(np.float32)

        epsilon = 0.001
        y, _, _ = nn_impl.fused_batch_norm(x,
                                           scale,
                                           offset,
                                           mean=mean,
                                           variance=var,
                                           epsilon=epsilon,
                                           data_format='NCDHW',
                                           is_training=False)
        y = nn_ops.relu(y)
        y_val = self.evaluate(y)
        y_ref = self._inference_ref(x, scale, offset, mean, var, epsilon,
                                    'NCDHW')
        y_ref = np.maximum(y_ref, 0.)
        self.assertAllClose(y_ref, y_val, atol=1e-3)
示例#6
0
 def _test_training(self,
                    x_shape,
                    scale_shape,
                    use_gpu=True,
                    data_format='NHWC'):
     np.random.seed(1)
     x_val = np.random.random_sample(x_shape).astype(np.float32)
     scale_val = np.random.random_sample(scale_shape).astype(np.float32)
     offset_val = np.random.random_sample(scale_shape).astype(np.float32)
     with self.test_session(use_gpu=use_gpu) as sess:
         x = constant_op.constant(x_val, name='x')
         scale = constant_op.constant(scale_val, name='scale')
         offset = constant_op.constant(offset_val, name='offset')
         epsilon = 0.001
         y, mean, var = nn_impl.fused_batch_norm(x,
                                                 scale,
                                                 offset,
                                                 epsilon=epsilon,
                                                 data_format=data_format,
                                                 is_training=True)
         y_val, mean_val, var_val = sess.run([y, mean, var])
         y_ref, mean_ref, var_ref = self._training_ref(
             x, scale, offset, epsilon, data_format)
     self.assertAllClose(y_ref, y_val, atol=1e-3)
     self.assertAllClose(mean_ref, mean_val, atol=1e-3)
     # This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
     # the denominator in the formula to calculate variance, while
     # tf.nn.fused_batch_norm has Bessel's correction built in.
     sample_size = x_val.size / scale_val.size
     var_ref = var_ref * sample_size / (max(sample_size - 1.0, 1.0))
     self.assertAllClose(var_ref, var_val, atol=1e-3)
 def _test_training(self,
                    x_shape,
                    x_dtype,
                    scale_shape,
                    scale_dtype,
                    use_gpu=True,
                    data_format='NHWC'):
   np.random.seed(1)
   x_val = np.random.random_sample(x_shape).astype(x_dtype)
   scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
   offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
   with self.test_session(use_gpu=use_gpu) as sess:
     x = constant_op.constant(x_val, name='x')
     scale = constant_op.constant(scale_val, name='scale')
     offset = constant_op.constant(offset_val, name='offset')
     epsilon = 0.001
     y, mean, var = nn_impl.fused_batch_norm(
         x,
         scale,
         offset,
         epsilon=epsilon,
         data_format=data_format,
         is_training=True)
     y_val, mean_val, var_val = sess.run([y, mean, var])
     y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, epsilon,
                                                   data_format)
   y_atol = 2e-3 if x_dtype == np.float16 else 1e-3
   self.assertAllClose(y_ref, y_val, atol=y_atol)
   self.assertAllClose(mean_ref, mean_val, atol=1e-3)
   # This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
   # the denominator in the formula to calculate variance, while
   # tf.nn.fused_batch_norm has Bessel's correction built in.
   sample_size = x_val.size / scale_val.size
   var_ref = var_ref * sample_size / (max(sample_size - 1.0, 1.0))
   self.assertAllClose(var_ref, var_val, atol=1e-3)
示例#8
0
 def GraphFn(self, x):
   dtype = x.dtype
   x, _, _ = nn_impl.fused_batch_norm(
       x, [1.0, 1.0], [0.0, 0.0],
       mean=[0.5, 0.5],
       variance=[1.0, 1.0],
       data_format="NCHW",
       is_training=False)
   e = constant_op.constant(
       np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype)
   conv = nn.conv2d(
       input=x,
       filter=e,
       data_format="NCHW",
       strides=[1, 1, 2, 2],
       padding="SAME",
       name="conv")
   b = constant_op.constant(np.random.randn(6), name="bias", dtype=dtype)
   t = nn.bias_add(conv, b, data_format="NCHW", name="biasAdd")
   relu = nn.relu(t, "relu")
   idty = array_ops.identity(relu, "ID")
   v = nn_ops.max_pool(
       idty, [1, 1, 2, 2], [1, 1, 2, 2],
       "VALID",
       data_format="NCHW",
       name="max_pool")
   return array_ops.squeeze(v, name="output_0")
 def run_test(sess):
     inp = array_ops.placeholder(dtypes.float32)
     filt = array_ops.placeholder(dtypes.float32)
     scale = array_ops.placeholder(dtypes.float32)
     offset = array_ops.placeholder(dtypes.float32)
     mean = array_ops.placeholder(dtypes.float32)
     variance = array_ops.placeholder(dtypes.float32)
     relu_op = self.get_relu_op(relutype)
     bn, _, _ = nn_impl.fused_batch_norm(nn_ops.conv2d(
         inp, filt, strides=[1, 1, 1, 1], padding="SAME"),
                                         scale,
                                         offset,
                                         mean,
                                         variance,
                                         epsilon=0.02,
                                         is_training=False)
     return sess.run(
         relu_op(bn), {
             inp: inp_values,
             filt: filt_values,
             scale: scale_values,
             offset: offset_values,
             mean: mean_values,
             variance: variance_values,
         })
  def _test_inference(self,
                      x_shape,
                      scale_shape,
                      use_gpu=True,
                      data_format='NHWC'):
    np.random.seed(1)
    x_val = np.random.random_sample(x_shape).astype(np.float32)
    scale_val = np.random.random_sample(scale_shape).astype(np.float32)
    offset_val = np.random.random_sample(scale_shape).astype(np.float32)
    mean_val = np.random.random_sample(scale_shape).astype(np.float32)
    var_val = np.random.random_sample(scale_shape).astype(np.float32)

    with self.test_session(use_gpu=use_gpu) as sess:
      x = constant_op.constant(x_val, name='x')
      scale = constant_op.constant(scale_val, name='scale')
      offset = constant_op.constant(offset_val, name='offset')
      mean = constant_op.constant(mean_val, name='mean')
      var = constant_op.constant(var_val, name='variance')
      epsilon = 0.001
      y, _, _ = nn_impl.fused_batch_norm(
          x,
          scale,
          offset,
          mean=mean,
          variance=var,
          epsilon=epsilon,
          data_format=data_format,
          is_training=False)
      y_val = sess.run(y)
      y_ref = self._inference_ref(x, scale, offset, mean, var, epsilon,
                                  data_format)
    self.assertAllClose(y_ref, y_val, atol=1e-3)
示例#11
0
    def _test_gradient(self,
                       x_shape,
                       scale_shape,
                       use_gpu=True,
                       data_format='NHWC'):
        np.random.seed(1)
        x_val = np.random.random_sample(x_shape).astype(np.float32)
        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
        offset_val = np.random.random_sample(scale_shape).astype(np.float32)

        with self.test_session(use_gpu=use_gpu):
            x = constant_op.constant(x_val, name='x')
            scale = constant_op.constant(scale_val, name='scale')
            offset = constant_op.constant(offset_val, name='offset')
            y, _, _ = nn_impl.fused_batch_norm(x,
                                               scale,
                                               offset,
                                               data_format=data_format)
            err_x = gradient_checker.compute_gradient_error(
                x, x_shape, y, x_shape)
            err_scale = gradient_checker.compute_gradient_error(
                scale, scale_shape, y, x_shape)
            err_offset = gradient_checker.compute_gradient_error(
                offset, scale_shape, y, x_shape)
        err_tolerance = 1e-3
        self.assertLess(err_x, err_tolerance)
        self.assertLess(err_scale, err_tolerance)
        self.assertLess(err_offset, err_tolerance)
示例#12
0
 def _bn_fused(x_arg, scale_arg, offset_arg):
     return nn_impl.fused_batch_norm(x_arg,
                                     scale_arg,
                                     offset_arg,
                                     epsilon=epsilon,
                                     is_training=True,
                                     data_format=data_format)[0]
  def _test_gradient(self,
                     x_shape,
                     scale_shape,
                     use_gpu=True,
                     data_format='NHWC'):
    np.random.seed(1)
    x_val = np.random.random_sample(x_shape).astype(np.float32)
    scale_val = np.random.random_sample(scale_shape).astype(np.float32)
    offset_val = np.random.random_sample(scale_shape).astype(np.float32)

    with self.test_session(use_gpu=use_gpu):
      x = constant_op.constant(x_val, name='x')
      scale = constant_op.constant(scale_val, name='scale')
      offset = constant_op.constant(offset_val, name='offset')
      y, _, _ = nn_impl.fused_batch_norm(
          x, scale, offset, data_format=data_format)
      err_x = gradient_checker.compute_gradient_error(x, x_shape, y, x_shape)
      err_scale = gradient_checker.compute_gradient_error(scale, scale_shape, y,
                                                          x_shape)
      err_offset = gradient_checker.compute_gradient_error(offset, scale_shape,
                                                           y, x_shape)
    err_tolerance = 1e-3
    self.assertLess(err_x, err_tolerance)
    self.assertLess(err_scale, err_tolerance)
    self.assertLess(err_offset, err_tolerance)
示例#14
0
 def GetParams(self):
   """Single vgg layer test in TF-TRT conversion."""
   dtype = dtypes.float32
   input_name = "input"
   input_dims = [5, 8, 8, 2]
   output_name = "output"
   g = ops.Graph()
   with g.as_default():
     x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
     x, _, _ = nn_impl.fused_batch_norm(
         x, [1.0, 1.0], [0.0, 0.0],
         mean=[0.5, 0.5],
         variance=[1.0, 1.0],
         is_training=False)
     e = constant_op.constant(
         np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype)
     conv = nn.conv2d(
         input=x, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
     b = constant_op.constant(np.random.randn(6), name="bias", dtype=dtype)
     t = nn.bias_add(conv, b, name="biasAdd")
     relu = nn.relu(t, "relu")
     idty = array_ops.identity(relu, "ID")
     v = nn_ops.max_pool(
         idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
     array_ops.squeeze(v, name=output_name)
   return trt_test.TfTrtIntegrationTestParams(
       gdef=g.as_graph_def(),
       input_names=[input_name],
       input_dims=[input_dims],
       output_names=[output_name],
       expected_output_dims=[(5, 2, 2, 6)])
示例#15
0
 def _bn_fused(x_arg, scale_arg, offset_arg):
     return nn_impl.fused_batch_norm(x_arg,
                                     scale_arg,
                                     offset_arg,
                                     mean,
                                     variance,
                                     epsilon=epsilon,
                                     is_training=False)[0]
 def testForward(self):
     with self.cached_session():
         for data_format in ['NHWC', 'NCHW']:
             for large_batch in [False, True]:
                 for x_dtype in [dtypes.float16,
                                 dtypes.float32]:  # skipping bfloat16
                     x, scale, offset, mean, variance, _ = self._genParams(
                         data_format, x_dtype, large_batch)
                     for is_training in [False, True]:
                         op_output = nn_impl.fused_batch_norm(
                             x,
                             scale,
                             offset,
                             mean,
                             variance,
                             data_format=data_format,
                             is_training=is_training,
                             exponential_avg_factor=1.01)
                         y_a, running_mean_a, running_var_a = op_output
                         y_a = self.evaluate(y_a)
                         if is_training:
                             running_mean_a = self.evaluate(running_mean_a)
                             running_var_a = self.evaluate(running_var_a)
                         for _ in range(5):
                             op_output_b = nn_impl.fused_batch_norm(
                                 x,
                                 scale,
                                 offset,
                                 mean,
                                 variance,
                                 data_format=data_format,
                                 is_training=is_training,
                                 exponential_avg_factor=1.01)
                             y_b, running_mean_b, running_var_b = op_output_b
                             y_b = self.evaluate(y_b)
                             self.assertAllEqual(y_a, y_b)
                             if is_training:
                                 running_mean_b = self.evaluate(
                                     running_mean_b)
                                 running_var_b = self.evaluate(
                                     running_var_b)
                                 self.assertAllEqual(
                                     running_mean_a, running_mean_b)
                                 self.assertAllEqual(
                                     running_var_a, running_var_b)
 def testBackward(self):
     with self.cached_session():
         for data_format in ['NHWC', 'NCHW']:
             for large_batch in [False, True]:
                 for x_dtype in [dtypes.float16,
                                 dtypes.float32]:  # skipping bfloat16
                     params = self._genParams(data_format, x_dtype,
                                              large_batch)
                     x, scale, offset, mean, variance, upstream_gradients = params
                     for is_training in [False, True]:
                         for backprop_to in [x, scale, offset]:
                             with backprop.GradientTape(
                                     persistent=True) as tape:
                                 tape.watch(backprop_to)
                                 op_output = nn_impl.fused_batch_norm(
                                     x,
                                     scale,
                                     offset,
                                     mean,
                                     variance,
                                     data_format=data_format,
                                     is_training=is_training,
                                     exponential_avg_factor=0.99)
                                 gradient_injector_output = op_output[
                                     0] * upstream_gradients
                             if (len(config.list_physical_devices('GPU'))
                                     and not is_training):
                                 # Only backprop to offset is nondeterministic (on GPU, when
                                 # is_training=False), but backprop to the other parameters is
                                 # calculated using the same kernel.
                                 with self.assertRaisesRegex(
                                         errors_impl.UnimplementedError,
                                         'A deterministic GPU implementation of fused batch-norm'
                                         +
                                         ' backprop, when training is disabled, is not currently'
                                         + ' available.'):
                                     grad = tape.gradient(
                                         gradient_injector_output,
                                         backprop_to)
                                     self.evaluate(grad)
                             else:
                                 grad_a = tape.gradient(
                                     gradient_injector_output, backprop_to)
                                 grad_a = self.evaluate(grad_a)
                                 for _ in range(5):
                                     grad_b = tape.gradient(
                                         gradient_injector_output,
                                         backprop_to)
                                     grad_b = self.evaluate(grad_b)
                                     self.assertAllEqual(grad_a, grad_b)
示例#18
0
 def GetParams(self):
     """Single vgg layer in NCHW unit tests in TF-TRT."""
     dtype = dtypes.float32
     input_name = "input"
     input_dims = [5, 2, 8, 8]
     g = ops.Graph()
     with g.as_default():
         x = array_ops.placeholder(dtype=dtype,
                                   shape=input_dims,
                                   name=input_name)
         x, _, _ = nn_impl.fused_batch_norm(
             x,
             np.random.randn(2).astype(np.float32),
             np.random.randn(2).astype(np.float32),
             mean=np.random.randn(2).astype(np.float32),
             variance=np.random.randn(2).astype(np.float32),
             data_format="NCHW",
             is_training=False)
         e = constant_op.constant(np.random.randn(1, 1, 2, 6),
                                  name="weights",
                                  dtype=dtype)
         conv = nn.conv2d(input=x,
                          filter=e,
                          data_format="NCHW",
                          strides=[1, 1, 2, 2],
                          padding="SAME",
                          name="conv")
         b = constant_op.constant(np.random.randn(6),
                                  name="bias",
                                  dtype=dtype)
         t = nn.bias_add(conv, b, data_format="NCHW", name="biasAdd")
         relu = nn.relu(t, "relu")
         idty = array_ops.identity(relu, "ID")
         v = nn_ops.max_pool(idty, [1, 1, 2, 2], [1, 1, 2, 2],
                             "VALID",
                             data_format="NCHW",
                             name="max_pool")
         array_ops.squeeze(v, name="output")
     return trt_test.TfTrtIntegrationTestParams(gdef=g.as_graph_def(),
                                                input_names=[input_name],
                                                input_dims=[input_dims],
                                                num_expected_engines=1,
                                                expected_output_dims=(5, 6,
                                                                      2, 2),
                                                allclose_atol=1.e-03,
                                                allclose_rtol=1.e-03)
示例#19
0
    def _test01(self, dtype):
        with test_util.device(True):

            x_val = [5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15]
            shape = [1, 1, 6, 2]

            x = constant_op.constant(x_val, dtype=dtype, shape=shape)
            scale = constant_op.constant([4, 5], dtype=float)
            offset = constant_op.constant([2, 3], dtype=float)

            batch_norm, batch_mean, batch_var = nn_impl.fused_batch_norm(
                x, scale, offset, is_training=True)
            relu = nn_ops.relu(batch_norm)
            grad = gradients_impl.gradients(relu, x)

            y1 = array_ops.identity(grad)

            return (y1, )
    def _test_training(self,
                       x_shape,
                       x_dtype,
                       scale_shape,
                       scale_dtype,
                       use_gpu=True,
                       exponential_avg_factor=1.0,
                       data_format='NHWC'):
        np.random.seed(1)
        x_val = np.random.random_sample(x_shape).astype(x_dtype)
        scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
        offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
        if exponential_avg_factor == 1.0:
            old_mean_val = None
            old_var_val = None
        else:
            old_mean_val = np.random.random_sample(scale_shape).astype(
                scale_dtype)
            old_var_val = np.random.random_sample(scale_shape).astype(
                scale_dtype)

        with self.cached_session(use_gpu=use_gpu) as sess:
            x = constant_op.constant(x_val, name='x')
            scale = constant_op.constant(scale_val, name='scale')
            offset = constant_op.constant(offset_val, name='offset')
            epsilon = 0.001
            y, mean, var = nn_impl.fused_batch_norm(
                x,
                scale,
                offset,
                mean=old_mean_val,
                variance=old_var_val,
                epsilon=epsilon,
                exponential_avg_factor=exponential_avg_factor,
                data_format=data_format,
                is_training=True)
            y_val, mean_val, var_val = self.evaluate([y, mean, var])
            y_ref, mean_ref, var_ref = self._training_ref(
                x, scale, offset, old_mean_val, old_var_val,
                exponential_avg_factor, epsilon, data_format)
        y_atol = 2e-3 if x_dtype == np.float16 else 1e-3
        self.assertAllClose(y_ref, y_val, atol=y_atol)
        self.assertAllClose(mean_ref, mean_val, atol=1e-3)
        self.assertAllClose(var_ref, var_val, atol=1e-3)
 def GetParams(self):
   """Single vgg layer in NCHW unit tests in TF-TRT."""
   dtype = dtypes.float32
   input_name = "input"
   input_dims = [5, 2, 8, 8]
   g = ops.Graph()
   with g.as_default():
     x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
     x, _, _ = nn_impl.fused_batch_norm(
         x,
         np.random.randn(2).astype(np.float32),
         np.random.randn(2).astype(np.float32),
         mean=np.random.randn(2).astype(np.float32),
         variance=np.random.randn(2).astype(np.float32),
         data_format="NCHW",
         is_training=False)
     e = constant_op.constant(
         np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype)
     conv = nn.conv2d(
         input=x,
         filter=e,
         data_format="NCHW",
         strides=[1, 1, 2, 2],
         padding="SAME",
         name="conv")
     b = constant_op.constant(np.random.randn(6), name="bias", dtype=dtype)
     t = nn.bias_add(conv, b, data_format="NCHW", name="biasAdd")
     relu = nn.relu(t, "relu")
     idty = array_ops.identity(relu, "ID")
     v = nn_ops.max_pool(
         idty, [1, 1, 2, 2], [1, 1, 2, 2],
         "VALID",
         data_format="NCHW",
         name="max_pool")
     array_ops.squeeze(v, name="output")
   return trt_test.TfTrtIntegrationTestParams(
       gdef=g.as_graph_def(),
       input_names=[input_name],
       input_dims=[input_dims],
       num_expected_engines=1,
       expected_output_dims=(5, 6, 2, 2),
       allclose_atol=1.e-03,
       allclose_rtol=1.e-03)
示例#22
0
    def _test01(self, dtype):
        with test_util.device(True):

            x = constant_op.constant(
                [5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15],
                dtype=dtype,
                shape=[1, 1, 6, 2])
            scale = constant_op.constant([4, 5], dtype=float)
            offset = constant_op.constant([2, 3], dtype=float)

            batch_norm, batch_mean, batch_var = nn_impl.fused_batch_norm(
                x, scale, offset, is_training=True)
            relu = nn_ops.relu(batch_norm)

            y1 = array_ops.identity(relu)
            y2 = array_ops.identity(batch_mean)
            y3 = array_ops.identity(batch_var)

            return (y1, y2, y3)
    def _test_inference(self,
                        x_shape,
                        x_dtype,
                        scale_shape,
                        scale_dtype,
                        use_gpu=True,
                        exponential_avg_factor=1.0,
                        data_format='NHWC'):
        np.random.seed(1)
        x_val = np.random.random_sample(x_shape).astype(x_dtype)
        scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
        offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
        mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
        var_val = np.random.random_sample(scale_shape).astype(scale_dtype)

        with self.cached_session(use_gpu=use_gpu) as sess:
            x = constant_op.constant(x_val, name='x')
            scale = constant_op.constant(scale_val, name='scale')
            offset = constant_op.constant(offset_val, name='offset')
            mean = constant_op.constant(mean_val, name='mean')
            var = constant_op.constant(var_val, name='variance')
            epsilon = 0.001
            y, _, _ = nn_impl.fused_batch_norm(
                x,
                scale,
                offset,
                mean=mean,
                variance=var,
                epsilon=epsilon,
                exponential_avg_factor=exponential_avg_factor,
                data_format=data_format,
                is_training=False)
            y_val = self.evaluate(y)
            y_ref = self._inference_ref(x, scale, offset, mean, var, epsilon,
                                        data_format)
        # An atol value of 1e-3 is too small for float16's, because some adjacent
        # float16 values that y_val can take are greater than 1e-3 apart, e.g.
        # 2.16602 and 2.16797.
        atol = 2e-3 if x_dtype == np.float16 else 1e-3
        self.assertAllClose(y_ref, y_val, atol=atol)
示例#24
0
    def _test01(self, dtype):
        with test_util.device(True):
            x = constant_op.constant(
                [5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15],
                dtype=dtype,
                shape=[1, 1, 6, 2])
            scale = constant_op.constant([4, 5], dtype=float)
            offset = constant_op.constant([2, 3], dtype=float)
            batch_mean = constant_op.constant([10, 10], dtype=float)
            batch_var = constant_op.constant([14, 14], dtype=float)

            batch_norm, _, _ = nn_impl.fused_batch_norm(x,
                                                        scale,
                                                        offset,
                                                        mean=batch_mean,
                                                        variance=batch_var,
                                                        is_training=False)
            relu = nn_ops.relu(batch_norm)

            y1 = array_ops.identity(relu)

            return (y1, )
  def _test_inference(self,
                      x_shape,
                      x_dtype,
                      scale_shape,
                      scale_dtype,
                      use_gpu=True,
                      data_format='NHWC'):
    np.random.seed(1)
    x_val = np.random.random_sample(x_shape).astype(x_dtype)
    scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    var_val = np.random.random_sample(scale_shape).astype(scale_dtype)

    with self.test_session(use_gpu=use_gpu) as sess:
      x = constant_op.constant(x_val, name='x')
      scale = constant_op.constant(scale_val, name='scale')
      offset = constant_op.constant(offset_val, name='offset')
      mean = constant_op.constant(mean_val, name='mean')
      var = constant_op.constant(var_val, name='variance')
      epsilon = 0.001
      y, _, _ = nn_impl.fused_batch_norm(
          x,
          scale,
          offset,
          mean=mean,
          variance=var,
          epsilon=epsilon,
          data_format=data_format,
          is_training=False)
      y_val = sess.run(y)
      y_ref = self._inference_ref(x, scale, offset, mean, var, epsilon,
                                  data_format)
    # An atol value of 1e-3 is too small for float16's, because some adjacent
    # float16 values that y_val can take are greater than 1e-3 apart, e.g.
    # 2.16602 and 2.16797.
    atol = 2e-3 if x_dtype == np.float16 else 1e-3
    self.assertAllClose(y_ref, y_val, atol=atol)
    def _test_grad_grad(self,
                        x_shape,
                        x_dtype,
                        scale_shape,
                        scale_dtype,
                        use_gpu=True,
                        exponential_avg_factor=1.0,
                        data_format='NHWC',
                        is_training=True,
                        err_tolerance=1e-3):
        np.random.seed(1)
        x_val = np.random.random_sample(x_shape).astype(x_dtype)
        grad_y_val = np.random.random_sample(x_shape).astype(x_dtype)
        scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
        offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)

        with self.cached_session(use_gpu=use_gpu) as sess:
            x = constant_op.constant(x_val, name='x')
            grad_y = constant_op.constant(grad_y_val, name='grad_y')
            scale = constant_op.constant(scale_val, name='scale')
            offset = constant_op.constant(offset_val, name='offset')
            if is_training and exponential_avg_factor == 1.0:
                pop_mean = None
                pop_var = None
            else:
                pop_mean = np.random.random_sample(scale_shape).astype(
                    scale_dtype)
                pop_var = np.random.random_sample(scale_shape).astype(
                    scale_dtype)
            y, _, _ = nn_impl.fused_batch_norm(
                x,
                scale,
                offset,
                mean=pop_mean,
                variance=pop_var,
                exponential_avg_factor=exponential_avg_factor,
                data_format=data_format,
                is_training=is_training)
            grad_x, grad_scale, grad_offset = gradients_impl.gradients(
                y, [x, scale, offset], grad_y)

            if is_training:
                epsilon = y.op.get_attr('epsilon')
                data_format = y.op.get_attr('data_format')
                grad_vals = self.evaluate([grad_x, grad_scale, grad_offset])
                grad_internal = nn_grad._BatchNormGrad(grad_y, x, scale,
                                                       pop_mean, pop_var,
                                                       epsilon, data_format)
                grad_internal_vals = self.evaluate(list(grad_internal))
                for grad_val, grad_internal_val in zip(grad_vals,
                                                       grad_internal_vals):
                    self.assertAllClose(grad_val,
                                        grad_internal_val,
                                        atol=err_tolerance)

            if x_dtype != np.float16:
                err_grad_grad_y_1 = gradient_checker.compute_gradient_error(
                    grad_y, x_shape, grad_x, x_shape)
                err_grad_grad_y_2 = gradient_checker.compute_gradient_error(
                    grad_y, x_shape, grad_scale, scale_shape)
                err_grad_grad_y_3 = gradient_checker.compute_gradient_error(
                    grad_y, x_shape, grad_offset, scale_shape)
                # In freeze mode, grad_x is not a function of x.
                if is_training:
                    err_grad_x_1 = gradient_checker.compute_gradient_error(
                        x, x_shape, grad_x, x_shape)
                err_grad_x_2 = gradient_checker.compute_gradient_error(
                    x, x_shape, grad_scale, scale_shape)

                err_grad_scale = gradient_checker.compute_gradient_error(
                    scale, scale_shape, grad_x, x_shape)
            else:
                x32 = constant_op.constant(x_val,
                                           dtype=dtypes.float32,
                                           name='x32')
                grad_y32 = constant_op.constant(grad_y_val,
                                                dtype=dtypes.float32,
                                                name='grad_y32')
                y32, _, _ = nn_impl.fused_batch_norm(
                    x32,
                    scale,
                    offset,
                    mean=pop_mean,
                    variance=pop_var,
                    exponential_avg_factor=exponential_avg_factor,
                    data_format=data_format,
                    is_training=is_training)
                grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients(
                    y32, [x32, scale, offset], grad_y32)
                err_grad_grad_y_1 = self._compute_gradient_error_float16(
                    grad_y, grad_y32, x_shape, grad_x, grad_x32, x_shape)
                err_grad_grad_y_2 = self._compute_gradient_error_float16(
                    grad_y, grad_y32, x_shape, grad_scale, grad_scale32,
                    scale_shape)
                err_grad_grad_y_3 = self._compute_gradient_error_float16(
                    grad_y, grad_y32, x_shape, grad_offset, grad_offset32,
                    scale_shape)
                # In freeze mode, grad_x is not a function of x.
                if is_training:
                    err_grad_x_1 = self._compute_gradient_error_float16(
                        x, x32, x_shape, grad_x, grad_x32, x_shape)
                err_grad_x_2 = self._compute_gradient_error_float16(
                    x, x32, x_shape, grad_scale, grad_scale32, scale_shape)

                err_grad_scale = self._compute_gradient_error_float16(
                    scale, scale, scale_shape, grad_x, grad_x32, x_shape)

        self.assertLess(err_grad_grad_y_1, err_tolerance)
        self.assertLess(err_grad_grad_y_2, err_tolerance)
        self.assertLess(err_grad_grad_y_3, err_tolerance)
        if is_training:
            self.assertLess(err_grad_x_1, err_tolerance)
        self.assertLess(err_grad_x_2, err_tolerance)
        self.assertLess(err_grad_scale, err_tolerance)
def _fused_batchnorm(x, scale, offset):
  """Batchnorm."""
  return nn_impl.fused_batch_norm(
      x, scale=scale, offset=offset, is_training=True)
    def testEagerShapeErrors(self):
        with context.eager_mode():
            x = array_ops.ones((2, 2, 2, 2))
            scale = array_ops.ones((3, ))
            offset = array_ops.ones((2, ))
            with self.assertRaisesRegex(
                    errors_impl.InvalidArgumentError,
                    'scale must have the same number of elements'):
                nn_impl.fused_batch_norm(x, scale, offset)

            x = array_ops.ones((2, 2, 2, 2))
            scale = array_ops.ones((2, ))
            offset = array_ops.ones((3, ))
            with self.assertRaisesRegex(
                    errors_impl.InvalidArgumentError,
                    'offset must have the same number of elements'):
                nn_impl.fused_batch_norm(x, scale, offset)

            x = array_ops.ones((2, 2, 2, 2))
            scale = array_ops.ones((2, ))
            offset = array_ops.ones((2, ))
            mean = array_ops.ones((0, ))
            variance = array_ops.ones((2, ))
            with self.assertRaisesRegex(
                    errors_impl.InvalidArgumentError,
                    'When is_training=false, mean must have the same number of elements'
            ):
                nn_impl.fused_batch_norm(x,
                                         scale,
                                         offset,
                                         mean=mean,
                                         variance=variance,
                                         is_training=False)

            x = array_ops.ones((2, 2, 2, 2))
            scale = array_ops.ones((2, ))
            offset = array_ops.ones((2, ))
            mean = array_ops.ones((2, ))
            variance = array_ops.ones((0, ))
            with self.assertRaisesRegex(
                    errors_impl.InvalidArgumentError,
                    'When is_training=false, variance must have the same number of '
                    'elements'):
                nn_impl.fused_batch_norm(x,
                                         scale,
                                         offset,
                                         mean=mean,
                                         variance=variance,
                                         is_training=False)

            x = array_ops.ones((2, 2, 2, 2))
            scale = array_ops.ones((2, ))
            offset = array_ops.ones((2, ))
            mean = array_ops.ones((0, ))
            variance = array_ops.ones((2, ))
            with self.assertRaisesRegex(
                    errors_impl.InvalidArgumentError,
                    'When exponential_avg_factor != 1, mean must have the same number of '
                    'elements'):
                nn_impl.fused_batch_norm(x,
                                         scale,
                                         offset,
                                         mean=mean,
                                         variance=variance,
                                         exponential_avg_factor=0.5)

            x = array_ops.ones((2, 2, 2, 2))
            scale = array_ops.ones((2, ))
            offset = array_ops.ones((2, ))
            mean = array_ops.ones((2, ))
            variance = array_ops.ones((0, ))
            with self.assertRaisesRegex(
                    errors_impl.InvalidArgumentError,
                    'When exponential_avg_factor != 1, variance must have the same '
                    'number of elements'):
                nn_impl.fused_batch_norm(x,
                                         scale,
                                         offset,
                                         mean=mean,
                                         variance=variance,
                                         exponential_avg_factor=0.5)
    def _test_grad_grad(self,
                        x_shape,
                        scale_shape,
                        use_gpu=True,
                        data_format='NHWC',
                        is_training=True,
                        err_tolerance=1e-3):
        np.random.seed(1)
        x_val = np.random.random_sample(x_shape).astype(np.float32)
        grad_y_val = np.random.random_sample(x_shape).astype(np.float32)
        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
        offset_val = np.random.random_sample(scale_shape).astype(np.float32)

        with self.test_session(use_gpu=use_gpu) as sess:
            x = constant_op.constant(x_val, name='x')
            grad_y = constant_op.constant(grad_y_val, name='grad_y')
            scale = constant_op.constant(scale_val, name='scale')
            offset = constant_op.constant(offset_val, name='offset')
            if is_training:
                pop_mean = None
                pop_var = None
            else:
                pop_mean = np.random.random_sample(scale_shape).astype(
                    np.float32)
                pop_var = np.random.random_sample(scale_shape).astype(
                    np.float32)
            y, _, _ = nn_impl.fused_batch_norm(x,
                                               scale,
                                               offset,
                                               mean=pop_mean,
                                               variance=pop_var,
                                               data_format=data_format,
                                               is_training=is_training)
            grad_x, grad_scale, grad_offset = gradients_impl.gradients(
                y, [x, scale, offset], grad_y)

            if is_training:
                epsilon = y.op.get_attr('epsilon')
                data_format = y.op.get_attr('data_format')
                grad_vals = sess.run([grad_x, grad_scale, grad_offset])
                grad_internal = nn_grad._BatchNormGrad(grad_y, x, scale,
                                                       epsilon, data_format)
                grad_internal_vals = sess.run(list(grad_internal))
                for grad_val, grad_internal_val in zip(grad_vals,
                                                       grad_internal_vals):
                    self.assertAllClose(grad_val,
                                        grad_internal_val,
                                        atol=err_tolerance)

            err_grad_grad_y_1 = gradient_checker.compute_gradient_error(
                grad_y, x_shape, grad_x, x_shape)
            err_grad_grad_y_2 = gradient_checker.compute_gradient_error(
                grad_y, x_shape, grad_scale, scale_shape)
            err_grad_grad_y_3 = gradient_checker.compute_gradient_error(
                grad_y, x_shape, grad_offset, scale_shape)
            # In freeze mode, grad_x is not a function of x.
            if is_training:
                err_grad_x_1 = gradient_checker.compute_gradient_error(
                    x, x_shape, grad_x, x_shape)
            err_grad_x_2 = gradient_checker.compute_gradient_error(
                x, x_shape, grad_scale, scale_shape)

            err_grad_scale = gradient_checker.compute_gradient_error(
                scale, scale_shape, grad_x, x_shape)

        self.assertLess(err_grad_grad_y_1, err_tolerance)
        self.assertLess(err_grad_grad_y_2, err_tolerance)
        self.assertLess(err_grad_grad_y_3, err_tolerance)
        if is_training:
            self.assertLess(err_grad_x_1, err_tolerance)
        self.assertLess(err_grad_x_2, err_tolerance)
        self.assertLess(err_grad_scale, err_tolerance)
  def _test_grad_grad(self,
                      x_shape,
                      x_dtype,
                      scale_shape,
                      scale_dtype,
                      use_gpu=True,
                      data_format='NHWC',
                      is_training=True,
                      err_tolerance=1e-3):
    np.random.seed(1)
    x_val = np.random.random_sample(x_shape).astype(x_dtype)
    grad_y_val = np.random.random_sample(x_shape).astype(x_dtype)
    scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)

    with self.test_session(use_gpu=use_gpu) as sess:
      x = constant_op.constant(x_val, name='x')
      grad_y = constant_op.constant(grad_y_val, name='grad_y')
      scale = constant_op.constant(scale_val, name='scale')
      offset = constant_op.constant(offset_val, name='offset')
      if is_training:
        pop_mean = None
        pop_var = None
      else:
        pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
        pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
      y, _, _ = nn_impl.fused_batch_norm(
          x,
          scale,
          offset,
          mean=pop_mean,
          variance=pop_var,
          data_format=data_format,
          is_training=is_training)
      grad_x, grad_scale, grad_offset = gradients_impl.gradients(
          y, [x, scale, offset], grad_y)

      if is_training:
        epsilon = y.op.get_attr('epsilon')
        data_format = y.op.get_attr('data_format')
        grad_vals = sess.run([grad_x, grad_scale, grad_offset])
        grad_internal = nn_grad._BatchNormGrad(grad_y, x, scale, pop_mean, pop_var, epsilon, data_format)
        grad_internal_vals = sess.run(list(grad_internal))
        for grad_val, grad_internal_val in zip(grad_vals, grad_internal_vals):
          self.assertAllClose(grad_val, grad_internal_val, atol=err_tolerance)

      if x_dtype != np.float16:
        err_grad_grad_y_1 = gradient_checker.compute_gradient_error(
            grad_y, x_shape, grad_x, x_shape)
        err_grad_grad_y_2 = gradient_checker.compute_gradient_error(
            grad_y, x_shape, grad_scale, scale_shape)
        err_grad_grad_y_3 = gradient_checker.compute_gradient_error(
            grad_y, x_shape, grad_offset, scale_shape)
        # In freeze mode, grad_x is not a function of x.
        if is_training:
          err_grad_x_1 = gradient_checker.compute_gradient_error(
              x, x_shape, grad_x, x_shape)
        err_grad_x_2 = gradient_checker.compute_gradient_error(
            x, x_shape, grad_scale, scale_shape)

        err_grad_scale = gradient_checker.compute_gradient_error(
            scale, scale_shape, grad_x, x_shape)
      else:
        x32 = constant_op.constant(x_val, dtype=dtypes.float32, name='x32')
        grad_y32 = constant_op.constant(
            grad_y_val, dtype=dtypes.float32, name='grad_y32')
        y32, _, _ = nn_impl.fused_batch_norm(
            x32,
            scale,
            offset,
            mean=pop_mean,
            variance=pop_var,
            data_format=data_format,
            is_training=is_training)
        grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients(
            y32, [x32, scale, offset], grad_y32)
        err_grad_grad_y_1 = self._compute_gradient_error_float16(
            grad_y, grad_y32, x_shape, grad_x, grad_x32, x_shape)
        err_grad_grad_y_2 = self._compute_gradient_error_float16(
            grad_y, grad_y32, x_shape, grad_scale, grad_scale32, scale_shape)
        err_grad_grad_y_3 = self._compute_gradient_error_float16(
            grad_y, grad_y32, x_shape, grad_offset, grad_offset32, scale_shape)
        # In freeze mode, grad_x is not a function of x.
        if is_training:
          err_grad_x_1 = self._compute_gradient_error_float16(
              x, x32, x_shape, grad_x, grad_x32, x_shape)
        err_grad_x_2 = self._compute_gradient_error_float16(
            x, x32, x_shape, grad_scale, grad_scale32, scale_shape)

        err_grad_scale = self._compute_gradient_error_float16(
            scale, scale, scale_shape, grad_x, grad_x32, x_shape)

    self.assertLess(err_grad_grad_y_1, err_tolerance)
    self.assertLess(err_grad_grad_y_2, err_tolerance)
    self.assertLess(err_grad_grad_y_3, err_tolerance)
    if is_training:
      self.assertLess(err_grad_x_1, err_tolerance)
    self.assertLess(err_grad_x_2, err_tolerance)
    self.assertLess(err_grad_scale, err_tolerance)
def _fused_batchnorm(x, scale, offset):
  """Batchnorm."""
  return nn_impl.fused_batch_norm(
      x, scale=scale, offset=offset, is_training=True)