def test_lstm_model_init_hidden_states_should_return_correct_output(): # Given expected = (torch.zeros(1, 64, 10), torch.zeros(1, 64, 10)) model = LSTMModel(100, 10, 10, 5, 1) # When output = model.init_hidden_states(64) # Then assert torch.equal(output[0], expected[0]) assert torch.equal(output[1], expected[1])
def test_lstm_model_should_call_other_nn_blocks( decoder_block_mock, lstm_mock, embedding_block_mock ): # Given inputs = MagicMock() hidden_states = MagicMock() model = LSTMModel(100, 10, 10, 5, 1) # When _ = model(inputs, hidden_states) # Then embedding_block_mock.assert_called_with(inputs) lstm_mock.assert_called_with(10, hidden_states) decoder_block_mock.assert_called_with(5)
def _eval_on_epoch( self, model: LSTMModel, eval_dataloader: LanguageModelingDataLoader, criterion: CrossEntropyLoss, ) -> None: self._clean_gradients(model) self._put_model_to_eval_mode(model) with torch.no_grad(): hidden_states = model.init_hidden_states(self.batch_size) for batch_index in tqdm( range(0, len(eval_dataloader), eval_dataloader.bptt), desc=EVAL_DESCRIPTION_MESSAGE, ): hidden_states = self._eval_on_batch( model, next(eval_dataloader.get_batches(batch_index)), hidden_states, criterion, )
def _train_on_epoch( self, model: LSTMModel, train_dataloader: LanguageModelingDataLoader, criterion: CrossEntropyLoss, optimizer: Optimizer, ) -> None: self._clean_gradients(model) self._put_model_to_train_mode(model) hidden_states = model.init_hidden_states(self.batch_size) for batch_index in tqdm( range(0, len(train_dataloader), train_dataloader.bptt), desc=TRAIN_DESCRIPTION_MESSAGE, ): hidden_states = self._train_on_batch( model, next(train_dataloader.get_batches(batch_index)), hidden_states, criterion, optimizer, )
def _put_model_to_eval_mode(model: LSTMModel) -> None: model.eval()
def _put_model_to_train_mode(model: LSTMModel) -> None: model.train()
def _clean_gradients(model: LSTMModel) -> None: model.zero_grad()
def _put_model_on_the_device(model: LSTMModel, device=DEVICE) -> None: model.to(device)
def main(): argument_parser = argparse.ArgumentParser() argument_parser.add_argument("--path_to_train_data", type=str, required=True) argument_parser.add_argument("--path_to_eval_data", type=str, required=False, default=None) argument_parser.add_argument("--n_epochs", type=int, required=False, default=3) argument_parser.add_argument("--batch_size", type=int, required=False, default=32) argument_parser.add_argument("--bptt", type=int, required=False, default=64) argument_parser.add_argument("--lr", type=float, required=False, default=0.0001) argument_parser.add_argument("--vocabulary_size", type=int, required=False, default=20000) argument_parser.add_argument("--embedding_dimension", type=int, required=False, default=300) argument_parser.add_argument("--hidden_units_for_lstm", type=int, required=False, default=256) argument_parser.add_argument("--num_of_lstm_layer", type=int, required=False, default=1) argument_parser.add_argument("--n_decoder_blocks", type=int, required=False, default=5) arguments = argument_parser.parse_args() train_language_modeling_dataset = LanguageModelingDataset( arguments.batch_size, arguments.bptt) train_language_modeling_dataset.set_tokenizer(ByteLevelBPETokenizer()) train_language_modeling_dataset.fit( arguments.path_to_train_data, vocabulary_size=arguments.vocabulary_size) train_language_modeling_dataloader = LanguageModelingDataLoader( arguments.bptt, train_language_modeling_dataset.transform(arguments.path_to_train_data, return_target=True), ) model = LSTMModel( arguments.vocabulary_size, arguments.embedding_dimension, arguments.hidden_units_for_lstm, arguments.n_decoder_blocks, arguments.num_of_lstm_layer, ) logger = TensorboardLogger() trainer = Trainer(arguments.batch_size) trainer.set_logger(logger) if arguments.path_to_eval_data: eval_language_modeling_dataloader = LanguageModelingDataLoader( arguments.bptt, train_language_modeling_dataset.transform( arguments.path_to_eval_data, return_target=True), ) trainer.train( model, train_language_modeling_dataloader, CrossEntropyLoss(), Adam(model.parameters(), arguments.lr), eval_language_modeling_dataloader, arguments.n_epochs, ) else: trainer.train( model, train_language_modeling_dataloader, CrossEntropyLoss(), Adam(model.parameters(), arguments.lr), None, arguments.n_epochs, ) logger.log_params(vars(arguments), trainer.losses) saver = Saver(logger.log_dir()) saver.save_preprocessor_and_model(train_language_modeling_dataset, model)