コード例 #1
0
def train():
    print("Loading data...")
    SRC, TGT, train, val, test = generate_dataloaders()

    devices = [0, 1, 2, 3]
    pad_idx = TGT.vocab.stoi["<blank>"]
    print("Making model...")
    model = make_model(len(SRC.vocab), len(TGT.vocab), N=6)
    model.cuda()
    criterion = LabelSmoothing(
        size=len(TGT.vocab), padding_idx=pad_idx, smoothing=0.1)
    criterion.cuda()
    BATCH_SIZE = 12000
    train_iter = BatchIterator(train, batch_size=BATCH_SIZE, device=torch.device(0),
                               repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                               batch_size_fn=batch_size_fn, train=True)
    valid_iter = BatchIterator(val, batch_size=BATCH_SIZE, device=torch.device(0),
                               repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                               batch_size_fn=batch_size_fn, train=False)
    model_par = nn.DataParallel(model, device_ids=devices)
    model_opt = NoamOpt(model.src_embed[0].d_model, 1, 2000,
                        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    folder = get_unique_folder("./models/", "model")
    if not(os.path.exists(folder)):
        os.mkdir(folder)
    for epoch in tqdm(range(10)):
        model_par.train()
        run_epoch((rebatch(pad_idx, b) for b in train_iter),
                  model_par,
                  MultiGPULossCompute(model.generator, criterion,
                                      devices=devices, opt=model_opt))
        model_par.eval()
        loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter),
                         model_par,
                         MultiGPULossCompute(model.generator, criterion,
                                             devices=devices, opt=None))
        torch.save(model.state_dict, os.path.join(folder, "model.bin." + str(epoch)))
        print(loss)

    for i, batch in enumerate(valid_iter):
        src = batch.src.transpose(0, 1)[:1]
        src_mask = (src != SRC.vocab.stoi["<blank>"]).unsqueeze(-2)
        out = greedy_decode(model, src, src_mask,
                            max_len=60, start_symbol=TGT.vocab.stoi["<s>"])
        print("Translation:", end="\t")
        for i in range(1, out.size(1)):
            sym = TGT.vocab.itos[out[0, i]]
            if sym == "</s>":
                break
            print(sym, end=" ")
        print()
        print("Target:", end="\t")
        for i in range(1, batch.trg.size(0)):
            sym = TGT.vocab.itos[batch.trg.data[i, 0]]
            if sym == "</s>":
                break
            print(sym, end=" ")
        print()
        break
コード例 #2
0
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)


if __name__ == "__main__":
    BATCH_SIZE = 12000
    parser = argparse.ArgumentParser()
    parser.add_argument('model_name')
    parser.add_argument('log_name')

    args = parser.parse_args()
    model_file = open(args.model_name, 'rb')
    log_file = open(args.log_name, 'w+')

    print("Loading data...")
    SRC, TGT, train, val, test = generate_dataloaders("../data_processed/")
    test_iter = BatchIterator(test,
                              batch_size=BATCH_SIZE,
                              device=torch.device(0),
                              repeat=False,
                              sort_key=lambda x: (len(x.src), len(x.trg)),
                              batch_size_fn=batch_size_fn,
                              train=False)
    print("Loading model...")
    model = make_model(len(SRC.vocab), len(TGT.vocab), N=6)
    model.load_state_dict(torch.load(args.model_name))
    model.cuda()
    model.eval()
    print("Generating test output...")
    log("Testing model stored at " + args.model_name + ".", log_file)
    for k, batch in tqdm(enumerate(test_iter)):