def train(): print("Loading data...") SRC, TGT, train, val, test = generate_dataloaders() devices = [0, 1, 2, 3] pad_idx = TGT.vocab.stoi["<blank>"] print("Making model...") model = make_model(len(SRC.vocab), len(TGT.vocab), N=6) model.cuda() criterion = LabelSmoothing( size=len(TGT.vocab), padding_idx=pad_idx, smoothing=0.1) criterion.cuda() BATCH_SIZE = 12000 train_iter = BatchIterator(train, batch_size=BATCH_SIZE, device=torch.device(0), repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), batch_size_fn=batch_size_fn, train=True) valid_iter = BatchIterator(val, batch_size=BATCH_SIZE, device=torch.device(0), repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), batch_size_fn=batch_size_fn, train=False) model_par = nn.DataParallel(model, device_ids=devices) model_opt = NoamOpt(model.src_embed[0].d_model, 1, 2000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) folder = get_unique_folder("./models/", "model") if not(os.path.exists(folder)): os.mkdir(folder) for epoch in tqdm(range(10)): model_par.train() run_epoch((rebatch(pad_idx, b) for b in train_iter), model_par, MultiGPULossCompute(model.generator, criterion, devices=devices, opt=model_opt)) model_par.eval() loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter), model_par, MultiGPULossCompute(model.generator, criterion, devices=devices, opt=None)) torch.save(model.state_dict, os.path.join(folder, "model.bin." + str(epoch))) print(loss) for i, batch in enumerate(valid_iter): src = batch.src.transpose(0, 1)[:1] src_mask = (src != SRC.vocab.stoi["<blank>"]).unsqueeze(-2) out = greedy_decode(model, src, src_mask, max_len=60, start_symbol=TGT.vocab.stoi["<s>"]) print("Translation:", end="\t") for i in range(1, out.size(1)): sym = TGT.vocab.itos[out[0, i]] if sym == "</s>": break print(sym, end=" ") print() print("Target:", end="\t") for i in range(1, batch.trg.size(0)): sym = TGT.vocab.itos[batch.trg.data[i, 0]] if sym == "</s>": break print(sym, end=" ") print() break
tgt_elements = count * max_tgt_in_batch return max(src_elements, tgt_elements) if __name__ == "__main__": BATCH_SIZE = 12000 parser = argparse.ArgumentParser() parser.add_argument('model_name') parser.add_argument('log_name') args = parser.parse_args() model_file = open(args.model_name, 'rb') log_file = open(args.log_name, 'w+') print("Loading data...") SRC, TGT, train, val, test = generate_dataloaders("../data_processed/") test_iter = BatchIterator(test, batch_size=BATCH_SIZE, device=torch.device(0), repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), batch_size_fn=batch_size_fn, train=False) print("Loading model...") model = make_model(len(SRC.vocab), len(TGT.vocab), N=6) model.load_state_dict(torch.load(args.model_name)) model.cuda() model.eval() print("Generating test output...") log("Testing model stored at " + args.model_name + ".", log_file) for k, batch in tqdm(enumerate(test_iter)):