示例#1
0
def smart_cond(pred, true_fn=None, false_fn=None, name=None):
    """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.

  If `pred` is a bool or has a constant value, we return either `true_fn()`
  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.

  Arguments:
    pred: A scalar determining whether to return the result of `true_fn` or
      `false_fn`.
    true_fn: The callable to be performed if pred is true.
    false_fn: The callable to be performed if pred is false.
    name: Optional name prefix when using `tf.cond`.

  Returns:
    Tensors returned by the call to either `true_fn` or `false_fn`.

  Raises:
    TypeError: If `true_fn` or `false_fn` is not callable.
  """
    if not callable(true_fn):
        raise TypeError('`true_fn` must be callable.')
    if not callable(false_fn):
        raise TypeError('`false_fn` must be callable.')
    pred_value = tf.get_static_value(pred)
    if isinstance(pred, tf.Variable) or pred_value is None:
        return tf.cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
    if pred_value:
        return true_fn()
    else:
        return false_fn()
示例#2
0
文件: utils.py 项目: google/automl
def get_ckpt_var_map(ckpt_path, ckpt_scope, var_scope, skip_mismatch=None):
  """Get a var map for restoring from pretrained checkpoints.

  Args:
    ckpt_path: string. A pretrained checkpoint path.
    ckpt_scope: string. Scope name for checkpoint variables.
    var_scope: string. Scope name for model variables.
    skip_mismatch: skip variables if shape mismatch.

  Returns:
    var_map: a dictionary from checkpoint name to model variables.
  """
  logging.info('Init model from checkpoint {}'.format(ckpt_path))
  if not ckpt_scope.endswith('/') or not var_scope.endswith('/'):
    raise ValueError('Please specific scope name ending with /')
  if ckpt_scope.startswith('/'):
    ckpt_scope = ckpt_scope[1:]
  if var_scope.startswith('/'):
    var_scope = var_scope[1:]

  var_map = {}
  # Get the list of vars to restore.
  model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=var_scope)
  reader = tf.train.load_checkpoint(ckpt_path)
  ckpt_var_name_to_shape = reader.get_variable_to_shape_map()
  ckpt_var_names = set(reader.get_variable_to_shape_map().keys())

  if tf.distribute.get_replica_context():
    replica_id = tf.get_static_value(
        tf.distribute.get_replica_context().replica_id_in_sync_group)
  else:
    replica_id = 0

  for i, v in enumerate(model_vars):
    var_op_name = v.op.name

    if replica_id >= 1:
      var_op_name = ''.join(var_op_name.rsplit(f'/replica_{replica_id}', 1))

    if not var_op_name.startswith(var_scope):
      logging.info('skip {} -- does not match scope {}'.format(
          var_op_name, var_scope))
    ckpt_var = ckpt_scope + var_op_name[len(var_scope):]
    if 'global_step' in ckpt_var:
      continue

    if (ckpt_var not in ckpt_var_names and
        var_op_name.endswith('/ExponentialMovingAverage')):
      ckpt_var = ckpt_scope + var_op_name[:-len('/ExponentialMovingAverage')]

    if ckpt_var not in ckpt_var_names:
      if 'Momentum' in ckpt_var or 'RMSProp' in ckpt_var:
        # Skip optimizer variables.
        continue
      if skip_mismatch:
        logging.info('skip {} ({}) -- not in ckpt'.format(
            var_op_name, ckpt_var))
        continue
      raise ValueError('{} is not in ckpt {}'.format(v.op, ckpt_path))

    if v.shape != ckpt_var_name_to_shape[ckpt_var]:
      if skip_mismatch:
        logging.info('skip {} ({} vs {}) -- shape mismatch'.format(
            var_op_name, v.shape, ckpt_var_name_to_shape[ckpt_var]))
        continue
      raise ValueError('shape mismatch {} ({} vs {})'.format(
          var_op_name, v.shape, ckpt_var_name_to_shape[ckpt_var]))

    if i < 5:
      # Log the first few elements for sanity check.
      logging.info('Init {} from ckpt var {}'.format(var_op_name, ckpt_var))
    var_map[ckpt_var] = v

  return var_map
示例#3
0
def stateless_dropout(x: tf.Tensor,
                      rate: float,
                      seed: tf.Tensor,
                      noise_shape: Optional[Union[Sequence[int],
                                                  tf.TensorShape]] = None,
                      name: Optional[Text] = None) -> tf.Tensor:
    """Computes dropout: randomly sets elements to zero to prevent overfitting.

  See https://www.tensorflow.org/api_docs/python/tf/nn/dropout.
  This version differs in that the seed is required if the rate is nonzero.

  Args:
    x: A floating point tensor.
    rate: A scalar `Tensor` with the same type as x. The probability that each
      element is dropped. For example, setting rate=0.1 would drop 10% of input
      elements.
    seed: A shape [2] integer Tensor of seeds to the random number generator.
      Must have dtype `tf.int32` when compiling to XLA.
    noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for
      randomly generated keep/drop flags.
    name: A name for this operation (optional).

  Returns:
    A `Tensor` of the same shape of `x`.

  Raises:
    ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating point
      tensor. `rate=1` is disallowed, because the output would be all zeros,
      which is likely not what was intended.
  """
    with tf.name_scope(name or 'stateless_dropout') as name:
        x = tf.convert_to_tensor(x, name='x')
        if not x.dtype.is_floating:
            raise ValueError(
                'x has to be a floating point tensor since it\'s going '
                ' to be scaled. Got a %s tensor instead.' % x.dtype)
        if isinstance(rate, numbers.Real):
            if not (rate >= 0 and rate < 1):
                raise ValueError(
                    'rate must be a scalar tensor or a float in the '
                    'range [0, 1), got %g' % rate)
            if rate > 0.5:
                logging.log_first_n(
                    logging.WARN,
                    'Large dropout rate: %g (>0.5). In TensorFlow '
                    '.x, dropout() uses dropout rate instead of keep_prob. '
                    'Please ensure that this is intended.', 5, rate)

        # Early return if nothing needs to be dropped.
        if tf.get_static_value(rate) == 0:
            return x

        rate = tf.convert_to_tensor(rate, dtype=x.dtype, name='rate')
        rate.shape.assert_has_rank(0)
        noise_shape = _get_noise_shape(x, noise_shape)
        # Sample a uniform distribution on [0.0, 1.0) and select values larger than
        # rate.
        #
        # NOTE: Random uniform actually can only generate 2^23 floats on [1.0, 2.0)
        # and subtract 1.0.
        random_tensor = tf.random.stateless_uniform(noise_shape,
                                                    seed=seed,
                                                    dtype=x.dtype)
        keep_prob = 1 - rate
        scale = 1 / keep_prob
        # NOTE: if (1.0 + rate) - 1 is equal to rate, then we want to consider that
        # float to be selected, hence we use a >= comparison.
        keep_mask = random_tensor >= rate
        ret = x * scale * tf.cast(keep_mask, x.dtype)
        if not tf.executing_eagerly():
            ret.set_shape(x.get_shape())
        return ret
示例#4
0
  def reconstruction_loss(self, x_input, x_target, x_length, z=None,
                          c_input=None):
    """Reconstruction loss calculation.

    Args:
      x_input: Batch of decoder input sequences for teacher forcing, sized
        `[batch_size, max(x_length), output_depth]`.
      x_target: Batch of expected output sequences to compute loss against,
        sized `[batch_size, max(x_length), output_depth]`.
      x_length: Length of input/output sequences, sized `[batch_size]`.
      z: (Optional) Latent vectors. Required if model is conditional. Sized
        `[n, z_size]`.
      c_input: (Optional) Batch of control sequences, sized
          `[batch_size, max(x_length), control_depth]`. Required if conditioning
          on control sequences.

    Returns:
      r_loss: The reconstruction loss for each sequence in the batch.
      metric_map: Map from metric name to tf.metrics return values for logging.
      decode_results: The LstmDecodeResults.
    """
    batch_size = int(x_input.shape[0])

    has_z = z is not None
    z = tf.zeros([batch_size, 0]) if z is None else z
    repeated_z = tf.tile(
        tf.expand_dims(z, axis=1), [1, tf.shape(x_input)[1], 1])

    has_control = c_input is not None
    if c_input is None:
      c_input = tf.zeros([batch_size, tf.shape(x_input)[1], 0])

    sampling_probability_static = tf.get_static_value(
        self._sampling_probability)
    if sampling_probability_static == 0.0:
      # Use teacher forcing.
      x_input = tf.concat([x_input, repeated_z, c_input], axis=2)
      helper = contrib_seq2seq.TrainingHelper(x_input, x_length)
    else:
      # Use scheduled sampling.
      if has_z or has_control:
        auxiliary_inputs = tf.zeros([batch_size, tf.shape(x_input)[1], 0])
        if has_z:
          auxiliary_inputs = tf.concat([auxiliary_inputs, repeated_z], axis=2)
        if has_control:
          auxiliary_inputs = tf.concat([auxiliary_inputs, c_input], axis=2)
      else:
        auxiliary_inputs = None
      helper = contrib_seq2seq.ScheduledOutputTrainingHelper(
          inputs=x_input,
          sequence_length=x_length,
          auxiliary_inputs=auxiliary_inputs,
          sampling_probability=self._sampling_probability,
          next_inputs_fn=self._sample)

    decode_results = self._decode(
        z, helper=helper, input_shape=helper.inputs.shape[2:])
    flat_x_target = flatten_maybe_padded_sequences(x_target, x_length)
    flat_rnn_output = flatten_maybe_padded_sequences(
        decode_results.rnn_output, x_length)
    r_loss, metric_map = self._flat_reconstruction_loss(
        flat_x_target, flat_rnn_output)

    # Sum loss over sequences.
    cum_x_len = tf.concat([(0,), tf.cumsum(x_length)], axis=0)
    r_losses = []
    for i in range(batch_size):
      b, e = cum_x_len[i], cum_x_len[i + 1]
      r_losses.append(tf.reduce_sum(r_loss[b:e]))
    r_loss = tf.stack(r_losses)

    return r_loss, metric_map, decode_results