def do_predict(): train_iterator, valid_iterator, test_iterator, SRC, TGT = prepare_data_multi30k() src_pad_idx = SRC.vocab.stoi[SRC.pad_token] tgt_pad_idx = TGT.vocab.stoi[TGT.pad_token] src_vocab_size = len(SRC.vocab) tgt_vocab_size = len(TGT.vocab) model = Transformer(n_src_vocab=src_vocab_size, n_trg_vocab=tgt_vocab_size, src_pad_idx=src_pad_idx, trg_pad_idx=tgt_pad_idx, d_word_vec=256, d_model=256, d_inner=512, n_layer=3, n_head=8, dropout=0.1, n_position=200) model.cuda() model_dir = "./checkpoint/transformer" model_path = os.path.join(model_dir, "model_9.pt") state_dict = torch.load(model_path) model.load_state_dict(state_dict) model.eval() pre_sents = [] gth_sents = [] for idx, batch in enumerate(test_iterator): if idx % 10 == 0: print("[TIME] --- time: {} --- [TIME]".format(time.ctime(time.time()))) # src_seq: [seq_len, batch_size] # tgt_seq: [seq_len, batch_size] src_seq, src_len = batch.src tgt_seq, tgt_len = batch.trg batch_size = src_seq.size(0) pre_tokens = [] with torch.no_grad(): for idx in range(batch_size): tokens = translate_tokens(src_seq[idx], SRC, TGT, model, max_len=32) pre_tokens.append(tokens) # tgt: [batch_size, seq_len] gth_tokens = tgt_seq.cpu().detach().numpy().tolist() for tokens, gth_ids in zip(pre_tokens, gth_tokens): gth = [TGT.vocab.itos[idx] for idx in gth_ids] pre_sents.append(" ".join(tokens)) gth_sents.append(" ".join(gth)) pre_path = os.path.join(model_dir, "pre.json") gth_path = os.path.join(model_dir, "gth.json") with open(pre_path, "w", encoding="utf-8") as writer: json.dump(pre_sents, writer, ensure_ascii=False, indent=4) with open(gth_path, "w", encoding="utf-8") as writer: json.dump(gth_sents, writer, ensure_ascii=False, indent=4)
def do_train(): train_iterator, valid_iterator, test_iterator, SRC, TGT = prepare_data_multi30k() src_pad_idx = SRC.vocab.stoi[SRC.pad_token] tgt_pad_idx = TGT.vocab.stoi[TGT.pad_token] src_vocab_size = len(SRC.vocab) tgt_vocab_size = len(TGT.vocab) model = Transformer(n_src_vocab=src_vocab_size, n_trg_vocab=tgt_vocab_size, src_pad_idx=src_pad_idx, trg_pad_idx=tgt_pad_idx, d_word_vec=256, d_model=256, d_inner=512, n_layer=3, n_head=8, dropout=0.1, n_position=200) model.cuda() optimizer = Adam(model.parameters(), lr=5e-4) num_epoch = 10 results = [] model_dir = os.path.join("./checkpoint/transformer") for epoch in range(num_epoch): train_loss, train_accuracy = train_epoch(model, optimizer, train_iterator, tgt_pad_idx, smoothing=False) eval_loss, eval_accuracy = eval_epoch(model, valid_iterator, tgt_pad_idx, smoothing=False) os.makedirs(model_dir, exist_ok=True) model_path = os.path.join(model_dir, f"model_{epoch}.pt") torch.save(model.state_dict(), model_path) results.append({"epoch": epoch, "train_loss": train_loss, "eval_loss": eval_loss}) print("[TIME] --- {} --- [TIME]".format(time.ctime(time.time()))) print("epoch: {}, train_loss: {}, eval_loss: {}".format(epoch, train_loss, eval_loss)) print("epoch: {}, train_accuracy: {}, eval_accuracy: {}".format(epoch, train_accuracy, eval_accuracy)) result_path = os.path.join(model_dir, "result.json") with open(result_path, "w", encoding="utf-8") as writer: json.dump(results, writer, ensure_ascii=False, indent=4)