예제 #1
0
파일: cli.py 프로젝트: dwright37/pet
def main():
    parser = argparse.ArgumentParser(
        description="Command line interface for PET/iPET")

    # Required parameters
    parser.add_argument(
        "--method",
        required=True,
        choices=['pet', 'ipet', 'sequence_classifier'],
        help=
        "The training method to use. Either regular sequence classification, PET or iPET."
    )
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the data files for the task.")
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        choices=MODEL_CLASSES.keys(),
        help="The type of the pretrained language model to use")
    parser.add_argument("--model_name_or_path",
                        default=None,
                        type=str,
                        required=True,
                        help="Path to the pre-trained model or shortcut name")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        choices=PROCESSORS.keys(),
                        help="The name of the task to train/evaluate on")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written"
    )

    # PET-specific optional parameters
    parser.add_argument(
        "--wrapper_type",
        default="mlm",
        choices=WRAPPER_TYPES,
        help=
        "The wrapper type. Set this to 'mlm' for a masked language model like BERT or to 'plm' "
        "for a permuted language model like XLNet (only for PET)")
    parser.add_argument("--pattern_ids",
                        default=[0],
                        type=int,
                        nargs='+',
                        help="The ids of the PVPs to be used (only for PET)")
    parser.add_argument(
        "--lm_training",
        action='store_true',
        help="Whether to use language modeling as auxiliary task (only for PET)"
    )
    parser.add_argument(
        "--alpha",
        default=0.9999,
        type=float,
        help=
        "Weighting term for the auxiliary language modeling task (only for PET)"
    )
    parser.add_argument(
        "--soft_labels",
        action='store_true',
        help="Whether or not the training data uses soft labels (only for PET)"
    )
    parser.add_argument(
        "--temperature",
        default=2,
        type=float,
        help="Temperature used for combining PVPs (only for PET)")
    parser.add_argument(
        "--verbalizer_file",
        default=None,
        help="The path to a file to override default verbalizers (only for PET)"
    )
    parser.add_argument(
        "--reduction",
        default='wmean',
        choices=['wmean', 'mean'],
        help=
        "Reduction strategy for merging predictions from multiple PET models. Select either "
        "uniform weighting (mean) or weighting based on train set accuracy (wmean)"
    )
    parser.add_argument(
        "--decoding_strategy",
        default='default',
        choices=['default', 'ltr', 'parallel'],
        help="The decoding strategy for PET with multiple masks (only for PET)"
    )
    parser.add_argument(
        "--no_distillation",
        action='store_true',
        help="If set to true, no distillation is performed (only for PET)")
    parser.add_argument(
        "--pet_repetitions",
        default=3,
        type=int,
        help=
        "The number of times to repeat PET training and testing with different seeds."
    )
    parser.add_argument(
        "--pet_max_seq_length",
        default=256,
        type=int,
        help=
        "The maximum total input sequence length after tokenization for PET. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--pet_per_gpu_train_batch_size",
                        default=4,
                        type=int,
                        help="Batch size per GPU/CPU for PET training.")
    parser.add_argument("--pet_per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for PET evaluation.")
    parser.add_argument(
        "--pet_per_gpu_unlabeled_batch_size",
        default=4,
        type=int,
        help=
        "Batch size per GPU/CPU for auxiliary language modeling examples in PET."
    )
    parser.add_argument(
        '--pet_gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass in PET."
    )
    parser.add_argument(
        "--pet_num_train_epochs",
        default=3,
        type=float,
        help="Total number of training epochs to perform in PET.")
    parser.add_argument(
        "--pet_max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform in PET. Override num_train_epochs."
    )

    # SequenceClassifier-specific optional parameters (also used for the final PET classifier)
    parser.add_argument(
        "--sc_repetitions",
        default=1,
        type=int,
        help=
        "The number of times to repeat seq. classifier training and testing with different seeds."
    )
    parser.add_argument(
        "--sc_max_seq_length",
        default=256,
        type=int,
        help=
        "The maximum total input sequence length after tokenization for sequence classification. "
        "Sequences longer than this will be truncated, sequences shorter will be padded."
    )
    parser.add_argument(
        "--sc_per_gpu_train_batch_size",
        default=4,
        type=int,
        help="Batch size per GPU/CPU for sequence classifier training.")
    parser.add_argument(
        "--sc_per_gpu_eval_batch_size",
        default=8,
        type=int,
        help="Batch size per GPU/CPU for sequence classifier evaluation.")
    parser.add_argument(
        "--sc_per_gpu_unlabeled_batch_size",
        default=4,
        type=int,
        help=
        "Batch size per GPU/CPU for unlabeled examples used for distillation.")
    parser.add_argument(
        '--sc_gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass for "
        "sequence classifier training.")
    parser.add_argument(
        "--sc_num_train_epochs",
        default=3,
        type=float,
        help=
        "Total number of training epochs to perform for sequence classifier training."
    )
    parser.add_argument(
        "--sc_max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform for sequence classifier training. "
        "Override num_train_epochs.")

    # iPET-specific optional parameters
    parser.add_argument(
        "--ipet_generations",
        default=3,
        type=int,
        help="The number of generations to train (only for iPET)")
    parser.add_argument(
        "--ipet_logits_percentage",
        default=0.25,
        type=float,
        help=
        "The percentage of models to choose for annotating new training sets (only for iPET)"
    )
    parser.add_argument(
        "--ipet_scale_factor",
        default=5,
        type=float,
        help=
        "The factor by which to increase the training set size per generation (only for iPET)"
    )
    parser.add_argument(
        "--ipet_n_most_likely",
        default=-1,
        type=int,
        help=
        "If >0, in the first generation the n_most_likely examples per label are chosen even "
        "if their predicted label is different (only for iPET)")

    # Other optional parameters
    parser.add_argument(
        "--train_examples",
        default=-1,
        type=int,
        help=
        "The total number of train examples to use, where -1 equals all examples."
    )
    parser.add_argument(
        "--test_examples",
        default=-1,
        type=int,
        help=
        "The total number of test examples to use, where -1 equals all examples."
    )
    parser.add_argument(
        "--unlabeled_examples",
        default=-1,
        type=int,
        help=
        "The total number of unlabeled examples to use, where -1 equals all examples"
    )
    parser.add_argument(
        "--split_examples_evenly",
        action='store_true',
        help=
        "If true, train examples are not chosen randomly, but split evenly across all labels."
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where to store the pre-trained models downloaded from S3.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--do_train',
                        action='store_true',
                        help="Whether to perform training")
    parser.add_argument('--do_eval',
                        action='store_true',
                        help="Whether to perform evaluation")
    parser.add_argument('--priming',
                        action='store_true',
                        help="Whether to use priming for evaluation")
    parser.add_argument(
        "--eval_set",
        choices=['dev', 'test'],
        default='dev',
        help="Whether to perform evaluation on the dev set or the test set")

    args = parser.parse_args()
    logger.info("Parameters: {}".format(args))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) \
            and args.do_train and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))

    # Setup CUDA, GPU & distributed training
    args.device = "cuda:0" if torch.cuda.is_available(
    ) and not args.no_cuda else "cpu"
    args.n_gpu = torch.cuda.device_count()

    # Prepare task
    args.task_name = args.task_name.lower()
    if args.task_name not in PROCESSORS:
        raise ValueError("Task '{}' not found".format(args.task_name))
    processor = PROCESSORS[args.task_name]()
    args.label_list = processor.get_labels()

    train_ex_per_label, test_ex_per_label = None, None
    train_ex, test_ex = args.train_examples, args.test_examples
    if args.split_examples_evenly:
        train_ex_per_label = eq_div(args.train_examples, len(
            args.label_list)) if args.train_examples != -1 else -1
        test_ex_per_label = eq_div(args.test_examples, len(
            args.label_list)) if args.test_examples != -1 else -1
        train_ex, test_ex = None, None

    eval_set = TEST_SET if args.eval_set == 'test' else DEV_SET

    train_data = load_examples(args.task_name,
                               args.data_dir,
                               TRAIN_SET,
                               num_examples=train_ex,
                               num_examples_per_label=train_ex_per_label)
    eval_data = load_examples(args.task_name,
                              args.data_dir,
                              eval_set,
                              num_examples=test_ex,
                              num_examples_per_label=test_ex_per_label)
    unlabeled_data = load_examples(args.task_name,
                                   args.data_dir,
                                   UNLABELED_SET,
                                   num_examples=args.unlabeled_examples)

    args.metrics = METRICS.get(args.task_name, DEFAULT_METRICS)

    pet_model_cfg, pet_train_cfg, pet_eval_cfg = load_pet_configs(args)
    sc_model_cfg, sc_train_cfg, sc_eval_cfg = load_sequence_classifier_configs(
        args)
    ipet_cfg = load_ipet_config(args)

    if args.method == 'pet':
        pet.train_pet(pet_model_cfg,
                      pet_train_cfg,
                      pet_eval_cfg,
                      sc_model_cfg,
                      sc_train_cfg,
                      sc_eval_cfg,
                      pattern_ids=args.pattern_ids,
                      output_dir=args.output_dir,
                      ensemble_repetitions=args.pet_repetitions,
                      final_repetitions=args.sc_repetitions,
                      reduction=args.reduction,
                      train_data=train_data,
                      unlabeled_data=unlabeled_data,
                      eval_data=eval_data,
                      do_train=args.do_train,
                      do_eval=args.do_eval,
                      no_distillation=args.no_distillation,
                      seed=args.seed)

    elif args.method == 'ipet':
        pet.train_ipet(pet_model_cfg,
                       pet_train_cfg,
                       pet_eval_cfg,
                       ipet_cfg,
                       sc_model_cfg,
                       sc_train_cfg,
                       sc_eval_cfg,
                       pattern_ids=args.pattern_ids,
                       output_dir=args.output_dir,
                       ensemble_repetitions=args.pet_repetitions,
                       final_repetitions=args.sc_repetitions,
                       reduction=args.reduction,
                       train_data=train_data,
                       unlabeled_data=unlabeled_data,
                       eval_data=eval_data,
                       do_train=args.do_train,
                       do_eval=args.do_eval,
                       seed=args.seed)

    elif args.method == 'sequence_classifier':
        pet.train_classifier(sc_model_cfg,
                             sc_train_cfg,
                             sc_eval_cfg,
                             output_dir=args.output_dir,
                             repetitions=args.sc_repetitions,
                             train_data=train_data,
                             unlabeled_data=unlabeled_data,
                             eval_data=eval_data,
                             do_train=args.do_train,
                             do_eval=args.do_eval,
                             seed=args.seed)

    else:
        raise ValueError(f"Training method '{args.method}' not implemented")
예제 #2
0
파일: petal.py 프로젝트: puraminy/pet
def main():
    parser = argparse.ArgumentParser()

    # required parameters
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory. The verbalizers are written to a file 'verbalizer.json' in this directory.",
    )
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the data files for the task.",
    )
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="The model type",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name",
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " +
        ", ".join(PROCESSORS.keys()),
    )

    # verbalizer search hyperparameters
    parser.add_argument(
        "--normalize",
        action="store_true",
        help=
        "Whether to normalize the loss as proposed in the paper. It is recommended to set this to 'true'.",
    )
    parser.add_argument(
        "--combine_patterns",
        action="store_true",
        help=
        "If set to true, a single joint verbalizer is searched for all patterns",
    )
    parser.add_argument(
        "--num_candidates",
        default=1000,
        type=int,
        help=
        "The number of candidate tokens to consider as verbalizers (see Section 4.1 of the paper)",
    )
    parser.add_argument(
        "--words_per_label",
        default=10,
        type=int,
        help="The number of verbalizer tokens to assign to each label",
    )
    parser.add_argument(
        "--score_fct",
        default="llr",
        choices=["llr", "ce", "random"],
        help=
        "The function used to score verbalizers. Choices are: the log-likelihood ratio loss proposed in the paper "
        "('llr'), cross-entropy loss ('ce') and 'random', which assigns random tokens to each label.",
    )

    # other optional parameters
    parser.add_argument(
        "--train_examples",
        default=50,
        type=int,
        help=
        "The total number of train examples to use, where -1 equals all examples.",
    )
    parser.add_argument(
        "--pattern_ids",
        default=[0],
        type=int,
        nargs="+",
        help="The ids of the PVPs to be used",
    )
    parser.add_argument(
        "--max_seq_length",
        default=256,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument(
        "--per_gpu_eval_batch_size",
        default=8,
        type=int,
        help="Batch size per GPU/CPU for evaluation.",
    )
    parser.add_argument(
        "--words_file",
        default=None,
        type=str,
        help=
        "Path to a file containing (unlabeled) texts from the task's domain. This text is used to compute "
        "verbalization candidates by selecting the most frequent words.",
    )
    parser.add_argument(
        "--max_words",
        default=10000,
        type=int,
        help=
        "Only the 10,000 tokens that occur most frequently in the task’s unlabeled data (see --words_file) are "
        "considered as verbalization candidates",
    )
    parser.add_argument(
        "--additional_input_examples",
        type=str,
        help=
        "An optional path to an additional set of input examples (e.g., obtained using iPET)",
    )
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        help="random seed for initialization")

    args = parser.parse_args()
    random.seed(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    with open(os.path.join(args.output_dir, "config.txt"),
              "w",
              encoding="utf8") as fh:
        json.dump(args.__dict__, fh, indent=2)

    # setup gpu/cpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()

    # prepare task
    args.task_name = args.task_name.lower()
    if args.task_name not in PROCESSORS:
        raise ValueError("Task not found: {}".format(args.task_name))
    processor = PROCESSORS[args.task_name]()
    args.label_list = processor.get_labels()
    args.cache_dir = ""
    args.do_lower_case = False
    args.verbalizer_file = None
    args.wrapper_type = "mlm"

    # get training data
    train_examples_per_label = (eq_div(args.train_examples, len(
        args.label_list)) if args.train_examples != -1 else -1)
    train_data = load_examples(
        args.task_name,
        args.data_dir,
        set_type=TRAIN_SET,
        num_examples_per_label=train_examples_per_label,
    )
    if args.additional_input_examples:
        additional_data = InputExample.load_examples(
            args.additional_input_examples)
        train_data += additional_data
        logger.info(
            f"Loaded {len(additional_data)} additional examples from {args.additional_input_examples}, total"
            f"training set size is now {len(train_data)}")

    expected = {
        label: np.array([1 if x.label == label else 0 for x in train_data])
        for label in args.label_list
    }

    if args.words_file:
        with open(args.words_file, "r", encoding="utf8") as fh:
            word_counts = Counter(fh.read().split())
    else:
        word_counts = None

    tokenizer_class = MODEL_CLASSES[args.model_type]["tokenizer"]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    word2idx = get_word_to_id_map(tokenizer,
                                  word_counts=word_counts,
                                  max_words=args.max_words)

    logits = []

    for pattern_id in args.pattern_ids:
        logger.info(f"Processing examples with pattern id {pattern_id}...")
        args.pattern_id = pattern_id

        config = WrapperConfig(
            model_type=args.model_type,
            model_name_or_path=args.model_name_or_path,
            wrapper_type="mlm",
            task_name=args.task_name,
            max_seq_length=args.max_seq_length,
            label_list=args.label_list,
            pattern_id=args.pattern_id,
        )

        wrapper = TransformerModelWrapper(config)
        wrapper.model.to(device)
        # modify all patterns so that they return a single text segment instead of two segments
        get_parts = wrapper.preprocessor.pvp.get_parts
        wrapper.preprocessor.pvp.get_parts = lambda example: (
            get_parts(example)[0] + get_parts(example)[1],
            [],
        )
        wrapper.preprocessor.pvp.convert_mlm_logits_to_cls_logits = lambda mask, x, _=None: x[
            mask >= 0]

        pattern_logits = wrapper.eval(
            train_data,
            device,
            per_gpu_eval_batch_size=args.per_gpu_eval_batch_size,
            n_gpu=args.n_gpu,
        )["logits"]
        pattern_logits = pattern_logits - np.expand_dims(
            np.max(pattern_logits, axis=1), axis=1)
        logits.append(pattern_logits)

    logger.info("Starting verbalizer search...")

    if args.combine_patterns:
        avs = AutomaticVerbalizerSearch(word2idx, args.label_list, logits,
                                        expected)
        verbalizer = avs.find_verbalizer(
            num_candidates=args.num_candidates,
            words_per_label=args.words_per_label,
            normalize=args.normalize,
            score_fct=args.score_fct,
        )
        verbalizers = {
            pattern_id: verbalizer
            for pattern_id in args.pattern_ids
        }

    else:
        verbalizers = {}
        for idx, pattern_id in enumerate(args.pattern_ids):
            avs = AutomaticVerbalizerSearch(word2idx, args.label_list,
                                            [logits[idx]], expected)
            verbalizers[pattern_id] = avs.find_verbalizer(
                num_candidates=args.num_candidates,
                words_per_label=args.words_per_label,
                normalize=args.normalize,
                score_fct=args.score_fct,
            )

    print(json.dumps(verbalizers, indent=2))
    logger.info("Verbalizer search complete, writing output...")

    with open(os.path.join(args.output_dir, "verbalizers.json"),
              "w",
              encoding="utf8") as fh:
        json.dump(verbalizers, fh, indent=2)

    logger.info("Done")
예제 #3
0
파일: cli.py 프로젝트: yuweifamily/pet
def main():
    args = parser.parse_args()
    logger.info("Parameters: {}".format(args))

    # Setup CUDA, GPU & distributed training
    if args.local_rank != -1:
        args.n_gpu = 1
        args.device = args.local_rank if torch.cuda.is_available(
        ) and not args.no_cuda else "cpu"
    else:
        args.n_gpu = torch.cuda.device_count()
        args.device = "cuda" if torch.cuda.is_available(
        ) and not args.no_cuda else "cpu"

    # Prepare task
    args.task_name = args.task_name.lower()
    if args.task_name not in PROCESSORS:
        raise ValueError("Task '{}' not found".format(args.task_name))
    if args.verbalizer_file is not None:
        args.verbalizer_file = args.verbalizer_file.replace(
            "[TASK_NAME]", args.task_name)
    processor = PROCESSORS[args.task_name]()
    args.label_list = processor.get_labels()

    wandb_initalized = False

    if args.local_rank != -1:
        torch.distributed.init_process_group("nccl", rank=args.local_rank)

    for n_train_examples in args.train_examples:
        train_ex_per_label, test_ex_per_label = None, None
        train_ex, test_ex = n_train_examples, args.test_examples
        if args.split_examples_evenly:
            train_ex_per_label = eq_div(n_train_examples, len(
                args.label_list)) if n_train_examples != -1 else -1
            test_ex_per_label = eq_div(args.test_examples, len(
                args.label_list)) if args.test_examples != -1 else -1
            train_ex, test_ex = None, None

        data_dir = os.path.join(args.data_dir, args.task_name)
        output_dir = args.output_dir.replace("[TASK_NAME]", args.task_name)

        train_data = load_examples(args.task_name,
                                   data_dir,
                                   TRAIN_SET,
                                   num_examples=train_ex,
                                   num_examples_per_label=train_ex_per_label)
        dev_data = load_examples(args.task_name,
                                 data_dir,
                                 DEV_SET,
                                 num_examples=test_ex,
                                 num_examples_per_label=test_ex_per_label)
        if args.do_test:
            try:
                test_data = load_examples(
                    args.task_name,
                    data_dir,
                    TEST_SET,
                    num_examples=test_ex,
                    num_examples_per_label=test_ex_per_label)
            except (FileNotFoundError, NotImplementedError):
                test_data = None
                warnings.warn("Test data not found.")
        else:
            test_data = None
        try:
            unlabeled_data = load_examples(
                args.task_name,
                data_dir,
                UNLABELED_SET,
                num_examples=args.unlabeled_examples)
        except FileNotFoundError:
            warnings.warn("Unlabeled data not found.")
            unlabeled_data = None

        args.metrics = METRICS.get(args.task_name, DEFAULT_METRICS)

        pet_model_cfg, pet_train_cfg, pet_eval_cfg = load_pet_configs(args)
        sc_model_cfg, sc_train_cfg, sc_eval_cfg = load_sequence_classifier_configs(
            args)
        ipet_cfg = load_ipet_config(args)

        try:
            if args.method == "pet":
                final_results = pet.train_pet(
                    pet_model_cfg,
                    pet_train_cfg,
                    pet_eval_cfg,
                    sc_model_cfg,
                    sc_train_cfg,
                    sc_eval_cfg,
                    pattern_ids=args.pattern_ids,
                    output_dir=output_dir,
                    ensemble_repetitions=args.pet_repetitions,
                    final_repetitions=args.sc_repetitions,
                    reduction=args.reduction,
                    train_data=train_data,
                    unlabeled_data=unlabeled_data,
                    dev_data=dev_data,
                    test_data=test_data,
                    do_train=args.do_train,
                    do_eval=args.do_eval,
                    no_distillation=args.no_distillation,
                    seed=args.seed,
                    overwrite_dir=args.overwrite_output_dir,
                    save_model=args.save_model,
                    local_rank=args.local_rank,
                )

            elif args.method == "ipet":
                final_results = pet.train_ipet(
                    pet_model_cfg,
                    pet_train_cfg,
                    pet_eval_cfg,
                    ipet_cfg,
                    sc_model_cfg,
                    sc_train_cfg,
                    sc_eval_cfg,
                    pattern_ids=args.pattern_ids,
                    output_dir=output_dir,
                    ensemble_repetitions=args.pet_repetitions,
                    final_repetitions=args.sc_repetitions,
                    reduction=args.reduction,
                    train_data=train_data,
                    unlabeled_data=unlabeled_data,
                    dev_data=dev_data,
                    test_data=test_data,
                    do_train=args.do_train,
                    do_eval=args.do_eval,
                    seed=args.seed,
                    overwrite_dir=args.overwrite_output_dir,
                    save_model=args.save_model,
                    local_rank=args.local_rank,
                )

            elif args.method == "sequence_classifier":
                final_results = pet.train_classifier(
                    sc_model_cfg,
                    sc_train_cfg,
                    sc_eval_cfg,
                    output_dir=output_dir,
                    repetitions=args.sc_repetitions,
                    train_data=train_data,
                    unlabeled_data=unlabeled_data,
                    dev_data=dev_data,
                    test_data=test_data,
                    do_train=args.do_train,
                    do_eval=args.do_eval,
                    seed=args.seed,
                    overwrite_dir=args.overwrite_output_dir,
                    save_model=args.save_model,
                    local_rank=args.local_rank,
                )

            else:
                raise ValueError(
                    f"Training method '{args.method}' not implemented")

        except json.decoder.JSONDecodeError:
            warnings.warn("JSONDecodeError in transformers")
            continue

        if final_results is not None and args.local_rank in [-1, 0]:
            if not wandb_initalized:
                wandb.init(project=f"pvp-vs-finetuning-{args.task_name}",
                           name=naming_convention(args))
                wandb_initalized = True
            final_results["training_points"] = n_train_examples
            wandb.log(final_results)