def run_multiple_choice(self, model_name, task_name, fp16):
        model_args = ModelArguments(model_name_or_path=model_name,
                                    cache_dir=self.cache_dir)
        data_args = DataTrainingArguments(task_name=task_name,
                                          data_dir=self.data_dir,
                                          max_seq_length=self.max_seq_length)

        training_args = TrainingArguments(
            output_dir=os.path.join(self.output_dir, task_name),
            do_train=True,
            do_eval=True,
            per_gpu_train_batch_size=self.train_batch_size,
            per_gpu_eval_batch_size=self.eval_batch_size,
            learning_rate=self.learning_rate,
            num_train_epochs=self.num_train_epochs,
            local_rank=self.local_rank,
            overwrite_output_dir=self.overwrite_output_dir,
            gradient_accumulation_steps=self.gradient_accumulation_steps,
            fp16=fp16,
            logging_steps=self.logging_steps)

        # Setup logging
        logging.basicConfig(
            format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
            datefmt="%m/%d/%Y %H:%M:%S",
            level=logging.INFO
            if training_args.local_rank in [-1, 0] else logging.WARN,
        )
        logger.warning(
            "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
            training_args.local_rank,
            training_args.device,
            training_args.n_gpu,
            bool(training_args.local_rank != -1),
            training_args.fp16,
        )
        logger.info("Training/evaluation parameters %s", training_args)

        set_seed(training_args.seed)
        onnxruntime.set_seed(training_args.seed)

        try:
            processor = SwagProcessor()
            label_list = processor.get_labels()
            num_labels = len(label_list)
        except KeyError:
            raise ValueError("Task not found: %s" % (data_args.task_name))

        config = AutoConfig.from_pretrained(
            model_args.config_name
            if model_args.config_name else model_args.model_name_or_path,
            num_labels=num_labels,
            finetuning_task=data_args.task_name,
            cache_dir=model_args.cache_dir,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name
            if model_args.tokenizer_name else model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
        )

        model = AutoModelForMultipleChoice.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
        )

        # Get datasets
        train_dataset = (MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            processor=processor,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.train,
        ) if training_args.do_train else None)
        eval_dataset = (MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            processor=processor,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.dev,
        ) if training_args.do_eval else None)

        def compute_metrics(p: EvalPrediction) -> Dict:
            preds = np.argmax(p.predictions, axis=1)
            return {"acc": simple_accuracy(preds, p.label_ids)}

        if model_name.startswith('bert'):
            model_desc = ModelDescription([
                IODescription('input_ids', [
                    self.train_batch_size, num_labels, data_args.max_seq_length
                ],
                              torch.int64,
                              num_classes=model.config.vocab_size),
                IODescription('attention_mask', [
                    self.train_batch_size, num_labels, data_args.max_seq_length
                ],
                              torch.int64,
                              num_classes=2),
                IODescription('token_type_ids', [
                    self.train_batch_size, num_labels, data_args.max_seq_length
                ],
                              torch.int64,
                              num_classes=2),
                IODescription('labels', [self.train_batch_size, num_labels],
                              torch.int64,
                              num_classes=num_labels)
            ], [
                IODescription('loss', [], torch.float32),
                IODescription('reshaped_logits',
                              [self.train_batch_size, num_labels],
                              torch.float32)
            ])
        else:
            model_desc = ModelDescription([
                IODescription('input_ids',
                              ['batch', num_labels, 'max_seq_len_in_batch'],
                              torch.int64,
                              num_classes=model.config.vocab_size),
                IODescription('attention_mask',
                              ['batch', num_labels, 'max_seq_len_in_batch'],
                              torch.int64,
                              num_classes=2),
                IODescription('labels', ['batch', num_labels],
                              torch.int64,
                              num_classes=num_labels)
            ], [
                IODescription('loss', [], torch.float32),
                IODescription('reshaped_logits', ['batch', num_labels],
                              torch.float32)
            ])

        # Initialize the ORTTrainer within ORTTransformerTrainer
        trainer = ORTTransformerTrainer(
            model=model,
            model_desc=model_desc,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics,
        )

        # Training
        if training_args.do_train:
            trainer.train()
            trainer.save_model()

        # Evaluation
        results = {}
        if training_args.do_eval and training_args.local_rank in [-1, 0]:
            logger.info("*** Evaluate ***")

            result = trainer.evaluate()

            logger.info("***** Eval results {} *****".format(
                data_args.task_name))
            for key, value in result.items():
                logger.info("  %s = %s", key, value)

            results.update(result)

        return results
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, MetaTrainingArguments))
    model_args, data_args, training_args, metatraining_args = parser.parse_args_into_dataclasses()

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        processor = processors[data_args.task_name]()
        label_list = processor.get_labels()
        num_labels = len(label_list)
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    # BertForMultipleChoice
    model = AutoModelForMultipleChoice.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    # Get datasets
    s1_train_dataset = (
        MetaMultipleChoiceDataset(
            data_dir=os.path.join(data_args.data_dir, 'swag'),
            tokenizer=tokenizer,
            task=data_args.task_name,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.train,
            num_task=20,
            k_support=5,
            k_query=1,
        )
        if training_args.do_train
        else None
    )

    # s2_train_dataset = (
    #     MetaMultipleChoiceDataset(
    #         data_dir=os.path.join(data_args.data_dir, 'ComVE_A'),
    #         tokenizer=tokenizer,
    #         task=data_args.task_name,
    #         max_seq_length=data_args.max_seq_length,
    #         overwrite_cache=data_args.overwrite_cache,
    #         mode=Split.train,
            # num_task=100,
            # k_support=5,
            # k_query=1,
    #     )
    #     if training_args.do_train
    #     else None
    # )

    # s3_train_dataset = (
    #     MetaMultipleChoiceDataset(
    #         data_dir=os.path.join(data_args.data_dir, 'ComVE_B'),
    #         tokenizer=tokenizer,
    #         task=data_args.task_name,
    #         max_seq_length=data_args.max_seq_length,
    #         overwrite_cache=data_args.overwrite_cache,
    #         mode=Split.train,
            # num_task=100,
            # k_support=5,
            # k_query=1,
    #     )
    #     if training_args.do_train
    #     else None
    # )
    # s1_train_dataset = (
    #     MultipleChoiceDataset(
    #         data_dir=os.path.join(data_args.data_dir, 'swag'),
    #         tokenizer=tokenizer,
    #         task=data_args.task_name,
    #         max_seq_length=data_args.max_seq_length,
    #         overwrite_cache=data_args.overwrite_cache,
    #         mode=Split.train,
    #     )
    #     if training_args.do_train
    #     else None
    # )
    # eval_dataset = (
    #     MultipleChoiceDataset(
    #         data_dir=data_args.data_dir,
    #         tokenizer=tokenizer,
    #         task=data_args.task_name,
    #         max_seq_length=data_args.max_seq_length,
    #         overwrite_cache=data_args.overwrite_cache,
    #         mode=Split.test,
    #     )
    #     if training_args.do_eval
    #     else None
    # )

    target_train_dataset = (
        MultipleChoiceDataset(
            data_dir=os.path.join(data_args.data_dir, 'cqa'), 
            tokenizer=tokenizer,
            task='cqa_clf',
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.train,
        )
        if training_args.do_train
        else None
    )

    # [TODO]:Modify this...
    # target_test_dataset = (
    #     MultipleChoiceDataset(
    #         data_dir=os.path.join(data_args.data_dir, 'cqa'), 
    #         tokenizer=tokenizer,
    #         task='cqa_clf',
    #         max_seq_length=data_args.max_seq_length,
    #         overwrite_cache=data_args.overwrite_cache,
    #         mode=Split.test,
    #     )
    #     if training_args.do_train
    #     else None
    # )

    def compute_metrics(p: EvalPrediction) -> Dict:
        preds = np.argmax(p.predictions, axis=1)
        return {"acc": simple_accuracy(preds, p.label_ids)}


    # Initialize our Trainer


    # Create meta batch
    s1_db = create_batch_of_tasks(s1_train_dataset, is_shuffle = True, batch_size = metatraining_args.outer_batch_size) 
    # s2_db = create_batch_of_tasks(s2_train_dataset, is_shuffle = True, batch_size = metatraining_args.outer_batch_size) 
    # s3_db = create_batch_of_tasks(s3_train_dataset, is_shuffle = True, batch_size = metatraining_args.outer_batch_size) 

    # Define Data Loader

    def _get_train_sampler(train_dataset) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(train_dataset, torch.utils.data.IterableDataset):
            return None
        else:
            return (
                RandomSampler(train_dataset)
            )

    # s1_train_sampler = _get_train_sampler(s1_train_dataset)

    # s1_train_dataloader = DataLoader(s1_tarin_dataset,
    #  batch_size=args.train_batch_size,
    #  sampler=s1_train_sampler,
    #  collate_fn=DataCollatorWithPadding(tokenizer),
    #  drop_last=args.dataloader_drop_last)
    
    target_train_sampler = _get_train_sampler(target_train_dataset)

    target_train_dataloader = DataLoader(target_train_dataset,
    batch_size=training_args.train_batch_size,
    sampler=target_train_sampler,
    collate_fn=default_data_collator, #DataCollatorWithPadding(tokenizer),
    drop_last=training_args.dataloader_drop_last)

    
    metalearner = MetaLearner(metatraining_args, tokenizer)
    mtl_optimizer = Adam(metalearner.model.parameters(), lr=metatraining_args.mtl_update_lr)
   

    for source_idx, db in enumerate([s1_db]): # , s2_db, s3_db]):

        for step, task_batch in enumerate(db):
            # Meta-Training(FOMAML)
            f = open('log.txt', 'a')
            # print("\n")
            # print(task_batch)
            # print("\n")
            acc, loss = metalearner(task_batch)
            print('Step:', step, '\tTraining Loss | Acc:', loss, " | ",acc)
            f.write(str(acc) + '\n')

        # Fine-tuning on Target Set
        # target_batch = iter(target_train_dataloader).next()
        target_train_loss = []
        target_train_acc = []
        metalearner.model.cuda()
        metalearner.model.train()
        print(metalearner.model.parameters())

        for target_batch in tqdm.tqdm(target_train_dataloader):
            target_batch = metalearner.prepare_inputs(target_batch)
            outputs = metalearner.model(**target_batch)
            loss = outputs[0]
            loss.backward()
            metalearner.outer_optimizer.step()
            metalearner.outer_optimizer.zero_grad()
            target_train_loss.append(loss.item())

            # Compute Acc for target
            logits = F.softmax(outputs[1], dim=1)
            target_label_id = target_batch.get('labels')
            pre_label_id = torch.argmax(logits,dim=1)
            pre_label_id = pre_label_id.detach().cpu().numpy().tolist()
            target_label_id = target_label_id.detach().cpu().numpy().tolist()
            acc = accuracy_score(pre_label_id,target_label_id)
            target_train_acc.append(acc)



        print("Target Loss: ", np.mean(target_train_loss))
        print("Target Acc: ", np.mean(target_train_acc))
            
            # end fine tuning
        
    # end MML 
    
    # MTL : Normal fine tuning
    target_finetune_loss = []
    for target_batch in target_train_dataloader:
        metalearner.model.train()
        target_batch = metalearner.prepare_inputs(target_batch)
        outputs = metalearner.model(**target_batch)
        loss = outputs[0]              
        loss.backward()
        mtl_optimizer.step()
        mtl_optimizer.zero_grad()
        target_finetune_loss.append(loss.item())

    print("Target Loss: ", np.mean(target_finetune_loss))
示例#3
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, MultiLingAdapterArguments))
    model_args, data_args, training_args, adapter_args = parser.parse_args_into_dataclasses()

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        processor = processors[data_args.task_name]()
        label_list = processor.get_labels()
        num_labels = len(label_list)
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = AutoModelForMultipleChoice.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    # Setup adapters
    if adapter_args.train_adapter:
        task_name = data_args.task_name
        # check if adapter already exists, otherwise add it
        if task_name not in model.config.adapters.adapter_list(AdapterType.text_task):
            # resolve the adapter config
            adapter_config = AdapterConfig.load(
                adapter_args.adapter_config,
                non_linearity=adapter_args.adapter_non_linearity,
                reduction_factor=adapter_args.adapter_reduction_factor,
            )
            # load a pre-trained from Hub if specified
            if adapter_args.load_adapter:
                model.load_adapter(
                    adapter_args.load_adapter, AdapterType.text_task, config=adapter_config, load_as=task_name,
                )
            # otherwise, add a fresh adapter
            else:
                model.add_adapter(task_name, AdapterType.text_task, config=adapter_config)
        # optionally load a pre-trained language adapter
        if adapter_args.load_lang_adapter:
            # resolve the language adapter config
            lang_adapter_config = AdapterConfig.load(
                adapter_args.lang_adapter_config,
                non_linearity=adapter_args.lang_adapter_non_linearity,
                reduction_factor=adapter_args.lang_adapter_reduction_factor,
            )
            # load the language adapter from Hub
            lang_adapter_name = model.load_adapter(
                adapter_args.load_lang_adapter,
                AdapterType.text_lang,
                config=lang_adapter_config,
                load_as=adapter_args.language,
            )
        else:
            lang_adapter_name = None
        # Freeze all model weights except of those of this adapter
        model.train_adapter([task_name])
        # Set the adapters to be used in every forward pass
        if lang_adapter_name:
            model.set_active_adapters([lang_adapter_name, task_name])
        else:
            model.set_active_adapters([task_name])

    # Get datasets
    train_dataset = (
        MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.train,
        )
        if training_args.do_train
        else None
    )
    eval_dataset = (
        MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.dev,
        )
        if training_args.do_eval
        else None
    )

    def compute_metrics(p: EvalPrediction) -> Dict:
        preds = np.argmax(p.predictions, axis=1)
        return {"acc": simple_accuracy(preds, p.label_ids)}

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
        do_save_full_model=not adapter_args.train_adapter,
        do_save_adapters=adapter_args.train_adapter,
    )

    # Training
    if training_args.do_train:
        trainer.train(
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        result = trainer.evaluate()

        output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key, value in result.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))

                results.update(result)

    return results
    do_predict=True,
    learning_rate=learning_rate,
    num_train_epochs=num_train_epochs,
)
set_seed(training_args.seed)

# Data Preprocess
tokenizer = BertTokenizer.from_pretrained(
    tokenizer_path,
    cache_dir=tokenizer_path,
)

train_dataset = (MultipleChoiceDataset(
    data_dir=data_path,
    tokenizer=tokenizer,
    task=task_name,
    max_seq_length=max_seq_length,
    overwrite_cache=overwrite_tokenizer,
    mode=Split.train,
) if training_args.do_train else None)

eval_dataset = (MultipleChoiceDataset(
    data_dir=data_path,
    tokenizer=tokenizer,
    task=task_name,
    max_seq_length=max_seq_length,
    overwrite_cache=overwrite_tokenizer,
    mode=Split.dev,
) if training_args.do_eval else None)

test_dataset = (MultipleChoiceDataset(
    data_dir=data_path,
示例#5
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        processor = processors[data_args.task_name]()
        label_list = processor.get_labels()
        num_labels = len(label_list)
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = AutoModelForMultipleChoice.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    # Get datasets
    train_dataset = (
        MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.train,
        )
        if training_args.do_train
        else None
    )
    eval_dataset = (
        MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.dev,
        )
        if training_args.do_eval
        else None
    )

    def compute_metrics(p: EvalPrediction) -> Dict:
        preds = np.argmax(p.predictions, axis=1)
        return {"acc": simple_accuracy(preds, p.label_ids)}

    # Data collator
    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) if training_args.fp16 else None

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
        data_collator=data_collator,
    )

    # Training
    if training_args.do_train:
        trainer.train(
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        result = trainer.evaluate()

        output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key, value in result.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))

                results.update(result)

    return results
示例#6
0
def main():
    # args
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # data
    processor = processors['race']()
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # load model
    global_config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    global_model = AutoModelForMultipleChoice.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=global_config,
        cache_dir=model_args.cache_dir,
    )

    # local_model = BertForMaskedLM.from_pretrained(

    # )

    # Get datasets
    train_dataset = (MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.train,
    ) if training_args.do_train else None)
    eval_dataset = (MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.dev,
    ) if training_args.do_eval else None)

    def compute_metrics(p: EvalPrediction) -> Dict:
        preds = np.argmax(p.predictions, axis=1)
        return {"acc": simple_accuracy(preds, p.label_ids)}

    # Initialize our Trainer
    trainer = Trainer(
        model=global_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )

    # Training
    if training_args.do_train:
        trainer.train(model_path=model_args.model_name_or_path if os.path.
                      isdir(model_args.model_name_or_path) else None)
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    results = {}
    if training_args.do_eval and training_args.local_rank in [-1, 0]:
        logger.info("*** Evaluate ***")

        result = trainer.evaluate()

        output_eval_file = os.path.join(training_args.output_dir,
                                        "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key, value in result.items():
                logger.info("  %s = %s", key, value)
                writer.write("%s = %s\n" % (key, value))

            results.update(result)

    return results
示例#7
0
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = AutoModelForMultipleChoice.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    # Get datasets
    train_dataset = (
        MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.train,
        )
        if training_args.do_train
        else None
    )
    eval_dataset = (
        MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.dev,
        )
def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        processor = processors[data_args.task_name]()
        label_list = processor.get_labels()
        num_labels = len(label_list)
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = AutoModelForMultipleChoice.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    train_dataset = (MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.train,
        perturbation_type=data_args.perturbation_type,
        perturbation_num=data_args.perturbation_num_train,
        augment=data_args.augment,
        name_gender_or_race=data_args.name_gender_or_race,
    ) if training_args.do_train else None)

    eval_dataset = (MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.dev,
    ) if training_args.do_train else None)

    def compute_metrics(p: EvalPrediction) -> Dict:
        preds = np.argmax(p.predictions, axis=1)
        return {"acc": simple_accuracy(preds, p.label_ids)}

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )

    # Training
    if training_args.do_train:
        trainer.train(model_path=model_args.model_name_or_path if os.path.
                      isdir(model_args.model_name_or_path) else None)
        trainer.save_model()

        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Test
    test_dataset = MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.test,
        perturbation_type=data_args.perturbation_type,
        perturbation_num=data_args.perturbation_num_test,
        augment=data_args.augment,
        name_gender_or_race=data_args.name_gender_or_race,
    )

    predictions, label_ids, metrics = trainer.predict(test_dataset)

    predictions_file = os.path.join(training_args.output_dir,
                                    "test_predictions")
    labels_ids_file = os.path.join(training_args.output_dir, "test_labels_id")
    torch.save(predictions, predictions_file)
    torch.save(label_ids, labels_ids_file)

    examples_ids = []
    perturbated = []
    run = []

    for input_feature in test_dataset.features:
        examples_ids.append(input_feature.example_id)

    for examples in test_dataset.examples:
        perturbated.append(examples.perturbated)
        run.append(examples.run)

    examples_ids_file = os.path.join(training_args.output_dir, "examples_ids")
    torch.save(examples_ids, examples_ids_file)
    perturbated_file = os.path.join(training_args.output_dir, "perturbated")
    torch.save(perturbated, perturbated_file)
    run_file = os.path.join(training_args.output_dir, "run")
    torch.save(run, run_file)

    output_eval_file = os.path.join(training_args.output_dir,
                                    "test_results.txt")

    if trainer.is_world_master():
        with open(output_eval_file, "w") as writer:
            logger.info("***** Test results *****")
            for key, value in metrics.items():
                logger.info("  %s = %s", key, value)
                writer.write("%s = %s\n" % (key, value))
    return metrics
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments,
                               TrainingArguments, AdapterArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args, adapter_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args, adapter_args = parser.parse_args_into_dataclasses(
        )

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            f" Use --overwrite_output_dir to overcome.")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        processor = processors[data_args.task_name]()
        label_list = processor.get_labels()
        num_labels = len(label_list)
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = AutoModelForMultipleChoice.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    # Setup adapters
    from transformers.adapter_config import AdapterType

    # base_model = getattr(model, model.base_model_prefix, model)
    # base_model.set_adapter_config(AdapterType.text_task, adapter_args.adapter_config)

    from transformers.adapter_config import PfeifferConfig
    model.load_adapter("/home/theorist17/projects/adapter/adapters/MNLI/mnli",
                       "text_task",
                       config=PfeifferConfig(),
                       with_head=False)
    model.load_adapter(
        "/home/theorist17/projects/adapter/adapters/commonsenseqa/commonsenseqa",
        "text_task",
        config=PfeifferConfig(),
        with_head=False)
    model.load_adapter(
        "/home/theorist17/projects/adapter/adapters/conceptnet/conceptnet",
        "text_task",
        config=PfeifferConfig(),
        with_head=False)
    adapter_names = [["mnli", "commonsenseqa", "conceptnet"]]

    model.add_fusion(adapter_names[0], "dynamic")
    #model.base_model.set_active_adapters(adapter_names)
    #model.train_fusion(adapter_names)
    model.train_fusion(adapter_names)
    # inspect parameters of the fusion layer
    for (n, p) in model.named_parameters():
        print(n, p.requires_grad)

    # Get datasets
    train_dataset = (MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.train,
    ) if training_args.do_train else None)
    eval_dataset = (MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.dev,
    ) if training_args.do_eval else None)

    def simple_accuracy(preds, labels):
        return (preds == labels).mean()

    def compute_metrics(p: EvalPrediction) -> Dict:
        preds = np.argmax(p.predictions, axis=1)
        return {"acc": simple_accuracy(preds, p.label_ids)}

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
        do_save_full_model=False,
        do_save_adapter_fusion=True,
        adapter_names=adapter_names,
    )

    # Training
    if training_args.do_train:
        trainer.train(model_path=model_args.model_name_or_path if os.path.
                      isdir(model_args.model_name_or_path) else None)
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    eval_results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        result = trainer.evaluate()

        output_eval_file = os.path.join(training_args.output_dir,
                                        "eval_results.txt")
        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key, value in result.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))

                eval_results.update(result)

    return eval_results
示例#10
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        processor = processors[data_args.task_name]()
        label_list = processor.get_labels()
        num_labels = len(label_list)
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = AutoModelForMultipleChoice.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    if data_args.reinit_pooler:
        if model_args.model_type in ["bert", "roberta"]:
            encoder_temp = getattr(model, model_args.model_type)
            encoder_temp.pooler.dense.weight.data.normal_(
                mean=0.0, std=encoder_temp.config.initializer_range)
            encoder_temp.pooler.dense.bias.data.zero_()
            for p in encoder_temp.pooler.parameters():
                p.requires_grad = True
        elif model_args.model_type in ["xlnet", "bart", "electra"]:
            raise ValueError(
                f"{model_args.model_type} does not have a pooler at the end")
        else:
            raise NotImplementedError

    if data_args.reinit_layers > 0:
        if model_args.model_type in ["bert", "roberta", "electra"]:
            assert data_args.reinit_pooler or model_args.model_type == "electra"
            from transformers.modeling_bert import BertLayerNorm

            encoder_temp = getattr(model, model_args.model_type)
            for layer in encoder_temp.encoder.layer[-data_args.reinit_layers:]:
                for module in layer.modules():
                    if isinstance(module, (nn.Linear, nn.Embedding)):
                        module.weight.data.normal_(
                            mean=0.0,
                            std=encoder_temp.config.initializer_range)
                    elif isinstance(module, BertLayerNorm):
                        module.bias.data.zero_()
                        module.weight.data.fill_(1.0)
                    if isinstance(module,
                                  nn.Linear) and module.bias is not None:
                        module.bias.data.zero_()
        elif model_args.model_type == "xlnet":
            from transformers.modeling_xlnet import XLNetLayerNorm, XLNetRelativeAttention

            for layer in model.transformer.layer[-data_args.reinit_layers:]:
                for module in layer.modules():
                    if isinstance(module, (nn.Linear, nn.Embedding)):
                        module.weight.data.normal_(
                            mean=0.0,
                            std=model.transformer.config.initializer_range)
                        if isinstance(module,
                                      nn.Linear) and module.bias is not None:
                            module.bias.data.zero_()
                    elif isinstance(module, XLNetLayerNorm):
                        module.bias.data.zero_()
                        module.weight.data.fill_(1.0)
                    elif isinstance(module, XLNetRelativeAttention):
                        for param in [
                                module.q,
                                module.k,
                                module.v,
                                module.o,
                                module.r,
                                module.r_r_bias,
                                module.r_s_bias,
                                module.r_w_bias,
                                module.seg_embed,
                        ]:
                            param.data.normal_(
                                mean=0.0,
                                std=model.transformer.config.initializer_range)
        elif model_args.model_type == "bart":
            for layer in model.model.decoder.layers[-data_args.reinit_layers:]:
                for module in layer.modules():
                    model.model._init_weights(module)

        else:
            raise NotImplementedError

    train_dataset = (MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.train,
        solve_coref=data_args.solve_coref,
    ) if training_args.do_train else None)
    eval_dataset = (MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.dev,
        solve_coref=data_args.solve_coref,
    ) if training_args.do_eval else None)

    test_dataset = (MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.test,
        solve_coref=data_args.solve_coref,
    ) if training_args.do_predict else None)

    test_dataset_high = (MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.test,
        solve_coref=data_args.solve_coref,
        group='high',
    ) if training_args.do_predict else None)

    test_dataset_middle = (MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.test,
        solve_coref=data_args.solve_coref,
        group='middle',
    ) if training_args.do_predict else None)

    def compute_metrics(p: EvalPrediction) -> Dict:
        preds = np.argmax(p.predictions, axis=1)
        return {"acc": simple_accuracy(preds, p.label_ids)}

    # Initialize our Trainer
    if training_args.freelb:
        trainer = FreeLBTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics,
        )
    else:
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics,
        )

    # Training
    if training_args.do_train:
        trainer.train(model_path=model_args.model_name_or_path if os.path.
                      isdir(model_args.model_name_or_path) else None)
        trainer.save_model()

        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        result = trainer.evaluate()

        output_eval_file = os.path.join(training_args.output_dir,
                                        "eval_results.txt")
        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key, value in result.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))

                results.update(result)

    if training_args.do_predict:
        predictions, label_ids, metrics = trainer.predict(test_dataset)
        predictions_high, label_ids_high, metrics_high = trainer.predict(
            test_dataset_high)
        predictions_middle, label_ids_middle, metrics_middle = trainer.predict(
            test_dataset_middle)

        predictions_file = os.path.join(training_args.output_dir,
                                        "test_predictions")
        labels_ids_file = os.path.join(training_args.output_dir,
                                       "test_labels_id")

        predictions_file_high = os.path.join(training_args.output_dir,
                                             "test_predictions_high")
        labels_ids_file_high = os.path.join(training_args.output_dir,
                                            "test_labels_id_high")

        predictions_file_middle = os.path.join(training_args.output_dir,
                                               "test_predictions_middle")
        labels_ids_file_middle = os.path.join(training_args.output_dir,
                                              "test_labels_id_middle")

        torch.save(predictions, predictions_file)
        torch.save(label_ids, labels_ids_file)

        torch.save(predictions_high, predictions_file_high)
        torch.save(label_ids_high, labels_ids_file_high)

        torch.save(predictions_middle, predictions_file_middle)
        torch.save(label_ids_middle, labels_ids_file_middle)

        examples_ids = []
        for input_feature in test_dataset.features:
            examples_ids.append(input_feature.example_id)

        examples_ids_file = os.path.join(training_args.output_dir,
                                         "examples_ids")
        torch.save(examples_ids, examples_ids_file)

        output_eval_file = os.path.join(training_args.output_dir,
                                        "test_results.txt")

        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Test results *****")
                for key, value in metrics.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))
                for key, value in metrics_high.items():
                    logger.info("  high %s = %s", key, value)
                    writer.write("high %s = %s\n" % (key, value))
                for key, value in metrics_middle.items():
                    logger.info("  middle %s = %s", key, value)
                    writer.write("middle %s = %s\n" % (key, value))

    return results