Esempio n. 1
0
def do_eval(user_sentence):
    check_params(eval_params)
    device = 'cpu'

    src_tokenizer = eval_params.src_tokenizer()
    tgt_tokenizer = eval_params.tgt_tokenizer()
    checkpoint_path = eval_params.checkpoint_path

    base_dir = os.getcwd()
    dataset_dir = os.path.join(base_dir, 'dataset')

    src_vocab_file_path = os.path.join(dataset_dir,
                                       eval_params.src_vocab_filename)
    tgt_vocab_file_path = os.path.join(dataset_dir,
                                       eval_params.tgt_vocab_filename)
    src_word_embedding_file_path = os.path.join(
        dataset_dir, eval_params.src_word_embedding_filename)
    tgt_word_embedding_file_path = os.path.join(
        dataset_dir, eval_params.tgt_word_embedding_filename)

    src_word2id, src_id2word, src_embedding = check_vocab_embedding(
        src_vocab_file_path, src_word_embedding_file_path)
    tgt_word2id, tgt_id2word, tgt_embedding = check_vocab_embedding(
        tgt_vocab_file_path, tgt_word_embedding_file_path)

    encoder_params.vocab_size = len(src_word2id)
    encoder_params.device = device
    encoder = eval_params.encoder(encoder_params)

    decoder_params.vocab_size = len(tgt_word2id)
    decoder_params.device = device
    decoder = eval_params.decoder(decoder_params)

    model: nn.Module = Seq2Seq(encoder, decoder)

    checkpoint = torch.load(os.path.join(base_dir, checkpoint_path))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    src_max_length = encoder_params.max_seq_len
    src_seqs = user_sentence
    print(f'Input sequence: {src_seqs}')
    with torch.no_grad():
        src_tokenized = src_tokenizer.tokenize(src_seqs)
        print(src_tokenized)

        temp_tokenized = []
        for word in src_tokenized:
            if word in src_word2id:
                temp_tokenized.append(src_word2id[word])
            else:
                temp_tokenized.append(UNK_TOKEN_ID)
        src_tokenized = temp_tokenized

        pad_token(src_tokenized, src_max_length)
        print(src_tokenized)

        src_padded_tokens = torch.tensor(src_tokenized,
                                         dtype=torch.long,
                                         device=device).unsqueeze(0)
        src_length = torch.tensor(len(src_tokenized)).unsqueeze(0)
        logits, preds = model(src_padded_tokens, src_length, None, None)

        sentence = []
        for token in preds:
            token = token.item()
            if token == PAD_TOKEN_ID:
                break
            sentence.append(tgt_id2word[token].strip())
        print(sentence)
        print(len(sentence))
    return sentence
Esempio n. 2
0
def main():
    check_params(train_params)

    device = get_device()
    print(f'  Available device is {device}')

    src_tokenizer = train_params.src_tokenizer()
    tgt_tokenizer = train_params.tgt_tokenizer()

    base_dir = os.getcwd()
    dataset_dir = os.path.join(base_dir, 'dataset')

    src_vocab_file_path = os.path.join(dataset_dir,
                                       train_params.src_vocab_filename)
    tgt_vocab_file_path = os.path.join(dataset_dir,
                                       train_params.tgt_vocab_filename)
    src_word_embedding_file_path = os.path.join(
        dataset_dir, train_params.src_word_embedding_filename)
    tgt_word_embedding_file_path = os.path.join(
        dataset_dir, train_params.tgt_word_embedding_filename)
    src_corpus_file_path = os.path.join(dataset_dir,
                                        train_params.src_corpus_filename)
    tgt_corpus_file_path = os.path.join(dataset_dir,
                                        train_params.tgt_corpus_filename)

    src_word2id, src_id2word, src_embed_matrix = ensure_vocab_embedding(
        src_tokenizer, src_vocab_file_path, src_word_embedding_file_path,
        src_corpus_file_path, encoder_params.embedding_dim, "Source")

    tgt_word2id, tgt_id2word, tgt_embed_matrix = ensure_vocab_embedding(
        tgt_tokenizer, tgt_vocab_file_path, tgt_word_embedding_file_path,
        tgt_corpus_file_path, decoder_params.embedding_dim, "Target")

    dataset = ParallelTextDataSet(src_tokenizer, tgt_tokenizer,
                                  src_corpus_file_path, tgt_corpus_file_path,
                                  encoder_params.max_seq_len,
                                  decoder_params.max_seq_len, src_word2id,
                                  tgt_word2id)
    data_loader = DataLoader(dataset,
                             batch_size=train_params.batch_size,
                             shuffle=True,
                             collate_fn=dataset.collate_func)

    encoder_params.vocab_size = len(src_word2id)
    encoder_params.device = device

    decoder_params.vocab_size = len(tgt_word2id)
    decoder_params.device = device

    ## Evaluation dataset

    eval_src_tokenizer = eval_params.src_tokenizer()
    eval_tgt_tokenizer = eval_params.tgt_tokenizer()

    eval_src_vocab_file_path = os.path.join(dataset_dir,
                                            eval_params.src_vocab_filename)
    eval_tgt_vocab_file_path = os.path.join(dataset_dir,
                                            eval_params.tgt_vocab_filename)
    eval_src_word_embedding_file_path = os.path.join(
        dataset_dir, eval_params.src_word_embedding_filename)
    eval_tgt_word_embedding_file_path = os.path.join(
        dataset_dir, eval_params.tgt_word_embedding_filename)
    eval_src_corpus_file_path = os.path.join(dataset_dir,
                                             eval_params.src_corpus_filename)
    eval_tgt_corpus_file_path = os.path.join(dataset_dir,
                                             eval_params.tgt_corpus_filename)

    eval_src_word2id, eval_src_id2word, eval_src_embedding = check_vocab_embedding(
        eval_src_vocab_file_path, eval_src_word_embedding_file_path)
    eval_tgt_word2id, eval_tgt_id2word, eval_tgt_embedding = check_vocab_embedding(
        eval_tgt_vocab_file_path, eval_tgt_word_embedding_file_path)

    # encoder_params.vocab_size = len(src_word2id)
    # encoder_params.device = device
    #
    # decoder_params.vocab_size = len(tgt_word2id)
    # decoder_params.device = device

    eval_dataset = ParallelTextDataSet(eval_src_tokenizer, eval_tgt_tokenizer,
                                       eval_src_corpus_file_path,
                                       eval_tgt_corpus_file_path,
                                       encoder_params.max_seq_len,
                                       decoder_params.max_seq_len,
                                       eval_src_word2id, eval_tgt_word2id)
    eval_data_loader = DataLoader(dataset,
                                  eval_params.batch_size,
                                  collate_fn=dataset.collate_func)

    if train_params['encoder'] == GruEncoder:
        encoder = train_params.encoder(encoder_params)
        # Freeze word embedding weight
        encoder.init_embedding_weight(src_embed_matrix)
        decoder = train_params.decoder(decoder_params)
        # Freeze word embedding weight
        decoder.init_embedding_weight(tgt_embed_matrix)
        model: nn.Module = Seq2Seq(encoder, decoder)

    elif train_params['encoder'] == Transformer:
        encoder = train_params.encoder(encoder_params, decoder_params)
        # Freeze word embedding weight
        encoder.init_src_embedding_weight(src_embed_matrix)
        decoder = train_params.decoder(decoder_params, decoder_params)
        # Freeze word embedding weight
        decoder.init_tgt_embedding_weight(tgt_embed_matrix)
        model: nn.Module = Transformer(encoder_params, decoder_params)

    model.to(device)

    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=train_params.learning_rate)

    epoch = 0
    avg_loss = 0.
    best_val_loss = 1e+10
    for epoch in range(train_params.n_epochs):
        avg_loss, val_loss = train_model(model, optimizer, loss_func,
                                         data_loader, eval_data_loader,
                                         eval_tgt_id2word, device,
                                         train_params, encoder_params,
                                         decoder_params, epoch + 1)

        if val_loss < best_val_loss:
            save_dir_path = os.path.join(train_params.model_save_directory,
                                         get_checkpoint_dir_path(epoch + 1))
            if not os.path.exists(save_dir_path):
                os.makedirs(save_dir_path)

            print("[Best model Save] train_loss: {}, val_loss: {}".format(
                avg_loss, val_loss))
            # CPU에서도 동작 가능하도록 자료형 바꾼 뒤 저장?
            # save checkpoint for best model
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': avg_loss
                }, os.path.join(save_dir_path, 'checkpoint.tar'))

            best_val_loss = val_loss