def _prepare_inputs_for_decode(self, features): """Prepare inputs for decoding. Args: features: A map of string to model features. Returns: Inputs after fixing shape and applying modality. """ dp = self._data_parallelism hparams = self._hparams inputs = features['inputs'] inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) inputs = self._shard_features({'inputs': inputs})['inputs'] input_modality = self._problem_hparams.modality['inputs'] input_vocab_size = self._problem_hparams.vocab_size['inputs'] if input_vocab_size is not None and hasattr(hparams, 'vocab_divisor'): input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor modality_name = hparams.name.get('inputs', modalities.get_name(input_modality))( hparams, input_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get('inputs', modalities.get_bottom(input_modality)) inputs = dp(bottom, inputs, hparams, input_vocab_size) return inputs
def testGetForAllModalities(self): for modality in modalities.ModalityType.get_choices(): bottom = modalities.get_bottom(modality) loss = modalities.get_loss(modality) name = modalities.get_name(modality) targets_bottom = modalities.get_targets_bottom(modality) top = modalities.get_top(modality) weights_fn = modalities.get_weights_fn(modality) self.assertIsNotNone(bottom, msg="{} has no default bottom".format(modality)) self.assertIsNotNone(loss, msg="{} has no default loss".format(modality)) self.assertIsNotNone(name, msg="{} has no default name".format(modality)) self.assertIsNotNone( targets_bottom, msg="{} has no default targets_bottom".format(modality)) self.assertIsNotNone(top, msg="{} has no default top".format(modality)) self.assertIsNotNone(weights_fn, msg="{} has no default weights_fn".format(modality))
def testSymbolModalityInputs(self): batch_size = 10 num_datashards = 5 length = 5 vocab_size = 5000 hidden_size = 9 model_hparams = common_hparams.basic_params1() model_hparams.hidden_size = hidden_size model_hparams.mode = tf.estimator.ModeKeys.TRAIN x = np.random.randint(vocab_size, size=(batch_size, length, 1, 1)) data_parallelism = expert_utils.Parallelism(["/device:CPU:0"] * num_datashards) xs = tf.split(x, num_datashards) sharded_output = data_parallelism( modalities.get_bottom(modalities.ModalityType.SYMBOL), xs, model_hparams, vocab_size) output = tf.concat(sharded_output, 0) self.evaluate(tf.global_variables_initializer()) res = self.evaluate(output) self.assertEqual(res.shape, (batch_size, length, 1, hidden_size))
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] target_vocab_size = self._problem_hparams.vocab_size["targets"] if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"): target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") if self.has_input: inputs = features["inputs"] if target_modality == modalities.ModalityType.CLASS_LABEL: decode_length = 1 else: decode_length = ( common_layers.shape_list(inputs)[1] + features.get( "decode_length", decode_length)) # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] input_vocab_size = self._problem_hparams.vocab_size["inputs"] if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"): input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor modality_name = hparams.name.get( "inputs", modalities.get_name(input_modality))(hparams, input_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get("inputs", modalities.get_bottom(input_modality)) inputs = dp(bottom, inputs, hparams, input_vocab_size) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if 'partial_targets' in features: partial_targets = features['partial_targets'] else: partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = ( partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, "body/targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get( "targets", modalities.get_targets_bottom(target_modality)) targets = dp(bottom, targets, hparams, target_vocab_size)[0] targets = common_layers.flatten4d3d(targets) # GO embeddings are all zero, this is because transformer_prepare_decoder # Shifts the targets along by one for the input which pads with zeros. # If the modality already maps GO to the zero embeddings this is not # needed. targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) # Create tensors for encoder-decoder attention history att_cache = {"attention_history": {}} num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers att_batch_size, enc_seq_length = common_layers.shape_list(encoder_output)[0:2] for layer in range(num_layers): att_cache["attention_history"]["layer_%d" % layer] = tf.zeros( [att_batch_size, hparams.num_heads, 0, enc_seq_length]) att_cache["body_outputs"] = tf.zeros([att_batch_size, 1, 0, hparams.hidden_size]) def update_decoder_attention_history(cache): for k in filter(lambda x: "decoder" in x and not "self" in x and not "logits" in x, self.attention_weights.keys()): m = re.search(r"(layer_\d+)", k) if m is None: continue cache["attention_history"][m[0]] = tf.concat( [cache["attention_history"][m[0]], self.attention_weights[k]], axis=2) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) update_decoder_attention_history(cache) cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2) modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): top = hparams.top.get("targets", modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_vocab_size, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, cache=att_cache) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] else: ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] return ret