def weights_fn(labels): """Per-token weights for loss.""" # Use target_weights_fn() given by modality as well as explicitly given # weights. modality_weights = targets_weights_fn(labels) # Broadcast 'weights' along minor dimensions (TF's default is major). explicit_weights = weights if len(explicit_weights.shape) < len(modality_weights.shape): explicit_weights = common_layers.expand_squeeze_to_nd( weights, modality_weights.shape.ndims) return explicit_weights * modality_weights
def _import_feature(self, features, mesh, key): """Import a feature from the features dictionary into a mtf.Tensor. Args: features: a features dictionary mesh: a Mesh key: a string Returns: a mtf.Tensor with dtype int32 and shape self.batch_dims + self.length_dim """ if key not in features: return None x = tf.to_int32(features[key]) x = common_layers.expand_squeeze_to_nd(x, 2) # pad to length extra_length = self.length_dim.size - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) mtf_shape = mtf.Shape(self.batch_dims + [self.length_dim]) x = tf.reshape(x, mtf_shape.to_integer_list) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)
def sample(self, features, mesh): hparams = self._hparams model = self.model() # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, self.length_dim.size - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh) # strip EOS partial_targets *= mtf.to_int32(mtf.not_equal(partial_targets, 1)) else: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) partial_targets = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) if hparams.beam_size == 1: pass else: raise NotImplementedError("not implemented") # beam_dim = mtf.Dimension("beam", hparams.beam_size) # ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) partial_targets = mtf.Print(partial_targets, [partial_targets], "Partial_Targets", summarize=1000) return model.sample_autoregressive(partial_targets, temperature=hparams.sampling_temp, variable_dtype=self.variable_dtype)
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") if self.has_input: inputs = features["inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) contexts = {} for feature_name in features: if 'context' in feature_name and 'raw' not in feature_name: contexts[feature_name] = features[feature_name] inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] context_modality = {} for context_name in contexts: if context_name in self._problem_hparams.modality: context_modality[ context_name] = self._problem_hparams.modality[ context_name] else: context_modality[context_name] = input_modality with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE): inputs = input_modality.bottom_sharded(inputs, dp) for feature_name in contexts: with tf.variable_scope(context_modality[feature_name].name, reuse=tf.AUTO_REUSE): contexts[feature_name] = context_modality[ feature_name].bottom_sharded(contexts[feature_name], dp) contexts_list = [ contexts[feature_name][0] for feature_name in contexts ] contexts = tf.concat(contexts_list, axis=1) inputs = [tf.concat([contexts, inputs[0]], axis=1)] with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = (partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, "body/targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] else: ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] return ret
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad(inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = (mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack( x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension(x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum([encoder_output, k_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) v = mtf.einsum([encoder_output, v_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError("hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim ]) local_kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, local_attention_window, self.kv_dim ]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend([ mtf.zeros( mesh, local_kv_shape, dtype=self.activation_dtype) ] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode(logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] target_vocab_size = self._problem_hparams.vocab_size["targets"] if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"): target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") if self.has_input: inputs = features["inputs"] if target_modality == modalities.ModalityType.CLASS_LABEL: decode_length = 1 else: decode_length = ( common_layers.shape_list(inputs)[1] + features.get( "decode_length", decode_length)) # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] input_vocab_size = self._problem_hparams.vocab_size["inputs"] if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"): input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor modality_name = hparams.name.get( "inputs", modalities.get_name(input_modality))(hparams, input_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get("inputs", modalities.get_bottom(input_modality)) inputs = dp(bottom, inputs, hparams, input_vocab_size) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if 'partial_targets' in features: partial_targets = features['partial_targets'] else: partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = ( partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, "body/targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get( "targets", modalities.get_targets_bottom(target_modality)) targets = dp(bottom, targets, hparams, target_vocab_size)[0] targets = common_layers.flatten4d3d(targets) # GO embeddings are all zero, this is because transformer_prepare_decoder # Shifts the targets along by one for the input which pads with zeros. # If the modality already maps GO to the zero embeddings this is not # needed. targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) # Create tensors for encoder-decoder attention history att_cache = {"attention_history": {}} num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers att_batch_size, enc_seq_length = common_layers.shape_list(encoder_output)[0:2] for layer in range(num_layers): att_cache["attention_history"]["layer_%d" % layer] = tf.zeros( [att_batch_size, hparams.num_heads, 0, enc_seq_length]) att_cache["body_outputs"] = tf.zeros([att_batch_size, 1, 0, hparams.hidden_size]) def update_decoder_attention_history(cache): for k in filter(lambda x: "decoder" in x and not "self" in x and not "logits" in x, self.attention_weights.keys()): m = re.search(r"(layer_\d+)", k) if m is None: continue cache["attention_history"][m[0]] = tf.concat( [cache["attention_history"][m[0]], self.attention_weights[k]], axis=2) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) update_decoder_attention_history(cache) cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2) modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): top = hparams.top.get("targets", modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_vocab_size, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, cache=att_cache) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] else: ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") if self.has_input: inputs = features["inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) inputs = tf.squeeze(inputs, (2, 3)) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] # input_modality = self._problem_hparams.input_modality["inputs"] # with tf.variable_scope(input_modality.name): # inputs = input_modality.bottom_sharded(inputs, dp) encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = (partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.d_model]), hparams.max_length, "positional_embedding", None) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ targets_emb_var = self._get_targets_emb_var targets = tf.gather(targets_emb_var, targets) tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, (2, 3)) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = None # decoder_self_attention_bias[:, :, i:i + 1, :i + 1] body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache) logits = body_outputs[0] # with tf.variable_scope(target_modality.name): # logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] else: ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] return ret
def _fast_decode( self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0, preprocess_targets_method=None, ): if self._num_datashards != 1: raise NotImplementedError( 'Fast decoding only supports a single shard.') dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality['targets'] target_vocab_size = self._problem_hparams.vocab_size['targets'] if target_vocab_size is not None and hasattr(hparams, 'vocab_divisor'): target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor target_tag_modality = self._problem_hparams.modality[ 'targets_error_tag'] target_tag_vocab_size = self._problem_hparams.vocab_size[ 'targets_error_tag'] if target_tag_vocab_size is not None and hasattr( hparams, 'vocab_divisor'): target_tag_vocab_size += ( -target_tag_vocab_size) % hparams.vocab_divisor if 'targets_segmentation' in features: raise NotImplementedError( 'Decoding not supported on packed datasets ' ' If you want to decode from a dataset, use the non-packed version' ' of the dataset when decoding.') if self.has_input: inputs_shape = common_layers.shape_list(features['inputs']) if (target_modality == modalities.ModalityType.CLASS_LABEL or self._problem_hparams.get('regression_targets')): decode_length = 1 else: decode_length = inputs_shape[1] + features.get( 'decode_length', decode_length) batch_size = inputs_shape[0] inputs = self._prepare_inputs_for_decode(features) with tf.variable_scope('body'): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features['target_space_id'], hparams, features=features, ) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = features.get('partial_targets') else: encoder_output = None encoder_decoder_attention_bias = None partial_targets = features.get('inputs') if partial_targets is None: partial_targets = features['targets'] assert partial_targets is not None if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = partial_targets_length + features.get( 'decode_length', decode_length) batch_size = partial_targets_shape[0] if hparams.pos == 'timing': positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == 'timing_from_features': positional_encoding = common_attention.add_timing_signals_from_features( tf.zeros([1, decode_length, hparams.hidden_size]), features, hparams.position_features, ) elif hparams.pos == 'emb': positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, 'body/targets_positional_embedding', None, ) else: positional_encoding = None def preprocess_targets(targets, i): targets = self._shard_features({'targets': targets})['targets'] modality_name = hparams.name.get( 'targets', modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name + '/targets'): bottom = hparams.bottom.get( 'targets', modalities.get_targets_bottom(target_modality)) targets = dp(bottom, targets, hparams, target_vocab_size)[0] targets = common_layers.flatten4d3d(targets) if not self.get_decode_start_id(): targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets, ) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets def preprocess_targets_tag_method(targets, i): targets = self._shard_features({'targets_error_tag': targets})['targets_error_tag'] modality_name = hparams.name.get( 'targets_error_tag', modalities.get_name(target_tag_modality))( hparams, target_tag_vocab_size) with tf.variable_scope(modality_name + '/targets_error_tag'): bottom = hparams.bottom.get( 'targets_error_tag', modalities.get_targets_bottom(target_tag_modality), ) targets = dp(bottom, targets, hparams, target_tag_vocab_size)[0] targets = common_layers.flatten4d3d(targets) if not self.get_decode_start_id(): targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets, ) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = common_attention.attention_bias_lower_triangle( decode_length) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) att_cache = {'attention_history': {}} num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers if encoder_output is not None: att_batch_size, enc_seq_length = common_layers.shape_list( encoder_output)[0:2] for layer in range(num_layers): att_cache['attention_history']['layer_%d' % layer] = tf.zeros( [att_batch_size, hparams.num_heads, 0, enc_seq_length]) def update_decoder_attention_history(cache): for k in [ x for x in self.attention_weights if 'decoder' in x and 'self' not in x and 'logits' not in x ]: idx = k.find('layer_') if idx < 0: continue # Get layer number from the string name. layer_nbr = k[idx + 6:] idx = 0 while (idx + 1 < len(layer_nbr) and layer_nbr[:idx + 1].isdigit()): idx += 1 layer_nbr = 'layer_%d' % int(layer_nbr[:idx]) if layer_nbr in cache['attention_history']: cache['attention_history'][layer_nbr] = tf.concat( [ cache['attention_history'][layer_nbr], self.attention_weights[k], ], axis=2, ) if not preprocess_targets_method: preprocess_targets_method = preprocess_targets def symbols_to_logits_fn(ids, ids_tag, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets_method(targets, i) ids_tag = ids_tag[:, -1:] targets_tag = tf.expand_dims(tf.expand_dims(ids_tag, axis=2), axis=3) targets_tag = preprocess_targets_tag_method(targets_tag, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope('body'): with tf.variable_scope('edit_ops_layer'): with tf.variable_scope('ffn'): x = targets preproc = lambda z: common_layers.layer_preprocess( z, hparams, layer_collection=None) layer_inputs = [ tf.concat(preproc(x), axis=0), tf.concat(preproc(targets_tag), axis=0), ] y = transformer_layers.transformer_ffn_layer( tf.concat(layer_inputs, axis=2), hparams, conv_padding='LEFT', nonpadding_mask=features_to_nonpadding( features, 'targets'), losses=None, cache=cache, decode_loop_step=None, layer_collection=None, ) targets = common_layers.layer_postprocess( x, y, hparams) if hparams.middle_prediction: num_decoder_layers = (hparams.num_decoder_layers or hparams.num_hidden_layers) hparams.num_decoder_layers = int( num_decoder_layers / hparams.middle_prediction_layer_factor) body_outputs = dp( self.decode, targets, cache.get('encoder_output'), cache.get('encoder_decoder_attention_bias'), bias, hparams, cache, nonpadding=features_to_nonpadding(features, 'targets'), )[0] body_outputs, logits_tag = dp( self._prediction_cascade_predict, hparams, features_to_nonpadding(features, 'targets'), cache.get('encoder_decoder_attention_bias'), cache.get('encoder_output'), body_outputs, ) logits_tag = logits_tag[0]['targets_error_tag'] if hparams.middle_prediction: with tf.variable_scope('after_prediction'): body_outputs = dp( self.decode, targets + body_outputs[0], cache.get('encoder_output'), cache.get('encoder_decoder_attention_bias'), bias, hparams, cache, nonpadding=features_to_nonpadding( features, 'targets'), ) update_decoder_attention_history(cache) modality_name = hparams.name.get( 'targets', modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope('targets/' + modality_name): top = hparams.top.get('targets', modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2]) if partial_targets is not None: vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9, ) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret, ) logits_tag = tf.squeeze(logits_tag, axis=[1]) return ret, logits_tag, cache sos_id = self.get_decode_start_id() or 0 eos_id = self.get_decode_end_id() or beam_search.EOS_ID temperature = features.get('sampling_temp', getattr(hparams, 'sampling_temp', 0.0)) top_k = features.get('sampling_keep_top_k', getattr(hparams, 'sampling_keep_top_k', -1)) ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_vocab_size, init_cache_fn=_init_transformer_cache, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, sos_id=sos_id, eos_id=eos_id, sampling_temperature=temperature, top_k=top_k, cache=att_cache, ) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret['outputs'] = ret['outputs'][:, partial_targets_length:] else: ret['outputs'] = ret['outputs'][:, :, partial_targets_length:] return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Overrides tensor2tensor.models.transformer.Transformer._fast_decode to let symbols_to_logits_fn return multiple things. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. Returns: A dict of decoding results { "body_output": tensor of size [batch_size, <= decode_length, hidden_size] (or [batch_size, top_beams, <= decode_length, hidden_size]) giving the raw output of the Transformer decoder corresponding to the predicted sequences "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality if isinstance(target_modality, dict): primary_target_feature = self._problem_hparams.primary_target_modality primary_target_modality = target_modality[primary_target_feature] bottom_variable_scope = "%s/%s" % (primary_target_modality.name, primary_target_feature) else: primary_target_feature = "targets" primary_target_modality = target_modality bottom_variable_scope = target_modality.name if self.has_input: inputs = features["inputs"] if primary_target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features[primary_target_feature] assert partial_targets is not None partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = (partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length + 1, hparams.hidden_size]), hparams.max_length, "targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({primary_target_feature: targets})[primary_target_feature] with tf.variable_scope(bottom_variable_scope): targets = primary_target_modality.targets_bottom_sharded( targets, dp)[0] targets = common_layers.flatten4d3d(targets) # At step 0, targets will have 0 size, and instead we want to # create an embedding of all-zero, corresponding to the start symbol # this matches what transformer_prepare_decoder does to the target # outputs during training targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] logits = self._symbols_to_logits_fn(targets, features, bias, cache) logits = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(logits)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) logits = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: logits) return logits, cache cache = dict() infer_out = dict() if encoder_output is not None: padding_mask = 1. - common_attention.attention_bias_to_padding( encoder_decoder_attention_bias) masked_encoded_output = encoder_output * tf.expand_dims( padding_mask, axis=2) infer_out["encoded_inputs"] = tf.reduce_sum(masked_encoded_output, axis=1) self._prepare_decoder_cache(batch_size, features, cache) ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=primary_target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, cache=cache) infer_out.update(ret) if "cache" in ret: infer_out.update(ret["cache"]) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: infer_out["outputs"] = infer_out[ "outputs"][:, partial_targets_length:] else: infer_out["outputs"] = infer_out[ "outputs"][:, :, partial_targets_length:] return infer_out
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError( "hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, local_attention_window, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend( [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def _fast_decode_tpu(self, features, decode_length, beam_size=1): """Fast decoding. Implements only greedy decoding on TPU. Args: features: A map of string to model features. decode_length: An integer, how many additional timesteps to decode. beam_size: An integer, number of beams. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) }. Raises: NotImplementedError: If there are multiple data shards or beam_size > 1. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality if self.has_input: inputs = features["inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = (partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length + 1, hparams.hidden_size]), hparams.max_length, "body/targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: A tensor, inputs ids to the decoder. [batch_size, 1]. i: An integer, Step number of the decoding loop. Returns: A tensor, processed targets [batch_size, 1, hidden_dim]. """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: positional_encoding_shape = positional_encoding.shape.as_list() targets += tf.slice(positional_encoding, [0, i, 0], [ positional_encoding_shape[0], 1, positional_encoding_shape[2] ]) return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_tpu_fn(ids, i, cache): """Go from ids to logits for next symbol on TPU. Args: ids: A tensor, symbol IDs. i: An integer, step number of the decoding loop. Only used for inference on TPU. cache: A dict, containing tensors which are the results of previous attentions, used for fast decoding. Returns: ret: A tensor, computed logits. cache: A dict, containing tensors which are the results of previous attentions, used for fast decoding. """ ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias_shape = decoder_self_attention_bias.shape.as_list() bias = tf.slice(decoder_self_attention_bias, [0, 0, i, 0], [bias_shape[0], bias_shape[1], 1, bias_shape[3]]) with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, i, nonpadding=features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile( tf.slice(partial_targets, [0, i], [partial_targets.shape.as_list()[0], 1]), [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode_tpu( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_tpu_fn, hparams=hparams, decode_length=decode_length, beam_size=beam_size, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length) if partial_targets is not None: ret["outputs"] = ret["outputs"][:, partial_targets_length:] return ret