コード例 #1
0
  def testSplitWithNonConstAxis(self):
    if test.is_gpu_available(cuda_only=True):
      random_seed.set_random_seed(0)
      x = random_ops.truncated_normal([1, 784], seed=0)
      conv = _two_layer_model(x)
      dim = array_ops.placeholder(dtype='int32')
      split = array_ops.split(conv, 2, axis=dim)
      scale = constant_op.constant(0.1, shape=[32])
      offset = constant_op.constant(0.3, shape=[32])
      bn0 = nn.fused_batch_norm(split[0], scale, offset)
      bn1 = nn.fused_batch_norm(split[1], scale, offset)
      add = bn0[0] + bn1[0]
      output = array_ops.identity(add)

      with session.Session() as sess:
        output_val_ref = sess.run(output, feed_dict={dim: 3})

      with session.Session(config=_get_config()) as sess:
        metadata = config_pb2.RunMetadata()
        output_val = sess.run(output, run_metadata=metadata, feed_dict={dim: 3})

      nodes = []
      num_transposes = 0
      for node in metadata.cost_graph.node:
        if _is_transpose(node.name):
          num_transposes += 1
        nodes.append(node.name)

      expected_num_transposes = 2
      self.assertEqual(expected_num_transposes, num_transposes)
      self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
      self._assert_trans_nchw_to_nhwc('add_2-0-0', nodes)
      self._assert_map_nhwc_to_nchw('split-0', nodes)
      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
コード例 #2
0
  def testSplitWithNonConstAxis(self):
    if test.is_gpu_available(cuda_only=True):
      random_seed.set_random_seed(0)
      x = random_ops.truncated_normal([1, 784], seed=0)
      conv = _two_layer_model(x)
      dim = array_ops.placeholder(dtype='int32')
      split = array_ops.split(conv, 2, axis=dim)
      scale = constant_op.constant(0.1, shape=[32])
      offset = constant_op.constant(0.3, shape=[32])
      bn0 = nn.fused_batch_norm(split[0], scale, offset)
      bn1 = nn.fused_batch_norm(split[1], scale, offset)
      add = bn0[0] + bn1[0]
      output = array_ops.identity(add)

      with session.Session() as sess:
        output_val_ref = sess.run(output, feed_dict={dim: 3})

      with session.Session(config=_get_config()) as sess:
        metadata = config_pb2.RunMetadata()
        output_val = sess.run(output, run_metadata=metadata, feed_dict={dim: 3})

      nodes = []
      num_transposes = 0
      for node in metadata.cost_graph.node:
        if _is_transpose(node.name):
          num_transposes += 1
        nodes.append(node.name)

      expected_num_transposes = 2
      self.assertEqual(expected_num_transposes, num_transposes)
      self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
      self._assert_trans_nchw_to_nhwc('add_2-0-0', nodes)
      self._assert_map_nhwc_to_nchw('split-0', nodes)
      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
コード例 #3
0
    def call(self, inputs, training=None):
        training = self._get_training_value(training)
        inputs = tf.cast(
            inputs,
            tf.keras.mixed_precision.global_policy().variable_dtype)

        if training:
            outputs, mean, variance = nn.fused_batch_norm(inputs,
                                                          self.gamma,
                                                          self.beta,
                                                          epsilon=self.epsilon)
        else:
            outputs, mean, variance = nn.fused_batch_norm(
                inputs,
                self.gamma,
                self.beta,
                mean=self.moving_mean,
                variance=self.moving_variance,
                epsilon=self.epsilon,
                is_training=False)

        outputs = tf.cast(
            outputs,
            tf.keras.mixed_precision.global_policy().compute_dtype)

        @tf.custom_gradient
        def moving_avg_updates(x, moving_m, moving_v):
            def bw(dx):
                return dx, moving_m - mean, moving_v - variance

            return x, bw

        return moving_avg_updates(outputs, self.moving_mean,
                                  self.moving_variance)
コード例 #4
0
  def testInference(self):
    channel = 3
    x_shape = [2, 2, 6, channel]
    scale_shape = [channel]
    x_val = np.random.random_sample(x_shape).astype(np.float32)
    scale_val = np.random.random_sample(scale_shape).astype(np.float32)

    offset_val = np.random.random_sample(scale_shape).astype(np.float32)
    data_format = "NHWC"
    with self.test_session() as sess, self.test_scope():
      # To avoid constant folding
      t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
      offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
      epsilon = 0.001
      y_ref, mean_ref, var_ref = self._reference_training(
          x_val, scale_val, offset_val, epsilon, data_format)
      y, mean, variance = nn.fused_batch_norm(
          t_val,
          scale,
          offset,
          mean=mean_ref,
          variance=var_ref,
          epsilon=epsilon,
          data_format=data_format,
          is_training=False)

      y_val, _, _ = sess.run(
          [y, mean,
           variance], {t_val: x_val,
                       scale: scale_val,
                       offset: offset_val})
      self.assertAllClose(y_val, y_ref, atol=1e-3)
コード例 #5
0
  def testBasic(self):
    x_shape = [2, 2, 6, 2]
    scale_shape = [2]
    x_val = np.random.random_sample(x_shape).astype(np.float32)
    scale_val = np.random.random_sample(scale_shape).astype(np.float32)

    offset_val = np.random.random_sample(scale_shape).astype(np.float32)
    mean_val = np.random.random_sample(scale_shape).astype(np.float32)
    var_val = np.random.random_sample(scale_shape).astype(np.float32)
    data_format = "NHWC"
    with self.test_session() as sess, self.test_scope():
      # To avoid constant folding
      t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
      scale = array_ops.placeholder(np.float32, shape=[2], name="scale")
      offset = array_ops.placeholder(np.float32, shape=[2], name="offset")
      epsilon = 0.001
      y, mean, var = nn.fused_batch_norm(
          t_val,
          scale,
          offset,
          mean=None,
          variance=None,
          epsilon=epsilon,
          data_format=data_format,
          is_training=True)
      y_val, mean_val, var_val = sess.run(
          [y, mean, var], {t_val: x_val,
                           scale: scale_val,
                           offset: offset_val})
      y_ref, mean_ref, var_ref = self._reference_training(
          x_val, scale_val, offset_val, epsilon, data_format)
      self.assertAllClose(mean_val, mean_ref, atol=1e-3)
      self.assertAllClose(y_val, y_ref, atol=1e-3)
      self.assertAllClose(var_val, var_ref, atol=1e-3)
コード例 #6
0
    def _testLearning(self, use_gradient_checker, data_format):
        channel = 3
        x_shape = [2, 2, 6, channel]
        scale_shape = [channel]
        x_val = np.random.random_sample(x_shape).astype(np.float32)
        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
        offset_val = np.random.random_sample(scale_shape).astype(np.float32)
        mean_val = np.random.random_sample(scale_shape).astype(np.float32)
        var_val = np.random.random_sample(scale_shape).astype(np.float32)
        epsilon = 0.001
        data_format_src = "NHWC"
        y_ref, mean_ref, var_ref = self._reference_training(
            x_val, scale_val, offset_val, epsilon, data_format_src)

        with self.cached_session() as sess, self.test_scope():
            # To avoid constant folding
            x_val_converted = test_utils.ConvertBetweenDataFormats(
                x_val, data_format_src, data_format)
            y_ref_converted = test_utils.ConvertBetweenDataFormats(
                y_ref, data_format_src, data_format)

            t_val = array_ops.placeholder(np.float32,
                                          shape=x_val_converted.shape,
                                          name="x")
            scale = array_ops.placeholder(np.float32,
                                          shape=scale_shape,
                                          name="scale")
            offset = array_ops.placeholder(np.float32,
                                           shape=scale_shape,
                                           name="offset")
            y, mean, var = nn.fused_batch_norm(t_val,
                                               scale,
                                               offset,
                                               mean=None,
                                               variance=None,
                                               epsilon=epsilon,
                                               data_format=data_format,
                                               is_training=True)
            # Check gradient.
            if use_gradient_checker:
                err = gradient_checker.compute_gradient_error(
                    t_val,
                    x_val_converted.shape,
                    y,
                    x_val_converted.shape,
                    extra_feed_dict={
                        t_val: x_val_converted,
                        scale: scale_val,
                        offset: offset_val
                    })
                self.assertLess(err, 1e-3)

            y_val, mean_val, var_val = sess.run([y, mean, var], {
                t_val: x_val_converted,
                scale: scale_val,
                offset: offset_val
            })
            self.assertAllClose(mean_val, mean_ref, atol=1e-3)
            self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
            self.assertAllClose(var_val, var_ref, atol=1e-3)
コード例 #7
0
 def loop_fn(i):
   with g:
     x1 = array_ops.gather(x, i)
     outputs = nn.fused_batch_norm(
         x1,
         scale,
         offset,
         mean=mean,
         variance=variance,
         epsilon=0.01,
         data_format=data_format,
         is_training=is_training)
     outputs = list(outputs)
     # We only test the first value of outputs when is_training is
     # False. It looks like CPU and GPU have different outputs for
     # batch_mean and batch_variance for this case.
     if not is_training:
       outputs[1] = constant_op.constant(0.)
       outputs[2] = constant_op.constant(0.)
     loss = nn.l2_loss(outputs[0])
   if is_training:
     gradients = g.gradient(loss, [x1, scale, offset])
   else:
     gradients = [constant_op.constant(0.)] * 3
   return outputs + gradients
コード例 #8
0
            def my_graph(a):
                with ops.device("/device:IPU:0"):
                    with variable_scope.variable_scope("", use_resource=True):

                        beta = variable_scope.get_variable(
                            "x",
                            dtype=np.float32,
                            shape=[4],
                            initializer=init_ops.constant_initializer(0.0))
                        gamma = variable_scope.get_variable(
                            "y",
                            dtype=np.float32,
                            shape=[4],
                            initializer=init_ops.constant_initializer(1.0))

                        b_mean, b_var = nn.moments(a, [0, 1, 2],
                                                   name='moments')

                        normed = nn.fused_batch_norm(a,
                                                     gamma,
                                                     beta,
                                                     b_mean,
                                                     b_var,
                                                     is_training=False)
                        return normed
コード例 #9
0
  def testInference(self):
    channel = 3
    x_shape = [2, 2, 6, channel]
    scale_shape = [channel]
    x_val = np.random.random_sample(x_shape).astype(np.float32)
    scale_val = np.random.random_sample(scale_shape).astype(np.float32)

    offset_val = np.random.random_sample(scale_shape).astype(np.float32)
    data_format = "NHWC"
    with self.test_session() as sess, self.test_scope():
      # To avoid constant folding
      t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
      offset = array_ops.placeholder(
          np.float32, shape=scale_shape, name="offset")
      epsilon = 0.001
      y_ref, mean_ref, var_ref = self._reference_training(
          x_val, scale_val, offset_val, epsilon, data_format)
      y, mean, variance = nn.fused_batch_norm(
          t_val,
          scale,
          offset,
          mean=mean_ref,
          variance=var_ref,
          epsilon=epsilon,
          data_format=data_format,
          is_training=False)

      y_val, _, _ = sess.run(
          [y, mean,
           variance], {t_val: x_val,
                       scale: scale_val,
                       offset: offset_val})
      self.assertAllClose(y_val, y_ref, atol=1e-3)
コード例 #10
0
 def _ComputeBN(self, inputs, paddings, gamma, beta, norm_mean,
                norm_variance):
     p = self.params
     with tf.control_dependencies([
             py_utils.assert_greater_equal(norm_variance,
                                           tf.zeros_like(norm_variance)),
             py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                         tf.shape(norm_mean)),
             py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                         tf.shape(norm_variance)),
     ]):
         if p.use_fused_batch_norm_for_eval and (self.do_eval
                                                 or p.freeze_bn_stats):
             bn_output, _, _ = nn.fused_batch_norm(inputs,
                                                   gamma,
                                                   beta,
                                                   norm_mean,
                                                   norm_variance,
                                                   self._epsilon,
                                                   is_training=False)
         else:
             bn_output = tf.nn.batch_normalization(inputs, norm_mean,
                                                   norm_variance, beta,
                                                   gamma, self._epsilon)
         if p.set_padded_output_to_zero:
             bn_output = py_utils.ApplyPadding(paddings, bn_output)
     return bn_output
コード例 #11
0
 def _fused_batch_norm_training():
   return nn.fused_batch_norm(
       inputs,
       gamma,
       beta,
       epsilon=self.epsilon,
       data_format=self._data_format)
コード例 #12
0
 def _fused_batch_norm_training():
   return nn.fused_batch_norm(
       inputs,
       gamma,
       beta,
       epsilon=self.epsilon,
       data_format=self._data_format)
コード例 #13
0
 def loop_fn(i):
   with g:
     x1 = array_ops.gather(x, i)
     outputs = nn.fused_batch_norm(
         x1,
         scale,
         offset,
         mean=mean,
         variance=variance,
         epsilon=0.01,
         data_format=data_format,
         is_training=is_training)
     outputs = list(outputs)
     # We only test the first value of outputs when is_training is False.
     # It looks like CPU and GPU have different outputs for batch_mean
     # and batch_variance for this case.
     if not is_training:
       outputs[1] = constant_op.constant(0.)
       outputs[2] = constant_op.constant(0.)
     loss = nn.l2_loss(outputs[0])
   if is_training:
     gradients = g.gradient(loss, [x1, scale, offset])
   else:
     gradients = [constant_op.constant(0.)] * 3
   return outputs + gradients
コード例 #14
0
  def _testLearning(self, use_gradient_checker, data_format):
    channel = 3
    x_shape = [2, 2, 6, channel]
    scale_shape = [channel]
    x_val = np.random.random_sample(x_shape).astype(np.float32)
    scale_val = np.random.random_sample(scale_shape).astype(np.float32)
    offset_val = np.random.random_sample(scale_shape).astype(np.float32)
    mean_val = np.random.random_sample(scale_shape).astype(np.float32)
    var_val = np.random.random_sample(scale_shape).astype(np.float32)
    epsilon = 0.001
    data_format_src = "NHWC"
    # When in training mode, fused_batchnorm applies an implicit Bessel's
    # correction. So we have to use the corrected variance here, as well.
    y_ref, mean_ref, _, var_ref_corr = self._reference_training(
        x_val, scale_val, offset_val, epsilon, data_format_src)

    with self.cached_session() as sess, self.test_scope():
      # To avoid constant folding
      x_val_converted = test_utils.ConvertBetweenDataFormats(
          x_val, data_format_src, data_format)
      y_ref_converted = test_utils.ConvertBetweenDataFormats(
          y_ref, data_format_src, data_format)

      t_val = array_ops.placeholder(
          np.float32, shape=x_val_converted.shape, name="x")
      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
      offset = array_ops.placeholder(
          np.float32, shape=scale_shape, name="offset")
      y, mean, var = nn.fused_batch_norm(
          t_val,
          scale,
          offset,
          mean=None,
          variance=None,
          epsilon=epsilon,
          data_format=data_format,
          is_training=True)
      # Check gradient.
      if use_gradient_checker:
        err = gradient_checker.compute_gradient_error(
            t_val,
            x_val_converted.shape,
            y,
            x_val_converted.shape,
            extra_feed_dict={
                t_val: x_val_converted,
                scale: scale_val,
                offset: offset_val
            })
        self.assertLess(err, 1e-3)

      y_val, mean_val, var_val = sess.run([y, mean, var], {
          t_val: x_val_converted,
          scale: scale_val,
          offset: offset_val
      })
      self.assertAllClose(mean_val, mean_ref, atol=1e-3)
      self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
      self.assertAllClose(var_val, var_ref_corr, atol=1e-3)
コード例 #15
0
def _model_with_second_port():
    random_seed.set_random_seed(0)
    x = random_ops.truncated_normal([2, 5, 5, 4], seed=0)
    scale = constant_op.constant(0.1, shape=[4])
    offset = constant_op.constant(0.3, shape=[4])
    y, mean, _ = nn.fused_batch_norm(x, scale, offset)
    mul = math_ops.add(y, mean)
    output = array_ops.identity(mul)
    return output
コード例 #16
0
 def _fused_batch_norm_inference():
     return nn.fused_batch_norm(inputs,
                                gamma,
                                beta,
                                mean=self.moving_mean,
                                variance=self.moving_variance,
                                epsilon=self.epsilon,
                                is_training=False,
                                data_format=self._data_format)
コード例 #17
0
def _model_with_second_port():
  random_seed.set_random_seed(0)
  x = random_ops.truncated_normal([2, 5, 5, 4], seed=0)
  scale = constant_op.constant(0.1, shape=[4])
  offset = constant_op.constant(0.3, shape=[4])
  y, mean, _ = nn.fused_batch_norm(x, scale, offset)
  mul = math_ops.add(y, mean)
  output = array_ops.identity(mul)
  return output
コード例 #18
0
def my_batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
    """Applies batch normalization on x given mean, var, beta and gamma.
    I.e. returns:
    `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
    Arguments:
        x: Input tensor or variable.
        mean: Mean of batch.
        var: Variance of batch.
        beta: Tensor with which to center the input.
        gamma: Tensor by which to scale the input.
        axis: Integer, the axis that should be normalized.
            (typically the features axis).
        epsilon: Fuzz factor.
    Returns:
        A tensor.
    """

    if K.ndim(x) == 4:
        print("hey")
        # The CPU implementation of `fused_batch_norm` only supports NHWC
        if axis == 1 or axis == -3:
            tf_data_format = 'NCHW'
        elif axis == 3 or axis == -1:
            tf_data_format = 'NHWC'
        else:
            tf_data_format = None

        if (tf_data_format == 'NHWC'
                or tf_data_format == 'NCHW' and _has_nchw_support()):
            # The mean / var / beta / gamma tensors may be broadcasted
            # so they may have extra axes of size 1, which should be squeezed.
            if K.ndim(mean) > 1:
                mean = array_ops.reshape(mean, [-1])
            if K.ndim(var) > 1:
                var = array_ops.reshape(var, [-1])
            if beta is None:
                beta = zeros_like(mean)
            elif K.ndim(beta) > 1:
                beta = array_ops.reshape(beta, [-1])
            if gamma is None:
                gamma = ones_like(mean)
            elif K.ndim(gamma) > 1:
                gamma = array_ops.reshape(gamma, [-1])
        y, _, _ = nn.fused_batch_norm(x,
                                      gamma,
                                      beta,
                                      epsilon=epsilon,
                                      mean=mean,
                                      variance=var,
                                      data_format=tf_data_format,
                                      is_training=False)
        return y

    return tf.map_fn(
        lambda xx: nn.batch_normalization(xx, mean, var, beta, gamma, epsilon),
        x)
コード例 #19
0
 def _fused_batch_norm_inference():
   return nn.fused_batch_norm(
       inputs,
       gamma,
       beta,
       mean=self.moving_mean,
       variance=self.moving_variance,
       epsilon=self.epsilon,
       is_training=False,
       data_format=self._data_format)
コード例 #20
0
def template(x_shape=[2, 3, 4, 5], data_format="NHWC", description: str = ""):
    from tensorflow.python.ops import nn
    x = tf.placeholder(np.float32, x_shape)
    scale = tf.placeholder(
        np.float32, x_shape[-1] if data_format == "NHWC" else x_shape[1])
    bias = tf.placeholder(np.float32,
                          x_shape[-1] if data_format == "NHWC" else x_shape[1])
    mean = tf.placeholder(np.float32,
                          x_shape[-1] if data_format == "NHWC" else x_shape[1])
    variance = tf.placeholder(
        np.float32, x_shape[-1] if data_format == "NHWC" else x_shape[1])
    y, _, _ = nn.fused_batch_norm(x,
                                  scale,
                                  bias,
                                  mean,
                                  variance,
                                  data_format=data_format,
                                  is_training=False)

    vx = np.random.rand(*x_shape).astype(np.float32)
    vs = np.random.rand(
        *[x_shape[-1] if data_format == "NHWC" else x_shape[1]]).astype(
            np.float32)
    vb = np.random.rand(
        *[x_shape[-1] if data_format == "NHWC" else x_shape[1]]).astype(
            np.float32)
    vm = np.random.rand(
        *[x_shape[-1] if data_format == "NHWC" else x_shape[1]]).astype(
            np.float32)
    vv = np.random.rand(
        *[x_shape[-1] if data_format == "NHWC" else x_shape[1]]).astype(
            np.float32)
    with tf.Session() as sess:
        vy, = sess.run([y], {
            x: vx,
            scale: vs,
            bias: vb,
            mean: vm,
            variance: vv
        })
        graph = TensorFlowConverter(sess, batch_size=2).convert(
            [x, scale, bias, mean, variance], [y])

    generate_kernel_test_case(
        description=f"[TensorFlow] FusedBatchNorm {description}",
        graph=graph,
        inputs={
            graph.inputs[0]: vx,
            graph.inputs[1]: vs,
            graph.inputs[2]: vb,
            graph.inputs[3]: vm,
            graph.inputs[4]: vv
        },
        expected={graph.outputs[0]: vy},
    )
コード例 #21
0
ファイル: normalization.py プロジェクト: zys-123/tensorflow
 def _fused_batch_norm_training():
   return nn.fused_batch_norm(
       inputs,
       gamma,
       beta,
       mean=self.moving_mean,
       variance=_maybe_add_or_remove_bessels_correction(
           self.moving_variance, remove=False),
       epsilon=self.epsilon,
       is_training=True,
       data_format=self._data_format,
       exponential_avg_factor=exponential_avg_factor)
コード例 #22
0
            def model(x, y, z):
                scale = gen_array_ops.broadcast_to(z, shape=[65536])
                offset = scale
                b_mean, b_var = nn.moments(x, [0, 1, 2], name='moments')
                a = nn.fused_batch_norm(x,
                                        scale,
                                        offset,
                                        b_mean,
                                        b_var,
                                        1e-3,
                                        is_training=False,
                                        name="a")
                b = nn.fused_batch_norm(y,
                                        scale,
                                        offset,
                                        b_mean,
                                        b_var,
                                        1e-3,
                                        is_training=False,
                                        name="b")

                return a[0] + b[0]
コード例 #23
0
 def _fused_batch_norm_training():
     outputs, mean, variance = nn.fused_batch_norm(
             inputs, gamma, beta, epsilon=epsilon, data_format=data_format)
     if renorm:
         moving_inv = math_ops.rsqrt(moving_variance + epsilon)
         r = tf.stop_gradient(tf.clip_by_value(tf.sqrt(variance + epsilon) * moving_inv,
                                                 1/RMAX,
                                                 RMAX))
         d = tf.stop_gradient(tf.clip_by_value((mean - moving_mean) * moving_inv,
                                                 -DMAX,
                                                 DMAX))
         outputs = outputs * r + d
     return outputs, mean, variance
コード例 #24
0
  def _testLearning(self, use_gradient_checker):
    channel = 3
    x_shape = [2, 2, 6, channel]
    scale_shape = [channel]
    x_val = np.random.random_sample(x_shape).astype(np.float32)
    scale_val = np.random.random_sample(scale_shape).astype(np.float32)

    offset_val = np.random.random_sample(scale_shape).astype(np.float32)
    mean_val = np.random.random_sample(scale_shape).astype(np.float32)
    var_val = np.random.random_sample(scale_shape).astype(np.float32)
    data_format = "NHWC"
    with self.test_session() as sess, self.test_scope():
      # To avoid constant folding
      t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
      offset = array_ops.placeholder(
          np.float32, shape=scale_shape, name="offset")
      epsilon = 0.001
      y, mean, var = nn.fused_batch_norm(
          t_val,
          scale,
          offset,
          mean=None,
          variance=None,
          epsilon=epsilon,
          data_format=data_format,
          is_training=True)
      # Check gradient.
      if use_gradient_checker:
        err = gradient_checker.compute_gradient_error(
            t_val,
            x_shape,
            y,
            x_shape,
            extra_feed_dict={
                t_val: x_val,
                scale: scale_val,
                offset: offset_val
            })
        self.assertLess(err, 1e-3)

      y_val, mean_val, var_val = sess.run(
          [y, mean, var], {t_val: x_val,
                           scale: scale_val,
                           offset: offset_val})
      y_ref, mean_ref, var_ref = self._reference_training(
          x_val, scale_val, offset_val, epsilon, data_format)
      self.assertAllClose(mean_val, mean_ref, atol=1e-3)
      self.assertAllClose(y_val, y_ref, atol=1e-3)
      self.assertAllClose(var_val, var_ref, atol=1e-3)
コード例 #25
0
ファイル: bn_layers.py プロジェクト: UofT-EcoSystem/lingvo
    def FProp(self, theta, inputs, paddings=None):
        """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        with tf.name_scope(p.name):
            norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
                theta, inputs, paddings)
            with tf.control_dependencies([
                    py_utils.assert_greater_equal(
                        norm_variance, tf.zeros_like(norm_variance)),
                    py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                                tf.shape(norm_mean)),
                    py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                                tf.shape(norm_variance)),
            ]):
                if p.use_fused_batch_norm_for_eval and self.do_eval:
                    bn_output, _, _ = nn.fused_batch_norm(inputs,
                                                          gamma,
                                                          beta,
                                                          norm_mean,
                                                          norm_variance,
                                                          self._epsilon,
                                                          is_training=False)
                else:
                    bn_output = tf.nn.batch_normalization(
                        inputs, norm_mean, norm_variance, beta, gamma,
                        self._epsilon)

                if p.set_padded_output_to_zero:
                    bn_output *= 1.0 - paddings

            return bn_output
コード例 #26
0
    def testInference(self, data_format):
        channel = 3
        x_shape = [2, 2, 6, channel]
        scale_shape = [channel]
        x_val = np.random.random_sample(x_shape).astype(np.float32)
        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
        offset_val = np.random.random_sample(scale_shape).astype(np.float32)
        epsilon = 0.001
        exponential_avg_factor = 1.0
        data_format_src = "NHWC"
        y_ref, mean_ref, var_ref, _ = self._reference_training(
            x_val, scale_val, offset_val, None, None, epsilon,
            exponential_avg_factor, data_format_src)

        with self.session() as sess, self.test_scope():
            # To avoid constant folding
            x_val_converted = test_utils.ConvertBetweenDataFormats(
                x_val, data_format_src, data_format)
            y_ref_converted = test_utils.ConvertBetweenDataFormats(
                y_ref, data_format_src, data_format)

            t_val = array_ops.placeholder(np.float32,
                                          shape=x_val_converted.shape,
                                          name="x")
            scale = array_ops.placeholder(np.float32,
                                          shape=scale_shape,
                                          name="scale")
            offset = array_ops.placeholder(np.float32,
                                           shape=scale_shape,
                                           name="offset")
            y, mean, variance = nn.fused_batch_norm(t_val,
                                                    scale,
                                                    offset,
                                                    mean=mean_ref,
                                                    variance=var_ref,
                                                    epsilon=epsilon,
                                                    data_format=data_format,
                                                    is_training=False)

            y_val, _, _ = sess.run([y, mean, variance], {
                t_val: x_val_converted,
                scale: scale_val,
                offset: offset_val
            })
            self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
コード例 #27
0
    def testBatchNormalizeFused(self):
        x = array_ops.placeholder(np.float32, [4, 64, 64, 4], name="a")

        with ops.device("/device:IPU:0"):
            with variable_scope.variable_scope("", use_resource=True):

                beta = variable_scope.get_variable(
                    "x",
                    dtype=np.float32,
                    shape=[4],
                    initializer=init_ops.constant_initializer(0.0))
                gamma = variable_scope.get_variable(
                    "y",
                    dtype=np.float32,
                    shape=[4],
                    initializer=init_ops.constant_initializer(1.0))

                b_mean, b_var = nn.moments(x, [0, 1, 2], name='moments')

                normed = nn.fused_batch_norm(x,
                                             gamma,
                                             beta,
                                             b_mean,
                                             b_var,
                                             is_training=False)

        with ops.device('cpu'):
            report = gen_ipu_ops.ipu_event_trace()

        tu.configure_ipu_system()

        with tu.ipu_session() as sess:
            sess.run(report)

            sess.run(variables.global_variables_initializer())
            result, _, _ = sess.run(normed, {x: np.zeros([4, 64, 64, 4])})
            self.assertAllClose(result, np.zeros([4, 64, 64, 4]))

            rep = sess.run(report)
            s = tu.extract_all_strings_from_event_trace(rep)
            cs = tu.get_compute_sets_from_report(s)

            bl = ['*convert*/Cast*']
            self.assertTrue(tu.check_compute_sets_not_in_blacklist(cs, bl))
コード例 #28
0
 def call(self, inputs, training=None):
     training = self._get_training_value(training)
     if self.subdivisions <= 1 or self.subdivisions is None:
         return super().call(inputs, training=training)
     else:
         if self.renorm is False and training is False and self.fused:
             # outputs = self._fused_batch_norm(inputs, training=False)
             beta = self.beta if self.center else self._beta_const
             gamma = self.gamma if self.scale else self._gamma_const
             outputs, mean, variance = nn.fused_batch_norm(
                 inputs,
                 gamma,
                 beta,
                 mean=self.moving_mean,
                 variance=self.moving_variance,
                 epsilon=self.epsilon,
                 is_training=False,
                 data_format=self._data_format)
             return outputs
         return self._subdiv_batch_norm(inputs, training=training)
コード例 #29
0
                    def my_graph(a):
                        beta = variable_scope.get_variable(
                            "x",
                            dtype=np.float16,
                            shape=[4],
                            initializer=init_ops.constant_initializer(0.0))
                        gamma = variable_scope.get_variable(
                            "y",
                            dtype=np.float16,
                            shape=[4],
                            initializer=init_ops.constant_initializer(1.0))

                        b_mean, b_var = nn.moments(a, [0, 1, 2],
                                                   name='moments')

                        normed = nn.fused_batch_norm(a,
                                                     gamma,
                                                     beta,
                                                     b_mean,
                                                     b_var,
                                                     is_training=False)
                        return normed
コード例 #30
0
ファイル: normalization.py プロジェクト: zys-123/tensorflow
  def call(self, inputs):
    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.shape
    ndims = len(input_shape)

    # Broadcasting only necessary for norm where the axis is not just
    # the last dimension
    broadcast_shape = [1] * ndims
    for dim in self.axis:
      broadcast_shape[dim] = input_shape.dims[dim].value
    def _broadcast(v):
      if (v is not None and len(v.shape) != ndims and
          self.axis != [ndims - 1]):
        return array_ops.reshape(v, broadcast_shape)
      return v

    if not self._fused:
      input_dtype = inputs.dtype
      if input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32':
        # If mixed precision is used, cast inputs to float32 so that this is at
        # least as numerically stable as the fused version.
        inputs = math_ops.cast(inputs, 'float32')

      # Calculate the moments on the last axis (layer activations).
      mean, variance = nn.moments(inputs, self.axis, keep_dims=True)

      scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

      # Compute layer normalization using the batch_normalization function.
      outputs = nn.batch_normalization(
          inputs,
          mean,
          variance,
          offset=offset,
          scale=scale,
          variance_epsilon=self.epsilon)
      outputs = math_ops.cast(outputs, input_dtype)
    else:
      # Collapse dims before self.axis, and dims in self.axis
      pre_dim, in_dim = (1, 1)
      axis = sorted(self.axis)
      tensor_shape = array_ops.shape(inputs)
      for dim in range(0, ndims):
        dim_tensor = tensor_shape[dim]
        if dim < axis[0]:
          pre_dim = pre_dim * dim_tensor
        else:
          assert dim in axis
          in_dim = in_dim * dim_tensor

      squeezed_shape = [1, pre_dim, in_dim, 1]
      # This fused operation requires reshaped inputs to be NCHW.
      data_format = 'NCHW'

      inputs = array_ops.reshape(inputs, squeezed_shape)

      def _set_const_tensor(val, dtype, shape):
        return array_ops.fill(shape, constant_op.constant(val, dtype=dtype))

      # self.gamma and self.beta have the wrong shape for fused_batch_norm, so
      # we cannot pass them as the scale and offset parameters. Therefore, we
      # create two constant tensors in correct shapes for fused_batch_norm and
      # later construct a separate calculation on the scale and offset.
      scale = _set_const_tensor(1.0, self.dtype, [pre_dim])
      offset = _set_const_tensor(0.0, self.dtype, [pre_dim])

      # Compute layer normalization using the fused_batch_norm function.
      outputs, _, _ = nn.fused_batch_norm(
          inputs,
          scale=scale,
          offset=offset,
          epsilon=self.epsilon,
          data_format=data_format)

      outputs = array_ops.reshape(outputs, tensor_shape)

      scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

      if scale is not None:
        outputs = outputs * math_ops.cast(scale, outputs.dtype)
      if offset is not None:
        outputs = outputs + math_ops.cast(offset, outputs.dtype)

    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    return outputs
コード例 #31
0
def fused_instance_norm(inputs,
                        center=True,
                        scale=True,
                        epsilon=1e-6,
                        activation_fn=None,
                        param_initializers=None,
                        reuse=None,
                        variables_collections=None,
                        outputs_collections=None,
                        trainable=True,
                        data_format=DATA_FORMAT_NHWC,
                        scope=None):
  """Functional interface for the instance normalization layer.

  Reference: https://arxiv.org/abs/1607.08022.

    "Instance Normalization: The Missing Ingredient for Fast Stylization"
    Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky

  Args:
    inputs: A tensor with 2 or more dimensions, where the first dimension has
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC` and the second dimension if `data_format` is
      `NCHW`.
    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    epsilon: Small float added to variance to avoid dividing by zero.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: Optional initializers for beta, gamma, moving mean and
      moving variance.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional collections for the variables.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    data_format: A string. `NHWC` (default) and `NCHW` are supported.
    scope: Optional scope for `variable_scope`.

  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
    ValueError: If the rank of `inputs` is undefined.
    ValueError: If rank or channels dimension of `inputs` is undefined.
  """
  inputs = ops.convert_to_tensor(inputs)
  inputs_shape = inputs.shape
  inputs_rank = inputs.shape.ndims

  if inputs_rank is None:
    raise ValueError('Inputs %s has undefined rank.' % inputs.name)
  if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
    raise ValueError('data_format has to be either NCHW or NHWC.')

  with variable_scope.variable_scope(
      scope, 'InstanceNorm', [inputs], reuse=reuse) as sc:
    if data_format == DATA_FORMAT_NCHW:
      reduction_axis = 1
      # For NCHW format, rather than relying on implicit broadcasting, we
      # explicitly reshape the params to params_shape_broadcast when computing
      # the moments and the batch normalization.
      params_shape_broadcast = list(
          [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
    else:
      reduction_axis = inputs_rank - 1
      params_shape_broadcast = None
    moments_axes = list(range(inputs_rank))
    del moments_axes[reduction_axis]
    del moments_axes[0]
    params_shape = inputs_shape[reduction_axis:reduction_axis + 1]
    if not params_shape.is_fully_defined():
      raise ValueError('Inputs %s has undefined channels dimension %s.' % (
          inputs.name, params_shape))

    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    dtype = inputs.dtype.base_dtype
    if param_initializers is None:
      param_initializers = {}
    if center:
      beta_collections = utils.get_variable_collections(
          variables_collections, 'beta')
      beta_initializer = param_initializers.get(
          'beta', init_ops.zeros_initializer())
      beta = variables.model_variable('beta',
                                      shape=params_shape,
                                      dtype=dtype,
                                      initializer=beta_initializer,
                                      collections=beta_collections,
                                      trainable=trainable)
      if params_shape_broadcast:
        beta = array_ops.reshape(beta, params_shape_broadcast)
    if scale:
      gamma_collections = utils.get_variable_collections(
          variables_collections, 'gamma')
      gamma_initializer = param_initializers.get(
          'gamma', init_ops.ones_initializer())
      gamma = variables.model_variable('gamma',
                                       shape=params_shape,
                                       dtype=dtype,
                                       initializer=gamma_initializer,
                                       collections=gamma_collections,
                                       trainable=trainable)
      if params_shape_broadcast:
        gamma = array_ops.reshape(gamma, params_shape_broadcast)

    if data_format == DATA_FORMAT_NHWC:
      inputs = array_ops.transpose(inputs, list(range(1, reduction_axis)) + [0, reduction_axis])
    if data_format == DATA_FORMAT_NCHW:
      inputs = array_ops.transpose(inputs, list(range(2, inputs_rank)) + [0, reduction_axis])
    hw, n, c = inputs.shape.as_list()[:-2], inputs.shape[-2].value, inputs.shape[-1].value
    inputs = array_ops.reshape(inputs, [1] + hw + [n * c])
    if inputs.shape.ndims != 4:
        # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W]
        if inputs.shape.ndims > 4:
            inputs_ndims4_shape = [1, hw[0], -1, n * c]
        else:
            inputs_ndims4_shape = [1, 1, -1, n * c]
        inputs = array_ops.reshape(inputs, inputs_ndims4_shape)
    beta = array_ops.reshape(array_ops.tile(beta[None, :], [n, 1]), [-1])
    gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [n, 1]), [-1])

    outputs, _, _ = nn.fused_batch_norm(
        inputs, gamma, beta, epsilon=epsilon,
        data_format=DATA_FORMAT_NHWC, name='instancenorm')

    outputs = array_ops.reshape(outputs, hw + [n, c])
    if data_format == DATA_FORMAT_NHWC:
      outputs = array_ops.transpose(outputs, [inputs_rank - 2] + list(range(inputs_rank - 2)) + [inputs_rank - 1])
    if data_format == DATA_FORMAT_NCHW:
      outputs = array_ops.transpose(outputs, [inputs_rank - 2, inputs_rank - 1] + list(range(inputs_rank - 2)))

    # if data_format == DATA_FORMAT_NHWC:
    #   inputs = array_ops.transpose(inputs, [0, reduction_axis] + list(range(1, reduction_axis)))
    # inputs_nchw_shape = inputs.shape
    # inputs = array_ops.reshape(inputs, [1, -1] + inputs_nchw_shape.as_list()[2:])
    # if inputs.shape.ndims != 4:
    #     # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W]
    #     if inputs.shape.ndims > 4:
    #         inputs_ndims4_shape = inputs.shape.as_list()[:2] + [-1, inputs_nchw_shape.as_list()[-1]]
    #     else:
    #         inputs_ndims4_shape = inputs.shape.as_list()[:2] + [1, -1]
    #     inputs = array_ops.reshape(inputs, inputs_ndims4_shape)
    # beta = array_ops.reshape(array_ops.tile(beta[None, :], [inputs_nchw_shape[0].value, 1]), [-1])
    # gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [inputs_nchw_shape[0].value, 1]), [-1])
    #
    # outputs, _, _ = nn.fused_batch_norm(
    #     inputs, gamma, beta, epsilon=epsilon,
    #     data_format=DATA_FORMAT_NCHW, name='instancenorm')
    #
    # outputs = array_ops.reshape(outputs, inputs_nchw_shape)
    # if data_format == DATA_FORMAT_NHWC:
    #   outputs = array_ops.transpose(outputs, [0] + list(range(2, inputs_rank)) + [1])

    if activation_fn is not None:
      outputs = activation_fn(outputs)
    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
コード例 #32
0
    def _testLearning(self, use_gradient_checker, data_format,
                      exponential_avg_factor):
        channel = 3
        x_shape = [2, 2, 6, channel]
        scale_shape = [channel]
        x_val = np.random.random_sample(x_shape).astype(np.float32)
        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
        offset_val = np.random.random_sample(scale_shape).astype(np.float32)
        mean_val = np.random.random_sample(scale_shape).astype(np.float32)
        var_val_corr = np.random.random_sample(scale_shape).astype(np.float32)
        epsilon = 0.001
        data_format_src = "NHWC"
        # When in training mode, fused_batchnorm applies an implicit Bessel's
        # correction. So we have to use the corrected variance here, as well.
        y_ref, mean_ref, _, var_ref_corr = self._reference_training(
            x_val, scale_val, offset_val, mean_val, var_val_corr, epsilon,
            exponential_avg_factor, data_format_src)

        with self.session() as sess, self.test_scope():
            # To avoid constant folding
            x_val_converted = test_utils.ConvertBetweenDataFormats(
                x_val, data_format_src, data_format)
            y_ref_converted = test_utils.ConvertBetweenDataFormats(
                y_ref, data_format_src, data_format)

            t_val = array_ops.placeholder(np.float32,
                                          shape=x_val_converted.shape,
                                          name="x")
            scale = array_ops.placeholder(np.float32,
                                          shape=scale_shape,
                                          name="scale")
            offset = array_ops.placeholder(np.float32,
                                           shape=scale_shape,
                                           name="offset")
            if exponential_avg_factor == 1.0:
                old_mean = None
                old_var = None
            else:
                old_mean = array_ops.placeholder(np.float32,
                                                 shape=scale_shape,
                                                 name="old_mean")
                old_var = array_ops.placeholder(np.float32,
                                                shape=scale_shape,
                                                name="old_var")
            y, mean, var = nn.fused_batch_norm(
                t_val,
                scale,
                offset,
                mean=old_mean,
                variance=old_var,
                epsilon=epsilon,
                exponential_avg_factor=exponential_avg_factor,
                data_format=data_format,
                is_training=True)
            if exponential_avg_factor == 1.0:
                feed_dict = {
                    t_val: x_val_converted,
                    scale: scale_val,
                    offset: offset_val,
                }
            else:
                feed_dict = {
                    t_val: x_val_converted,
                    scale: scale_val,
                    offset: offset_val,
                    old_mean: mean_val,
                    old_var: var_val_corr
                }
            # Check gradient.
            if use_gradient_checker:
                err = gradient_checker.compute_gradient_error(
                    t_val,
                    x_val_converted.shape,
                    y,
                    x_val_converted.shape,
                    extra_feed_dict=feed_dict)
                self.assertLess(err, 1e-3)

            y_tf, mean_tf, var_tf = sess.run([y, mean, var], feed_dict)
            self.assertAllClose(y_tf, y_ref_converted, atol=1e-3)
            self.assertAllClose(mean_tf, mean_ref, atol=1e-3)
            self.assertAllClose(var_tf, var_ref_corr, atol=1e-3)
コード例 #33
0
  def _testLearning(self, use_gradient_checker, data_format):
    channel = 3
    x_shape = [2, 2, 6, channel]
    scale_shape = [channel]
    x_val = np.random.random_sample(x_shape).astype(np.float32)
    scale_val = np.random.random_sample(scale_shape).astype(np.float32)
    offset_val = np.random.random_sample(scale_shape).astype(np.float32)
    mean_val = np.random.random_sample(scale_shape).astype(np.float32)
    var_val = np.random.random_sample(scale_shape).astype(np.float32)
    epsilon = 0.001
    data_format_src = "NHWC"
    y_ref, mean_ref, var_ref = self._reference_training(
        x_val, scale_val, offset_val, epsilon, data_format_src)

    # TODO(b/110530713): Support data format HWCN on GPU
    if self.device == "XLA_GPU" and data_format == "HWCN":
      self.skipTest("GPU does not support data format HWCN.")

    with self.test_session() as sess, self.test_scope():
      # To avoid constant folding
      x_val_converted = test_utils.ConvertBetweenDataFormats(
          x_val, data_format_src, data_format)
      y_ref_converted = test_utils.ConvertBetweenDataFormats(
          y_ref, data_format_src, data_format)

      t_val = array_ops.placeholder(
          np.float32, shape=x_val_converted.shape, name="x")
      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
      offset = array_ops.placeholder(
          np.float32, shape=scale_shape, name="offset")
      y, mean, var = nn.fused_batch_norm(
          t_val,
          scale,
          offset,
          mean=None,
          variance=None,
          epsilon=epsilon,
          data_format=data_format,
          is_training=True)
      # Check gradient.
      if use_gradient_checker:
        err = gradient_checker.compute_gradient_error(
            t_val,
            x_val_converted.shape,
            y,
            x_val_converted.shape,
            extra_feed_dict={
                t_val: x_val_converted,
                scale: scale_val,
                offset: offset_val
            })
        self.assertLess(err, 1e-3)

      y_val, mean_val, var_val = sess.run([y, mean, var], {
          t_val: x_val_converted,
          scale: scale_val,
          offset: offset_val
      })
      self.assertAllClose(mean_val, mean_ref, atol=1e-3)
      self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
      self.assertAllClose(var_val, var_ref, atol=1e-3)
コード例 #34
0
def fused_layer_norm(inputs,
                     center=True,
                     scale=True,
                     activation_fn=None,
                     reuse=None,
                     variables_collections=None,
                     outputs_collections=None,
                     trainable=True,
                     begin_norm_axis=1,
                     begin_params_axis=-1,
                     scope=None,
                     use_fused_batch_norm=False):
    with tf.compat.v1.variable_scope(scope, 'LayerNorm', [inputs],
                                     reuse=reuse) as sc:
        inputs = ops.convert_to_tensor(inputs)
        inputs_shape = inputs.shape
        inputs_rank = inputs_shape.ndims
        if inputs_rank is None:
            raise ValueError('Inputs %s has undefined rank.' % inputs.name)
        dtype = inputs.dtype.base_dtype
        if begin_norm_axis < 0:
            begin_norm_axis = inputs_rank + begin_norm_axis
        if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
            raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) '
                             'must be < rank(inputs) (%d)' %
                             (begin_params_axis, begin_norm_axis, inputs_rank))
        params_shape = inputs_shape[begin_params_axis:]
        if not params_shape.is_fully_defined():
            raise ValueError(
                'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' %
                (inputs.name, begin_params_axis, inputs_shape))
        # Allocate parameters for the beta and gamma of the normalization.
        beta, gamma = None, None
        if center:
            beta_collections = utils.get_variable_collections(
                variables_collections, 'beta')
            beta = variables.model_variable(
                'beta',
                shape=params_shape,
                dtype=dtype,
                initializer=init_ops.zeros_initializer(),
                collections=beta_collections,
                trainable=trainable)
        if scale:
            gamma_collections = utils.get_variable_collections(
                variables_collections, 'gamma')
            gamma = variables.model_variable(
                'gamma',
                shape=params_shape,
                dtype=dtype,
                initializer=init_ops.ones_initializer(),
                collections=gamma_collections,
                trainable=trainable)
        if use_fused_batch_norm:
            # get static TensorShape if fully defined,
            # otherwise retrieve shape tensor
            norm_shape = inputs.shape[begin_norm_axis:]
            if norm_shape.is_fully_defined():
                bn_shape = [1, -1, 1, numpy.prod(norm_shape.as_list())]
            else:
                norm_shape = tf.shape(input=inputs)[begin_norm_axis:]
                bn_shape = [1, -1, 1, tf.reduce_prod(input_tensor=norm_shape)]
            if inputs.get_shape().is_fully_defined():
                outputs_shape = inputs.get_shape()
            else:
                outputs_shape = tf.shape(input=inputs)
            inputs = array_ops.reshape(inputs, bn_shape)
            if inputs.get_shape().is_fully_defined():
                # static inputs TensorShape fully defined after reshape.
                ones = array_ops.ones(inputs.get_shape()[1],
                                      dtype=dtypes.float32)
                zeros = array_ops.zeros(inputs.get_shape()[1],
                                        dtype=dtypes.float32)
            else:
                # static inputs TensorShape NOT fully defined after reshape.
                # must use dynamic shape, which means these input tensors
                # have to be created at runtime, which causes a slowdown.
                scale_shape = tf.shape(input=inputs)[1]
                ones = array_ops.ones(scale_shape, dtype=dtypes.float32)
                zeros = array_ops.zeros(scale_shape, dtype=dtypes.float32)
            outputs, mean, variance = nn.fused_batch_norm(inputs,
                                                          ones,
                                                          zeros,
                                                          epsilon=1e-4,
                                                          data_format="NCHW")
            outputs = array_ops.reshape(outputs, outputs_shape)
            if center and scale:
                outputs = outputs * gamma + beta
            elif center:
                outputs = outputs + beta
            elif scale:
                outputs = outputs * gamma
        else:
            # Calculate the moments on the last axis (layer activations).
            norm_axes = list(range(begin_norm_axis, inputs_rank))
            mean, variance = nn.moments(inputs, norm_axes, keep_dims=True)
            # Compute layer normalization using the batch_normalization function.
            variance_epsilon = 1e-4
            outputs = nn.batch_normalization(inputs,
                                             mean,
                                             variance,
                                             offset=beta,
                                             scale=gamma,
                                             variance_epsilon=variance_epsilon)
            outputs.set_shape(inputs_shape)
        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return utils.collect_named_outputs(outputs_collections, sc.name,
                                           outputs)