def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10): """ Does top-k sampling. if ignore_ids is on, then we will zero out those logits. :param logits: [batch_size, vocab_size] tensor :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, like padding maybe :param p: topp threshold to use, either a float or a [batch_size] vector :return: [batch_size, num_samples] samples # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK """ with tf.variable_scope('top_p_sample'): batch_size, vocab_size = get_shape_list(logits, expected_rank=2) probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, axis=-1) # [batch_size, vocab_perm] indices = tf.argsort(probs, direction='DESCENDING') # find the top pth index to cut off. careful we don't want to cutoff everything! # result will be [batch_size, vocab_perm] k_expanded = k if isinstance(k, int) else k[:, None] exclude_mask = tf.range(vocab_size)[None] >= k_expanded # OPTION A - sample in the sorted space, then unsort. logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10 sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples) sample = tf.batch_gather(indices, sample_perm) return { 'probs': probs, 'sample': sample, }
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9): """ Does top-p sampling. if ignore_ids is on, then we will zero out those logits. :param logits: [batch_size, vocab_size] tensor :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, like padding maybe :param p: topp threshold to use, either a float or a [batch_size] vector :return: [batch_size, num_samples] samples # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK """ with tf.variable_scope('top_p_sample'): batch_size, vocab_size = get_shape_list(logits, expected_rank=2) probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, axis=-1) if isinstance(p, float) and p > 0.999999: # Don't do top-p sampling in this case print("Top-p sampling DISABLED", flush=True) return { 'probs': probs, 'sample': tf.random.categorical( logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, num_samples=num_samples, dtype=tf.int32), } # [batch_size, vocab_perm] indices = tf.argsort(probs, direction='DESCENDING') cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False) # find the top pth index to cut off. careful we don't want to cutoff everything! # result will be [batch_size, vocab_perm] p_expanded = p if isinstance(p, float) else p[:, None] exclude_mask = tf.logical_not( tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1)) # OPTION A - sample in the sorted space, then unsort. logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10 sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples) sample = tf.batch_gather(indices, sample_perm) # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample # unperm_indices = tf.argsort(indices, direction='ASCENDING') # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices) # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10 # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32) return { 'probs': probs, 'sample': sample, }
def test_posterior_mode_invariance_states(self): observation_probs_data = tf.constant([[0.12, 0.48, 0.5, 0.1], [0.4, 0.1, 0.5, 0.0], [0.1, 0.2, 0.3, 0.4]], dtype=self.dtype) transition_matrix_data = tf.constant([[0.21, 0.49, 0.3], [0.18, 0.12, 0.7], [0.75, 0.15, 0.1]], dtype=self.dtype) initial_prob_data = tf.constant([0.8, 0.13, 0.07], dtype=self.dtype) (initial_prob, transition_matrix, observation_probs) = self.make_placeholders([ initial_prob_data, transition_matrix_data, observation_probs_data]) permutations = tf.identity(np.array([np.random.permutation(3) for _ in range(8)])) inverse_permutations = tf.argsort(permutations) initial_prob_permuted = tf.gather(initial_prob, inverse_permutations) # Permute rows of observation matrix observation_probs_permuted = tf.gather(observation_probs, inverse_permutations) # Permute both rows and columns of transition matrix transition_matrix_permuted = tf.transpose( a=tf.gather(tf.transpose(a=transition_matrix), inverse_permutations), perm=[0, 2, 1]) transition_matrix_permuted = tf1.batch_gather(transition_matrix_permuted, inverse_permutations) observations = tf.constant([1, 0, 3, 1, 3, 0, 2, 1, 2, 1, 3, 0, 0, 1, 1, 2]) [num_steps] = self.make_placeholders([16]) model = tfd.HiddenMarkovModel( tfd.Categorical(probs=initial_prob_permuted), tfd.Categorical(probs=transition_matrix_permuted), tfd.Categorical(probs=observation_probs_permuted), num_steps=num_steps) inferred_states = model.posterior_mode(observations) expected_states = [0, 1, 2, 0, 2, 1, 2, 0, 2, 0, 2, 0, 1, 2, 0, 1] expected_states_permuted = tf.transpose( a=tf1.batch_gather( tf.expand_dims(tf.transpose( a=permutations), axis=-1), expected_states)[..., 0]) self.assertAllEqual(inferred_states, expected_states_permuted)
def sample(self, num_samples=1): """Sample from the rejection sampling distribution. For ease of implementation, draw the maximum number of proposal samples. Args: num_samples: integer, number of samples to draw. Returns: samples: Tensor of samples from the distribution, [num_samples] + data_dim """ flat_proposal_samples = self.proposal.sample(num_samples * self.T) proposal_samples = tf.reshape(flat_proposal_samples, [num_samples, self.T] + self.data_dim) flat_logit_accept = self.logit_accept_fn(flat_proposal_samples) logit_accept = tf.reshape(flat_logit_accept, [num_samples, self.T]) accept_samples = tfd.Bernoulli(logits=logit_accept[:, :-1]).sample() # Add forced accept to last sample to ensure truncation accept_samples = tf.concat([ accept_samples, tf.ones([num_samples, 1], dtype=accept_samples.dtype) ], axis=-1) # For each of sample_shape, find the first nonzero accept def get_first_nonzero_index(t): # t is batch_dims + [T], t is binary. _, indices = tf.math.top_k(t, k=1, sorted=False) return indices accept_indices = get_first_nonzero_index( accept_samples) # sample_shape samples = tf.batch_gather(proposal_samples, accept_indices) return tf.squeeze(samples, axis=1) # Squeeze the selected dim
def gaussian_mixture_approximate_mode(gm): """Returns the mean of the most probable mixture component.""" # Find the most likely mixture component. mode_alpha = gm.mixture_distribution.mode()[Ellipsis, None] mus = gm.components_distribution.mean() # Gather the mean of the most likely component. return tf.squeeze(tf.batch_gather(mus, mode_alpha), axis=-2)
def initialize_from_context(initial_context, ignore_ids, news_config, p_for_topp=0.95, k_for_topk=100, do_topk=False): """ same signature as sample_step""" batch_size, _ = get_shape_list(initial_context, expected_rank=2) context_output = sample_step(tokens=initial_context, ignore_ids=ignore_ids, news_config=news_config, batch_size=batch_size, p_for_topp=p_for_topp, k_for_topk=k_for_topk, cache=None, do_topk=do_topk) model = context_output['model'] gt_logprobs = tf.squeeze(tf.batch_gather(model.log_probs[:, :-1], model.input_ids[:, 1:, None]), axis=2) return { 'tokens': tf.concat([initial_context, context_output['new_tokens'][:, None]], 1), 'cache': context_output['new_cache'], 'probs': tf.concat([tf.exp(gt_logprobs), context_output['new_probs'][:, None]], axis=1) }
def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False): """ Helper function that samples from grover for a single step :param tokens: [batch_size, n_ctx_b] tokens that we will predict from :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict :param news_config: config for the GroverModel :param batch_size: batch size to use :param p_for_topp: top-p or top-k threshold :param cache: [batch_size, news_config.num_hidden_layers, 2, news_config.num_attention_heads, n_ctx_a, news_config.hidden_size // news_config.num_attention_heads] OR, None :return: new_tokens, size [batch_size] new_probs, also size [batch_size] new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b, news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads] """ model = GroverModel( config=news_config, is_training=False, input_ids=tokens, reuse=tf.AUTO_REUSE, scope='newslm', chop_off_last_token=False, do_cache=True, cache=cache, ) # Extract the FINAL SEQ LENGTH batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2) next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1] if do_topk: sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32)) else: sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp) new_tokens = tf.squeeze(sample_info['sample'], 1) new_probs = tf.squeeze( tf.batch_gather(sample_info['probs'], sample_info['sample']), 1) return { 'new_tokens': new_tokens, 'new_probs': new_probs, 'new_cache': model.new_kvs, }
def fast_tpu_gather(params, indices, name=None): """Fast gather implementation for models running on TPU. This function use one_hot and batch matmul to do gather, which is faster than gather_nd on TPU. For params that have dtype of int32 (sequences to gather from), batch_gather is used to keep accuracy. Args: params: A tensor from which to gather values. [batch_size, original_size, ...] indices: A tensor used as the index to gather values. [batch_size, selected_size]. name: A string, name of the operation (optional). Returns: gather_result: A tensor that has the same rank as params. [batch_size, selected_size, ...] """ with tf.name_scope(name): dtype = params.dtype def _gather(params, indices): """Fast gather using one_hot and batch matmul.""" if dtype != tf.float32: params = tf.to_float(params) shape = common_layers.shape_list(params) indices_shape = common_layers.shape_list(indices) ndims = params.shape.ndims # Adjust the shape of params to match one-hot indices, which is the # requirement of Batch MatMul. if ndims == 2: params = tf.expand_dims(params, axis=-1) if ndims > 3: params = tf.reshape(params, [shape[0], shape[1], -1]) gather_result = tf.matmul( tf.one_hot(indices, shape[1], dtype=params.dtype), params) if ndims == 2: gather_result = tf.squeeze(gather_result, axis=-1) if ndims > 3: shape[1] = indices_shape[1] gather_result = tf.reshape(gather_result, shape) if dtype != tf.float32: gather_result = tf.cast(gather_result, dtype) return gather_result # If the dtype is int, use the gather instead of one_hot matmul to avoid # precision loss. The max int value can be represented by bfloat16 in MXU is # 256, which is smaller than the possible id values. Encoding/decoding can # potentially used to make it work, but the benenfit is small right now. if dtype.is_integer: gather_result = tf.batch_gather(params, indices) else: gather_result = _gather(params, indices) return gather_result
def gather_neighbour(pc, neighbor_idx): # gather the coordinates or features of neighboring points batch_size = tf.shape(pc)[0] num_points = tf.shape(pc)[1] d = pc.get_shape()[2].value index_input = tf.reshape(neighbor_idx, shape=[batch_size, -1]) features = tf.batch_gather(pc, index_input) features = tf.reshape( features, [batch_size, num_points, tf.shape(neighbor_idx)[-1], d]) return features
def _compute_calibration_bin_statistics(num_bins, logits=None, labels_true=None, labels_predicted=None): """Compute binning statistics required for calibration measures. Args: num_bins: int, number of probability bins, e.g. 10. logits: Tensor, (n,nlabels), with logits for n instances and nlabels. labels_true: Tensor, (n,), with tf.int32 or tf.int64 elements containing ground truth class labels in the range [0,nlabels]. labels_predicted: Tensor, (n,), with tf.int32 or tf.int64 elements containing decisions of the predictive system. If `None`, we will use the argmax decision using the `logits`. Returns: bz: Tensor, shape (2,num_bins), tf.int32, counts of incorrect (row 0) and correct (row 1) predictions in each of the `num_bins` probability bins. pmean_observed: Tensor, shape (num_bins,), tf.float32, the mean predictive probabilities in each probability bin. """ if labels_predicted is None: # If no labels are provided, we take the label with the maximum probability # decision. This corresponds to the optimal expected minimum loss decision # under 0/1 loss. pred_y = tf.argmax(logits, axis=1, output_type=labels_true.dtype) else: pred_y = labels_predicted correct = tf.cast(tf.equal(pred_y, labels_true), tf.int32) # Collect predicted probabilities of decisions pred = tf.nn.softmax(logits, axis=1) prob_y = tf1.batch_gather(pred, pred_y[:, tf.newaxis]) # p(pred_y | x) prob_y = tf.reshape(prob_y, (tf.size(prob_y), )) # Compute b/z histogram statistics: # bz[0,bin] contains counts of incorrect predictions in the probability bin. # bz[1,bin] contains counts of correct predictions in the probability bin. bins = tf.histogram_fixed_width_bins(prob_y, [0.0, 1.0], nbins=num_bins) event_bin_counts = tf.math.bincount(correct * num_bins + bins, minlength=2 * num_bins, maxlength=2 * num_bins) event_bin_counts = tf.reshape(event_bin_counts, (2, num_bins)) # Compute mean predicted probability value in each of the `num_bins` bins pmean_observed = tf.math.unsorted_segment_sum(prob_y, bins, num_bins) tiny = np.finfo(dtype_util.as_numpy_dtype(logits.dtype)).tiny pmean_observed = pmean_observed / ( tf.cast(tf.reduce_sum(event_bin_counts, axis=0), logits.dtype) + tiny) return event_bin_counts, pmean_observed
def sample(self, num_samples=1): """Sample from the model.""" flat_proposal_samples = self.proposal.sample(num_samples * self.K) proposal_samples = tf.reshape(flat_proposal_samples, [num_samples, self.K] + self.data_dim) log_energy = tf.reshape( tf.squeeze(self.energy_fn(flat_proposal_samples), axis=-1), [num_samples, self.K]) indexes = tfd.Categorical(logits=log_energy).sample() # [num_samples] samples = tf.batch_gather(proposal_samples, tf.expand_dims(indexes, axis=-1)) return tf.squeeze(samples, axis=1) # Squeeze the selected dim
def nearest_interpolation(feature, interp_idx): """ :param feature: [B, N, d] input features matrix :param interp_idx: [B, up_num_points, 1] nearest neighbour index :return: [B, up_num_points, d] interpolated features matrix """ feature = tf.squeeze(feature, axis=2) batch_size = tf.shape(interp_idx)[0] up_num_points = tf.shape(interp_idx)[1] interp_idx = tf.reshape(interp_idx, [batch_size, up_num_points]) interpolated_features = tf.batch_gather(feature, interp_idx) interpolated_features = tf.expand_dims(interpolated_features, axis=2) return interpolated_features
def _build_verified_loss(self, labels): """Build verified loss using an upper bound on specification.""" if not self._specification: self._verified_loss = tf.constant(0.) self._interval_bounds_accuracy = tf.constant(0.) return # Interval bounds. bounds = self._get_specification_bounds() # Select specifications. if self._interval_bounds_loss_mode == 'all': pass # Keep bounds the way it is. elif self._interval_bounds_loss_mode == 'most': bounds = tf.reduce_max(bounds, axis=1, keepdims=True) elif self._interval_bounds_loss_mode == 'random': idx = tf.random.uniform( [tf.shape(bounds)[0], self._interval_bounds_loss_n], 0, tf.shape(bounds)[1], dtype=tf.int32) bounds = tf.batch_gather(bounds, idx) else: assert self._interval_bounds_loss_mode == 'least' # This picks the least violated contraint. mask = tf.cast(bounds < 0., tf.float32) smallest_violation = tf.reduce_min( bounds + mask * _BIG_NUMBER, axis=1, keepdims=True) has_violations = tf.less( tf.reduce_sum(mask, axis=1, keepdims=True) + .5, tf.cast(tf.shape(bounds)[1], tf.float32)) largest_bounds = tf.reduce_max(bounds, axis=1, keepdims=True) bounds = tf.where(has_violations, smallest_violation, largest_bounds) if self._interval_bounds_loss_type == 'xent': v = tf.concat( [bounds, tf.zeros([tf.shape(bounds)[0], 1], dtype=bounds.dtype)], axis=1) l = tf.concat( [tf.zeros_like(bounds), tf.ones([tf.shape(bounds)[0], 1], dtype=bounds.dtype)], axis=1) self._verified_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2( labels=tf.stop_gradient(l), logits=v)) elif self._interval_bounds_loss_type == 'softplus': self._verified_loss = tf.reduce_mean( tf.nn.softplus(bounds + self._interval_bounds_hinge_margin)) else: assert self._interval_bounds_loss_type == 'hinge' self._verified_loss = tf.reduce_mean( tf.maximum(bounds, -self._interval_bounds_hinge_margin))
def compute_joint_mlp_logits(sequence, max_span_length): """Computes joint span (start, end) logits from sequence input.""" batch_size, seq_length, hidden_size = modeling.get_shape_list( sequence, expected_rank=3) projection_size = hidden_size # This seems to be a reasonable setting. with tf.variable_scope("joint_span"): projection = tf.layers.dense( sequence, projection_size * 2, activation=None, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), name="projection") start_projection, end_projection = tf.split(projection, 2, axis=-1) # 1. The start representations are tiled max_answer_length times. # TODO(danielandor): Use the mask to compute an optimal span list. starts = tf.reshape(start_projection, [batch_size * seq_length, 1, projection_size]) starts = tf.tile(starts, [1, max_span_length, 1]) starts = tf.reshape( starts, [batch_size, seq_length * max_span_length, projection_size]) # 2. To make the end representations, we compute band diagonal indices and # perform a batched gather. seqs = tf.expand_dims(tf.range(seq_length), 1) offsets = tf.expand_dims(tf.range(max_span_length), 0) indices = seqs + offsets # uses broadcasting indices.shape.assert_is_compatible_with((seq_length, max_span_length)) indices = tf.reshape(indices, [1, seq_length * max_span_length]) indices = tf.tile(indices, [batch_size, 1]) indices = tf.minimum(indices, seq_length - 1) # clips indices ends = tf.batch_gather(end_projection, indices) # 3. The final step adds the starts and ends. ends.shape.assert_is_compatible_with(starts.shape) inputs = starts + ends inputs = modeling.gelu(inputs) # Bias is already in the projection. inputs = contrib_layers.layer_norm(inputs) start_logits = tf.layers.dense( inputs, 1, activation=None, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02), name="logits") return tf.reshape(start_logits, [batch_size, seq_length, max_span_length])
def random_sample(feature, pool_idx): """ :param feature: [B, N, d] input features matrix :param pool_idx: [B, N', max_num] N' < N, N' is the selected position after pooling :return: pool_features = [B, N', d] pooled features matrix """ feature = tf.squeeze(feature, axis=2) num_neigh = tf.shape(pool_idx)[-1] d = feature.get_shape()[-1] batch_size = tf.shape(pool_idx)[0] pool_idx = tf.reshape(pool_idx, [batch_size, -1]) pool_features = tf.batch_gather(feature, pool_idx) pool_features = tf.reshape(pool_features, [batch_size, -1, num_neigh, d]) pool_features = tf.reduce_max(pool_features, axis=2, keepdims=True) return pool_features
def positional_sampling(self, layer_input, feature_dimension, name='positional_sampling'): featuremap = layer_input[0] batch_indices = layer_input[1] grid = layer_input[2] shape_grid = tf.shape(grid) featuremap_flat = tf.reshape(featuremap, [shape_grid[0], -1, feature_dimension]) batch_indices_flat = tf.reshape(batch_indices, [shape_grid[0], -1]) batch_ps_flat = tf.batch_gather(featuremap_flat, batch_indices_flat) b, h, w, c = shape_grid[0], shape_grid[1], shape_grid[ 2], feature_dimension return tf.reshape(batch_ps_flat, [b, h, w, c])
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = GroverModel( config=config, is_training=is_training, input_ids=input_ids, pad_token_id=config.pad_token_id, chop_off_last_token=True, ) total_loss = model.lm_loss() if is_training: train_op, train_metrics = optimization_adafactor.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) else: train_op = None train_metrics = {} tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names) = get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, host_call=construct_scalar_host_call( metric_dict=train_metrics, model_dir=params['model_dir'], prefix='training/'), scaffold_fn=scaffold_fn) else: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[ tf.train.LoggingTensorHook( {'loss': tf.metrics.mean(total_loss)[1]}, every_n_iter=100) ], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(total_loss): loss = tf.metrics.mean(values=total_loss) return { "eval_loss": loss, } eval_metrics = (metric_fn, [total_loss]) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: gt_logprobs = tf.squeeze(tf.batch_gather( model.log_probs, model.target_ids[:, :, None]), axis=2) # Need top-p required under topp sampling! better_than_gt = model.log_probs > gt_logprobs[:, :, None] top_p_required = tf.reduce_sum( tf.cast(better_than_gt, tf.float32) * tf.exp(model.log_probs), axis=2) # No top-p sampling for now, since this seems to be too slow on TPUs if use_tpu: predictions = tf.reshape( tf.random.categorical(logits=model.logits_flat, num_samples=1), get_shape_list(model.target_ids), ) else: # Argmax # predictions = tf.math.argmax(model.log_probs, axis=-1, output_type=tf.int32) predictions = tf.reshape( _top_p_sample(model.logits_flat, num_samples=1, p=0.99)['sample'], get_shape_list(model.target_ids), ) pred_logprobs = tf.squeeze(tf.batch_gather(model.log_probs, predictions[:, :, None]), axis=2) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions={ 'gt_logprobs': gt_logprobs, 'top_p_required': top_p_required, 'predictions': predictions, 'pred_logprobs': pred_logprobs, 'labels': input_ids }, scaffold_fn=scaffold_fn) return output_spec
def define_ppo_step(data_points, hparams, action_space, lr, epoch=-1, distributional_size=1, distributional_subscale=0.04): """Define ppo step.""" del distributional_subscale (observation, action, discounted_reward, discounted_reward_probs, norm_advantage, old_pdf) = data_points obs_shape = common_layers.shape_list(observation) observation = tf.reshape( observation, [obs_shape[0] * obs_shape[1]] + obs_shape[2:] ) (logits, new_value) = get_policy(observation, hparams, action_space, epoch=epoch, distributional_size=distributional_size) logits = tf.reshape(logits, obs_shape[:2] + [action_space.n]) new_policy_dist = tfp.distributions.Categorical(logits=logits) new_pdf = new_policy_dist.prob(action) ratio = new_pdf / old_pdf clipped_ratio = tf.clip_by_value(ratio, 1 - hparams.clipping_coef, 1 + hparams.clipping_coef) surrogate_objective = tf.minimum(clipped_ratio * norm_advantage, ratio * norm_advantage) policy_loss = -tf.reduce_mean(surrogate_objective) if distributional_size > 1: new_value = tf.reshape(new_value, obs_shape[:2] + [distributional_size]) new_value = tf.nn.log_softmax(new_value, axis=-1) value_shape = common_layers.shape_list(new_value) # The above is the new value distribution. We are also given as discounted # reward the value distribution and the corresponding probabilities. # The given discounted reward is already rounded to integers but in range # increased by 2x for greater fidelity. Increase range of new_values here. new_value_shifted = tf.concat([new_value[1:], new_value[-1:]], axis=0) new_value_mean = (new_value + new_value_shifted) / 2 new_value = tf.concat([tf.expand_dims(new_value, axis=-1), tf.expand_dims(new_value_mean, axis=-1)], -1) new_value = tf.reshape(new_value, value_shape[:-1] + [2 * value_shape[-1]]) # Cast discounted reward to integers and gather the new log-probs for them. discounted_reward = tf.cast(discounted_reward, tf.int32) value_loss = tf.batch_gather(new_value, discounted_reward) # Weight the gathered (new) log-probs by the old probabilities. discounted_reward_probs = tf.expand_dims(discounted_reward_probs, axis=1) value_loss = - tf.reduce_sum(value_loss * discounted_reward_probs, axis=-1) # Take the mean over batch and time as final loss, multiply by coefficient. value_loss = hparams.value_loss_coef * tf.reduce_mean(value_loss) else: new_value = tf.reshape(new_value, obs_shape[:2]) value_error = new_value - discounted_reward value_loss = hparams.value_loss_coef * tf.reduce_mean(value_error ** 2) entropy = new_policy_dist.entropy() entropy_loss = -hparams.entropy_loss_coef * tf.reduce_mean(entropy) losses = [policy_loss, value_loss, entropy_loss] loss = sum(losses) variables = tf.global_variables(hparams.policy_network + "/.*") train_op = optimize.optimize(loss, lr, hparams, variables=variables) with tf.control_dependencies([train_op]): return [tf.identity(x) for x in losses]
def model_fn(features, labels, mode, params): """Model function.""" del labels # ============================== # Input features # ============================== # [batch_size, query_seq_len] query_inputs = features["query_inputs"] # [batch_size, num_candidates, candidate_seq_len] candidate_inputs = features["candidate_inputs"] # [batch_size, num_candidates, query_seq_len + candidate_seq_len] joint_inputs = features["joint_inputs"] # [batch_size, num_masks] mlm_targets = features["mlm_targets"] mlm_positions = features["mlm_positions"] mlm_mask = features["mlm_mask"] # ============================== # Create modules. # ============================== bert_module = hub.Module( spec=params["bert_hub_module_handle"], name="locbert", tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) hub.register_module_for_export(bert_module, "locbert") embedder_module = hub.Module( spec=params["embedder_hub_module_handle"], name="embedder", tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) hub.register_module_for_export(embedder_module, "embedder") if params["share_embedders"]: query_embedder_module = embedder_module else: query_embedder_module = hub.Module( spec=params["embedder_hub_module_handle"], name="embedder", tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) hub.register_module_for_export(embedder_module, "query_embedder") # ============================== # Retrieve. # ============================== # [batch_size, projected_size] query_emb = query_embedder_module( inputs=dict( input_ids=query_inputs.token_ids, input_mask=query_inputs.mask, segment_ids=query_inputs.segment_ids), signature="projected") # [batch_size * num_candidates, candidate_seq_len] flat_candidate_inputs, unflatten = flatten_bert_inputs( candidate_inputs) # [batch_size * num_candidates, projected_size] flat_candidate_emb = embedder_module( inputs=dict( input_ids=flat_candidate_inputs.token_ids, input_mask=flat_candidate_inputs.mask, segment_ids=flat_candidate_inputs.segment_ids), signature="projected") # [batch_size, num_candidates, projected_size] unflattened_candidate_emb = unflatten(flat_candidate_emb) # [batch_size, num_candidates] retrieval_score = tf.einsum("BD,BND->BN", query_emb, unflattened_candidate_emb) # ============================== # Read. # ============================== # [batch_size * num_candidates, query_seq_len + candidate_seq_len] flat_joint_inputs, unflatten = flatten_bert_inputs(joint_inputs) # [batch_size * num_candidates, num_masks] flat_mlm_positions, _ = tensor_utils.flatten( tf.tile( tf.expand_dims(mlm_positions, 1), [1, params["num_candidates"], 1])) batch_size, num_masks = tensor_utils.shape(mlm_targets) # [batch_size * num_candidates, query_seq_len + candidates_seq_len] flat_joint_bert_outputs = bert_module( inputs=dict( input_ids=flat_joint_inputs.token_ids, input_mask=flat_joint_inputs.mask, segment_ids=flat_joint_inputs.segment_ids, mlm_positions=flat_mlm_positions), signature="mlm", as_dict=True) # [batch_size, num_candidates] candidate_score = retrieval_score # [batch_size, num_candidates] candidate_log_probs = tf.math.log_softmax(candidate_score) # ============================== # Compute marginal log-likelihood. # ============================== # [batch_size * num_candidates, num_masks] flat_mlm_logits = flat_joint_bert_outputs["mlm_logits"] # [batch_size, num_candidates, num_masks, vocab_size] mlm_logits = tf.reshape( flat_mlm_logits, [batch_size, params["num_candidates"], num_masks, -1]) mlm_log_probs = tf.math.log_softmax(mlm_logits) # [batch_size, num_candidates, num_masks] tiled_mlm_targets = tf.tile( tf.expand_dims(mlm_targets, 1), [1, params["num_candidates"], 1]) # [batch_size, num_candidates, num_masks, 1] tiled_mlm_targets = tf.expand_dims(tiled_mlm_targets, -1) # [batch_size, num_candidates, num_masks, 1] gold_log_probs = tf.batch_gather(mlm_log_probs, tiled_mlm_targets) # [batch_size, num_candidates, num_masks] gold_log_probs = tf.squeeze(gold_log_probs, -1) # [batch_size, num_candidates, num_masks] joint_gold_log_probs = ( tf.expand_dims(candidate_log_probs, -1) + gold_log_probs) # [batch_size, num_masks] marginal_gold_log_probs = tf.reduce_logsumexp(joint_gold_log_probs, 1) # [batch_size, num_masks] float_mlm_mask = tf.cast(mlm_mask, tf.float32) # [] loss = -tf.div_no_nan( tf.reduce_sum(marginal_gold_log_probs * float_mlm_mask), tf.reduce_sum(float_mlm_mask)) # ============================== # Optimization # ============================== num_warmup_steps = min(10000, max(100, int(params["num_train_steps"] / 10))) train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=num_warmup_steps, use_tpu=params["use_tpu"]) # ============================== # Evaluation # ============================== eval_metric_ops = None if params["use_tpu"] else dict() if mode != tf.estimator.ModeKeys.PREDICT: # [batch_size, num_masks] retrieval_utility = marginal_gold_log_probs - gold_log_probs[:, 0] retrieval_utility *= tf.cast(features["mlm_mask"], tf.float32) # [] retrieval_utility = tf.div_no_nan( tf.reduce_sum(retrieval_utility), tf.reduce_sum(float_mlm_mask)) add_mean_metric("retrieval_utility", retrieval_utility, eval_metric_ops) has_timestamp = tf.cast( tf.greater(features["export_timestamp"], 0), tf.float64) off_policy_delay_secs = ( tf.timestamp() - tf.cast(features["export_timestamp"], tf.float64)) off_policy_delay_mins = off_policy_delay_secs / 60.0 off_policy_delay_mins *= tf.cast(has_timestamp, tf.float64) add_mean_metric("off_policy_delay_mins", off_policy_delay_mins, eval_metric_ops) # Create empty predictions to avoid errors when running in prediction mode. predictions = dict() if params["use_tpu"]: return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, predictions=predictions) else: if eval_metric_ops is not None: # Make sure the eval metrics are updated during training so that we get # quick feedback from tensorboard summaries when debugging locally. with tf.control_dependencies([u for _, u in eval_metric_ops.values()]): loss = tf.identity(loss) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, predictions=predictions)
def non_max_suppression(scores_in, boxes_in, top_k_indices, labels, num_detections=ssd_constants.MAX_NUM_EVAL_BOXES): """Implement Non-maximum suppression. Args: scores_in: a Tensor with shape [batch_size, ssd_constants.MAX_NUM_EVAL_BOXES, num_classes]. The top ssd_constants.MAX_NUM_EVAL_BOXES box scores for each class. boxes_in: a Tensor with shape [batch_size, N, 4], which stacks box regression outputs on all feature levels. The N is the number of total anchors on all levels. top_k_indices: a Tensor with shape [batch_size, ssd_constants.MAX_NUM_EVAL_BOXES, num_classes]. The indices for these top boxes for each class. labels: labels tensor. num_detections: maximum output length. Returns: A tensor size of [batch_size, num_detections, 6] represents boxes, labels and scores after NMS. """ _, _, num_classes = scores_in.get_shape().as_list() source_id = tf.cast( tf.tile(tf.expand_dims(labels[ssd_constants.SOURCE_ID], 1), [1, num_detections]), scores_in.dtype) raw_shape = tf.cast( tf.tile(tf.expand_dims(labels[ssd_constants.RAW_SHAPE], 1), [1, num_detections, 1]), scores_in.dtype) list_of_all_boxes = [] list_of_all_scores = [] list_of_all_classes = [] # Skip background class. for class_i in range(1, num_classes, 1): boxes = tf.batch_gather(boxes_in, top_k_indices[:, :, class_i]) class_i_scores = scores_in[:, :, class_i] class_i_scores, boxes = _filter_scores(class_i_scores, boxes) (class_i_post_scores, class_i_post_boxes) = ssd_architecture.non_max_suppression_padded( scores=tf.cast(class_i_scores, scores_in.dtype), boxes=tf.cast(boxes, scores_in.dtype), max_output_size=num_detections, iou_threshold=ssd_constants.OVERLAP_CRITERIA) class_i_classes = tf.fill(tf.shape(class_i_post_scores), ssd_constants.CLASS_INV_MAP[class_i]) list_of_all_boxes.append(class_i_post_boxes) list_of_all_scores.append(class_i_post_scores) list_of_all_classes.append(class_i_classes) post_nms_boxes = tf.concat(list_of_all_boxes, axis=1) post_nms_scores = tf.concat(list_of_all_scores, axis=1) post_nms_classes = tf.concat(list_of_all_classes, axis=1) # sort all results. post_nms_scores, sorted_indices = tf.nn.top_k(tf.cast( post_nms_scores, scores_in.dtype), k=num_detections, sorted=True) post_nms_boxes = tf.gather(post_nms_boxes, sorted_indices, batch_dims=1) post_nms_classes = tf.gather(post_nms_classes, sorted_indices, batch_dims=1) detections_result = tf.stack([ source_id, post_nms_boxes[:, :, 1] * raw_shape[:, :, 1], post_nms_boxes[:, :, 0] * raw_shape[:, :, 0], (post_nms_boxes[:, :, 3] - post_nms_boxes[:, :, 1]) * raw_shape[:, :, 1], (post_nms_boxes[:, :, 2] - post_nms_boxes[:, :, 0]) * raw_shape[:, :, 0], post_nms_scores, tf.cast(post_nms_classes, scores_in.dtype), ], axis=2) return detections_result
def read_from_memory(read_keys, read_strengths, mem_state, top_k): """Function for cosine similarity content based reading from memory matrix. In the args list, we have the following conventions: B: batch size M: number of slots in a row of the memory matrix R: number of rows in the memory matrix H: number of read heads (of the controller or the policy) K: top_k if top_k>0 Args: read_keys: the read keys of shape [B, H, M]. read_strengths: the coefficients used to compute the normalised weighting vector of shape [B, H]. mem_state: the primary memory tensor. Of shape [B, R, M]. top_k: only use top k read matches, other reads do not go into softmax and are zeroed out in the output. top_k=0 (default) means use dense reads. Returns: The memory reads [B, H, M], read weights [B, H, top k], read indices [B, H, top k], and read strengths [B, H, 1]. """ _assert_compatible_read_from_memory_inputs(read_keys, read_strengths, mem_state) batch_size = read_keys.shape[0] num_read_heads = read_keys.shape[1] with tf.name_scope('memory_reading'): # Scale such that all rows are L2-unit vectors, for memory and read query. scaled_read_keys = tf.math.l2_normalize(read_keys, axis=-1) # [B, H, M] scaled_mem = tf.math.l2_normalize(mem_state, axis=-1) # [B, R, M] # The cosine distance is then their dot product. # Find the cosine distance between each read head and each row of memory. cosine_distances = tf.matmul(scaled_read_keys, scaled_mem, transpose_b=True) # [B, H, R] # The rank must match cosine_distances for broadcasting to work. read_strengths = tf.expand_dims(read_strengths, axis=-1) # [B, H, 1] weighted_distances = read_strengths * cosine_distances # [B, H, R] if top_k: # Get top k indices (row indices with top k largest weighted distances). top_k_output = tf.nn.top_k(weighted_distances, top_k, sorted=False) read_indices = top_k_output.indices # [B, H, K] # Create a sub-memory for each read head with only the top k rows. # Each batch_gather is [B, K, M] and the list stacks to [B, H, K, M]. topk_mem_per_head = [ tf.batch_gather(mem_state, ri_this_head) for ri_this_head in tf.unstack(read_indices, axis=1) ] topk_mem = tf.stack(topk_mem_per_head, axis=1) # [B, H, K, M] topk_scaled_mem = tf.math.l2_normalize(topk_mem, axis=-1) # [B, H, K, M] # Calculate read weights for each head's top k sub-memory. expanded_scaled_read_keys = tf.expand_dims(scaled_read_keys, axis=2) # [B, H, 1, M] topk_cosine_distances = tf.reduce_sum(expanded_scaled_read_keys * topk_scaled_mem, axis=-1) # [B, H, K] topk_weighted_distances = (read_strengths * topk_cosine_distances ) # [B, H, K] read_weights = tf.nn.softmax(topk_weighted_distances, axis=-1) # [B, H, K] # For each head, read using the sub-memories and corresponding weights. expanded_weights = tf.expand_dims(read_weights, axis=-1) # [B, H, K, 1] memory_reads = tf.reduce_sum(expanded_weights * topk_mem, axis=2) # [B, H, M] else: read_weights = tf.nn.softmax(weighted_distances, axis=-1) num_rows_memory = mem_state.shape[1] all_indices = tf.range(num_rows_memory, dtype=tf.int32) all_indices = tf.reshape(all_indices, [1, 1, num_rows_memory]) read_indices = tf.tile(all_indices, [batch_size, num_read_heads, 1]) # This is the actual memory access. # Note that matmul automatically batch applies for us. memory_reads = tf.matmul(read_weights, mem_state) read_keys.shape.assert_is_compatible_with(memory_reads.shape) read_strengths = tf.squeeze(read_strengths, axis=-1) # [B, H, 1] -> [B, H] return memory_reads, read_weights, read_indices, read_strengths
def generate_detections_per_image_op(cls_outputs, box_outputs, anchor_boxes, image_id, image_info, num_detections=100, pre_nms_num_detections=1000, nms_threshold=0.3, bbox_reg_weights=(10., 10., 5., 5.)): """Generates detections with model outputs and anchors. Args: cls_outputs: a Tensor with shape [N, num_classes], which stacks class logit outputs on all feature levels. The N is the number of total anchors on all levels. The num_classes is the number of classes predicted by the model. Note that the cls_outputs should be the output of softmax(). box_outputs: a Tensor with shape [N, num_classes*4], which stacks box regression outputs on all feature levels. The N is the number of total anchors on all levels. anchor_boxes: a Tensor with shape [N, 4], which stacks anchors on all feature levels. The N is the number of total anchors on all levels. image_id: an integer number to specify the image id. image_info: a tensor of shape [5] which encodes the input image's [height, width, scale, original_height, original_width] num_detections: Number of detections after NMS. pre_nms_num_detections: Number of candidates before NMS. nms_threshold: a float number to specify the threshold of NMS. bbox_reg_weights: a list of 4 float scalars, which are default weights on (dx, dy, dw, dh) for normalizing bbox regression targets. Returns: detections: detection results in a tensor with each row representing [image_id, ymin, xmin, ymax, xmax, score, class] """ num_boxes, num_classes = cls_outputs.get_shape().as_list() # Removes background class scores. cls_outputs = cls_outputs[:, 1:num_classes] top_k_scores, top_k_indices_with_classes = tf.nn.top_k( tf.reshape(cls_outputs, [-1]), k=pre_nms_num_detections, sorted=True) classes = tf.mod(top_k_indices_with_classes, num_classes - 1) top_k_indices = tf.floordiv(top_k_indices_with_classes, num_classes - 1) anchor_boxes = tf.gather(anchor_boxes, top_k_indices) box_outputs = tf.reshape(box_outputs, [num_boxes, num_classes, 4])[:, 1:num_classes, :] box_outputs = tf.gather_nd(box_outputs, tf.stack([top_k_indices, classes], axis=1)) # Applies bounding box regression to anchors. boxes = box_utils.batch_decode_box_outputs_op( tf.expand_dims(anchor_boxes, axis=0), tf.expand_dims(box_outputs, axis=0), bbox_reg_weights)[0] boxes = box_utils.clip_boxes(tf.expand_dims(boxes, axis=0), tf.expand_dims(image_info[:2], axis=0))[0] classes = tf.tile(tf.reshape(classes, [1, pre_nms_num_detections]), [num_classes - 1, 1]) scores = tf.tile(tf.reshape(top_k_scores, [1, pre_nms_num_detections]), [num_classes - 1, 1]) boxes = tf.tile(tf.reshape(boxes, [1, pre_nms_num_detections, 4]), [num_classes - 1, 1, 1]) class_bitmask = tf.tile( tf.reshape(tf.range(num_classes - 1), [num_classes - 1, 1]), [1, pre_nms_num_detections]) scores = tf.where(tf.equal(classes, class_bitmask), scores, tf.zeros_like(scores)) scores = tf.where(tf.greater(scores, 0.05), scores, tf.zeros_like(scores)) # Reshape classes to be compartible with the top_k function. classes = tf.reshape(classes, [num_classes - 1, pre_nms_num_detections, 1]) scores, sorted_tensors = box_utils.top_k(scores, k=pre_nms_num_detections, tensors=[boxes, classes]) boxes = sorted_tensors[0] classes = tf.reshape(sorted_tensors[1], [num_classes - 1, pre_nms_num_detections]) idx, num_valid = non_max_suppression.non_max_suppression_padded( scores, boxes, max_output_size=num_detections, iou_threshold=nms_threshold, level=0) post_nms_boxes = non_max_suppression.gather_boxes_by_indices( boxes, num_detections, idx, num_valid) post_nms_scores = non_max_suppression.gather_scores_by_indices( scores, num_detections, idx, num_valid) # Sorts all results. sorted_scores, sorted_indices = tf.nn.top_k(tf.to_float( tf.reshape(post_nms_scores, [-1])), k=num_detections, sorted=True) post_nms_boxes = tf.gather(tf.reshape(post_nms_boxes, [-1, 4]), sorted_indices) classes = tf.batch_gather(classes, idx) post_nms_classes = tf.gather(tf.reshape(classes, [-1]), sorted_indices) + 1 if isinstance(image_id, int): image_id = tf.constant(image_id) image_id = tf.reshape(image_id, []) detections_result = tf.stack([ tf.to_float(tf.fill(tf.shape(sorted_scores), image_id)), post_nms_boxes[:, 0], post_nms_boxes[:, 1], post_nms_boxes[:, 2], post_nms_boxes[:, 3], sorted_scores, tf.to_float(post_nms_classes), ], axis=1) return detections_result