def _testPackedInputs(self, dtype=tf.float32): p = self._DecoderParams() np.random.seed(_NUMPY_RANDOM_SEED) src_time = 5 batch = 2 emb_dims = 4 tgt_time = 5 src_enc = tf.constant( np.random.normal(size=[src_time, batch, p.source_dim]), dtype=dtype) paddings = tf.zeros([src_time, batch], dtype=dtype) tgt_ids = tf.constant(np.random.randint(20, size=[batch, tgt_time]), dtype=tf.int32) tgt_labels = tf.constant(np.random.randint(20, size=[batch, tgt_time]), dtype=tf.int32) tgt_paddings = tf.zeros([batch, tgt_time], dtype=dtype) tgt_weights = 1.0 - tgt_paddings tgts = py_utils.NestedMap({ 'ids': tgt_ids, 'labels': tgt_labels, 'weights': tgt_weights, 'paddings': tgt_paddings }) src_enc_packed = tf.transpose(src_enc, [1, 0, 2]) src_enc_packed = tf.reshape(src_enc_packed, [-1, 1, emb_dims]) src_enc_padding_packed = tf.reshape(paddings, [-1, 1]) target_packed = py_utils.NestedMap({ 'ids': tf.reshape(tgts.ids, [1, -1]), 'labels': tf.reshape(tgts.labels, [1, -1]), 'weights': tf.reshape(tgts.weights, [1, -1]), 'paddings': tf.reshape(tgts.paddings, [1, -1]) }) src_segment_id = tf.transpose( tf.constant(np.asarray([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]]), dtype=tf.float32)) target_packed.segment_ids = tf.constant(np.asarray( [[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]]), dtype=tf.float32) target_packed.segment_pos = tf.constant( np.asarray([[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]])) return (src_enc, paddings, tgts, src_enc_packed, src_enc_padding_packed, src_segment_id, target_packed)
def testBiEncoderForwardPassWithTransparent(self): with self.session(use_gpu=False): tf.random.set_seed(8372749040) p = self._BiEncoderParams() p.is_transparent = True mt_enc = encoder.MTEncoderBiRNN(p) batch = py_utils.NestedMap() batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2])) batch.paddings = tf.zeros([2, 4]) enc_out = mt_enc.FPropDefaultTheta(batch).encoded self.evaluate(tf.global_variables_initializer()) actual_enc_out = enc_out.eval() expected_enc_out = [[[-7.4536911e-05, 8.8465633e-05], [2.8940600e-05, 3.2297492e-05]], [[-1.9775725e-05, 9.8312848e-05], [5.1837378e-05, 1.2998647e-05]], [[4.5528584e-05, -6.8125606e-05], [1.0955606e-04, -2.1024598e-04]], [[8.5454740e-05, -1.8263397e-04], [5.2042866e-05, -1.6407830e-04]]] self.assertAllClose(expected_enc_out, actual_enc_out)
def _ProcessMASSInput(self, source_id, src): """Perform MASS input processing.""" skip_mass = self.do_eval and not self.params.enable_mass_for_eval if skip_mass or self.mass_layer is None: # At eval time, we copy src to tgt return self._ProcessSingleInput(source_id, src, src) _, labels, paddings = self.StringsToIds(tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key) weights = 1 - paddings actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32) src_lang_ids, tgt_lang_ids = self._GetLangIds(source_id) mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len) features = py_utils.NestedMap() features.src = py_utils.NestedMap() features.src.ids = mass_out.src.ids features.src.paddings = paddings features.src.weights = weights features.src.task_ids = tf.cast(features.src.weights, dtype=tf.int32) * src_lang_ids features.src.source_ids = tf.cast(features.src.weights, dtype=tf.int32) * source_id features.src.ids_indicator = weights features.tgt = py_utils.NestedMap() features.tgt.ids = mass_out.tgt.ids features.tgt.labels = mass_out.tgt.labels features.tgt.paddings = paddings features.tgt.weights = mass_out.tgt.weights features.tgt.task_ids = tf.cast(weights, dtype=tf.int32) * tgt_lang_ids features.tgt.source_ids = tf.cast(weights, dtype=tf.int32) * source_id features.tgt.ids_indicator = weights if not py_utils.use_tpu(): features.src.strs = src features.tgt.strs = src return features.Transform(tf.squeeze)
def _ProcessMASSInput(self, source_id, src): """Perform MASS input processing.""" # TODO(yuancao): By doing so we assume that right now for monolingual # eval/dev sets (xx->xx) are in double-column format (since it bypasses # the Mass op). Ideally we should add a dedicated eval/dev processing # procedure for unsupervised MT cases, so that single-column eval/devs sets # are also supported. This should not be handled by any specific ops like # Mass, but inside the TextPackedInput class. assert not self.do_eval, 'MASS input can only be used for training.' _, labels, paddings = self.StringsToIds( tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key) weights = 1 - paddings actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32) src_lang_ids, tgt_lang_ids = self._GetTaskIds(source_id) mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len) features = py_utils.NestedMap() features.src = py_utils.NestedMap() features.src.ids = mass_out.src.ids features.src.paddings = paddings features.src.weights = weights features.src.task_ids = tf.cast( features.src.weights, dtype=tf.int32) * src_lang_ids features.src.ids_indicator = weights features.tgt = py_utils.NestedMap() features.tgt.ids = mass_out.tgt.ids features.tgt.labels = mass_out.tgt.labels features.tgt.paddings = paddings features.tgt.weights = mass_out.tgt.weights features.tgt.task_ids = tf.ones_like( features.src.task_ids, dtype=tf.int32) * tgt_lang_ids features.tgt.ids_indicator = weights if not py_utils.use_tpu(): features.src.strs = src features.tgt.strs = src return features.Transform(tf.squeeze)
def testIntraModalLabels(self): # Simulate a batch of 4 examples with 2 items each in the 'text' modality. batch_size = 4 items_per_example = 2 modality = 'text' modality_shape = tf.TensorShape([batch_size, items_per_example]) inputs = label_lib.ExamplePairs.WithinBatch( batch=dict(some_feature=tf.range(batch_size)), query_modality=modality, result_modality=modality) def example_pair_labeler(_): return tf.constant([ [1, 0, 0, X], [0, 1, 0, 0], [0, 0, 1, 0], [X, 0, 0, 1], ]) labeler = label_lib.MultiItemExampleWrapper( example_pair_labeler, modality_batch_shapes={modality: modality_shape}) labels = labeler(inputs) self.assertEqual(modality_shape + modality_shape, labels.shape) # The pairwise labels actually have rank 4 (twice the rank of ids), but we # compare them in matrix form for easier inspection. There are 8 items # total. Each should have a positive label for every other item from the # same example. Self-pairs should be ignored (they are neither positive # nor negative pairs), as well as pairs from duplicated examples. self.assertAllEqual([ [X, 1, 0, 0, 0, 0, X, X], [1, X, 0, 0, 0, 0, X, X], [0, 0, X, 1, 0, 0, 0, 0], [0, 0, 1, X, 0, 0, 0, 0], [0, 0, 0, 0, X, 1, 0, 0], [0, 0, 0, 0, 1, X, 0, 0], [X, X, 0, 0, 0, 0, X, 1], [X, X, 0, 0, 0, 0, 1, X], ], tf.reshape(labels, [8, 8]))
def Proc(record): """Parses a serialized tf.Example record.""" # There we go! string, string, float32. I hope frames is allowed # to be a waveform directly... features = [ ('uttid', tf.io.VarLenFeature(tf.int64)), # Would like to change this to tf.int16 in the future, if that is possible (would have to read from ('frames', tf.io.VarLenFeature(tf.float32)), ] example = tf.io.parse_single_example(record, dict(features)) fval = {k: v.values for k, v in example.items()} # Reshape the flattened vector into its original time-major # representation. fval['frames'] = tf.reshape(fval['frames'], shape=[-1, self.params.frame_size]) # Input duration determines the bucket. bucket_key = tf.cast(tf.shape(fval['frames'])[0], tf.int32) if self.params.append_eos_frame: bucket_key += 1 src_paddings = tf.zeros([tf.shape(fval['frames'])[0]], dtype=tf.float32) return [fval['uttid'], fval['frames'], src_paddings], bucket_key
def testForwardPassPackedInput(self): with self.session(use_gpu=False) as sess: bs = 2 sl = 21 d = 16 tf.random.set_seed(8372749040) p = self._EncoderParams(packed_input=True) mt_enc = p.Instantiate() batch = py_utils.NestedMap() batch.ids = tf.constant( np.random.randint(low=0, high=63, size=[bs, sl], dtype=np.int32)) # Pack these into a single batch packed_bs = 1 packed_sl = 2 * sl batch.ids = tf.reshape(batch.ids, [packed_bs, packed_sl]) batch.paddings = tf.zeros([packed_bs, packed_sl]) batch.segment_pos = [ list(range(sl)) + list(range(sl)), ] batch.segment_ids = [ [0 for i in range(sl)] + [1 for i in range(sl)], ] out = mt_enc.FPropDefaultTheta(batch) enc_out_sum = tf.reduce_sum(out.encoded) tf.global_variables_initializer().run() actual_enc_out, actual_enc_out_sum = sess.run( [out.encoded, enc_out_sum]) self.assertAllEqual([packed_sl, packed_bs, d], actual_enc_out.shape) self.assertAllClose(306.010132, actual_enc_out_sum)
def Unflatten(self, flat_tensors): """The inverse of Flatten(); expands the leading dim to `batch_shape`. Args: flat_tensors: A tensor or structure of tensors to be reshaped. Returns: The reshaped tensors, with `batch_shape.rank` - 1 more dimensions, in the same format (tensor, list, dict) as the input. """ if self._is_no_op: return flat_tensors batch_shape = self._batch_shape.as_list() if batch_shape[0] is None: batch_shape[0] = -1 unflattened_tensors = [ tf.reshape(flat_tensor, batch_shape + flat_tensor.shape.as_list()[1:]) for flat_tensor in tf.nest.flatten(flat_tensors) ] return tf.nest.pack_sequence_as(flat_tensors, unflattened_tensors)
def testForwardPass(self): with self.session(use_gpu=False): tf.set_random_seed(8372749040) p = self._EncoderParams() mt_enc = encoder.MTEncoderV1(p) batch = py_utils.NestedMap() batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2])) batch.paddings = tf.zeros([2, 4]) enc_out = mt_enc.FPropDefaultTheta(batch).encoded tf.global_variables_initializer().run() actual_enc_out = enc_out.eval() expected_enc_out = [[[ -7.51581979e-07, 1.55304758e-06, -3.39117889e-07, 2.79457527e-06 ], [-1.06733505e-05, 7.56898862e-06, -4.18875834e-06, -9.10360086e-06]], [ [1.58444971e-06, 5.11627661e-07, 1.33408967e-05, 1.81603957e-06], [-1.59942228e-05, 1.26068180e-05, 4.49321249e-07, -1.43790385e-05] ], [[5.56546365e-06, -8.01007627e-06, 8.96620350e-06, 3.96485439e-06], [ -8.77006005e-06, 4.04282991e-06, -4.79895652e-06, -5.90156833e-06 ]], [[-8.59513818e-07, -7.63760727e-06, -5.57065960e-06, 1.80756274e-06], [-2.96017470e-06, -1.51323195e-06, -1.03562079e-05, 1.23328198e-06]]] self.assertAllClose(expected_enc_out, actual_enc_out)
def _Extract(self, features): """Returns the image Tensor.""" outputs = py_utils.NestedMap() p = self.params for camera_name in p.camera_names: image_shape = tf.reshape( _Dense(features['image_%s_shape' % camera_name]), [-1]) image_shape = tf.cast(image_shape, tf.int32) if p.decode_image: image = tf.io.decode_png( tf.strings.reduce_join( _Dense(features['image_%s' % camera_name], default_value=''))) image = tf.reshape(image, image_shape) image = py_utils.PadOrTrimTo(image, p.image_shape) intrinsics = tf.reshape( _Dense(features['camera_%s_intrinsics' % camera_name]), [9]) extrinsics = tf.reshape( _Dense(features['camera_%s_extrinsics' % camera_name]), [4, 4]) pose = tf.reshape(_Dense(features['image_%s_pose' % camera_name]), [4, 4]) velocity = tf.reshape( _Dense(features['image_%s_velocity' % camera_name]), [6]) outputs[camera_name] = py_utils.NestedMap() if p.decode_image: outputs[camera_name]['image'] = tf.cast( image, p.image_output_dtype) outputs[camera_name]['image_shape'] = image_shape outputs[camera_name]['intrinsics'] = intrinsics outputs[camera_name]['extrinsics'] = extrinsics outputs[camera_name]['pose'] = pose outputs[camera_name]['velocity'] = velocity outputs[camera_name]['rolling_shutter_direction'] = features[ 'camera_%s_rolling_shutter_direction' % camera_name] for feat in [ 'shutter', 'camera_trigger_time', 'camera_readout_done_time', 'pose_timestamp' ]: outputs[camera_name][feat] = features['image_%s_%s' % (camera_name, feat)] return outputs
def Proc(record): """Parses a serialized tf.Example record.""" features = [ ('uttid', tf.VarLenFeature(tf.string)), ('transcript', tf.VarLenFeature(tf.string)), ('frames', tf.VarLenFeature(tf.float32)), ] example = tf.parse_single_example(record, dict(features)) fval = {k: v.values for k, v in six.iteritems(example)} # Reshape the flattened vector into its original time-major # representation. fval['frames'] = tf.reshape(fval['frames'], shape=[-1, self.params.frame_size]) # Input duration determines the bucket. bucket_key = tf.to_int32(tf.shape(fval['frames'])[0]) if self.params.append_eos_frame: bucket_key += 1 tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds( fval['transcript']) src_paddings = tf.zeros([tf.shape(fval['frames'])[0]], dtype=tf.float32) return fval['uttid'], tgt_ids, tgt_labels, tgt_paddings, fval[ 'frames'], src_paddings, bucket_key
def testUniEncoderForwardPass(self): with self.session(use_gpu=False): tf.random.set_seed(8372749040) p = self._UniEncoderParams() mt_enc = encoder.MTEncoderUniRNN(p) batch = py_utils.NestedMap() batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2])) batch.paddings = tf.zeros([2, 4]) enc_out = mt_enc.FPropDefaultTheta(batch).encoded self.evaluate(tf.global_variables_initializer()) actual_enc_out = enc_out.eval() tf.logging.info('testUniEncoderForwardPass actual_enc_out %r' % actual_enc_out) expected_enc_out = [[[-4.3304257e-07, 5.4100457e-07], [-4.0170832e-07, -2.6441572e-07]], [[-1.7024040e-07, -1.8555815e-07], [-6.4563977e-07, -3.7835261e-07]], [[-2.4001852e-07, 5.1114228e-07], [-3.4349023e-07, -1.0049351e-06]], [[1.8068013e-07, -6.8982729e-08], [3.3005003e-07, -8.8834116e-07]]] self.assertAllClose(expected_enc_out, actual_enc_out)
def testBiEncoderForwardPassWithDropout(self): with self.session(use_gpu=False): tf.random.set_seed(8372749040) p = self._BiEncoderParams() p.dropout_prob = 0.5 mt_enc = encoder.MTEncoderBiRNN(p) batch = py_utils.NestedMap() batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2])) batch.paddings = tf.zeros([2, 4]) enc_out = mt_enc.FPropDefaultTheta(batch).encoded self.evaluate(tf.global_variables_initializer()) actual_enc_out = enc_out.eval() print('bi_enc_actual_enc_out_with_dropout', np.array_repr(actual_enc_out)) expected_enc_out = [[[1.60383240e-06, 1.22550023e-06], [-7.21660126e-06, 1.05704457e-05]], [[1.42539475e-05, -2.06075638e-05], [-4.98754298e-06, 1.51066461e-05]], [[-7.15192800e-06, -6.44075908e-06], [5.02962678e-07, -3.40795486e-06]], [[-6.54424548e-06, 9.88359807e-06], [1.42836643e-06, -1.68607176e-06]]] self.assertAllClose(expected_enc_out, actual_enc_out)
def testBiEncoderForwardPass(self): with self.session(use_gpu=False): tf.random.set_seed(8372749040) p = self._BiEncoderParams() mt_enc = encoder.MTEncoderBiRNN(p) batch = py_utils.NestedMap() batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2])) batch.paddings = tf.zeros([2, 4]) enc_out = mt_enc.FPropDefaultTheta(batch).encoded self.evaluate(tf.global_variables_initializer()) actual_enc_out = enc_out.eval() tf.logging.info('testBiEncoderForwardPass actual_enc_out %r' % actual_enc_out) expected_enc_out = [[[-2.47998378e-06, 7.36457878e-06], [7.89248020e-07, -2.67464316e-06]], [[-2.98803275e-06, 8.20233890e-06], [1.00139073e-06, -2.24554151e-06]], [[-5.06675951e-06, 1.15983785e-05], [-4.58391014e-07, -2.99553108e-07]], [[-4.34937465e-06, 8.58816838e-06], [-1.74859031e-06, 3.99598093e-06]]] self.assertAllClose(expected_enc_out, actual_enc_out)
def FProp(self, theta): """Combines the list of input tensors into a single tensor. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. Returns: A tensor of weights with dropout applied with shape [num_sources]. """ p = self.params # The constant factor is just meant to support the non-normalized scenario. # If softmax is applied, this factor will cancel out. w = theta.sum_weight * p.global_weight_scale + (1 / p.num_sources) w = tf.reshape(w, [p.num_sources]) w = self.weighted_merger_dropout.FProp(theta.weighted_merger_dropout, w) if p.weighted_merger_softmax: residual_weights = p.minimal_prob * p.num_sources assert residual_weights >= 0.0 assert residual_weights < 1.0 w = tf.nn.softmax(w, axis=0) * (1.0 - residual_weights) + p.minimal_prob return w
def ConvertToBlocks(x, block_size, padding_val=0.0): """Turns a sequence to non overlapping blocks. Args: x: a tensor of [batch, time, ...]. block_size: int. Number of time frames in a block. padding_val: float. value on the padded frames. Returns: A tensor of [batch, num_blocks, block_size, ...], with necessary paddings, where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...]. """ shape = py_utils.GetShape(x) b, t = shape[:2] if block_size < 1: raise ValueError('block_size must be at least 1, got {}'.format(block_size)) w = block_size # Pad t to be a multiply of w. num_blocks = (t + w - 1) // w pad_to_length = num_blocks * w padded = py_utils.PadSequenceDimension(x, pad_to_length, padding_val) reshaped = tf.reshape(padded, [b, num_blocks, w] + shape[2:]) return reshaped
def testForwardPass(self): with self.session(use_gpu=False): tf.random.set_seed(8372749040) p = self._EncoderParams() mt_enc = encoder.MTEncoderV1(p) batch = py_utils.NestedMap() batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2])) batch.paddings = tf.zeros([2, 4]) enc_out = mt_enc.FPropDefaultTheta(batch).encoded self.evaluate(tf.global_variables_initializer()) actual_enc_out = enc_out.eval() expected_enc_out = [ [[1.5309354e-06, -1.7816075e-07, 3.8047763e-06, -5.6422067e-07], [1.9017770e-06, -2.9778969e-06, -4.5083775e-06, -1.7054812e-06]], [[-2.1852782e-06, -1.8208171e-06, -1.4747930e-06, -5.8206351e-06], [6.7667429e-07, -3.6828042e-06, -1.0916860e-05, -3.2522742e-06]], [[-3.2333378e-07, 3.2147584e-06, 5.0556650e-07, -7.0188378e-07], [-6.5340635e-07, 1.9502845e-06, -9.2459632e-06, 5.1955390e-06]], [[2.0232728e-06, 4.9331529e-06, 1.1346837e-06, 7.5571520e-06], [-5.8475212e-07, 3.5547487e-06, -3.9037773e-06, 8.9575424e-06]] ] self.assertAllClose(expected_enc_out, actual_enc_out)
def testBiEncoderForwardPassWithDropout(self): with self.session(use_gpu=False): tf.random.set_seed(8372749040) p = self._BiEncoderParams() p.dropout_prob = 0.5 mt_enc = encoder.MTEncoderBiRNN(p) batch = py_utils.NestedMap() batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2])) batch.paddings = tf.zeros([2, 4]) enc_out = mt_enc.FPropDefaultTheta(batch).encoded self.evaluate(tf.global_variables_initializer()) actual_enc_out = enc_out.eval() print('bi_enc_actual_enc_out_with_dropout', np.array_repr(actual_enc_out)) expected_enc_out = [[[-1.8358192e-05, 1.2103478e-05], [2.9347059e-06, -3.0652325e-06]], [[-8.1282624e-06, 4.5443494e-06], [3.0826509e-06, -5.2950490e-06]], [[-4.6669629e-07, 2.4246765e-05], [-1.5221613e-06, -1.9654153e-06]], [[-1.1511075e-05, 1.9061190e-05], [-5.7250163e-06, 9.2785704e-06]]] self.assertAllClose(expected_enc_out, actual_enc_out)
def _CellFeaturizer(self, theta, input_batch): """Featurizes each center location.""" # Validate Shapes cell_feature = py_utils.HasRank(input_batch.cell_feature, 4) batch_size, num_centers, num_points_per_cell = py_utils.GetShape( cell_feature, 3) cell_points_xyz = py_utils.HasShape( input_batch.cell_points_xyz, [batch_size, num_centers, num_points_per_cell, 3]) cell_center_xyz = py_utils.HasShape(input_batch.cell_center_xyz, [batch_size, num_centers, 3]) cell_points_padding = py_utils.HasShape( input_batch.cell_points_padding, [batch_size, num_centers, num_points_per_cell]) # Center each cell cell_center_xyz = tf.reshape(cell_center_xyz, [batch_size, num_centers, 1, 3]) centered_cell_points_xyz = cell_points_xyz - cell_center_xyz concat_feature = tf.concat([ centered_cell_points_xyz, cell_feature ], axis=-1) # pyformat: disable # Featurize point clouds at each center. point_input = py_utils.NestedMap({ 'points': centered_cell_points_xyz, 'features': concat_feature, 'padding': cell_points_padding, }) featurized_cell = self.cell_featurizer.FProp(theta.cell_featurizer, point_input) featurized_cell = py_utils.HasShape(featurized_cell, [batch_size, num_centers, -1]) return featurized_cell
def AttenLogitsRPEOneStep(self, query, key, abs_pos_emb): """RPE attention logits for one single target (query) step. B: batch size S: sequence length N: num of attention heads. H: per-head attention dimension. Args: query: [B, N, H]. key: [S, B, N, H] or [S, B, N*H/128, 128]. abs_pos_emb: [S, 1, N, H] Returns: A Tensor of shape [S, B, N] """ s, b, _, _ = py_utils.GetShape(key, 4) _, n, h = py_utils.GetShape(query, 3) key = tf.reshape(key, [s, b, n, h]) key_emb = key + abs_pos_emb query, key_emb = self.ToAqtActActInputs(query, key_emb) logits = tf.einsum('BNH,SBNH->SBN', query, key_emb) return self.FromAqtActActMatmul(logits)
def testMergeBeamSearchOutputs(self): with self.session(): topk_scores_1 = [[1., 3., 5.], [-2., -1., 0.]] topk_ids_1 = [[[10, 11, 12], [30, 31, 32], [50, 51, 52]], [[20, 21, 22], [10, 11, 12], [0, 0, 0]]] topk_lens_1 = [[3, 3, 2], [3, 3, 0]] topk_hyps_1 = [['one', 'three', 'five'], ['minus two', 'minus one', '']] topk_1 = beam_search_helper.BeamSearchDecodeOutput( None, tf.constant(topk_hyps_1), tf.reshape(tf.constant(topk_ids_1), [6, -1]), tf.reshape(tf.constant(topk_lens_1), [-1]), tf.reshape(tf.constant(topk_scores_1), [-1]), None, None) topk_scores_2 = [[2., 4.], [-3., 0.]] topk_ids_2 = [[[20, 21, 22], [40, 41, 42]], [[30, 31, 33], [0, 0, 0]]] topk_lens_2 = [[3, 2], [3, 0]] topk_hyps_2 = [['two', 'four'], ['minus three', '']] topk_2 = beam_search_helper.BeamSearchDecodeOutput( None, tf.constant(topk_hyps_2), tf.reshape(tf.constant(topk_ids_2), [4, -1]), tf.reshape(tf.constant(topk_lens_2), [-1]), tf.reshape(tf.constant(topk_scores_2), [-1]), None, None) topk = beam_search_helper.MergeBeamSearchOutputs( 3, [topk_1, topk_2]) self.assertIsNone(topk.done_hyps) self.assertIsNone(topk.topk_decoded) self.assertAllEqual([5., 4., 3., -1., -2., -3.], topk.topk_scores.eval()) self.assertAllEqual([2, 2, 3, 3, 3, 3], topk.topk_lens.eval()) self.assertAllEqual([[50, 51, 52], [40, 41, 42], [30, 31, 32], [10, 11, 12], [20, 21, 22], [30, 31, 33]], topk.topk_ids.eval()) self.assertAllEqual([[b'five', b'four', b'three'], [b'minus one', b'minus two', b'minus three']], topk.topk_hyps.eval())
def _ReshapeTransform(inp): """Reshape the transformation tensor to [..., idims, odims].""" base_shape = py_utils.GetShape(inp)[:-1] out_shape = list(base_shape) + [idims, odims] return tf.reshape(inp, out_shape)
def _CombineLastTwoDims(x): shape = py_utils.GetShape(x) return tf.reshape(x, shape[:-2] + [np.prod(shape[-2:])])
def NeighborhoodIndices(points, query_points, k, points_padding=None, max_distance=None, sample_neighbors_uniformly=False): """Get indices to k-neighbors of query_points in points. Padding is returned along-side indices. Non-padded points are guaranteed to be unique (non-repeated) points from original non-padded points. Padded points arise due to either a lack of points (k exceeds the number of original non-padded points) or points are too far away (exceeds max distance). Note: Padded point indices may refer to padded points from the original, or may be duplicates of the closest point. TODO(weihan,jngiam): PointCNN implementation makes an assumption that padded points are repeated points from the original points. This behavior is maintained here, but we should update PointCNN to respect indices paddings. Args: points: tensor of shape [N, P1, dims]. query_points: tensor of shape [N, P2, dims] k: Integer. points_padding: optional tensor of shape [N, P1] containing True/1.0 iff the point is a padded point. if None, then all points are considered real points. max_distance: float representing the maximum distance that each neighbor can be. If there are no points within the distance, then the closest point is returned (regardless of distance). If this is set to None, then no filtering by distance is performed. sample_neighbors_uniformly: boolean specifying whether to sample neighbors uniformly if they are within max distance. Returns: A pair of tensors: - indices: tensor of shape [N, P2, k]. - padding: tensor of shape [N, P2, k] where 1 represents a padded point, and 0 represents an unpadded (real) point. """ n, p1 = py_utils.GetShape(points, 2) query_points = py_utils.HasShape(query_points, [n, -1, -1]) _, p2 = py_utils.GetShape(query_points, 2) # Compute pair-wise squared distances. # Note that dist_mat contains the squared distance (without sqrt). Thus, when # using max_distance, we will need to square max_distance to make sure it's # in the same units. dist_mat = SquaredDistanceMatrix(query_points, points) dist_mat = py_utils.HasShape(dist_mat, [n, p2, p1]) # Add a large scalar to the distances for padded points. # dist_mat[i, j, k] will be: # if k < valid_num[i]: distance between points[i, k] and query_points[i, j] # otherwise: a large scalar added to dist_mat[i, j, k] if points_padding is not None: points_padding = tf.cast(tf.expand_dims(points_padding, 1), tf.float32) points_padding = py_utils.HasShape(points_padding, [n, 1, p1]) large_scalar = tf.reduce_max(dist_mat) + 1 dist_mat += points_padding * large_scalar # To perform sampling neighbors uniformly efficiently, we set all neighbors # that are within the distance threshold to have distances be drawn uniformly # at random. Using top_k with this enables selecting a random set quickly # without replacement. if sample_neighbors_uniformly: if max_distance is not None: mask_by_distance = tf.less_equal(dist_mat, max_distance**2) dist_mat = tf.where( mask_by_distance, tf.square(max_distance) * tf.random_uniform(tf.shape(dist_mat)), dist_mat) else: raise ValueError( 'Uniform sampling requires specifying max_distance.') top_k_dist, indices = tf.nn.top_k(-dist_mat, k=k, sorted=True) # N x P2 x K # Set padding using top_k_dist; padded points will have distance exceeding # the large_scalar. if points_padding is not None: paddings = tf.greater_equal(-top_k_dist, large_scalar) else: paddings = tf.zeros_like(top_k_dist, dtype=tf.bool) # Filter by max_distances by setting all indices that exceed the max_distance # to the closest point. if max_distance is not None: # Mask is true for points that are further than max_distance. mask_by_distance = tf.greater(-top_k_dist, tf.square(max_distance)) closest_idx = tf.tile(indices[:, :, :1], [1, 1, k]) indices = tf.where(mask_by_distance, closest_idx, indices) paddings |= mask_by_distance indices = tf.reshape(indices, [n, p2, k]) paddings = tf.cast(paddings, tf.float32) return indices, paddings
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` object containing: ids - The inputs tensor of shape [batch, time]. paddings - The ids' paddings of shape [batch, time]. Returns: A '.NestedMap' object containing: encoded - The encoded features of shape [time, batch, dim] or [batch, time, dim], depending p.output_data_format. padding - The encoded features' padding of shape [time, batch] or [batch, time]. segment_id - The segmentation of packed inputs of shape [time, batch] or [batch, time] if it is supported by the model, or None otherwise. embedded_inputs - The embedded inputs tokens without positional encodings of shape [time, batch, dim] or [batch, time, dim]. """ p = self.params with tf.name_scope(p.name): # [batch, time] input_ids = input_batch.ids # [batch, time] paddings = input_batch.paddings # [batch, time] segment_ids = input_batch.segment_ids if p.packed_input else None batch = py_utils.GetShape(input_ids)[0] time = py_utils.GetShape(input_ids)[1] # Embedding layer. # [batch, time, dim] if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, input_ids) else: input_embs = self.softmax.EmbLookup(theta.softmax, input_ids) orig_input_embs = input_embs # [1, time, dim] if p.packed_input: positions = input_batch.segment_pos position_embs = tf.expand_dims( self.position_emb.FPropWithPosition( theta.position_emb, positions), 0) else: position_embs = tf.expand_dims( self.position_emb.FProp(theta.position_emb, time), 0) # [batch, time, dim] input_embs += position_embs if p.input_dropout_tpl.fprop_dtype: input_embs = tf.cast(input_embs, p.input_dropout_tpl.fprop_dtype) paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [batch, time, dim] transformer_input = input_embs # Explicitly set the input shape of Transformer layers, to avoid # unknown shape error occurred to tf.einsum on nonTPU devices. transformer_input = tf.reshape(transformer_input, [batch, time, p.model_dim]) # Compute self-attention segment mask once. if p.packed_input: segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids, dtype=transformer_input.dtype) else: segment_mask = tf.zeros([batch, 1, time, time]) encoded, padding = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, segment_mask) if p.final_layer_norm: encoded = self.final_ln.FProp(theta.final_ln, encoded) seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1), tf.int32) if p.output_data_format == 'TBC': encoded = tf.transpose(encoded, [1, 0, 2]) # [time, batch, dim] padding = tf.transpose(padding) # [time, batch] segment_ids = tf.transpose( segment_ids) if p.packed_input else None orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2]) return py_utils.NestedMap( encoded=encoded, padding=padding, seq_lengths=seq_lengths, # used by beam_search_helper. segment_id=segment_ids, embedded_inputs=orig_input_embs)
def FProp(self, theta, source_input, source_paddings, target_input=None, target_paddings=None, source_segment_id=None, target_segment_id=None, labels=None, label_weights=None, source_pos_id=None, target_pos_id=None, source_task_id=None, target_task_id=None): """Transforms source sequence of Tensors with Transformers layers. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. source_input: A sequence of ints indicating source input ids of [time, batch] shape or [batch, time] if batch_dim is 0. source_paddings: A sequence of 0s and 1s indicating input paddings of [time, batch] shape or [batch, time] if batch_dim is 0. target_input: A sequence of ints indicating target input ids of [time, batch] shape or [batch, time] if batch_dim is 0. target_paddings: [target_time, target_batch] or [target_batch, target_time] if batch_dim is 0. source_segment_id: A sequence of ints indicating source segment ids of [time, batch] shape or [batch, time] if batch_dim is 0. target_segment_id: A sequence of ints indicating target segment ids of [time, batch] shape or [batch, time] if batch_dim is 0. labels: A sequence of ints indicating label ids of [time, batch] shape, or [batch, time] if batch_dim is 0. label_weights: A sequence of floats indicates label weights of [time, batch] shape, or [batch, time] if batch_dim is 0. source_pos_id: A sequence of ints indicating source position ids of [time, batch] shape, or [batch, time] if batch_dim is 0. target_pos_id: A sequence of ints indicating target position ids of [time, batch] shape, or [batch, time] if batch_dim is 0. source_task_id: A sequence of ints indicating source task ids of [time, batch] shape, or [batch, time] if batch_dim is 0. target_task_id: A sequence of ints indicating target task ids of [time, batch] shape, or [batch, time] if batch_dim is 0. Returns: transformer_output with shape [time, batch, dim] or [batch, time, dim] if batch_dim is 0. """ p = self.params if p.num_decoder_layers > 0: assert target_input is not None assert target_paddings is not None if p.packed_input: assert source_segment_id is not None, ( 'Need to specify src_segment_id if packed input is supported.') assert source_pos_id is not None, ( 'Need to specify src_pos_id for packed input and embeddings.') logits = super(GPipeTransformerStack, self).FProp(theta, source_input, source_paddings, target_input, target_paddings, source_segment_id, target_segment_id, source_pos_id, target_pos_id, source_task_id, target_task_id) if not p.softmax_tpl: return logits label_weights = tf.reshape(label_weights, [-1]) target_probs = None if p.label_smoothing: if p.batch_dim: # Time-major target_probs = tf.transpose( self.smoother.FProp(theta.smoother, tf.transpose(target_paddings), tf.transpose(labels), target_ids=None), [1, 0, 2]) else: target_probs = self.smoother.FProp(theta.smoother, target_paddings, labels, target_ids=None) target_probs = tf.reshape(target_probs, [-1, p.softmax_tpl.num_classes]) reshaped_logits = tf.reshape(logits, [-1, p.softmax_tpl.num_classes]) tgt_labels = tf.reshape(labels, [-1]) num_splits = len(p.splits) softmax = self.children['cell_{}'.format(num_splits - 1)].softmax softmax_theta = theta['cell_{}'.format(num_splits - 1)].softmax per_example_xent, _ = softmax.XentLossFromLogits( softmax_theta, reshaped_logits, class_weights=tf.reshape(label_weights, [-1]), class_ids=tgt_labels, class_probabilities=target_probs) xent_shape = tf.shape(logits)[:2] per_example_xent = tf.reshape(per_example_xent, xent_shape) return per_example_xent, logits
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs): """Merges beam search hyps from multiple decoders. Args: max_hyps_per_beam: the number of top hyps in the merged results. Must be less than or equal to total number of input hyps. beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share the same source_batch and max sequence length. Returns: A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per beam. """ source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0] value_dict = {} for output in beam_search_outputs: hyps_per_beam = py_utils.with_dependencies([ py_utils.assert_equal(source_batch, tf.shape(output.topk_hyps)[0]), ], tf.shape(output.topk_hyps)[1]) for k, v in six.iteritems(output._asdict()): if v is None: continue if k == 'done_hyps': v = tf.transpose(v) if k not in value_dict: value_dict[k] = [] value_dict[k].append(tf.reshape(v, [source_batch, hyps_per_beam, -1])) # Concatenate the tensors along the 'num_hyps_per_beam' dimension. concatenated = {} for k, values in six.iteritems(value_dict): if len(values) != len(beam_search_outputs): raise ValueError('Incomplete values for %s: %s' % (k, beam_search_outputs)) concatenated[k] = tf.concat(values, axis=1) scores = concatenated['topk_scores'] scores = tf.where( tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6), scores) scores = tf.squeeze(scores, -1) # Select top max_hyps_per_beam indices per beam. _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam) batch_ids = tf.tile( tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam]) # [source_batch, max_hyps_per_beam, 2] gather_indices = tf.stack([batch_ids, top_indices], axis=-1) # Gather the merged top hyps according to 'gather_indices'. top = beam_search_outputs[0]._asdict() total_hyps = source_batch * max_hyps_per_beam for k, v in six.iteritems(concatenated): v = tf.gather_nd(v, gather_indices) if k == 'done_hyps': v = tf.transpose(tf.reshape(v, [total_hyps, -1])) elif k == 'topk_hyps': v = tf.reshape(v, [source_batch, max_hyps_per_beam]) elif k == 'topk_ids': v = tf.reshape(v, [total_hyps, -1]) elif k in ('topk_lens', 'topk_scores', 'topk_decoded'): v = tf.reshape(v, [total_hyps]) else: raise ValueError('Unexpected field: %s' % k) top[k] = v return BeamSearchDecodeOutput(**top)
def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.logical_and(cur_step < max_steps, tf.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # TODO(rpang): avoid inspecting 'encoder_outputs'. source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): source_seq_lengths = tf.cast( tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: source_seq_lengths = tf.cast( tf.round(tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, source_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states, num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend beam search hyps for one step. | num_beams = Number of source sequences to be decoded. | num_hyps_per_beam = Number of hyps to keep per source sequence. | num_hyps = num_beams * num_hyps_per_beam | src_seq_len = Number of time steps in the source sequence. | src_batch = Number of examples in the source sequence. | tgt_seq_len = Maximum allowed time steps in the target sequence. | tgt_batch = num_hyps_per_beam * src_batch Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. This list is maintained by this helper class. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. num_hyps_per_beam: Num of hyps to keep per beam. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next beam search step, (next step, all_done, step_ids, core_bs_states, other_states) """ p = self.params bs_results, other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam) (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs) = core_bs_states (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs, all_done) = ops.beam_search_step( bs_results.log_probs, bs_results.atten_probs, best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, bs_results.is_last_chunk if self._model_uses_eoc_id else [], cur_step, eoc_id=p.target_eoc_id, eos_id=p.target_eos_id, beam_size=p.beam_size, num_hyps_per_beam=num_hyps_per_beam, valid_eos_max_logit_delta=p.valid_eos_max_logit_delta, merge_paths=p.merge_paths, allow_empty_terminated_hyp=p.allow_empty_terminated_hyp, ensure_full_beam=p.ensure_full_beam, force_eos_in_last_step=p.force_eos_in_last_step, local_eos_threshold=p.local_eos_threshold) new_step_ids = tf.reshape(out_hyps[cur_step, :], tf.shape(step_ids)) new_step_ids.set_shape(step_ids.get_shape()) old_hyp_ids = tf.reshape( tf.slice(out_prev_hyps, begin=[cur_step, 0], size=[1, -1]), [-1]) if p.batch_major_compute: # Transformed the indices into the key/value cache for fast decoding # (prefix_states in other_states) due to the num_hyps dimension of # cache is computed as num_beams by num_hyps_per_beam, which is different # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams). # Both transpose and recomputation are required to correct the indices. num_beams = tf.shape(best_scores)[0] old_hyp_ids_in_cache_order = tf.reshape( tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])), [-1]) old_hyp_ids_in_cache_order = ( (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam + old_hyp_ids_in_cache_order // num_beams) new_bs_states = (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs) def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and x_in.shape.ndims > 0): if x_in.shape.ndims > 2 and not p.batch_major_state: # Use corrected indices only here for batch major compute as key/value # caches are the states being affected. correct_old_hyp_ids = ( old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in new_other_states = other_states.Transform(ReOrderHyps) final_other_states = post_beam_search_step_callback(theta, encoder_outputs, new_step_ids, new_other_states) return (cur_step + 1, all_done, new_step_ids, new_bs_states, final_other_states)
def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx): """Loop body for farthest point sampler.""" def _GetRandomRealPoint(): """Select the first point. For the first point, we want any random real (non padded) point, so we create a random values per point, and then set all padded ones to some large value (more than the maxval). We then take the min per batch element to get the first points. Returns: Tensor containing the index of a random point selected for each example in the batch. """ random_values = tf.random.uniform((batch_size, num_points), minval=0, maxval=1, dtype=tf.float32, seed=random_seed) random_values = tf.where(tf.equal(padding, 0.0), random_values, padding * 10) return tf.argmin(random_values, axis=1, output_type=tf.int32) def _GetFurthestPoint(): """Get point that is furthest from those already selected. We also bias the sampling towards real points by setting the distance to padded points negative until we are out of real points. Returns: Tensor containing the index of the next farthest point selected for each example in the batch. """ # Set padded points distance to negative so they aren't selected. padding_masked_distance_to_selected = tf.where( tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones( (batch_size, num_points), dtype=tf.float32)) # But only do this when we still have valid points left. padding_masked_distance_to_selected = tf.where( tf.less(curr_idx, num_valid_points), padding_masked_distance_to_selected, distance_to_selected) return tf.argmax(padding_masked_distance_to_selected, axis=-1, output_type=tf.int32) def _GetSeededPoint(): """Select a seeded point. Seeded points are assumed to be at the beginning of the original points. Returns: Tensor containing the index of the next seeded point to select for each example in the batch. """ return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx # Select indices for this loop iteration. def _Seeded(): return tf.cond(tf.less(curr_idx, num_seeded_points), _GetSeededPoint, _GetFurthestPoint) def _Real(): return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint) new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real) sampled_idx = sampled_idx.write(curr_idx, new_selected) # Extract the distance to the latest point selected to update # distance_to_selected. new_selected_gather_idx = tf.stack( [tf.range(batch_size), new_selected], axis=1) if precomputed_squared_distance is not None: new_distance = tf.gather_nd(precomputed_squared_distance, new_selected_gather_idx) else: new_points = tf.reshape( tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims]) new_distance = tf.reshape( SquaredDistanceMatrix(points, new_points), [batch_size, num_points]) is_newly_closest = tf.less(new_distance, distance_to_selected) distance_to_selected = tf.minimum(distance_to_selected, new_distance) # Track the index to the closest selected point. new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points]) closest_idx = tf.cond( tf.equal(curr_idx, 0), # At the first loop iteration, the init points are the closest. lambda: new_selected_tiled, # Otherwise, update with the new points based on the distances. lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx) ) return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx