def _InitBeamSearchStateCallback(self, theta, encoder_outputs, num_hyps_per_beam): """Returns initial beams search states. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. encoder_outputs: a NestedMap computed by encoder. num_hyps_per_beam: An int, number hyps to keep for source sentence. Returns: A tuple (initial_results, states). initial_results: a `.NestedMap` of initial results. atten_probs: The initial attention probs, of shape [tgt_batch, src_len]. states: a `.NestedMap` of initial model states. source_encs: A tensor of shape [src_batch, src_len, source_dim]. source_paddings: A tensor of shape [src_batch, src_len]. target_ids: Initial empty list of decoded ids. [num_hyps, 0]. """ p = self.params source_encs = encoder_outputs.encoded num_hyps = py_utils.GetShape(source_encs)[1] * num_hyps_per_beam source_len = py_utils.GetShape(source_encs)[0] # Dummy attention probs atten_probs = tf.ones([num_hyps, source_len]) / tf.to_float(source_len) initial_results = py_utils.NestedMap(log_probs=tf.zeros( [num_hyps, p.softmax.num_classes], dtype=py_utils.FPropDtype(p)), atten_probs=atten_probs) batch_size = num_hyps atten_hidden_dim = p.trans_tpl.tr_atten_tpl.atten_hidden_dim if not atten_hidden_dim: atten_hidden_dim = p.model_dim if p.beam_search.name == 'tpu_beam_search': seq_len = p.target_seq_len else: seq_len = 0 prefix_states = py_utils.NestedMap({ 'layer_%d' % layer: py_utils.NestedMap({ 'key': tf.zeros([seq_len, batch_size, atten_hidden_dim], dtype=py_utils.FPropDtype(p)), 'value': tf.zeros([seq_len, batch_size, atten_hidden_dim], dtype=py_utils.FPropDtype(p)), }) for layer in range(p.num_trans_layers) }) return initial_results, py_utils.NestedMap({ 'prefix_states': prefix_states, 'time_step': tf.constant(0) })
def _InitBeamSearchStateCallback(self, theta, source_encs, source_paddings, num_hyps_per_beam, additional_source_info=None): """Returns initial beams search states. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. source_encs: A tensor of shape [src_len, src_batch, source_dim]. Can be [time, batch, depth, num_layers] if is_transparent is set. source_paddings: A tensor of shape [src_len, src_batch]. num_hyps_per_beam: An int, number hyps to keep for source sentence. additional_source_info: a `.NestedMap` of tensors containing extra context information about the source that may be useful for decoding. Returns: A tuple (initial_results, states). initial_results: a `.NestedMap` of initial results. atten_probs: The initial attention probs, of shape [tgt_batch, src_len]. states: a `.NestedMap` of initial model states. source_encs: A tensor of shape [src_batch, src_len, source_dim]. source_paddings: A tensor of shape [src_batch, src_len]. target_ids: Initial empty list of decoded ids. [num_hyps, 0]. """ p = self.params # additional_source_info is currently not used. del additional_source_info num_hyps = py_utils.GetShape(source_encs)[1] * num_hyps_per_beam source_len = py_utils.GetShape(source_encs)[0] # Dummy attention probs atten_probs = tf.ones([num_hyps, source_len]) / tf.to_float(source_len) initial_results = py_utils.NestedMap({'atten_probs': atten_probs}) batch_size = num_hyps key_channels = p.model_dim value_channels = p.model_dim prefix_states = py_utils.NestedMap({ 'layer_%d' % layer: py_utils.NestedMap({ 'key': tf.zeros([batch_size, 0, key_channels], dtype=py_utils.FPropDtype(p)), 'value': tf.zeros([batch_size, 0, value_channels], dtype=py_utils.FPropDtype(p)), }) for layer in range(p.num_trans_layers) }) return initial_results, py_utils.NestedMap({ 'prefix_states': prefix_states, 'time_step': 0 })
def _ValidateBiases(self, content_bias, positional_bias, n, h): if content_bias is not None: content_bias = py_utils.HasShape(content_bias, [n, h]) else: content_bias = tf.constant(0, dtype=py_utils.FPropDtype(self.params)) if positional_bias is not None: positional_bias = py_utils.HasShape(positional_bias, [n, h]) else: positional_bias = tf.constant(0, dtype=py_utils.FPropDtype(self.params)) return content_bias, positional_bias
def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile(tensor, [num_hyps_per_beam, 1 ]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, py_utils.FPropDtype(p)) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] label_probs = tf.one_hot( label, vocab_size, dtype=py_utils.FPropDtype(p)) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) # Ensure that tf.math.log is applied to positive values. probs = tf.maximum(probs, tf.constant(1e-12, dtype=probs.dtype)) return tf.math.log(probs), consistent
def zero_state(self, batch_size): p = self.params num_groups = self.num_groups if not p.cumulative: return py_utils.NestedMap() # Note: Prefer storing data in <=4D tensors, as TFLite doesn't support # implicit broadcasting for 5D (or larger) tensors on many operators. cached_count_shape = [batch_size, 1] cached_moment_shape = [batch_size, num_groups] cached_sum = tf.zeros(cached_moment_shape, py_utils.FPropDtype(p)) cached_count = tf.zeros(cached_count_shape, py_utils.FPropDtype(p)) cached_var = tf.zeros(cached_moment_shape, py_utils.FPropDtype(p)) return py_utils.NestedMap( cached_sum=cached_sum, cached_count=cached_count, cached_var=cached_var)
def zero_state(self, batch_size): p = self.params num_groups = self.num_groups if not p.cumulative: return py_utils.NestedMap() if p.input_rank == 4: cache_shape = [batch_size, 1, 1, num_groups, 1] else: cache_shape = [batch_size, 1, num_groups, 1] cached_sum = tf.zeros(cache_shape, py_utils.FPropDtype(p)) cached_count = tf.zeros(cache_shape, py_utils.FPropDtype(p)) cached_var = tf.zeros(cache_shape, py_utils.FPropDtype(p)) return py_utils.NestedMap( cached_sum=cached_sum, cached_count=cached_count, cached_var=cached_var)
def _Cast(x): if x is None: return None x = tf.convert_to_tensor(x) if not x.dtype.is_floating: return x return tf.cast(x, py_utils.FPropDtype(self.params))
def _config_outfeed(self, xformer, infeed_batch): """Setup the outfeed ops.""" fprop_dtype = py_utils.FPropDtype(self.model_params.task) assert len(infeed_batch) == 6 or len(infeed_batch) == 7, len( infeed_batch) if len(infeed_batch) == 7: (key, tgt_ids, tgt_segment_id, tgt_segment_pos, tgt_labels, _, _) = infeed_batch elif len(infeed_batch) == 6: (key, tgt_ids, tgt_segment_id, tgt_segment_pos, tgt_labels, _) = infeed_batch tgt_segment_id = tf.cast(tgt_segment_id, fprop_dtype) input_batch = py_utils.NestedMap() input_batch.src = py_utils.NestedMap() input_batch.src.ids = (0 * tgt_ids) # unused input_batch.src.segment_ids = (0 * tgt_segment_id) # unused input_batch.src.segment_pos = (0 * tgt_segment_pos) # unused input_batch.tgt = py_utils.NestedMap() input_batch.tgt.ids = tgt_ids input_batch.tgt.segment_ids = tgt_segment_id input_batch.tgt.segment_pos = tgt_segment_pos input_batch.tgt.labels = tgt_labels # only used when --fprop=true with tpu_summary.context(rewrite_while_loop=True): dec_ret = xformer.DecodeIds(xformer.theta, input_batch) dec_metrics = tpu_summary.merge_all() key = infeed_batch[0] return [ key, tgt_ids, tgt_segment_id, dec_ret.topk_ids, dec_ret.topk_lens, dec_ret.topk_scores, dec_metrics ]
def testMoEFFLayerFProp(self, use_fflayer_start_moe, use_fflayer_end_moe, expected_aux_loss): p = self._GetParams() if use_fflayer_start_moe: p.fflayer_start_tpl = gshard_builder.MoEBuilder.Params().Set( e_dim=2, c_dim=2, num_devices=2) if use_fflayer_end_moe: p.fflayer_end_tpl = gshard_builder.MoEBuilder.Params().Set( e_dim=2, c_dim=2, num_devices=2) l = p.Instantiate() inputs, paddings = self._GetInputs() inputs = tf.convert_to_tensor(inputs) paddings = tf.convert_to_tensor(paddings) in_nmap = py_utils.NestedMap(features=inputs, paddings=paddings) in_nmap.aux_loss = tf.convert_to_tensor(0., py_utils.FPropDtype(p)) out_nmap = l.FPropDefaultTheta(in_nmap) self.assertIn('aux_loss', out_nmap) loss = tf.reduce_sum(out_nmap.features) + 0.01 * out_nmap.aux_loss grads = tf.gradients( loss, l.vars.Flatten(), unconnected_gradients=tf.UnconnectedGradients.ZERO) with self.session() as sess: tf.global_variables_initializer().run() out_vals = sess.run(out_nmap.features) grad_vals = sess.run(grads) self.assertEqual(out_nmap.aux_loss.shape, ()) aux_loss = sess.run(out_nmap.aux_loss) self.assertAlmostEqual(expected_aux_loss, aux_loss, places=5) print([x.shape for x in out_vals]) print([g.shape for g in grad_vals])
def _InferenceSubgraph_Default(self): """Default inference subgraph.""" batch_size = None seq_length = None fp_dtype = py_utils.FPropDtype(self.params) tshape = (batch_size, seq_length) input_ids = tf.placeholder(dtype=tf.int32, shape=tshape) targets = tf.placeholder(dtype=tf.int32, shape=tshape) paddings = tf.placeholder(dtype=fp_dtype, shape=tshape) weights = tf.placeholder(dtype=fp_dtype, shape=tshape) segment_ids = tf.placeholder(dtype=tf.int32, shape=tshape) segment_pos = tf.placeholder(dtype=tf.int32, shape=tshape) word_count = tf.placeholder(dtype=tf.int32, shape=(batch_size)) num_sentences = tf.placeholder(dtype=tf.int32, shape=(batch_size)) feeds = { 'ids': input_ids, 'labels': targets, 'paddings': paddings, 'weights': weights, 'segment_ids': segment_ids, 'segment_pos': segment_pos, 'word_count': word_count, 'num_sentences': num_sentences } input_batch = py_utils.NestedMap(feeds) loss, _ = self.FPropTower(self.theta, input_batch) fetches = {'loss': loss['loss'][0]} return fetches, feeds
def testTransformerStackAlternateLayers(self): batch = 3 tf.flags.FLAGS.tpu_compatible = True with self.session(use_gpu=False) as sess: model_dim = 2 num_transformer_layers = 2 transformer_tpl = layers_with_attention.TransformerLayer.Params() transformer_tpl.tr_atten_tpl.num_attention_heads = 1 transformer_tpl.tr_fflayer_tpl.hidden_dim = 2 params = mt_layers.TransformerStack.Params().Set( name='transformer', model_dim=model_dim, num_transformer_layers=num_transformer_layers, transformer_tpl=[ transformer_tpl.Copy() for _ in range(num_transformer_layers) ], random_seed=123456) xformer = mt_layers.TransformerStack(params) input_arr = np.array([ [[0, 1]] * batch, [[1, -1]] * batch, ], dtype=int) paddings_arr = np.array([[0] * batch, [0] * batch], dtype=int) inputs = tf.constant(input_arr.tolist(), dtype=py_utils.FPropDtype(params)) paddings = tf.constant(paddings_arr.tolist(), dtype=py_utils.FPropDtype(params)) output, _, _ = xformer.FProp(xformer.theta, inputs, paddings) tf.global_variables_initializer().run() output = sess.run(output) print(repr(output)) # pylint: disable=bad-whitespace # pyformat: disable self.assertAllCloseAccordingToType( np.array([[[-2.17566538, -0.2821945], [-2.17566514, -0.28219438], [-2.17566514, -0.28219438]], [[-0.71516591, -0.90594757], [-0.71516603, -0.90594769], [-0.71516603, -0.90594769]]]), output)
def _testDecoderFPropHelper(self, params): """Computes decoder from params and computes loss with random inputs.""" dec = decoder.AsrDecoder(params) src_seq_len = 5 src_enc = tf.random_normal([src_seq_len, 2, 8], seed=982774838, dtype=py_utils.FPropDtype(params)) src_enc_padding = tf.constant( [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=py_utils.FPropDtype(params)) encoder_outputs = py_utils.NestedMap(encoded=src_enc, padding=src_enc_padding) # shape=[4, 5] target_ids = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15], [5, 6, 7, 8], [10, 5, 2, 5]], dtype=tf.int32)) # shape=[4, 5] target_labels = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13], [5, 7, 8, 10], [10, 5, 2, 4]], dtype=tf.int32)) # shape=[4, 5] target_paddings = tf.transpose( tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [1, 1, 1, 0]], dtype=py_utils.FPropDtype(params))) target_transcripts = tf.constant( ['abcd', 'bcde', 'klmp', 'fghi', 'kfcf']) target_weights = 1.0 - target_paddings # ids/labels/weights/paddings are all in [batch, time] shape. targets = py_utils.NestedMap({ 'ids': target_ids, 'labels': target_labels, 'weights': target_weights, 'paddings': target_paddings, 'transcripts': target_transcripts, }) metrics, per_sequence_loss = dec.FPropWithPerExampleLoss( encoder_outputs, targets) loss = metrics['loss'] return loss, per_sequence_loss
def ExtendStep(self, theta, source_vecs, prefix_states, aux_vecs=None, aux_paddings=None, t=None): """Transformer Layer, extend one step in decoding. This function is expected to be called during fast decoding of Transformer models. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. source_vecs: [source_batch, dim]. prefix_states: dict, containing tensors which are the results of previous attentions, used for fast decoding. aux_vecs: [aux_time, aux_batch, dim] aux_paddings: [aux_time, aux_batch] t: a scalar, the current time step, 0-based. Returns: The attention context vector, [target_batch, source_dim] The attention probability vector, [source_time, target_batch] Updated prefix states """ p = self.params if p.has_aux_atten: assert aux_vecs is not None assert aux_paddings is not None batch_size = tf.shape(source_vecs)[0] # First the self-attention layer. atten_vec, atten_prob, new_states = self.self_atten.ExtendStep( theta.self_atten, source_vecs, prefix_states, t) atten_vec = tf.expand_dims(atten_vec, axis=0) # Next the source attention layer. if p.has_aux_atten: atten_vec, atten_prob = self.atten.FProp(theta.atten, atten_vec, aux_paddings, aux_vecs) # Finally, the feedforward layer. h = self.fflayer.FProp( theta.fflayer, atten_vec, tf.zeros([1, batch_size], dtype=py_utils.FPropDtype(p))) h = tf.squeeze(h, 0) return h, atten_prob, new_states
def FPropTower(self, theta, input_batch): p = self.params tf.logging.info('input_batch=%r', input_batch) ids, paddings, labels_ids, weights = self._TrimIfPossible( input_batch.ids, input_batch.paddings, input_batch.labels, input_batch.weights) fprop_dtype = py_utils.FPropDtype(p) paddings = tf.cast(paddings, fprop_dtype) weights = tf.cast(weights, fprop_dtype) tf.logging.info('inputs={}'.format((ids, paddings, labels_ids, weights))) batch_size = tf.shape(ids)[0] state0 = None labels = py_utils.NestedMap(class_ids=labels_ids, class_weights=weights) fprop_kwargs = dict() if 'segment_ids' in input_batch: fprop_kwargs.update( segment_ids=input_batch.segment_ids, segment_pos=input_batch.segment_pos) xent_output, _ = self.lm.FProp(theta.lm, ids, paddings, state0, labels, **fprop_kwargs) if 'segment_ids' in input_batch: num_sentences = input_batch.num_sentences else: num_sentences = tf.ones(shape=[batch_size], dtype=tf.int32) # +num_sentences to account for the end of sequence symbol. num_words = tf.cast( tf.reduce_sum(input_batch.word_count + num_sentences), fprop_dtype) predicted_labels = tf.cast(xent_output.per_example_argmax, labels_ids.dtype) num_preds = xent_output.total_weight mean_acc = tf.reduce_sum( tf.cast(tf.equal(labels_ids, predicted_labels), fprop_dtype) * weights) / tf.math.maximum(num_preds, 1) loss = xent_output.avg_xent per_sequence_loss = tf.reduce_sum( xent_output.per_example_xent * weights, axis=1) if p.train.sum_loss_across_tokens_in_batch: loss = xent_output.total_xent else: per_sequence_loss /= tf.reduce_sum(weights, axis=1) return { 'loss': (loss, num_preds), 'fraction_of_correct_next_step_preds': (mean_acc, num_preds), 'log_pplx': (xent_output.avg_xent, num_preds), 'log_pplx_per_word': (xent_output.total_xent / num_words, num_words), 'num_predictions': (num_preds, 1), 'num_words': (num_words, 1), 'num_sentences': (tf.reduce_sum(num_sentences), 1), }, { 'loss': per_sequence_loss, }
def _getDecoderFPropMetrics(self, params): """Creates decoder from params and computes metrics with random inputs.""" dec = params.Instantiate() src_seq_len = 5 src_enc = tf.random.normal([src_seq_len, 2, 8], seed=982774838, dtype=py_utils.FPropDtype(params)) src_enc_padding = tf.constant( [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=py_utils.FPropDtype(params)) encoder_outputs = py_utils.NestedMap(encoded=src_enc, padding=src_enc_padding) # shape=[4, 5] target_ids = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15], [5, 6, 7, 8], [10, 5, 2, 5]], dtype=tf.int32)) # shape=[4, 5] target_labels = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13], [5, 7, 8, 10], [10, 5, 2, 4]], dtype=tf.int32)) # shape=[4, 5] target_paddings = tf.transpose( tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [1, 1, 1, 0]], dtype=py_utils.FPropDtype(params))) target_transcripts = tf.constant( ['abcd', 'bcde', 'klmp', 'fghi', 'kfcf']) target_weights = 1.0 - target_paddings # ids/labels/weights/paddings are all in [batch, time] shape. targets = py_utils.NestedMap({ 'ids': target_ids, 'labels': target_labels, 'weights': target_weights, 'paddings': target_paddings, 'transcripts': target_transcripts, }) decoder_outputs = dec.FPropDefaultTheta(encoder_outputs, targets) return decoder_outputs.metrics, decoder_outputs.per_sequence['loss']
def _UpdateVnConfig(self): """Update vn config from the various vn flags.""" p = self.params tp = p.train if tp: vn_enabled = ((tp.vn_std > 0) and p.vn and (p.vn.global_vn or p.vn.per_step_vn)) if p.is_eval or (not vn_enabled): p.vn = py_utils.VariationalNoiseParams(None, False, False) else: # vn.scale is dependent on global_step. p.vn.scale = tf.cast(self._global_step > tp.vn_start_step, py_utils.FPropDtype(p)) * tp.vn_std
def testTransformerStackAlternateLayers(self): batch = 3 tf.flags.FLAGS.tpu_compatible = True with self.session(use_gpu=False): model_dim = 2 num_transformer_layers = 2 transformer_tpl = layers_with_attention.TransformerLayer.Params() transformer_tpl.tr_atten_tpl.num_attention_heads = 1 transformer_tpl.tr_fflayer_tpl.hidden_dim = 2 params = mt_layers.TransformerStack.Params().Set( name='transformer', model_dim=model_dim, num_transformer_layers=num_transformer_layers, transformer_tpl=[ transformer_tpl.Copy() for _ in range(num_transformer_layers) ], random_seed=123456) xformer = mt_layers.TransformerStack(params) input_arr = np.array([ [[0, 1]] * batch, [[1, -1]] * batch, ], dtype=int) paddings_arr = np.array([[0] * batch, [0] * batch], dtype=int) inputs = tf.constant(input_arr.tolist(), dtype=py_utils.FPropDtype(params)) paddings = tf.constant(paddings_arr.tolist(), dtype=py_utils.FPropDtype(params)) output, _, _ = xformer.FProp(xformer.theta, inputs, paddings) self.evaluate(tf.global_variables_initializer()) output = self.evaluate(output) print(repr(output)) self.assertAllCloseAccordingToType( np.array([[[-0.940543, 1.479253]] * batch, [[-0.413938, -2.550903]] * batch]), output)
def FProp(self, theta, inp): """Look up style embedding.""" p = self.params b_size = tf.shape(inp)[0] styles_w = tf.tile(tf.nn.tanh(theta.styles_w), [1, b_size, 1]) styles_paddings = tf.zeros([p.num_styles, b_size], dtype=py_utils.FPropDtype(p)) packed_src = self.atten.InitForSourcePacked(theta.atten, styles_w, styles_w, styles_paddings) style_emb, probs, _ = self.atten.ComputeContextVectorWithSource( theta.atten, packed_src, inp) # TODO(yonghui): Extract and return the attention probabilities. return style_emb, probs
def _CreateSourceAndTargets(params): """Creates encoder outputs and targets from params for the decoder.""" src_seq_len = 5 src_enc = tf.random.normal([src_seq_len, 2, 8], seed=982774838, dtype=py_utils.FPropDtype(params)) src_enc_padding = tf.constant( [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=py_utils.FPropDtype(params)) encoder_outputs = py_utils.NestedMap(encoded=src_enc, padding=src_enc_padding) # shape=[4, 5] target_ids = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15], [5, 6, 7, 8], [10, 5, 2, 5]], dtype=tf.int32)) # shape=[4, 5] target_labels = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13], [5, 7, 8, 10], [10, 5, 2, 4]], dtype=tf.int32)) # shape=[4, 5] target_paddings = tf.transpose( tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [1, 1, 1, 0]], dtype=py_utils.FPropDtype(params))) target_transcripts = tf.constant(['abcd', 'bcde', 'klmp', 'fghi', 'kfcf']) target_weights = 1.0 - target_paddings # ids/labels/weights/paddings are all in [batch, time] shape. targets = py_utils.NestedMap({ 'ids': target_ids, 'labels': target_labels, 'weights': target_weights, 'paddings': target_paddings, 'transcripts': target_transcripts, }) return encoder_outputs, targets
def FPropTower(self, theta, input_batch): p = self.params fprop_dtype = py_utils.FPropDtype(p) tf.logging.info('input_batch=%r', input_batch) ids = input_batch.ids labels_ids = input_batch.labels paddings = tf.cast(input_batch.paddings, fprop_dtype) weights = tf.cast(input_batch.weights, fprop_dtype) tf.logging.info('inputs={}'.format( (ids, paddings, labels_ids, weights))) batch_size = tf.shape(ids)[0] state0 = self.lm.zero_state(theta.lm, batch_size) labels = py_utils.NestedMap(class_ids=labels_ids, class_weights=weights) xent_output, _ = self.lm.FProp(theta.lm, ids, paddings, state0, labels, segment_ids=input_batch.segment_ids, segment_pos=input_batch.segment_pos) # +input_batch.num_sentences to account for the end of sequence symbol. num_words = tf.cast( tf.reduce_sum(input_batch.word_count + tf.cast(input_batch.num_sentences, dtype=tf.int32)), fprop_dtype) predicted_labels = tf.cast(xent_output.per_example_argmax, labels_ids.dtype) num_sentences = tf.reduce_sum(input_batch.num_sentences) num_preds = xent_output.total_weight mean_acc = tf.reduce_sum( tf.cast(tf.equal(labels_ids, predicted_labels), fprop_dtype) * weights) / tf.math.maximum(num_preds, 1) loss = xent_output.avg_xent return { 'loss': (loss, num_preds), 'fraction_of_correct_next_step_preds': (mean_acc, num_preds), 'log_pplx': (xent_output.avg_xent, num_preds), 'log_pplx_per_word': (xent_output.total_xent / num_words, num_words), 'num_predictions': (num_preds, 1), 'num_words': (num_words, 1), 'num_sentences': (num_sentences, 1) }, {}
def FProp(self, theta, ids, segment_pos): p = self.params fprop_dtype = py_utils.FPropDtype(p) ids = self._MaybeSplit(ids) segment_pos = self._MaybeSplit(segment_pos) one_hot_ids = tf.one_hot(ids, p.vocab_size, dtype=fprop_dtype) one_hot_ids = self._MaybeSplit(one_hot_ids) one_hot_pos = tf.one_hot(segment_pos, p.max_len, dtype=fprop_dtype) one_hot_pos = self._MaybeSplit(one_hot_pos) token_emb = tf.einsum('VH,BLV->BLH', theta.embedding, one_hot_ids) token_emb = self._MaybeSplit(token_emb) pos_emb = tf.einsum('VH,BLV->BLH', theta.pos_emb, one_hot_pos) pos_emb = self._MaybeSplit(pos_emb) return self._MaybeSplit(token_emb + pos_emb)
def zero_state(self, batch_size): """Returns the initial state given the batch size. Args: batch_size: the batch size. Returns: state0: A NestedMap of tensors including: - context: A Tensor of shape [b, filter_shape[0]-1, 1, c]. """ p = self.params assert p.filter_shape[1] == 1, ( 'zero_state() only supports 1d causal convolution.') context = tf.zeros( shape=[batch_size] + [p.filter_shape[0] - 1, p.filter_shape[1], p.filter_shape[2]], dtype=py_utils.FPropDtype(p)) return py_utils.NestedMap(context=context)
def _InitBeamSearchStateCallback(self, theta, encoder_outputs, num_hyps_per_beam): """Returns initial beams search states. Args: theta: a NestedMap of parameters. encoder_outputs: a NestedMap computed by encoder. num_hyps_per_beam: An int, number hyps to keep for source sentence. Returns: A tuple (initial_results, states). initial_results: a `.NestedMap` of initial results. atten_probs: The initial attention probs, of shape [tgt_batch, src_len]. states: a `.NestedMap` of initial model states. rnn_states: Initial state of the RNN. atten_context: Initial attention context vector. atten_states: Initial attention state. """ p = self.params num_beams = py_utils.GetShape(encoder_outputs.padding)[1] num_hyps = num_beams * num_hyps_per_beam rnn_states, init_atten_context, atten_probs, atten_states = ( self._InitDecoder(theta, encoder_outputs, num_hyps)) initial_results = py_utils.NestedMap(log_probs=tf.zeros( [num_hyps, p.softmax.num_classes], dtype=py_utils.FPropDtype(p)), atten_probs=atten_probs) return initial_results, py_utils.NestedMap({ 'rnn_states': rnn_states, 'atten_context': init_atten_context, 'atten_probs': atten_probs, 'atten_states': atten_states, })
def Callback(theta, encoder_outputs, num_hyps_per_beam): initial_results, states = self._InitBeamSearchStateCallback( theta, encoder_outputs, num_hyps_per_beam) assert hasattr(states, 'time_step') if tf.is_tensor(encoder_outputs.padding): batch_size = tf.shape(encoder_outputs.padding)[1] else: # Required for multisource models. batch_size = tf.shape( list(encoder_outputs.padding.values())[0])[1] num_hyps = batch_size * num_hyps_per_beam if biased: # states.consistent is initially all True states.consistent = tf.ones([ num_hyps, ], dtype=tf.bool) if stochastic: dtype = py_utils.FPropDtype(self.params) states.cumulative_log_probs = tf.zeros([num_hyps, 1], dtype=dtype) states.perturbed_cumulative_log_probs = tf.zeros([num_hyps, 1], dtype=dtype) # Temporary tensors that store information passed from # PreBeamSearchStepCallback to PostBeamSearchStepCallback. These are # used for updating states.cumulative_log_probs and # states.perturbed_cumulative_log_probs for the next step, which # requires the knowledge of the chosen IDs, which only becomes available # after PreBeamSearchStepCallback. states.tmp_states = py_utils.NestedMap( # Top-k (non-perturbed) log-probs. Used for updating # `cumulative_log_probs` in PostBeamSearchStepCallback. top_k_log_probs=tf.zeros([num_hyps, k], dtype=dtype), # Vocab ID of each item of `top_k_log_probs`. top_k_ids=tf.zeros([num_hyps, k], dtype=tf.int32), # Perturbed cumulative log-probs of the top-k IDs. Used for updating # `perturbed_cumulative_log_probs` in PostBeamSearchStepCallback. new_perturbed_cumulative_log_probs=tf.zeros([num_hyps, k], dtype=dtype), ) return initial_results, states
def FProp(self, theta, inputs, paddings, state0=None, labels=None): """Computes xent loss given the language model input activations. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: Input ids. An int32 tensor of shape [time, batch]. paddings: A 0/1 tensor of shape [time, batch]. state0: Not used for Transformer. labels: If not None, a `.NestedMap` containing the following fields: - class_weights, a tensor with shape [time, batch] containing the weights for each target word. - class_ids, a tensor with shape [time, batch] of int32 dtype containing the target class labels. - class_probabilities, a tensor with shape [time, batch, vocab_size] of float values indicating class-membership probabilities. Returns: If `labels` is not None, returns (xent_output, state1), where `xent_output` is a `.NestedMap` as defined by `SoftmaxLayer`'s return value and `state1` is the next recurrent state. Otherwise, `xent_output` only contains the softmax logits. """ p = self.params ids = py_utils.HasRank(inputs, 2) paddings = py_utils.HasShape(paddings, tf.shape(ids)) per_example_xent, logits = self.stack.FProp( theta.stack, ids, paddings, None, None, None, None, tf.cast(labels.class_ids, py_utils.FPropDtype(p)), labels.class_weights) per_example_argmax = py_utils.ArgMax(logits) total_xent = tf.reduce_sum(per_example_xent * labels.class_weights) total_weights = tf.reduce_sum(labels.class_weights) xent_output = py_utils.NestedMap( total_weight=total_weights, per_example_xent=per_example_xent, logits=logits, per_example_argmax=per_example_argmax, avg_xent=total_xent / total_weights, total_xent=total_xent) return xent_output, {}
def StyleEmbFromProbs(self, theta, inp): """Look up style embedding based on feedin probabilities. Args: theta: params for this layer and its sub-layers. inp: attention probabilities of shape [batch_size, num_styles]. Returns: style_emb - weighted combined style embedding based on inp. """ p = self.params b_size = tf.shape(inp)[0] styles_w = tf.tile(tf.nn.tanh(theta.styles_w), [1, b_size, 1]) styles_paddings = tf.zeros([p.num_styles, b_size], dtype=py_utils.FPropDtype(p)) atten_probs = tf.tile(tf.expand_dims(inp, 1), [1, p.num_heads, 1]) atten_probs = tf.reshape(atten_probs, [-1, p.num_styles]) packed_src = self.atten.InitForSourcePacked(theta.atten, styles_w, styles_w, styles_paddings) style_emb, _ = self.atten.ComputeContextVectorWithAttenProbs( theta.atten, packed_src.source_contexts, atten_probs) return style_emb
def ZeroState(self, theta, prepared_inputs, batch_size): """Produce a zero state for this step. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. prepared_inputs: A set of inputs pre-processed by using PrepareExternalInputs. batch_size: Number of elements in the batched input. Returns: state0, a state parameter to pass to FProp on its first invocation. """ max_seq_length = py_utils.GetShape(prepared_inputs.src, 3)[0] atten_state = self.atten.ZeroAttentionState(max_seq_length, batch_size) (new_atten_context, _, new_atten_states) = self.atten.ComputeContextVectorWithSource( theta.atten, prepared_inputs.packed_src, tf.zeros([batch_size, self.params.atten.query_dim], dtype=py_utils.FPropDtype(self.params)), attention_state=atten_state) return py_utils.NestedMap( atten_context=new_atten_context, atten_state=new_atten_states)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def Cast(self, v): """Cast tensor dtype to fprop_dtype.""" if not v.dtype.is_floating: return v return tf.cast(v, py_utils.FPropDtype(self.params))
def FProp(self, theta, query_vec, source_paddings, source_vecs=None, query_segment_id=None, source_segment_id=None): """Transformer attention, residual and normalization layer. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. query_vec: [target_time, target_batch, dim] source_paddings: [source_time, source_batch] source_vecs: [source_time, source_batch, dim]. query_segment_id: [target_time, target_batch] source_segment_id: [source_time, source_batch] Returns: (output, atten_probs). output is of shape [target_time, target_batch, source_dim], atten_probs is of shape [target_time, target_batch, source_time]. """ p = self.params unnormalized_query_vec = query_vec query_vec = self.layer_norm.FProp(theta.layer_norm, query_vec) if source_vecs is None: source_vecs = query_vec source_segment_id = query_segment_id if p.is_masked: assert source_vecs is not None query_vec = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(source_vecs), tf.shape(query_vec)) ], query_vec) # Prepares mask for self-attention # [time, time] target_time = tf.shape(query_vec)[0] target_bs = tf.shape(query_vec)[1] triangle_padding = 1.0 - tf.matrix_band_part( tf.ones([target_time, target_time], dtype=py_utils.FPropDtype(p)), -1, 0) # [time, batch, time] causal_padding = tf.tile(tf.expand_dims(triangle_padding, 1), [1, target_bs, 1]) causal_padding = tf.reshape(causal_padding, [-1, target_time]) else: causal_padding = None query_dim = tf.shape(query_vec)[-1] packed_src = self.atten.PackSource(theta.atten, source_vecs, source_vecs, source_paddings, source_segment_id) if query_segment_id is not None: query_segment_id = tf.reshape(query_segment_id, [-1]) ctx_vec, atten_prob, _ = self.atten.ComputeContextVectorWithSource( theta.atten, packed_src, tf.reshape(query_vec, [-1, query_dim]), per_step_source_padding=causal_padding, query_segment_id=query_segment_id) ctx_vec = self.residual_dropout.FProp(theta.residual_dropout, ctx_vec) input_to_add = (unnormalized_query_vec if p.add_unnormalized_input else query_vec) h = input_to_add + tf.reshape(ctx_vec, tf.shape(query_vec)) atten_prob = tf.reshape(atten_prob, [ tf.shape(query_vec)[0], tf.shape(query_vec)[1], tf.shape(source_vecs)[0] ]) return h, atten_prob