def _prepare_inputs_for_decode(self, features): """Prepare inputs for decoding. Args: features: A map of string to model features. Returns: Inputs after fixing shape and applying modality. """ dp = self._data_parallelism hparams = self._hparams inputs = features['inputs'] inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) inputs = self._shard_features({'inputs': inputs})['inputs'] input_modality = self._problem_hparams.modality['inputs'] input_vocab_size = self._problem_hparams.vocab_size['inputs'] if input_vocab_size is not None and hasattr(hparams, 'vocab_divisor'): input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor modality_name = hparams.name.get('inputs', modalities.get_name(input_modality))( hparams, input_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get('inputs', modalities.get_bottom(input_modality)) inputs = dp(bottom, inputs, hparams, input_vocab_size) return inputs
def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) update_decoder_attention_history(cache) cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2) modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): top = hparams.top.get("targets", modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache
def testGetForAllModalities(self): for modality in modalities.ModalityType.get_choices(): bottom = modalities.get_bottom(modality) loss = modalities.get_loss(modality) name = modalities.get_name(modality) targets_bottom = modalities.get_targets_bottom(modality) top = modalities.get_top(modality) weights_fn = modalities.get_weights_fn(modality) self.assertIsNotNone(bottom, msg="{} has no default bottom".format(modality)) self.assertIsNotNone(loss, msg="{} has no default loss".format(modality)) self.assertIsNotNone(name, msg="{} has no default name".format(modality)) self.assertIsNotNone( targets_bottom, msg="{} has no default targets_bottom".format(modality)) self.assertIsNotNone(top, msg="{} has no default top".format(modality)) self.assertIsNotNone(weights_fn, msg="{} has no default weights_fn".format(modality))
def preprocess_targets(targets, i): targets = self._shard_features({'targets': targets})['targets'] modality_name = hparams.name.get( 'targets', modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name + '/targets'): bottom = hparams.bottom.get( 'targets', modalities.get_targets_bottom(target_modality)) targets = dp(bottom, targets, hparams, target_vocab_size)[0] targets = common_layers.flatten4d3d(targets) if not self.get_decode_start_id(): targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets, ) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets
def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get( "targets", modalities.get_targets_bottom(target_modality)) targets = dp(bottom, targets, hparams, target_vocab_size)[0] targets = common_layers.flatten4d3d(targets) # GO embeddings are all zero, this is because transformer_prepare_decoder # Shifts the targets along by one for the input which pads with zeros. # If the modality already maps GO to the zero embeddings this is not # needed. targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] target_vocab_size = self._problem_hparams.vocab_size["targets"] if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"): target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") if self.has_input: inputs = features["inputs"] if target_modality == modalities.ModalityType.CLASS_LABEL: decode_length = 1 else: decode_length = ( common_layers.shape_list(inputs)[1] + features.get( "decode_length", decode_length)) # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] input_vocab_size = self._problem_hparams.vocab_size["inputs"] if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"): input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor modality_name = hparams.name.get( "inputs", modalities.get_name(input_modality))(hparams, input_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get("inputs", modalities.get_bottom(input_modality)) inputs = dp(bottom, inputs, hparams, input_vocab_size) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if 'partial_targets' in features: partial_targets = features['partial_targets'] else: partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = ( partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, "body/targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get( "targets", modalities.get_targets_bottom(target_modality)) targets = dp(bottom, targets, hparams, target_vocab_size)[0] targets = common_layers.flatten4d3d(targets) # GO embeddings are all zero, this is because transformer_prepare_decoder # Shifts the targets along by one for the input which pads with zeros. # If the modality already maps GO to the zero embeddings this is not # needed. targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) # Create tensors for encoder-decoder attention history att_cache = {"attention_history": {}} num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers att_batch_size, enc_seq_length = common_layers.shape_list(encoder_output)[0:2] for layer in range(num_layers): att_cache["attention_history"]["layer_%d" % layer] = tf.zeros( [att_batch_size, hparams.num_heads, 0, enc_seq_length]) att_cache["body_outputs"] = tf.zeros([att_batch_size, 1, 0, hparams.hidden_size]) def update_decoder_attention_history(cache): for k in filter(lambda x: "decoder" in x and not "self" in x and not "logits" in x, self.attention_weights.keys()): m = re.search(r"(layer_\d+)", k) if m is None: continue cache["attention_history"][m[0]] = tf.concat( [cache["attention_history"][m[0]], self.attention_weights[k]], axis=2) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) update_decoder_attention_history(cache) cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2) modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): top = hparams.top.get("targets", modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_vocab_size, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, cache=att_cache) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] else: ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] return ret
def symbols_to_logits_fn(ids, ids_tag, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets_method(targets, i) ids_tag = ids_tag[:, -1:] targets_tag = tf.expand_dims(tf.expand_dims(ids_tag, axis=2), axis=3) targets_tag = preprocess_targets_tag_method(targets_tag, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope('body'): with tf.variable_scope('edit_ops_layer'): with tf.variable_scope('ffn'): x = targets preproc = lambda z: common_layers.layer_preprocess( z, hparams, layer_collection=None) layer_inputs = [ tf.concat(preproc(x), axis=0), tf.concat(preproc(targets_tag), axis=0), ] y = transformer_layers.transformer_ffn_layer( tf.concat(layer_inputs, axis=2), hparams, conv_padding='LEFT', nonpadding_mask=features_to_nonpadding( features, 'targets'), losses=None, cache=cache, decode_loop_step=None, layer_collection=None, ) targets = common_layers.layer_postprocess( x, y, hparams) if hparams.middle_prediction: num_decoder_layers = (hparams.num_decoder_layers or hparams.num_hidden_layers) hparams.num_decoder_layers = int( num_decoder_layers / hparams.middle_prediction_layer_factor) body_outputs = dp( self.decode, targets, cache.get('encoder_output'), cache.get('encoder_decoder_attention_bias'), bias, hparams, cache, nonpadding=features_to_nonpadding(features, 'targets'), )[0] body_outputs, logits_tag = dp( self._prediction_cascade_predict, hparams, features_to_nonpadding(features, 'targets'), cache.get('encoder_decoder_attention_bias'), cache.get('encoder_output'), body_outputs, ) logits_tag = logits_tag[0]['targets_error_tag'] if hparams.middle_prediction: with tf.variable_scope('after_prediction'): body_outputs = dp( self.decode, targets + body_outputs[0], cache.get('encoder_output'), cache.get('encoder_decoder_attention_bias'), bias, hparams, cache, nonpadding=features_to_nonpadding( features, 'targets'), ) update_decoder_attention_history(cache) modality_name = hparams.name.get( 'targets', modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope('targets/' + modality_name): top = hparams.top.get('targets', modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2]) if partial_targets is not None: vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9, ) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret, ) logits_tag = tf.squeeze(logits_tag, axis=[1]) return ret, logits_tag, cache