def add_noise_add(d, noise_scale): """Inject additive noise""" d = smart_cond( is_training, lambda: d + tf.random_normal(tf.shape(d), stddev=noise_scale), lambda: d) return d
def dropout(d, len): """Dropout dependent on sequence length""" if dropout_keep_prob < 1: prob = (1.0 - dropout_keep_prob) / len d = smart_cond(is_training, lambda: tf.nn.dropout(d, rate=prob), lambda: d) return d
def _build_update_ops(self, mean, variance, is_training): """Builds the moving average update ops when using moving variance. Args: mean: The mean value to update with. variance: The variance value to update with. is_training: Boolean Tensor to indicate if we're currently in training mode. Returns: Tuple of `(update_mean_op, update_variance_op)` when `is_training` is or could be `True`. Returns `None` when `is_training=False`. """ def build_update_ops(): """Builds the exponential moving average update ops.""" update_mean_op = moving_averages.assign_moving_average( variable=self._moving_mean, value=tf.reshape(mean, (self._num_channels,)), decay=self._decay_rate, zero_debias=False, name="update_moving_mean").op update_variance_op = moving_averages.assign_moving_average( variable=self._moving_variance, value=tf.reshape(variance, (self._num_channels,)), decay=self._decay_rate, zero_debias=False, name="update_moving_variance").op return update_mean_op, update_variance_op def build_no_ops(): return tf.no_op(), tf.no_op() # Only make the ops if we know that `is_training=True`, or the value of # `is_training` is unknown. is_training_const = utils.constant_value(is_training) if is_training_const is None or is_training_const: update_mean_op, update_variance_op = contrib_framework.smart_cond( is_training, build_update_ops, build_no_ops, ) return update_mean_op, update_variance_op else: return None
def crf_beta_backward(inputs, transition_params): batch_size = tf.shape(inputs)[0] seq_len = tf.shape(inputs)[1] n_tags = tf.shape(inputs)[2] def _single_seq_fn(): return tf.ones([batch_size, 1, n_tags]) * -10000 def _multi_seq_fn(): trans_mat_t = tf.transpose(transition_params) def scan_step_backward(prev_betas, inputs): prev_betas_ex = tf.expand_dims(prev_betas, 2) inputs_ex = tf.expand_dims(inputs, 2) trans_scores = prev_betas_ex + trans_mat_t + inputs_ex new_batas = tf.reduce_logsumexp(trans_scores, 1) # new_batas = tf.reduce_logsumexp(trans_scores, 2) # trans_scores = prev_betas_ex + trans_mat_t # new_batas = inputs + tf.reduce_logsumexp(trans_scores, 2) return new_batas elems = tf.reverse(tf.transpose(inputs, [1, 0, 2]), [0]) # init_val = tf.ones([batch_size, n_tags]) * -10000 init_val = tf.zeros([batch_size, n_tags]) rest_inputs = elems[:-1] betas_m = tf.scan(scan_step_backward, rest_inputs, initializer=init_val) betas_m = tf.concat([tf.expand_dims(init_val, 0), betas_m], axis=0) betas_m = tf.reverse(betas_m, [0]) return betas_m betas = smart_cond(pred=tf.equal(seq_len, 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn) return betas
def crf_log_norm_forward(inputs, sequence_lengths, transition_params): first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1]) first_input = tf.squeeze(first_input, [1]) def _single_seq_fn(): log_norm = tf.reduce_logsumexp(first_input, [1]) # Mask `log_norm` of the sequences with length <= zero. log_norm = tf.where(tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm) return log_norm, inputs def _multi_seq_fn(): """Forward computation of alpha values.""" rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1]) # Compute the alpha values in the forward algorithm in order to get the # partition function. forward_cell = CrfForwardRnnCell(transition_params) # Sequence length is not allowed to be less than zero. sequence_lengths_less_one = tf.maximum( tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1) all_alphas, alphas_final = rnn.dynamic_rnn( cell=forward_cell, inputs=rest_of_input, sequence_length=sequence_lengths_less_one, initial_state=first_input, dtype=tf.float32) log_norm = tf.reduce_logsumexp(alphas_final, [1]) # Mask `log_norm` of the sequences with length <= zero. log_norm = tf.where(tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm) return log_norm, all_alphas log_norm_z, alphas = smart_cond(pred=tf.equal(tf.shape(inputs)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn) return log_norm_z, alphas
def _fused_batch_norm_op(self, input_batch, mean, variance, use_batch_stats): """Creates a fused batch normalization op.""" # Store the original shape of the mean and variance. mean_shape = mean.get_shape() variance_shape = variance.get_shape() # The fused batch norm expects the mean, variance, gamma and beta # tensors to have dimension 1, so we flatten them to remove the # extra dimensions. In addition, it expects the input_batch to have # dimension 4, so we reshape it accordingly. gamma_flatten = tf.reshape(self._gamma, shape=(self._num_channels,)) beta_flatten = tf.reshape(self._beta, shape=(self._num_channels,)) flatten_mean = tf.reshape(mean, shape=(self._num_channels,)) flatten_variance = tf.reshape(variance, shape=(self._num_channels,)) use_batch_stats = tf.convert_to_tensor(use_batch_stats) input_shape = input_batch.get_shape() output_shape = tf.shape(input_batch) flat_image_size = tf.cast(tf.reduce_prod(self._image_shape, keepdims=True), tf.int64) if len(self._data_format) == 4: fusable_data_format = self._data_format fusable_batch = input_batch elif self._channel_index == 1 and input_shape.rank > 2: fusable_data_format = "NCHW" fusable_shape = tf.concat( [[-1, self._num_channels, 1], flat_image_size], axis=0) fusable_batch = tf.reshape(input_batch, shape=fusable_shape) else: # The CPU implementation of FusedBatchNorm only supports NHWC tensor # format for now. fusable_data_format = "NHWC" fusable_shape = tf.concat( [[-1, 1], flat_image_size, [self._num_channels]], axis=0) fusable_batch = tf.reshape(input_batch, shape=fusable_shape) common_args = { "scale": gamma_flatten, "offset": beta_flatten, "epsilon": self._eps, "data_format": fusable_data_format, "name": "batch_norm" } def use_batch_stats_fused_batch_norm(): return tf.nn.fused_batch_norm( fusable_batch, mean=None, variance=None, is_training=True, **common_args) def moving_average_fused_batch_norm(): return tf.nn.fused_batch_norm( fusable_batch, mean=flatten_mean, variance=flatten_variance, is_training=False, **common_args) batch_norm_op, mean, variance = contrib_framework.smart_cond( use_batch_stats, use_batch_stats_fused_batch_norm, moving_average_fused_batch_norm) if len(self._data_format) != 4: batch_norm_op = tf.reshape(batch_norm_op, output_shape) mean = tf.reshape(mean, mean_shape) variance = tf.reshape(variance, variance_shape) return batch_norm_op, mean, variance
def _build_statistics(self, input_batch, use_batch_stats, stat_dtype): """Builds the statistics part of the graph when using moving variance. Args: input_batch: Input batch Tensor. use_batch_stats: Boolean to indicate if batch statistics should be calculated, otherwise moving averages are returned. stat_dtype: TensorFlow datatype to use for the moving mean and variance. Returns: Tuple of (mean, variance), each of the same datatype as `input_batch`. """ # Set up our moving statistics. When connecting in parallel, this is shared. if self.MOVING_MEAN not in self._initializers: self._initializers[self.MOVING_MEAN] = create_mean_initializer() self._moving_mean = tf.get_variable( "moving_mean", dtype=stat_dtype, shape=(self._num_channels,), collections=[ tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES, ], initializer=self._initializers[self.MOVING_MEAN], trainable=False) if self.MOVING_VARIANCE not in self._initializers: self._initializers[self.MOVING_VARIANCE] = create_variance_initializer() self._moving_variance = tf.get_variable( "moving_variance", dtype=stat_dtype, shape=(self._num_channels,), collections=[ tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES, ], initializer=self._initializers[self.MOVING_VARIANCE], trainable=False) def build_batch_stats(): """Builds the batch statistics calculation ops.""" mean, variance = tf.nn.moments(input_batch, self._axis, keep_dims=True, name="normalize_moments") return mean, variance def build_moving_stats(): """Retrieves the moving statistics.""" # If necessary, cast the moving statistics to match the input type. # This is required by tf.nn.batch_normalization. input_dtype = input_batch.dtype if stat_dtype == input_dtype: return ( tf.identity(self._moving_mean), tf.identity(self._moving_variance), ) else: return ( tf.cast(self._moving_mean, input_dtype), tf.cast(self._moving_variance, input_dtype), ) mean, variance = contrib_framework.smart_cond( use_batch_stats, build_batch_stats, build_moving_stats, ) return mean, variance
def add_noise_mul(d, noise_scale): """Inject multiplicative noise""" d = smart_cond( is_training, lambda: d * tf.random_normal( tf.shape(d), mean=1.0, stddev=noise_scale), lambda: d) return d