Beispiel #1
0
def main(args):
    set_random_seed(args.random_seed)

    dump_config(args)
    device = torch.device("cuda")

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    UTTR_TOKEN = get_uttr_token()
    special_tokens_dict = {"additional_special_tokens": [UTTR_TOKEN, "[NOTA]"]}
    tokenizer.add_special_tokens(special_tokens_dict)

    if not args.random_initialization:
        bert = BertModel.from_pretrained("bert-base-uncased")
    else:
        bert = BertModel(BertConfig())
    bert.resize_token_embeddings(len(tokenizer))
    model = BertSelect(bert)
    model = torch.nn.DataParallel(model)
    model.to(device)

    raw_dd_train, raw_dd_dev = get_dd_corpus("train"), get_dd_corpus(
        "validation")
    raw_dd_train = get_dd_corpus("train")
    raw_dd_dev = get_dd_corpus("validation")

    print("Load begin!")
    if args.uw_unk_ratio != 0.0:
        with open(args.uw_unk_dump_fname.format(args.uw_unk_ratio, "train"),
                  "rb") as f:
            unk_train_dump = pickle.load(f)
        with open(args.uw_unk_dump_fname.format(args.uw_unk_ratio, "dev"),
                  "rb") as f:
            unk_dev_dump = pickle.load(f)
    else:
        unk_train_dump, unk_dev_dump = None, None

    train_dataset = SelectionDataset(
        raw_dd_train,
        tokenizer,
        "train",
        300,
        args.retrieval_candidate_num,
        UTTR_TOKEN,
        "./data/selection/text_cand{}".format(args.retrieval_candidate_num) +
        "_{}.pck",
        "./data/selection/tensor_cand{}".format(args.retrieval_candidate_num) +
        "_{}.pck",
        corrupted_context_dataset=unk_train_dump,
    )

    dev_dataset = SelectionDataset(
        raw_dd_dev,
        tokenizer,
        "dev",
        300,
        args.retrieval_candidate_num,
        UTTR_TOKEN,
        "./data/selection/text_cand{}".format(args.retrieval_candidate_num) +
        "_{}.pck",
        "./data/selection/tensor_cand{}".format(args.retrieval_candidate_num) +
        "_{}.pck",
        corrupted_context_dataset=unk_dev_dump,
    )
    print("Load end!")

    trainloader = DataLoader(
        train_dataset,
        shuffle=True,
        batch_size=args.batch_size,
        drop_last=True,
    )
    validloader = DataLoader(dev_dataset,
                             batch_size=args.batch_size,
                             drop_last=True)
    """
    Training
    """
    crossentropy = CrossEntropyLoss()
    optimizer = Adam(
        model.parameters(),
        lr=args.lr,
    )
    writer = SummaryWriter(args.board_path)

    save_model(model, "begin", args.model_path)
    global_step = 0
    for epoch in range(args.epoch):
        print("Epoch {}".format(epoch))
        model.train()
        for step, batch in enumerate(tqdm(trainloader)):
            optimizer.zero_grad()
            ids_list, mask_list, label = (
                batch[:args.retrieval_candidate_num],
                batch[args.retrieval_candidate_num:2 *
                      args.retrieval_candidate_num],
                batch[2 * args.retrieval_candidate_num],
            )
            label = label.to(device)
            bs = label.shape[0]
            ids_list = (torch.cat(ids_list,
                                  1).reshape(bs * args.retrieval_candidate_num,
                                             300).to(device))
            mask_list = (torch.cat(mask_list, 1).reshape(
                bs * args.retrieval_candidate_num, 300).to(device))

            output = model(ids_list, mask_list)
            output = output.reshape(bs, -1)
            loss = crossentropy(output, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            write2tensorboard(writer, {"loss": loss}, "train", global_step)
            global_step += 1

        model.eval()
        loss_list = []
        try:
            with torch.no_grad():
                for step, batch in enumerate(tqdm(validloader)):
                    ids_list, mask_list, label = (
                        batch[:args.retrieval_candidate_num],
                        batch[args.retrieval_candidate_num:2 *
                              args.retrieval_candidate_num],
                        batch[2 * args.retrieval_candidate_num],
                    )
                    label = label.to(device)
                    bs = label.shape[0]
                    ids_list = (torch.cat(ids_list, 1).reshape(
                        bs * args.retrieval_candidate_num, 300).to(device))
                    mask_list = (torch.cat(mask_list, 1).reshape(
                        bs * args.retrieval_candidate_num, 300).to(device))
                    output = model(ids_list, mask_list)
                    output = output.reshape(bs, -1)
                    loss = crossentropy(output, label)
                    loss_list.append(loss.cpu().detach().numpy())
                    write2tensorboard(writer, {"loss": loss}, "train",
                                      global_step)
                final_loss = sum(loss_list) / len(loss_list)
                write2tensorboard(writer, {"loss": final_loss}, "valid",
                                  global_step)
        except Exception as err:
            print(err)
        save_model(model, epoch, args.model_path)