示例#1
0
 def parser(input_dic):
   """Parser for string dict."""
   inputs = parsing_ops.encode(
       tf.reshape(input_dic["inputs"], [1]), max_input_len, vocab_filename,
       encoder_type)
   targets = parsing_ops.encode(
       tf.reshape(input_dic["targets"], [1]), max_target_len, vocab_filename,
       encoder_type)
   inputs = utils.add_length_bucket_id(inputs, targets, length_bucket_size,
                                       length_bucket_start_id,
                                       length_bucket_max_id)
   if add_task_id:
     inputs = utils.add_task_id(inputs, task_start_id + input_dic["task_id"])
   return {"inputs": inputs, "targets": targets}
示例#2
0
def serving_input_fn(params):
  """Returns expected input spec for exported savedmodels."""
  inputs_ph = tf.placeholder(
      dtype=tf.string, shape=[params.batch_size], name="inputs")

  inputs = public_parsing_ops.encode(inputs_ph, params.max_input_len,
                                     params.vocab_filename, params.encoder_type,
                                     params.length_bucket_size > 0)
  inputs = tf.reshape(inputs, [params.batch_size, params.max_input_len])
  features = {"inputs": inputs}
  return tf.estimator.export.ServingInputReceiver(
      features=features, receiver_tensors=inputs_ph)
示例#3
0
  def parser(input_dic):
    """Parser for string dict."""

    if parser_strategy not in [
        "random", "lead", "rouge", "greedy_rouge", "continuous_rouge", "hybrid",
        "none", "dynamic_rouge"
    ]:
      raise ValueError("Invalid parser_strategy. Got %s." % parser_strategy)

    if parser_rouge_metric_type not in ["precision", "recall", "F"]:
      raise ValueError("Invalid parser_rouge_metric_type. ",
                       "Got %s." % parser_rouge_metric_type)
    if parser_rouge_compute_option not in ["standard", "deduplicate", "log"]:
      raise ValueError("Invalid parser_rouge_compute_options. ",
                       "Got %s." % parser_rouge_compute_option)

    supervised = input_dic["supervised"]

    pretrain_inputs, pretrain_targets, pretrain_masked_inputs = pretrain_parsing_ops.sentence_mask_and_encode(
        input_dic[input_feature], max_input_len, max_target_len,
        max_total_words, parser_strategy, parser_masked_sentence_ratio,
        parser_masked_words_ratio, parser_mask_word_options_prob,
        parser_mask_sentence_options_prob, vocab_filename, encoder_type,
        parser_rouge_ngrams_size, parser_rouge_metric_type,
        parser_rouge_stopwords_filename, parser_rouge_compute_option,
        parser_rouge_noise_ratio, parser_dynamic_mask_min_ratio,
        shift_special_token_id)

    supervised_inputs = parsing_ops.encode(
        tf.reshape(input_dic["inputs"], [1]), max_input_len, vocab_filename,
        encoder_type)
    supervised_targets = parsing_ops.encode(
        tf.reshape(input_dic["targets"], [1]), max_target_len, vocab_filename,
        encoder_type)

    inputs = tf.cond(pred=supervised, true_fn=lambda: supervised_inputs,
                     false_fn=lambda: pretrain_inputs)
    targets = tf.cond(pred=supervised, true_fn=lambda: supervised_targets,
                      false_fn=lambda: pretrain_targets)
    masked_inputs = tf.cond(pred=supervised, true_fn=lambda: supervised_inputs,
                            false_fn=lambda: pretrain_masked_inputs)

    inputs, targets, masked_inputs = utils.filter_by_length(
        [inputs, targets, masked_inputs],
        min_len_list=[None, pretrain_target_filter_min, None])

    inputs = utils.add_length_bucket_id(inputs, targets, length_bucket_size,
                                        length_bucket_start_id,
                                        length_bucket_max_id)
    masked_inputs = utils.add_length_bucket_id(masked_inputs, targets,
                                               length_bucket_size,
                                               length_bucket_start_id,
                                               length_bucket_max_id)

    if add_task_id:
      inputs = utils.add_task_id(inputs, task_start_id + input_dic["task_id"])
      masked_inputs = utils.add_task_id(masked_inputs,
                                        task_start_id + input_dic["task_id"])

    output_dic = {
        "inputs": inputs,
        "targets": targets,
        "masked_inputs": masked_inputs
    }

    return output_dic
示例#4
0
 def test_tf_encode(self, encoder_type):
   string = tf.constant(["the quick brown fox.", "the quick brown\n"])
   self.assertAllEqual(
       parsing_ops.encode(string, 10, _SPM_VOCAB, encoder_type),
       public_parsing_ops.encode(string, 10, _SPM_VOCAB, encoder_type))