def test(opts):
    source_vocab = vocabs.load_vocabs_from_file(opts.source_vocab)
    target_vocab = vocabs.load_vocabs_from_file(opts.target_vocab)

    test_dataset = Seq2SeqDataset(opts.testing_dir, source_vocab, target_vocab,
                                  opts.source_lang, opts.target_lang)
    test_dataloader = Seq2SeqDataLoader(
        test_dataset,
        test_dataset.source_pad_id,
        test_dataset.target_pad_id,
        batch_first=True,
        batch_size=opts.batch_size,
        shuffle=True,
        pin_memory=(opts.device.type == "cuda"),
        num_workers=4,
    )

    model = helper.build_model(
        opts,
        test_dataset.source_vocab_size,
        test_dataset.target_vocab_size,
        test_dataset.source_pad_id,
        test_dataset.target_sos,
        test_dataset.target_eos,
        test_dataset.target_pad_id,
        opts.device,
    )
    model.load_state_dict(torch.load(opts.model_path))
    model.eval()

    # The loss function
    loss_function = torch.nn.CrossEntropyLoss(
        ignore_index=test_dataset.target_pad_id)

    # Evaluate the model
    test_loss = evaluate_model_by_loss_function(model, loss_function,
                                                test_dataloader, opts.device)
    test_bleu = evaluate_model_by_bleu_score(
        model,
        test_dataloader,
        opts.device,
        test_dataset.target_sos,
        test_dataset.target_eos,
        test_dataset.target_pad_id,
        target_vocab.get_id2word(),
    )

    print(f"Test loss={test_loss}, Test Bleu={test_bleu}")
def test(opts):
    torch.manual_seed(opts.seed)

    vocab = vocabs.load_vocabs_from_file(opts.vocab)

    test_dataset = Seq2VecDataset(opts.testing_dir, vocab, opts.langs)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=opts.batch_size,
        shuffle=True,
        pin_memory=(opts.device.type == "cuda"),
        num_workers=4,
    )

    model = Seq2VecNN(len(vocab.get_word2id()),
                      2,
                      num_neurons_per_layer=[100, 25])
    model = model.to(opts.device)
    model.load_state_dict(torch.load(opts.model_path))
    model.eval()

    # The loss function
    loss_function = torch.nn.CrossEntropyLoss()

    # Evaluate the model
    test_loss, test_accuracy = evaluate_model(model, loss_function,
                                              test_dataloader, opts.device)

    print(f"Test loss={test_loss}, Test Accuracy={test_accuracy}")
    def __init__(self, vocabs_filepath, saved_model_filepath):
        vocab = vocabs.load_vocabs_from_file(vocabs_filepath)
        self.word2index = vocab.get_word2id()

        self.model = Seq2VecNN(len(self.word2index),
                               2,
                               num_neurons_per_layer=[100, 25])
        self.model.load_state_dict(torch.load(saved_model_filepath))
        self.model.eval()
def train(opts):
    """ Trains the model """
    torch.manual_seed(opts.seed)

    vocab = vocabs.load_vocabs_from_file(opts.vocab)

    dataset = Seq2VecDataset(opts.training_dir, vocab, opts.langs)

    num_training_data = int(len(dataset) * opts.train_val_ratio)
    num_val_data = len(dataset) - num_training_data

    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [num_training_data, num_val_data])

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opts.batch_size,
        shuffle=True,
        pin_memory=(opts.device.type == "cuda"),
        num_workers=2,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=opts.batch_size,
        shuffle=True,
        pin_memory=(opts.device.type == "cuda"),
        num_workers=2,
    )

    model = Seq2VecNN(len(vocab.get_word2id()),
                      2,
                      num_neurons_per_layer=[100, 25])
    model = model.to(opts.device)

    patience = opts.patience
    num_epochs = opts.epochs

    if opts.patience is None:
        patience = float("inf")
    else:
        num_epochs = float("inf")

    best_val_loss = float("inf")

    num_poor = 0
    epoch = 1

    optimizer = torch.optim.Adam(model.parameters(), lr=opts.learning_rate)

    if opts.resume_from_checkpoint and os.path.isfile(
            opts.resume_from_checkpoint):
        print("Loading from checkpoint")
        best_val_loss, num_poor, epoch = load_checkpoint(
            opts.resume_from_checkpoint, model, optimizer)
        print(
            f"Previous state > Epoch {epoch}: Val loss={best_val_loss}, num_poor={num_poor}"
        )

    while epoch <= num_epochs and num_poor < patience:

        loss_function = torch.nn.CrossEntropyLoss()

        # Train
        train_loss, train_accuracy = train_for_one_epoch(
            model, loss_function, optimizer, train_dataloader, opts.device)

        # Evaluate the model
        eval_loss, eval_accuracy = test.evaluate_model(model, loss_function,
                                                       val_dataloader,
                                                       opts.device)

        print(
            f"Epoch={epoch} Train-Loss={train_loss} Train-Acc={train_accuracy} Test-Loss={eval_loss} Test-Acc={eval_accuracy} Num-Poor={num_poor}"
        )

        model.cpu()
        if eval_loss >= best_val_loss:
            num_poor += 1

        else:
            num_poor = 0
            best_eval_loss = eval_loss

            print("Saved model")
            torch.save(model.state_dict(), opts.model_path)

        save_checkpoint(
            opts.save_checkpoint_to,
            model,
            optimizer,
            best_val_loss,
            num_poor,
            epoch,
        )
        print("Saved checkpoint")
        model.to(opts.device)

        epoch += 1

    if epoch > num_epochs:
        print(f"Finished {num_epochs} epochs")
    else:
        print(f"Loss did not improve after {patience} epochs")
Beispiel #5
0
    def __init__(
        self,
        source_lang,
        target_lang,
        source_vocabs_path,
        target_vocabs_path,
        model_weights_path,
        word_embedding_size: int = 256,
        num_encoder_layers: int = 3,
        num_encoder_heads: int = 8,
        encoder_pf_dim: int = 512,
        encoder_dropout: int = 0.1,
        num_decoder_layers: int = 3,
        num_decoder_heads: int = 8,
        decoder_pf_dim: int = 512,
        decoder_dropout: int = 0.1,
    ):
        """ Sets up the translator

            Parameters
            ----------
            source_lang : str
                The source language. Currently supports one of: { 'en', 'fr' }
            target_lang : str
                The target language. Currently supports one of: { 'en', 'fr' }
            source_vocabs_path : str
                The file path to the source vocabs
            target_vocabs_path : str
                The file path to the target vocabs
            model_weights_path : str
                The file path to the model weights 
        """
        # Set up Spacy
        self.spacy_instance = get_spacy_instance(source_lang)

        # A set of punctuations
        self.punctuations = set([w for w in string.punctuation])

        # Set the model params
        self.word_embedding_size = word_embedding_size
        self.num_encoder_layers = num_encoder_layers
        self.num_encoder_heads = num_encoder_heads
        self.encoder_pf_dim = encoder_pf_dim
        self.encoder_dropout = encoder_dropout
        self.num_decoder_layers = num_decoder_layers
        self.num_decoder_heads = num_decoder_heads
        self.decoder_pf_dim = decoder_pf_dim
        self.decoder_dropout = decoder_dropout

        # Get the source and target vocab word mappings
        self.source_vocab = vocabs.load_vocabs_from_file(source_vocabs_path)
        self.target_vocab = vocabs.load_vocabs_from_file(target_vocabs_path)

        source_word2id = self.source_vocab.get_word2id()
        target_word2id = self.target_vocab.get_word2id()

        # Set up the sizes, unk, padding, eos, sos
        self.source_vocab_size = len(source_word2id) + 2
        self.target_vocab_size = len(target_word2id) + 4

        self.source_unk_id = len(source_word2id)
        self.source_pad_id = len(source_word2id) + 1

        self.target_unk_id = len(target_word2id)
        self.target_sos_id = len(target_word2id) + 1
        self.target_eos_id = len(target_word2id) + 2
        self.target_pad_id = len(target_word2id) + 3

        # Set up the model
        self.model = self.__build_model__()
        self.model.load_state_dict(torch.load(model_weights_path))
        self.model.eval()

        del source_word2id, target_word2id
def predict(opts):

    # Get our current version of spacy
    spacy_instance = utils.get_spacy_instance(opts.source_lang)

    # Make the text lowercase and no EOF
    input_text = opts.input_text.lower().strip()

    # Parse input into tokens with spacy
    input_tokens = [
        token.text for token in spacy_instance.tokenizer(input_text)
    ]

    print("Input:", " ".join(input_tokens))

    # Get the vocabs
    # TODO: Handle the case of translating from fr to en
    source_vocab = vocabs.load_vocabs_from_file(opts.source_vocab)
    target_vocab = vocabs.load_vocabs_from_file(opts.target_vocab)

    # Get the mappings
    source_word2id = source_vocab.get_word2id()
    target_word2id = target_vocab.get_word2id()

    source_id2word = source_vocab.get_id2word()
    target_id2word = target_vocab.get_id2word()

    source_vocab_size = len(source_word2id) + 2
    target_vocab_size = len(target_word2id) + 4

    src_unk, src_pad = range(len(source_word2id), source_vocab_size)
    trg_unk, trg_sos, trg_eos, trg_pad = range(len(target_word2id),
                                               target_vocab_size)

    model = helper.build_model(
        opts,
        source_vocab_size,
        target_vocab_size,
        src_pad,
        trg_sos,
        trg_eos,
        trg_pad,
        opts.device,
    )
    model.load_state_dict(torch.load(opts.model_path))
    model.eval()

    src = [torch.tensor([source_word2id[word] for word in input_tokens])]
    src_lens = torch.tensor([len(input_tokens)])
    src = torch.nn.utils.rnn.pad_sequence(src, padding_value=src_pad)

    predicted_words = None
    with torch.no_grad():

        # Get the output
        logits = model(src, src_lens)
        predicted_trg = logits.argmax(2)[0, :]

        # Remove the EOS and SOS
        predicted_trg = predicted_trg[1:-1]

        # Get the resultant sequence of words
        predicted_words = [
            target_id2word.get(word_id.item(), "NAN")
            for word_id in predicted_trg
        ]

    return predicted_words
def train(opts):
    """ Trains the model """
    torch.manual_seed(opts.seed)

    source_vocab = vocabs.load_vocabs_from_file(opts.source_vocab)
    target_vocab = vocabs.load_vocabs_from_file(opts.target_vocab)

    dataset = Seq2SeqDataset(
        opts.training_dir,
        source_vocab,
        target_vocab,
        opts.source_lang,
        opts.target_lang,
    )

    num_training_data = int(len(dataset) * opts.train_val_ratio)
    num_val_data = len(dataset) - num_training_data

    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [num_training_data, num_val_data])

    train_dataloader = Seq2SeqDataLoader(
        train_dataset,
        dataset.source_pad_id,
        dataset.target_pad_id,
        batch_first=True,
        batch_size=opts.batch_size,
        shuffle=True,
        pin_memory=(opts.device.type == "cuda"),
        num_workers=4,
    )
    val_dataloader = Seq2SeqDataLoader(
        val_dataset,
        dataset.source_pad_id,
        dataset.target_pad_id,
        batch_first=True,
        batch_size=opts.batch_size,
        shuffle=True,
        pin_memory=(opts.device.type == "cuda"),
        num_workers=4,
    )

    model = helper.build_model(
        opts,
        dataset.source_vocab_size,
        dataset.target_vocab_size,
        dataset.source_pad_id,
        dataset.target_sos,
        dataset.target_eos,
        dataset.target_pad_id,
        opts.device,
    )

    patience = opts.patience
    num_epochs = opts.epochs

    if opts.patience is None:
        patience = float("inf")
    else:
        num_epochs = float("inf")

    best_val_loss = float("inf")

    num_poor = 0
    epoch = 1

    optimizer = torch.optim.Adam(model.parameters(), lr=opts.learning_rate)

    if opts.resume_from_checkpoint and os.path.isfile(
            opts.resume_from_checkpoint):
        print("Loading from checkpoint")
        best_val_loss, num_poor, epoch = load_checkpoint(
            opts.resume_from_checkpoint, model, optimizer)
        print(
            f"Previous state > Epoch {epoch}: Val loss={best_val_loss}, num_poor={num_poor}"
        )

    while epoch <= num_epochs and num_poor < patience:

        # Train
        loss_function = nn.CrossEntropyLoss(ignore_index=dataset.target_pad_id)
        train_loss = train_for_one_epoch(model, loss_function, optimizer,
                                         train_dataloader, opts.device)

        # Evaluate the model
        val_loss = test.evaluate_model_by_loss_function(
            model, loss_function, val_dataloader, opts.device)

        print(f"Epoch {epoch}: Train loss={train_loss}, Val loss={val_loss}")

        model.cpu()
        if val_loss > best_val_loss:
            num_poor += 1
        else:
            num_poor = 0
            best_val_loss = val_loss

            print("Saved model")
            torch.save(model.state_dict(), opts.model_path)

        save_checkpoint(
            opts.save_checkpoint_to,
            model,
            optimizer,
            best_val_loss,
            num_poor,
            epoch,
        )
        print("Saved checkpoint")

        model.to(opts.device)

        epoch += 1

    if epoch > num_epochs:
        print(f"Finished {num_epochs} epochs")
    else:
        print(f"Loss did not improve after {patience} epochs")

    val_bleu_score = test.evaluate_model_by_bleu_score(
        model,
        val_dataloader,
        opts.device,
        dataset.target_sos,
        dataset.target_eos,
        dataset.target_pad_id,
        target_vocab.get_id2word(),
    )
    print(f"Final BLEU score: {val_bleu_score}. Done.")