示例#1
0
def proc_and_save_pseudo_data(processor, sub_set, raw_data_dir, data_stats_dir,
                              pseudo_out_dir, tokenizer, max_seq_length,
                              trunc_keep_right, aug_ops, aug_copy_num,
                              worker_id, replicas):
    random_seed = np.random.randint(0, 100000)
    tf.logging.info('random seed: {:d}'.format(random_seed))
    np.random.seed(random_seed)
    tf.logging.info('getting examples')

    if sub_set == 'train':
        ori_examples = processor.get_train_examples(raw_data_dir)
    elif sub_set.startswith('unsup'):
        ori_examples = processor.get_unsup_examples(raw_data_dir, sub_set)
    else:
        assert False

    labels = processor.get_labels()
    _ori_examples = []
    for example in ori_examples:
        if example.label in labels:
            _ori_examples.append(example)
    ori_examples = _ori_examples
    data_total_size = len(ori_examples)
    if replicas != -1:
        ori_examples, start, end = get_data_for_worker(ori_examples, replicas,
                                                       worker_id)
    else:
        start = 0
        end = len(ori_examples)

    tf.logging.info('getting augmented examples')
    aug_examples = copy.deepcopy(ori_examples)
    aug_examples = sent_level_augment.run_augment(aug_examples,
                                                  aug_ops,
                                                  sub_set,
                                                  aug_copy_num,
                                                  start,
                                                  end,
                                                  data_total_size,
                                                  aug_only=True)

    tf.logging.info('processing ori & pseudo examples')
    examples = []
    for example in ori_examples + aug_examples:
        if example.label in labels:
            examples.append(example)
    examples = tokenize_examples(examples, tokenizer)
    tf.logging.info(
        'ori examples = {:d}, aug examples = {:d}, total = {:d}'.format(
            len(ori_examples), len(aug_examples), len(examples)))

    features = convert_examples_to_features(examples, processor.get_labels(),
                                            max_seq_length, tokenizer,
                                            trunc_keep_right, None, None)
    dump_tfrecord(features, pseudo_out_dir, worker_id)
示例#2
0
def proc_and_save_unsup_data(processor, sub_set, raw_data_dir, data_stats_dir,
                             unsup_out_dir, tokenizer, max_seq_length,
                             trunc_keep_right, aug_ops, aug_copy_num,
                             worker_id, replicas, input_file):
    # print random seed just to double check that we use different random seeds
    # for different runs so that we generate different augmented examples for the same original example.
    random_seed = np.random.randint(0, 100000)
    tf.logging.info("random seed: {:d}".format(random_seed))
    np.random.seed(random_seed)
    tf.logging.info("getting examples")

    if sub_set == "train":
        ori_examples = processor.get_train_examples(raw_data_dir)
    elif sub_set.startswith("unsup"):
        ori_examples = processor.get_unsup_examples(raw_data_dir, sub_set)
    else:
        assert False
    # this is the size before spliting data for each worker
    data_total_size = len(ori_examples)
    if replicas != -1:
        ori_examples, start, end = get_data_for_worker(ori_examples, replicas,
                                                       worker_id)
    else:
        start = 0
        end = len(ori_examples)

    tf.logging.info("getting augmented examples")
    aug_examples = copy.deepcopy(ori_examples)
    aug_examples = sent_level_augment.run_augment(aug_examples, aug_ops,
                                                  sub_set, aug_copy_num, start,
                                                  end, data_total_size,
                                                  input_file)

    labels = processor.get_labels() + ["unsup"]
    tf.logging.info("processing ori examples")
    ori_examples = tokenize_examples(ori_examples, tokenizer)
    ori_features = convert_examples_to_features(ori_examples, labels,
                                                max_seq_length, tokenizer,
                                                trunc_keep_right, None, None)

    if "idf" in aug_ops:
        data_stats = get_data_stats(data_stats_dir, sub_set, -1, replicas,
                                    ori_examples)
    else:
        data_stats = None

    tf.logging.info("processing aug examples")
    aug_examples = tokenize_examples(aug_examples, tokenizer)
    aug_features = convert_examples_to_features(aug_examples, labels,
                                                max_seq_length, tokenizer,
                                                trunc_keep_right, data_stats,
                                                aug_ops)

    unsup_features = []
    for ori_feat, aug_feat in zip(ori_features, aug_features):
        unsup_features.append(
            PairedUnsupInputFeatures(
                ori_feat.input_ids,
                ori_feat.input_mask,
                ori_feat.input_type_ids,
                aug_feat.input_ids,
                aug_feat.input_mask,
                aug_feat.input_type_ids,
            ))
    dump_tfrecord(unsup_features, unsup_out_dir, worker_id)
示例#3
0
def proc_and_save_unsup_data_xlnet(
    processor, sub_set,
    raw_data_dir, data_stats_dir, unsup_out_dir,
    tokenize_fn,
    max_seq_length, trunc_keep_right,
    aug_ops, aug_copy_num,
    worker_id, replicas):
  # print random seed just to double check that we use different random seeds
  # for different runs so that we generate different augmented examples for the same original example.
  random_seed = np.random.randint(0, 100000)
  logging.info("random seed: {:d}".format(random_seed))
  np.random.seed(random_seed)
  logging.info("getting examples")

  if sub_set == "train":
    ori_examples = processor.get_train_examples(raw_data_dir)
  elif sub_set.startswith("unsup"):
    print(sub_set)
    ori_examples = processor.get_unsup_examples(raw_data_dir, sub_set)
  else:
    assert False
  # this is the size before spliting data for each worker
  data_total_size = len(ori_examples)
  if replicas != -1:
    ori_examples, start, end = get_data_for_worker(
        ori_examples, replicas, worker_id)
  else:
    start = 0
    end = len(ori_examples)

  logging.info("getting augmented examples")
  aug_examples = copy.deepcopy(ori_examples)

  # Doesn't do anything for tf-idf augmentation
  aug_examples = sent_level_augment.run_augment(
      aug_examples, aug_ops, sub_set,
      aug_copy_num,
      start, end, data_total_size)

  labels = processor.get_labels() + ["unsup"]
  logging.info("processing ori examples with labels: {}".format(labels))

  ori_features = file_based_convert_examples_to_features(
      ori_examples, labels, max_seq_length,
      tokenize_fn, num_passes=1)

  tokenized_ori_examples = tokenize_examples(
               ori_examples, tokenization.FullTokenizer(do_lower_case=False))

  if "idf" in aug_ops:
    data_stats = get_data_stats(
        data_stats_dir, sub_set,
        -1, replicas, tokenized_ori_examples)
  else:
    data_stats = None

  logging.info("processing aug examples using aug ops {}".format(aug_ops))

  aug_features = file_based_convert_examples_to_features(
      aug_examples, labels, max_seq_length,
      tokenize_fn, num_passes=1, data_stats=data_stats, aug_ops=aug_ops)

  logging.info("{} Original Features".format(len(ori_features)))
  logging.info("{} Augmented Features".format(len(aug_features)))
  unsup_features = []

  for ori_feat, aug_feat in zip(ori_features, aug_features):
    unsup_features.append(PairedUnsupInputFeaturesXL(
        ori_feat.input_ids,
        ori_feat.input_mask,
        ori_feat.segment_ids,
        ori_feat.is_real_example,
        aug_feat.input_ids,
        aug_feat.input_mask,
        aug_feat.segment_ids,
        aug_feat.is_real_example
        ))
  logging.info("There are {} total unsupervised records".format(len(unsup_features)))
  dump_tfrecord(unsup_features, unsup_out_dir, worker_id)