예제 #1
0
    def test_generate_dataset_from_tfds_processor(self, task_type):
        with tfds.testing.mock_data(num_examples=5):
            output_path = os.path.join(self.model_dir, task_type)

            processor = self.processors[task_type]()

            classifier_data_lib.generate_tf_record_from_data_file(
                processor,
                None,
                self.tokenizer,
                train_data_output_path=output_path,
                eval_data_output_path=output_path,
                test_data_output_path=output_path)
            files = tf.io.gfile.glob(output_path)
            self.assertNotEmpty(files)

            train_dataset = tf.data.TFRecordDataset(output_path)
            seq_length = 128
            label_type = tf.int64
            name_to_features = {
                "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
                "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
                "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
                "label_ids": tf.io.FixedLenFeature([], label_type),
            }
            train_dataset = train_dataset.map(
                lambda record: decode_record(record, name_to_features))

            # If data is retrieved without error, then all requirements
            # including data type/shapes are met.
            _ = next(iter(train_dataset))
예제 #2
0
def generate_classifier_dataset():
  """Generates classifier dataset and returns input meta data."""
  assert (FLAGS.input_data_dir and FLAGS.classification_task_name
          or FLAGS.tfds_params)

  if FLAGS.tokenizer_impl == "word_piece":
    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
    processor_text_fn = tokenization.convert_to_unicode
  else:
    assert FLAGS.tokenizer_impl == "sentence_piece"
    tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
    processor_text_fn = functools.partial(
        tokenization.preprocess_text, lower=FLAGS.do_lower_case)

  if FLAGS.tfds_params:
    processor = classifier_data_lib.TfdsProcessor(
        tfds_params=FLAGS.tfds_params,
        process_text_fn=processor_text_fn)
    return classifier_data_lib.generate_tf_record_from_data_file(
        processor,
        None,
        tokenizer,
        train_data_output_path=FLAGS.train_data_output_path,
        eval_data_output_path=FLAGS.eval_data_output_path,
        test_data_output_path=FLAGS.test_data_output_path,
        max_seq_length=FLAGS.max_seq_length)
  else:
    processors = {
        "cola":
            classifier_data_lib.ColaProcessor,
        "mnli":
            classifier_data_lib.MnliProcessor,
        "mrpc":
            classifier_data_lib.MrpcProcessor,
        "qnli":
            classifier_data_lib.QnliProcessor,
        "qqp": classifier_data_lib.QqpProcessor,
        "sst-2":
            classifier_data_lib.SstProcessor,
        "xnli":
            functools.partial(classifier_data_lib.XnliProcessor,
                              language=FLAGS.xnli_language),
        "paws-x":
            functools.partial(classifier_data_lib.PawsxProcessor,
                              language=FLAGS.pawsx_language)
    }
    task_name = FLAGS.classification_task_name.lower()
    if task_name not in processors:
      raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name](process_text_fn=processor_text_fn)
    return classifier_data_lib.generate_tf_record_from_data_file(
        processor,
        FLAGS.input_data_dir,
        tokenizer,
        train_data_output_path=FLAGS.train_data_output_path,
        eval_data_output_path=FLAGS.eval_data_output_path,
        test_data_output_path=FLAGS.test_data_output_path,
        max_seq_length=FLAGS.max_seq_length)
예제 #3
0
def generate_regression_dataset():
    """Generates regression dataset and returns input meta data."""
    if FLAGS.tokenizer_impl == "word_piece":
        tokenizer = tokenization.FullTokenizer(
            vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
        processor_text_fn = tokenization.convert_to_unicode
    else:
        assert FLAGS.tokenizer_impl == "sentence_piece"
        tokenizer = tokenization.FullSentencePieceTokenizer(
            FLAGS.sp_model_file)
        processor_text_fn = functools.partial(tokenization.preprocess_text,
                                              lower=FLAGS.do_lower_case)

    if FLAGS.tfds_params:
        processor = classifier_data_lib.TfdsProcessor(
            tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
        return classifier_data_lib.generate_tf_record_from_data_file(
            processor,
            None,
            tokenizer,
            train_data_output_path=FLAGS.train_data_output_path,
            eval_data_output_path=FLAGS.eval_data_output_path,
            test_data_output_path=FLAGS.test_data_output_path,
            max_seq_length=FLAGS.max_seq_length)
    else:
        raise ValueError(
            "No data processor found for the given regression task.")
예제 #4
0
def generate_tfrecords(args, dataset_dir, labels):
    """Generates tfrecords from generated tsv files"""
    processor = TextClassificationProcessor(labels)
    # save label mapping
    processor.save_label_mapping(dataset_dir)
    # get tokenizer
    tokenizer = get_tokenizer(args.model_class)
    processor_text_fn = tokenization.convert_to_unicode
    # generate tfrecords
    input_dir = os.path.join(dataset_dir, 'preprocessed')
    output_dir = os.path.join(dataset_dir, 'tfrecords')
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)
    input_meta_data = generate_tf_record_from_data_file(
        processor,
        input_dir,
        tokenizer,
        train_data_output_path=os.path.join(output_dir, 'train.tfrecords'),
        eval_data_output_path=os.path.join(output_dir, 'dev.tfrecords'),
        max_seq_length=args.max_seq_length)
    with tf.io.gfile.GFile(os.path.join(dataset_dir, 'meta.json'),
                           'w') as writer:
        writer.write(json.dumps(input_meta_data, indent=4) + '\n')
    logger.info(f'Sucessfully wrote tfrecord files to {output_dir}')
예제 #5
0
def generate_classifier_dataset():
    """Generates classifier dataset and returns input meta data."""
    if FLAGS.classification_task_name in [
            "COLA",
            "WNLI",
            "SST-2",
            "MRPC",
            "QQP",
            "STS-B",
            "MNLI",
            "QNLI",
            "RTE",
            "AX",
            "SUPERGLUE-RTE",
            "CB",
            "BoolQ",
            "WIC",
    ]:
        assert not FLAGS.input_data_dir or FLAGS.tfds_params
    else:
        assert (FLAGS.input_data_dir and FLAGS.classification_task_name
                or FLAGS.tfds_params)

    if FLAGS.tokenization == "WordPiece":
        tokenizer = tokenization.FullTokenizer(
            vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
        processor_text_fn = tokenization.convert_to_unicode
    else:
        assert FLAGS.tokenization == "SentencePiece"
        tokenizer = tokenization.FullSentencePieceTokenizer(
            FLAGS.sp_model_file)
        processor_text_fn = functools.partial(tokenization.preprocess_text,
                                              lower=FLAGS.do_lower_case)

    if FLAGS.tfds_params:
        processor = classifier_data_lib.TfdsProcessor(
            tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
        return classifier_data_lib.generate_tf_record_from_data_file(
            processor,
            None,
            tokenizer,
            train_data_output_path=FLAGS.train_data_output_path,
            eval_data_output_path=FLAGS.eval_data_output_path,
            test_data_output_path=FLAGS.test_data_output_path,
            max_seq_length=FLAGS.max_seq_length)
    else:
        processors = {
            "ax":
            classifier_data_lib.AxProcessor,
            "cola":
            classifier_data_lib.ColaProcessor,
            "imdb":
            classifier_data_lib.ImdbProcessor,
            "mnli":
            functools.partial(classifier_data_lib.MnliProcessor,
                              mnli_type=FLAGS.mnli_type),
            "mrpc":
            classifier_data_lib.MrpcProcessor,
            "qnli":
            classifier_data_lib.QnliProcessor,
            "qqp":
            classifier_data_lib.QqpProcessor,
            "rte":
            classifier_data_lib.RteProcessor,
            "sst-2":
            classifier_data_lib.SstProcessor,
            "sts-b":
            classifier_data_lib.StsBProcessor,
            "xnli":
            functools.partial(classifier_data_lib.XnliProcessor,
                              language=FLAGS.xnli_language),
            "paws-x":
            functools.partial(classifier_data_lib.PawsxProcessor,
                              language=FLAGS.pawsx_language),
            "wnli":
            classifier_data_lib.WnliProcessor,
            "xtreme-xnli":
            functools.partial(
                classifier_data_lib.XtremeXnliProcessor,
                translated_data_dir=FLAGS.translated_input_data_dir,
                only_use_en_dev=FLAGS.only_use_en_dev),
            "xtreme-paws-x":
            functools.partial(
                classifier_data_lib.XtremePawsxProcessor,
                translated_data_dir=FLAGS.translated_input_data_dir,
                only_use_en_dev=FLAGS.only_use_en_dev),
            "ax-g":
            classifier_data_lib.AXgProcessor,
            "superglue-rte":
            classifier_data_lib.SuperGLUERTEProcessor,
            "cb":
            classifier_data_lib.CBProcessor,
            "boolq":
            classifier_data_lib.BoolQProcessor,
            "wic":
            classifier_data_lib.WnliProcessor,
        }
        task_name = FLAGS.classification_task_name.lower()
        if task_name not in processors:
            raise ValueError("Task not found: %s" % (task_name))

        processor = processors[task_name](process_text_fn=processor_text_fn)
        return classifier_data_lib.generate_tf_record_from_data_file(
            processor,
            FLAGS.input_data_dir,
            tokenizer,
            train_data_output_path=FLAGS.train_data_output_path,
            eval_data_output_path=FLAGS.eval_data_output_path,
            test_data_output_path=FLAGS.test_data_output_path,
            max_seq_length=FLAGS.max_seq_length)