def smart_cond(pred, true_fn=None, false_fn=None, name=None): """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. If `pred` is a bool or has a constant value, we return either `true_fn()` or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. Arguments: pred: A scalar determining whether to return the result of `true_fn` or `false_fn`. true_fn: The callable to be performed if pred is true. false_fn: The callable to be performed if pred is false. name: Optional name prefix when using `tf.cond`. Returns: Tensors returned by the call to either `true_fn` or `false_fn`. Raises: TypeError: If `true_fn` or `false_fn` is not callable. """ if not callable(true_fn): raise TypeError('`true_fn` must be callable.') if not callable(false_fn): raise TypeError('`false_fn` must be callable.') pred_value = tf.get_static_value(pred) if isinstance(pred, tf.Variable) or pred_value is None: return tf.cond(pred, true_fn=true_fn, false_fn=false_fn, name=name) if pred_value: return true_fn() else: return false_fn()
def get_ckpt_var_map(ckpt_path, ckpt_scope, var_scope, skip_mismatch=None): """Get a var map for restoring from pretrained checkpoints. Args: ckpt_path: string. A pretrained checkpoint path. ckpt_scope: string. Scope name for checkpoint variables. var_scope: string. Scope name for model variables. skip_mismatch: skip variables if shape mismatch. Returns: var_map: a dictionary from checkpoint name to model variables. """ logging.info('Init model from checkpoint {}'.format(ckpt_path)) if not ckpt_scope.endswith('/') or not var_scope.endswith('/'): raise ValueError('Please specific scope name ending with /') if ckpt_scope.startswith('/'): ckpt_scope = ckpt_scope[1:] if var_scope.startswith('/'): var_scope = var_scope[1:] var_map = {} # Get the list of vars to restore. model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=var_scope) reader = tf.train.load_checkpoint(ckpt_path) ckpt_var_name_to_shape = reader.get_variable_to_shape_map() ckpt_var_names = set(reader.get_variable_to_shape_map().keys()) if tf.distribute.get_replica_context(): replica_id = tf.get_static_value( tf.distribute.get_replica_context().replica_id_in_sync_group) else: replica_id = 0 for i, v in enumerate(model_vars): var_op_name = v.op.name if replica_id >= 1: var_op_name = ''.join(var_op_name.rsplit(f'/replica_{replica_id}', 1)) if not var_op_name.startswith(var_scope): logging.info('skip {} -- does not match scope {}'.format( var_op_name, var_scope)) ckpt_var = ckpt_scope + var_op_name[len(var_scope):] if 'global_step' in ckpt_var: continue if (ckpt_var not in ckpt_var_names and var_op_name.endswith('/ExponentialMovingAverage')): ckpt_var = ckpt_scope + var_op_name[:-len('/ExponentialMovingAverage')] if ckpt_var not in ckpt_var_names: if 'Momentum' in ckpt_var or 'RMSProp' in ckpt_var: # Skip optimizer variables. continue if skip_mismatch: logging.info('skip {} ({}) -- not in ckpt'.format( var_op_name, ckpt_var)) continue raise ValueError('{} is not in ckpt {}'.format(v.op, ckpt_path)) if v.shape != ckpt_var_name_to_shape[ckpt_var]: if skip_mismatch: logging.info('skip {} ({} vs {}) -- shape mismatch'.format( var_op_name, v.shape, ckpt_var_name_to_shape[ckpt_var])) continue raise ValueError('shape mismatch {} ({} vs {})'.format( var_op_name, v.shape, ckpt_var_name_to_shape[ckpt_var])) if i < 5: # Log the first few elements for sanity check. logging.info('Init {} from ckpt var {}'.format(var_op_name, ckpt_var)) var_map[ckpt_var] = v return var_map
def stateless_dropout(x: tf.Tensor, rate: float, seed: tf.Tensor, noise_shape: Optional[Union[Sequence[int], tf.TensorShape]] = None, name: Optional[Text] = None) -> tf.Tensor: """Computes dropout: randomly sets elements to zero to prevent overfitting. See https://www.tensorflow.org/api_docs/python/tf/nn/dropout. This version differs in that the seed is required if the rate is nonzero. Args: x: A floating point tensor. rate: A scalar `Tensor` with the same type as x. The probability that each element is dropped. For example, setting rate=0.1 would drop 10% of input elements. seed: A shape [2] integer Tensor of seeds to the random number generator. Must have dtype `tf.int32` when compiling to XLA. noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for randomly generated keep/drop flags. name: A name for this operation (optional). Returns: A `Tensor` of the same shape of `x`. Raises: ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating point tensor. `rate=1` is disallowed, because the output would be all zeros, which is likely not what was intended. """ with tf.name_scope(name or 'stateless_dropout') as name: x = tf.convert_to_tensor(x, name='x') if not x.dtype.is_floating: raise ValueError( 'x has to be a floating point tensor since it\'s going ' ' to be scaled. Got a %s tensor instead.' % x.dtype) if isinstance(rate, numbers.Real): if not (rate >= 0 and rate < 1): raise ValueError( 'rate must be a scalar tensor or a float in the ' 'range [0, 1), got %g' % rate) if rate > 0.5: logging.log_first_n( logging.WARN, 'Large dropout rate: %g (>0.5). In TensorFlow ' '.x, dropout() uses dropout rate instead of keep_prob. ' 'Please ensure that this is intended.', 5, rate) # Early return if nothing needs to be dropped. if tf.get_static_value(rate) == 0: return x rate = tf.convert_to_tensor(rate, dtype=x.dtype, name='rate') rate.shape.assert_has_rank(0) noise_shape = _get_noise_shape(x, noise_shape) # Sample a uniform distribution on [0.0, 1.0) and select values larger than # rate. # # NOTE: Random uniform actually can only generate 2^23 floats on [1.0, 2.0) # and subtract 1.0. random_tensor = tf.random.stateless_uniform(noise_shape, seed=seed, dtype=x.dtype) keep_prob = 1 - rate scale = 1 / keep_prob # NOTE: if (1.0 + rate) - 1 is equal to rate, then we want to consider that # float to be selected, hence we use a >= comparison. keep_mask = random_tensor >= rate ret = x * scale * tf.cast(keep_mask, x.dtype) if not tf.executing_eagerly(): ret.set_shape(x.get_shape()) return ret
def reconstruction_loss(self, x_input, x_target, x_length, z=None, c_input=None): """Reconstruction loss calculation. Args: x_input: Batch of decoder input sequences for teacher forcing, sized `[batch_size, max(x_length), output_depth]`. x_target: Batch of expected output sequences to compute loss against, sized `[batch_size, max(x_length), output_depth]`. x_length: Length of input/output sequences, sized `[batch_size]`. z: (Optional) Latent vectors. Required if model is conditional. Sized `[n, z_size]`. c_input: (Optional) Batch of control sequences, sized `[batch_size, max(x_length), control_depth]`. Required if conditioning on control sequences. Returns: r_loss: The reconstruction loss for each sequence in the batch. metric_map: Map from metric name to tf.metrics return values for logging. decode_results: The LstmDecodeResults. """ batch_size = int(x_input.shape[0]) has_z = z is not None z = tf.zeros([batch_size, 0]) if z is None else z repeated_z = tf.tile( tf.expand_dims(z, axis=1), [1, tf.shape(x_input)[1], 1]) has_control = c_input is not None if c_input is None: c_input = tf.zeros([batch_size, tf.shape(x_input)[1], 0]) sampling_probability_static = tf.get_static_value( self._sampling_probability) if sampling_probability_static == 0.0: # Use teacher forcing. x_input = tf.concat([x_input, repeated_z, c_input], axis=2) helper = contrib_seq2seq.TrainingHelper(x_input, x_length) else: # Use scheduled sampling. if has_z or has_control: auxiliary_inputs = tf.zeros([batch_size, tf.shape(x_input)[1], 0]) if has_z: auxiliary_inputs = tf.concat([auxiliary_inputs, repeated_z], axis=2) if has_control: auxiliary_inputs = tf.concat([auxiliary_inputs, c_input], axis=2) else: auxiliary_inputs = None helper = contrib_seq2seq.ScheduledOutputTrainingHelper( inputs=x_input, sequence_length=x_length, auxiliary_inputs=auxiliary_inputs, sampling_probability=self._sampling_probability, next_inputs_fn=self._sample) decode_results = self._decode( z, helper=helper, input_shape=helper.inputs.shape[2:]) flat_x_target = flatten_maybe_padded_sequences(x_target, x_length) flat_rnn_output = flatten_maybe_padded_sequences( decode_results.rnn_output, x_length) r_loss, metric_map = self._flat_reconstruction_loss( flat_x_target, flat_rnn_output) # Sum loss over sequences. cum_x_len = tf.concat([(0,), tf.cumsum(x_length)], axis=0) r_losses = [] for i in range(batch_size): b, e = cum_x_len[i], cum_x_len[i + 1] r_losses.append(tf.reduce_sum(r_loss[b:e])) r_loss = tf.stack(r_losses) return r_loss, metric_map, decode_results