def _MergeCandidates(tokens, candidates): """Merge in the reverse binary tree.""" best_id = tf.argmin(candidates, output_type=tf.int32) # Perform the merge at position best_id. tokens = tf.concat([ tokens[:best_id], [candidates[best_id]], tokens[best_id + 2:] ], axis=0) # Recompute the merge candidates. # Only the neighbors of best_id need to be recomputed. empty = tf.zeros([0], dtype=candidates.dtype) def _MergeLeft(): return tf.concat([ candidates[:best_id - 1], _MergeOneToken(tokens, best_id - 1) ], axis=0) left_candidates = tf.cond(tf.equal(best_id, 0), lambda: empty, _MergeLeft) def _MergeRight(): return tf.concat([ _MergeOneToken(tokens, best_id), candidates[best_id + 2:] ], axis=0) right_candidates = tf.cond( tf.greater_equal(best_id, tf.size(tokens) - 1), lambda: empty, _MergeRight) candidates = tf.concat([left_candidates, right_candidates], axis=0) return tokens, candidates
def _GetRandomRealPoint(): """Select the first point. For the first point, we want any random real (non padded) point, so we create a random values per point, and then set all padded ones to some large value (more than the maxval). We then take the min per batch element to get the first points. """ random_values = tf.random.uniform((batch_size, num_points), minval=0, maxval=1, dtype=tf.float32, seed=random_seed) random_values = tf.where(tf.equal(padding, 0.0), random_values, padding * 10) return tf.argmin(random_values, axis=1, output_type=tf.int32)