def __call__(self, inputs, *args, **kwargs): def make_quantizer_fn(training, x, quantizer_vars): """Use currying to return True/False specialized fns to the cond.""" def quantizer_fn(x=x, quantizer=self.quantizer, quantizer_vars=quantizer_vars): return quantizer(x, training, weights=quantizer_vars) return quantizer_fn x = inputs if self._should_pre_quantize(): x = control_flow_util.smart_cond( self._training, make_quantizer_fn(True, x, self._pre_activation_vars), make_quantizer_fn(False, x, self._pre_activation_vars)) x = self.activation(x, *args, **kwargs) if self._should_post_quantize(): x = control_flow_util.smart_cond( self._training, make_quantizer_fn(True, x, self._post_activation_vars), make_quantizer_fn(False, x, self._post_activation_vars)) return x
def call(self, inputs, training=None): if training is None: training = K.learning_phase() real_inputs = K.math_ops.real(inputs) imag_inputs = K.math_ops.imag(inputs) def dropped_inputs(input_type): def _dropped_inputs(): if input_type == 'real': _inputs = real_inputs elif input_type == 'imag': _inputs = imag_inputs else: raise ValueError("Invalid input type. " "Available values are 'real' and 'imag'") return nn.dropout(_inputs, noise_shape=self._get_noise_shape(_inputs), seed=self.seed, rate=self.rate) return _dropped_inputs real_output = control_flow_util.smart_cond( training, dropped_inputs('real'), lambda: array_ops.identity(real_inputs)) imag_output = control_flow_util.smart_cond( training, dropped_inputs('imag'), lambda: array_ops.identity(imag_inputs)) return tf.complex(real_output, imag_output)
def call(self, inputs, training=None): if training is None: training = tf.keras.backend.learning_phase() # Quantize all weights, and replace them in the underlying layer. quantized_weights = [] for unquantized_weight, quantizer, quantizer_vars in self._weight_vars: quantized_weight = control_flow_util.smart_cond( training, self._make_quantizer_fn(quantizer, unquantized_weight, True, quantizer_vars), self._make_quantizer_fn(quantizer, unquantized_weight, False, quantizer_vars)) quantized_weights.append(quantized_weight) self.quantize_config.set_quantize_weights(self.layer, quantized_weights) # Replace all activations with `QuantizeAwareActivation`s which can # quantize activation tensors during graph construction. for quantize_activation in self._quantize_activations: quantize_activation.training = training self.quantize_config.set_quantize_activations( self.layer, self._quantize_activations) args = tf_inspect.getfullargspec(self.layer.call).args if 'training' in args: outputs = self.layer.call(inputs, training=training) else: outputs = self.layer.call(inputs) if not self._output_quantizers: return outputs # Assuming outputs is a single tensor. There might be some rare layers # where this is not true. Handle them when enabling such a layer. if isinstance(outputs, list) or isinstance(outputs, tuple): raise RuntimeError( 'Multiple output tensors not handled currently.') output_quantizer = self._output_quantizers[0] return control_flow_util.smart_cond( training, self._make_quantizer_fn(output_quantizer, outputs, True, self._output_quantizer_vars), self._make_quantizer_fn(output_quantizer, outputs, False, self._output_quantizer_vars))
def call(self, inputs, weights, training: tf.constant): """ Apply rb sparsity mask to given weights. :param inputs: Target weights to sparsify. :param weights: Operation weights contains `mask` and param `trainable`. :param training: True if operation called in training mode else False """ true_fn = lambda: apply_mask(inputs, self._calc_rb_binary_mask(weights)) false_fn = lambda: apply_mask(inputs, binary_mask(weights['mask'])) return smart_cond(training, true_fn=lambda: smart_cond(weights['trainable'], true_fn=true_fn, false_fn=false_fn), false_fn=false_fn)
def call(self, inputs, training=None, mask=None): if training is None: training = tf.keras.backend.learning_phase() input_logits, input_targets = inputs input_logits = tf.cast(input_logits, self.compute_dtype) input_logits, row_lengths = convert_inputs_if_ragged(input_logits) input_targets, _ = convert_inputs_if_ragged(input_targets) is_ragged_input = (row_lengths is not None) loss_weights = tf.ones_like(input_targets, dtype=tf.bool) loss_weights = maybe_convert_to_ragged(is_ragged_input, loss_weights, row_lengths) if is_ragged_input: loss_weights = loss_weights.to_tensor(False) if mask is not None: loss_weights = tf.logical_and(loss_weights, mask) loss_weights = tf.cast(loss_weights, self.compute_dtype) probs, loss = control_flow_util.smart_cond( training, lambda: self._train_probs_loss(input_logits, input_targets, loss_weights), lambda: self._eval_probs_loss(input_logits, input_targets, loss_weights) ) self.add_loss(loss, inputs=True) probs = maybe_convert_to_ragged(is_ragged_input, probs, row_lengths) return probs
def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name, experimental_aggregate_gradients): grads = [g for g, _ in grads_and_vars] loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads) def apply_fn(): # We do not want DistributionStrategy to unwrap any MirroredVariables in # grads_and_vars, because even in a replica context, the wrapped optimizer # expects mirrored variables. So we wrap the variables with an # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the # MirroredVariables. wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars]) return distribution.extended.call_for_each_replica( self._apply_gradients, args=(grads, wrapped_vars, name, experimental_aggregate_gradients)) def do_not_apply_fn(): # Normally self._optimizer.iterations is incremented in # self._optimizer.apply_gradients(). Since that is not called in this # branch, we increment it here instead. return self._optimizer.iterations.assign_add(1, read_value=False) # Note: We must call this cond() in a cross-replica context. # DistributionStrategy does not support having a cond in a replica context # with a branch that calls `merge_call`, and self._optimizer.apply_gradients # calls `merge_call`. maybe_apply_op = control_flow_util.smart_cond(should_apply_grads, apply_fn, do_not_apply_fn) return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
def call(self, inputs, training=True): if training is None: training = K.learning_phase() def mask_inputs(): mask = tf.random.stateless_binomial( shape=tf.shape(inputs), seed=self.seed, counts=tf.ones((tf.shape(inputs)[1], )), probs=self.probs, ) # tf.random.shuffle() without tf.gather() doesn't work in a custom layer # ref: https://github.com/tensorflow/tensorflow/issues/6269#issuecomment-465850464 return tf.where( mask == 1, tf.gather( inputs, tf.random.shuffle(tf.range(tf.shape(inputs)[0]), seed=self.seed[0]), ), inputs, ) outputs = control_flow_util.smart_cond(training, mask_inputs, lambda: inputs) return outputs
def call(self, x, training=None): if training is None: training = keras.backend.learning_phase() output = control_flow_util.smart_cond(training, lambda: x * 0, lambda: array_ops.identity(x)) if not context.executing_eagerly(): output._uses_learning_phase = True # pylint: disable=protected-access return output
def call(self, inputs, training=None, initial_state=None): if training is None: training = tf.keras.backend.learning_phase() reverse_sim = [0] if self.time_major else [1] if self.go_backwards: inputs = tf.reverse(inputs, reverse_sim) inputs_batch_major = inputs if self.time_major: # go to batch_major for convolution if needed inputs_batch_major = tf.transpose(inputs, (1, 0, 2), name='to_batch_major') gate_values = self.conv1d(inputs_batch_major) if self.time_major: # return to time_major if needed gate_values = tf.transpose(gate_values, (1, 0, 2), name='to_time_major') gate_values = tf.split(gate_values, 3 if self.output_gate else 2, axis=-1) if self.output_gate: z, f, o = gate_values else: z, f = gate_values z = self.act(z) f = self.gate_act(f) if self.zoneout > 0.: f = control_flow_util.smart_cond( training, # multiply by (1. - self.zoneout) due to dropout scales preserved items lambda: self.drop(f) * (1. - self.zoneout), lambda: f * (1. - self.zoneout)) c = fo_pool(z, f, initial_state=initial_state, time_major=self.time_major) h = self.gate_act(o) * c if self.output_gate else c if not self.return_sequences: h = h[:, -1, :] if not self.time_major else h[-1, :, :] elif self.go_backwards: h = tf.reverse(h, reverse_sim) if self.return_state: last_state = c[:, -1, :] if not self.time_major else c[-1, :, :] return h, last_state return h
def call(self, inputs, training=None): if training is None: training = tf.keras.backend.learning_phase() def _make_quantizer_fn(train_var): def quantizer_fn(): return self.quantizer(inputs, train_var, weights=self.quantizer_vars) return quantizer_fn return control_flow_util.smart_cond(training, _make_quantizer_fn(True), _make_quantizer_fn(False))
def _apply_weight_quantizer(self, training, folded_conv_kernel): """All Keras call() logic for applying weight quantization.""" def make_quantizer_fn(training): """Return quantizer conditioned on whether training or not.""" def quantizer_fn(): return self.weight_quantizer( folded_conv_kernel, training, weights=self._weight_quantizer_vars) # pylint: disable=protected-access return quantizer_fn return control_flow_util.smart_cond(training, make_quantizer_fn(True), make_quantizer_fn(False))
def call(self, inputs, training=None): if self.rate == 0.0: return inputs if training is None: training = tf.keras.backend.learning_phase() if self.noise_shape is None: self.noise_shape = tf.shape(inputs) return control_flow_util.smart_cond( training, lambda: self._non_scaling_drop_op(inputs), lambda: array_ops.identity(inputs))
def call(self, inputs, training=None): if inputs.shape.rank != 2: # [batch, time] raise ValueError('inputs.shape.rank:%d must be 2' % inputs.shape.rank) if not self.time_shift: return inputs if training is None: training = tf.keras.backend.learning_phase() # pylint: disable=g-long-lambda return control_flow_util.smart_cond( training, lambda: random_shift(inputs, self.time_shift, self.seed), lambda: array_ops.identity(inputs))
def call(self, inputs, training=None): if training is None: training = tf.keras.backend.learning_phase() def masked_inputs(): # in time dim net = spectrogram_masking(inputs, 1, self.time_masks_number, self.time_mask_max_size) # in frequency dim net = spectrogram_masking(net, 2, self.frequency_masks_number, self.frequency_mask_max_size) return net outputs = control_flow_util.smart_cond(training, masked_inputs, lambda: array_ops.identity(inputs)) return outputs
def _apply_scores(self, scores, value, scores_mask=None, training=None): """Applies attention scores to the given value tensor. To use this method in your attention layer, follow the steps: * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape `[batch_size, Tv]` to calculate the attention `scores`. * Pass `scores` and `value` tensors to this method. The method applies `scores_mask`, calculates `attention_distribution = softmax(scores)`, then returns `matmul(attention_distribution, value). * Apply `query_mask` and return the result. Args: scores: Scores float tensor of shape `[batch_size, Tq, Tv]`. value: Value tensor of shape `[batch_size, Tv, dim]`. scores_mask: A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or `[batch_size, Tq, Tv]`. If given, scores at positions where `scores_mask==False` do not contribute to the result. It must contain at least one `True` value in each line along the last dimension. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout). Returns: Tensor of shape `[batch_size, Tq, dim]`. Attention scores after masking and softmax with shape `[batch_size, Tq, Tv]`. """ if scores_mask is not None: padding_mask = math_ops.logical_not(scores_mask) # Bias so padding positions do not contribute to attention distribution. # Note 65504. is the max float16 value. if scores.dtype is dtypes.float16: scores -= 65504. * math_ops.cast(padding_mask, dtype=scores.dtype) else: scores -= 1.e9 * math_ops.cast(padding_mask, dtype=scores.dtype) if training is None: training = backend.learning_phase() weights = nn.softmax(scores) def dropped_weights(): return nn.dropout(weights, rate=self.dropout) weights = control_flow_util.smart_cond( training, dropped_weights, lambda: array_ops.identity(weights)) return math_ops.matmul(weights, value), weights
def call(self, inputs, training=True): if training is None: training = K.learning_phase() def mask_inputs(): mask = tf.random.stateless_binomial(shape=tf.shape(inputs), seed=self.seed, counts=tf.ones((tf.shape(inputs)[1],)), probs=self.probs) return tf.where(mask == 1, tf.zeros_like(inputs), inputs) outputs = control_flow_util.smart_cond(training, mask_inputs, lambda: inputs) return outputs
def _apply_activation_quantizer(self, training, activation_output): """All Keras call() logic for applying weight quantization.""" def make_quantizer_fn(training): """Return quantizer conditioned on whether training or not.""" def quantizer_fn(): weights = { 'min_var': self._activation_min_var, # pylint: disable=protected-access 'max_var': self._activation_max_var } # pylint: disable=protected-access return self.activation_quantizer(activation_output, training, weights=weights) return quantizer_fn return control_flow_util.smart_cond(training, make_quantizer_fn(True), make_quantizer_fn(False))
def wrap_with_training_arg(*args, **kwargs): """Wrap the `wrapped_call` function, and set training argument.""" training_arg_index = get_training_arg_index(original_call) training = get_training_arg(training_arg_index, args, kwargs) if training is None: training = default_training_value or K.learning_phase() args = list(args) kwargs = kwargs.copy() def replace_training_and_call(training): set_training_arg(training, training_arg_index, args, kwargs) return wrapped_call(*args, **kwargs) return control_flow_util.smart_cond( training, lambda: replace_training_and_call(True), lambda: replace_training_and_call(False))
def call(self, inputs, training=None): if inputs.shape.rank != 3: # [batch, time, feature] raise ValueError('inputs.shape.rank:%d must be 3' % inputs.shape.rank) if training is None: training = tf.keras.backend.learning_phase() def masked_inputs(): net = tf.keras.backend.expand_dims(inputs, axis=-1) for i in range(self.masks_number): net = random_cutout( net, (self.time_mask_size, self.frequency_mask_size), seed=self.seed + i) net = tf.keras.backend.squeeze(net, axis=-1) return net outputs = control_flow_util.smart_cond(training, masked_inputs, lambda: array_ops.identity(inputs)) return outputs
def call(self, inputs, training=None, mask=None): with tf.device('cpu:0'): if training is None: training = tf.keras.backend.learning_phase() input_logits, input_targets = inputs input_logits = tf.cast(input_logits, self.compute_dtype) input_logits, row_lengths = convert_inputs_if_ragged(input_logits) input_targets, _ = convert_inputs_if_ragged(input_targets) is_ragged_input = (row_lengths is not None) loss_weights = tf.ones_like(input_targets, dtype=tf.bool) loss_weights = maybe_convert_to_ragged(is_ragged_input, loss_weights, row_lengths) if is_ragged_input: loss_weights = loss_weights.to_tensor(False) if mask is not None: loss_weights = tf.logical_and(loss_weights, mask) loss_weights = tf.cast(loss_weights, self.compute_dtype) input_shape = tf.shape(input_logits) output_shape = tf.stack(tf.unstack(input_shape)[:-1] + [self.units]) input_logits = tf.reshape(input_logits, [-1, self.num_channels]) input_targets = tf.reshape(input_targets, [-1]) loss_weights = tf.reshape(loss_weights, [-1]) output_logits = tf.matmul(input_logits, self.kernel, transpose_b=True) output_logits = tf.nn.bias_add(output_logits, self.bias) loss = control_flow_util.smart_cond( training, lambda: self._train_loss(input_logits, input_targets), lambda: self._eval_loss(output_logits, input_targets) ) loss = compute_weighted_loss(loss, sample_weight=loss_weights, reduction=self.loss_reduction) self.add_loss(loss, inputs=True) output_probs = tf.nn.softmax(output_logits) output_probs = tf.reshape(output_probs, output_shape) output_probs = maybe_convert_to_ragged(is_ragged_input, output_probs, row_lengths) return output_probs
def call(self, inputs, training=None): if training is None: training = K.learning_phase() def add_update(): with tf.control_dependencies([ tf.debugging.assert_greater_equal( self.pruning_step, np.int64(0), message=self._PRUNE_CALLBACK_ERROR_MSG) ]): with tf.control_dependencies( [self.pruning_obj.conditional_mask_update()]): return tf.no_op('update') def no_op(): return tf.no_op('no_update') update_op = control_flow_util.smart_cond(training, add_update, no_op) self.add_update(update_op) # Always execute the op that performs weights = weights * mask # Relies on UpdatePruningStep callback to ensure the weights # are sparse after the final backpropagation. # # self.add_update does nothing during eager execution. self.add_update(self.pruning_obj.weight_mask_op()) # TODO(evcu) remove this check after dropping py2 support. In py3 getargspec # is deprecated. if hasattr(inspect, 'getfullargspec'): args = inspect.getfullargspec(self.layer.call).args else: args = inspect.getargspec(self.layer.call).args # Propagate the training bool to the underlying layer if it accepts # training as an arg. if 'training' in args: return self.layer.call(inputs, training=training) return self.layer.call(inputs)
def variance_update(): """Update the moving variance.""" def true_branch_renorm(): # We apply epsilon as part of the moving_stddev to mirror the training # code path. moving_stddev = _do_update( self.moving_stddev, math_ops.sqrt(new_variance + self.epsilon)) return self._assign_new_value( self.moving_variance, # Apply relu in case floating point rounding causes it to go # negative. K.relu(moving_stddev * moving_stddev - self.epsilon)) if self.renorm: true_branch = true_branch_renorm else: true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return control_flow_util.smart_cond(training, true_branch, false_branch)
def call(self, inputs, training=None): if training is None: training = K.learning_phase() def dropped_inputs(): rate = self.rate noise_shape = self.noise_shape seed = self.seed with ops.name_scope(None, "coordinated_dropout", [inputs]) as name: is_rate_number = isinstance(rate, numbers.Real) if is_rate_number and (rate < 0 or rate >= 1): raise ValueError( "rate must be a scalar tensor or a float in the " "range [0, 1), got %g" % rate) x = ops.convert_to_tensor(inputs, name="x") x_dtype = x.dtype if not x_dtype.is_floating: raise ValueError( "x has to be a floating point tensor since it's going " "to be scaled. Got a %s tensor instead." % x_dtype) is_executing_eagerly = context.executing_eagerly() if not tensor_util.is_tensor(rate): if is_rate_number: keep_prob = 1 - rate scale = 1 / keep_prob scale = ops.convert_to_tensor(scale, dtype=x_dtype) ret = gen_math_ops.mul(x, scale) else: raise ValueError( "rate is neither scalar nor scalar tensor %r" % rate) else: rate.get_shape().assert_has_rank(0) rate_dtype = rate.dtype if rate_dtype != x_dtype: if not rate_dtype.is_compatible_with(x_dtype): raise ValueError( "Tensor dtype %s is incomptaible with Tensor dtype %s: %r" % (x_dtype.name, rate_dtype.name, rate)) rate = gen_math_ops.cast(rate, x_dtype, name="rate") one_tensor = constant_op.constant(1, dtype=x_dtype) ret = gen_math_ops.real_div( x, gen_math_ops.sub(one_tensor, rate)) noise_shape = nn_ops._get_noise_shape(x, noise_shape) # Sample a uniform distribution on [0.0, 1.0) and select values larger # than rate. # # NOTE: Random uniform can only generate 2^23 floats on [1.0, 2.0) # and subtract 1.0. random_tensor = random_ops.random_uniform(noise_shape, seed=seed, dtype=x_dtype) # NOTE: if (1.0 + rate) - 1 is equal to rate, then that float is selected, # hence a >= comparison is used. keep_mask = random_tensor >= rate ret = gen_math_ops.mul(ret, gen_math_ops.cast(keep_mask, x_dtype)) if not is_executing_eagerly: ret.set_shape(x.get_shape()) return ret, keep_mask output = control_flow_util.smart_cond( training, dropped_inputs, lambda: (array_ops.identity(inputs), array_ops.ones_like(inputs) > 0)) return output
def _subdiv_batch_norm(self, inputs, training=None): # tf.print('bn', self.local_count) training = self._get_training_value(training) inputs_dtype = inputs.dtype.base_dtype if inputs_dtype in (dtypes.float16, dtypes.bfloat16): # Do all math in float32 if given 16-bit inputs for numeric stability. # In particular, it's very easy for variance to overflow in float16 and # for safety we also choose to cast bfloat16 to float32. inputs = math_ops.cast(inputs, dtypes.float32) params_dtype = self._param_dtype # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) # what does this do... def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # is training value true false or None training_value = control_flow_util.constant_value(training) update_value = (self.local_count + 1) % self.subdivisions == 0 if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison mean, variance = self.moving_mean, self.moving_variance else: # training_value could be True or None -> None means determine at runtime if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = control_flow_util.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = control_flow_util.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 # normalization stats for the current batch important = mean and squared_mean mean, net_sum, variance, squared_mean, input_batch_size = self.subdiv_moments( math_ops.cast(inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) # aggregate the things def _update_aggragate_sum(): return self._assign_subdiv_rotating_sum( self.aggregated_sum_batch, net_sum, self.subdivisions, self.local_count, input_batch_size) def _update_aggragate_squared_sum(): return self._assign_subdiv_rotating_sum( self.aggregated_square_sum_batch, squared_mean, self.subdivisions, self.local_count, input_batch_size) def _update_aggragate_batch_size(): return self._assign_subdiv_rotating_sum( self.aggregated_batch_size, input_batch_size, self.subdivisions, self.local_count, input_batch_size) self.add_update(_update_aggragate_sum) self.add_update(_update_aggragate_squared_sum) self.add_update(_update_aggragate_batch_size) aggregated_mean = self.aggregated_sum_batch / math_ops.cast( self.aggregated_batch_size, params_dtype) aggregated_squared_mean = self.aggregated_square_sum_batch / math_ops.cast( self.aggregated_batch_size, params_dtype) aggregated_variance = aggregated_squared_mean - math_ops.square( aggregated_mean) moving_mean = self.moving_mean moving_variance = self.moving_variance # if we are training use the stats for this batch for normalizing this # value other wise use the moving average # should only happen when we update the moving values mean = control_flow_util.smart_cond( training, true_fn=lambda: mean, false_fn=lambda: ops.convert_to_tensor_v2_with_dispatch( moving_mean)) variance = control_flow_util.smart_cond( training, true_fn=lambda: variance, false_fn=lambda: ops.convert_to_tensor_v2_with_dispatch( moving_variance)) # circular update of the mean and variance new_mean = control_flow_util.smart_cond( update_value, true_fn=lambda: ops.convert_to_tensor_v2_with_dispatch( aggregated_mean), false_fn=lambda: moving_mean) new_variance = control_flow_util.smart_cond( update_value, true_fn=lambda: ops.convert_to_tensor_v2_with_dispatch( aggregated_variance), false_fn=lambda: moving_variance) # # should only be done when the moving mean is updated # tf.print(new_variance, self.local_count, update_value, self.aggregated_batch_size, self.aggregated_sum_batch) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training, input_batch_size) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, self.aggregated_batch_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return control_flow_util.smart_cond(training, true_branch, false_branch) def variance_update(): """Update the moving variance.""" def true_branch_renorm(): # We apply epsilon as part of the moving_stddev to mirror the training # code path. moving_stddev = _do_update( self.moving_stddev, math_ops.sqrt(new_variance + self.epsilon)) return self._assign_new_value( self.moving_variance, # Apply relu in case floating point rounding causes it to go # negative. K.relu(moving_stddev * moving_stddev - self.epsilon)) if self.renorm: true_branch = true_branch_renorm else: true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return control_flow_util.smart_cond(training, true_branch, false_branch) def update_count(): with K.name_scope('update_count') as scope: # update the local count return state_ops.assign_add(self.local_count, tf.cast( 1, self.local_count.dtype), name=scope) self.add_update(mean_update) self.add_update(variance_update) self.add_update(update_count) mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) if inputs_dtype in (dtypes.float16, dtypes.bfloat16): outputs = math_ops.cast(outputs, inputs_dtype) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs
def call(self, inputs, training=True): return control_flow_util.smart_cond( training, lambda: inputs * 0, lambda: array_ops.identity(inputs))
def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return control_flow_util.smart_cond(training, true_branch, false_branch)