Beispiel #1
0
 def test_add_bucket_id(self):
     input_tensor = tf.constant([[101, 102, 1, 0], [103, 104, 105, 106]])
     target_tensor = tf.constant([[2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0],
                                  [2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]])
     bucket_size = 2
     length_bucket_start_id = 50
     length_bucket_max_id = 53
     input_tensor = utils.add_length_bucket_id(input_tensor, target_tensor,
                                               bucket_size,
                                               length_bucket_start_id,
                                               length_bucket_max_id)
     self.assertAllEqual([[53, 101, 102, 1], [51, 103, 104, 105]],
                         input_tensor)
Beispiel #2
0
 def parser(input_dic):
   """Parser for string dict."""
   inputs = parsing_ops.encode(
       tf.reshape(input_dic["inputs"], [1]), max_input_len, vocab_filename,
       encoder_type)
   targets = parsing_ops.encode(
       tf.reshape(input_dic["targets"], [1]), max_target_len, vocab_filename,
       encoder_type)
   inputs = utils.add_length_bucket_id(inputs, targets, length_bucket_size,
                                       length_bucket_start_id,
                                       length_bucket_max_id)
   if add_task_id:
     inputs = utils.add_task_id(inputs, task_start_id + input_dic["task_id"])
   return {"inputs": inputs, "targets": targets}
Beispiel #3
0
  def parser(input_dic):
    """Parser for string dict."""

    if parser_strategy not in [
        "random", "lead", "rouge", "greedy_rouge", "continuous_rouge", "hybrid",
        "none", "dynamic_rouge"
    ]:
      raise ValueError("Invalid parser_strategy. Got %s." % parser_strategy)

    if parser_rouge_metric_type not in ["precision", "recall", "F"]:
      raise ValueError("Invalid parser_rouge_metric_type. ",
                       "Got %s." % parser_rouge_metric_type)
    if parser_rouge_compute_option not in ["standard", "deduplicate", "log"]:
      raise ValueError("Invalid parser_rouge_compute_options. ",
                       "Got %s." % parser_rouge_compute_option)

    supervised = input_dic["supervised"]

    pretrain_inputs, pretrain_targets, pretrain_masked_inputs = pretrain_parsing_ops.sentence_mask_and_encode(
        input_dic[input_feature], max_input_len, max_target_len,
        max_total_words, parser_strategy, parser_masked_sentence_ratio,
        parser_masked_words_ratio, parser_mask_word_options_prob,
        parser_mask_sentence_options_prob, vocab_filename, encoder_type,
        parser_rouge_ngrams_size, parser_rouge_metric_type,
        parser_rouge_stopwords_filename, parser_rouge_compute_option,
        parser_rouge_noise_ratio, parser_dynamic_mask_min_ratio,
        shift_special_token_id)

    supervised_inputs = parsing_ops.encode(
        tf.reshape(input_dic["inputs"], [1]), max_input_len, vocab_filename,
        encoder_type)
    supervised_targets = parsing_ops.encode(
        tf.reshape(input_dic["targets"], [1]), max_target_len, vocab_filename,
        encoder_type)

    inputs = tf.cond(pred=supervised, true_fn=lambda: supervised_inputs,
                     false_fn=lambda: pretrain_inputs)
    targets = tf.cond(pred=supervised, true_fn=lambda: supervised_targets,
                      false_fn=lambda: pretrain_targets)
    masked_inputs = tf.cond(pred=supervised, true_fn=lambda: supervised_inputs,
                            false_fn=lambda: pretrain_masked_inputs)

    inputs, targets, masked_inputs = utils.filter_by_length(
        [inputs, targets, masked_inputs],
        min_len_list=[None, pretrain_target_filter_min, None])

    inputs = utils.add_length_bucket_id(inputs, targets, length_bucket_size,
                                        length_bucket_start_id,
                                        length_bucket_max_id)
    masked_inputs = utils.add_length_bucket_id(masked_inputs, targets,
                                               length_bucket_size,
                                               length_bucket_start_id,
                                               length_bucket_max_id)

    if add_task_id:
      inputs = utils.add_task_id(inputs, task_start_id + input_dic["task_id"])
      masked_inputs = utils.add_task_id(masked_inputs,
                                        task_start_id + input_dic["task_id"])

    output_dic = {
        "inputs": inputs,
        "targets": targets,
        "masked_inputs": masked_inputs
    }

    return output_dic