def do_train(): train_iterator, valid_iterator, test_iterator, SRC, TGT = prepare_data_multi30k() src_pad_idx = SRC.vocab.stoi[SRC.pad_token] tgt_pad_idx = TGT.vocab.stoi[TGT.pad_token] src_vocab_size = len(SRC.vocab) tgt_vocab_size = len(TGT.vocab) model = Transformer(n_src_vocab=src_vocab_size, n_trg_vocab=tgt_vocab_size, src_pad_idx=src_pad_idx, trg_pad_idx=tgt_pad_idx, d_word_vec=256, d_model=256, d_inner=512, n_layer=3, n_head=8, dropout=0.1, n_position=200) model.cuda() optimizer = Adam(model.parameters(), lr=5e-4) num_epoch = 10 results = [] model_dir = os.path.join("./checkpoint/transformer") for epoch in range(num_epoch): train_loss, train_accuracy = train_epoch(model, optimizer, train_iterator, tgt_pad_idx, smoothing=False) eval_loss, eval_accuracy = eval_epoch(model, valid_iterator, tgt_pad_idx, smoothing=False) os.makedirs(model_dir, exist_ok=True) model_path = os.path.join(model_dir, f"model_{epoch}.pt") torch.save(model.state_dict(), model_path) results.append({"epoch": epoch, "train_loss": train_loss, "eval_loss": eval_loss}) print("[TIME] --- {} --- [TIME]".format(time.ctime(time.time()))) print("epoch: {}, train_loss: {}, eval_loss: {}".format(epoch, train_loss, eval_loss)) print("epoch: {}, train_accuracy: {}, eval_accuracy: {}".format(epoch, train_accuracy, eval_accuracy)) result_path = os.path.join(model_dir, "result.json") with open(result_path, "w", encoding="utf-8") as writer: json.dump(results, writer, ensure_ascii=False, indent=4)
def run(args): writer = SummaryWriter() src, tgt, train_iterator, val_iterator = build_dataset(args) src_vocab_size = len(src.vocab.itos) tgt_vocab_size = len(tgt.vocab.itos) print('Instantiating model...') device = args.device model = Transformer(src_vocab_size, tgt_vocab_size, device, p_dropout=args.dropout) model = model.to(device) if args.checkpoint is not None: model.load_state_dict(torch.load(args.checkpoint)) else: for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) print('Model instantiated!') optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9) print('Starting training...') for epoch in range(args.epochs): acc = train(model, epoch + 1, train_iterator, optimizer, src.vocab, tgt.vocab, args, writer) model_file = 'models/model_' + str(epoch) + '_' + str(acc) + '.pth' torch.save(model.state_dict(), model_file) print('Saved model to ' + model_file) validate(model, epoch + 1, val_iterator, src.vocab, tgt.vocab, args, writer) print('Finished training.')