def train_pet( ensemble_model_config: WrapperConfig, ensemble_train_config: TrainConfig, ensemble_eval_config: EvalConfig, final_model_config: WrapperConfig, final_train_config: TrainConfig, final_eval_config: EvalConfig, pattern_ids: List[int], output_dir: str, ensemble_repetitions: int = 3, final_repetitions: int = 1, reduction: str = "wmean", train_data: List[InputExample] = None, unlabeled_data: List[InputExample] = None, dev_data: List[InputExample] = None, test_data: List[InputExample] = None, do_train: bool = True, do_eval: bool = True, no_distillation: bool = False, seed: int = 42, overwrite_dir: bool = False, save_model=False, local_rank=-1, ): """ Train and evaluate a new PET model for a given task. :param ensemble_model_config: the model configuration for each model corresponding to an individual PVP :param ensemble_train_config: the training configuration for each model corresponding to an individual PVP :param ensemble_eval_config: the evaluation configuration for each model corresponding to an individual PVP :param final_model_config: the model configuration for the final distilled sequence classifier :param final_train_config: the training configuration for the final distilled sequence classifier :param final_eval_config: the evaluation configuration for the final distilled sequence classifier :param pattern_ids: the ids of all PVPs to use :param output_dir: the output directory :param ensemble_repetitions: the number of training repetitions for each model corresponding to an individual PVP :param final_repetitions: the number of training repetitions for the final distilled sequence classifier :param reduction: the reduction strategy for merging predictions, either 'mean' or 'wmean' :param train_data: the training examples to use :param unlabeled_data: the unlabeled examples to use :param dev_data: the evaluation examples to use :param do_train: whether to perform training :param do_eval: whether to perform evaluation :param no_distillation: if true, no distillation is performed :param seed: the random seed to use """ # Step 1: Train an ensemble of models corresponding to individual patterns final_results = train_pet_ensemble( ensemble_model_config, ensemble_train_config, ensemble_eval_config, pattern_ids, output_dir, repetitions=ensemble_repetitions, train_data=train_data, unlabeled_data=unlabeled_data, dev_data=dev_data, test_data=test_data, do_train=do_train, do_eval=do_eval, save_unlabeled_logits=not no_distillation, seed=seed, overwrite_dir=overwrite_dir, save_model=save_model, local_rank=local_rank, ) if no_distillation: return final_results # Step 2: Merge the annotations created by each individual model logits_file = os.path.join(output_dir, "unlabeled_logits.txt") merge_logits(output_dir, logits_file, reduction) logits = LogitsList.load(logits_file).logits assert len(logits) == len(unlabeled_data) logger.info("Got {} logits from file {}".format(len(logits), logits_file)) for example, example_logits in zip(unlabeled_data, logits): example.logits = example_logits # Step 3: Train the final sequence classifier model final_model_config.wrapper_type = SEQUENCE_CLASSIFIER_WRAPPER final_train_config.use_logits = True return train_classifier( final_model_config, final_train_config, final_eval_config, os.path.join(output_dir, "final"), repetitions=final_repetitions, train_data=train_data, unlabeled_data=unlabeled_data, dev_data=dev_data, test_data=test_data, do_train=do_train, do_eval=do_eval, seed=seed, local_rank=local_rank, )
def train_ipet( ensemble_model_config: WrapperConfig, ensemble_train_config: TrainConfig, ensemble_eval_config: EvalConfig, ipet_config: IPetConfig, final_model_config: WrapperConfig, final_train_config: TrainConfig, final_eval_config: EvalConfig, pattern_ids: List[int], output_dir: str, ensemble_repetitions: int = 3, final_repetitions: int = 1, reduction: str = "wmean", train_data: List[InputExample] = None, unlabeled_data: List[InputExample] = None, dev_data: List[InputExample] = None, test_data: List[InputExample] = None, do_train: bool = True, do_eval: bool = True, seed: int = 42, overwrite_dir: bool = False, save_model=False, local_rank=-1, ): """ Train and evaluate a new iPET model for a given task. :param ensemble_model_config: the model configuration for each model corresponding to an individual PVP :param ensemble_train_config: the training configuration for each model corresponding to an individual PVP :param ensemble_eval_config: the evaluation configuration for each model corresponding to an individual PVP :param ipet_config: the iPET training configuration :param final_model_config: the model configuration for the final distilled sequence classifier :param final_train_config: the training configuration for the final distilled sequence classifier :param final_eval_config: the evaluation configuration for the final distilled sequence classifier :param pattern_ids: the ids of all PVPs to use :param output_dir: the output directory :param ensemble_repetitions: the number of training repetitions for each model corresponding to an individual PVP :param final_repetitions: the number of training repetitions for the final distilled sequence classifier :param reduction: the reduction strategy for merging predictions, either 'mean' or 'wmean' :param train_data: the training examples to use :param unlabeled_data: the unlabeled examples to use :param dev_data: the evaluation examples to use :param do_train: whether to perform training :param do_eval: whether to perform evaluation :param seed: the random seed to use """ for gen in range(ipet_config.generations): gen_output_dir = os.path.join(output_dir, f"g{gen}") # Step 1: Train an ensemble of models corresponding to individual patterns ipet_data_dir = os.path.join( output_dir, f"g{gen - 1}", "next-gen-train-data") if gen > 0 else None train_pet_ensemble( ensemble_model_config, ensemble_train_config, ensemble_eval_config, pattern_ids, gen_output_dir, ipet_data_dir=ipet_data_dir, repetitions=ensemble_repetitions, train_data=train_data, unlabeled_data=unlabeled_data, dev_data=dev_data, test_data=test_data, do_train=do_train, do_eval=do_eval, save_unlabeled_logits=True, overwrite_dir=overwrite_dir, save_model=save_model, local_rank=local_rank, ) # Step 2: Use the model to annotate examples for the next generation original_data_size = len( train_data) if train_data else 10 / ipet_config.scale_factor num_new_examples = int(original_data_size * (ipet_config.scale_factor**(gen + 1)) - len(train_data)) generate_ipet_train_sets( train_data=train_data, unlabeled_data=unlabeled_data, labels=ensemble_model_config.label_list, logits_dir=gen_output_dir, output_dir=os.path.join(gen_output_dir, "next-gen-train-data"), reduction=reduction, num_new_examples=num_new_examples, logits_percentage=ipet_config.logits_percentage, n_most_likely=ipet_config.n_most_likely if gen == 0 else -1, seed=seed, local_rank=local_rank, ) # Step 3: Merge the annotations created by each individual model logits_dir = os.path.join(output_dir, f"g{ipet_config.generations - 1}") logits_file = os.path.join(logits_dir, "unlabeled_logits.txt") if local_rank in [-1, 0]: merge_logits(logits_dir, logits_file, reduction) torch.distributed.barrier() logits = LogitsList.load(logits_file).logits assert len(logits) == len(unlabeled_data) logger.info("Got {} logits from file {}".format(len(logits), logits_file)) for example, example_logits in zip(unlabeled_data, logits): example.logits = example_logits # Step 4: Train the final sequence classifier model final_model_config.wrapper_type = SEQUENCE_CLASSIFIER_WRAPPER final_train_config.use_logits = True final_results = train_classifier( final_model_config, final_train_config, final_eval_config, os.path.join(output_dir, "final"), repetitions=final_repetitions, train_data=train_data, unlabeled_data=unlabeled_data, dev_data=dev_data, test_data=test_data, do_train=do_train, do_eval=do_eval, local_rank=local_rank, ) return final_results