def test_save_checkpoint_calls_torch_save(self, mock_open, mock_dill, mock_torch): epoch = 5 step = 10 optim = mock.Mock() state_dict = {'epoch': epoch, 'step': step, 'optimizer': optim} mock_model = mock.Mock() mock_vocab = mock.Mock() mock_open.return_value = mock.MagicMock() chk_point = Checkpoint(model=mock_model, optimizer=optim, epoch=epoch, step=step, input_vocab=mock_vocab, output_vocab=mock_vocab) path = chk_point.save(self._get_experiment_dir()) self.assertEquals(2, mock_torch.save.call_count) mock_torch.save.assert_any_call( state_dict, os.path.join(chk_point.path, Checkpoint.TRAINER_STATE_NAME)) mock_torch.save.assert_any_call( mock_model, os.path.join(chk_point.path, Checkpoint.MODEL_NAME)) self.assertEquals(2, mock_open.call_count) mock_open.assert_any_call( os.path.join(path, Checkpoint.INPUT_VOCAB_FILE), ANY) mock_open.assert_any_call( os.path.join(path, Checkpoint.OUTPUT_VOCAB_FILE), ANY) self.assertEquals(2, mock_dill.dump.call_count) mock_dill.dump.assert_any_call( mock_vocab, mock_open.return_value.__enter__.return_value)
def test_save_checkpoint_saves_vocab_if_not_exist(self, mock_torch, mock_os_path_isfile): epoch = 5 step = 10 model_dict = {"key1": "val1"} opt_dict = {"key2": "val2"} mock_model = mock.Mock() mock_model.state_dict.return_value = model_dict input_vocab = mock.Mock() output_vocab = mock.Mock() chk_point = Checkpoint(model=mock_model, optimizer_state_dict=opt_dict, epoch=epoch, step=step, input_vocab=input_vocab, output_vocab=output_vocab) chk_point.save(self._get_experiment_dir()) input_vocab.save.assert_called_once_with(os.path.join(chk_point.path, "input_vocab.pt")) output_vocab.save.assert_called_once_with(os.path.join(chk_point.path, "output_vocab.pt"))
def test_save_checkpoint_calls_torch_save(self, mock_torch): epoch = 5 step = 10 opt_state_dict = {"key2": "val2"} state_dict = {'epoch': epoch, 'step': step, 'optimizer': opt_state_dict} mock_model = mock.Mock() chk_point = Checkpoint(model=mock_model, optimizer_state_dict=opt_state_dict, epoch=epoch, step=step, input_vocab=mock.Mock(), output_vocab=mock.Mock()) chk_point.save(self._get_experiment_dir()) self.assertEquals(2, mock_torch.save.call_count) mock_torch.save.assert_any_call(state_dict, os.path.join(chk_point.path, Checkpoint.TRAINER_STATE_NAME)) mock_torch.save.assert_any_call(mock_model, os.path.join(chk_point.path, Checkpoint.MODEL_NAME))
def train_epochs_generator(train_generator_dict, model, optimizer, loss_func, epochs, start_epoch, experiment_dir, print_every=100, save_every=None, valid_generator_dict=None): best_epoch_loss = 1e8 if start_epoch < 0: start_epoch = 0 start_epoch += 1 for epoch in range(start_epoch, start_epoch + epochs): train_generator = data_generator( train_generator_dict['path'], train_generator_dict['input_word2idx'], train_generator_dict['output_word2idx'], train_generator_dict['glove_array'], train_generator_dict['max_len'], train_generator_dict['batch_size']) if valid_generator_dict is not None: valid_generator = data_generator( valid_generator_dict['path'], valid_generator_dict['input_word2idx'], valid_generator_dict['output_word2idx'], valid_generator_dict['glove_array'], valid_generator_dict['max_len'], valid_generator_dict['batch_size']) step = 0 epoch_loss = 0 model.train() for inputs, lengths, targets, indices in train_generator: step += 1 epoch_loss += train_on_mini_batch(model, optimizer, loss_func, inputs, lengths, targets) if step % print_every == 0: print('Step : {} Avg Step Loss : {}'.format( step, epoch_loss / step)) # break if save_every is not None and step % save_every == 0: checkpoint = Checkpoint(model, optimizer, epoch, step, None, None) checkpoint.save(experiment_dir) print('Saved step chekpoint ...epoch:{} \tstep:{}'.format( epoch, step)) if best_epoch_loss > epoch_loss: best_epoch_loss = epoch_loss checkpoint = Checkpoint(model, optimizer, epoch, step, None, None) checkpoint.save(experiment_dir) print('Saved epoch chekpoint ...epoch:{} \tstep:{}'.format( epoch, step)) val_acc = '--NA--' if valid_generator: val_acc = eval_generator(valid_generator, model) print( '\n-->--> Epoch: {} \tTraining Loss: {} \tValidation Accuracy: {}' .format(epoch, epoch_loss, val_acc)) else: print('\n-->--> Epoch: {} \tTraining Loss: {}'.format( epoch, epoch_loss))