def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" # TODO(reedwm): Add support for fp16 inputs. beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) mean = array_ops.reshape(mean, shape=self.moving_mean.get_shape()) variance = array_ops.reshape(variance, shape=self.moving_variance.get_shape()) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = utils.constant_value(training) if training_value is None: one_minus_decay = utils.smart_cond(training, lambda: self._one_minus_decay, lambda: 0.) else: one_minus_decay = ops.convert_to_tensor(self._one_minus_decay) if training_value or training_value is None: mean_update = self._assign_moving_average(self.moving_mean, mean, one_minus_decay) variance_update = self._assign_moving_average(self.moving_variance, variance, one_minus_decay) if context.in_graph_mode(): # Note that in Eager mode, the updates are already executed when running # assign_moving_averages. So we do not need to put them into # collections. self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) return output
def _renorm_correction_and_moments(self, mean, variance, training): """Returns the correction and update values for renorm.""" stddev = math_ops.sqrt(variance + self.epsilon) # Compute the average mean and standard deviation, as if they were # initialized with this batch's moments. mixed_renorm_mean = (self.renorm_mean + (1. - self.renorm_mean_weight) * mean) mixed_renorm_stddev = (self.renorm_stddev + (1. - self.renorm_stddev_weight) * stddev) # Compute the corrections for batch renorm. r = stddev / mixed_renorm_stddev d = (mean - mixed_renorm_mean) / mixed_renorm_stddev # Ensure the corrections use pre-update moving averages. with ops.control_dependencies([r, d]): mean = array_ops.identity(mean) stddev = array_ops.identity(stddev) rmin, rmax, dmax = [self.renorm_clipping.get(key) for key in ['rmin', 'rmax', 'dmax']] if rmin is not None: r = math_ops.maximum(r, rmin) if rmax is not None: r = math_ops.minimum(r, rmax) if dmax is not None: d = math_ops.maximum(d, -dmax) d = math_ops.minimum(d, dmax) # When not training, use r=1, d=0, and decay=1 meaning no updates. r = utils.smart_cond(training, lambda: r, lambda: array_ops.ones_like(r)) d = utils.smart_cond(training, lambda: d, lambda: array_ops.zeros_like(d)) decay = utils.smart_cond(training, lambda: self.renorm_momentum, lambda: 1.) def _update_renorm_variable(var, weight, value): """Updates a moving average and weight, returns the unbiased value.""" # Update the variables without zero debiasing. The debiasing will be # accomplished by dividing the exponential moving average by the weight. # For example, after a single update, the moving average would be # (1-decay) * value. and the weight will be 1-decay, with their ratio # giving value. # Make sure the weight is not updated until before r and d computation. value = array_ops.identity(value) with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) new_var = moving_averages.assign_moving_average( var, value, decay, zero_debias=False) new_weight = moving_averages.assign_moving_average( weight, weight_value, decay, zero_debias=False) return new_var / new_weight with ops.colocate_with(self.moving_mean): new_mean = _update_renorm_variable(self.renorm_mean, self.renorm_mean_weight, mean) with ops.colocate_with(self.moving_variance): new_stddev = _update_renorm_variable(self.renorm_stddev, self.renorm_stddev_weight, stddev) # Make sqrt(moving_variance + epsilon) = new_stddev. new_variance = math_ops.square(new_stddev) - self.epsilon return (r, d, new_mean, new_variance)
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = utils.constant_value(training) if training_value is None: momentum = utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: mean_update = self._assign_moving_average(self.moving_mean, mean, momentum) variance_update = self._assign_moving_average(self.moving_variance, variance, momentum) self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) return output
def _update_renorm_variable(var, weight, value): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) def _do_update(): # Update the variables without zero debiasing. The debiasing will be # accomplished by dividing the exponential moving average by the weight. # For example, after a single update, the moving average would be # (1-decay) * value. and the weight will be 1-decay, with their ratio # giving the value. # Make sure the weight is not updated until before r and d computation. with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) new_var = moving_averages.assign_moving_average( var, value, self.renorm_momentum, zero_debias=False) new_weight = moving_averages.assign_moving_average( weight, weight_value, self.renorm_momentum, zero_debias=False) return new_var / new_weight def _fake_update(): return array_ops.identity(var) return utils.smart_cond(training, _do_update, _fake_update)
def _update_renorm_variable(var, weight, value): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) def _do_update(): """Updates the var and weight, returns their updated ratio.""" # Update the variables without zero debiasing. The debiasing will be # accomplished by dividing the exponential moving average by the weight. # For example, after a single update, the moving average would be # (1-decay) * value. and the weight will be 1-decay, with their ratio # giving the value. # Make sure the weight is not updated until before r and d computation. with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) new_var = self._assign_moving_average(var, value, self.renorm_momentum) new_weight = self._assign_moving_average( weight, weight_value, self.renorm_momentum) # TODO(yuefengz): the updates to var and weighted can not be batched # together if we fetch their updated values here. Consider calculating # new values and delaying the updates. return new_var / new_weight def _fake_update(): return array_ops.identity(var) return utils.smart_cond(training, _do_update, _fake_update)
def summary_writer_function(name, tensor, function, family=None): """Helper function to write summaries. Args: name: name of the summary tensor: main tensor to form the summary function: function taking a tag and a scope which writes the summary family: optional, the summary's family Returns: The result of writing the summary. """ def record(): with summary_op_util.summary_scope(name, family, values=[tensor]) as (tag, scope): with ops.control_dependencies([function(tag, scope)]): return constant_op.constant(True) with ops.device("cpu:0"): op = utils.smart_cond(should_record_summaries(), record, _nothing, name="") ops.add_to_collection(_SUMMARY_COLLECTION_NAME, op) return op
def _build_statistics(self, input_batch, axis, use_batch_stats, dtype): """Builds the statistics part of the graph when using moving variance. Args: input_batch: Input batch Tensor. axis: Indices of `input_batch` to reduce over. use_batch_stats: Boolean to indicate if batch statistics should be calculated, otherwise moving averages are returned. dtype: TensorFlow datatype to use for the moving mean and variance. Returns: Tuple of (mean, variance). """ # Set up our moving statistics. When connecting in parallel, this is shared. if self.MOVING_MEAN not in self._initializers: self._initializers[self.MOVING_MEAN] = create_mean_initializer() self._moving_mean = tf.get_variable( "moving_mean", dtype=dtype, shape=self._mean_shape, collections=[ tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES, ], initializer=self._initializers[self.MOVING_MEAN], trainable=False) if self.MOVING_VARIANCE not in self._initializers: self._initializers[self.MOVING_VARIANCE] = create_variance_initializer() self._moving_variance = tf.get_variable( "moving_variance", dtype=dtype, shape=self._mean_shape, collections=[ tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES, ], initializer=self._initializers[self.MOVING_VARIANCE], trainable=False) def build_batch_stats(): """Builds the batch statistics calculation ops.""" mean, variance = tf.nn.moments(input_batch, axis, keep_dims=True, name="normalize_moments") return mean, variance def build_moving_stats(): return ( tf.identity(self._moving_mean), tf.identity(self._moving_variance), ) mean, variance = utils.smart_cond( use_batch_stats, build_batch_stats, build_moving_stats, ) return mean, variance
def summary_writer_function(name, tensor, function, family=None): """Helper function to write summaries. Args: name: name of the summary tensor: main tensor to form the summary function: function taking a tag and a scope which writes the summary family: optional, the summary's family Returns: The result of writing the summary. """ def record(): with summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): with ops.control_dependencies([function(tag, scope)]): return constant_op.constant(True) if context.context().summary_writer_resource is None: return control_flow_ops.no_op() with ops.device("cpu:0"): op = utils.smart_cond( should_record_summaries(), record, _nothing, name="") ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access return op
def _fused_batch_norm_op(self, input_batch, mean, variance, use_batch_stats): """Creates a fused batch normalization op.""" # The fused batch norm expects the mean, variance, gamma and beta # tensors to have dimension 1, so we flatten them to remove the # extra dimensions. gamma_flatten = tf.reshape(self._gamma, shape=(-1,)) beta_flatten = tf.reshape(self._beta, shape=(-1,)) flatten_mean = tf.reshape(mean, shape=(-1,)) flatten_variance = tf.reshape(variance, shape=(-1,)) use_batch_stats = tf.convert_to_tensor(use_batch_stats) common_args = { "scale": gamma_flatten, "offset": beta_flatten, "epsilon": self._eps, "data_format": self._infer_fused_data_format(input_batch), "name": "batch_norm" } def use_batch_stats_fused_batch_norm(): return tf.nn.fused_batch_norm(input_batch, mean=None, variance=None, is_training=True, **common_args) def moving_average_fused_batch_norm(): return tf.nn.fused_batch_norm(input_batch, mean=flatten_mean, variance=flatten_variance, is_training=False, **common_args) batch_norm_op, mean, variance = utils.smart_cond( use_batch_stats, use_batch_stats_fused_batch_norm, moving_average_fused_batch_norm) return batch_norm_op, mean, variance
def distort_color(image, batch_position=0, distort_color_in_yiq=False): def distort_fn_0(image=image): """Variant 0 of distort function.""" image = tf.image.random_brightness(image, max_delta=R_BRIGHTNESS_MAX_DELTA) if distort_color_in_yiq: image = distort_image_ops.random_hsv_in_yiq( image, lower_saturation=R_SATURATION_LOWER, upper_saturation=R_SATURATION_UPPER, max_delta_hue=R_HUE_MAX_DELTA * math.pi) else: image = tf.image.random_saturation(image, lower=R_SATURATION_LOWER, upper=R_SATURATION_UPPER) image = tf.image.random_hue(image, max_delta=R_HUE_MAX_DELTA) image = tf.image.random_contrast(image, lower=R_CONSTRAST_LOWER, upper=R_CONSTRAST_UPPER) return image def distort_fn_1(image=image): """Variant 1 of distort function.""" image = tf.image.random_brightness(image, max_delta=R_BRIGHTNESS_MAX_DELTA) image = tf.image.random_contrast(image, lower=R_CONSTRAST_LOWER, upper=R_CONSTRAST_UPPER) if distort_color_in_yiq: image = distort_image_ops.random_hsv_in_yiq( image, lower_saturation=R_SATURATION_LOWER, upper_saturation=R_SATURATION_UPPER, max_delta_hue=R_HUE_MAX_DELTA * math.pi) else: image = tf.image.random_saturation(image, lower=R_SATURATION_LOWER, upper=R_SATURATION_UPPER) image = tf.image.random_hue(image, max_delta=R_HUE_MAX_DELTA) return image image = utils.smart_cond(batch_position % 2 == 0, distort_fn_0, distort_fn_1) # The random_* ops do not necessarily clamp. image = tf.clip_by_value(image, 0.0, 1.0) return image
def _build(self, inputs, is_training=True, dropout_keep_prob=0.5): """Assembles the `MLP` and connects it to the graph. Args: inputs: A 2D Tensor of size `[batch_size, input_size]`. is_training: A bool or tf.Bool Tensor. Indicates whether we are currently training. Defaults to `True`. dropout_keep_prob: The probability that each element is kept when both `use_dropout` and `is_training` are True. Defaults to 0.5. Returns: A 2D Tensor of size `[batch_size, output_sizes[-1]]`. """ self._input_shape = tuple(inputs.get_shape().as_list()) net = inputs final_index = self._num_layers - 1 for layer_id in xrange(self._num_layers): net = self._layers[layer_id](net) if final_index != layer_id or self._activate_final: # Only perform dropout whenever we are activating the layer's outputs. if self._use_dropout: keep_prob = utils.smart_cond( is_training, true_fn=lambda: dropout_keep_prob, false_fn=lambda: tf.constant(1.0)) net = tf.nn.dropout(net, keep_prob=keep_prob) net = self._activation(net) return net
def summary_writer_function(name, tensor, function, family=None): """Helper function to write summaries. Args: name: name of the summary tensor: main tensor to form the summary function: function taking a tag and a scope which writes the summary family: optional, the summary's family Returns: The result of writing the summary. """ def record(): with summary_op_util.summary_scope(name, family, values=[tensor]) as (tag, scope): with ops.control_dependencies([function(tag, scope)]): return constant_op.constant(True) if context.context().summary_writer_resource is None: return control_flow_ops.no_op() with ops.device("cpu:0"): op = utils.smart_cond(should_record_summaries(), record, _nothing, name="") ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access return op
def dropout_selu(x, rate, alpha=-1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0, noise_shape=None, seed=None, name=None, training=False): """Dropout to a value with rescaling.""" def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name): keep_prob = 1.0 - rate x = ops.convert_to_tensor(x, name="x") if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1: raise ValueError( "keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob") keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha") alpha.get_shape().assert_is_compatible_with(tensor_shape.scalar()) if tensor_util.constant_value(keep_prob) == 1: return x noise_shape = noise_shape if noise_shape is not None else array_ops.shape( x) random_tensor = keep_prob random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype) binary_tensor = math_ops.floor(random_tensor) ret = x * binary_tensor + alpha * (1 - binary_tensor) a = math_ops.sqrt( fixedPointVar / (keep_prob * ((1 - keep_prob) * math_ops.pow(alpha - fixedPointMean, 2) + fixedPointVar))) b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha) ret = a * ret + b ret.set_shape(x.get_shape()) return ret with ops.name_scope(name, "dropout", [x]) as name: return utils.smart_cond( training, lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name), lambda: array_ops.identity(x)) #如果用keras中封装了selu # 如果做了归一化或标准化后再用selu激活时 在创建神经网络层时可以用‘lecun_normal’来初始化(默认是RandomNormal) # eg :model.add(Dense(inputdim=41, units=100, activation='selu', kernel_initializer='lecun_normal'))
def act(o: [so], noisy=True): with arg_scope([layers.batch_norm], is_training=False): s = preprocess(o) a = actor(s, noise=noisy) a = smart_cond(noisy, lambda: noise(a), lambda: a) q = critic(s, a) layers.summarize_tensors([s, a, q]) return a
def call(self, inputs, training=False): def dropped_inputs(): return nn.dropout(inputs, 1 - self.rate, noise_shape=self.noise_shape, seed=self.seed) return utils.smart_cond(training, dropped_inputs, lambda: array_ops.identity(inputs))
def call(self, inputs, training=True): def dropped_inputs(): return self.concrete_dropout(inputs) if not self.reuse: self.apply_dropout_regularizer(inputs) return utils.smart_cond(training, dropped_inputs, lambda: array_ops.identity(inputs))
def call(self, inputs, training=False): def dropped_inputs(): return nn.dropout(inputs, 1 - self.rate, noise_shape=self._get_noise_shape(inputs), seed=self.seed) return utils.smart_cond(training, dropped_inputs, lambda: array_ops.identity(inputs))
def distort_color(image, batch_position=0, distort_color_in_yiq=False, scope=None): """Distort the color of the image. Each color distortion is non-commutative and thus ordering of the color ops matters. Ideally we would randomly permute the ordering of the color ops. Rather then adding that level of complication, we select a distinct ordering of color ops based on the position of the image in a batch. Args: image: float32 Tensor containing single image. Tensor values should be in range [0, 1]. batch_position: the position of the image in a batch. NOTE: this argument can be an integer or a tensor distort_color_in_yiq: distort color of input images in YIQ space. scope: Optional scope for op_scope. Returns: color-distorted image """ with tf.name_scope(scope or 'distort_color'): def distort_fn_0(image=image): """Variant 0 of distort function.""" image = tf.image.random_brightness(image, max_delta=32. / 255.) if distort_color_in_yiq: image = distort_image_ops.random_hsv_in_yiq( image, lower_saturation=0.5, upper_saturation=1.5, max_delta_hue=0.2 * math.pi) else: image = tf.image.random_saturation(image, lower=0.5, upper=1.5) image = tf.image.random_hue(image, max_delta=0.2) image = tf.image.random_contrast(image, lower=0.5, upper=1.5) return image def distort_fn_1(image=image): """Variant 1 of distort function.""" image = tf.image.random_brightness(image, max_delta=32. / 255.) image = tf.image.random_contrast(image, lower=0.5, upper=1.5) if distort_color_in_yiq: image = distort_image_ops.random_hsv_in_yiq( image, lower_saturation=0.5, upper_saturation=1.5, max_delta_hue=0.2 * math.pi) else: image = tf.image.random_saturation(image, lower=0.5, upper=1.5) image = tf.image.random_hue(image, max_delta=0.2) return image image = utils.smart_cond(batch_position % 2 == 0, distort_fn_0, distort_fn_1) # The random_* ops do not necessarily clamp. image = tf.clip_by_value(image, 0.0, 1.0) return image
def dropout_selu(x, rate, alpha=-1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0, noise_shape=None, seed=None, name=None, training=False): from tensorflow.python.framework import ops, tensor_shape, tensor_util from tensorflow.python.ops import array_ops, random_ops, math_ops from tensorflow.python.layers import utils import numbers import tensorflow as tf """Dropout to a value with rescaling.""" def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name): keep_prob = 1.0 - rate x = ops.convert_to_tensor(x, name="x") if isinstance(keep_prob, numbers.Real) and not 0. < keep_prob <= 1.: raise ValueError( "keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob") keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha") keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) if tensor_util.constant_value(keep_prob) == 1: return x noise_shape = noise_shape if noise_shape is not None else array_ops.shape( x) random_tensor = keep_prob random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype) binary_tensor = math_ops.floor(random_tensor) ret = x * binary_tensor + alpha * (1 - binary_tensor) a = tf.sqrt(fixedPointVar / (keep_prob * ((1 - keep_prob) * tf.pow(alpha - fixedPointMean, 2) + fixedPointVar))) b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha) ret = a * ret + b ret.set_shape(x.get_shape()) return ret with ops.name_scope(name, "dropout", [x]) as name: return utils.smart_cond( training, lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name), lambda: array_ops.identity(x))
def crf_log_norm(inputs, sequence_lengths, transition_params): """Computes the normalization for a CRF. Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix. Returns: log_norm: A [batch_size] vector of normalizers for a CRF. """ # Split up the first and rest of the inputs in preparation for the forward # algorithm. first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1]) first_input = array_ops.squeeze(first_input, [1]) # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over # the "initial state" (the unary potentials). def _single_seq_fn(): log_norm = math_ops.reduce_logsumexp(first_input, [1]) # Mask `log_norm` of the sequences with length <= zero. log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0), array_ops.zeros_like(log_norm), log_norm) return log_norm def _multi_seq_fn(): """Forward computation of alpha values.""" rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) # Compute the alpha values in the forward algorithm in order to get the # partition function. forward_cell = CrfForwardRnnCell(transition_params) # Sequence length is not allowed to be less than zero. sequence_lengths_less_one = math_ops.maximum( constant_op.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1) _, alphas = rnn.dynamic_rnn( cell=forward_cell, inputs=rest_of_input, sequence_length=sequence_lengths_less_one, initial_state=first_input, dtype=dtypes.float32) log_norm = math_ops.reduce_logsumexp(alphas, [1]) # Mask `log_norm` of the sequences with length <= zero. log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0), array_ops.zeros_like(log_norm), log_norm) return log_norm return utils.smart_cond( pred=math_ops.equal( tensor_shape.dimension_value( inputs.shape[1]) or array_ops.shape(inputs)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn)
def layer_dropout(self, x, reuse, is_training, l_ind): name = self.dropout_name + "_" + str(l_ind) if self.alpha_drop_flag: return utils.smart_cond(is_training, lambda: tf.contrib.nn.alpha_dropout(x, self.dropout, noise_shape=None, seed=None, name=name), lambda: array_ops.identity(x)) else: return tf.layers.dropout(inputs=x, rate=self.dropout, training=is_training, name=name)
def gaussian_noise_layer(input_layer, std, training): def add_noise(): noise = tf.random_normal(shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32) return input_layer + noise return utils.smart_cond(training, add_noise, lambda: input_layer)
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" # TODO(reedwm): Add support for fp16 inputs. beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _fused_batch_norm_training(): return nn.fused_batch_norm(inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm(inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = utils.constant_value(training) if training_value is None: one_minus_decay = _smart_select(training, lambda: self._one_minus_decay, lambda: 0.) else: one_minus_decay = ops.convert_to_tensor(self._one_minus_decay) if training_value or training_value is None: mean_update = self._assign_moving_average(self.moving_mean, mean, one_minus_decay) variance_update = self._assign_moving_average( self.moving_variance, variance, one_minus_decay) if context.in_graph_mode(): # Note that in Eager mode, the updates are already executed when running # assign_moving_averages. So we do not need to put them into # collections. self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) return output
def _drop_attention_logits(self, logits: tf.Tensor, pad_mask: tf.Tensor, training: tf.Tensor) -> tf.Tensor: def droped_logits() -> tf.Tensor: keep_prob = tf.random.uniform(tf.shape(logits), 0, 1) + pad_mask drop_mask = tf.cast( tf.less(keep_prob, self.attention_dropout_rate), logits.dtype) return logits + drop_mask * -1e9 return smart_cond(training, droped_logits, lambda: tf.identity(logits))
def testConstantValue(self): f1 = lambda: constant_op.constant(5) f2 = lambda: constant_op.constant(32) # Boolean pred self.assertEqual(5, utils.constant_value(utils.smart_cond(True, f1, f2))) self.assertEqual(32, utils.constant_value(utils.smart_cond(False, f1, f2))) # Integer pred self.assertEqual(5, utils.constant_value(utils.smart_cond(1, f1, f2))) self.assertEqual(32, utils.constant_value(utils.smart_cond(0, f1, f2))) # Unknown pred pred = array_ops.placeholder_with_default(True, shape=()) self.assertIsNone(utils.constant_value(utils.smart_cond(pred, f1, f2))) #Error case with self.assertRaises(TypeError): utils.constant_value(5)
def call(self, inputs, training=False): input_dim = inputs.get_shape()[-1].value k = random_ops.random_uniform([1], maxval=input_dim, dtype=dtypes.int32)[0] _, indices = nn_ops.top_k(inputs, k, sorted=False) mask = array_ops.one_hot(indices, input_dim, axis=-1) mask = math_ops.reduce_sum(mask, axis=-2) return utils.smart_cond(training, lambda: mask * inputs, lambda: array_ops.identity(inputs))
def call(self, inputs, training=False): l0_op = l0_regularizer(self.reg_const, seed=self.seed) if l0_op is not None: self.add_loss(l0_op(inputs)) l0_regularizer_scope = self.scope_name + '/l0_regularizer/' layer_trng, layer_pred = get_l0_maskeds(l0_regularizer_scope) return utils.smart_cond(training, lambda: layer_trng, lambda: layer_pred) else: return inputs
def dropout_selu(x, rate, alpha=-1.7580993408473766, noise_shape=None, seed=None, name=None, training=False): """Dropout to a value with rescaling.""" def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name): keep_prob = 1.0 - rate x = ops.convert_to_tensor(x, name="x") if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1: raise ValueError( "keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob") keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha") keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) # Do nothing if we know keep_prob == 1 if tensor_util.constant_value(keep_prob) == 1: return x noise_shape = noise_shape if noise_shape is not None else array_ops.shape( x) # uniform [keep_prob, 1.0 + keep_prob) random_tensor = keep_prob random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype) # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) binary_tensor = math_ops.floor(random_tensor) #binary_tensor2 = math_ops.ceil(random_tensor) ret = x * binary_tensor + alpha * (1 - binary_tensor) a = tf.sqrt(1.0 / (keep_prob + alpha * alpha * keep_prob * (1.0 - keep_prob))) b = -a * (1.0 - keep_prob) * alpha ret = a * ret + b #ret = tf.add(tf.multiply(a , ret) , b) ret.set_shape(x.get_shape()) return ret with ops.name_scope(name, "dropout", [x]) as name: return utils.smart_cond( training, lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name), lambda: array_ops.identity(x))
def _fused_batch_norm_op(self, input_batch, mean, variance, use_batch_stats): """Creates a fused batch normalization op.""" # Store the original shape of the mean and variance. mean_shape = mean.get_shape() variance_shape = variance.get_shape() # The fused batch norm expects the mean, variance, gamma and beta # tensors to have dimension 1, so we flatten them to remove the # extra dimensions. gamma_flatten = tf.reshape(self._gamma, shape=(-1, )) beta_flatten = tf.reshape(self._beta, shape=(-1, )) flatten_mean = tf.reshape(mean, shape=(-1, )) flatten_variance = tf.reshape(variance, shape=(-1, )) use_batch_stats = tf.convert_to_tensor(use_batch_stats) common_args = { "scale": gamma_flatten, "offset": beta_flatten, "epsilon": self._eps, "data_format": self._infer_fused_data_format(input_batch), "name": "batch_norm", } def use_batch_stats_fused_batch_norm(): return tf.nn.fused_batch_norm( input_batch, mean=None, variance=None, is_training=True, **common_args, ) def moving_average_fused_batch_norm(): return tf.nn.fused_batch_norm( input_batch, mean=flatten_mean, variance=flatten_variance, is_training=False, **common_args, ) batch_norm_op, mean, variance = utils.smart_cond( use_batch_stats, use_batch_stats_fused_batch_norm, moving_average_fused_batch_norm, ) mean = tf.reshape(mean, mean_shape) variance = tf.reshape(variance, variance_shape) return batch_norm_op, mean, variance
def dropout_selu(x, keep_prob, alpha=-1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0, noise_shape=None, seed=None, name=None, training=False): # 드롭아웃 selu 함수 """Dropout to a value with rescaling.""" def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name): keep_prob = 1.0 - rate # 드롭아웃 비율 x = ops.convert_to_tensor(x, name="x") # 텐서로 변환 if isinstance( keep_prob, numbers.Real) and not 0 < keep_prob <= 1: # 적절한 범위가 아니면 에러 발생 raise ValueError( "keep_prob must be a scalar tensor or a float in the range (0, 1], got %g" % keep_prob) keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob") # 텐서로 변환 keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha") # 텐서로 변환 keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) if tensor_util.constant_value(keep_prob) == 1: # keep_prob 1이면 리턴 return x noise_shape = noise_shape if noise_shape is not None else array_ops.shape( x) # noise_shape None 이면 array_ops.shape random_tensor = keep_prob # 랜덤 텐서 random_tensor += random_ops.random_uniform( noise_shape, seed=seed, dtype=x.dtype) # uniform 분포의 랜덤값 추가 binary_tensor = math_ops.floor(random_tensor) # 크지 않은 최대 정수값 ret = x * binary_tensor + alpha * (1 - binary_tensor) a = tf.sqrt(fixedPointVar / (keep_prob * ((1 - keep_prob) * tf.pow(alpha - fixedPointMean, 2) + fixedPointVar))) b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha) ret = a * ret + b ret.set_shape(x.get_shape()) return ret with ops.name_scope(name, "dropout", [x]) as name: return utils.smart_cond( training, lambda: dropout_selu_impl(x, keep_prob, alpha, noise_shape, seed, name), lambda: array_ops.identity(x))
def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, transition_params): """Computes the unnormalized score of all tag sequences matching tag_bitmap. tag_bitmap enables more than one tag to be considered correct at each time step. This is useful when an observed output at a given time step is consistent with more than one tag, and thus the log likelihood of that observation must take into account all possible consistent tags. Using one-hot vectors in tag_bitmap gives results identical to crf_sequence_score. Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor representing all active tags at each index for which to calculate the unnormalized score. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix. Returns: sequence_scores: A [batch_size] vector of unnormalized sequence scores. """ # If max_seq_len is 1, we skip the score calculation and simply gather the # unary potentials of all active tags. def _single_seq_fn(): filtered_inputs = array_ops.where( tag_bitmap, inputs, array_ops.fill(array_ops.shape(inputs), float("-inf"))) return math_ops.reduce_logsumexp( filtered_inputs, axis=[1, 2], keepdims=False) def _multi_seq_fn(): # Compute the logsumexp of all scores of sequences matching the given tags. filtered_inputs = array_ops.where( tag_bitmap, inputs, array_ops.fill(array_ops.shape(inputs), float("-inf"))) return crf_log_norm( inputs=filtered_inputs, sequence_lengths=sequence_lengths, transition_params=transition_params) return utils.smart_cond( pred=math_ops.equal( tensor_shape.dimension_value( inputs.shape[1]) or array_ops.shape(inputs)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn)
def _build_update_ops(self, mean, variance, is_training): """Builds the moving average update ops when using moving variance. Args: mean: The mean value to update with. variance: The variance value to update with. is_training: Boolean Tensor to indicate if we're currently in training mode. Returns: Tuple of `(update_mean_op, update_variance_op)` when `is_training` is or could be `True`. Returns `None` when `is_training=False`. """ def build_update_ops(): """Builds the exponential moving average update ops.""" update_mean_op = moving_averages.assign_moving_average( variable=self._moving_mean, value=mean, decay=self._decay_rate, zero_debias=False, name="update_moving_mean", ).op update_variance_op = moving_averages.assign_moving_average( variable=self._moving_variance, value=variance, decay=self._decay_rate, zero_debias=False, name="update_moving_variance", ).op return update_mean_op, update_variance_op def build_no_ops(): return (tf.no_op(), tf.no_op()) # Only make the ops if we know that `is_training=True`, or the value of # `is_training` is unknown. is_training_const = utils.constant_value(is_training) if is_training_const is None or is_training_const: update_mean_op, update_variance_op = utils.smart_cond( is_training, build_update_ops, build_no_ops, ) return (update_mean_op, update_variance_op) else: return None
def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, transition_params): """Computes the unnormalized score of all tag sequences matching tag_bitmap. tag_bitmap enables more than one tag to be considered correct at each time step. This is useful when an observed output at a given time step is consistent with more than one tag, and thus the log likelihood of that observation must take into account all possible consistent tags. Using one-hot vectors in tag_bitmap gives results identical to crf_sequence_score. Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor representing all active tags at each index for which to calculate the unnormalized score. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix. Returns: sequence_scores: A [batch_size] vector of unnormalized sequence scores. """ # If max_seq_len is 1, we skip the score calculation and simply gather the # unary potentials of all active tags. def _single_seq_fn(): filtered_inputs = array_ops.where( tag_bitmap, inputs, array_ops.fill(array_ops.shape(inputs), float("-inf"))) return math_ops.reduce_logsumexp(filtered_inputs, axis=[1, 2], keepdims=False) def _multi_seq_fn(): # Compute the logsumexp of all scores of sequences matching the given tags. filtered_inputs = array_ops.where( tag_bitmap, inputs, array_ops.fill(array_ops.shape(inputs), float("-inf"))) return crf_log_norm(inputs=filtered_inputs, sequence_lengths=sequence_lengths, transition_params=transition_params) return utils.smart_cond(pred=math_ops.equal( tensor_shape.dimension_value(inputs.shape[1]) or array_ops.shape(inputs)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn)
def _build_update_ops(self, mean, variance, is_training): """Builds the moving average update ops when using moving variance. Args: mean: The mean value to update with. variance: The variance value to update with. is_training: Boolean Tensor to indicate if we're currently in training mode. Returns: Tuple of `(update_mean_op, update_variance_op)` when `is_training` is or could be `True`. Returns `None` when `is_training=False`. """ def build_update_ops(): """Builds the exponential moving average update ops.""" update_mean_op = moving_averages.assign_moving_average( variable=self._moving_mean, value=mean, decay=self._decay_rate, zero_debias=False, name="update_moving_mean").op update_variance_op = moving_averages.assign_moving_average( variable=self._moving_variance, value=variance, decay=self._decay_rate, zero_debias=False, name="update_moving_variance").op return update_mean_op, update_variance_op def build_no_ops(): return (tf.no_op(), tf.no_op()) # Only make the ops if we know that `is_training=True`, or the value of # `is_training` is unknown. is_training_const = utils.constant_value(is_training) if is_training_const is None or is_training_const: update_mean_op, update_variance_op = utils.smart_cond( is_training, build_update_ops, build_no_ops, ) return (update_mean_op, update_variance_op) else: return None
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _fused_batch_norm_training(): return nn.fused_batch_norm(inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm(inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = utils.constant_value(training) if training_value is not False: decay = _smart_select(training, lambda: self.momentum, lambda: 1.) mean_update = moving_averages.assign_moving_average( self.moving_mean, mean, decay, zero_debias=False) variance_update = moving_averages.assign_moving_average( self.moving_variance, variance, decay, zero_debias=False) self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) return output
def saltpepper_noise(x, noise_rate, ones_rate=0.5, training=False, name=None): ''' Adds saltpepper noise (sets some elements to either 0 or 1). ones_rate controls the fraction of pepper (1s) vs salt (0s). ones_rate == 0 => dropout (noise value is always 0) ones_rate == 1 => dropin (noise value is always 1) ''' def saltpepper_noise_impl(x, noise_rate, ones_rate): assert 0 <= noise_rate <= 1 assert 0 <= ones_rate <= 1 b = tf.floor(tf.random_uniform(x.get_shape(), 0, 1) + ones_rate) c = tf.random_uniform(x.get_shape(), 0, 1) < noise_rate return tf.where(c, b, x) with ops.name_scope(name, "inputnoise", [x]) as name: return utils.smart_cond( training, lambda: saltpepper_noise_impl(x, noise_rate, ones_rate), lambda: array_ops.identity(x))
def _get_examples(file_name_queue, reader, num_threads, read_batch_size, filter_fn, parse_fn): """Get example filenames matching. Args: file_name_queue: A queue implementation that dequeues elements in first-in first-out order. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). num_threads: The number of threads enqueuing examples. read_batch_size: An int or scalar `Tensor` specifying the number of records to read at once. filter_fn: Filtering function, takes both keys as well as an `Example` Tensors and returns a boolean mask of the same shape as the input Tensors to be applied for filtering. If `None`, no filtering is done. parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. Returns: List of example file names matching `file_name_queue`. """ with ops.name_scope('read'): example_list = [] for _ in range(num_threads): keys, examples_proto = utils.smart_cond( read_batch_size > 1, lambda: reader().read_up_to(file_name_queue, read_batch_size), lambda: reader().read(file_name_queue)) if filter_fn: mask = filter_fn(keys, examples_proto) keys = array_ops.boolean_mask(keys, mask) examples_proto = array_ops.boolean_mask(examples_proto, mask) if parse_fn: parsed_examples = parse_fn(examples_proto) # Map keys into example map because batch_join doesn't support # tuple of Tensor + dict. if isinstance(parsed_examples, dict): parsed_examples[KEY_FEATURE_NAME] = keys example_list.append(parsed_examples) else: example_list.append((keys, parsed_examples)) else: example_list.append((keys, examples_proto)) return example_list
def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): """Computes the unnormalized score for a tag sequence. Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we compute the unnormalized score. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix. Returns: sequence_scores: A [batch_size] vector of unnormalized sequence scores. """ # If max_seq_len is 1, we skip the score calculation and simply gather the # unary potentials of the single tag. def _single_seq_fn(): batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0] example_inds = array_ops.reshape( math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) sequence_scores = array_ops.gather_nd( array_ops.squeeze(inputs, [1]), array_ops.concat([example_inds, tag_indices], axis=1)) sequence_scores = array_ops.where(math_ops.less_equal(sequence_lengths, 0), array_ops.zeros_like(sequence_scores), sequence_scores) return sequence_scores def _multi_seq_fn(): # Compute the scores of the given tag sequence. unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) binary_scores = crf_binary_score(tag_indices, sequence_lengths, transition_params) sequence_scores = unary_scores + binary_scores return sequence_scores return utils.smart_cond( pred=math_ops.equal( tensor_shape.dimension_value( inputs.shape[1]) or array_ops.shape(inputs)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn)
def _distort_color(image, batch_position=0, distort_color_in_yiq=False, scope=None): """Distort the color of the image.""" with tf.name_scope(scope or 'distort_color'): def distort_fn_0(image=image): """Variant 0 of distort function.""" image = tf.image.random_brightness(image, max_delta=32. / 255.) if distort_color_in_yiq: image = distort_image_ops.random_hsv_in_yiq( image, lower_saturation=0.5, upper_saturation=1.5, max_delta_hue=0.2 * math.pi) else: image = tf.image.random_saturation(image, lower=0.5, upper=1.5) image = tf.image.random_hue(image, max_delta=0.2) image = tf.image.random_contrast(image, lower=0.5, upper=1.5) return image def distort_fn_1(image=image): """Variant 1 of distort function.""" image = tf.image.random_brightness(image, max_delta=32. / 255.) image = tf.image.random_contrast(image, lower=0.5, upper=1.5) if distort_color_in_yiq: image = distort_image_ops.random_hsv_in_yiq( image, lower_saturation=0.5, upper_saturation=1.5, max_delta_hue=0.2 * math.pi) else: image = tf.image.random_saturation(image, lower=0.5, upper=1.5) image = tf.image.random_hue(image, max_delta=0.2) return image image = utils.smart_cond(batch_position % 2 == 0, distort_fn_0, distort_fn_1) # The random_* ops do not necessarily clamp. image = tf.clip_by_value(image, 0.0, 1.0) return image
def summary_writer_function(name, tensor, function, family=None): """Helper function to write summaries. Args: name: name of the summary tensor: main tensor to form the summary function: function taking a tag and a scope which writes the summary family: optional, the summary's family Returns: The result of writing the summary. """ def record(): with summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): function(tag, scope) return True return utils.smart_cond( should_record_summaries(), record, _nothing, name="")
def dropout_selu(x, rate, alpha= -1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0, noise_shape=None, seed=None, name=None, training=False): from tensorflow.python.framework import ops, tensor_shape, tensor_util from tensorflow.python.ops import array_ops, random_ops, math_ops from tensorflow.python.layers import utils import numbers import tensorflow as tf """Dropout to a value with rescaling.""" def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name): keep_prob = 1.0 - rate x = ops.convert_to_tensor(x, name="x") if isinstance(keep_prob, numbers.Real) and not 0. < keep_prob <= 1.: raise ValueError("keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob") keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha") keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) if tensor_util.constant_value(keep_prob) == 1: return x noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x) random_tensor = keep_prob random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype) binary_tensor = math_ops.floor(random_tensor) ret = x * binary_tensor + alpha * (1-binary_tensor) a = tf.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * tf.pow(alpha-fixedPointMean,2) + fixedPointVar))) b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha) ret = a * ret + b ret.set_shape(x.get_shape()) return ret with ops.name_scope(name, "dropout", [x]) as name: return utils.smart_cond(training, lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name), lambda: array_ops.identity(x))
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) training_value = utils.constant_value(training) if training_value is not False: decay = _smart_select(training, lambda: self.momentum, lambda: 1.) mean_update = moving_averages.assign_moving_average( self.moving_mean, mean, decay, zero_debias=False) variance_update = moving_averages.assign_moving_average( self.moving_variance, variance, decay, zero_debias=False) self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) return output
def _update_renorm_variable(var, weight, value): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) def _do_update(): """Updates the var and weight, returns their updated ratio.""" # Update the variables without zero debiasing. The debiasing will be # accomplished by dividing the exponential moving average by the weight. # For example, after a single update, the moving average would be # (1-decay) * value. and the weight will be 1-decay, with their ratio # giving the value. # Make sure the weight is not updated until before r and d computation. with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) new_var = self._assign_moving_average(var, value, self.renorm_momentum) new_weight = self._assign_moving_average(weight, weight_value, self.renorm_momentum) # TODO(yuefengz): the updates to var and weighted can not be batched # together if we fetch their updated values here. Consider calculating # new values and delaying the updates. return new_var / new_weight def _fake_update(): return array_ops.identity(var) return utils.smart_cond(training, _do_update, _fake_update)
def crf_decode(potentials, transition_params, sequence_length): """Decode the highest scoring sequence of tags in TensorFlow. This is a function for tensor. Args: potentials: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. transition_params: A [num_tags, num_tags] matrix of binary potentials. sequence_length: A [batch_size] vector of true sequence lengths. Returns: decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. Contains the highest scoring tag indices. best_score: A [batch_size] vector, containing the score of `decode_tags`. """ # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag # and the max activation. def _single_seq_fn(): squeezed_potentials = array_ops.squeeze(potentials, [1]) decode_tags = array_ops.expand_dims( math_ops.argmax(squeezed_potentials, axis=1), 1) best_score = math_ops.reduce_max(squeezed_potentials, axis=1) return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score def _multi_seq_fn(): """Decoding of highest scoring sequence.""" # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). num_tags = potentials.get_shape()[2].value # Computes forward decoding. Get last score and backpointers. crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] # Sequence length is not allowed to be less than zero. sequence_length_less_one = math_ops.maximum( constant_op.constant(0, dtype=sequence_length.dtype), sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, inputs=inputs, sequence_length=sequence_length_less_one, initial_state=initial_state, time_major=False, dtype=dtypes.int32) backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O] backpointers, sequence_length_less_one, seq_dim=1) # Computes backward decoding. Extract tag indices from backpointers. crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B] dtype=dtypes.int32) initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1] crf_bwd_cell, inputs=backpointers, sequence_length=sequence_length_less_one, initial_state=initial_state, time_major=False, dtype=dtypes.int32) decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T] axis=1) decode_tags = gen_array_ops.reverse_sequence( # [B, T] decode_tags, sequence_length, seq_dim=1) best_score = math_ops.reduce_max(last_score, axis=1) # [B] return decode_tags, best_score return utils.smart_cond( pred=math_ops.equal(potentials.shape[1].value or array_ops.shape(potentials)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn)
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, fused_batch_norm): """Computes batch norm correction params. Before batch normalization is frozen: We use batch statistics for batch norm. correction_scale = sigma_b/sigma_mv correction_recip = 1/correction_scale correction_offset = 0 After batch normalization is frozen: correction_scale = sigma_b/sigma_mv correction_recip = 1 correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). Batch norm is frozen if global_step > bn_freeze_delay. The corrections ensure that: a) The weights are quantized after scaling by gamma/sigma_mv. This enables smoother training as the scaling on the weights changes slowly, rather than jump across mini-batches b) Changing the values of the corrections allows for one to switch between using batch statistics to using moving mean and average, without requiring changes to batch_norm Args: context: The scope under which we look for batch norm params match: Object containing required batch norm tensors for correction computation. freeze_batch_norm_delay: Delay in steps at which computation switches from regular batch norm to frozen mean and variance. fused_batch_norm: Bool, true if fused batch norm is used. Returns: A tuple of correction_scale, correction_recip, correction_offset """ g = ops.get_default_graph() prefix = '' if not context else context + '/' with g.name_scope(prefix + 'batch_norm_correction'): recip_sigma_mv = math_ops.rsqrt( match.moving_variance_tensor + match.batch_epsilon) recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon) correction_scale = math_ops.divide( recip_sigma_mv, recip_sigma, name='scale_compute') correction_scale = array_ops.identity( correction_scale, name='correction_scale') correction_recip = math_ops.reciprocal( correction_scale, name='reciprocal_compute') correction_offset = math_ops.multiply( match.gamma_tensor, match.mean_tensor * recip_sigma - match.moving_mean_tensor * recip_sigma_mv, name='offset_compute') if freeze_batch_norm_delay is not None: use_mv_avg = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), freeze_batch_norm_delay, name='use_moving_average') else: use_mv_avg = False bn_decay_zero = 0.0 bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_mean_out = utils.smart_cond( use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') graph_editor.reroute_ts( [bn_decay_mean_out], [match.bn_decay_mean_tensor], can_modify=bn_decay_mean_consumers) if fused_batch_norm is False: bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) bn_decay_var_out = utils.smart_cond( use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_var_tensor, name='freeze_moving_var') graph_editor.reroute_ts( [bn_decay_var_out], [match.bn_decay_var_tensor], can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( use_mv_avg, lambda: array_ops.ones(correction_scale.shape), lambda: correction_recip, name='correction_recip') correction_offset = utils.smart_cond( use_mv_avg, lambda: correction_offset, lambda: array_ops.zeros(correction_offset.shape), name='correction_offset') return correction_scale, correction_recip, correction_offset
def resize_method_2(): return utils.smart_cond(batch_position % len(resize_methods) == 2, lambda: lookup(2), lambda: lookup(3))
def resize_method_1(): return utils.smart_cond(batch_position % len(resize_methods) == 1, lambda: lookup(1), resize_method_2)
def call(self, inputs, training=False): if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching return undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = utils.smart_cond(training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = utils.smart_cond(training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( mean, variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) else: new_mean, new_variance = mean, variance if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(new_mean, axis=1, keep_dims=True) new_variance = math_ops.reduce_mean(new_variance, axis=1, keep_dims=True) def _do_update(var, value): return moving_averages.assign_moving_average( var, value, self.momentum, zero_debias=False) mean_update = utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) variance_update = utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) if context.in_graph_mode(): self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) else: mean, variance = self.moving_mean, self.moving_variance outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: return undo_virtual_batching(outputs) return outputs
def _fused_batch_norm_op(self, input_batch, mean, variance, use_batch_stats): """Creates a fused batch normalization op.""" # Store the original shape of the mean and variance. mean_shape = mean.get_shape() variance_shape = variance.get_shape() # The fused batch norm expects the mean, variance, gamma and beta # tensors to have dimension 1, so we flatten them to remove the # extra dimensions. In addition, it expects the input_batch to have # dimension 4, so we reshape it accordingly. gamma_flatten = tf.reshape(self._gamma, shape=(self._num_channels,)) beta_flatten = tf.reshape(self._beta, shape=(self._num_channels,)) flatten_mean = tf.reshape(mean, shape=(self._num_channels,)) flatten_variance = tf.reshape(variance, shape=(self._num_channels,)) use_batch_stats = tf.convert_to_tensor(use_batch_stats) input_shape = input_batch.get_shape() output_shape = [-1] + input_shape.as_list()[1:] flat_image_size = np.prod(self._image_shape, dtype=np.int32) if len(self._data_format) == 4: fusable_data_format = self._data_format fusable_batch = input_batch elif self._channel_index == 1 and self._image_shape: fusable_data_format = "NCHW" fusable_batch = tf.reshape( input_batch, shape=(-1, self._num_channels, 1, flat_image_size)) else: # The CPU implementation of FusedBatchNorm only supports NHWC tensor # format for now. fusable_data_format = "NHWC" fusable_batch = tf.reshape( input_batch, shape=(-1, 1, flat_image_size, self._num_channels)) common_args = { "scale": gamma_flatten, "offset": beta_flatten, "epsilon": self._eps, "data_format": fusable_data_format, "name": "batch_norm" } def use_batch_stats_fused_batch_norm(): return tf.nn.fused_batch_norm( fusable_batch, mean=None, variance=None, is_training=True, **common_args) def moving_average_fused_batch_norm(): return tf.nn.fused_batch_norm( fusable_batch, mean=flatten_mean, variance=flatten_variance, is_training=False, **common_args) batch_norm_op, mean, variance = utils.smart_cond( use_batch_stats, use_batch_stats_fused_batch_norm, moving_average_fused_batch_norm) if len(self._data_format) != 4: batch_norm_op = tf.reshape(batch_norm_op, output_shape) mean = tf.reshape(mean, mean_shape) variance = tf.reshape(variance, variance_shape) return batch_norm_op, mean, variance
def bucket(tensors, which_bucket, batch_size, num_buckets, num_threads=1, capacity=32, bucket_capacities=None, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, keep_input=True, shared_name=None, name=None): """Lazy bucketing of input tensors according to `which_bucket`. The argument `tensors` can be a list or a dictionary of tensors. The value returned by the function will be of the same type as `tensors`. The tensors entering this function are put into the bucket given by `which_bucket`. Each bucket has its own queue. When a bucket contains `batch_size` elements, this minibatch is pushed onto a top queue. The tensors returned from this function are a the result of dequeueing the next minibatch from this top queue. This function is implemented using several queues. A `QueueRunner` for the queues is added to the current `Graph`'s `QUEUE_RUNNER` collection. As the returned tensors are the result of a dequeue operation, evaluating them will throw a `tf.errors.OutOfRangeError` when the input queue is exhausted. If these tensors are feeding another input queue, its queue runner will catch this exception, however, if they are used in your main thread you are responsible for catching this yourself. *N.B.:* If `dynamic_pad` is `False`, you must ensure that either (i) the `shapes` argument is passed, or (ii) all of the tensors in `tensors` must have fully-defined shapes. `ValueError` will be raised if neither of these conditions holds. If `dynamic_pad` is `True`, it is sufficient that the *rank* of the tensors is known, but individual dimensions may have shape `None`. In this case, for each enqueue the dimensions with value `None` may have a variable length; upon dequeue, the output tensors will be padded on the right to the maximum shape of the tensors in the current minibatch. For numbers, this padding takes value 0. For strings, this padding is the empty string. See `PaddingFIFOQueue` for more info. If `allow_smaller_final_batch` is `True`, a smaller batch value than `batch_size` is returned when the queues are closed and there are not enough elements to fill the batch, otherwise the pending elements are discarded. In addition, all output tensors' static shapes, as accessed via the `get_shape()` method will have a 0th `Dimension` value of `None`, and operations that depend on fixed batch_size would fail. Args: tensors: The list or dictionary of tensors, representing a single element, to bucket. Nested lists are not supported. which_bucket: An `int32` scalar Tensor taking a value in `[0, num_buckets)`. batch_size: The new batch size pulled from the queue (all queues will have the same size). If a list is passed in then each bucket will have a different batch_size. (python int, int32 scalar or iterable of integers of length num_buckets). num_buckets: A python integer, the number of buckets. num_threads: An integer. The number of threads enqueuing `tensors`. capacity: An integer. The maximum number of minibatches in the top queue, and also (by default) the maximum number of elements within each bucket. bucket_capacities: (Optional) None or a list of integers, the capacities of each bucket. If None, capacity is used (default). If specified, it must be a list of integers of length num_buckets: the i-th element is used as capacity for the i-th bucket queue. shapes: (Optional) The shapes for each example. Defaults to the inferred shapes for `tensors`. dynamic_pad: Boolean. Allow variable dimensions in input shapes. The given dimensions are padded upon dequeue so that tensors within a batch have the same shapes. allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final batches to be smaller if there are insufficient items left in the queues. keep_input: A `bool` scalar Tensor. If provided, this tensor controls whether the input is added to the queue or not. If it evaluates `True`, then `tensors` are added to the bucket; otherwise they are dropped. This tensor essentially acts as a filtering mechanism. shared_name: (Optional). If set, the queues will be shared under the given name across multiple sessions. name: (Optional) A name for the operations. Returns: A tuple `(bucket, outputs)` where `bucket` is a `int32` scalar tensor and `outputs` is a list or dictionary of batched outputs corresponding to elements of `tensors`. Every step will receive a new bucket of outputs. Raises: ValueError: If the `shapes` are not specified, and cannot be inferred from the elements of `tensors` or if batch_size is a sequence but its length != num_buckets. Also if bucket_capacities is not None but its length != num_buckets. """ batch_size_per_bucket = False if isinstance(batch_size, (list, tuple)): batch_size_per_bucket = True if len(batch_size) != num_buckets: raise ValueError( "If batch_size is a list it must have num_buckets elements") else: batch_size = [batch_size] * num_buckets if bucket_capacities is None: bucket_capacities = [capacity] * num_buckets if len(bucket_capacities) != num_buckets: raise ValueError( "The list bucket_capacities (%s) must have exactly num_buckets (%d) " "elements." % (str(bucket_capacities), num_buckets)) tensor_list = _as_tensor_list(tensors) with ops.name_scope(name, "bucket", tensor_list) as name: tensor_list = _validate_bucket(tensor_list) keep_input = _validate_keep_input(keep_input, enqueue_many=False) (tensor_list, sparse_info) = _store_sparse_tensors( tensor_list, enqueue_many=False, keep_input=keep_input) # Round-trip batch_size to a tensor, and possibly back for i, bucket_batch_size in enumerate(batch_size): bucket_batch_size = ops.convert_to_tensor( bucket_batch_size, dtype=dtypes.int32, name="batch_size") static_batch_size = tensor_util.constant_value(bucket_batch_size) batch_size[i] = (static_batch_size if static_batch_size is not None else bucket_batch_size) types = _dtypes([tensor_list]) shapes = _shapes([tensor_list], shapes, enqueue_many=False) which_bucket = ops.convert_to_tensor( which_bucket, dtype=dtypes.int32, name="which_bucket") queue_creator = _which_queue(dynamic_pad) bucket_queues = [] for i in range(num_buckets): shared_name_i = ("%s_%d" % (shared_name, i) if shared_name is not None else None) bucket_queues.append( queue_creator( capacity=bucket_capacities[i], dtypes=types, shapes=shapes, shared_name=shared_name_i, name="bucket_queue_%d" % i)) maybe_static_batch_size = ( None if (allow_smaller_final_batch or batch_size_per_bucket) else static_batch_size) bucket_shapes = [ tensor_shape.vector(maybe_static_batch_size).concatenate(s) for s in bucket_queues[0].shapes ] # top_queue is a PaddingFIFOQueue even if the bucket queues are regular FIFO # queues because if we use allow_smaller_final_batch, shapes will # contain Nones in their first entry; as a result, a regular # FIFOQueue would die when being passed shapes that are not fully defined. top_queue = data_flow_ops.PaddingFIFOQueue( capacity=capacity, dtypes=[dtypes.int32] + types, shapes=[tensor_shape.scalar()] + bucket_shapes, shared_name=shared_name, name="top_queue") def enqueue_which(): """Return an op that enqueues conditionally in one of the queues.""" def enqueue_single(i): return bucket_queues[i].enqueue(tensor_list) enqueues = [ control_flow_ops.cond( math_ops.equal(which_bucket, i), functools.partial(enqueue_single, i), control_flow_ops.no_op) for i in range(num_buckets) ] return control_flow_ops.group(*enqueues, name="group_enqueues") maybe_enqueue = utils.smart_cond( keep_input, enqueue_which, control_flow_ops.no_op) bucket_enqueue_ops = [maybe_enqueue] * num_threads if allow_smaller_final_batch: which_dequeue = lambda q: q.dequeue_up_to else: which_dequeue = lambda q: q.dequeue_many def make_list(t): if isinstance(t, (list, tuple)): return t else: return [t] enqueues_to_top = [ top_queue.enqueue( [constant_op.constant(i)] + make_list(which_dequeue(q)( bs, name="read_bucket_%d" % i)), name="enqueue_from_bucket_%d" % i) for i, (q, bs) in enumerate(zip(bucket_queues, batch_size)) ] queue_runner.add_queue_runner( queue_runner.QueueRunner( bucket_queues[0], enqueues_to_top, close_op=top_queue.close(), cancel_op=top_queue.close(cancel_pending_enqueues=True), queue_closed_exception_types=(errors.OutOfRangeError, errors.CancelledError))) queue_runner.add_queue_runner( queue_runner.QueueRunner( top_queue, bucket_enqueue_ops, close_op=control_flow_ops.group( *[q.close() for q in bucket_queues]), cancel_op=control_flow_ops.group( *[q.close(cancel_pending_enqueues=True) for q in bucket_queues]), queue_closed_exception_types=(errors.OutOfRangeError, errors.CancelledError))) for q in bucket_queues: summary.scalar("bucket/%s/size" % q.name, math_ops.cast(top_queue.size(), dtypes.float32)) summary.scalar("bucket/%s/fraction_of_%d_full" % (top_queue.name, capacity), math_ops.cast(top_queue.size(), dtypes.float32) * (1. / capacity)) dequeued = top_queue.dequeue(name="dequeue_top") which_bucket_dequeued = dequeued[0] dequeued = dequeued[1:] if len(dequeued) == 1: dequeued = dequeued[0] dequeued = _restore_sparse_tensors(dequeued, sparse_info) return (which_bucket_dequeued, _as_original_type(tensors, dequeued))
def _build_statistics(self, input_batch, use_batch_stats, stat_dtype): """Builds the statistics part of the graph when using moving variance. Args: input_batch: Input batch Tensor. use_batch_stats: Boolean to indicate if batch statistics should be calculated, otherwise moving averages are returned. stat_dtype: TensorFlow datatype to use for the moving mean and variance. Returns: Tuple of (mean, variance), each of the same datatype as `input_batch`. """ # Set up our moving statistics. When connecting in parallel, this is shared. if self.MOVING_MEAN not in self._initializers: self._initializers[self.MOVING_MEAN] = create_mean_initializer() self._moving_mean = tf.get_variable( "moving_mean", dtype=stat_dtype, shape=(self._num_channels,), collections=[ tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES, ], initializer=self._initializers[self.MOVING_MEAN], trainable=False) if self.MOVING_VARIANCE not in self._initializers: self._initializers[self.MOVING_VARIANCE] = create_variance_initializer() self._moving_variance = tf.get_variable( "moving_variance", dtype=stat_dtype, shape=(self._num_channels,), collections=[ tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES, ], initializer=self._initializers[self.MOVING_VARIANCE], trainable=False) def build_batch_stats(): """Builds the batch statistics calculation ops.""" mean, variance = tf.nn.moments(input_batch, self._axis, keep_dims=True, name="normalize_moments") return mean, variance def build_moving_stats(): """Retrieves the moving statistics.""" # If necessary, cast the moving statistics to match the input type. # This is required by tf.nn.batch_normalization. input_dtype = input_batch.dtype.base_dtype if stat_dtype == input_dtype: return ( tf.identity(self._moving_mean), tf.identity(self._moving_variance), ) else: return ( tf.cast(self._moving_mean, input_dtype), tf.cast(self._moving_variance, input_dtype), ) mean, variance = utils.smart_cond( use_batch_stats, build_batch_stats, build_moving_stats, ) return mean, variance
def call(self, inputs, training=False): # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. input_shape = inputs.get_shape() ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis].value # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1]) # Determine a boolean value for `training`: could be True, False, or None. training_value = utils.constant_value(training) if needs_broadcasting: # In this case we must explictly broadcast all parameters. if self.center: broadcast_beta = array_ops.reshape(self.beta, broadcast_shape) else: broadcast_beta = None if self.scale: broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape) else: broadcast_gamma = None if training_value is not False: if needs_broadcasting: broadcast_mean, broadcast_variance = nn.moments( inputs, reduction_axes, keep_dims=True) mean = array_ops.reshape(broadcast_mean, [-1]) variance = array_ops.reshape(broadcast_variance, [-1]) else: mean, variance = nn.moments(inputs, reduction_axes) # Prepare updates if necessary. if not self.updates: mean_update = moving_averages.assign_moving_average( self.moving_mean, mean, self.momentum, zero_debias=False) variance_update = moving_averages.assign_moving_average( self.moving_variance, variance, self.momentum, zero_debias=False) # In the future this should be refactored into a self.add_update # methods in order to allow for instance-based BN layer sharing # across unrelated input streams (e.g. like in Keras). self.updates.append(mean_update) self.updates.append(variance_update) # Normalize batch. We do this inside separate functions for training # and inference so as to avoid evaluating both branches. def normalize_in_test(): if needs_broadcasting: broadcast_moving_mean = array_ops.reshape(self.moving_mean, broadcast_shape) broadcast_moving_variance = array_ops.reshape(self.moving_variance, broadcast_shape) return nn.batch_normalization(inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, self.epsilon) else: return nn.batch_normalization(inputs, self.moving_mean, self.moving_variance, self.beta if self.center else None, self.gamma if self.scale else None, self.epsilon) def normalize_in_training(): if needs_broadcasting: return nn.batch_normalization(inputs, broadcast_mean, broadcast_variance, broadcast_beta, broadcast_gamma, self.epsilon) else: return nn.batch_normalization(inputs, mean, variance, self.beta if self.center else None, self.gamma if self.scale else None, self.epsilon) return utils.smart_cond(training, normalize_in_training, normalize_in_test)