def _build(self, inputs): """Connects the LayerNorm module into the graph. Args: inputs: a Tensor of shape `[batch_size, layer_dim]`. Returns: normalized: layer normalized outputs with same shape as inputs. Raises: base.NotSupportedError: If `inputs` has data type of `tf.float16`. """ if inputs.dtype == tf.float16: raise base.NotSupportedError( "LayerNorm does not support `tf.float16`, insufficient " "precision for calculating sufficient statistics." ) if inputs.get_shape().ndims != 2: raise base.NotSupportedError( "Layer normalization expects inputs of rank 2." " Got inputs of rank {}.".format(inputs.get_shape().ndims) ) hidden_size = inputs.get_shape()[1].value if self.GAMMA not in self._initializers: self._initializers[self.GAMMA] = create_gamma_initializer() self._gamma = tf.get_variable( self.GAMMA, shape=[hidden_size], dtype=inputs.dtype, initializer=self._initializers[self.GAMMA], partitioner=self._partitioners.get(self.GAMMA), regularizer=self._regularizers.get(self.GAMMA), ) if self.BETA not in self._initializers: self._initializers[self.BETA] = create_beta_initializer() self._beta = tf.get_variable( self.BETA, shape=[hidden_size], dtype=inputs.dtype, initializer=self._initializers[self.BETA], partitioner=self._partitioners.get(self.BETA), regularizer=self._regularizers.get(self.BETA), ) mean, var = tf.nn.moments(inputs, [1], keep_dims=True) normalized = tf.nn.batch_normalization(inputs, mean, var, self._beta, self._gamma, self._eps) return normalized
def _build(self, input_batch, is_training, test_local_stats=False): """Connects the BatchNormV2 module into the graph. Args: input_batch: A Tensor of the same dimension as `len(data_format)`. is_training: A boolean to indicate if the module should be connected in training mode, meaning the moving averages are updated. Can be a Tensor. test_local_stats: A boolean to indicate if local batch statistics should be used when `is_training=False`. If not, moving averages are used. By default `False`. Can be a Tensor. Returns: A tensor with the same shape as `input_batch`. Raises: base.IncompatibleShapeError: If `data_format` is not valid for the input shape. base.NotSupportedError: If `input_batch` has data type of `tf.bfloat16`. """ input_shape = input_batch.get_shape() if not self._data_format: if len(input_shape) == 2: self._data_format = "NC" elif len(input_shape) == 3: self._data_format = "NWC" elif len(input_shape) == 4: self._data_format = "NHWC" elif len(input_shape) == 5: self._data_format = "NDHWC" else: raise base.IncompatibleShapeError( "Input shape {} has too many or too few dimensions.". format(input_shape)) self._channel_index = self._data_format.index("C") # Use list to turn range into iterator in python3. self._axis = list(range(len(self._data_format))) del self._axis[self._channel_index] if len(self._data_format) != len(input_shape): raise base.IncompatibleShapeError( "Incorrect data format {} for input shape {}.".format( self._data_format, input_shape)) dtype = input_batch.dtype.base_dtype if self._fused and dtype == tf.bfloat16: raise base.NotSupportedError( "Fused batch norm does not support tf.bfloat16.") # Maintain moving averages at a minimum precision of tf.float32. stat_dtype = tf.float32 if dtype in [tf.float16, tf.bfloat16 ] else dtype self._num_channels = int(input_shape[self._channel_index]) if self._channel_index == 1: self._image_shape = [int(x) for x in input_shape[2:]] else: self._image_shape = [int(x) for x in input_shape[1:-1]] self._expanded_mean_shape = [1] * len(input_shape) self._expanded_mean_shape[self._channel_index] = self._num_channels use_batch_stats = is_training | test_local_stats mean, variance = self._build_statistics(input_batch, use_batch_stats, stat_dtype) # Sets up optional gamma and beta parameters self._build_scale_offset(dtype) # Sets up the batch normalization op. out, mean, variance = self._batch_norm_op(input_batch, mean, variance, use_batch_stats, stat_dtype) # Sets up the update op. update_ops = self._build_update_ops(mean, variance, is_training) # Put update ops in the update ops collection if given, otherwise add as # control dependencies of the output. if update_ops: if self._update_ops_collection: for update_op in update_ops: tf.add_to_collection(self._update_ops_collection, update_op) else: with tf.control_dependencies(update_ops): out = tf.identity(out) return out
def _build(self, input_batch, is_training, test_local_stats=True): """Connects the BatchNorm module into the graph. Args: input_batch: A Tensor of arbitrary dimension. By default, the final dimension is not reduced over when computing the minibatch statistics. is_training: A boolean to indicate if the module should be connected in training mode, meaning the moving averages are updated. Can be a Tensor. test_local_stats: A boolean to indicate if local batch statistics should be used when `is_training=False`. If not, moving averages are used. By default `True`. Can be a Tensor. Returns: A tensor with the same shape as `input_batch`. Raises: base.IncompatibleShapeError: If `axis` is not valid for the input shape or has negative entries. base.NotSupportedError: If `input_batch` has data type of `tf.float16`. """ input_shape = input_batch.get_shape() if self._axis is not None: if len(self._axis) > len(input_shape): raise base.IncompatibleShapeError( "Too many indices specified in axis: len({}) > len({}).".format( self._axis, input_shape)) if max(self._axis) >= len(input_shape): raise base.IncompatibleShapeError( "One or more index in axis is too large for " "input shape: {} >= {:d}.".format(self._axis, len(input_shape))) if min(self._axis) < 0: raise base.IncompatibleShapeError( "Indices in axis must be non-negative: {} < 0.".format( self._axis)) axis = self._axis else: # Reduce over all dimensions except the last. axis = tuple(range(len(input_shape))[:-1]) # See following for important note on accuracy for dtype=tf.float16 # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_impl.py#L63 dtype = input_batch.dtype if dtype == tf.float16: raise base.NotSupportedError( "BatchNorm does not support `tf.float16`, insufficient " "precision for calculating sufficient statistics.") self._mean_shape = input_batch.get_shape().as_list() for index in axis: self._mean_shape[index] = 1 use_batch_stats = is_training | test_local_stats mean, variance = self._build_statistics(input_batch, axis, use_batch_stats, dtype) # Sets up optional gamma and beta parameters self._build_scale_offset(dtype) # Sets up the batch normalization op. out, mean, variance = self._batch_norm_op(input_batch, mean, variance, use_batch_stats) # Sets up the update op. update_ops = self._build_update_ops(mean, variance, is_training) # Put update ops in the update ops collection if given, otherwise add as # control dependencies of the output. if update_ops: if self._update_ops_collection: for update_op in update_ops: tf.add_to_collection(self._update_ops_collection, update_op) else: with tf.control_dependencies(update_ops): out = tf.identity(out) return out
def _build(self, inputs): """Connects the LayerNorm module into the graph. Args: inputs: a Tensor of dimensionality >= 2. Returns: normalized: layer normalized outputs with same shape as inputs. Raises: base.NotSupportedError: If `inputs` has less than 2 dimensions. """ if self._axis is None: axis = list(range(1, inputs.shape.ndims)) else: axis = self._axis original_dtype = inputs.dtype if original_dtype in [tf.float16, tf.bfloat16]: inputs = tf.cast(inputs, tf.float32) if inputs.get_shape().ndims < 2: raise base.NotSupportedError( "Layer normalization expects inputs of at least rank 2." " Got inputs of rank {}.".format(inputs.get_shape().ndims)) # Shape for the learnable scale and offset is the number of channels. See # https://arxiv.org/pdf/1803.08494.pdf around equation 6. params_shape = inputs.get_shape()[-1:] if self._scale: if self.GAMMA not in self._initializers: self._initializers[self.GAMMA] = create_gamma_initializer() self._gamma = tf.get_variable( self.GAMMA, shape=params_shape, dtype=inputs.dtype, initializer=self._initializers[self.GAMMA], partitioner=self._partitioners.get(self.GAMMA), regularizer=self._regularizers.get(self.GAMMA)) else: self._gamma = None if self._offset: if self.BETA not in self._initializers: self._initializers[self.BETA] = create_beta_initializer() self._beta = tf.get_variable( self.BETA, shape=params_shape, dtype=inputs.dtype, initializer=self._initializers[self.BETA], partitioner=self._partitioners.get(self.BETA), regularizer=self._regularizers.get(self.BETA)) else: self._beta = None mean, var = tf.nn.moments(inputs, axis, keep_dims=True) normalized = tf.nn.batch_normalization(inputs, mean, var, self._beta, self._gamma, self._eps) if original_dtype in [tf.float16, tf.bfloat16]: normalized = tf.cast(normalized, dtype=original_dtype) return normalized