Beispiel #1
0
def train_pet(
    ensemble_model_config: WrapperConfig,
    ensemble_train_config: TrainConfig,
    ensemble_eval_config: EvalConfig,
    final_model_config: WrapperConfig,
    final_train_config: TrainConfig,
    final_eval_config: EvalConfig,
    pattern_ids: List[int],
    output_dir: str,
    ensemble_repetitions: int = 3,
    final_repetitions: int = 1,
    reduction: str = "wmean",
    train_data: List[InputExample] = None,
    unlabeled_data: List[InputExample] = None,
    dev_data: List[InputExample] = None,
    test_data: List[InputExample] = None,
    do_train: bool = True,
    do_eval: bool = True,
    no_distillation: bool = False,
    seed: int = 42,
    overwrite_dir: bool = False,
    save_model=False,
    local_rank=-1,
):
    """
    Train and evaluate a new PET model for a given task.

    :param ensemble_model_config: the model configuration for each model corresponding to an individual PVP
    :param ensemble_train_config: the training configuration for each model corresponding to an individual PVP
    :param ensemble_eval_config: the evaluation configuration for each model corresponding to an individual PVP
    :param final_model_config: the model configuration for the final distilled sequence classifier
    :param final_train_config: the training configuration for the final distilled sequence classifier
    :param final_eval_config: the evaluation configuration for the final distilled sequence classifier
    :param pattern_ids: the ids of all PVPs to use
    :param output_dir: the output directory
    :param ensemble_repetitions: the number of training repetitions for each model corresponding to an individual PVP
    :param final_repetitions: the number of training repetitions for the final distilled sequence classifier
    :param reduction: the reduction strategy for merging predictions, either 'mean' or 'wmean'
    :param train_data: the training examples to use
    :param unlabeled_data: the unlabeled examples to use
    :param dev_data: the evaluation examples to use
    :param do_train: whether to perform training
    :param do_eval: whether to perform evaluation
    :param no_distillation: if true, no distillation is performed
    :param seed: the random seed to use
    """

    # Step 1: Train an ensemble of models corresponding to individual patterns
    final_results = train_pet_ensemble(
        ensemble_model_config,
        ensemble_train_config,
        ensemble_eval_config,
        pattern_ids,
        output_dir,
        repetitions=ensemble_repetitions,
        train_data=train_data,
        unlabeled_data=unlabeled_data,
        dev_data=dev_data,
        test_data=test_data,
        do_train=do_train,
        do_eval=do_eval,
        save_unlabeled_logits=not no_distillation,
        seed=seed,
        overwrite_dir=overwrite_dir,
        save_model=save_model,
        local_rank=local_rank,
    )

    if no_distillation:
        return final_results

    # Step 2: Merge the annotations created by each individual model
    logits_file = os.path.join(output_dir, "unlabeled_logits.txt")
    merge_logits(output_dir, logits_file, reduction)
    logits = LogitsList.load(logits_file).logits
    assert len(logits) == len(unlabeled_data)
    logger.info("Got {} logits from file {}".format(len(logits), logits_file))
    for example, example_logits in zip(unlabeled_data, logits):
        example.logits = example_logits

    # Step 3: Train the final sequence classifier model
    final_model_config.wrapper_type = SEQUENCE_CLASSIFIER_WRAPPER
    final_train_config.use_logits = True

    return train_classifier(
        final_model_config,
        final_train_config,
        final_eval_config,
        os.path.join(output_dir, "final"),
        repetitions=final_repetitions,
        train_data=train_data,
        unlabeled_data=unlabeled_data,
        dev_data=dev_data,
        test_data=test_data,
        do_train=do_train,
        do_eval=do_eval,
        seed=seed,
        local_rank=local_rank,
    )
Beispiel #2
0
def train_ipet(
    ensemble_model_config: WrapperConfig,
    ensemble_train_config: TrainConfig,
    ensemble_eval_config: EvalConfig,
    ipet_config: IPetConfig,
    final_model_config: WrapperConfig,
    final_train_config: TrainConfig,
    final_eval_config: EvalConfig,
    pattern_ids: List[int],
    output_dir: str,
    ensemble_repetitions: int = 3,
    final_repetitions: int = 1,
    reduction: str = "wmean",
    train_data: List[InputExample] = None,
    unlabeled_data: List[InputExample] = None,
    dev_data: List[InputExample] = None,
    test_data: List[InputExample] = None,
    do_train: bool = True,
    do_eval: bool = True,
    seed: int = 42,
    overwrite_dir: bool = False,
    save_model=False,
    local_rank=-1,
):
    """
    Train and evaluate a new iPET model for a given task.

    :param ensemble_model_config: the model configuration for each model corresponding to an individual PVP
    :param ensemble_train_config: the training configuration for each model corresponding to an individual PVP
    :param ensemble_eval_config: the evaluation configuration for each model corresponding to an individual PVP
    :param ipet_config: the iPET training configuration
    :param final_model_config: the model configuration for the final distilled sequence classifier
    :param final_train_config: the training configuration for the final distilled sequence classifier
    :param final_eval_config: the evaluation configuration for the final distilled sequence classifier
    :param pattern_ids: the ids of all PVPs to use
    :param output_dir: the output directory
    :param ensemble_repetitions: the number of training repetitions for each model corresponding to an individual PVP
    :param final_repetitions: the number of training repetitions for the final distilled sequence classifier
    :param reduction: the reduction strategy for merging predictions, either 'mean' or 'wmean'
    :param train_data: the training examples to use
    :param unlabeled_data: the unlabeled examples to use
    :param dev_data: the evaluation examples to use
    :param do_train: whether to perform training
    :param do_eval: whether to perform evaluation
    :param seed: the random seed to use
    """
    for gen in range(ipet_config.generations):
        gen_output_dir = os.path.join(output_dir, f"g{gen}")

        # Step 1: Train an ensemble of models corresponding to individual patterns
        ipet_data_dir = os.path.join(
            output_dir, f"g{gen - 1}",
            "next-gen-train-data") if gen > 0 else None
        train_pet_ensemble(
            ensemble_model_config,
            ensemble_train_config,
            ensemble_eval_config,
            pattern_ids,
            gen_output_dir,
            ipet_data_dir=ipet_data_dir,
            repetitions=ensemble_repetitions,
            train_data=train_data,
            unlabeled_data=unlabeled_data,
            dev_data=dev_data,
            test_data=test_data,
            do_train=do_train,
            do_eval=do_eval,
            save_unlabeled_logits=True,
            overwrite_dir=overwrite_dir,
            save_model=save_model,
            local_rank=local_rank,
        )

        # Step 2: Use the model to annotate examples for the next generation
        original_data_size = len(
            train_data) if train_data else 10 / ipet_config.scale_factor
        num_new_examples = int(original_data_size *
                               (ipet_config.scale_factor**(gen + 1)) -
                               len(train_data))
        generate_ipet_train_sets(
            train_data=train_data,
            unlabeled_data=unlabeled_data,
            labels=ensemble_model_config.label_list,
            logits_dir=gen_output_dir,
            output_dir=os.path.join(gen_output_dir, "next-gen-train-data"),
            reduction=reduction,
            num_new_examples=num_new_examples,
            logits_percentage=ipet_config.logits_percentage,
            n_most_likely=ipet_config.n_most_likely if gen == 0 else -1,
            seed=seed,
            local_rank=local_rank,
        )

    # Step 3: Merge the annotations created by each individual model
    logits_dir = os.path.join(output_dir, f"g{ipet_config.generations - 1}")
    logits_file = os.path.join(logits_dir, "unlabeled_logits.txt")
    if local_rank in [-1, 0]:
        merge_logits(logits_dir, logits_file, reduction)
    torch.distributed.barrier()
    logits = LogitsList.load(logits_file).logits
    assert len(logits) == len(unlabeled_data)
    logger.info("Got {} logits from file {}".format(len(logits), logits_file))
    for example, example_logits in zip(unlabeled_data, logits):
        example.logits = example_logits

    # Step 4: Train the final sequence classifier model
    final_model_config.wrapper_type = SEQUENCE_CLASSIFIER_WRAPPER
    final_train_config.use_logits = True

    final_results = train_classifier(
        final_model_config,
        final_train_config,
        final_eval_config,
        os.path.join(output_dir, "final"),
        repetitions=final_repetitions,
        train_data=train_data,
        unlabeled_data=unlabeled_data,
        dev_data=dev_data,
        test_data=test_data,
        do_train=do_train,
        do_eval=do_eval,
        local_rank=local_rank,
    )

    return final_results