def __call__(self, inputs, *args, **kwargs):
        def make_quantizer_fn(training, x, quantizer_vars):
            """Use currying to return True/False specialized fns to the cond."""
            def quantizer_fn(x=x,
                             quantizer=self.quantizer,
                             quantizer_vars=quantizer_vars):
                return quantizer(x, training, weights=quantizer_vars)

            return quantizer_fn

        x = inputs
        if self._should_pre_quantize():
            x = control_flow_util.smart_cond(
                self._training,
                make_quantizer_fn(True, x, self._pre_activation_vars),
                make_quantizer_fn(False, x, self._pre_activation_vars))

        x = self.activation(x, *args, **kwargs)

        if self._should_post_quantize():
            x = control_flow_util.smart_cond(
                self._training,
                make_quantizer_fn(True, x, self._post_activation_vars),
                make_quantizer_fn(False, x, self._post_activation_vars))

        return x
Beispiel #2
0
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        real_inputs = K.math_ops.real(inputs)
        imag_inputs = K.math_ops.imag(inputs)

        def dropped_inputs(input_type):
            def _dropped_inputs():
                if input_type == 'real':
                    _inputs = real_inputs
                elif input_type == 'imag':
                    _inputs = imag_inputs
                else:
                    raise ValueError("Invalid input type. "
                                     "Available values are 'real' and 'imag'")
                return nn.dropout(_inputs,
                                  noise_shape=self._get_noise_shape(_inputs),
                                  seed=self.seed,
                                  rate=self.rate)

            return _dropped_inputs

        real_output = control_flow_util.smart_cond(
            training, dropped_inputs('real'),
            lambda: array_ops.identity(real_inputs))
        imag_output = control_flow_util.smart_cond(
            training, dropped_inputs('imag'),
            lambda: array_ops.identity(imag_inputs))
        return tf.complex(real_output, imag_output)
Beispiel #3
0
    def call(self, inputs, training=None):
        if training is None:
            training = tf.keras.backend.learning_phase()

        # Quantize all weights, and replace them in the underlying layer.

        quantized_weights = []
        for unquantized_weight, quantizer, quantizer_vars in self._weight_vars:
            quantized_weight = control_flow_util.smart_cond(
                training,
                self._make_quantizer_fn(quantizer, unquantized_weight, True,
                                        quantizer_vars),
                self._make_quantizer_fn(quantizer, unquantized_weight, False,
                                        quantizer_vars))
            quantized_weights.append(quantized_weight)

        self.quantize_config.set_quantize_weights(self.layer,
                                                  quantized_weights)

        # Replace all activations with `QuantizeAwareActivation`s which can
        # quantize activation tensors during graph construction.

        for quantize_activation in self._quantize_activations:
            quantize_activation.training = training

        self.quantize_config.set_quantize_activations(
            self.layer, self._quantize_activations)

        args = tf_inspect.getfullargspec(self.layer.call).args
        if 'training' in args:
            outputs = self.layer.call(inputs, training=training)
        else:
            outputs = self.layer.call(inputs)

        if not self._output_quantizers:
            return outputs

        # Assuming outputs is a single tensor. There might be some rare layers
        # where this is not true. Handle them when enabling such a layer.
        if isinstance(outputs, list) or isinstance(outputs, tuple):
            raise RuntimeError(
                'Multiple output tensors not handled currently.')

        output_quantizer = self._output_quantizers[0]
        return control_flow_util.smart_cond(
            training,
            self._make_quantizer_fn(output_quantizer, outputs, True,
                                    self._output_quantizer_vars),
            self._make_quantizer_fn(output_quantizer, outputs, False,
                                    self._output_quantizer_vars))
Beispiel #4
0
    def call(self, inputs, weights, training: tf.constant):
        """
        Apply rb sparsity mask to given weights.

        :param inputs: Target weights to sparsify.
        :param weights: Operation weights contains
            `mask` and param `trainable`.
        :param training: True if operation called in training mode
            else False
        """
        true_fn = lambda: apply_mask(inputs, self._calc_rb_binary_mask(weights))
        false_fn = lambda: apply_mask(inputs, binary_mask(weights['mask']))
        return smart_cond(training,
                          true_fn=lambda: smart_cond(weights['trainable'],
                                                     true_fn=true_fn, false_fn=false_fn),
                          false_fn=false_fn)
Beispiel #5
0
    def call(self, inputs, training=None, mask=None):
        if training is None:
            training = tf.keras.backend.learning_phase()

        input_logits, input_targets = inputs
        input_logits = tf.cast(input_logits, self.compute_dtype)

        input_logits, row_lengths = convert_inputs_if_ragged(input_logits)
        input_targets, _ = convert_inputs_if_ragged(input_targets)
        is_ragged_input = (row_lengths is not None)

        loss_weights = tf.ones_like(input_targets, dtype=tf.bool)
        loss_weights = maybe_convert_to_ragged(is_ragged_input, loss_weights, row_lengths)
        if is_ragged_input:
            loss_weights = loss_weights.to_tensor(False)
        if mask is not None:
            loss_weights = tf.logical_and(loss_weights, mask)
        loss_weights = tf.cast(loss_weights, self.compute_dtype)

        probs, loss = control_flow_util.smart_cond(
            training,
            lambda: self._train_probs_loss(input_logits, input_targets, loss_weights),
            lambda: self._eval_probs_loss(input_logits, input_targets, loss_weights)
        )
        self.add_loss(loss, inputs=True)

        probs = maybe_convert_to_ragged(is_ragged_input, probs, row_lengths)

        return probs
  def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name,
                                     experimental_aggregate_gradients):
    grads = [g for g, _ in grads_and_vars]
    loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads)

    def apply_fn():
      # We do not want DistributionStrategy to unwrap any MirroredVariables in
      # grads_and_vars, because even in a replica context, the wrapped optimizer
      # expects mirrored variables. So we wrap the variables with an
      # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the
      # MirroredVariables.
      wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars])
      return distribution.extended.call_for_each_replica(
          self._apply_gradients,
          args=(grads, wrapped_vars, name, experimental_aggregate_gradients))

    def do_not_apply_fn():
      # Normally self._optimizer.iterations is incremented in
      # self._optimizer.apply_gradients(). Since that is not called in this
      # branch, we increment it here instead.
      return self._optimizer.iterations.assign_add(1, read_value=False)

    # Note: We must call this cond() in a cross-replica context.
    # DistributionStrategy does not support having a cond in a replica context
    # with a branch that calls `merge_call`, and self._optimizer.apply_gradients
    # calls `merge_call`.
    maybe_apply_op = control_flow_util.smart_cond(should_apply_grads, apply_fn,
                                                  do_not_apply_fn)
    return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
Beispiel #7
0
    def call(self, inputs, training=True):
        if training is None:
            training = K.learning_phase()

        def mask_inputs():
            mask = tf.random.stateless_binomial(
                shape=tf.shape(inputs),
                seed=self.seed,
                counts=tf.ones((tf.shape(inputs)[1], )),
                probs=self.probs,
            )

            # tf.random.shuffle() without tf.gather() doesn't work in a custom layer
            # ref: https://github.com/tensorflow/tensorflow/issues/6269#issuecomment-465850464
            return tf.where(
                mask == 1,
                tf.gather(
                    inputs,
                    tf.random.shuffle(tf.range(tf.shape(inputs)[0]),
                                      seed=self.seed[0]),
                ),
                inputs,
            )

        outputs = control_flow_util.smart_cond(training, mask_inputs,
                                               lambda: inputs)

        return outputs
 def call(self, x, training=None):
     if training is None:
         training = keras.backend.learning_phase()
     output = control_flow_util.smart_cond(training, lambda: x * 0,
                                           lambda: array_ops.identity(x))
     if not context.executing_eagerly():
         output._uses_learning_phase = True  # pylint: disable=protected-access
     return output
Beispiel #9
0
    def call(self, inputs, training=None, initial_state=None):
        if training is None:
            training = tf.keras.backend.learning_phase()

        reverse_sim = [0] if self.time_major else [1]
        if self.go_backwards:
            inputs = tf.reverse(inputs, reverse_sim)

        inputs_batch_major = inputs
        if self.time_major:
            # go to batch_major for convolution if needed
            inputs_batch_major = tf.transpose(inputs, (1, 0, 2),
                                              name='to_batch_major')

        gate_values = self.conv1d(inputs_batch_major)
        if self.time_major:
            # return to time_major if needed
            gate_values = tf.transpose(gate_values, (1, 0, 2),
                                       name='to_time_major')

        gate_values = tf.split(gate_values,
                               3 if self.output_gate else 2,
                               axis=-1)
        if self.output_gate:
            z, f, o = gate_values
        else:
            z, f = gate_values

        z = self.act(z)
        f = self.gate_act(f)

        if self.zoneout > 0.:
            f = control_flow_util.smart_cond(
                training,
                # multiply by (1. - self.zoneout) due to dropout scales preserved items
                lambda: self.drop(f) * (1. - self.zoneout),
                lambda: f * (1. - self.zoneout))

        c = fo_pool(z,
                    f,
                    initial_state=initial_state,
                    time_major=self.time_major)
        h = self.gate_act(o) * c if self.output_gate else c

        if not self.return_sequences:
            h = h[:, -1, :] if not self.time_major else h[-1, :, :]
        elif self.go_backwards:
            h = tf.reverse(h, reverse_sim)

        if self.return_state:
            last_state = c[:, -1, :] if not self.time_major else c[-1, :, :]

            return h, last_state

        return h
    def call(self, inputs, training=None):
        if training is None:
            training = tf.keras.backend.learning_phase()

        def _make_quantizer_fn(train_var):
            def quantizer_fn():
                return self.quantizer(inputs,
                                      train_var,
                                      weights=self.quantizer_vars)

            return quantizer_fn

        return control_flow_util.smart_cond(training, _make_quantizer_fn(True),
                                            _make_quantizer_fn(False))
    def _apply_weight_quantizer(self, training, folded_conv_kernel):
        """All Keras call() logic for applying weight quantization."""
        def make_quantizer_fn(training):
            """Return quantizer conditioned on whether training or not."""
            def quantizer_fn():
                return self.weight_quantizer(
                    folded_conv_kernel,
                    training,
                    weights=self._weight_quantizer_vars)  # pylint: disable=protected-access

            return quantizer_fn

        return control_flow_util.smart_cond(training, make_quantizer_fn(True),
                                            make_quantizer_fn(False))
Beispiel #12
0
    def call(self, inputs, training=None):

        if self.rate == 0.0:
            return inputs

        if training is None:
            training = tf.keras.backend.learning_phase()

        if self.noise_shape is None:
            self.noise_shape = tf.shape(inputs)

        return control_flow_util.smart_cond(
            training, lambda: self._non_scaling_drop_op(inputs),
            lambda: array_ops.identity(inputs))
Beispiel #13
0
    def call(self, inputs, training=None):

        if inputs.shape.rank != 2:  # [batch, time]
            raise ValueError('inputs.shape.rank:%d must be 2' %
                             inputs.shape.rank)

        if not self.time_shift:
            return inputs

        if training is None:
            training = tf.keras.backend.learning_phase()
        # pylint: disable=g-long-lambda
        return control_flow_util.smart_cond(
            training, lambda: random_shift(inputs, self.time_shift, self.seed),
            lambda: array_ops.identity(inputs))
  def call(self, inputs, training=None):
    if training is None:
      training = tf.keras.backend.learning_phase()

    def masked_inputs():
      # in time dim
      net = spectrogram_masking(inputs, 1, self.time_masks_number,
                                self.time_mask_max_size)
      # in frequency dim
      net = spectrogram_masking(net, 2, self.frequency_masks_number,
                                self.frequency_mask_max_size)
      return net

    outputs = control_flow_util.smart_cond(training, masked_inputs,
                                           lambda: array_ops.identity(inputs))
    return outputs
Beispiel #15
0
    def _apply_scores(self, scores, value, scores_mask=None, training=None):
        """Applies attention scores to the given value tensor.

    To use this method in your attention layer, follow the steps:

    * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape
      `[batch_size, Tv]` to calculate the attention `scores`.
    * Pass `scores` and `value` tensors to this method. The method applies
      `scores_mask`, calculates `attention_distribution = softmax(scores)`, then
      returns `matmul(attention_distribution, value).
    * Apply `query_mask` and return the result.

    Args:
      scores: Scores float tensor of shape `[batch_size, Tq, Tv]`.
      value: Value tensor of shape `[batch_size, Tv, dim]`.
      scores_mask: A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or
        `[batch_size, Tq, Tv]`. If given, scores at positions where
        `scores_mask==False` do not contribute to the result. It must contain
        at least one `True` value in each line along the last dimension.
      training: Python boolean indicating whether the layer should behave in
        training mode (adding dropout) or in inference mode (no dropout).

    Returns:
      Tensor of shape `[batch_size, Tq, dim]`.
      Attention scores after masking and softmax with shape
        `[batch_size, Tq, Tv]`.
    """
        if scores_mask is not None:
            padding_mask = math_ops.logical_not(scores_mask)
            # Bias so padding positions do not contribute to attention distribution.
            # Note 65504. is the max float16 value.
            if scores.dtype is dtypes.float16:
                scores -= 65504. * math_ops.cast(padding_mask,
                                                 dtype=scores.dtype)
            else:
                scores -= 1.e9 * math_ops.cast(padding_mask,
                                               dtype=scores.dtype)
        if training is None:
            training = backend.learning_phase()
        weights = nn.softmax(scores)

        def dropped_weights():
            return nn.dropout(weights, rate=self.dropout)

        weights = control_flow_util.smart_cond(
            training, dropped_weights, lambda: array_ops.identity(weights))
        return math_ops.matmul(weights, value), weights
Beispiel #16
0
    def call(self, inputs, training=True):
        if training is None:
            training = K.learning_phase()

        def mask_inputs():
            mask = tf.random.stateless_binomial(shape=tf.shape(inputs),
                                                seed=self.seed,
                                                counts=tf.ones((tf.shape(inputs)[1],)),
                                                probs=self.probs)

            return tf.where(mask == 1, tf.zeros_like(inputs), inputs)

        outputs = control_flow_util.smart_cond(training,
                                               mask_inputs,
                                               lambda: inputs)

        return outputs
    def _apply_activation_quantizer(self, training, activation_output):
        """All Keras call() logic for applying weight quantization."""
        def make_quantizer_fn(training):
            """Return quantizer conditioned on whether training or not."""
            def quantizer_fn():
                weights = {
                    'min_var': self._activation_min_var,  # pylint: disable=protected-access
                    'max_var': self._activation_max_var
                }  # pylint: disable=protected-access
                return self.activation_quantizer(activation_output,
                                                 training,
                                                 weights=weights)

            return quantizer_fn

        return control_flow_util.smart_cond(training, make_quantizer_fn(True),
                                            make_quantizer_fn(False))
Beispiel #18
0
    def wrap_with_training_arg(*args, **kwargs):
        """Wrap the `wrapped_call` function, and set training argument."""
        training_arg_index = get_training_arg_index(original_call)
        training = get_training_arg(training_arg_index, args, kwargs)
        if training is None:
            training = default_training_value or K.learning_phase()

        args = list(args)
        kwargs = kwargs.copy()

        def replace_training_and_call(training):
            set_training_arg(training, training_arg_index, args, kwargs)
            return wrapped_call(*args, **kwargs)

        return control_flow_util.smart_cond(
            training, lambda: replace_training_and_call(True),
            lambda: replace_training_and_call(False))
Beispiel #19
0
  def call(self, inputs, training=None):

    if inputs.shape.rank != 3:  # [batch, time, feature]
      raise ValueError('inputs.shape.rank:%d must be 3' % inputs.shape.rank)

    if training is None:
      training = tf.keras.backend.learning_phase()

    def masked_inputs():
      net = tf.keras.backend.expand_dims(inputs, axis=-1)
      for i in range(self.masks_number):
        net = random_cutout(
            net, (self.time_mask_size, self.frequency_mask_size),
            seed=self.seed + i)
      net = tf.keras.backend.squeeze(net, axis=-1)
      return net

    outputs = control_flow_util.smart_cond(training, masked_inputs,
                                           lambda: array_ops.identity(inputs))
    return outputs
Beispiel #20
0
    def call(self, inputs, training=None, mask=None):
        with tf.device('cpu:0'):
            if training is None:
                training = tf.keras.backend.learning_phase()

            input_logits, input_targets = inputs
            input_logits = tf.cast(input_logits, self.compute_dtype)
            input_logits, row_lengths = convert_inputs_if_ragged(input_logits)
            input_targets, _ = convert_inputs_if_ragged(input_targets)
            is_ragged_input = (row_lengths is not None)

            loss_weights = tf.ones_like(input_targets, dtype=tf.bool)
            loss_weights = maybe_convert_to_ragged(is_ragged_input, loss_weights, row_lengths)
            if is_ragged_input:
                loss_weights = loss_weights.to_tensor(False)
            if mask is not None:
                loss_weights = tf.logical_and(loss_weights, mask)
            loss_weights = tf.cast(loss_weights, self.compute_dtype)

            input_shape = tf.shape(input_logits)
            output_shape = tf.stack(tf.unstack(input_shape)[:-1] + [self.units])
            input_logits = tf.reshape(input_logits, [-1, self.num_channels])
            input_targets = tf.reshape(input_targets, [-1])
            loss_weights = tf.reshape(loss_weights, [-1])

            output_logits = tf.matmul(input_logits, self.kernel, transpose_b=True)
            output_logits = tf.nn.bias_add(output_logits, self.bias)

            loss = control_flow_util.smart_cond(
                training,
                lambda: self._train_loss(input_logits, input_targets),
                lambda: self._eval_loss(output_logits, input_targets)
            )
            loss = compute_weighted_loss(loss, sample_weight=loss_weights, reduction=self.loss_reduction)
            self.add_loss(loss, inputs=True)

            output_probs = tf.nn.softmax(output_logits)
            output_probs = tf.reshape(output_probs, output_shape)
            output_probs = maybe_convert_to_ragged(is_ragged_input, output_probs, row_lengths)

            return output_probs
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()

        def add_update():
            with tf.control_dependencies([
                    tf.debugging.assert_greater_equal(
                        self.pruning_step,
                        np.int64(0),
                        message=self._PRUNE_CALLBACK_ERROR_MSG)
            ]):
                with tf.control_dependencies(
                    [self.pruning_obj.conditional_mask_update()]):
                    return tf.no_op('update')

        def no_op():
            return tf.no_op('no_update')

        update_op = control_flow_util.smart_cond(training, add_update, no_op)
        self.add_update(update_op)
        # Always execute the op that performs weights = weights * mask
        # Relies on UpdatePruningStep callback to ensure the weights
        # are sparse after the final backpropagation.
        #
        # self.add_update does nothing during eager execution.
        self.add_update(self.pruning_obj.weight_mask_op())
        # TODO(evcu) remove this check after dropping py2 support. In py3 getargspec
        # is deprecated.
        if hasattr(inspect, 'getfullargspec'):
            args = inspect.getfullargspec(self.layer.call).args
        else:
            args = inspect.getargspec(self.layer.call).args
        # Propagate the training bool to the underlying layer if it accepts
        # training as an arg.
        if 'training' in args:
            return self.layer.call(inputs, training=training)

        return self.layer.call(inputs)
Beispiel #22
0
            def variance_update():
                """Update the moving variance."""
                def true_branch_renorm():
                    # We apply epsilon as part of the moving_stddev to mirror the training
                    # code path.
                    moving_stddev = _do_update(
                        self.moving_stddev,
                        math_ops.sqrt(new_variance + self.epsilon))
                    return self._assign_new_value(
                        self.moving_variance,
                        # Apply relu in case floating point rounding causes it to go
                        # negative.
                        K.relu(moving_stddev * moving_stddev - self.epsilon))

                if self.renorm:
                    true_branch = true_branch_renorm
                else:
                    true_branch = lambda: _do_update(self.moving_variance,
                                                     new_variance)

                false_branch = lambda: self.moving_variance
                return control_flow_util.smart_cond(training, true_branch,
                                                    false_branch)
Beispiel #23
0
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()

        def dropped_inputs():
            rate = self.rate
            noise_shape = self.noise_shape
            seed = self.seed
            with ops.name_scope(None, "coordinated_dropout", [inputs]) as name:
                is_rate_number = isinstance(rate, numbers.Real)
                if is_rate_number and (rate < 0 or rate >= 1):
                    raise ValueError(
                        "rate must be a scalar tensor or a float in the "
                        "range [0, 1), got %g" % rate)
                x = ops.convert_to_tensor(inputs, name="x")
                x_dtype = x.dtype
                if not x_dtype.is_floating:
                    raise ValueError(
                        "x has to be a floating point tensor since it's going "
                        "to be scaled. Got a %s tensor instead." % x_dtype)
                is_executing_eagerly = context.executing_eagerly()
                if not tensor_util.is_tensor(rate):
                    if is_rate_number:
                        keep_prob = 1 - rate
                        scale = 1 / keep_prob
                        scale = ops.convert_to_tensor(scale, dtype=x_dtype)
                        ret = gen_math_ops.mul(x, scale)
                    else:
                        raise ValueError(
                            "rate is neither scalar nor scalar tensor %r" %
                            rate)
                else:
                    rate.get_shape().assert_has_rank(0)
                    rate_dtype = rate.dtype
                    if rate_dtype != x_dtype:
                        if not rate_dtype.is_compatible_with(x_dtype):
                            raise ValueError(
                                "Tensor dtype %s is incomptaible with Tensor dtype %s: %r"
                                % (x_dtype.name, rate_dtype.name, rate))
                        rate = gen_math_ops.cast(rate, x_dtype, name="rate")
                    one_tensor = constant_op.constant(1, dtype=x_dtype)
                    ret = gen_math_ops.real_div(
                        x, gen_math_ops.sub(one_tensor, rate))

                noise_shape = nn_ops._get_noise_shape(x, noise_shape)
                # Sample a uniform distribution on [0.0, 1.0) and select values larger
                # than rate.
                #
                # NOTE: Random uniform can only generate 2^23 floats on [1.0, 2.0)
                # and subtract 1.0.
                random_tensor = random_ops.random_uniform(noise_shape,
                                                          seed=seed,
                                                          dtype=x_dtype)
                # NOTE: if (1.0 + rate) - 1 is equal to rate, then that float is selected,
                # hence a >= comparison is used.
                keep_mask = random_tensor >= rate
                ret = gen_math_ops.mul(ret,
                                       gen_math_ops.cast(keep_mask, x_dtype))
                if not is_executing_eagerly:
                    ret.set_shape(x.get_shape())
                return ret, keep_mask

        output = control_flow_util.smart_cond(
            training, dropped_inputs, lambda:
            (array_ops.identity(inputs), array_ops.ones_like(inputs) > 0))
        return output
Beispiel #24
0
    def _subdiv_batch_norm(self, inputs, training=None):
        # tf.print('bn', self.local_count)
        training = self._get_training_value(training)

        inputs_dtype = inputs.dtype.base_dtype
        if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
            # Do all math in float32 if given 16-bit inputs for numeric stability.
            # In particular, it's very easy for variance to overflow in float16 and
            # for safety we also choose to cast bfloat16 to float32.
            inputs = math_ops.cast(inputs, dtypes.float32)

        params_dtype = self._param_dtype

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.shape) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        # what does this do...
        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # is training value true false or None
        training_value = control_flow_util.constant_value(training)
        update_value = (self.local_count + 1) % self.subdivisions == 0
        if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
            mean, variance = self.moving_mean, self.moving_variance
        else:
            # training_value could be True or None -> None means determine at runtime
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = control_flow_util.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = control_flow_util.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1

            # normalization stats for the current batch important = mean and squared_mean
            mean, net_sum, variance, squared_mean, input_batch_size = self.subdiv_moments(
                math_ops.cast(inputs, self._param_dtype),
                reduction_axes,
                keep_dims=keep_dims)

            # aggregate the things
            def _update_aggragate_sum():
                return self._assign_subdiv_rotating_sum(
                    self.aggregated_sum_batch, net_sum, self.subdivisions,
                    self.local_count, input_batch_size)

            def _update_aggragate_squared_sum():
                return self._assign_subdiv_rotating_sum(
                    self.aggregated_square_sum_batch, squared_mean,
                    self.subdivisions, self.local_count, input_batch_size)

            def _update_aggragate_batch_size():
                return self._assign_subdiv_rotating_sum(
                    self.aggregated_batch_size, input_batch_size,
                    self.subdivisions, self.local_count, input_batch_size)

            self.add_update(_update_aggragate_sum)
            self.add_update(_update_aggragate_squared_sum)
            self.add_update(_update_aggragate_batch_size)

            aggregated_mean = self.aggregated_sum_batch / math_ops.cast(
                self.aggregated_batch_size, params_dtype)
            aggregated_squared_mean = self.aggregated_square_sum_batch / math_ops.cast(
                self.aggregated_batch_size, params_dtype)
            aggregated_variance = aggregated_squared_mean - math_ops.square(
                aggregated_mean)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            # if we are training use the stats for this batch for normalizing this
            # value other wise use the moving average

            # should only happen when we update the moving values
            mean = control_flow_util.smart_cond(
                training,
                true_fn=lambda: mean,
                false_fn=lambda: ops.convert_to_tensor_v2_with_dispatch(
                    moving_mean))
            variance = control_flow_util.smart_cond(
                training,
                true_fn=lambda: variance,
                false_fn=lambda: ops.convert_to_tensor_v2_with_dispatch(
                    moving_variance))

            # circular update of the mean and variance
            new_mean = control_flow_util.smart_cond(
                update_value,
                true_fn=lambda: ops.convert_to_tensor_v2_with_dispatch(
                    aggregated_mean),
                false_fn=lambda: moving_mean)

            new_variance = control_flow_util.smart_cond(
                update_value,
                true_fn=lambda: ops.convert_to_tensor_v2_with_dispatch(
                    aggregated_variance),
                false_fn=lambda: moving_variance)

            # # should only be done when the moving mean is updated
            # tf.print(new_variance, self.local_count, update_value, self.aggregated_batch_size, self.aggregated_sum_batch)

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    new_mean, new_variance, training, input_batch_size)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)

            def _do_update(var, value):
                """Compute the updates for mean and variance."""
                return self._assign_moving_average(var, value, self.momentum,
                                                   self.aggregated_batch_size)

            def mean_update():
                true_branch = lambda: _do_update(self.moving_mean, new_mean)
                false_branch = lambda: self.moving_mean
                return control_flow_util.smart_cond(training, true_branch,
                                                    false_branch)

            def variance_update():
                """Update the moving variance."""
                def true_branch_renorm():
                    # We apply epsilon as part of the moving_stddev to mirror the training
                    # code path.
                    moving_stddev = _do_update(
                        self.moving_stddev,
                        math_ops.sqrt(new_variance + self.epsilon))
                    return self._assign_new_value(
                        self.moving_variance,
                        # Apply relu in case floating point rounding causes it to go
                        # negative.
                        K.relu(moving_stddev * moving_stddev - self.epsilon))

                if self.renorm:
                    true_branch = true_branch_renorm
                else:
                    true_branch = lambda: _do_update(self.moving_variance,
                                                     new_variance)

                false_branch = lambda: self.moving_variance
                return control_flow_util.smart_cond(training, true_branch,
                                                    false_branch)

            def update_count():
                with K.name_scope('update_count') as scope:
                    # update the local count
                    return state_ops.assign_add(self.local_count,
                                                tf.cast(
                                                    1, self.local_count.dtype),
                                                name=scope)

            self.add_update(mean_update)
            self.add_update(variance_update)
            self.add_update(update_count)

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        if scale is not None:
            scale = math_ops.cast(scale, inputs.dtype)
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
            outputs = math_ops.cast(outputs, inputs_dtype)

        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        return outputs
 def call(self, inputs, training=True):
     return control_flow_util.smart_cond(
         training, lambda: inputs * 0,
         lambda: array_ops.identity(inputs))
Beispiel #26
0
 def mean_update():
     true_branch = lambda: _do_update(self.moving_mean, new_mean)
     false_branch = lambda: self.moving_mean
     return control_flow_util.smart_cond(training, true_branch,
                                         false_branch)