def __init__(self, params): super(TestInputGenerator, self).__init__(params) p = self.params self._bprop_variable_filters = [''] self._bprop_onehot = tf.constant([1], dtype=tf.float32) if p.target_key and not p.target_key_target_shape: raise ValueError('target_key_target_shape must be set when ' 'target_key (%s) is not empty.' % p.target_key) if (p.set_tgt_and_additional_tgts and (p.target_key_target_shape[0] != p.target_shape[0])): raise ValueError( 'The first dimension of target_key_target_shape (%d) ' 'should match the first dimension of target_shape ' '(%d) when both have to be set.' % (p.target_key_target_shape[0], p.target_shape[0])) self._cur_iter = 0 if p.bprop_filters and p.number_sources: raise ValueError( 'Number of sources will be set to length of bprop_filters, the param' 'number_sources should not be used when bprop_filters is set.') number_sources = p.number_sources if p.bprop_filters: self._bprop_variable_filters = p.bprop_filters number_sources = len(p.bprop_filters) if number_sources and number_sources > 1: self._bprop_onehot = tf.one_hot(p.source_selected, number_sources, dtype=tf.float32)
def ComputeLoss(self, theta, predictions, input_batch): p = self.params batch = tf.shape(input_batch.data)[0] act = predictions.act with tf.colocate_with(act): tf.logging.info("{}'s device: {}".format(act, act.device)) # Softmax labels = tf.to_int64(input_batch.label) onehot_labels = tf.one_hot(labels, p.softmax.num_classes) if p.label_smoothing > 0: smooth_positives = 1.0 - p.label_smoothing smooth_negatives = p.label_smoothing / p.softmax.num_classes onehot_labels = onehot_labels * smooth_positives + smooth_negatives xent = self.softmax.FProp(theta=theta.softmax, inputs=act, class_weights=input_batch.weight, class_probabilities=onehot_labels) self._AddSummary(input_batch, xent.per_example_argmax) rets = { 'loss': (xent.avg_xent, batch), 'log_pplx': (xent.avg_xent, batch), 'num_preds': (batch, 1), } if p.is_eval or p.compute_accuracy_for_training: acc1 = self._Accuracy(1, xent.logits, labels, input_batch.weight) acc5 = self._Accuracy(5, xent.logits, labels, input_batch.weight) rets.update(accuracy=(acc1, batch), acc5=(acc5, batch)) return rets, {}
def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs, unused_step_ids, states, unused_num_hyps_per_beam): # Same probs for each id. logits = tf.zeros([tgt_batch_size, vocab_size]) # Except eoc has slightly lower score. logits = logits - 1.0 * tf.expand_dims( tf.one_hot(p.target_eoc_id, vocab_size), 0) # eos has very low score (can not terminate by eos) logits = logits + eos_score * tf.expand_dims( tf.one_hot(p.target_eos_id, vocab_size), 0) return py_utils.NestedMap( atten_probs=tf.zeros([tgt_batch_size, 0]), log_probs=logits, is_last_chunk=tf.fill([tgt_batch_size], value=is_last_chunk)), states
def _ComputeClassificationLoss(self, predictions, input_batch, class_weights): """Compute classification loss for the given predictions. Args: predictions: The output of `ComputePredictions`, contains: logits - [b, nx, ny, nz, na, 7 + num_classes]. na is the number of anchor boxes per cell. [..., :7] are (dx, dy, dz, dw, dl, dh, dt). input_batch: The input batch from which we accesses the groundtruth. class_weights: Per-class weights to use in loss computation. Returns: Classification loss. """ p = self.params predicted_class_logits = py_utils.HasShape( predictions.classification_logits, [-1, -1, -1, -1, p.num_anchors, p.num_classes]) bs, nx, ny, nz, na, _ = py_utils.GetShape(predicted_class_logits, 6) assigned_gt_labels = py_utils.HasShape(input_batch.assigned_gt_labels, [bs, nx, ny, nz, na]) class_loss = py_utils.SigmoidCrossEntropyFocalLoss( logits=predicted_class_logits, labels=tf.one_hot(assigned_gt_labels, p.num_classes), alpha=p.focal_loss_alpha, gamma=p.focal_loss_gamma) class_loss *= class_weights[..., tf.newaxis] class_loss_sum = tf.reduce_sum(class_loss) return class_loss_sum
def _ConcatOnehotFn(input_data): """Concat the input features with a onehot version of the label ids.""" features = input_data.features label = input_data.label num_pts = tf.shape(features)[1] label_one_hot = tf.one_hot(tf.cast(label, tf.int32), depth=16) label_one_hot = tf.tile(tf.expand_dims(label_one_hot, 1), [1, num_pts, 1]) input_data.features = tf.concat([features, label_one_hot], axis=-1) return input_data
def _PostprocessSample(self, sample, is_tpu): """Add topk_hyps, topk_ids, topk_lens, topk_scores tensors to `sample`. These features are required by `.BeamSearchDecodeOutput`. Args: sample: a NestedMap with `id`, `paddings`, and `logits` fields. is_tpu: whether inference is being run on TPU. Returns: sample with additional feature that matches `.BeamSearchDecodeOutput` requirements. `topk_hyps` is empty. """ p = self.params bs = tf.shape(sample.ids)[0] num_hyps_per_beam = p.target_sequence_sampler.num_hyps_per_beam vocab_size = tf.shape(sample.logits)[2] # tf.string is not supported on tpu. sample.topk_hyps = tf.zeros([bs], dtype=tf.int32 if is_tpu else tf.string) sample.topk_hyps = tf.reshape(sample.topk_hyps, [-1, num_hyps_per_beam]) sample.topk_ids = sample.ids weights = 1 - sample.paddings sample.topk_lens = tf.cast(tf.reduce_sum(weights, axis=1), dtype=tf.int32) # Computing the hypothesis scores based on the returned ids mask = tf.one_hot(sample.topk_ids, depth=vocab_size, axis=-1, dtype=sample.logits.dtype) token_log_probs = tf.einsum('ijk,ijk->ij', tf.nn.log_softmax(sample.logits), mask) sample.topk_scores = tf.reduce_sum(token_log_probs * weights, axis=1) # At this point batch dimension is (batch_size*num_hyps_per_beam), # interleaved as [num_hyps_per_beam, batch_size]. # This does not match the order expected by beam search post-processing. # Must transpose to [batch_size, num_hyps_per_beam] and flatten back. max_len = tf.shape(sample.topk_ids)[1] sample.topk_ids = tf.reshape(sample.topk_ids, [num_hyps_per_beam, -1, max_len]) sample.topk_ids = tf.transpose(sample.topk_ids, perm=[1, 0, 2]) sample.topk_ids = tf.reshape(sample.topk_ids, [bs, max_len]) # The same for topk_lens and topk_scores sample.topk_lens = tf.reshape(sample.topk_lens, [num_hyps_per_beam, -1]) sample.topk_lens = tf.transpose(sample.topk_lens, [1, 0]) sample.topk_lens = tf.reshape(sample.topk_lens, [-1]) sample.topk_scores = tf.reshape(sample.topk_scores, [num_hyps_per_beam, -1]) sample.topk_scores = tf.transpose(sample.topk_scores, [1, 0]) sample.topk_scores = tf.reshape(sample.topk_scores, [-1]) return sample
def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs, unused_step_ids, states, unused_num_hyps_per_beam): # Same probs for each id. logits = tf.zeros([tgt_batch_size, vocab_size]) # Except eos is slightly lower prob. logits = logits - 1.0 * tf.expand_dims( tf.one_hot(p.target_eos_id, vocab_size), 0) return py_utils.NestedMap( atten_probs=tf.zeros([tgt_batch_size, 0]), log_probs=logits), states
def FProp(self, theta, ids, segment_pos): p = self.params fprop_dtype = py_utils.FPropDtype(p) ids = self._MaybeSplit(ids) segment_pos = self._MaybeSplit(segment_pos) one_hot_ids = tf.one_hot(ids, p.vocab_size, dtype=fprop_dtype) one_hot_ids = self._MaybeSplit(one_hot_ids) one_hot_pos = tf.one_hot(segment_pos, p.max_len, dtype=fprop_dtype) one_hot_pos = self._MaybeSplit(one_hot_pos) token_emb = tf.einsum('VH,BLV->BLH', theta.embedding, one_hot_ids) token_emb = self._MaybeSplit(token_emb) pos_emb = tf.einsum('VH,BLV->BLH', theta.pos_emb, one_hot_pos) pos_emb = self._MaybeSplit(pos_emb) return self._MaybeSplit(token_emb + pos_emb)
def test_entmax_loss_generate_right_loss(self): inputs = tf.constant([[[0.5, 1.0, 2.0]] * 3], dtype='bfloat16') labels = tf.constant([[0, 1, 2]]) # Convert to the matrix with given depth, e.g. the vocabulary size. labels = tf.one_hot(labels, depth=3) expected_loss = tf.constant([[1.5642307, 1.0642307, 0.06423065]], dtype='bfloat16') entmax_loss_val = entmax.entmax_loss(labels, inputs, alpha=1.5) with self.session(use_gpu=False) as sess: output = sess.run(entmax_loss_val) self.assertAllClose(expected_loss, output)
def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs, unused_step_ids, states, num_hyps_per_beam): self.assertEqual(1, num_hyps_per_beam) logits = tf.random.stateless_normal([batch_size, vocab_size], seed=[8273747, 9]) # Make it never predict <eos>. logits -= tf.one_hot([p.target_eos_id], vocab_size, 1e30) is_last_chunk = tf.equal(states.src_step, src_len - 1) result = py_utils.NestedMap( log_probs=logits, is_last_chunk=is_last_chunk) return result, states
def unmask(h, m): with tf.name_scope('unmask'): tpu_summary.tensor('unmask_h', h) tpu_summary.tensor('unmask_m', m) t = tf.cumsum(m, -1) * m - 1 mh = einsum_i32('bkt,bt->bkt', m, h) t2 = tf.one_hot(tf.cast(t, tf.int32), output_len, dtype=fprop_dtype) x = einsum_i32('bkt,bktT->bkT', mh, t2) return tf.cast(x, h.dtype)
def testOneHotLabels(self): """Tests that the loss equals softmax CE when the labels are one hot.""" num_classes = 400 batch_size = 7 label_indices = np.random.randint(0, num_classes, size=(batch_size, 3)) labels = tf.one_hot(label_indices, depth=num_classes, dtype=tf.float32) logits = np.random.uniform(size=(batch_size, 3, num_classes)) * 10 + 1e7 logits_tensor = tf.convert_to_tensor(logits, dtype=tf.float32) losses = label_lib.MultiLabelContrastiveLoss(labels, logits_tensor) expected = tf.nn.softmax_cross_entropy_with_logits( labels=labels, logits=logits_tensor) self.assertAllClose(expected, losses)
def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile(tensor, [num_hyps_per_beam, 1 ]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot( label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) return tf.log(probs), consistent
def dec_callback(self, tgt_id, tgt_pos, tgt_segment_id, tgt_mask, dec_state, t): del tgt_pos, tgt_segment_id [buf] = dec_state if tgt_id.shape == (self.batch_size, self.beam_size): buf = inplace_ops.alias_inplace_update(buf, t, tgt_id) else: div = int(tgt_id.shape[1] // self.beam_size) for i, x_i in enumerate(tf.split(tgt_id, div, 1)): buf = inplace_ops.alias_inplace_update(buf, t + i, x_i) buf1 = tf.transpose(buf, [1, 0, 2]) buf1 = tf.reshape(buf1, [self.batch_size, self.max_steps * self.beam_size]) # select next_tgt_id as a function of previous target tokens if self.rule == '+1': next_tgt_id = (tgt_id + 1) next_tgt_id %= self.vocab_size elif self.rule == 'sum': # sum over all previous tokens in tgt_mask next_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(tgt_mask, tf.int32)) next_tgt_id %= self.vocab_size elif self.rule == 'fib': # select last token according to tgt_mask m = tgt_mask m *= tf.cast( tf.equal(tf.cumsum(m, -1), tf.reduce_sum(m, -1, keepdims=True) - 1), m.dtype) last_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(m, tf.int32)) next_tgt_id = (last_tgt_id + tgt_id) % self.vocab_size # with a lower probably add extra +1 to the correct next_tgt_id n = self.vocab_size logits = 5 * tf.one_hot(next_tgt_id % n, n) logits += 4 * tf.one_hot((next_tgt_id + 1) % n, n) logits += 3 * tf.one_hot((next_tgt_id + 2) % n, n) logits += 2 * tf.one_hot((next_tgt_id + 3) % n, n) logits += 1 * tf.one_hot((next_tgt_id + 4) % n, n) # increase eos_score if current tgt_id contains 9 eos_id = 0 tgt_id_contains_9 = tf.logical_or(tf.equal(tgt_id % 10, 9), tf.equal((tgt_id // 10) % 10, 9)) logits += 9 * tf.einsum('V,BK->BKV', tf.one_hot( eos_id, self.vocab_size), tf.cast(tgt_id_contains_9, tf.float32)) # tie-breaking -- lower token id wins a little bit tie = np.arange(0., 1., 1. / n) tie /= tie.sum() logits -= tie logits = tf.nn.log_softmax(logits) dec_state = [buf] return logits, dec_state
def update_nbest(nbest_hyps, cur_hyps): """Updates nbest hyps from cur_hyps. Returns new values for nbest_hyps.""" with tf.name_scope('update_nbest'): (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps (cur_mask, cur_score, cur_score_norm) = cur_hyps k = int(nbest_mask.shape[1]) m = int(cur_mask.shape[1]) mask = tf.concat([nbest_mask, cur_mask], 1) score = tf.concat([nbest_score, cur_score], 1) score_norm = tf.concat([nbest_score_norm, cur_score_norm], 1) nbest_score_norm, i = tf.math.top_k(score_norm, k) i_one_hot = tf.one_hot(i, k + m, dtype=mask.dtype) nbest_mask = tf.einsum('bkt,bjk->bjt', mask, i_one_hot) nbest_score = tf.einsum('bk,bjk->bj', score, i_one_hot) return (nbest_mask, nbest_score, nbest_score_norm)
def _Slice(tensor): """Return a slice of this tensor at time=state0.t.""" shape = py_utils.GetShape(tensor) # All zeros except for t in the time dimension. # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...] begin = tf.one_hot(self.params.axis, tf.rank(tensor), on_value=state0.t) # Same as shape, but with a 1 in the time dimension. # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...] size = tf.concat([ shape[0:self.params.axis], tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:] ], axis=0) # Make a slice where the time dimension is fixed at state0.t. time_slice = tf.slice(tensor, begin, size) # Remove the time dimension. return tf.squeeze(time_slice, axis=self.params.axis)
def ApplyBias(): """Bias and update log_probs and consistent.""" # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden # later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, py_utils.FPropDtype(p)) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] label_probs = tf.one_hot(label, vocab_size, dtype=py_utils.FPropDtype( p)) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) # Ensure that tf.math.log is applied to positive values. probs = tf.maximum(probs, tf.constant(1e-12, dtype=probs.dtype)) return tf.math.log(probs), consistent
def test_entmax_loss_generate_right_gradient(self): inputs = tf.constant([[0.5, 1.0, 2.0]] * 3) labels = tf.constant([0, 1, 2]) expected_loss_gradient = tf.constant( [[[-0.97671956, 0.16207013, 0.8146494], [0.02328045, -0.83792984, 0.8146494], [0.02328045, 0.16207013, -0.1853506]]]) # Convert to the matrix with given depth, e.g. the vocabulary size. labels = tf.one_hot(labels, depth=3) expected_loss = tf.constant(2.692692) entmax_loss_val = tf.reduce_sum(entmax.entmax_loss( labels, inputs, 1.5)) entmax_loss_gradient_val = tf.gradients(entmax_loss_val, inputs) with self.session(use_gpu=False) as sess: loss_output = sess.run(entmax_loss_val) gradient_output = sess.run(entmax_loss_gradient_val) self.assertAllClose(expected_loss, loss_output) self.assertAllClose(expected_loss_gradient, gradient_output)
def ComputeLoss(self, theta, predictions, input_batch): p = self.params batch = tf.shape(input_batch.data)[0] act = predictions.act with tf.ops.colocate_with(act): tf.logging.info("{}'s device: {}".format(act, act.device)) # Softmax if py_utils.GetRank(input_batch.label) == 1: # Create one_hot labels if rank is 1. labels = tf.cast(input_batch.label, tf.int64) onehot_labels = tf.one_hot(labels, p.softmax.num_classes) else: onehot_labels = input_batch.label labels = tf.math.argmax(onehot_labels, axis=-1) if p.label_smoothing > 0: smooth_positives = 1.0 - p.label_smoothing smooth_negatives = p.label_smoothing / p.softmax.num_classes onehot_labels = onehot_labels * smooth_positives + smooth_negatives xent = self.softmax.FProp( theta=theta.softmax, inputs=act, class_weights=input_batch.weight, class_probabilities=onehot_labels) self._AddSummary(input_batch, xent.per_example_argmax) rets = { 'loss': (xent.avg_xent, batch), 'log_pplx': (xent.avg_xent, batch), 'num_preds': (batch, 1), } if self.do_eval or p.compute_accuracy_for_training: acc1 = self._Accuracy(1, xent.logits, labels, input_batch.weight) acc5 = self._Accuracy(5, xent.logits, labels, input_batch.weight) rets.update( accuracy=(acc1, batch), acc5=(acc5, batch), error=(1. - acc1, batch), error5=(1. - acc5, batch)) return rets, {'loss': xent.per_example_xent}
def _ComputeGradientMask(self, bprop_variable_filters): """Compute gradient mask for each variable and bprop_variable_filters. Note that per_input_gradient_mask[var][i] will be 1 if var matches bprop_variable_filter[i], 0 otherwise. Args: bprop_variable_filters: A list of regex bprop_variable_filters for each file pattern. """ self._per_input_gradient_mask = py_utils.NestedMap() all_vars = set(self.vars.Flatten()) for var in all_vars: self._per_input_gradient_mask[var.name] = ( tf.zeros(len(bprop_variable_filters), dtype=tf.float32)) for i in range(len(bprop_variable_filters)): if re.search(bprop_variable_filters[i], var.name): tf.logging.info('Keep gradient after filtering, regex: %s var: %s' % (bprop_variable_filters[i], var.name)) self._per_input_gradient_mask[var.name] += ( tf.one_hot(i, len(bprop_variable_filters), dtype=tf.float32))
def FProp(self, theta, x): p = self.params if not hasattr(theta, 't'): return x t = theta.t if t is None: return x assert hasattr(theta, 'state') state = theta.state tf.logging.info('p.name=%r', p.name) tf.logging.info('state=%r', state) tf.logging.info('x=%r', x) tf.logging.info('t=%r', t) with tf.name_scope(p.name): if not self._for_flat_beam_search: # For tpu_beam_search_helper z = tf.one_hot(t, tf.shape(state)[1]) z = tf.expand_dims(z, 0) while len(z.shape) < len(x.shape): z = tf.expand_dims(z, -1) y = state = (1 - z) * state + z * x if self._for_flat_beam_search: # For flat beam search state = tf.InplaceUpdate(state, t, x) # [T,B,L,...] y = state # [T, B, L, ...] -> [B, T, L, ...] perm = list(range(len(y.shape))) perm[:2] = [1, 0] y = tf.transpose(y, perm) # [B, T, L, ...] -> [B, T*L, ...] y_shape = list(y.shape) y_shape[1:3] = [int(y_shape[1]) * int(y_shape[2])] y = tf.reshape(y, y_shape) theta.state = state tf.logging.info('y=%r', y) return y
def MatmulGather(source, indices): """Drop in replacement for tf.gather_nd() optimized for speed on TPU. TODO(weihan): tf.gather_nd() is supposed to be implemented in the same way on TPU. Investigate why it's much slower. Args: source: tensor of shape [N, P1, C] indices: tensor of shape [N, P2, K] Returns: tensor of shape [N, P2, K, C] """ source = py_utils.HasRank(source, 3) n, p1, c = py_utils.GetShape(source) indices = py_utils.HasShape(indices, [n, -1, -1]) _, p2, k = py_utils.GetShape(indices) onehot = tf.one_hot(indices, depth=p1) # N x P2 x K x P1 reshaped = tf.reshape(onehot, [n, -1, p1]) # N x (P2 x K) x P1 target = tf.matmul(reshaped, source) # N x (P2 x K) x C return tf.reshape(target, [n, p2, k, c])
def ComputeLoss(self, theta, predictions, input_batch): """Compute loss for the sparse detector model v1. Args: theta: A `.NestedMap` object containing variable values of this task. predictions: A `.NestedMap` object containing residuals and classification_logits. input_batch: A `.NestedMap` expected to contain cell_center_xyz, cell_points_xyz, cell_feature, anchor_bboxes, anchor_localization_residuals, assigned_gt_labels, and assigned_cls_mask. See class doc string for details. Returns: Two dicts: - A dict containing str keys and (metric, weight) pairs as values, where one of the keys is expected to be 'loss'. - A dict containing arbitrary tensors describing something about each training example, where the first dimension of each tensor is the batch index. """ p = self.params batch_size, num_centers = py_utils.GetShape( input_batch.cell_center_xyz, 2) # Assert shapes of inputs. anchor_bboxes = py_utils.HasShape( input_batch.anchor_bboxes, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7]) anchor_localization_residuals = py_utils.HasShape( input_batch.anchor_localization_residuals, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7]) predicted_residuals = py_utils.HasShape( predictions.residuals, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7]) assigned_gt_labels = py_utils.HasShape( input_batch.assigned_gt_labels, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) predicted_classification_logits = py_utils.HasShape( predictions.classification_logits, [ batch_size, num_centers, p.num_anchor_bboxes_per_center, p.num_classes ]) # assigned_cls_mask is for weighting the classification loss. # Ignored targets will have their mask = 0; this happens when their IOU is # not high enough to be a foreground object and not low enough to be # background. class_weights = py_utils.HasShape( input_batch.assigned_cls_mask, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) class_weights = tf.reshape( class_weights, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1]) # Broadcast per class loss weights. For each anchor, there are num_classes # prediction heads, we weight the outputs of these heads by the per class # loss weights. per_class_loss_weight = tf.constant([[[p.per_class_loss_weight]]], dtype=tf.float32) per_class_loss_weight = py_utils.HasShape(per_class_loss_weight, [1, 1, 1, p.num_classes]) class_weights *= per_class_loss_weight class_weights = py_utils.HasShape(class_weights, [ batch_size, num_centers, p.num_anchor_bboxes_per_center, p.num_classes ]) # We use assigned_reg_mask for masking the regression loss. # Only foreground objects will have assigned_reg_mask = 1. reg_weights = py_utils.HasShape( input_batch.assigned_reg_mask, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) reg_weights = tf.reshape( reg_weights, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1]) if p.loss_norm_type == LossNormType.NORM_BY_NUM_POS_PER_CENTER: # Compute number of positive anchors per example. foreground_mask = py_utils.HasShape( input_batch.assigned_reg_mask, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) # Sum to get the number of foreground anchors for each example. loss_normalization = tf.reduce_sum(foreground_mask, axis=2) loss_normalization = tf.maximum(loss_normalization, tf.ones_like(loss_normalization)) # Reshape for broadcasting. loss_normalization = tf.reshape(loss_normalization, [batch_size, num_centers, 1, 1]) # Normalize so that the loss is independent of # centers. loss_normalization *= num_centers class_weights /= loss_normalization reg_weights /= loss_normalization classification_loss = py_utils.SigmoidCrossEntropyFocalLoss( logits=predicted_classification_logits, labels=tf.one_hot(assigned_gt_labels, p.num_classes), alpha=p.focal_loss_alpha, gamma=p.focal_loss_gamma) # Apply mask. classification_loss *= class_weights # TODO(jngiam): Consider normalizing by num_foreground_anchors for each # example instead. This would match the 1/N_positive normalization in # point pillars. # Reduce sum over centers, boxes and classes. classification_loss = tf.reduce_sum(classification_loss, axis=[1, 2, 3]) # Reduce mean over batch. classification_loss = tf.reduce_mean(classification_loss) # Localization regression loss with Huber loss (SmoothL1). regression_loc_and_dims_loss = self._utils_3d.ScaledHuberLoss( labels=anchor_localization_residuals[..., :6], predictions=predicted_residuals[..., :6], delta=p.huber_loss_delta) # Rotation loss is computed on a transform on rotation_delta. For a # direction aware loss, we simply wrap the angles to -pi to pi; for a loss # that is symmetric to direction (i.e., rotating by pi), we use a sin # transform. rotation_delta_transform = tf.sin if p.direction_aware_rot_loss: rotation_delta_transform = functools.partial(geometry.WrapAngleRad, min_val=-np.pi, max_val=np.pi) rotation_delta = (predicted_residuals[..., 6:] - anchor_localization_residuals[..., 6:]) regression_rotation_loss = self._utils_3d.ScaledHuberLoss( labels=tf.zeros_like(rotation_delta), predictions=rotation_delta_transform(rotation_delta), delta=p.huber_loss_delta) reg_loc_loss = regression_loc_and_dims_loss[..., :3] reg_dim_loss = regression_loc_and_dims_loss[..., 3:6] gt_bboxes = self._utils_3d.ResidualsToBBoxes( anchor_bboxes, anchor_localization_residuals, min_angle_rad=-np.pi, max_angle_rad=np.pi) predicted_bboxes = self._utils_3d.ResidualsToBBoxes( anchor_bboxes, predicted_residuals, min_angle_rad=-np.pi, max_angle_rad=np.pi) # Apply mask to individual losses. # # And then reduce sum over centers, boxes, residuals, and batch # and divide by the batch_size. regression_rotation_loss *= reg_weights reg_rot_loss = tf.reduce_sum(regression_rotation_loss) / batch_size reg_loc_loss *= reg_weights reg_loc_loss = tf.reduce_sum(reg_loc_loss) / batch_size reg_dim_loss *= reg_weights reg_dim_loss = tf.reduce_sum(reg_dim_loss) / batch_size # Do not create corner loss graph if weight is 0.0 # TODO(bcyang): Remove condition after fixing corner loss NaN issue if p.corner_loss_weight != 0.0: reg_corner_loss = self._utils_3d.CornerLoss( gt_bboxes=gt_bboxes, predicted_bboxes=predicted_bboxes) reg_corner_loss = tf.expand_dims(reg_corner_loss, axis=-1) reg_corner_loss *= reg_weights reg_corner_loss = tf.reduce_sum(reg_corner_loss) / batch_size else: reg_corner_loss = 0.0 # Sum components of regression loss. regression_loss = (p.location_loss_weight * reg_loc_loss + p.dimension_loss_weight * reg_dim_loss + p.rotation_loss_weight * reg_rot_loss + p.corner_loss_weight * reg_corner_loss) # Compute total loss. total_loss = (p.loss_weight_localization * regression_loss + p.loss_weight_classification * classification_loss) metrics_dict = py_utils.NestedMap({ 'loss': (total_loss, batch_size), 'loss/regression': (regression_loss, batch_size), 'loss/regression/loc': (reg_loc_loss, batch_size), 'loss/regression/dim': (reg_dim_loss, batch_size), 'loss/regression/rot': (reg_rot_loss, batch_size), 'loss/regression/corner': (reg_corner_loss, batch_size), 'loss/classification': (classification_loss, batch_size), }) # Calculate dimension errors dimension_errors_dict = self._BBoxDimensionErrors( gt_bboxes, predicted_bboxes, reg_weights) metrics_dict.update(dimension_errors_dict) per_example_dict = py_utils.NestedMap({ 'residuals': predicted_residuals, 'classification_logits': predicted_classification_logits, 'predicted_bboxes': predicted_bboxes, 'gt_bboxes': gt_bboxes, 'reg_weights': reg_weights, }) return metrics_dict, per_example_dict
def Sum(theta, state, inputs): next_state = py_utils.NestedMap() v = tf.reduce_sum(tf.one_hot(inputs.one_hot, depth=2) * theta.x, axis=0) next_state.sum = state.sum + v return next_state, py_utils.NestedMap()
def ComputePredictions(self, theta, input_batch): """Computes predictions for `input_batch`. Args: theta: A `.NestedMap` object containing variable values of this task. input_batch: A `.NestedMap` expected to contain lasers.points_xyz, lasers.points_feature, lasers.points_padding, cell_center_xyz, cell_points_xyz, cell_feature, anchor_bboxes, anchor_localization_residuals, assigned_gt_labels, and assigned_cls_mask. See class doc string for details. Returns: A `.NestedMap` object containing residuals and classification_logits. """ p = self.params input_batch.Transform(lambda x: (x.shape, x.shape.num_elements())).VLog( 1, 'input_batch shapes: ') cell_feature = py_utils.HasRank(input_batch.cell_feature, 4) batch_size, num_centers = py_utils.GetShape(cell_feature, 2) featurized_cell = self._CellFeaturizer(theta, input_batch) # Project each featurized_cell features to each bbox per center. featurized_anchors = self.cell_feature_projector.FProp( theta.cell_feature_projector, featurized_cell) # Reshape output so that we have features per offset. featurized_anchors = tf.reshape( featurized_anchors, [batch_size, num_centers, p.num_anchor_bboxes_offsets, -1]) # Predict localization residuals. predicted_residuals = self.localization_regressor.FProp( theta.localization_regressor, featurized_anchors) predicted_residuals = tf.reshape( predicted_residuals, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7]) if any([p.oracle_location, p.oracle_dimension, p.oracle_rotation]): gt_residuals = py_utils.HasShape( input_batch.anchor_localization_residuals, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7]) residuals = [] if p.oracle_location: residuals.append(gt_residuals[..., 0:3]) else: residuals.append(predicted_residuals[..., 0:3]) if p.oracle_dimension: residuals.append(gt_residuals[..., 3:6]) else: residuals.append(predicted_residuals[..., 3:6]) if p.oracle_rotation: residuals.append(gt_residuals[..., 6:]) else: residuals.append(predicted_residuals[..., 6:]) predicted_residuals = tf.concat(residuals, axis=-1) if p.squash_rotation_predictions: predicted_rotations = predicted_residuals[..., 6:] predicted_rotations = np.pi * tf.tanh(predicted_rotations) predicted_residuals = tf.concat( [predicted_residuals[..., :6], predicted_rotations], axis=-1) # Predict object classification at each bbox. predicted_classification_logits = self.classifier.FProp( theta.classifier, featurized_anchors) predicted_classification_logits = tf.reshape( predicted_classification_logits, [ batch_size, num_centers, p.num_anchor_bboxes_per_center, p.num_classes ]) if p.oracle_classification: assigned_gt_labels = py_utils.HasShape( input_batch.assigned_gt_labels, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) predicted_classification_logits = tf.one_hot( assigned_gt_labels, p.num_classes) return py_utils.NestedMap({ 'residuals': predicted_residuals, 'classification_logits': predicted_classification_logits, })
def MaxPool3D(points, point_features, pooling_idx, closest_idx): """Apply max pooling to a point cloud with computed sampling indices. sampled_idx and closest_idx are the outputs of a sampler such as FurthestPointSampler. The pooling operation results in a point cloud with fewer points, where the pooled points are specified by pooling_idx. Each element of pooling_idx contains an integer in the range [0, P1) containing the index of the point in points/points_features. Max pooling is performed by assigning each point to its closest pooled point, and then taking a max over the features of points assigned. We assume that this mapping is provided by closest_idx, where each element should contain an integer in the range [0, P2) containing the index of the pooled point that each point is assigned to. Note: This logic for pooling assumes that there will be at least one value > 0 per sampled region for each feature, otherwise it will return 0. Additionally, it does a reduce over a masked version of the features, so mean and min would not work without a change in the logic. Args: points: a floating point tf.Tensor with shape [N, P1, 3] point_features: a floating point tf.Tensor with shape [N, P1, C] pooling_idx: A tf.int32 tf.Tensor of shape [N, P2] with the index of which points we want to keep. Each value should be in the range [0, P1]. closest_idx: A tf.int32 tf.Tensor of shape [N, P1] representing which sampled point is closest to each original point. Each value should be in the range of [0, P2]. Returns: A tuple of tf.Tensors (pooled_points, pooled_features). pooled_points has shape [N, P2, 3] representing the locations of each selected point. P2 corresponds to num_pooled_points. pooled_features has shape [N, P2, C] representing the pooled features at each point. """ batch_size, num_points = py_utils.GetShape(points, 2) point_features = py_utils.HasShape(point_features, [batch_size, num_points, -1]) pooling_idx = py_utils.HasShape(pooling_idx, [batch_size, -1]) _, num_output_points = py_utils.GetShape(pooling_idx) _, _, feature_dims = py_utils.GetShape(point_features, 3) # Gather new point locations. pooled_points = tf.batch_gather(points, pooling_idx) mask = tf.one_hot(closest_idx, num_output_points) # [N, P1, P2] mask = tf.transpose(mask, [2, 0, 1]) # [P2, N, P1] def _PartialPoolFeaturesFn(partial_mask): partial_mask = tf.tile( tf.reshape(partial_mask, [batch_size, num_points, 1]), [1, 1, feature_dims]) # Note: This method of pooling assumes there will be a value > 0 # And will only work with max under this condition. return tf.reduce_max(partial_mask * point_features, axis=1) # Performing a map_fn over the pooled points is more memory efficient. pooled_point_features = tf.map_fn(_PartialPoolFeaturesFn, mask) # [P2, N, P1] pooled_point_features = tf.transpose(pooled_point_features, [1, 0, 2]) return pooled_points, pooled_point_features
def FProp(self, theta, x, paddings=None, update=False): """Computes distances of the given input 'x' to all centroids. This implementation applies layer normalization on 'x' internally first, and the returned 'dists' is computed using the normalized 'x'. Args: theta: A `.NestedMap` of weights' values of this layer. x: A tensor of shape [B, L, N, H]. paddings: If not None, a tensor of shape [B, L]. update: bool, whether to update centroids using x. Returns: dists: "distances" of the given input 'x' to all centroids. Shape [B, L, N, K]. k_means_loss: the average squared Euclidean distances to the closest centroid, a scalar. """ p = self.params if paddings is None: paddings = tf.zeros_like(x[:, :, 0, 0]) # Shape [B, L, 1, 1] paddings_4d = paddings[:, :, None, None] if p.apply_layer_norm: x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon) # 'x' is normalized (but theta.means is not), we use negative dot product to # approximate the Euclidean distance here. dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means) # For padded positions we update the distances to very large numbers. very_large_dists = tf.ones_like(dists) * tf.constant( 0.1, dtype=dists.dtype) * dists.dtype.max paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters]) dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists) # Shape [B, L, N, K], the same as 'dists' above. nearest_one_hot = tf.one_hot( tf.math.argmin(dists, axis=-1), p.num_clusters, dtype=py_utils.FPropDtype(p)) # Same shape as the input 'x'. nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot, theta.means) diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid)) diff = py_utils.ApplyPadding(paddings_4d, diff) diff = tf.math.reduce_mean(diff, axis=2) # The commitment loss which when back proped against encourages the 'x' # values to commit to their chosen centroids. k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings) summary_utils.scalar('k_means/squared_distance_loss', k_means_loss) # TODO(zhouwk): investigate normalizing theta.means after each update. means_norm = tf.norm(theta.means) summary_utils.scalar('k_means/centroid_l2_norm/min', tf.math.reduce_min(means_norm)) summary_utils.scalar('k_means/centroid_l2_norm/mean', tf.math.reduce_mean(means_norm)) if not update: return dists, k_means_loss # To update the centroids (self.vars.means), we apply gradient descent on # the mini-batch of input 'x', which yields the following: # new_centroid = centroid + (1 - decay) * (x_mean - centroid) # where x_mean is the average over all the input vectors closest to this # centroid. # # Note that this approach is equivalent with backprop via # loss = tf.math.reduce_mean( # tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid))) # , except that here the learning rate is independently set via 'decay'. # Ensure that the padded positions are not used to update the centroids. nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot) # Sum away batch and sequence length dimensions to get per cluster count. # Shape: [N, K] per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1]) summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count) # Sum of the input 'x' per each closest centroid. sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x) if py_utils.use_tpu(): per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count) sum_x = tf.tpu.cross_replica_sum(sum_x) # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that # cluster's position will always be 0, hence 'sum_x' in that dimension will # be 0. new_means = sum_x / tf.maximum( tf.constant(1.0, dtype=per_cluster_count.dtype), tf.expand_dims(per_cluster_count, axis=-1)) # We use exponential moving average. TODO(zhouwk): investigate smooth this # over an exponentially moving averaged per cluster count. # # Note that we intentionally do not normalize the means after this update # as empirically this works better. update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means), self.vars.means.dtype) return py_utils.with_dependencies( [tf.assign_add(self.vars.means, update_means_diff)], dists), k_means_loss
def flat_beam_search(batch_size, beam_size, max_steps, dec_callback, dec_state, bos_id=1, eos_id=2, length_norm_alpha=0.8, beam_gap=3.0, top_k_fn=tf.math.top_k, prefix=None, prefix_len=None, fprop_dtype=tf.float32, ext_size=0, nbest_size=None, debug=True): """Flat beam search. Args: batch_size: batch size beam_size: beam size limit in number of hyps max_steps: max steps dec_callback: decoder callback (see above) dec_state: decoder state bos_id: <s> token id eos_id: </s> token id length_norm_alpha: length normalization parameter beam_gap: early stopping threshold; None to disable top_k_fn: top_k function to call prefix: (optional) int32 tensor [batch_size, prefix_max] prefix_len: (optional) int32 tensor [batch_size] fprop_dtype: fprop dtype ext_size: int >= beam_size, extension buffer size nbest_size: number of returned hyps, default is beam_size debug: log intermediate vlaues with tpu_summary.tensor() Returns: (loop_vars, dec_state, nbest) where nbest = (topk_ids, topk_len, topk_score) """ assert beam_size > 0 assert batch_size > 0 assert max_steps > 0 buf_size = beam_size * max_steps output_len = max_steps if prefix is None: assert prefix_len is None prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32) prefix += tf.one_hot(0, beam_size, dtype=tf.int32) * bos_id prefix_len = tf.ones([batch_size], dtype=tf.int32) else: assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape) assert int(prefix_len.shape[0]) == batch_size, (batch_size, prefix_len.shape) output_len += int(prefix.shape[1]) if debug: tpu_summary.tensor('prefix', prefix) tpu_summary.tensor('prefix_len', prefix_len) with tf.name_scope('init_state'): t = tf.constant(0) tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_id += bos_id tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size), buf_size, dtype=fprop_dtype) hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype) # penalize all hyps except the first hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) nbest_size = nbest_size or beam_size nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype) nbest_score -= 1e9 nbest_score_norm = nbest_score nbest_mask = tf.zeros([batch_size, nbest_size, buf_size], dtype=fprop_dtype) with tf.name_scope('init_ext'): # Initialize the extension buffer. # # Extension buffer stores a (potentially large) set of 'extensions', # which consist of a hypothesis (represented by ext_mask) and next token # (represented by ext_id). At each decoder iteration, top_k extensions # from each hypothesis are added to the buffer and sorted by score. # # Then top beam_size extensions are removed from the buffer and used # in the next decoder iteration. And top 'ext_size' remaining extensions # are carried over to be possibly evaluated at a later step. # # As a result of this manipulation, the decoder is no longer restricted # to always compare hyps of the same token length at each iteration. # In particular, for a fixed length N it can generate more than beam_size # terminated hyps. # # Setting ext_size = 0 disables this feautre. if ext_size: ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32) ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype) ext_score -= 1e9 ext_mask = tf.zeros([batch_size, ext_size, buf_size], dtype=fprop_dtype) else: ext_size = ext_id = ext_score = ext_mask = 0 with tf.name_scope('init_prefix'): # rename prefix->pfx for shorter variables pfx = tf.cast(prefix, tf.int32) pfx_len = tf.cast(prefix_len, tf.int32) del prefix, prefix_len # Before the first call to dec_callback() the prefix shall be packed into # the tgt_id buffer as follows: # # [ P P P P P P - - - - - - P* - - - ] ^ # [ P P P P P P P P P P - - P* - - - ] | batch # [ P - - - - - - - - - - - P* - - - ] V # |<---- prefix len ----> |<-- beam --> # # The last meaningful token in the prefix (P*) # must be located at the same position in all batch rows. # # We then make one dec_callback() with full prefix (minus P*) # which will populate the initial dec_state # (for transformer -- self-attention key/value cache) # # The last block [batch, beam] then becomes the first tgt_id for the loop. pfx_max = int(pfx.shape[1]) pfx_mul = pfx_max // beam_size assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size) pfx_time = tf.range(pfx_max) pfx_pad = tf.cast( tf.less(tf.expand_dims(pfx_time, 0), tf.expand_dims(pfx_len - 1, 1)), tf.int32) pfx_id = pfx * pfx_pad pfx_last = einsum_i32( 'BT,BT->B', pfx, tf.one_hot(pfx_len - 1, pfx_max, dtype=fprop_dtype)) buf_time = tf.range(buf_size) pfx_time_mask = tf.cast( tf.less_equal(tf.expand_dims(buf_time, 0), tf.expand_dims(pfx_time, 1)), fprop_dtype) pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype), pfx_time_mask) pfx_segment_id = pfx_pad pfx_pos = pfx_time * pfx_pad if debug: tpu_summary.tensor('pfx_id', pfx_id) tpu_summary.tensor('pfx_len', pfx_len) tpu_summary.tensor('pfx_pos', pfx_pos) tpu_summary.tensor('pfx_last', pfx_last) # Now call decoder with prefix minus P*: # 'dec_state' now shall contain the key/value cache for prefix tokens # (for transformer models), and 'logits' we can either discard or # roll into the initial hyp_score. Discard is simpler. with tf.name_scope('prefix_fprop'): # TODO(krikun): remove extra type checks assert (pfx_id.dtype == tf.int32), (pfx_id.dtype) assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype) assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype) assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype) assert (t.dtype == tf.int32), (t.dtype) logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos, pfx_mask, dec_state, t) del logits # Now construct the initial state for the rest of the beam search loop. # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape # 'tgt_pos' is different for each batch row and is equal to prefix_len # 'tgt_segment_id' always 1 (no packing) # 'hyp_score' is 0 for beam=0 and negative for beam>=1 tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( pfx_last, 1) tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( (pfx_len - 1), 1) hyp_score = tf.zeros( [batch_size, beam_size], dtype=fprop_dtype) - tf.cast( tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) # TODO(krikun) Here we make initial 't' constant and determined by the # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic # as t ~ max(pfx_len) / beam_size and this will more steps for beam search # however 'max' results in a very slow all-to-all for 'max' on 16x16 # and variable number of decoder steps may result in bad latency. t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32) # Initial tgt_mask is such that each token P* has attention on itself # (as usual) and on all prefix tokens before it, which are not padding. tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.cast( tf.expand_dims( tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1), fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) if debug: tpu_summary.tensor('tgt_id', tgt_id) tpu_summary.tensor('tgt_pos', tgt_pos) tpu_summary.tensor('tgt_mask', tgt_mask) tpu_summary.tensor('t', t) with tf.name_scope('init_hist'): # h_tgt_id is used to recover topk_ids from nbest_mask h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps) h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps) # When non-trivial prefix is present we also write prefix ids to # h_tgt_id so that the full sequence including prefix can be recovered # by unmask() below. When prefix is empty, pfx_id shape is [batch, 0] # and the loop below becomes a no-op. # TODO(krikun): maybe a tf.while_loop is more appropriate here. for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)): h_tgt_id = h_tgt_id.write(i, x_i) for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)): h_tgt_pos = h_tgt_pos.write(i, x_i) hist = (h_tgt_id, h_tgt_pos) tf.logging.info('hist=%r', hist) nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm) tf.logging.info('nbest_hyps=%r', nbest_hyps) ext = (ext_id, ext_score, ext_mask) tf.logging.info('ext=%r', ext) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) def loop_step(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (ext_id, ext_score, ext_mask) = ext (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id') h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos') # not using tf.ones() here because of XLA compilation error tgt_segment_id = tgt_id * 0 + 1 logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos, tgt_mask, dec_state, t) # take predicted EOS score for each hyp and compute normalized score eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype) def length_norm(t): t = tf.cast(t, fprop_dtype) alpha = length_norm_alpha tf.logging.info('length_norm.alpha=%r', alpha) return tf.math.pow((t + 5.) / 5., alpha) hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1) eos_score_norm = eos_score / length_norm(hyp_len) # update the n-best list nbest_hyps = update_nbest(nbest_hyps, (tgt_mask, hyp_score, eos_score_norm)) if debug: tpu_summary.tensor('eos_score', eos_score) tpu_summary.tensor('hyp_len', hyp_len) # take top k tokens for each hyp k = beam_size with tf.name_scope('topk1'): top_score, top_id = top_k_fn(logits, k) top_score = tf.cast(top_score, fprop_dtype) top_score += tf.expand_dims(hyp_score, -1) top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype) top_score = tf.reshape(top_score, [batch_size, beam_size * k]) top_id = tf.reshape(top_id, [batch_size, beam_size * k]) top_mask = tf.repeat(tgt_mask, beam_size, 1) if debug: tpu_summary.tensor('top_id', top_id) tpu_summary.tensor('top_score', top_score) # tpu_summary.tensor('top_mask', top_mask) with tf.name_scope('update_ext'): # combine top k tokens with extension buffer (if any) if ext_size: ext_id = tf.concat([ext_id, top_id], 1) ext_score = tf.concat([ext_score, top_score], 1) ext_mask = tf.concat([ext_mask, top_mask], 1) else: ext_id, ext_score, ext_mask = top_id, top_score, top_mask # sort by score ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size) i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype) ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1) ext_id = einsum_i32('bk,bjk->bj', ext_id, i1) # pick top beam_size extensions to evaluate at next iteration if ext_size: hyp_score = ext_score[:, :beam_size] ext_score = ext_score[:, beam_size:] tgt_id = ext_id[:, :beam_size] ext_id = ext_id[:, beam_size:] tgt_mask = ext_mask[:, :beam_size] ext_mask = ext_mask[:, beam_size:] else: hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask ext_score = ext_id = ext_mask = 0 tgt_pos = tf.reduce_sum(tgt_mask, -1) tgt_pos = tf.cast(tgt_pos, tf.int32) t += 1 with tf.name_scope('tgt_mask_extend'): tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) ext = (ext_id, ext_score, ext_mask) hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) return loop_vars, dec_state def loop_cond(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) if beam_gap is None: (t, _, _, _, _, _, _, _) = loop_vars return t < max_steps else: (t, _, _, _, _, nbest_hyps, _, _) = loop_vars (_, nbest_score, _) = nbest_hyps # stop early if all current hyps are significantly worse than nbest diff = tf.reduce_min( tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1)) return tf.math.logical_and(t < max_steps, diff < beam_gap) with tf.name_scope('flat_beam_search_loop'): (loop_vars, dec_state) = tf.while_loop(loop_cond, loop_step, loop_vars=(loop_vars, dec_state), back_prop=False, swap_memory=False, maximum_iterations=max_steps) # flatten all tensorarrays into tensors (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.stack() h_tgt_pos = h_tgt_pos.stack() hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) # recover topk_ids from nbest_mask and tgt_id history h = tf.transpose(h_tgt_id, [1, 0, 2]) h = tf.reshape(h, [batch_size, buf_size]) def unmask(h, m): with tf.name_scope('unmask'): tpu_summary.tensor('unmask_h', h) tpu_summary.tensor('unmask_m', m) t = tf.cumsum(m, -1) * m - 1 mh = einsum_i32('bkt,bt->bkt', m, h) t2 = tf.one_hot(tf.cast(t, tf.int32), output_len, dtype=fprop_dtype) x = einsum_i32('bkt,bktT->bkT', mh, t2) return tf.cast(x, h.dtype) topk_ids = unmask(h, nbest_mask) topk_len = tf.reduce_sum(nbest_mask, -1) topk_len = tf.cast(topk_len, tf.int32) # add eos, because nbest_mask does not encode eos topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32) topk_len += 1 topk_len = tf.minimum(topk_len, output_len) topk_score = nbest_score_norm nbest = (topk_ids, topk_len, topk_score) return loop_vars, dec_state, nbest
def ComputeLoss(self, theta, predictions, input_batch): """Computes loss and other metrics for the given predictions. Args: theta: A `.NestedMap` object containing variable values of this task. predictions: The output of `ComputePredictions`, contains: logits - [b, nx, ny, nz, na, 7 + num_classes]. na is the number of anchor boxes per cell. [..., :7] are (dx, dy, dz, dw, dl, dh, dt). input_batch: The input batch from which we accesses the groundtruth. Returns: Two dicts defined as BaseTask.ComputeLoss. """ p = self.params predicted_residuals = py_utils.HasShape( predictions.residuals, [-1, -1, -1, -1, p.num_anchors, 7]) predicted_class_logits = py_utils.HasShape( predictions.classification_logits, [-1, -1, -1, -1, p.num_anchors, p.num_classes]) bs, nx, ny, nz, na, _ = py_utils.GetShape(predicted_class_logits, 6) # Compute class and regression weights. class_weights = input_batch.assigned_cls_mask class_weights = py_utils.HasShape(class_weights, [bs, nx, ny, nz, na]) reg_weights = input_batch.assigned_reg_mask reg_weights = py_utils.HasShape(reg_weights, [bs, nx, ny, nz, na]) reg_weights = tf.expand_dims(reg_weights, -1) if p.loss_norm_type == LossNormType.NORM_BY_NUM_POSITIVES: # Compute number of positive anchors per example. foreground_mask = py_utils.HasShape(input_batch.assigned_reg_mask, [bs, nx, ny, nz, na]) # Sum to get the number of foreground anchors for each example. loss_normalization = tf.reduce_sum(foreground_mask, axis=[1, 2, 3, 4]) loss_normalization = tf.maximum(loss_normalization, tf.ones_like(loss_normalization)) # Reshape for broadcasting. loss_normalization = tf.reshape(loss_normalization, [bs, 1, 1, 1, 1, 1]) class_weights /= loss_normalization reg_weights /= loss_normalization # Classification loss. assigned_gt_labels = py_utils.HasShape(input_batch.assigned_gt_labels, [bs, nx, ny, nz, na]) class_loss = py_utils.SigmoidCrossEntropyFocalLoss( logits=predicted_class_logits, labels=tf.one_hot(assigned_gt_labels, p.num_classes), alpha=p.focal_loss_alpha, gamma=p.focal_loss_gamma) class_loss *= class_weights[..., tf.newaxis] class_loss_sum = tf.reduce_sum(class_loss) # Regression loss. anchor_localization_residuals = py_utils.HasShape( input_batch.anchor_localization_residuals, [bs, nx, ny, nz, na, 7]) # Location and dimensions loss. reg_loc_and_dims_loss = self._utils.ScaledHuberLoss( predictions=py_utils.HasShape(predicted_residuals[..., :6], [bs, nx, ny, nz, na, 6]), labels=anchor_localization_residuals[..., :6], delta=1 / (3.**2)) # Rotation loss with SmoothL1(sin(delta)). rot_delta = (predicted_residuals[..., 6:] - input_batch.anchor_localization_residuals[..., 6:]) if p.use_atan2_heading_loss: atan2_of_delta = tf.atan2(tf.sin(rot_delta), tf.cos(rot_delta)) reg_rot_loss = self._utils.ScaledHuberLoss( predictions=atan2_of_delta, labels=tf.zeros_like(atan2_of_delta), delta=1 / (3.**2)) else: # Rotation loss with SmoothL1(sin(delta)). reg_rot_loss = self._utils.ScaledHuberLoss( predictions=tf.sin(rot_delta), labels=tf.zeros_like(rot_delta), delta=1 / (3.**2)) # Direction loss if p.direction_classifier_weight > 0.0: # The target rotations are in the assigned_gt_bbox tensor, # which already has assigned a gt bounding box to every anchor. rot_target = input_batch.assigned_gt_bbox[..., 6] # If rotation is > 0, the class is 1, else it is 0. rot_dir = tf.cast(rot_target > 0., tf.int32) # Compute one-hot labels as a target. rot_dir_onehot = tf.one_hot(rot_dir, 2) # Manually handle loss reduction. dir_loss = tf.losses.softmax_cross_entropy( onehot_labels=rot_dir_onehot, logits=predictions.predicted_dir, weights=tf.squeeze(reg_weights, axis=-1), reduction=tf.losses.Reduction.NONE) # Reduce across all dimensions (we'll divide by the batch size below). dir_loss_sum = tf.reduce_sum(dir_loss) else: dir_loss_sum = 0.0 # Compute loss contribution from location and dimension separately. reg_loc_loss = reg_loc_and_dims_loss[..., :3] * reg_weights reg_loc_loss_sum = tf.reduce_sum(reg_loc_loss) reg_dim_loss = reg_loc_and_dims_loss[..., 3:6] * reg_weights reg_dim_loss_sum = tf.reduce_sum(reg_dim_loss) # Compute rotation loss contribution. reg_rot_loss *= reg_weights reg_rot_loss_sum = tf.reduce_sum(reg_rot_loss) # Num. predictions. # TODO(zhifengc): Consider other normalization factors. E.g., # of bboxes. preds = tf.cast(bs, class_loss_sum.dtype) # Normalize all of the components by batch size. reg_loc_loss = reg_loc_loss_sum / preds reg_dim_loss = reg_dim_loss_sum / preds reg_rot_loss = reg_rot_loss_sum / preds class_loss = class_loss_sum / preds dir_loss = dir_loss_sum / preds # Compute total localization regression loss. reg_loss = (p.location_loss_weight * reg_loc_loss + p.dimension_loss_weight * reg_dim_loss + p.rotation_loss_weight * reg_rot_loss) # Apply weights to normalized class losses. loss = (class_loss * p.classification_loss_weight + reg_loss * p.localization_loss_weight + dir_loss * p.direction_classifier_weight) metrics_dict = { 'loss': (loss, preds), 'loss/class': (class_loss, preds), 'loss/reg': (reg_loss, preds), 'loss/reg/rot': (reg_rot_loss, preds), 'loss/reg/loc': (reg_loc_loss, preds), 'loss/reg/dim': (reg_dim_loss, preds), 'loss/dir': (dir_loss, preds), } # Calculate dimension errors min_angle_rad = -np.pi if p.use_atan2_heading_loss else 0 gt_bboxes = self._utils_3d.ResidualsToBBoxes( input_batch.anchor_bboxes, anchor_localization_residuals, min_angle_rad=min_angle_rad, max_angle_rad=np.pi) predicted_bboxes = self._utils_3d.ResidualsToBBoxes( input_batch.anchor_bboxes, predicted_residuals, min_angle_rad=min_angle_rad, max_angle_rad=np.pi) dimension_errors_dict = self._BBoxDimensionErrors( gt_bboxes, predicted_bboxes, reg_weights) metrics_dict.update(dimension_errors_dict) per_example_dict = { 'residuals': predicted_residuals, 'classification_logits': predicted_class_logits, } return metrics_dict, per_example_dict
def loop_step(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (ext_id, ext_score, ext_mask) = ext (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id') h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos') # not using tf.ones() here because of XLA compilation error tgt_segment_id = tgt_id * 0 + 1 logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos, tgt_mask, dec_state, t) # take predicted EOS score for each hyp and compute normalized score eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype) def length_norm(t): t = tf.cast(t, fprop_dtype) alpha = length_norm_alpha tf.logging.info('length_norm.alpha=%r', alpha) return tf.math.pow((t + 5.) / 5., alpha) hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1) eos_score_norm = eos_score / length_norm(hyp_len) # update the n-best list nbest_hyps = update_nbest(nbest_hyps, (tgt_mask, hyp_score, eos_score_norm)) if debug: tpu_summary.tensor('eos_score', eos_score) tpu_summary.tensor('hyp_len', hyp_len) # take top k tokens for each hyp k = beam_size with tf.name_scope('topk1'): top_score, top_id = top_k_fn(logits, k) top_score = tf.cast(top_score, fprop_dtype) top_score += tf.expand_dims(hyp_score, -1) top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype) top_score = tf.reshape(top_score, [batch_size, beam_size * k]) top_id = tf.reshape(top_id, [batch_size, beam_size * k]) top_mask = tf.repeat(tgt_mask, beam_size, 1) if debug: tpu_summary.tensor('top_id', top_id) tpu_summary.tensor('top_score', top_score) # tpu_summary.tensor('top_mask', top_mask) with tf.name_scope('update_ext'): # combine top k tokens with extension buffer (if any) if ext_size: ext_id = tf.concat([ext_id, top_id], 1) ext_score = tf.concat([ext_score, top_score], 1) ext_mask = tf.concat([ext_mask, top_mask], 1) else: ext_id, ext_score, ext_mask = top_id, top_score, top_mask # sort by score ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size) i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype) ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1) ext_id = einsum_i32('bk,bjk->bj', ext_id, i1) # pick top beam_size extensions to evaluate at next iteration if ext_size: hyp_score = ext_score[:, :beam_size] ext_score = ext_score[:, beam_size:] tgt_id = ext_id[:, :beam_size] ext_id = ext_id[:, beam_size:] tgt_mask = ext_mask[:, :beam_size] ext_mask = ext_mask[:, beam_size:] else: hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask ext_score = ext_id = ext_mask = 0 tgt_pos = tf.reduce_sum(tgt_mask, -1) tgt_pos = tf.cast(tgt_pos, tf.int32) t += 1 with tf.name_scope('tgt_mask_extend'): tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) ext = (ext_id, ext_score, ext_mask) hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) return loop_vars, dec_state