def test_add_bucket_id(self): input_tensor = tf.constant([[101, 102, 1, 0], [103, 104, 105, 106]]) target_tensor = tf.constant([[2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0], [2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]]) bucket_size = 2 length_bucket_start_id = 50 length_bucket_max_id = 53 input_tensor = utils.add_length_bucket_id(input_tensor, target_tensor, bucket_size, length_bucket_start_id, length_bucket_max_id) self.assertAllEqual([[53, 101, 102, 1], [51, 103, 104, 105]], input_tensor)
def parser(input_dic): """Parser for string dict.""" inputs = parsing_ops.encode( tf.reshape(input_dic["inputs"], [1]), max_input_len, vocab_filename, encoder_type) targets = parsing_ops.encode( tf.reshape(input_dic["targets"], [1]), max_target_len, vocab_filename, encoder_type) inputs = utils.add_length_bucket_id(inputs, targets, length_bucket_size, length_bucket_start_id, length_bucket_max_id) if add_task_id: inputs = utils.add_task_id(inputs, task_start_id + input_dic["task_id"]) return {"inputs": inputs, "targets": targets}
def parser(input_dic): """Parser for string dict.""" if parser_strategy not in [ "random", "lead", "rouge", "greedy_rouge", "continuous_rouge", "hybrid", "none", "dynamic_rouge" ]: raise ValueError("Invalid parser_strategy. Got %s." % parser_strategy) if parser_rouge_metric_type not in ["precision", "recall", "F"]: raise ValueError("Invalid parser_rouge_metric_type. ", "Got %s." % parser_rouge_metric_type) if parser_rouge_compute_option not in ["standard", "deduplicate", "log"]: raise ValueError("Invalid parser_rouge_compute_options. ", "Got %s." % parser_rouge_compute_option) supervised = input_dic["supervised"] pretrain_inputs, pretrain_targets, pretrain_masked_inputs = pretrain_parsing_ops.sentence_mask_and_encode( input_dic[input_feature], max_input_len, max_target_len, max_total_words, parser_strategy, parser_masked_sentence_ratio, parser_masked_words_ratio, parser_mask_word_options_prob, parser_mask_sentence_options_prob, vocab_filename, encoder_type, parser_rouge_ngrams_size, parser_rouge_metric_type, parser_rouge_stopwords_filename, parser_rouge_compute_option, parser_rouge_noise_ratio, parser_dynamic_mask_min_ratio, shift_special_token_id) supervised_inputs = parsing_ops.encode( tf.reshape(input_dic["inputs"], [1]), max_input_len, vocab_filename, encoder_type) supervised_targets = parsing_ops.encode( tf.reshape(input_dic["targets"], [1]), max_target_len, vocab_filename, encoder_type) inputs = tf.cond(pred=supervised, true_fn=lambda: supervised_inputs, false_fn=lambda: pretrain_inputs) targets = tf.cond(pred=supervised, true_fn=lambda: supervised_targets, false_fn=lambda: pretrain_targets) masked_inputs = tf.cond(pred=supervised, true_fn=lambda: supervised_inputs, false_fn=lambda: pretrain_masked_inputs) inputs, targets, masked_inputs = utils.filter_by_length( [inputs, targets, masked_inputs], min_len_list=[None, pretrain_target_filter_min, None]) inputs = utils.add_length_bucket_id(inputs, targets, length_bucket_size, length_bucket_start_id, length_bucket_max_id) masked_inputs = utils.add_length_bucket_id(masked_inputs, targets, length_bucket_size, length_bucket_start_id, length_bucket_max_id) if add_task_id: inputs = utils.add_task_id(inputs, task_start_id + input_dic["task_id"]) masked_inputs = utils.add_task_id(masked_inputs, task_start_id + input_dic["task_id"]) output_dic = { "inputs": inputs, "targets": targets, "masked_inputs": masked_inputs } return output_dic