def get_predictions(reader_outputs, params): """Get predictions.""" tokenization_info = bert_utils.get_tokenization_info( params["reader_module_path"]) with tf.io.gfile.GFile(tokenization_info["vocab_file"]) as vocab_file: vocab = tf.constant([l.strip() for l in vocab_file.readlines()]) # [] predicted_block_index = tf.argmax(tf.reduce_max(reader_outputs.logits, 1)) predicted_candidate = tf.argmax(tf.reduce_max(reader_outputs.logits, 0)) predicted_block = tf.gather(reader_outputs.blocks, predicted_block_index) predicted_orig_block = tf.gather(reader_outputs.orig_blocks, predicted_block_index) predicted_orig_tokens = tf.gather(reader_outputs.orig_tokens, predicted_block_index) predicted_orig_start = tf.gather( tf.gather(reader_outputs.candidate_orig_starts, predicted_block_index), predicted_candidate) predicted_orig_end = tf.gather( tf.gather(reader_outputs.candidate_orig_ends, predicted_block_index), predicted_candidate) predicted_orig_answer = tf.reduce_join( predicted_orig_tokens[predicted_orig_start:predicted_orig_end + 1], separator=" ") predicted_token_ids = tf.gather(reader_outputs.token_ids, predicted_block_index) predicted_tokens = tf.gather(vocab, predicted_token_ids) predicted_start = tf.gather(reader_outputs.candidate_starts, predicted_candidate) predicted_end = tf.gather(reader_outputs.candidate_ends, predicted_candidate) predicted_normalized_answer = tf.reduce_join( predicted_tokens[predicted_start:predicted_end + 1], separator=" ") def _get_final_text(pred_text, orig_text): pred_text = six.ensure_text(pred_text, errors="ignore") orig_text = six.ensure_text(orig_text, errors="ignore") return squad_lib.get_final_text( pred_text=pred_text, orig_text=orig_text, do_lower_case=tokenization_info["do_lower_case"]) predicted_answer = tf.py_func( func=_get_final_text, inp=[predicted_normalized_answer, predicted_orig_answer], Tout=tf.string) return dict(block_index=predicted_block_index, candidate=predicted_candidate, block=predicted_block, orig_block=predicted_orig_block, orig_tokens=predicted_orig_tokens, orig_start=predicted_orig_start, orig_end=predicted_orig_end, answer=predicted_answer)
def get_predictions(reader_outputs): """Get predictions.""" # [] predicted_block_index = tf.argmax(tf.reduce_max(reader_outputs.logits, 1)) predicted_candidate = tf.argmax(tf.reduce_max(reader_outputs.logits, 0)) predicted_block = tf.gather(reader_outputs.blocks, predicted_block_index) predicted_orig_block = tf.gather(reader_outputs.orig_blocks, predicted_block_index) predicted_orig_tokens = tf.gather(reader_outputs.orig_tokens, predicted_block_index) predicted_orig_start = tf.gather( tf.gather(reader_outputs.candidate_orig_start, predicted_block_index), predicted_candidate) predicted_orig_end = tf.gather( tf.gather(reader_outputs.candidate_orig_end, predicted_block_index), predicted_candidate) predicted_answer = tf.reduce_join( predicted_orig_tokens[predicted_orig_start:predicted_orig_end + 1], separator=" ") return dict( block_index=predicted_block_index, candidate=predicted_candidate, block=predicted_block, orig_block=predicted_orig_block, orig_tokens=predicted_orig_tokens, orig_start=predicted_orig_start, orig_end=predicted_orig_end, answer=predicted_answer)
def from_tokens(raw, lookup_): gathered = tf.gather(lookup_, tf.cast(raw, tf.int32)) joined = tf.regex_replace(tf.reduce_join(gathered, axis=1), b"<EOS>.*", b"") cleaned = tf.regex_replace(joined, b"_", b" ") tokens = tf.string_split(cleaned, " ") return tokens
def from_characters(raw, lookup_): """Convert ascii+2 encoded codes to string-tokens.""" corrected = tf.bitcast(tf.clip_by_value(tf.subtract(raw, 2), 0, 255), tf.uint8) gathered = tf.gather(lookup_, tf.cast(corrected, tf.int32))[:, :, 0] joined = tf.reduce_join(gathered, axis=1) cleaned = tf.regex_replace(joined, b"\0", b"") tokens = tf.string_split(cleaned, " ") return tokens
def _get_prediction_text(args, window=5): """Get the prediction text for a single row in the batch.""" current_context, start, end = args prediction_context_start = tf.maximum(start - window, 0) prediction_context_end = tf.minimum(end + 1 + window, tf.shape(current_context)[0]) before = current_context[prediction_context_start:start] prediction = current_context[start:end + 1] after = current_context[end + 1:prediction_context_end] concat = tf.concat([before, ["**"], prediction, ["**"], after], 0) return tf.reduce_join(concat, separator=" ")
def random_dna_sequence(template, shape=()): """Generate a random DNA sequence according to a template. Args: template: string of characters representing a template for making DNA sequences. Each occurrence of an 'N' is replaced by a randomly chosen base. shape: optional tuple indicating the desired shape of the returned Tensor. Returns: tf.Tensor with dtype=strings. """ n_random = sum(base == 'N' for base in template) all_bases = tf.constant([x.encode('utf8') for x in dna.DNA_BASES]) random_index = tf.random_uniform((n_random, ) + shape, maxval=3, dtype=tf.int32) random_bases = tf.gather(all_bases, random_index) iter_bases = iter(random_bases[i] for i in range(n_random)) bases = [ next(iter_bases) if base == 'N' else tf.fill(shape, base) for base in template ] return tf.reduce_join(bases, 0)
def get_text_summary(question, context, start_predictions, end_predictions): """Get a text summary of the question and the predicted answer.""" question_text = tf.reduce_join(question, -1, separator=" ") def _get_prediction_text(args, window=5): """Get the prediction text for a single row in the batch.""" current_context, start, end = args prediction_context_start = tf.maximum(start - window, 0) prediction_context_end = tf.minimum(end + 1 + window, tf.shape(current_context)[0]) before = current_context[prediction_context_start:start] prediction = current_context[start:end + 1] after = current_context[end + 1:prediction_context_end] concat = tf.concat([before, ["**"], prediction, ["**"], after], 0) return tf.reduce_join(concat, separator=" ") prediction_text = tf.map_fn( fn=_get_prediction_text, elems=[context, start_predictions, end_predictions], dtype=tf.string, back_prop=False) return tf.summary.text("predictions", tf.stack([question_text, prediction_text], -1))