def add_noise_add(d, noise_scale):
    """Inject additive noise"""
    d = smart_cond(
        is_training,
        lambda: d + tf.random_normal(tf.shape(d), stddev=noise_scale),
        lambda: d)
    return d
def dropout(d, len):
    """Dropout dependent on sequence length"""
    if dropout_keep_prob < 1:
        prob = (1.0 - dropout_keep_prob) / len
        d = smart_cond(is_training, lambda: tf.nn.dropout(d, rate=prob),
                       lambda: d)
    return d
Example #3
0
  def _build_update_ops(self, mean, variance, is_training):
    """Builds the moving average update ops when using moving variance.

    Args:
      mean: The mean value to update with.
      variance: The variance value to update with.
      is_training: Boolean Tensor to indicate if we're currently in
        training mode.

    Returns:
      Tuple of `(update_mean_op, update_variance_op)` when `is_training` is or
      could be `True`. Returns `None` when `is_training=False`.
    """

    def build_update_ops():
      """Builds the exponential moving average update ops."""

      update_mean_op = moving_averages.assign_moving_average(
          variable=self._moving_mean,
          value=tf.reshape(mean, (self._num_channels,)),
          decay=self._decay_rate,
          zero_debias=False,
          name="update_moving_mean").op

      update_variance_op = moving_averages.assign_moving_average(
          variable=self._moving_variance,
          value=tf.reshape(variance, (self._num_channels,)),
          decay=self._decay_rate,
          zero_debias=False,
          name="update_moving_variance").op

      return update_mean_op, update_variance_op

    def build_no_ops():
      return tf.no_op(), tf.no_op()

    # Only make the ops if we know that `is_training=True`, or the value of
    # `is_training` is unknown.
    is_training_const = utils.constant_value(is_training)
    if is_training_const is None or is_training_const:
      update_mean_op, update_variance_op = contrib_framework.smart_cond(
          is_training,
          build_update_ops,
          build_no_ops,
      )
      return update_mean_op, update_variance_op
    else:
      return None
Example #4
0
def crf_beta_backward(inputs, transition_params):
    batch_size = tf.shape(inputs)[0]
    seq_len = tf.shape(inputs)[1]
    n_tags = tf.shape(inputs)[2]

    def _single_seq_fn():
        return tf.ones([batch_size, 1, n_tags]) * -10000

    def _multi_seq_fn():
        trans_mat_t = tf.transpose(transition_params)

        def scan_step_backward(prev_betas, inputs):
            prev_betas_ex = tf.expand_dims(prev_betas, 2)
            inputs_ex = tf.expand_dims(inputs, 2)
            trans_scores = prev_betas_ex + trans_mat_t + inputs_ex
            new_batas = tf.reduce_logsumexp(trans_scores, 1)

            # new_batas = tf.reduce_logsumexp(trans_scores, 2)
            # trans_scores = prev_betas_ex + trans_mat_t
            # new_batas = inputs + tf.reduce_logsumexp(trans_scores, 2)
            return new_batas

        elems = tf.reverse(tf.transpose(inputs, [1, 0, 2]), [0])
        # init_val = tf.ones([batch_size, n_tags]) * -10000
        init_val = tf.zeros([batch_size, n_tags])
        rest_inputs = elems[:-1]
        betas_m = tf.scan(scan_step_backward,
                          rest_inputs,
                          initializer=init_val)
        betas_m = tf.concat([tf.expand_dims(init_val, 0), betas_m], axis=0)
        betas_m = tf.reverse(betas_m, [0])
        return betas_m

    betas = smart_cond(pred=tf.equal(seq_len, 1),
                       true_fn=_single_seq_fn,
                       false_fn=_multi_seq_fn)

    return betas
Example #5
0
def crf_log_norm_forward(inputs, sequence_lengths, transition_params):
    first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
    first_input = tf.squeeze(first_input, [1])

    def _single_seq_fn():
        log_norm = tf.reduce_logsumexp(first_input, [1])
        # Mask `log_norm` of the sequences with length <= zero.
        log_norm = tf.where(tf.less_equal(sequence_lengths, 0),
                            tf.zeros_like(log_norm), log_norm)
        return log_norm, inputs

    def _multi_seq_fn():
        """Forward computation of alpha values."""
        rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])

        # Compute the alpha values in the forward algorithm in order to get the
        # partition function.
        forward_cell = CrfForwardRnnCell(transition_params)
        # Sequence length is not allowed to be less than zero.
        sequence_lengths_less_one = tf.maximum(
            tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1)
        all_alphas, alphas_final = rnn.dynamic_rnn(
            cell=forward_cell,
            inputs=rest_of_input,
            sequence_length=sequence_lengths_less_one,
            initial_state=first_input,
            dtype=tf.float32)
        log_norm = tf.reduce_logsumexp(alphas_final, [1])
        # Mask `log_norm` of the sequences with length <= zero.
        log_norm = tf.where(tf.less_equal(sequence_lengths, 0),
                            tf.zeros_like(log_norm), log_norm)
        return log_norm, all_alphas

    log_norm_z, alphas = smart_cond(pred=tf.equal(tf.shape(inputs)[1], 1),
                                    true_fn=_single_seq_fn,
                                    false_fn=_multi_seq_fn)

    return log_norm_z, alphas
Example #6
0
  def _fused_batch_norm_op(self, input_batch, mean, variance, use_batch_stats):
    """Creates a fused batch normalization op."""
    # Store the original shape of the mean and variance.
    mean_shape = mean.get_shape()
    variance_shape = variance.get_shape()
    # The fused batch norm expects the mean, variance, gamma and beta
    # tensors to have dimension 1, so we flatten them to remove the
    # extra dimensions. In addition, it expects the input_batch to have
    # dimension 4, so we reshape it accordingly.
    gamma_flatten = tf.reshape(self._gamma, shape=(self._num_channels,))
    beta_flatten = tf.reshape(self._beta, shape=(self._num_channels,))
    flatten_mean = tf.reshape(mean, shape=(self._num_channels,))
    flatten_variance = tf.reshape(variance, shape=(self._num_channels,))
    use_batch_stats = tf.convert_to_tensor(use_batch_stats)

    input_shape = input_batch.get_shape()
    output_shape = tf.shape(input_batch)

    flat_image_size = tf.cast(tf.reduce_prod(self._image_shape, keepdims=True),
                              tf.int64)

    if len(self._data_format) == 4:
      fusable_data_format = self._data_format
      fusable_batch = input_batch
    elif self._channel_index == 1 and input_shape.rank > 2:
      fusable_data_format = "NCHW"
      fusable_shape = tf.concat(
          [[-1, self._num_channels, 1], flat_image_size], axis=0)
      fusable_batch = tf.reshape(input_batch, shape=fusable_shape)
    else:
      # The CPU implementation of FusedBatchNorm only supports NHWC tensor
      # format for now.
      fusable_data_format = "NHWC"
      fusable_shape = tf.concat(
          [[-1, 1], flat_image_size, [self._num_channels]], axis=0)
      fusable_batch = tf.reshape(input_batch, shape=fusable_shape)

    common_args = {
        "scale": gamma_flatten,
        "offset": beta_flatten,
        "epsilon": self._eps,
        "data_format": fusable_data_format,
        "name": "batch_norm"
    }

    def use_batch_stats_fused_batch_norm():
      return tf.nn.fused_batch_norm(
          fusable_batch,
          mean=None,
          variance=None,
          is_training=True,
          **common_args)

    def moving_average_fused_batch_norm():
      return tf.nn.fused_batch_norm(
          fusable_batch,
          mean=flatten_mean,
          variance=flatten_variance,
          is_training=False,
          **common_args)

    batch_norm_op, mean, variance = contrib_framework.smart_cond(
        use_batch_stats, use_batch_stats_fused_batch_norm,
        moving_average_fused_batch_norm)

    if len(self._data_format) != 4:
      batch_norm_op = tf.reshape(batch_norm_op, output_shape)
    mean = tf.reshape(mean, mean_shape)
    variance = tf.reshape(variance, variance_shape)
    return batch_norm_op, mean, variance
Example #7
0
  def _build_statistics(self, input_batch, use_batch_stats, stat_dtype):
    """Builds the statistics part of the graph when using moving variance.

    Args:
      input_batch: Input batch Tensor.
      use_batch_stats: Boolean to indicate if batch statistics should be
        calculated, otherwise moving averages are returned.
      stat_dtype: TensorFlow datatype to use for the moving mean and variance.

    Returns:
      Tuple of (mean, variance), each of the same datatype as `input_batch`.
    """
    # Set up our moving statistics. When connecting in parallel, this is shared.
    if self.MOVING_MEAN not in self._initializers:
      self._initializers[self.MOVING_MEAN] = create_mean_initializer()
    self._moving_mean = tf.get_variable(
        "moving_mean",
        dtype=stat_dtype,
        shape=(self._num_channels,),
        collections=[
            tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
            tf.GraphKeys.GLOBAL_VARIABLES,
        ],
        initializer=self._initializers[self.MOVING_MEAN],
        trainable=False)

    if self.MOVING_VARIANCE not in self._initializers:
      self._initializers[self.MOVING_VARIANCE] = create_variance_initializer()
    self._moving_variance = tf.get_variable(
        "moving_variance",
        dtype=stat_dtype,
        shape=(self._num_channels,),
        collections=[
            tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
            tf.GraphKeys.GLOBAL_VARIABLES,
        ],
        initializer=self._initializers[self.MOVING_VARIANCE],
        trainable=False)

    def build_batch_stats():
      """Builds the batch statistics calculation ops."""
      mean, variance = tf.nn.moments(input_batch, self._axis,
                                     keep_dims=True, name="normalize_moments")

      return mean, variance

    def build_moving_stats():
      """Retrieves the moving statistics."""
      # If necessary, cast the moving statistics to match the input type.
      # This is required by tf.nn.batch_normalization.
      input_dtype = input_batch.dtype
      if stat_dtype == input_dtype:
        return (
            tf.identity(self._moving_mean),
            tf.identity(self._moving_variance),
        )
      else:
        return (
            tf.cast(self._moving_mean, input_dtype),
            tf.cast(self._moving_variance, input_dtype),
        )

    mean, variance = contrib_framework.smart_cond(
        use_batch_stats,
        build_batch_stats,
        build_moving_stats,
    )

    return mean, variance
def add_noise_mul(d, noise_scale):
    """Inject multiplicative noise"""
    d = smart_cond(
        is_training, lambda: d * tf.random_normal(
            tf.shape(d), mean=1.0, stddev=noise_scale), lambda: d)
    return d