def _SampleGumbelWithMax(phi, target_max, batch_seed, time_step, src_ids, src_paddings): """Samples a set of Gumbel noises with a specified maximum value. A set of values are sampled from Gumbel distributions with location parameters `phi` under the condition that their maximum is equal to `target_max`. The numerical stable implementation from Appendix B.3 of https://arxiv.org/pdf/1903.06059.pdf is used. Args: phi: A float tensor of shape [tgt_batch, k] thtat represents location parameters of Gumbel distributions. target_max: A float tensor of shape [tgt_batch, 1] that represents the target max values. batch_seed: An int tensor of shape [src_batch] that holds a seed value for each batch item. src_batch must be equal to tgt_batch / num_hyps_per_beam. The same seed is used within each consecutive num_hyps_per_beam items along the tgt_batch axis. time_step: A float tensor used as a secondary seed. src_ids: An int tensor of shape [src_batch, src_seq] that represents source IDs. Used for turning the random seed into a function of source IDs. src_paddings: A 0/1 float tensor of shape [src_batch, src_seq] where 1 means that the corresponding element of src_ids is a padding. Returns: A float tensor like `phi` where their maximum values along the second axis is (almost) equal to `target_max`. """ dtype = phi.dtype tgt_batch = tf.shape(phi)[0] k = tf.shape(phi)[1] src_batch = tf.shape(batch_seed)[0] num_hyps_per_beam = tgt_batch // src_batch # Sample noises from Gumbel distributions with location parameters `phi`. # shape: [src_batch, num_hyps_per_beam, k] gumbel_noises = _BatchSampleGumbel(batch_seed, time_step, src_ids, src_paddings, [num_hyps_per_beam, k], dtype) # shape: [num_hyps_per_beam, src_batch, k] gumbel_noises = tf.transpose(gumbel_noises, perm=[1, 0, 2]) # shape: [tgt_batch, k] gumbel_noises = tf.reshape(gumbel_noises, tf.shape(phi)) # shape: [tgt_batch, k] g_phi = phi + gumbel_noises # shape: [tgt_batch, 1] z = tf.reduce_max(g_phi, axis=1, keepdims=True) # Equation (23). # shape: [tgt_batch, k] v = target_max - g_phi + tf.math.log1p( # Without taking max, sometimes the result of log1p would become NaN on # TPU. tf.maximum(-tf.exp(g_phi - z), tf.constant(-1., dtype=dtype))) # Equation (24). return target_max - tf.nn.relu(v) - tf.math.log1p(tf.exp(-tf.abs(v)))
def _KeepTopP(sorted_log_probs, p): """Keeps the top-p probability mass of `sorted_log_probs`. For each row, elements that are not included in the first `p` probability mass are set to `LARGE_NEGATIVE_NUMBER`. The first element is always kept as-is. Args: sorted_log_probs: A float tensor of shape [batch, k] that represents log-probabilities sorted in descending order. The probabilities do not need to sum to 1. p: A float tensor of shape [batch] that represents a probability threshold for each batch item. Returns: A tensor like `sorted_log_probs` where elements outside the top-p probability mass are set to `LARGE_NEGATIVE_NUMBER`. """ sorted_cum_probs = tf.math.cumsum(tf.exp(sorted_log_probs), exclusive=True, axis=-1) mask = tf.less(sorted_cum_probs, tf.expand_dims(p, axis=1)) # Set mask[:, 0] = True to always keep the first element. batch_size = tf.shape(mask)[0] true = tf.ones([batch_size, 1], dtype=tf.bool) mask = tf.concat([true, mask[:, 1:]], axis=1) filtered_sorted_log_probs = tf.where( mask, sorted_log_probs, tf.fill( tf.shape(sorted_log_probs), tf.constant(LARGE_NEGATIVE_NUMBER, dtype=sorted_log_probs.dtype))) return filtered_sorted_log_probs
def GetSequenceInfo(self, ids, enc_out): inp_ids = self._AddStartToken(ids) dummy_pred = self.decoder.ComputePredictions( self.theta.decoder, enc_out, py_utils.NestedMap({ "ids": inp_ids, "paddings": self._GetPaddings(inp_ids), "weights": tf.ones_like(inp_ids, dtype=tf.float32), })) # What's that? You thought 'softmax_input' in dummy_pred were the logits? # Don't be silly. # Let's pass what we have through the loss layer to really get the logits. # and don't forget this magic line! self.decoder.params.per_example_tensors = True _, per_example_tensors = self.GetDecoderLoss(self.theta, dummy_pred, inp_ids) mask = tf.transpose(1 - self._GetPaddings(ids)) logits = per_example_tensors["logits"] log_p = tf.nn.log_softmax(logits) prob = tf.exp(log_p) entropy = -tf.reduce_sum(log_p * prob, axis=-1) * mask ave_entropy = tf.reduce_sum(entropy, axis=0) / tf.reduce_sum(mask, axis=0) return logits, ave_entropy, dummy_pred.attention["probs"]
def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile(tensor, [num_hyps_per_beam, 1 ]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot( label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) return tf.log(probs), consistent
def compute_relative_changes(eps, u, v, w, new_eps, new_u, new_v, new_w): prev_sum_uvw = tf.stop_gradient((u + v + w) / eps) sum_uvw = tf.stop_gradient((new_u + new_v + new_w) / new_eps) # Compute the relative changes on margins of P. # This will be used for stopping criteria. # Note the last update on w would guarantee the # margin constraint c is satisfied, so we don't # need to check it here. p = tf.exp(tf.stop_gradient(score_ / new_eps + sum_uvw)) p_a = tf.reduce_sum(p, axis=-1, keepdims=True) p_b = tf.reduce_sum(p, axis=-2, keepdims=True) delta_a = tf.abs(a - p_a) / (a + 1e-6) delta_b = tf.abs(b - p_b) / (b + 1e-6) new_delta = tf.reduce_max(delta_a) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_b)) # Compute the relative changes on assignment solution P. # This will be used for stopping criteria. delta_p = tf.abs(tf.exp(prev_sum_uvw) - tf.exp(sum_uvw)) / (tf.exp(sum_uvw) + 1e-6) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_p)) return new_delta
def ApplyBias(): """Bias and update log_probs and consistent.""" # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden # later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, py_utils.FPropDtype(p)) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] label_probs = tf.one_hot(label, vocab_size, dtype=py_utils.FPropDtype( p)) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) # Ensure that tf.math.log is applied to positive values. probs = tf.maximum(probs, tf.constant(1e-12, dtype=probs.dtype)) return tf.math.log(probs), consistent
def Exp(x): return tf.exp(self.linear.Value(x))
def ResidualsToBBoxes(self, anchor_bboxes, residuals, min_angle_rad=-np.pi, max_angle_rad=np.pi): r"""Converts anchor_boxes and residuals to predicted bboxes. This converts predicted residuals into bboxes using the following formulae:: x_predicted = x_a + x_residual * diagonal_xy y_predicted = y_a + y_residual * diagonal_xy z_predicted = z_a + z_residual * dz_a dx_predicted = dx_a * exp(dx_residual) dy_predicted = dy_a * exp(dy_residual) dz_predicted = dz_a * exp(dz_residual) # Adding the residual, and bounding it between # [min_angle_rad, max_angle_rad] phi_predicted = NormalizeAngleRad(phi_a + phi_residual, min_angle_rad, max_angle_rad) These equations follow from those in LocalizationResiduals, where we solve for the \*_gt variables. Args: anchor_bboxes: tf.float32. where [..., :7] contains (x, y, z, dx, dy, dz, phi), corresponding to each anchor bbox parameters. residuals: tf.float32 of the same shape as anchor_bboxes containing predicted residuals at each anchor location. min_angle_rad: Scalar with the minimum angle allowed (before wrapping) in radians. max_angle_rad: Scalar with the maximum angle allowed (before wrapping) in radians. This value usually should be pi. Returns: A tf.float32 tensor of the same shape as anchor_bboxes with predicted bboxes. """ anchor_bboxes_shape = py_utils.GetShape(anchor_bboxes) anchor_bboxes = py_utils.with_dependencies( [py_utils.assert_equal(anchor_bboxes_shape[-1], 7)], anchor_bboxes) residuals = py_utils.HasShape(residuals, anchor_bboxes_shape) x_a, y_a, z_a, dx_a, dy_a, dz_a, phi_a = tf.unstack(anchor_bboxes, num=7, axis=-1) (x_residual, y_residual, z_residual, dx_residual, dy_residual, dz_residual, phi_residual) = tf.unstack(residuals, num=7, axis=-1) diagonal_xy = tf.sqrt(tf.square(dx_a) + tf.square(dy_a)) x_predicted = x_a + x_residual * diagonal_xy y_predicted = y_a + y_residual * diagonal_xy z_predicted = z_a + z_residual * dz_a dx_predicted = dx_a * tf.exp(dx_residual) dy_predicted = dy_a * tf.exp(dy_residual) dz_predicted = dz_a * tf.exp(dz_residual) # We bound the angle between [min_angle_rad, max_angle_rad], which should # be passed in depending on the heading handling in the calling model. # If the model uses a sine(delta_phi) transformation in the loss, then it # cannot distinguish direction and a [0, np.pi] # [min_angle_rad, max_angle_rad] should be used. # If there is a heading encoding that is directional, most likely you # should use a [-np.pi, np.pi] [min_angle_rad, max_angle_rad]. phi_predicted = phi_a + phi_residual phi_predicted = geometry.WrapAngleRad(phi_predicted, min_angle_rad, max_angle_rad) return tf.stack([ x_predicted, y_predicted, z_predicted, dx_predicted, dy_predicted, dz_predicted, phi_predicted, ], axis=-1) # pyformat: disable
def Value(self): return tf.exp(self.linear.Value())
def max_assignment(score: tf.Tensor, *, elementwise_upper_bound: tf.Tensor, row_sums: tf.Tensor, col_sums: tf.Tensor, epsilon: float = 0.1, num_iterations: int = 50, use_epsilon_scaling: bool = True): """Differentiable max assignment with margin and upper bound constraints. Args: score: a 3D tensor of size [batch_size, n_rows, n_columns]. score[i, j, k] denotes the weight if the assignment on this entry is non-zero. elementwise_upper_bound: a 3D tensor of size [batch_size, n_rows, n_columns]. Each entry denotes the maximum value assignment[i, j, k] can take and must be a non-negative value. For example, upper_bound[i, j, k]=1.0 for binary assignment problem. row_sums: a 2D tensor of size [batch_size, n_rows]. The row sum constraint. The output assignment p[i, j, :] must sum to row_sums[i, j]. col_sums: a 2D tensor of size [batch_size, n_columns]. The column sum constraint. The output assignment p[i, :, k] must sum to col_sums[i, k]. epsilon: the epsilon coefficient of entropy regularization. The value should be within the range (0, 1]. `0.01` might work better than `0.1`. `0.1` may not make the assignment close enough to 0 or 1. num_iterations: the maximum number of iterations to perform. use_epsilon_scaling: whether to use epsilon scaling. In practice, the convergence of the iterative algorithm is much better if we start by solving the optimization with a larger epsilon value and re-use the solution (i.e. dual variables) for the instance with a smaller epsilon. This is called the epsilon scaling trick. See [Schmitzer 2019] (https://arxiv.org/pdf/1610.06519.pdf) as a reference. Here if use_epsilon_scaling=True, after each iteration we decrease the running epsilon by a constant factor until it reaches the target epsilon value. We found this to work well for gradient backward propagation, while the original scaling trick doesn't. Returns: A tuple with the following values. - assignment: a 3D tensor of size [batch_size, n_rows, n_columns]. The output assignment. - used_iter: a scalar tensor indicating the number of iterations used. - eps: a scalar tensor indicating the stopping epsilon value. - delta: a scalar tensor indicating the stopping delta value (the relative change on the margins of assignment p in the last iteration). """ # Check if all shapes are correct score_shape = score.shape bsz = score_shape[0] n = score_shape[1] m = score_shape[2] score = tf.ensure_shape(score, [bsz, n, m]) elementwise_upper_bound = tf.ensure_shape(elementwise_upper_bound, [bsz, n, m]) row_sums = tf.ensure_shape(tf.expand_dims(row_sums, axis=2), [bsz, n, 1]) col_sums = tf.ensure_shape(tf.expand_dims(col_sums, axis=1), [bsz, 1, m]) # the total sum of row sums must be equal to total sum of column sums sum_diff = tf.reduce_sum(row_sums, axis=1) - tf.reduce_sum(col_sums, axis=2) sum_diff = tf.abs(sum_diff) tf.Assert(tf.reduce_all(sum_diff < 1e-6), [sum_diff]) # Convert upper_bound constraint into another margin constraint # by adding auxiliary variables & scores. Tensor `a`, `b` and `c` # represent the margins (i.e. reduced sum) of 3 axes respectively. # max_row_sums = tf.reduce_sum(elementwise_upper_bound, axis=-1, keepdims=True) max_col_sums = tf.reduce_sum(elementwise_upper_bound, axis=-2, keepdims=True) score_ = tf.stack([score, tf.zeros_like(score)], axis=1) # (bsz, 2, n, m) a = tf.stack([row_sums, max_row_sums - row_sums], axis=1) # (bsz, 2, n, 1) b = tf.stack([col_sums, max_col_sums - col_sums], axis=1) # (bsz, 2, 1, m) c = tf.expand_dims(elementwise_upper_bound, axis=1) # (bsz, 1, n, m) # Clip log(0) to a large negative values -1e+36 to avoid # getting inf or NaN values in computation. Cannot use larger # values because float32 would use `-inf` automatically. # tf.Assert(tf.reduce_all(a >= 0), [a]) tf.Assert(tf.reduce_all(b >= 0), [b]) tf.Assert(tf.reduce_all(c >= 0), [c]) log_a = tf.maximum(tf.math.log(a), -1e+36) log_b = tf.maximum(tf.math.log(b), -1e+36) log_c = tf.maximum(tf.math.log(c), -1e+36) # Initialize the dual variables of margin constraints u = tf.zeros_like(a) v = tf.zeros_like(b) w = tf.zeros_like(c) eps = tf.constant(1.0 if use_epsilon_scaling else epsilon, dtype=score.dtype) epsilon = tf.constant(epsilon, dtype=score.dtype) def do_updates(cur_iter, eps, u, v, w): # pylint: disable=unused-argument # Epsilon scaling, i.e. gradually decreasing `eps` until it # reaches the target `epsilon` value cur_iter = tf.cast(cur_iter, u.dtype) scaling = tf.minimum(0.6 * 1.04**cur_iter, 0.85) eps = tf.maximum(epsilon, eps * scaling) score_div_eps = score_ / eps # Update u log_q_1 = score_div_eps + (w + v) / eps log_q_1 = tf.reduce_logsumexp(log_q_1, axis=-1, keepdims=True) new_u = (log_a - tf.maximum(log_q_1, -1e+30)) * eps # Update v log_q_2 = score_div_eps + (w + new_u) / eps log_q_2 = tf.reduce_logsumexp(log_q_2, axis=-2, keepdims=True) new_v = (log_b - tf.maximum(log_q_2, -1e+30)) * eps # Update w log_q_3 = score_div_eps + (new_u + new_v) / eps log_q_3 = tf.reduce_logsumexp(log_q_3, axis=-3, keepdims=True) new_w = (log_c - tf.maximum(log_q_3, -1e+30)) * eps return eps, new_u, new_v, new_w def compute_relative_changes(eps, u, v, w, new_eps, new_u, new_v, new_w): prev_sum_uvw = tf.stop_gradient((u + v + w) / eps) sum_uvw = tf.stop_gradient((new_u + new_v + new_w) / new_eps) # Compute the relative changes on margins of P. # This will be used for stopping criteria. # Note the last update on w would guarantee the # margin constraint c is satisfied, so we don't # need to check it here. p = tf.exp(tf.stop_gradient(score_ / new_eps + sum_uvw)) p_a = tf.reduce_sum(p, axis=-1, keepdims=True) p_b = tf.reduce_sum(p, axis=-2, keepdims=True) delta_a = tf.abs(a - p_a) / (a + 1e-6) delta_b = tf.abs(b - p_b) / (b + 1e-6) new_delta = tf.reduce_max(delta_a) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_b)) # Compute the relative changes on assignment solution P. # This will be used for stopping criteria. delta_p = tf.abs(tf.exp(prev_sum_uvw) - tf.exp(sum_uvw)) / (tf.exp(sum_uvw) + 1e-6) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_p)) return new_delta for cur_iter in tf.range(num_iterations): prev_eps, prev_u, prev_v, prev_w = eps, u, v, w eps, u, v, w = do_updates(cur_iter, eps, u, v, w) delta = compute_relative_changes(prev_eps, prev_u, prev_v, prev_w, eps, u, v, w) cur_iter = num_iterations assignment = tf.exp((score_ + u + v + w) / eps) assignment = assignment[:, 0] return assignment, cur_iter, eps, delta
def Value(self, step=None): return tf.exp(self.linear.Value(step))
def ResidualsToBBoxes(self, anchor_bboxes, residuals): r"""Converts anchor_boxes and residuals to predicted bboxes. This converts predicted residuals into bboxes using the following formulae: x_predicted = x_a + x_residual \* diagonal_xy y_predicted = y_a + y_residual \* diagonal_xy z_predicted = z_a + z_residual \* dz_a dx_predicted = dx_a \* exp(dx_residual) dy_predicted = dy_a \* exp(dy_residual) dz_predicted = dz_a \* exp(dz_residual) phi_predicted = phi_a + phi_residual These equations follow from those in LocalizationResiduals, where we solve for the \*_gt variables. Args: anchor_bboxes: tf.float32. where [..., :7] contains (x, y, z, dx, dy, dz, phi), corresponding to each anchor bbox parameters. residuals: tf.float32 of the same shape as anchor_bboxes containing predicted residuals at each anchor location. Returns: A tf.float32 tensor of the same shape as anchor_bboxes with predicted bboxes. """ anchor_bboxes_shape = py_utils.GetShape(anchor_bboxes) anchor_bboxes = py_utils.with_dependencies( [py_utils.assert_equal(anchor_bboxes_shape[-1], 7)], anchor_bboxes) residuals = py_utils.HasShape(residuals, anchor_bboxes_shape) x_a, y_a, z_a, dx_a, dy_a, dz_a, phi_a = tf.unstack( anchor_bboxes, num=7, axis=-1) (x_residual, y_residual, z_residual, dx_residual, dy_residual, dz_residual, phi_residual) = tf.unstack( residuals, num=7, axis=-1) diagonal_xy = tf.sqrt(tf.square(dx_a) + tf.square(dy_a)) x_predicted = x_a + x_residual * diagonal_xy y_predicted = y_a + y_residual * diagonal_xy z_predicted = z_a + z_residual * dz_a dx_predicted = dx_a * tf.exp(dx_residual) dy_predicted = dy_a * tf.exp(dy_residual) dz_predicted = dz_a * tf.exp(dz_residual) # Assuming a sine(delta_phi) transformation is used in the loss, then, it # is not possible to distinguish direction, hence, we use floormod here to # ensure that the predicted_phi is always in [0, np.pi) for consistency. # A separate direction classifier should be added the model if needed. phi_predicted = phi_a + phi_residual phi_predicted = tf.floormod(phi_predicted, np.pi) return tf.stack([ x_predicted, y_predicted, z_predicted, dx_predicted, dy_predicted, dz_predicted, phi_predicted, ], axis=-1) # pyformat: disable
def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs): """Wrapper for adding bias to _PreBeamSearchStateCallback. Biases results.log_probs towards provided encoder_outputs.targets. Args: theta: a NestedMap of parameters. encoder_outputs: a NestedMap computed by encoder. step_ids: A tensor of shape [tgt_batch, 1]. states: A `.NestedMap` of tensors representing states that the clients would like to keep track of for each of the active hyps. num_hyps_per_beam: Beam size. *args: additional arguments to _PreBeamSearchStepCallback. **kwargs: additional arguments to _PreBeamSearchStepCallback. Returns: A tuple (results, out_states). results: A `.NestedMap` of beam search results. atten_probs: The updated attention probs, of shape [tgt_batch, src_len]. log_probs: Log prob for each of the tokens in the target vocab. This is of shape [tgt_batch, vocab_size]. out_states: a `.NestedMap` The updated states. The states relevant here are: time_step: A scalar indicating current step of decoder. Must be provided and maintained by subclass. consistent: A boolean vector of shape [tgt_batch, ] which tracks whether each hypothesis has exactly matched encoder_outputs.targets so far. """ p = self.params time_step = states.time_step bs_results, out_states = self._PreBeamSearchStepCallback( theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs) labels = encoder_outputs.targets.labels weights = encoder_outputs.targets.weights def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile( tensor, [num_hyps_per_beam, 1]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) out_states.consistent = tf.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(out_states.consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot(label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) bs_results.log_probs = tf.log(probs) return bs_results, out_states