예제 #1
0
파일: cli.py 프로젝트: yuweifamily/pet
def load_pet_configs(
        args) -> Tuple[WrapperConfig, pet.TrainConfig, pet.EvalConfig]:
    """
    Load the model, training and evaluation configs for PET from the given command line arguments.
    """
    model_cfg = WrapperConfig(
        model_type=args.model_type,
        model_name_or_path=args.model_name_or_path,
        wrapper_type=args.wrapper_type,
        task_name=args.task_name,
        label_list=args.label_list,
        max_seq_length=args.pet_max_seq_length,
        verbalizer_file=args.verbalizer_file,
        cache_dir=args.cache_dir,
    )

    train_cfg = pet.TrainConfig(
        device=args.device,
        per_gpu_train_batch_size=args.pet_per_gpu_train_batch_size,
        per_gpu_unlabeled_batch_size=args.pet_per_gpu_unlabeled_batch_size,
        n_gpu=args.n_gpu,
        num_train_epochs=args.pet_num_train_epochs,
        max_steps=args.pet_max_steps,
        min_steps=args.pet_min_steps,
        gradient_accumulation_steps=args.pet_gradient_accumulation_steps,
        weight_decay=args.weight_decay,
        learning_rate=args.learning_rate,
        adam_epsilon=args.adam_epsilon,
        warmup_steps=args.warmup_steps,
        max_grad_norm=args.max_grad_norm,
        lm_training=args.lm_training,
        logging_steps=args.logging_steps,
        logging_number=args.logging_number,
        alpha=args.alpha,
        local_rank=args.local_rank,
    )

    eval_cfg = pet.EvalConfig(
        device=args.device,
        n_gpu=args.n_gpu,
        metrics=args.metrics,
        per_gpu_eval_batch_size=args.pet_per_gpu_eval_batch_size,
        decoding_strategy=args.decoding_strategy,
        priming=args.priming,
        local_rank=args.local_rank,
    )

    return model_cfg, train_cfg, eval_cfg
예제 #2
0
파일: cli.py 프로젝트: yuweifamily/pet
def load_sequence_classifier_configs(
        args) -> Tuple[WrapperConfig, pet.TrainConfig, pet.EvalConfig]:
    """
    Load the model, training and evaluation configs for a regular sequence classifier from the given command line
    arguments. This classifier can either be used as a standalone model or as the final classifier for PET/iPET.
    """
    model_cfg = WrapperConfig(
        model_type=args.model_type,
        model_name_or_path=args.model_name_or_path,
        wrapper_type=SEQUENCE_CLASSIFIER_WRAPPER,
        task_name=args.task_name,
        label_list=args.label_list,
        max_seq_length=args.sc_max_seq_length,
        verbalizer_file=args.verbalizer_file,
        cache_dir=args.cache_dir,
    )

    train_cfg = pet.TrainConfig(
        device=args.device,
        per_gpu_train_batch_size=args.sc_per_gpu_train_batch_size,
        per_gpu_unlabeled_batch_size=args.sc_per_gpu_unlabeled_batch_size,
        n_gpu=args.n_gpu,
        num_train_epochs=args.sc_num_train_epochs,
        max_steps=args.sc_max_steps,
        min_steps=args.sc_min_steps,
        temperature=args.temperature,
        gradient_accumulation_steps=args.sc_gradient_accumulation_steps,
        weight_decay=args.weight_decay,
        learning_rate=args.learning_rate,
        adam_epsilon=args.adam_epsilon,
        warmup_steps=args.warmup_steps,
        logging_steps=args.logging_steps,
        logging_number=args.logging_number,
        max_grad_norm=args.max_grad_norm,
        use_logits=args.method != "sequence_classifier",
        local_rank=args.local_rank,
    )

    eval_cfg = pet.EvalConfig(
        device=args.device,
        n_gpu=args.n_gpu,
        metrics=args.metrics,
        per_gpu_eval_batch_size=args.sc_per_gpu_eval_batch_size,
        local_rank=args.local_rank,
    )

    return model_cfg, train_cfg, eval_cfg
예제 #3
0
def train_pet_ensemble(
    model_config: WrapperConfig,
    train_config: TrainConfig,
    eval_config: EvalConfig,
    pattern_ids: List[Union[str, int]],
    output_dir: str,
    ipet_data_dir: str = None,
    repetitions: int = 3,
    train_data: List[InputExample] = None,
    unlabeled_data: List[InputExample] = None,
    dev_data: List[InputExample] = None,
    test_data: List[InputExample] = None,
    do_train: bool = True,
    do_eval: bool = True,
    save_unlabeled_logits: bool = False,
    seed: int = 42,
    overwrite_dir: bool = False,
    save_model=False,
    local_rank=-1,
):
    """
    Train and evaluate an ensemble of PET models without knowledge distillation.

    :param model_config: the model configuration to use
    :param train_config: the training configuration to use
    :param eval_config: the evaluation configuration to use
    :param pattern_ids: the ids of all PVPs to use
    :param output_dir: the output directory
    :param ipet_data_dir: optional directory containing additional training data for iPET
    :param repetitions: the number of training repetitions
    :param train_data: the training examples to use
    :param unlabeled_data: the unlabeled examples to use
    :param dev_data: the evaluation examples to use
    :param do_train: whether to perform training
    :param do_eval: whether to perform evaluation
    :param save_unlabeled_logits: whether logits for unlabeled examples should be saved in a file ``logits.txt``. This
           is required for both iPET and knowledge distillation.
    :param seed: the random seed to use
    """

    results = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    set_seed(seed)

    for pattern_id in pattern_ids:
        for iteration in range(repetitions):

            model_config.pattern_id = pattern_id
            results_dict = {}

            shots = 0 if train_data is None else len(train_data)
            pattern_iter_output_dir = "{}/{}shots-{}-i{}-seed{}".format(
                output_dir, shots, pattern_name(pattern_id), iteration, seed)

            if os.path.exists(pattern_iter_output_dir) and not overwrite_dir:
                logger.warning(
                    f"Path {pattern_iter_output_dir} already exists, skipping it..."
                )
                continue

            if not os.path.exists(pattern_iter_output_dir) and local_rank in [
                    -1, 0
            ]:
                os.makedirs(pattern_iter_output_dir)

            wrapper = init_model(model_config)

            # Training
            if do_train:
                if ipet_data_dir:
                    p = os.path.join(
                        ipet_data_dir,
                        "{}-i{}-train.bin".format(pattern_name(pattern_id),
                                                  iteration))
                    ipet_train_data = InputExample.load_examples(p)
                    for example in ipet_train_data:
                        example.logits = None
                else:
                    ipet_train_data = None

                results_dict.update(
                    train_single_model(
                        wrapper,
                        train_data,
                        train_config,
                        pattern_iter_output_dir,
                        dev_data,
                        eval_config,
                        ipet_train_data=ipet_train_data,
                        unlabeled_data=unlabeled_data,
                        return_train_set_results=False,
                        local_rank=local_rank,
                    ))

                with open(os.path.join(pattern_iter_output_dir, "results.txt"),
                          "w") as fh:
                    fh.write(str(results_dict))

                if local_rank in [-1, 0]:
                    logger.info("Saving trained model at {}...".format(
                        pattern_iter_output_dir))
                    train_config.save(
                        os.path.join(pattern_iter_output_dir,
                                     "train_config.json"))
                    eval_config.save(
                        os.path.join(pattern_iter_output_dir,
                                     "eval_config.json"))
                    logger.info("Saving complete")

                    if save_unlabeled_logits:
                        logits = evaluate(wrapper,
                                          unlabeled_data,
                                          eval_config,
                                          local_rank=local_rank)["logits"]
                        save_logits(
                            os.path.join(pattern_iter_output_dir,
                                         "logits.txt"), logits)

                if not do_eval:
                    wrapper.model = None
                    wrapper = None
                    torch.cuda.empty_cache()

            # Evaluation
            if do_eval:
                logger.info("Starting evaluation...")
                try:
                    wrapper = TransformerModelWrapper.from_pretrained(
                        pattern_iter_output_dir)
                except OSError:
                    warnings.warn(
                        "No model found saved, proceeding with current model instead of best"
                    )
                    pass

                for split, eval_data in {
                        "dev": dev_data,
                        "test": test_data
                }.items():
                    if eval_data is None:
                        continue
                    eval_result = evaluate(wrapper,
                                           eval_data,
                                           eval_config,
                                           priming_data=train_data,
                                           local_rank=local_rank)

                    if local_rank in [-1, 0]:
                        save_predictions(
                            os.path.join(pattern_iter_output_dir,
                                         "predictions.jsonl"), wrapper,
                            eval_result)
                        save_logits(
                            os.path.join(pattern_iter_output_dir,
                                         "eval_logits.txt"),
                            eval_result["logits"])

                    scores = eval_result["scores"]
                    logger.info(
                        "--- {} result (pattern_id={}, iteration={}) ---".
                        format(split, pattern_id, iteration))
                    logger.info(scores)

                    results_dict[f"{split}_set_after_training"] = scores
                    with open(
                            os.path.join(pattern_iter_output_dir,
                                         "results.json"), "w") as fh:
                        json.dump(results_dict, fh)

                    for metric, value in scores.items():
                        results[split][metric][pattern_id].append(value)

                wrapper.model = None
                wrapper = None
                torch.cuda.empty_cache()

    if do_eval:
        logger.info("=== OVERALL RESULTS ===")
        results_to_log = _write_results(
            os.path.join(output_dir, "result_test.txt"), results)
    else:
        logger.info("=== ENSEMBLE TRAINING COMPLETE ===")
        results_to_log = None

    if do_train and not save_model:
        outputs = os.listdir(pattern_iter_output_dir)
        for item in outputs:
            if item.endswith(".bin"):
                os.remove(os.path.join(pattern_iter_output_dir, item))

    return results_to_log
예제 #4
0
def train_pet(
    ensemble_model_config: WrapperConfig,
    ensemble_train_config: TrainConfig,
    ensemble_eval_config: EvalConfig,
    final_model_config: WrapperConfig,
    final_train_config: TrainConfig,
    final_eval_config: EvalConfig,
    pattern_ids: List[int],
    output_dir: str,
    ensemble_repetitions: int = 3,
    final_repetitions: int = 1,
    reduction: str = "wmean",
    train_data: List[InputExample] = None,
    unlabeled_data: List[InputExample] = None,
    dev_data: List[InputExample] = None,
    test_data: List[InputExample] = None,
    do_train: bool = True,
    do_eval: bool = True,
    no_distillation: bool = False,
    seed: int = 42,
    overwrite_dir: bool = False,
    save_model=False,
    local_rank=-1,
):
    """
    Train and evaluate a new PET model for a given task.

    :param ensemble_model_config: the model configuration for each model corresponding to an individual PVP
    :param ensemble_train_config: the training configuration for each model corresponding to an individual PVP
    :param ensemble_eval_config: the evaluation configuration for each model corresponding to an individual PVP
    :param final_model_config: the model configuration for the final distilled sequence classifier
    :param final_train_config: the training configuration for the final distilled sequence classifier
    :param final_eval_config: the evaluation configuration for the final distilled sequence classifier
    :param pattern_ids: the ids of all PVPs to use
    :param output_dir: the output directory
    :param ensemble_repetitions: the number of training repetitions for each model corresponding to an individual PVP
    :param final_repetitions: the number of training repetitions for the final distilled sequence classifier
    :param reduction: the reduction strategy for merging predictions, either 'mean' or 'wmean'
    :param train_data: the training examples to use
    :param unlabeled_data: the unlabeled examples to use
    :param dev_data: the evaluation examples to use
    :param do_train: whether to perform training
    :param do_eval: whether to perform evaluation
    :param no_distillation: if true, no distillation is performed
    :param seed: the random seed to use
    """

    # Step 1: Train an ensemble of models corresponding to individual patterns
    final_results = train_pet_ensemble(
        ensemble_model_config,
        ensemble_train_config,
        ensemble_eval_config,
        pattern_ids,
        output_dir,
        repetitions=ensemble_repetitions,
        train_data=train_data,
        unlabeled_data=unlabeled_data,
        dev_data=dev_data,
        test_data=test_data,
        do_train=do_train,
        do_eval=do_eval,
        save_unlabeled_logits=not no_distillation,
        seed=seed,
        overwrite_dir=overwrite_dir,
        save_model=save_model,
        local_rank=local_rank,
    )

    if no_distillation:
        return final_results

    # Step 2: Merge the annotations created by each individual model
    logits_file = os.path.join(output_dir, "unlabeled_logits.txt")
    merge_logits(output_dir, logits_file, reduction)
    logits = LogitsList.load(logits_file).logits
    assert len(logits) == len(unlabeled_data)
    logger.info("Got {} logits from file {}".format(len(logits), logits_file))
    for example, example_logits in zip(unlabeled_data, logits):
        example.logits = example_logits

    # Step 3: Train the final sequence classifier model
    final_model_config.wrapper_type = SEQUENCE_CLASSIFIER_WRAPPER
    final_train_config.use_logits = True

    return train_classifier(
        final_model_config,
        final_train_config,
        final_eval_config,
        os.path.join(output_dir, "final"),
        repetitions=final_repetitions,
        train_data=train_data,
        unlabeled_data=unlabeled_data,
        dev_data=dev_data,
        test_data=test_data,
        do_train=do_train,
        do_eval=do_eval,
        seed=seed,
        local_rank=local_rank,
    )
예제 #5
0
def train_ipet(
    ensemble_model_config: WrapperConfig,
    ensemble_train_config: TrainConfig,
    ensemble_eval_config: EvalConfig,
    ipet_config: IPetConfig,
    final_model_config: WrapperConfig,
    final_train_config: TrainConfig,
    final_eval_config: EvalConfig,
    pattern_ids: List[int],
    output_dir: str,
    ensemble_repetitions: int = 3,
    final_repetitions: int = 1,
    reduction: str = "wmean",
    train_data: List[InputExample] = None,
    unlabeled_data: List[InputExample] = None,
    dev_data: List[InputExample] = None,
    test_data: List[InputExample] = None,
    do_train: bool = True,
    do_eval: bool = True,
    seed: int = 42,
    overwrite_dir: bool = False,
    save_model=False,
    local_rank=-1,
):
    """
    Train and evaluate a new iPET model for a given task.

    :param ensemble_model_config: the model configuration for each model corresponding to an individual PVP
    :param ensemble_train_config: the training configuration for each model corresponding to an individual PVP
    :param ensemble_eval_config: the evaluation configuration for each model corresponding to an individual PVP
    :param ipet_config: the iPET training configuration
    :param final_model_config: the model configuration for the final distilled sequence classifier
    :param final_train_config: the training configuration for the final distilled sequence classifier
    :param final_eval_config: the evaluation configuration for the final distilled sequence classifier
    :param pattern_ids: the ids of all PVPs to use
    :param output_dir: the output directory
    :param ensemble_repetitions: the number of training repetitions for each model corresponding to an individual PVP
    :param final_repetitions: the number of training repetitions for the final distilled sequence classifier
    :param reduction: the reduction strategy for merging predictions, either 'mean' or 'wmean'
    :param train_data: the training examples to use
    :param unlabeled_data: the unlabeled examples to use
    :param dev_data: the evaluation examples to use
    :param do_train: whether to perform training
    :param do_eval: whether to perform evaluation
    :param seed: the random seed to use
    """
    for gen in range(ipet_config.generations):
        gen_output_dir = os.path.join(output_dir, f"g{gen}")

        # Step 1: Train an ensemble of models corresponding to individual patterns
        ipet_data_dir = os.path.join(
            output_dir, f"g{gen - 1}",
            "next-gen-train-data") if gen > 0 else None
        train_pet_ensemble(
            ensemble_model_config,
            ensemble_train_config,
            ensemble_eval_config,
            pattern_ids,
            gen_output_dir,
            ipet_data_dir=ipet_data_dir,
            repetitions=ensemble_repetitions,
            train_data=train_data,
            unlabeled_data=unlabeled_data,
            dev_data=dev_data,
            test_data=test_data,
            do_train=do_train,
            do_eval=do_eval,
            save_unlabeled_logits=True,
            overwrite_dir=overwrite_dir,
            save_model=save_model,
            local_rank=local_rank,
        )

        # Step 2: Use the model to annotate examples for the next generation
        original_data_size = len(
            train_data) if train_data else 10 / ipet_config.scale_factor
        num_new_examples = int(original_data_size *
                               (ipet_config.scale_factor**(gen + 1)) -
                               len(train_data))
        generate_ipet_train_sets(
            train_data=train_data,
            unlabeled_data=unlabeled_data,
            labels=ensemble_model_config.label_list,
            logits_dir=gen_output_dir,
            output_dir=os.path.join(gen_output_dir, "next-gen-train-data"),
            reduction=reduction,
            num_new_examples=num_new_examples,
            logits_percentage=ipet_config.logits_percentage,
            n_most_likely=ipet_config.n_most_likely if gen == 0 else -1,
            seed=seed,
            local_rank=local_rank,
        )

    # Step 3: Merge the annotations created by each individual model
    logits_dir = os.path.join(output_dir, f"g{ipet_config.generations - 1}")
    logits_file = os.path.join(logits_dir, "unlabeled_logits.txt")
    if local_rank in [-1, 0]:
        merge_logits(logits_dir, logits_file, reduction)
    torch.distributed.barrier()
    logits = LogitsList.load(logits_file).logits
    assert len(logits) == len(unlabeled_data)
    logger.info("Got {} logits from file {}".format(len(logits), logits_file))
    for example, example_logits in zip(unlabeled_data, logits):
        example.logits = example_logits

    # Step 4: Train the final sequence classifier model
    final_model_config.wrapper_type = SEQUENCE_CLASSIFIER_WRAPPER
    final_train_config.use_logits = True

    final_results = train_classifier(
        final_model_config,
        final_train_config,
        final_eval_config,
        os.path.join(output_dir, "final"),
        repetitions=final_repetitions,
        train_data=train_data,
        unlabeled_data=unlabeled_data,
        dev_data=dev_data,
        test_data=test_data,
        do_train=do_train,
        do_eval=do_eval,
        local_rank=local_rank,
    )

    return final_results
예제 #6
0
파일: modeling.py 프로젝트: dwright37/pet
def train_pet_ensemble(model_config: WrapperConfig,
                       train_config: TrainConfig,
                       eval_config: EvalConfig,
                       pattern_ids: List[int],
                       output_dir: str,
                       ipet_data_dir: str = None,
                       repetitions: int = 3,
                       train_data: List[InputExample] = None,
                       unlabeled_data: List[InputExample] = None,
                       eval_data: List[InputExample] = None,
                       do_train: bool = True,
                       do_eval: bool = True,
                       save_unlabeled_logits: bool = False,
                       seed: int = 42):
    """
    Train and evaluate an ensemble of PET models without knowledge distillation.

    :param model_config: the model configuration to use
    :param train_config: the training configuration to use
    :param eval_config: the evaluation configuration to use
    :param pattern_ids: the ids of all PVPs to use
    :param output_dir: the output directory
    :param ipet_data_dir: optional directory containing additional training data for iPET
    :param repetitions: the number of training repetitions
    :param train_data: the training examples to use
    :param unlabeled_data: the unlabeled examples to use
    :param eval_data: the evaluation examples to use
    :param do_train: whether to perform training
    :param do_eval: whether to perform evaluation
    :param save_unlabeled_logits: whether logits for unlabeled examples should be saved in a file ``logits.txt``. This
           is required for both iPET and knowledge distillation.
    :param seed: the random seed to use
    """

    results = defaultdict(lambda: defaultdict(list))
    set_seed(seed)

    for pattern_id in pattern_ids:
        for iteration in range(repetitions):

            model_config.pattern_id = pattern_id
            results_dict = {}

            pattern_iter_output_dir = "{}/p{}-i{}".format(
                output_dir, pattern_id, iteration)

            if os.path.exists(pattern_iter_output_dir):
                logger.warning(
                    f"Path {pattern_iter_output_dir} already exists, skipping it..."
                )
                continue

            if not os.path.exists(pattern_iter_output_dir):
                os.makedirs(pattern_iter_output_dir)

            wrapper = init_model(model_config)

            # Training
            if do_train:
                if ipet_data_dir:
                    p = os.path.join(
                        ipet_data_dir,
                        'p{}-i{}-train.bin'.format(pattern_id, iteration))
                    ipet_train_data = InputExample.load_examples(p)
                    for example in ipet_train_data:
                        example.logits = None
                else:
                    ipet_train_data = None

                results_dict.update(
                    train_single_model(wrapper,
                                       train_data,
                                       train_config,
                                       eval_config,
                                       ipet_train_data=ipet_train_data,
                                       unlabeled_data=unlabeled_data))

                with open(os.path.join(pattern_iter_output_dir, 'results.txt'),
                          'w') as fh:
                    fh.write(str(results_dict))

                logger.info("Saving trained model at {}...".format(
                    pattern_iter_output_dir))
                wrapper.save(pattern_iter_output_dir)
                train_config.save(
                    os.path.join(pattern_iter_output_dir, 'train_config.json'))
                eval_config.save(
                    os.path.join(pattern_iter_output_dir, 'eval_config.json'))
                logger.info("Saving complete")

                if save_unlabeled_logits:
                    logits = evaluate(wrapper, unlabeled_data,
                                      eval_config)['logits']
                    save_logits(
                        os.path.join(pattern_iter_output_dir, 'logits.txt'),
                        logits)

                if not do_eval:
                    wrapper.model = None
                    wrapper = None
                    torch.cuda.empty_cache()

            # Evaluation
            if do_eval:
                logger.info("Starting evaluation...")
                if not wrapper:
                    wrapper = TransformerModelWrapper.from_pretrained(
                        pattern_iter_output_dir)

                eval_result = evaluate(wrapper,
                                       eval_data,
                                       eval_config,
                                       priming_data=train_data)

                save_predictions(
                    os.path.join(pattern_iter_output_dir, 'predictions.jsonl'),
                    wrapper, eval_result)
                save_logits(
                    os.path.join(pattern_iter_output_dir, 'eval_logits.txt'),
                    eval_result['logits'])

                scores = eval_result['scores']
                logger.info(
                    "--- RESULT (pattern_id={}, iteration={}) ---".format(
                        pattern_id, iteration))
                logger.info(scores)

                results_dict['test_set_after_training'] = scores
                with open(
                        os.path.join(pattern_iter_output_dir, 'results.json'),
                        'w') as fh:
                    json.dump(results_dict, fh)

                for metric, value in scores.items():
                    results[metric][pattern_id].append(value)

                wrapper.model = None
                wrapper = None
                torch.cuda.empty_cache()

    if do_eval:
        logger.info("=== OVERALL RESULTS ===")
        _write_results(os.path.join(output_dir, 'result_test.txt'), results)
    else:
        logger.info("=== ENSEMBLE TRAINING COMPLETE ===")
예제 #7
0
파일: petal.py 프로젝트: puraminy/pet
def main():
    parser = argparse.ArgumentParser()

    # required parameters
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory. The verbalizers are written to a file 'verbalizer.json' in this directory.",
    )
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the data files for the task.",
    )
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="The model type",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name",
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " +
        ", ".join(PROCESSORS.keys()),
    )

    # verbalizer search hyperparameters
    parser.add_argument(
        "--normalize",
        action="store_true",
        help=
        "Whether to normalize the loss as proposed in the paper. It is recommended to set this to 'true'.",
    )
    parser.add_argument(
        "--combine_patterns",
        action="store_true",
        help=
        "If set to true, a single joint verbalizer is searched for all patterns",
    )
    parser.add_argument(
        "--num_candidates",
        default=1000,
        type=int,
        help=
        "The number of candidate tokens to consider as verbalizers (see Section 4.1 of the paper)",
    )
    parser.add_argument(
        "--words_per_label",
        default=10,
        type=int,
        help="The number of verbalizer tokens to assign to each label",
    )
    parser.add_argument(
        "--score_fct",
        default="llr",
        choices=["llr", "ce", "random"],
        help=
        "The function used to score verbalizers. Choices are: the log-likelihood ratio loss proposed in the paper "
        "('llr'), cross-entropy loss ('ce') and 'random', which assigns random tokens to each label.",
    )

    # other optional parameters
    parser.add_argument(
        "--train_examples",
        default=50,
        type=int,
        help=
        "The total number of train examples to use, where -1 equals all examples.",
    )
    parser.add_argument(
        "--pattern_ids",
        default=[0],
        type=int,
        nargs="+",
        help="The ids of the PVPs to be used",
    )
    parser.add_argument(
        "--max_seq_length",
        default=256,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument(
        "--per_gpu_eval_batch_size",
        default=8,
        type=int,
        help="Batch size per GPU/CPU for evaluation.",
    )
    parser.add_argument(
        "--words_file",
        default=None,
        type=str,
        help=
        "Path to a file containing (unlabeled) texts from the task's domain. This text is used to compute "
        "verbalization candidates by selecting the most frequent words.",
    )
    parser.add_argument(
        "--max_words",
        default=10000,
        type=int,
        help=
        "Only the 10,000 tokens that occur most frequently in the task’s unlabeled data (see --words_file) are "
        "considered as verbalization candidates",
    )
    parser.add_argument(
        "--additional_input_examples",
        type=str,
        help=
        "An optional path to an additional set of input examples (e.g., obtained using iPET)",
    )
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        help="random seed for initialization")

    args = parser.parse_args()
    random.seed(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    with open(os.path.join(args.output_dir, "config.txt"),
              "w",
              encoding="utf8") as fh:
        json.dump(args.__dict__, fh, indent=2)

    # setup gpu/cpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()

    # prepare task
    args.task_name = args.task_name.lower()
    if args.task_name not in PROCESSORS:
        raise ValueError("Task not found: {}".format(args.task_name))
    processor = PROCESSORS[args.task_name]()
    args.label_list = processor.get_labels()
    args.cache_dir = ""
    args.do_lower_case = False
    args.verbalizer_file = None
    args.wrapper_type = "mlm"

    # get training data
    train_examples_per_label = (eq_div(args.train_examples, len(
        args.label_list)) if args.train_examples != -1 else -1)
    train_data = load_examples(
        args.task_name,
        args.data_dir,
        set_type=TRAIN_SET,
        num_examples_per_label=train_examples_per_label,
    )
    if args.additional_input_examples:
        additional_data = InputExample.load_examples(
            args.additional_input_examples)
        train_data += additional_data
        logger.info(
            f"Loaded {len(additional_data)} additional examples from {args.additional_input_examples}, total"
            f"training set size is now {len(train_data)}")

    expected = {
        label: np.array([1 if x.label == label else 0 for x in train_data])
        for label in args.label_list
    }

    if args.words_file:
        with open(args.words_file, "r", encoding="utf8") as fh:
            word_counts = Counter(fh.read().split())
    else:
        word_counts = None

    tokenizer_class = MODEL_CLASSES[args.model_type]["tokenizer"]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    word2idx = get_word_to_id_map(tokenizer,
                                  word_counts=word_counts,
                                  max_words=args.max_words)

    logits = []

    for pattern_id in args.pattern_ids:
        logger.info(f"Processing examples with pattern id {pattern_id}...")
        args.pattern_id = pattern_id

        config = WrapperConfig(
            model_type=args.model_type,
            model_name_or_path=args.model_name_or_path,
            wrapper_type="mlm",
            task_name=args.task_name,
            max_seq_length=args.max_seq_length,
            label_list=args.label_list,
            pattern_id=args.pattern_id,
        )

        wrapper = TransformerModelWrapper(config)
        wrapper.model.to(device)
        # modify all patterns so that they return a single text segment instead of two segments
        get_parts = wrapper.preprocessor.pvp.get_parts
        wrapper.preprocessor.pvp.get_parts = lambda example: (
            get_parts(example)[0] + get_parts(example)[1],
            [],
        )
        wrapper.preprocessor.pvp.convert_mlm_logits_to_cls_logits = lambda mask, x, _=None: x[
            mask >= 0]

        pattern_logits = wrapper.eval(
            train_data,
            device,
            per_gpu_eval_batch_size=args.per_gpu_eval_batch_size,
            n_gpu=args.n_gpu,
        )["logits"]
        pattern_logits = pattern_logits - np.expand_dims(
            np.max(pattern_logits, axis=1), axis=1)
        logits.append(pattern_logits)

    logger.info("Starting verbalizer search...")

    if args.combine_patterns:
        avs = AutomaticVerbalizerSearch(word2idx, args.label_list, logits,
                                        expected)
        verbalizer = avs.find_verbalizer(
            num_candidates=args.num_candidates,
            words_per_label=args.words_per_label,
            normalize=args.normalize,
            score_fct=args.score_fct,
        )
        verbalizers = {
            pattern_id: verbalizer
            for pattern_id in args.pattern_ids
        }

    else:
        verbalizers = {}
        for idx, pattern_id in enumerate(args.pattern_ids):
            avs = AutomaticVerbalizerSearch(word2idx, args.label_list,
                                            [logits[idx]], expected)
            verbalizers[pattern_id] = avs.find_verbalizer(
                num_candidates=args.num_candidates,
                words_per_label=args.words_per_label,
                normalize=args.normalize,
                score_fct=args.score_fct,
            )

    print(json.dumps(verbalizers, indent=2))
    logger.info("Verbalizer search complete, writing output...")

    with open(os.path.join(args.output_dir, "verbalizers.json"),
              "w",
              encoding="utf8") as fh:
        json.dump(verbalizers, fh, indent=2)

    logger.info("Done")