Example #1
0
def test_lstm_model_init_hidden_states_should_return_correct_output():
    # Given
    expected = (torch.zeros(1, 64, 10), torch.zeros(1, 64, 10))
    model = LSTMModel(100, 10, 10, 5, 1)

    # When
    output = model.init_hidden_states(64)

    # Then
    assert torch.equal(output[0], expected[0])
    assert torch.equal(output[1], expected[1])
Example #2
0
def test_lstm_model_should_call_other_nn_blocks(
    decoder_block_mock, lstm_mock, embedding_block_mock
):
    # Given
    inputs = MagicMock()
    hidden_states = MagicMock()
    model = LSTMModel(100, 10, 10, 5, 1)

    # When
    _ = model(inputs, hidden_states)

    # Then
    embedding_block_mock.assert_called_with(inputs)
    lstm_mock.assert_called_with(10, hidden_states)
    decoder_block_mock.assert_called_with(5)
Example #3
0
 def _eval_on_epoch(
     self,
     model: LSTMModel,
     eval_dataloader: LanguageModelingDataLoader,
     criterion: CrossEntropyLoss,
 ) -> None:
     self._clean_gradients(model)
     self._put_model_to_eval_mode(model)
     with torch.no_grad():
         hidden_states = model.init_hidden_states(self.batch_size)
         for batch_index in tqdm(
                 range(0, len(eval_dataloader), eval_dataloader.bptt),
                 desc=EVAL_DESCRIPTION_MESSAGE,
         ):
             hidden_states = self._eval_on_batch(
                 model,
                 next(eval_dataloader.get_batches(batch_index)),
                 hidden_states,
                 criterion,
             )
Example #4
0
 def _train_on_epoch(
     self,
     model: LSTMModel,
     train_dataloader: LanguageModelingDataLoader,
     criterion: CrossEntropyLoss,
     optimizer: Optimizer,
 ) -> None:
     self._clean_gradients(model)
     self._put_model_to_train_mode(model)
     hidden_states = model.init_hidden_states(self.batch_size)
     for batch_index in tqdm(
             range(0, len(train_dataloader), train_dataloader.bptt),
             desc=TRAIN_DESCRIPTION_MESSAGE,
     ):
         hidden_states = self._train_on_batch(
             model,
             next(train_dataloader.get_batches(batch_index)),
             hidden_states,
             criterion,
             optimizer,
         )
Example #5
0
 def _put_model_to_eval_mode(model: LSTMModel) -> None:
     model.eval()
Example #6
0
 def _put_model_to_train_mode(model: LSTMModel) -> None:
     model.train()
Example #7
0
 def _clean_gradients(model: LSTMModel) -> None:
     model.zero_grad()
Example #8
0
 def _put_model_on_the_device(model: LSTMModel, device=DEVICE) -> None:
     model.to(device)
Example #9
0
def main():
    argument_parser = argparse.ArgumentParser()
    argument_parser.add_argument("--path_to_train_data",
                                 type=str,
                                 required=True)
    argument_parser.add_argument("--path_to_eval_data",
                                 type=str,
                                 required=False,
                                 default=None)
    argument_parser.add_argument("--n_epochs",
                                 type=int,
                                 required=False,
                                 default=3)
    argument_parser.add_argument("--batch_size",
                                 type=int,
                                 required=False,
                                 default=32)
    argument_parser.add_argument("--bptt",
                                 type=int,
                                 required=False,
                                 default=64)
    argument_parser.add_argument("--lr",
                                 type=float,
                                 required=False,
                                 default=0.0001)
    argument_parser.add_argument("--vocabulary_size",
                                 type=int,
                                 required=False,
                                 default=20000)
    argument_parser.add_argument("--embedding_dimension",
                                 type=int,
                                 required=False,
                                 default=300)
    argument_parser.add_argument("--hidden_units_for_lstm",
                                 type=int,
                                 required=False,
                                 default=256)
    argument_parser.add_argument("--num_of_lstm_layer",
                                 type=int,
                                 required=False,
                                 default=1)
    argument_parser.add_argument("--n_decoder_blocks",
                                 type=int,
                                 required=False,
                                 default=5)

    arguments = argument_parser.parse_args()

    train_language_modeling_dataset = LanguageModelingDataset(
        arguments.batch_size, arguments.bptt)
    train_language_modeling_dataset.set_tokenizer(ByteLevelBPETokenizer())
    train_language_modeling_dataset.fit(
        arguments.path_to_train_data,
        vocabulary_size=arguments.vocabulary_size)

    train_language_modeling_dataloader = LanguageModelingDataLoader(
        arguments.bptt,
        train_language_modeling_dataset.transform(arguments.path_to_train_data,
                                                  return_target=True),
    )

    model = LSTMModel(
        arguments.vocabulary_size,
        arguments.embedding_dimension,
        arguments.hidden_units_for_lstm,
        arguments.n_decoder_blocks,
        arguments.num_of_lstm_layer,
    )

    logger = TensorboardLogger()
    trainer = Trainer(arguments.batch_size)
    trainer.set_logger(logger)

    if arguments.path_to_eval_data:
        eval_language_modeling_dataloader = LanguageModelingDataLoader(
            arguments.bptt,
            train_language_modeling_dataset.transform(
                arguments.path_to_eval_data, return_target=True),
        )

        trainer.train(
            model,
            train_language_modeling_dataloader,
            CrossEntropyLoss(),
            Adam(model.parameters(), arguments.lr),
            eval_language_modeling_dataloader,
            arguments.n_epochs,
        )

    else:
        trainer.train(
            model,
            train_language_modeling_dataloader,
            CrossEntropyLoss(),
            Adam(model.parameters(), arguments.lr),
            None,
            arguments.n_epochs,
        )

    logger.log_params(vars(arguments), trainer.losses)
    saver = Saver(logger.log_dir())
    saver.save_preprocessor_and_model(train_language_modeling_dataset, model)