示例#1
0
    def test_eval(self):
        data = DATACQA(
            debug=False,
            data_dir=self.data_dir
        )
        test_examples = data.read_examples_test(os.path.join(self.data_dir, 'test.csv'))
        print('eval_examples的数量', len(test_examples))

        questions = [x.text_a for x in test_examples]
        test_features = data.convert_examples_to_features(test_examples, self.tokenizer, self.max_seq_length)
        all_input_ids = torch.tensor(data.select_field(test_features, 'input_ids'), dtype=torch.long)
        all_input_mask = torch.tensor(data.select_field(test_features, 'input_mask'), dtype=torch.long)
        all_segment_ids = torch.tensor(data.select_field(test_features, 'segment_ids'), dtype=torch.long)
        all_label = torch.tensor([f.label for f in test_features], dtype=torch.long)

        test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
        # Run prediction for full data
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=self.eval_batch_size)



        config = BertConfig.from_pretrained(self.model_name_or_path, num_labels=self.num_labels)
        model = BertForSequenceClassification.from_pretrained(
            os.path.join(self.output_dir, "pytorch_model_0.bin"), self.args, config=config)
        model.to(self.device)
        model.eval()

        inference_labels = []
        gold_labels = []
        scores = []

        for input_ids, input_mask, segment_ids, label_ids in test_dataloader:
            input_ids = input_ids.to(self.device)
            input_mask = input_mask.to(self.device)
            segment_ids = segment_ids.to(self.device)
            label_ids = label_ids.to(self.device)

            with torch.no_grad():
                logits = model(
                    input_ids=input_ids,
                    token_type_ids=segment_ids,
                    attention_mask=input_mask
                ).detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            scores.append(logits)
            inference_labels.append(np.argmax(logits, axis=1))
            gold_labels.append(label_ids)
        gold_labels = np.concatenate(gold_labels, 0)
        scores = np.concatenate(scores, 0)
        logits = np.concatenate(inference_labels, 0)


        # eval_accuracy = accuracyCQA(inference_logits, gold_labels)
        eval_mrr = compute_MRR_CQA(scores, gold_labels, questions)
        eval_5R20 = compute_5R20(scores, gold_labels, questions)
        print('eval_mrr',eval_mrr)
        print('eval_5R20',eval_5R20)
示例#2
0
    def train(self):
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        data_splitList = DATACQA.load_data(os.path.join(self.data_dir, 'train.csv'),n_splits=5)
        for split_index,each_data in enumerate(data_splitList):
            # Prepare model
            config = BertConfig.from_pretrained(self.model_name_or_path, num_labels=self.num_labels)
            model = BertForSequenceClassification.from_pretrained(self.model_name_or_path, self.args, config=config)
            model.to(self.device)

            logger.info(f'Fold {split_index + 1}')
            train_dataloader, eval_dataloader, train_examples, eval_examples = self.create_dataloader(each_data)

            num_train_optimization_steps = self.train_steps

            # Prepare optimizer

            param_optimizer = list(model.named_parameters())
            param_optimizer = [n for n in param_optimizer]

            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                 'weight_decay': self.weight_decay},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]

            optimizer = AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.adam_epsilon)
            scheduler = WarmupLinearSchedule(optimizer, warmup_steps=self.warmup_steps, t_total=self.train_steps)

            global_step = 0

            logger.info("***** Running training *****")
            logger.info("  Num examples = %d", len(train_examples))
            logger.info("  Batch size = %d", self.train_batch_size)
            logger.info("  Num steps = %d", num_train_optimization_steps)

            best_acc = 0
            model.train()
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            train_dataloader = cycle(train_dataloader)

            for step in range(num_train_optimization_steps):
                batch = next(train_dataloader)
                batch = tuple(t.to(self.device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
                tr_loss += loss.item()
                train_loss = round(tr_loss / (nb_tr_steps + 1), 4)

                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                loss.backward()

                if (nb_tr_steps + 1) % self.gradient_accumulation_steps == 0:

                    scheduler.step()
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if (step + 1) % (self.eval_steps * self.gradient_accumulation_steps) == 0:
                    tr_loss = 0
                    nb_tr_examples, nb_tr_steps = 0, 0
                    logger.info("***** Report result *****")
                    logger.info("  %s = %s", 'global_step', str(global_step))
                    logger.info("  %s = %s", 'train loss', str(train_loss))

                if self.do_eval and (step + 1) % (self.eval_steps * self.gradient_accumulation_steps) == 0:
                    for file in ['dev.csv']:
                        inference_labels = []
                        gold_labels = []
                        inference_logits = []
                        scores = []
                        questions = [x.text_a for x in eval_examples]

                        logger.info("***** Running evaluation *****")
                        logger.info("  Num examples = %d", len(eval_examples))
                        logger.info("  Batch size = %d", self.eval_batch_size)

                        # Run prediction for full data

                        model.eval()
                        eval_loss, eval_accuracy = 0, 0
                        nb_eval_steps, nb_eval_examples = 0, 0
                        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                            input_ids = input_ids.to(self.device)
                            input_mask = input_mask.to(self.device)
                            segment_ids = segment_ids.to(self.device)
                            label_ids = label_ids.to(self.device)

                            with torch.no_grad():
                                tmp_eval_loss = model(
                                    input_ids=input_ids,
                                    token_type_ids=segment_ids,
                                    attention_mask=input_mask,
                                    labels=label_ids)
                                logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask)

                            logits = logits.detach().cpu().numpy()
                            label_ids = label_ids.to('cpu').numpy()
                            inference_labels.append(np.argmax(logits, axis=1))
                            scores.append(logits)
                            gold_labels.append(label_ids)
                            inference_logits.append(logits)
                            eval_loss += tmp_eval_loss.mean().item()
                            nb_eval_examples += input_ids.size(0)
                            nb_eval_steps += 1

                        gold_labels = np.concatenate(gold_labels, 0)
                        inference_logits = np.concatenate(inference_logits, 0)
                        scores = np.concatenate(scores, 0)
                        model.train()
                        eval_loss = eval_loss / nb_eval_steps
                        eval_accuracy = accuracyCQA(inference_logits, gold_labels)
                        eval_mrr = compute_MRR_CQA(scores,gold_labels,questions)
                        eval_5R20 = compute_5R20(scores,gold_labels,questions)

                        result = {'eval_loss': eval_loss,
                                  'eval_F1': eval_accuracy,
                                  'eval_MRR':eval_mrr,
                                  'eval_5R20':eval_5R20,
                                  'global_step': global_step,
                                  'loss': train_loss}

                        output_eval_file = os.path.join(self.output_dir, "eval_results.txt")
                        with open(output_eval_file, "a") as writer:
                            for key in sorted(result.keys()):
                                logger.info("  %s = %s", key, str(result[key]))
                                writer.write("%s = %s\n" % (key, str(result[key])))
                            writer.write('*' * 80)
                            writer.write('\n')
                        if eval_accuracy > best_acc :
                            print("=" * 80)
                            print("Best F1", eval_accuracy)
                            print("Saving Model......")
                            best_acc = eval_accuracy
                            # Save a trained model
                            model_to_save = model.module if hasattr(model,'module') else model
                            output_model_file = os.path.join(self.output_dir, "pytorch_model_{}.bin".format(split_index))
                            torch.save(model_to_save.state_dict(), output_model_file)
                            print("=" * 80)
                        else:
                            print("=" * 80)

            del model
            gc.collect()