Esempio n. 1
0
    def _build(self, inputs):
        """Connects the LayerNorm module into the graph.

        Args:
          inputs: a Tensor of shape `[batch_size, layer_dim]`.

        Returns:
          normalized: layer normalized outputs with same shape as inputs.

        Raises:
          base.NotSupportedError: If `inputs` has data type of `tf.float16`.
        """

        if inputs.dtype == tf.float16:
            raise base.NotSupportedError(
                "LayerNorm does not support `tf.float16`, insufficient "
                "precision for calculating sufficient statistics."
            )

        if inputs.get_shape().ndims != 2:
            raise base.NotSupportedError(
                "Layer normalization expects inputs of rank 2."
                " Got inputs of rank {}.".format(inputs.get_shape().ndims)
            )

        hidden_size = inputs.get_shape()[1].value

        if self.GAMMA not in self._initializers:
            self._initializers[self.GAMMA] = create_gamma_initializer()
        self._gamma = tf.get_variable(
            self.GAMMA,
            shape=[hidden_size],
            dtype=inputs.dtype,
            initializer=self._initializers[self.GAMMA],
            partitioner=self._partitioners.get(self.GAMMA),
            regularizer=self._regularizers.get(self.GAMMA),
        )

        if self.BETA not in self._initializers:
            self._initializers[self.BETA] = create_beta_initializer()
        self._beta = tf.get_variable(
            self.BETA,
            shape=[hidden_size],
            dtype=inputs.dtype,
            initializer=self._initializers[self.BETA],
            partitioner=self._partitioners.get(self.BETA),
            regularizer=self._regularizers.get(self.BETA),
        )

        mean, var = tf.nn.moments(inputs, [1], keep_dims=True)

        normalized = tf.nn.batch_normalization(inputs, mean, var, self._beta, self._gamma, self._eps)
        return normalized
Esempio n. 2
0
    def _build(self, input_batch, is_training, test_local_stats=False):
        """Connects the BatchNormV2 module into the graph.

    Args:
      input_batch: A Tensor of the same dimension as `len(data_format)`.
      is_training: A boolean to indicate if the module should be connected in
        training mode, meaning the moving averages are updated. Can be a Tensor.
      test_local_stats: A boolean to indicate if local batch statistics should
        be used when `is_training=False`. If not, moving averages are used.
        By default `False`. Can be a Tensor.

    Returns:
      A tensor with the same shape as `input_batch`.

    Raises:
      base.IncompatibleShapeError: If `data_format` is not valid for the
        input shape.
      base.NotSupportedError: If `input_batch` has data type of `tf.bfloat16`.
    """
        input_shape = input_batch.get_shape()

        if not self._data_format:
            if len(input_shape) == 2:
                self._data_format = "NC"
            elif len(input_shape) == 3:
                self._data_format = "NWC"
            elif len(input_shape) == 4:
                self._data_format = "NHWC"
            elif len(input_shape) == 5:
                self._data_format = "NDHWC"
            else:
                raise base.IncompatibleShapeError(
                    "Input shape {} has too many or too few dimensions.".
                    format(input_shape))

        self._channel_index = self._data_format.index("C")
        # Use list to turn range into iterator in python3.
        self._axis = list(range(len(self._data_format)))
        del self._axis[self._channel_index]

        if len(self._data_format) != len(input_shape):
            raise base.IncompatibleShapeError(
                "Incorrect data format {} for input shape {}.".format(
                    self._data_format, input_shape))

        dtype = input_batch.dtype.base_dtype
        if self._fused and dtype == tf.bfloat16:
            raise base.NotSupportedError(
                "Fused batch norm does not support tf.bfloat16.")
        # Maintain moving averages at a minimum precision of tf.float32.
        stat_dtype = tf.float32 if dtype in [tf.float16, tf.bfloat16
                                             ] else dtype

        self._num_channels = int(input_shape[self._channel_index])
        if self._channel_index == 1:
            self._image_shape = [int(x) for x in input_shape[2:]]
        else:
            self._image_shape = [int(x) for x in input_shape[1:-1]]

        self._expanded_mean_shape = [1] * len(input_shape)
        self._expanded_mean_shape[self._channel_index] = self._num_channels

        use_batch_stats = is_training | test_local_stats

        mean, variance = self._build_statistics(input_batch, use_batch_stats,
                                                stat_dtype)

        # Sets up optional gamma and beta parameters
        self._build_scale_offset(dtype)
        # Sets up the batch normalization op.
        out, mean, variance = self._batch_norm_op(input_batch, mean, variance,
                                                  use_batch_stats, stat_dtype)
        # Sets up the update op.
        update_ops = self._build_update_ops(mean, variance, is_training)

        # Put update ops in the update ops collection if given, otherwise add as
        # control dependencies of the output.
        if update_ops:
            if self._update_ops_collection:
                for update_op in update_ops:
                    tf.add_to_collection(self._update_ops_collection,
                                         update_op)
            else:
                with tf.control_dependencies(update_ops):
                    out = tf.identity(out)

        return out
  def _build(self, input_batch, is_training, test_local_stats=True):
    """Connects the BatchNorm module into the graph.

    Args:
      input_batch: A Tensor of arbitrary dimension. By default, the final
        dimension is not reduced over when computing the minibatch statistics.
      is_training: A boolean to indicate if the module should be connected in
        training mode, meaning the moving averages are updated. Can be a Tensor.
      test_local_stats: A boolean to indicate if local batch statistics should
        be used when `is_training=False`. If not, moving averages are used.
        By default `True`. Can be a Tensor.

    Returns:
      A tensor with the same shape as `input_batch`.

    Raises:
      base.IncompatibleShapeError: If `axis` is not valid for the
        input shape or has negative entries.
      base.NotSupportedError: If `input_batch` has data type of `tf.float16`.
    """
    input_shape = input_batch.get_shape()

    if self._axis is not None:
      if len(self._axis) > len(input_shape):
        raise base.IncompatibleShapeError(
            "Too many indices specified in axis: len({}) > len({}).".format(
                self._axis, input_shape))

      if max(self._axis) >= len(input_shape):
        raise base.IncompatibleShapeError(
            "One or more index in axis is too large for "
            "input shape: {} >= {:d}.".format(self._axis, len(input_shape)))

      if min(self._axis) < 0:
        raise base.IncompatibleShapeError(
            "Indices in axis must be non-negative: {} < 0.".format(
                self._axis))

      axis = self._axis
    else:
      # Reduce over all dimensions except the last.
      axis = tuple(range(len(input_shape))[:-1])

    # See following for important note on accuracy for dtype=tf.float16
    # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_impl.py#L63
    dtype = input_batch.dtype
    if dtype == tf.float16:
      raise base.NotSupportedError(
          "BatchNorm does not support `tf.float16`, insufficient "
          "precision for calculating sufficient statistics.")

    self._mean_shape = input_batch.get_shape().as_list()
    for index in axis:
      self._mean_shape[index] = 1

    use_batch_stats = is_training | test_local_stats

    mean, variance = self._build_statistics(input_batch, axis,
                                            use_batch_stats, dtype)

    # Sets up optional gamma and beta parameters
    self._build_scale_offset(dtype)
    # Sets up the batch normalization op.
    out, mean, variance = self._batch_norm_op(input_batch, mean, variance,
                                              use_batch_stats)
    # Sets up the update op.
    update_ops = self._build_update_ops(mean, variance, is_training)

    # Put update ops in the update ops collection if given, otherwise add as
    # control dependencies of the output.
    if update_ops:
      if self._update_ops_collection:
        for update_op in update_ops:
          tf.add_to_collection(self._update_ops_collection, update_op)
      else:
        with tf.control_dependencies(update_ops):
          out = tf.identity(out)

    return out
Esempio n. 4
0
    def _build(self, inputs):
        """Connects the LayerNorm module into the graph.

    Args:
      inputs: a Tensor of dimensionality >= 2.

    Returns:
      normalized: layer normalized outputs with same shape as inputs.

    Raises:
      base.NotSupportedError: If `inputs` has less than 2 dimensions.
    """

        if self._axis is None:
            axis = list(range(1, inputs.shape.ndims))
        else:
            axis = self._axis

        original_dtype = inputs.dtype
        if original_dtype in [tf.float16, tf.bfloat16]:
            inputs = tf.cast(inputs, tf.float32)

        if inputs.get_shape().ndims < 2:
            raise base.NotSupportedError(
                "Layer normalization expects inputs of at least rank 2."
                " Got inputs of rank {}.".format(inputs.get_shape().ndims))

        # Shape for the learnable scale and offset is the number of channels. See
        # https://arxiv.org/pdf/1803.08494.pdf around equation 6.
        params_shape = inputs.get_shape()[-1:]

        if self._scale:
            if self.GAMMA not in self._initializers:
                self._initializers[self.GAMMA] = create_gamma_initializer()
            self._gamma = tf.get_variable(
                self.GAMMA,
                shape=params_shape,
                dtype=inputs.dtype,
                initializer=self._initializers[self.GAMMA],
                partitioner=self._partitioners.get(self.GAMMA),
                regularizer=self._regularizers.get(self.GAMMA))
        else:
            self._gamma = None

        if self._offset:
            if self.BETA not in self._initializers:
                self._initializers[self.BETA] = create_beta_initializer()
            self._beta = tf.get_variable(
                self.BETA,
                shape=params_shape,
                dtype=inputs.dtype,
                initializer=self._initializers[self.BETA],
                partitioner=self._partitioners.get(self.BETA),
                regularizer=self._regularizers.get(self.BETA))
        else:
            self._beta = None

        mean, var = tf.nn.moments(inputs, axis, keep_dims=True)

        normalized = tf.nn.batch_normalization(inputs, mean, var, self._beta,
                                               self._gamma, self._eps)

        if original_dtype in [tf.float16, tf.bfloat16]:
            normalized = tf.cast(normalized, dtype=original_dtype)
        return normalized