コード例 #1
0
ファイル: trainer.py プロジェクト: arkhycat/mrqa
    def evaluate_model(self, epoch):
        # result directory
        result_file = os.path.join(self.args.result_dir,
                                   "dev_eval_{}.txt".format(epoch))
        fw = open(result_file, "a")
        result_dict = dict()
        for dev_file in self.dev_files:
            print(dev_file)
            file_name = dev_file.split(".")[0]
            prediction_file = os.path.join(
                self.args.result_dir,
                "epoch_{}_{}.json".format(epoch, file_name))
            file_path = os.path.join(self.args.dev_folder, dev_file)
            metrics = eval_qa(self.model,
                              file_path,
                              prediction_file,
                              args=self.args,
                              tokenizer=self.tokenizer,
                              batch_size=self.args.batch_size)
            f1 = metrics["f1"]
            fw.write("{} : {}\n".format(file_name, f1))
            result_dict[dev_file] = f1
        fw.close()

        return result_dict
コード例 #2
0
def main(args):
    save_dir = os.path.join("./save", time.strftime("%m%d%H%M%S"))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    if args.all_data:
        data_loader = get_ext_data_loader(tokenizer,
                                          "./data/train/",
                                          shuffle=True,
                                          args=args)
    else:
        data_loader, _, _ = get_data_loader(tokenizer,
                                            "./data/train-v1.1.json",
                                            shuffle=True,
                                            args=args)
    vocab_size = len(tokenizer.vocab)
    if args.bidaf:
        print("train bidaf")
        model = BiDAF(embedding_size=args.embedding_size,
                      vocab_size=vocab_size,
                      hidden_size=args.hidden_size,
                      drop_prob=args.dropout)
    else:
        ntokens = len(tokenizer.vocab)
        model = QANet(ntokens,
                      embedding=args.embedding,
                      embedding_size=args.embedding_size,
                      hidden_size=args.hidden_size,
                      num_head=args.num_head)
    if args.load_model:
        state_dict = torch.load(args.model_path, map_location="cpu")
        model.load_state_dict(state_dict)
        print("load pre-trained model")
    device = torch.device("cuda")
    model = model.to(device)
    model.train()
    ema = EMA(model, args.decay)

    base_lr = 1
    parameters = filter(lambda param: param.requires_grad, model.parameters())
    optimizer = optim.Adam(lr=base_lr,
                           betas=(0.9, 0.999),
                           eps=1e-7,
                           weight_decay=5e-8,
                           params=parameters)
    cr = args.lr / math.log2(args.lr_warm_up_num)
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda ee: cr * math.log2(ee + 1)
        if ee < args.lr_warm_up_num else args.lr)
    step = 0
    num_batches = len(data_loader)
    avg_loss = 0
    best_f1 = 0
    for epoch in range(1, args.num_epochs + 1):
        step += 1
        start = time.time()
        model.train()
        for i, batch in enumerate(data_loader, start=1):
            c_ids, q_ids, start_positions, end_positions = batch
            c_len = torch.sum(torch.sign(c_ids), 1)
            max_c_len = torch.max(c_len)
            c_ids = c_ids[:, :max_c_len].to(device)
            q_len = torch.sum(torch.sign(q_ids), 1)
            max_q_len = torch.max(q_len)
            q_ids = q_ids[:, :max_q_len].to(device)

            start_positions = start_positions.to(device)
            end_positions = end_positions.to(device)

            optimizer.zero_grad()
            loss = model(c_ids,
                         q_ids,
                         start_positions=start_positions,
                         end_positions=end_positions)
            loss.backward()
            avg_loss = cal_running_avg_loss(loss.item(), avg_loss)
            nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            scheduler.step(step)
            ema(model, step // args.batch_size)

            batch_size = c_ids.size(0)
            step += batch_size

            msg = "{}/{} {} - ETA : {} - qa_loss: {:.4f}" \
                .format(i, num_batches, progress_bar(i, num_batches),
                        eta(start, i, num_batches),
                        avg_loss)
            print(msg, end="\r")
        if not args.debug:
            metric_dict = eval_qa(args, model)
            f1 = metric_dict["f1"]
            em = metric_dict["exact_match"]
            print("epoch: {}, final loss: {:.4f}, F1:{:.2f}, EM:{:.2f}".format(
                epoch, avg_loss, f1, em))

            if args.bidaf:
                model_name = "bidaf"
            else:
                model_name = "qanet"
            if f1 > best_f1:
                best_f1 = f1
                state_dict = model.state_dict()
                save_file = "{}_{:.2f}_{:.2f}".format(model_name, f1, em)
                path = os.path.join(save_dir, save_file)
                torch.save(state_dict, path)