def do_eval(user_sentence): check_params(eval_params) device = 'cpu' src_tokenizer = eval_params.src_tokenizer() tgt_tokenizer = eval_params.tgt_tokenizer() checkpoint_path = eval_params.checkpoint_path base_dir = os.getcwd() dataset_dir = os.path.join(base_dir, 'dataset') src_vocab_file_path = os.path.join(dataset_dir, eval_params.src_vocab_filename) tgt_vocab_file_path = os.path.join(dataset_dir, eval_params.tgt_vocab_filename) src_word_embedding_file_path = os.path.join( dataset_dir, eval_params.src_word_embedding_filename) tgt_word_embedding_file_path = os.path.join( dataset_dir, eval_params.tgt_word_embedding_filename) src_word2id, src_id2word, src_embedding = check_vocab_embedding( src_vocab_file_path, src_word_embedding_file_path) tgt_word2id, tgt_id2word, tgt_embedding = check_vocab_embedding( tgt_vocab_file_path, tgt_word_embedding_file_path) encoder_params.vocab_size = len(src_word2id) encoder_params.device = device encoder = eval_params.encoder(encoder_params) decoder_params.vocab_size = len(tgt_word2id) decoder_params.device = device decoder = eval_params.decoder(decoder_params) model: nn.Module = Seq2Seq(encoder, decoder) checkpoint = torch.load(os.path.join(base_dir, checkpoint_path)) model.load_state_dict(checkpoint['model_state_dict']) model.eval() src_max_length = encoder_params.max_seq_len src_seqs = user_sentence print(f'Input sequence: {src_seqs}') with torch.no_grad(): src_tokenized = src_tokenizer.tokenize(src_seqs) print(src_tokenized) temp_tokenized = [] for word in src_tokenized: if word in src_word2id: temp_tokenized.append(src_word2id[word]) else: temp_tokenized.append(UNK_TOKEN_ID) src_tokenized = temp_tokenized pad_token(src_tokenized, src_max_length) print(src_tokenized) src_padded_tokens = torch.tensor(src_tokenized, dtype=torch.long, device=device).unsqueeze(0) src_length = torch.tensor(len(src_tokenized)).unsqueeze(0) logits, preds = model(src_padded_tokens, src_length, None, None) sentence = [] for token in preds: token = token.item() if token == PAD_TOKEN_ID: break sentence.append(tgt_id2word[token].strip()) print(sentence) print(len(sentence)) return sentence
def main(): check_params(train_params) device = get_device() print(f' Available device is {device}') src_tokenizer = train_params.src_tokenizer() tgt_tokenizer = train_params.tgt_tokenizer() base_dir = os.getcwd() dataset_dir = os.path.join(base_dir, 'dataset') src_vocab_file_path = os.path.join(dataset_dir, train_params.src_vocab_filename) tgt_vocab_file_path = os.path.join(dataset_dir, train_params.tgt_vocab_filename) src_word_embedding_file_path = os.path.join( dataset_dir, train_params.src_word_embedding_filename) tgt_word_embedding_file_path = os.path.join( dataset_dir, train_params.tgt_word_embedding_filename) src_corpus_file_path = os.path.join(dataset_dir, train_params.src_corpus_filename) tgt_corpus_file_path = os.path.join(dataset_dir, train_params.tgt_corpus_filename) src_word2id, src_id2word, src_embed_matrix = ensure_vocab_embedding( src_tokenizer, src_vocab_file_path, src_word_embedding_file_path, src_corpus_file_path, encoder_params.embedding_dim, "Source") tgt_word2id, tgt_id2word, tgt_embed_matrix = ensure_vocab_embedding( tgt_tokenizer, tgt_vocab_file_path, tgt_word_embedding_file_path, tgt_corpus_file_path, decoder_params.embedding_dim, "Target") dataset = ParallelTextDataSet(src_tokenizer, tgt_tokenizer, src_corpus_file_path, tgt_corpus_file_path, encoder_params.max_seq_len, decoder_params.max_seq_len, src_word2id, tgt_word2id) data_loader = DataLoader(dataset, batch_size=train_params.batch_size, shuffle=True, collate_fn=dataset.collate_func) encoder_params.vocab_size = len(src_word2id) encoder_params.device = device decoder_params.vocab_size = len(tgt_word2id) decoder_params.device = device ## Evaluation dataset eval_src_tokenizer = eval_params.src_tokenizer() eval_tgt_tokenizer = eval_params.tgt_tokenizer() eval_src_vocab_file_path = os.path.join(dataset_dir, eval_params.src_vocab_filename) eval_tgt_vocab_file_path = os.path.join(dataset_dir, eval_params.tgt_vocab_filename) eval_src_word_embedding_file_path = os.path.join( dataset_dir, eval_params.src_word_embedding_filename) eval_tgt_word_embedding_file_path = os.path.join( dataset_dir, eval_params.tgt_word_embedding_filename) eval_src_corpus_file_path = os.path.join(dataset_dir, eval_params.src_corpus_filename) eval_tgt_corpus_file_path = os.path.join(dataset_dir, eval_params.tgt_corpus_filename) eval_src_word2id, eval_src_id2word, eval_src_embedding = check_vocab_embedding( eval_src_vocab_file_path, eval_src_word_embedding_file_path) eval_tgt_word2id, eval_tgt_id2word, eval_tgt_embedding = check_vocab_embedding( eval_tgt_vocab_file_path, eval_tgt_word_embedding_file_path) # encoder_params.vocab_size = len(src_word2id) # encoder_params.device = device # # decoder_params.vocab_size = len(tgt_word2id) # decoder_params.device = device eval_dataset = ParallelTextDataSet(eval_src_tokenizer, eval_tgt_tokenizer, eval_src_corpus_file_path, eval_tgt_corpus_file_path, encoder_params.max_seq_len, decoder_params.max_seq_len, eval_src_word2id, eval_tgt_word2id) eval_data_loader = DataLoader(dataset, eval_params.batch_size, collate_fn=dataset.collate_func) if train_params['encoder'] == GruEncoder: encoder = train_params.encoder(encoder_params) # Freeze word embedding weight encoder.init_embedding_weight(src_embed_matrix) decoder = train_params.decoder(decoder_params) # Freeze word embedding weight decoder.init_embedding_weight(tgt_embed_matrix) model: nn.Module = Seq2Seq(encoder, decoder) elif train_params['encoder'] == Transformer: encoder = train_params.encoder(encoder_params, decoder_params) # Freeze word embedding weight encoder.init_src_embedding_weight(src_embed_matrix) decoder = train_params.decoder(decoder_params, decoder_params) # Freeze word embedding weight decoder.init_tgt_embedding_weight(tgt_embed_matrix) model: nn.Module = Transformer(encoder_params, decoder_params) model.to(device) loss_func = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=train_params.learning_rate) epoch = 0 avg_loss = 0. best_val_loss = 1e+10 for epoch in range(train_params.n_epochs): avg_loss, val_loss = train_model(model, optimizer, loss_func, data_loader, eval_data_loader, eval_tgt_id2word, device, train_params, encoder_params, decoder_params, epoch + 1) if val_loss < best_val_loss: save_dir_path = os.path.join(train_params.model_save_directory, get_checkpoint_dir_path(epoch + 1)) if not os.path.exists(save_dir_path): os.makedirs(save_dir_path) print("[Best model Save] train_loss: {}, val_loss: {}".format( avg_loss, val_loss)) # CPU에서도 동작 가능하도록 자료형 바꾼 뒤 저장? # save checkpoint for best model torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_loss }, os.path.join(save_dir_path, 'checkpoint.tar')) best_val_loss = val_loss