def _Upd(c, x): if not self._cond_is_finite: return c c = tf.math.logical_and(c, tf.reduce_all(tf.math.is_finite(x))) c = tf.math.logical_and( c, tf.reduce_all(tf.math.logical_not(tf.math.is_inf(x)))) return c
def IsWithinBBox(points, bbox): """Checks if points are within a 2-d bbox. The function returns true if points are strictly inside the box. It also returns true when the points are exactly on the box edges. Args: points: a float Tensor of shape [..., 2] of points to be tested. The last coordinates are (x, y). bbox: a float Tensor of shape [..., 4, 2] of bboxes. The last coordinates are the four corners of the bbox and (x, y). The corners are assumed to be given in counter-clockwise order. Returns: Tensor: If ``pshape = tf.shape(points)[:-1]`` and ``bshape = tf.shape(bbox)[:-2]``, returns a boolean tensor of shape ``tf.concat(pshape, bshape)``, where each element is true if the point is inside to the corresponding box. If a point falls exactly on an edge of the bbox, it is also true. """ bshape = py_utils.GetShape(bbox)[:-2] pshape = py_utils.GetShape(points)[:-1] bbox = py_utils.HasShape(bbox, tf.concat([bshape, [4, 2]], axis=0)) points = py_utils.HasShape(points, tf.concat([pshape, [2]], axis=0)) # Enumerate all 4 edges: v1, v2, v3, v4 = (bbox[..., 0, :], bbox[..., 1, :], bbox[..., 2, :], bbox[..., 3, :]) v1v2v3_check = tf.reduce_all(_IsCounterClockwiseDirection(v1, v2, v3)) v2v3v4_check = tf.reduce_all(_IsCounterClockwiseDirection(v2, v3, v4)) v4v1v2_check = tf.reduce_all(_IsCounterClockwiseDirection(v4, v1, v2)) v3v4v1_check = tf.reduce_all(_IsCounterClockwiseDirection(v3, v4, v1)) with tf.control_dependencies([ py_utils.Assert(v1v2v3_check, [v1, v2, v3]), py_utils.Assert(v2v3v4_check, [v3, v3, v4]), py_utils.Assert(v4v1v2_check, [v4, v1, v2]), py_utils.Assert(v3v4v1_check, [v3, v4, v1]) ]): is_inside = tf.math.logical_and( tf.math.logical_and(_IsOnLeftHandSideOrOn(points, v1, v2), _IsOnLeftHandSideOrOn(points, v2, v3)), tf.math.logical_and(_IsOnLeftHandSideOrOn(points, v3, v4), _IsOnLeftHandSideOrOn(points, v4, v1))) has_non_zero_area = tf.greater(_BBoxArea(bbox), 0) is_inside = tf.logical_and(tf.cast(is_inside, tf.bool), has_non_zero_area) # Swap the last two dimensions. is_inside = tf.einsum('...ij->...ji', tf.cast(is_inside, tf.int32)) return tf.cast(is_inside, tf.bool)
def IsWithinBBox(points, bbox): """Checks if points are within a 2-d bbox. The function returns true if points are strictly inside the box. It also returns true when the points are exactly on the box edges. Args: points: a float Tensor of shape [..., 2] of points to be tested. The last coordinates are (x, y). bbox: a float Tensor of shape [..., 4, 2] of bboxes. The last coordinates are the four corners of the bbox and (x, y). The corners are assumed to be given in counter-clockwise order. Returns: If pshape = tf.shape(points)[:-1] and bshape = tf.shape(bbox)[:-2], a tensor of shape tf.concat(pshape, bshape), of booleans, where each element is true if the point is inside to the corresponding box. If a point falls exactly on an edge of the bbox, it is also true. """ bshape = py_utils.GetShape(bbox)[:-2] pshape = py_utils.GetShape(points)[:-1] bbox = py_utils.HasShape(bbox, bshape + [4, 2]) points = py_utils.HasShape(points, pshape + [2]) # Enumerate all 4 edges: v1, v2, v3, v4 = (bbox[..., 0, :], bbox[..., 1, :], bbox[..., 2, :], bbox[..., 3, :]) v1v2v3_check = tf.reduce_all(_IsCounterClockwiseDirection(v1, v2, v3)) v2v3v4_check = tf.reduce_all(_IsCounterClockwiseDirection(v2, v3, v4)) v4v1v2_check = tf.reduce_all(_IsCounterClockwiseDirection(v4, v1, v2)) v3v4v1_check = tf.reduce_all(_IsCounterClockwiseDirection(v3, v4, v1)) with tf.control_dependencies([ py_utils.Assert(v1v2v3_check, [v1, v2, v3]), py_utils.Assert(v2v3v4_check, [v3, v3, v4]), py_utils.Assert(v4v1v2_check, [v4, v1, v2]), py_utils.Assert(v3v4v1_check, [v3, v4, v1]) ]): is_inside = tf.logical_and( tf.logical_and( _IsOnLeftHandSideOrOn(points, v1, v2), _IsOnLeftHandSideOrOn(points, v2, v3)), tf.logical_and( _IsOnLeftHandSideOrOn(points, v3, v4), _IsOnLeftHandSideOrOn(points, v4, v1))) return is_inside
def get_accuracy(self, loss, pred, target): p = self.params int_dtype = pred.dtype target = tf.cast(target, int_dtype) pad_id = int(p.input.feature_neighborhood_input.batch_opts.pad_value) mask = tf.cast(tf.math.not_equal(target, pad_id), int_dtype) pred *= mask num_non_zero = tf.cast(tf.reduce_sum(mask), tf.float32) equal = tf.math.equal(pred, target) loss["accuracy_per_example"] = (tf.reduce_mean( tf.cast(tf.reduce_all(equal, axis=1), tf.float32)), p.input.batch_size) equal = tf.cast(equal, tf.float32) equal *= tf.cast(mask, tf.float32) loss["accuracy_per_char"] = (tf.reduce_sum(equal) / num_non_zero, p.input.batch_size)
def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs): """Wrapper for adding bias to _PreBeamSearchStateCallback. Biases results.log_probs towards provided encoder_outputs.targets. Args: theta: a NestedMap of parameters. encoder_outputs: a NestedMap computed by encoder. step_ids: A tensor of shape [tgt_batch, 1]. states: A `.NestedMap` of tensors representing states that the clients would like to keep track of for each of the active hyps. num_hyps_per_beam: Beam size. *args: additional arguments to _PreBeamSearchStepCallback. **kwargs: additional arguments to _PreBeamSearchStepCallback. Returns: A tuple (results, out_states). results: A `.NestedMap` of beam search results. atten_probs: The updated attention probs, of shape [tgt_batch, src_len]. log_probs: Log prob for each of the tokens in the target vocab. This is of shape [tgt_batch, vocab_size]. out_states: a `.NestedMap` The updated states. The states relevant here are: time_step: A scalar indicating current step of decoder. Must be provided and maintained by subclass. consistent: A boolean vector of shape [tgt_batch, ] which tracks whether each hypothesis has exactly matched encoder_outputs.targets so far. """ p = self.params time_step = states.time_step bs_results, out_states = self._PreBeamSearchStepCallback( theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs) labels = encoder_outputs.targets.labels weights = encoder_outputs.targets.weights def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile(tensor, [num_hyps_per_beam, 1 ]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot( label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) return tf.log(probs), consistent def NoApplyBias(): """No-op. Return original log_probs and consistent.""" return bs_results.log_probs, states.consistent log_probs, consistent = tf.cond( tf.reduce_all(tf.equal(weights, 0.0)), NoApplyBias, ApplyBias) bs_results.log_probs = log_probs out_states.consistent = consistent return bs_results, out_states
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def _Upd(c, k, x): stats[k] = x is_finite_checks.append(tf.reduce_all(tf.math.is_finite(x))) return c
def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids, unused_hyp_lens, done_hyps, unused_other_states_list): return tf.logical_and(cur_step < max_steps, tf.logical_not(tf.reduce_all(done_hyps)))
def Callback(theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs): p = self.params time_step = states.time_step bs_results, out_states = self._PreBeamSearchStepCallback( theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs) def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile( tensor, [num_hyps_per_beam, 1]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) if biased: labels = encoder_outputs.targets.labels weights = encoder_outputs.targets.weights def ApplyBias(): """Bias and update log_probs and consistent.""" # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden # later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, py_utils.FPropDtype(p)) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] label_probs = tf.one_hot(label, vocab_size, dtype=py_utils.FPropDtype( p)) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) # Ensure that tf.math.log is applied to positive values. probs = tf.maximum(probs, tf.constant(1e-12, dtype=probs.dtype)) return tf.math.log(probs), consistent def NoApplyBias(): """No-op. Return original log_probs and consistent.""" return bs_results.log_probs, states.consistent log_probs, consistent = tf.cond( tf.reduce_all(tf.equal(weights, 0.0)), NoApplyBias, ApplyBias) bs_results.log_probs = log_probs out_states.consistent = consistent if stochastic: log_probs = bs_results.log_probs def PerturbedLogProbs(): # STEP 1: Perform top-k filtering. This is done as a performance # optimization of avoiding sorting the entire `log_probs`, which is # prohibitively slow. top_k = tf.math.top_k(log_probs, k, sorted=True) # shape: [tgt_batch, k] top_k_log_probs = top_k.values # shape: [tgt_batch, k] top_k_ids = top_k.indices # STEP 2: Perform top-p filtering. # shape: [tgt_batch] top_p_threshold = encoder_outputs.stochastic_beam_search.top_p_threshold top_p_threshold = tf.clip_by_value(top_p_threshold, 0., 1.) top_p_threshold = TileForBeamAndFlatten(top_p_threshold) # shape: [tgt_batch, k] filtered_top_k_log_probs = _KeepTopP( top_k_log_probs, top_p_threshold) # STEP 3: Perturb cumulative log-probs. # shape: [tgt_batch, 1] last_cumulative_log_probs = states.cumulative_log_probs # shape: [tgt_batch, 1] last_perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs # Compute cumulative log-probs of the current step. # shape: [tgt_batch, k] cumulative_log_probs = (last_cumulative_log_probs + filtered_top_k_log_probs) # Perturb cumulative log-probs by Gumbel noises under the condition # that the max of the new perturbed log-probs is equal to # perturbed_cumulative_log_probs of the previous step. # shape: [tgt_batch, k] new_perturbed_cumulative_log_probs = _SampleGumbelWithMax( cumulative_log_probs, last_perturbed_cumulative_log_probs, encoder_outputs.stochastic_beam_search.seed, time_step, encoder_outputs.stochastic_beam_search.src_ids, encoder_outputs.stochastic_beam_search.src_paddings) # STEP 4: Compute updated log_probs. This step is necessary because # the output of PreBeamSearchStepCallback must be "per-step" # log-probs, whereas so far "cumulative" log-probs have been computed. # shape: [tgt_batch, k] updated_top_k_log_probs = ( new_perturbed_cumulative_log_probs - last_perturbed_cumulative_log_probs) # Convert to the shape [tgt_batch, vocab_size]. updated_log_probs = tf.fill( tf.shape(log_probs), tf.constant(LARGE_NEGATIVE_NUMBER, dtype=log_probs.dtype)) updated_log_probs = _BatchScatter(updated_log_probs, top_k_ids, updated_top_k_log_probs) return (updated_log_probs, py_utils.NestedMap( new_perturbed_cumulative_log_probs= new_perturbed_cumulative_log_probs, top_k_log_probs=top_k_log_probs, top_k_ids=top_k_ids, )) (bs_results.log_probs, out_states.tmp_states) = tf.cond( encoder_outputs.stochastic_beam_search.enable, PerturbedLogProbs, # No-op. lambda: (bs_results.log_probs, states.tmp_states)) # These states are not updated here but will be updated in # PostBeamSearchStepCallback since doing so requires the knowledge of # the next step IDs. out_states.cumulative_log_probs = states.cumulative_log_probs out_states.perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs return bs_results, out_states
def StopFn(recurrent_theta, state, inputs): del recurrent_theta, inputs return tf.logical_not( tf.reduce_all(tf.equal(state.ids, p.target_eos_id)))
def max_assignment(score: tf.Tensor, *, elementwise_upper_bound: tf.Tensor, row_sums: tf.Tensor, col_sums: tf.Tensor, epsilon: float = 0.1, num_iterations: int = 50, use_epsilon_scaling: bool = True): """Differentiable max assignment with margin and upper bound constraints. Args: score: a 3D tensor of size [batch_size, n_rows, n_columns]. score[i, j, k] denotes the weight if the assignment on this entry is non-zero. elementwise_upper_bound: a 3D tensor of size [batch_size, n_rows, n_columns]. Each entry denotes the maximum value assignment[i, j, k] can take and must be a non-negative value. For example, upper_bound[i, j, k]=1.0 for binary assignment problem. row_sums: a 2D tensor of size [batch_size, n_rows]. The row sum constraint. The output assignment p[i, j, :] must sum to row_sums[i, j]. col_sums: a 2D tensor of size [batch_size, n_columns]. The column sum constraint. The output assignment p[i, :, k] must sum to col_sums[i, k]. epsilon: the epsilon coefficient of entropy regularization. The value should be within the range (0, 1]. `0.01` might work better than `0.1`. `0.1` may not make the assignment close enough to 0 or 1. num_iterations: the maximum number of iterations to perform. use_epsilon_scaling: whether to use epsilon scaling. In practice, the convergence of the iterative algorithm is much better if we start by solving the optimization with a larger epsilon value and re-use the solution (i.e. dual variables) for the instance with a smaller epsilon. This is called the epsilon scaling trick. See [Schmitzer 2019] (https://arxiv.org/pdf/1610.06519.pdf) as a reference. Here if use_epsilon_scaling=True, after each iteration we decrease the running epsilon by a constant factor until it reaches the target epsilon value. We found this to work well for gradient backward propagation, while the original scaling trick doesn't. Returns: A tuple with the following values. - assignment: a 3D tensor of size [batch_size, n_rows, n_columns]. The output assignment. - used_iter: a scalar tensor indicating the number of iterations used. - eps: a scalar tensor indicating the stopping epsilon value. - delta: a scalar tensor indicating the stopping delta value (the relative change on the margins of assignment p in the last iteration). """ # Check if all shapes are correct score_shape = score.shape bsz = score_shape[0] n = score_shape[1] m = score_shape[2] score = tf.ensure_shape(score, [bsz, n, m]) elementwise_upper_bound = tf.ensure_shape(elementwise_upper_bound, [bsz, n, m]) row_sums = tf.ensure_shape(tf.expand_dims(row_sums, axis=2), [bsz, n, 1]) col_sums = tf.ensure_shape(tf.expand_dims(col_sums, axis=1), [bsz, 1, m]) # the total sum of row sums must be equal to total sum of column sums sum_diff = tf.reduce_sum(row_sums, axis=1) - tf.reduce_sum(col_sums, axis=2) sum_diff = tf.abs(sum_diff) tf.Assert(tf.reduce_all(sum_diff < 1e-6), [sum_diff]) # Convert upper_bound constraint into another margin constraint # by adding auxiliary variables & scores. Tensor `a`, `b` and `c` # represent the margins (i.e. reduced sum) of 3 axes respectively. # max_row_sums = tf.reduce_sum(elementwise_upper_bound, axis=-1, keepdims=True) max_col_sums = tf.reduce_sum(elementwise_upper_bound, axis=-2, keepdims=True) score_ = tf.stack([score, tf.zeros_like(score)], axis=1) # (bsz, 2, n, m) a = tf.stack([row_sums, max_row_sums - row_sums], axis=1) # (bsz, 2, n, 1) b = tf.stack([col_sums, max_col_sums - col_sums], axis=1) # (bsz, 2, 1, m) c = tf.expand_dims(elementwise_upper_bound, axis=1) # (bsz, 1, n, m) # Clip log(0) to a large negative values -1e+36 to avoid # getting inf or NaN values in computation. Cannot use larger # values because float32 would use `-inf` automatically. # tf.Assert(tf.reduce_all(a >= 0), [a]) tf.Assert(tf.reduce_all(b >= 0), [b]) tf.Assert(tf.reduce_all(c >= 0), [c]) log_a = tf.maximum(tf.math.log(a), -1e+36) log_b = tf.maximum(tf.math.log(b), -1e+36) log_c = tf.maximum(tf.math.log(c), -1e+36) # Initialize the dual variables of margin constraints u = tf.zeros_like(a) v = tf.zeros_like(b) w = tf.zeros_like(c) eps = tf.constant(1.0 if use_epsilon_scaling else epsilon, dtype=score.dtype) epsilon = tf.constant(epsilon, dtype=score.dtype) def do_updates(cur_iter, eps, u, v, w): # pylint: disable=unused-argument # Epsilon scaling, i.e. gradually decreasing `eps` until it # reaches the target `epsilon` value cur_iter = tf.cast(cur_iter, u.dtype) scaling = tf.minimum(0.6 * 1.04**cur_iter, 0.85) eps = tf.maximum(epsilon, eps * scaling) score_div_eps = score_ / eps # Update u log_q_1 = score_div_eps + (w + v) / eps log_q_1 = tf.reduce_logsumexp(log_q_1, axis=-1, keepdims=True) new_u = (log_a - tf.maximum(log_q_1, -1e+30)) * eps # Update v log_q_2 = score_div_eps + (w + new_u) / eps log_q_2 = tf.reduce_logsumexp(log_q_2, axis=-2, keepdims=True) new_v = (log_b - tf.maximum(log_q_2, -1e+30)) * eps # Update w log_q_3 = score_div_eps + (new_u + new_v) / eps log_q_3 = tf.reduce_logsumexp(log_q_3, axis=-3, keepdims=True) new_w = (log_c - tf.maximum(log_q_3, -1e+30)) * eps return eps, new_u, new_v, new_w def compute_relative_changes(eps, u, v, w, new_eps, new_u, new_v, new_w): prev_sum_uvw = tf.stop_gradient((u + v + w) / eps) sum_uvw = tf.stop_gradient((new_u + new_v + new_w) / new_eps) # Compute the relative changes on margins of P. # This will be used for stopping criteria. # Note the last update on w would guarantee the # margin constraint c is satisfied, so we don't # need to check it here. p = tf.exp(tf.stop_gradient(score_ / new_eps + sum_uvw)) p_a = tf.reduce_sum(p, axis=-1, keepdims=True) p_b = tf.reduce_sum(p, axis=-2, keepdims=True) delta_a = tf.abs(a - p_a) / (a + 1e-6) delta_b = tf.abs(b - p_b) / (b + 1e-6) new_delta = tf.reduce_max(delta_a) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_b)) # Compute the relative changes on assignment solution P. # This will be used for stopping criteria. delta_p = tf.abs(tf.exp(prev_sum_uvw) - tf.exp(sum_uvw)) / (tf.exp(sum_uvw) + 1e-6) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_p)) return new_delta for cur_iter in tf.range(num_iterations): prev_eps, prev_u, prev_v, prev_w = eps, u, v, w eps, u, v, w = do_updates(cur_iter, eps, u, v, w) delta = compute_relative_changes(prev_eps, prev_u, prev_v, prev_w, eps, u, v, w) cur_iter = num_iterations assignment = tf.exp((score_ + u + v + w) / eps) assignment = assignment[:, 0] return assignment, cur_iter, eps, delta
def FProp(self, theta, input_batch, interpolation_batch=None, lambdas=None): # pyformat: disable """Interpolates source ids in input_batch and interpolation_batch. Refer to Eq. (4) in paper https://arxiv.org/abs/2106.04060. It is a standard Transformer Encoder if interpolation_batch != None. Args: theta: A `.NestedMap` object containing weights values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. interpolation_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. - embs: Embeddings of ids. lambdas: A pair of tensors to combine embeddings of ids in input_batch and interpolation_batch. Returns: A NestedMap of - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ # pyformat: enable p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match( tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) max_seq_length = None if (not py_utils.use_tpu() and FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, :max_seq_length] src_segment_pos = input_batch.segment_pos[:, :max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup(theta.softmax, tf.reshape(input_ids, [-1])) if interpolation_batch is not None: other_input_ids = interpolation_batch.ids if not p.shared_emb: other_input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(other_input_ids, [-1])) else: other_input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(other_input_ids, [-1])) lambdas = [tf.expand_dims(a, -1) for a in lambdas] if 'embs' in input_batch and input_batch.embs is not None: input_embs = input_batch.embs if 'embs' in interpolation_batch and interpolation_batch.embs is not None: other_input_embs = interpolation_batch.embs else: input_embs = tf.reshape( input_embs, [-1, tf.shape(input_ids)[1], p.token_emb.embedding_dim]) other_input_embs = tf.reshape( other_input_embs, [-1, tf.shape(other_input_ids)[1], p.token_emb.embedding_dim]) input_embs = lambdas[0] * input_embs + lambdas[1] * other_input_embs paddings = paddings + interpolation_batch.paddings - 1.0 paddings = tf.clip_by_value(paddings, 0.0, 1.0) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) orig_input_embs = input_embs if p.task_emb: if interpolation_batch is None: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) else: task_embs = self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) other_task_embs = self.task_emb.EmbLookup( theta.task_emb, interpolation_batch.task_ids) task_embs = lambdas[0] * task_embs + lambdas[1] * other_task_embs input_embs += task_embs if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp(theta.position_emb, max_time) position_embs = tf.reshape(position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where( tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap( encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)