def StreamStep(self, theta, inputs, paddings, state0): if py_utils.testonly_skip_norm_layers(): return inputs, paddings, state0 p = self.params assert p.cumulative inputs = py_utils.HasRank(inputs, p.input_rank) group_size = self.group_size num_groups = self.num_groups tf.logging.vlog(1, 'group_size: %s', group_size) tf.logging.vlog(1, 'num_groups: %s', num_groups) input_shape = py_utils.GetShape(inputs) with tf.name_scope(f'{p.name}/StreamStep'): x = tf.reshape(inputs, input_shape[:-1] + [num_groups, group_size]) expanded_rank = p.input_rank + 1 expanded_paddings = tf.reshape( paddings, input_shape[:2] + [1] * (expanded_rank - 2)) (group_mean, group_variance, cached_sum, cached_count, cached_var) = self._StreamMoments(x, expanded_paddings, state0.cached_sum, state0.cached_count, state0.cached_var) outputs = self._Normalize(theta, x, group_mean, group_variance) return outputs, paddings, py_utils.NestedMap( cached_sum=cached_sum, cached_count=cached_count, cached_var=cached_var)
def FProp(self, theta, inputs, paddings=None): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1] or [...], the rank is either the same as inputs or tf.rank(inputs) - 1. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ inputs, paddings = self._CastToFPropDtype((inputs, paddings)) if py_utils.testonly_skip_norm_layers(): return inputs p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) # shape [..., 1] paddings = self._MaybeExpandPaddings(inputs, paddings) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings) return self._ComputeBN(inputs, paddings, gamma, beta, norm_mean, norm_variance)
def FProp(self, theta, inputs): """Applies batch normalization. Using the implementation in github.com/ tensorflow/tpu/blob/master/models/official/amoeba_net/network_utils.py#L550 Args: theta: A nested map object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ if py_utils.testonly_skip_norm_layers(): return inputs p = self.params inputs_dtype = inputs.dtype inputs = tf.cast(inputs, p.dtype) inputs = py_utils.with_dependencies([ py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(theta.beta)) ], inputs) with tf.name_scope(p.name) as scope: if self.do_eval: outputs = tf.nn.batch_normalization(inputs, theta.moving_mean, theta.moving_variance, theta.beta, theta.gamma, p.epsilon) else: mean, variance = self._Moments(inputs, p.bn_group_size) mean = py_utils.CheckNumerics( mean, 'mean of {} failed numeric check'.format(scope)) variance = py_utils.CheckNumerics( variance, 'variance of {} failed numeric check'.format(scope)) outputs = tf.nn.batch_normalization(inputs, mean, variance, theta.beta, theta.gamma, p.epsilon) outputs.set_shape(inputs.get_shape()) return tf.cast(outputs, inputs_dtype)
def FProp(self, theta, inputs, paddings, class_emb): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [batch, ..., dim]. paddings: The paddings tensor. Shaped [batch, ..., 1], with the same rank as the input tensor. class_emb: The conditioning inputs, Shaped [batch, emb_dim]. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ if py_utils.testonly_skip_norm_layers(): return inputs p = self.params batch = py_utils.GetShape(inputs)[0] class_emb = py_utils.HasShape(class_emb, [batch, p.class_emb_dim]) if not py_utils.use_tpu(): class_emb = py_utils.with_dependencies([ py_utils.assert_less_equal( tf.cast(class_emb, tf.int32), 1, name='one_hot_assert1'), py_utils.assert_greater_equal( tf.cast(class_emb, tf.int32), 0, name='one_hot_assert2'), py_utils.assert_equal(tf.ones([batch], tf.int32), tf.cast(tf.reduce_sum(class_emb, -1), tf.int32), name='one_hot_assert3'), ], class_emb) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings=paddings, class_emb=class_emb) return self._ComputeBN(inputs, paddings, gamma, beta, norm_mean, norm_variance)
def FProp(self, theta, inputs, paddings=None): """Apply group normalization. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor with shape [batch_size, height, width, channel]. if p.rank == 4, else [batch, height, channel]. paddings: The paddings tensor with shape [batch_size, height]. Intended to be used for sequence processing where `height` is `time`. Returns: A single tensor as the output after applying group normalization, with the same shape as 'inputs'. Or a output, output_paddings pair if input paddings is not None. """ inputs, paddings = self._CastToFPropDtype((inputs, paddings)) if py_utils.testonly_skip_norm_layers(): if paddings is None: return inputs else: return inputs, paddings p = self.params inputs = py_utils.HasRank(inputs, p.input_rank) num_groups = self.num_groups input_shape = py_utils.GetShape(inputs) with tf.name_scope(p.name): x = tf.reshape(inputs, input_shape[:-1] + [num_groups, self.group_size]) expanded_rank = p.input_rank + 1 all_dims = list(range(expanded_rank)) if paddings is None or not p.cumulative: # Skips batch and num_groups. reduce_over_dims = all_dims[1:-2] + all_dims[-1:] else: # Skips batch, seqlen and num_groups. reduce_over_dims = all_dims[2:-2] + all_dims[-1:] if paddings is None and not p.cumulative: # Fast path on tpu without reshape. counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics( x, axes=reduce_over_dims, keepdims=True) group_mean, group_variance = tf.nn.normalize_moments( counts, means_ss, variance_ss, None) else: expanded_paddings = tf.reshape( paddings, input_shape[:2] + [1] * (expanded_rank - 2)) group_mean, group_variance = ComputeMoments( x, expanded_paddings, reduce_over_dims, cumulative_axis=1, enable_cross_replica_sum_on_tpu=p. enable_cross_replica_sum_on_tpu, keepdims=True) outputs = self._Normalize(theta, x, group_mean, group_variance) if paddings is None: return outputs else: return outputs, paddings