示例#1
0
def train_entry(config):
    from models import BiDAF

    with open(config.word_emb_file, "rb") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "rb") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)

    print("Building model...")

    train_dataset = get_loader(config.train_record_file, config.batch_size)
    dev_dataset = get_loader(config.dev_record_file, config.batch_size)

    c_vocab_size, c_emb_size = char_mat.shape
    model = BiDAF(word_mat,
                  w_embedding_size=300,
                  c_embeding_size=c_emb_size,
                  c_vocab_size=c_vocab_size,
                  hidden_size=100,
                  drop_prob=0.2).to(device)
    if config.pretrained:
        print("load pre-trained model")
        state_dict = torch.load(config.save_path, map_location="cpu")
        model.load_state_dict(state_dict)

    ema = EMA(config.decay)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)

    parameters = filter(lambda param: param.requires_grad, model.parameters())
    optimizer = optim.Adadelta(lr=0.5, params=parameters)
    best_f1 = 0
    best_em = 0
    patience = 0
    for iter in range(config.num_epoch):
        train(model, optimizer, train_dataset, dev_dataset, dev_eval_file,
              iter, ema)
        ema.assign(model)
        metrics = test(model, dev_dataset, dev_eval_file,
                       (iter + 1) * len(train_dataset))
        dev_f1 = metrics["f1"]
        dev_em = metrics["exact_match"]
        if dev_f1 < best_f1 and dev_em < best_em:
            patience += 1
            if patience > config.early_stop:
                break
        else:
            patience = 0
            best_f1 = max(best_f1, dev_f1)
            best_em = max(best_em, dev_em)

        fn = os.path.join(
            config.save_dir,
            "model_{}_{:.2f}_{:.2f}.pt".format(iter, best_f1, best_em))
        torch.save(model.state_dict(), fn)
        ema.resume(model)
示例#2
0
def main(args):
    save_dir = os.path.join("./save", time.strftime("%m%d%H%M%S"))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    if args.all_data:
        data_loader = get_ext_data_loader(tokenizer,
                                          "./data/train/",
                                          shuffle=True,
                                          args=args)
    else:
        data_loader, _, _ = get_data_loader(tokenizer,
                                            "./data/train-v1.1.json",
                                            shuffle=True,
                                            args=args)
    vocab_size = len(tokenizer.vocab)
    if args.bidaf:
        print("train bidaf")
        model = BiDAF(embedding_size=args.embedding_size,
                      vocab_size=vocab_size,
                      hidden_size=args.hidden_size,
                      drop_prob=args.dropout)
    else:
        ntokens = len(tokenizer.vocab)
        model = QANet(ntokens,
                      embedding=args.embedding,
                      embedding_size=args.embedding_size,
                      hidden_size=args.hidden_size,
                      num_head=args.num_head)
    if args.load_model:
        state_dict = torch.load(args.model_path, map_location="cpu")
        model.load_state_dict(state_dict)
        print("load pre-trained model")
    device = torch.device("cuda")
    model = model.to(device)
    model.train()
    ema = EMA(model, args.decay)

    base_lr = 1
    parameters = filter(lambda param: param.requires_grad, model.parameters())
    optimizer = optim.Adam(lr=base_lr,
                           betas=(0.9, 0.999),
                           eps=1e-7,
                           weight_decay=5e-8,
                           params=parameters)
    cr = args.lr / math.log2(args.lr_warm_up_num)
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda ee: cr * math.log2(ee + 1)
        if ee < args.lr_warm_up_num else args.lr)
    step = 0
    num_batches = len(data_loader)
    avg_loss = 0
    best_f1 = 0
    for epoch in range(1, args.num_epochs + 1):
        step += 1
        start = time.time()
        model.train()
        for i, batch in enumerate(data_loader, start=1):
            c_ids, q_ids, start_positions, end_positions = batch
            c_len = torch.sum(torch.sign(c_ids), 1)
            max_c_len = torch.max(c_len)
            c_ids = c_ids[:, :max_c_len].to(device)
            q_len = torch.sum(torch.sign(q_ids), 1)
            max_q_len = torch.max(q_len)
            q_ids = q_ids[:, :max_q_len].to(device)

            start_positions = start_positions.to(device)
            end_positions = end_positions.to(device)

            optimizer.zero_grad()
            loss = model(c_ids,
                         q_ids,
                         start_positions=start_positions,
                         end_positions=end_positions)
            loss.backward()
            avg_loss = cal_running_avg_loss(loss.item(), avg_loss)
            nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            scheduler.step(step)
            ema(model, step // args.batch_size)

            batch_size = c_ids.size(0)
            step += batch_size

            msg = "{}/{} {} - ETA : {} - qa_loss: {:.4f}" \
                .format(i, num_batches, progress_bar(i, num_batches),
                        eta(start, i, num_batches),
                        avg_loss)
            print(msg, end="\r")
        if not args.debug:
            metric_dict = eval_qa(args, model)
            f1 = metric_dict["f1"]
            em = metric_dict["exact_match"]
            print("epoch: {}, final loss: {:.4f}, F1:{:.2f}, EM:{:.2f}".format(
                epoch, avg_loss, f1, em))

            if args.bidaf:
                model_name = "bidaf"
            else:
                model_name = "qanet"
            if f1 > best_f1:
                best_f1 = f1
                state_dict = model.state_dict()
                save_file = "{}_{:.2f}_{:.2f}".format(model_name, f1, em)
                path = os.path.join(save_dir, save_file)
                torch.save(state_dict, path)