def _create_predictions(self, decoder_output, features, labels, losses=None): """Creates the dictionary of predictions that is returned by the model. """ predictions = {} # Add features and, if available, labels to predictions predictions.update(_flatten_dict({"features": features})) if labels is not None: predictions.update(_flatten_dict({"labels": labels})) if losses is not None: predictions["losses"] = _transpose_batch_time(losses) # Decoders returns output in time-major form [T, B, ...] # Here we transpose everything back to batch-major for the user output_dict = collections.OrderedDict( zip(decoder_output._fields, decoder_output)) decoder_output_flat = _flatten_dict(output_dict) decoder_output_flat = { k: _transpose_batch_time(v) for k, v in decoder_output_flat.items() } predictions.update(decoder_output_flat) # If we predict the ids also map them back into the vocab and process them if "predicted_ids" in predictions.keys(): vocab_tables = graph_utils.get_dict_from_collection("vocab_tables") target_id_to_vocab = vocab_tables["target_id_to_vocab"] predicted_tokens = target_id_to_vocab.lookup( tf.to_int64(predictions["predicted_ids"])) # Raw predicted tokens predictions["predicted_tokens"] = predicted_tokens return predictions
def predict(decoder_output): predictions = {} # Decoders returns output in time-major form [T, B, ...] # Here we transpose everything back to batch-major for the user output_dict = collections.OrderedDict( zip(decoder_output._fields, decoder_output)) decoder_output_flat = _flatten_dict(output_dict) decoder_output_flat = { k: _transpose_batch_time(v) for k, v in decoder_output_flat.items() } predictions.update(decoder_output_flat) return predictions
def create_predictions(self, decoder_output, features, labels, losses=None): """Creates the dictionary of predictions that is returned by the model. """ predictions = {} # Add features and, if available, labels to predictions predictions.update(_flatten_dict({"features": features})) if labels is not None: predictions.update(_flatten_dict({"labels": labels})) if losses is not None: predictions["losses"] = losses # Decoders returns output in time-major form [T, B, ...] # Here we transpose everything back to batch-major for the user output_dict = collections.OrderedDict( zip(decoder_output._fields, decoder_output)) decoder_output_flat = _flatten_dict(output_dict) predictions.update(decoder_output_flat) # If we predict the ids also map them back into the vocab and process them if "predicted_ids" in predictions.keys(): predicted_tk_id = predictions["predicted_ids"] vocab_tables = graph_utils.get_dict_from_collection("vocab_tables") target_id_to_vocab = vocab_tables["target_id_to_vocab"] predicted_tokens = target_id_to_vocab.lookup( tf.to_int64(predicted_tk_id)) # Raw predicted tokens # Force to reshape to [batch_size, 1, 1] so that it has same shape as regular seq2seq, # and all post process code can be reused. #predicted_tokens = tf.reshape(predicted_tokens, [tf.shape(predicted_tokens)[0], 1, 1]) predictions["predicted_tokens"] = predicted_tokens return predictions