Example #1
0
 def calc_generator_moments_loss(y_true, y_pred):
     y_true_mean, y_true_var = nn.moments(x=y_true, axes=[0])
     y_pred_mean, y_pred_var = nn.moments(x=y_pred, axes=[0])
     g_loss_mean = reduce_mean(abs(y_true_mean - y_pred_mean))
     g_loss_var = reduce_mean(
         abs(sqrt(y_true_var + 1e-6) - sqrt(y_pred_var + 1e-6)))
     return g_loss_mean + g_loss_var
Example #2
0
    def call(self, inputs):
        mean, variance = nn.moments(
            inputs, self.moments_axes, keepdims=True)

        outputs = nn.batch_normalization(
            inputs, mean, variance, None, None, self.epsilon, name='LayerInstanceNorm')

        return outputs
Example #3
0
    def call(self, inputs, gamma, beta):
        mean, variance = nn.moments(
            inputs, self.moments_axes, keepdims=True)

        outputs = nn.batch_normalization(
            inputs, mean, variance, gamma, beta, self.epsilon, name='AdaptiveInstanceNorm')

        return outputs
Example #4
0
def batch_norm_layer(x, is_train, decay=0.9, name_or_scope=None):
    """
    x: [b, emb_dim]
    """
    with tf.variable_scope(name_or_scope=name_or_scope,
                           default_name="batch_norm_layer"):
        params_shape = [1, x.shape[-1]]
        beta = tf.get_variable("beta",
                               params_shape,
                               tf.float32,
                               initializer=tf.constant_initializer(
                                   0.0, tf.float32))
        gamma = tf.get_variable("gamma",
                                params_shape,
                                tf.float32,
                                initializer=tf.constant_initializer(
                                    1.0, tf.float32))
        if is_train:
            mean, variance = tfnn.moments(x, axes=[0], keep_dims=True)
            moving_mean = tf.get_variable('moving_mean',
                                          shape=params_shape,
                                          dtype=tf.float32,
                                          initializer=tf.constant_initializer(
                                              0.0, tf.float32),
                                          trainable=False)
            moving_variance = tf.get_variable(
                'moving_variance',
                shape=params_shape,
                dtype=tf.float32,
                initializer=tf.constant_initializer(1.0, tf.float32),
                trainable=False)
            tf.add_to_collection(
                tf.GraphKeys.TRAIN_OP,
                tf.assign(moving_mean,
                          decay * moving_mean + (1 - decay) * mean))
            tf.add_to_collection(
                tf.GraphKeys.TRAIN_OP,
                tf.assign(moving_variance,
                          decay * moving_variance + (1 - decay) * variance))
        else:
            mean = tf.get_variable('moving_mean',
                                   shape=params_shape,
                                   dtype=tf.float32,
                                   initializer=tf.constant_initializer(
                                       0.0, tf.float32),
                                   trainable=False)
            variance = tf.get_variable('moving_variance',
                                       shape=params_shape,
                                       dtype=tf.float32,
                                       initializer=tf.constant_initializer(
                                           1.0, tf.float32),
                                       trainable=False)
        x = tfnn.batch_normalization(x, mean, variance, beta, gamma, 1e-6)
    return x
Example #5
0
 def call(self, x):
     mean, variance = nn.moments(x, axes=[1, 2], keepdims=True)
     inv = math.rsqrt(variance + self.epsilon)
     normalized = (x - mean) * inv
     return self.scale * normalized + self.offset
    def call(self, inputs):
        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)

        # Broadcasting only necessary for norm where the axis is not just
        # the last dimension
        broadcast_shape = [1] * ndims
        for dim in self.axis:
            broadcast_shape[dim] = input_shape.dims[dim].value

        def _broadcast(v):
            if (v is not None and len(v.shape) != ndims
                    and self.axis != [ndims - 1]):
                return array_ops.reshape(v, broadcast_shape)
            return v

        if not self._fused:
            input_dtype = inputs.dtype
            if input_dtype in ('float16',
                               'bfloat16') and self.dtype == 'float32':
                # If mixed precision is used, cast inputs to float32 so that this is at
                # least as numerically stable as the fused version.
                inputs = math_ops.cast(inputs, 'float32')

            # Calculate the moments on the last axis (layer activations).
            mean, variance = nn.moments(inputs, self.axis, keep_dims=True)

            scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

            # Compute layer normalization using the batch_normalization function.
            outputs = nn.batch_normalization(inputs,
                                             mean,
                                             variance,
                                             offset=offset,
                                             scale=scale,
                                             variance_epsilon=self.epsilon)
            outputs = tf.cast(outputs, input_dtype)
        else:
            # Collapse dims before self.axis, and dims in self.axis
            pre_dim, in_dim = (1, 1)
            axis = sorted(self.axis)
            tensor_shape = array_ops.shape(inputs)
            for dim in range(0, ndims):
                dim_tensor = tensor_shape[dim]
                if dim < axis[0]:
                    pre_dim = pre_dim * dim_tensor
                else:
                    assert dim in axis
                    in_dim = in_dim * dim_tensor

            squeezed_shape = [1, pre_dim, in_dim, 1]
            # This fused operation requires reshaped inputs to be NCHW.
            data_format = 'NCHW'

            inputs = array_ops.reshape(inputs, squeezed_shape)

            def _set_const_tensor(val, dtype, shape):
                return array_ops.fill(shape,
                                      constant_op.constant(val, dtype=dtype))

            # self.gamma and self.beta have the wrong shape for fused_batch_norm, so
            # we cannot pass them as the scale and offset parameters. Therefore, we
            # create two constant tensors in correct shapes for fused_batch_norm and
            # later construct a separate calculation on the scale and offset.
            scale = _set_const_tensor(1.0, self.dtype, [pre_dim])
            offset = _set_const_tensor(0.0, self.dtype, [pre_dim])

            # Compute layer normalization using the fused_batch_norm function.
            outputs, _, _ = nn.fused_batch_norm(inputs,
                                                scale=scale,
                                                offset=offset,
                                                epsilon=self.epsilon,
                                                data_format=data_format)

            outputs = array_ops.reshape(outputs, tensor_shape)

            scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

            if scale is not None:
                outputs = outputs * math_ops.cast(scale, outputs.dtype)
            if offset is not None:
                outputs = outputs + math_ops.cast(offset, outputs.dtype)

        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        return outputs