def test_if_inference_output_is_valid(tmpdir): """Test that the output inferred from ONNX model is same as from PyTorch.""" model = BoringModel() model.example_input_array = torch.randn(5, 32) trainer = Trainer(fast_dev_run=True) model.eval() with torch.no_grad(): torch_out = model(model.example_input_array) file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path, model.example_input_array, export_params=True) ort_session = onnxruntime.InferenceSession(file_path) def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() # compute ONNX Runtime output prediction ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(model.example_input_array)} ort_outs =, ort_inputs) # compare ONNX Runtime and PyTorch results assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
def test_model_saves_with_example_output(tmpdir): """Test that ONNX model saves when provided with example output.""" model = BoringModel() trainer = Trainer(fast_dev_run=True) file_path = os.path.join(tmpdir, "model.onnx") input_sample = torch.randn((1, 32)) model.eval() example_outputs = model.forward(input_sample) model.to_onnx(file_path, input_sample, example_outputs=example_outputs) assert os.path.exists(file_path) is True
def test_torchscript_input_output_trace(): """ Test that traced LightningModule forward works with example_inputs """ model = BoringModel() example_inputs = torch.randn(1, 32) script = model.to_torchscript(example_inputs=example_inputs, method='trace') assert isinstance(script, torch.jit.ScriptModule) model.eval() with torch.no_grad(): model_output = model(example_inputs) script_output = script(example_inputs) assert torch.allclose(script_output, model_output)
def test_model_saving_loading(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" model = BoringModel() # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer( max_epochs=1, limit_train_batches=2, limit_val_batches=2, logger=logger, callbacks=[ModelCheckpoint(dirpath=tmpdir)], default_root_dir=tmpdir, ) # traning complete assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # make a prediction dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] batch = next(iter(dataloaders[0])) # generate preds before saving model model.eval() pred_before_saving = model(batch) # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, 'hparams.yaml') model_2 = BoringModel.load_from_checkpoint( checkpoint_path=new_weights_path, hparams_file=hparams_path, ) model_2.eval() # make prediction # assert that both predictions are the same new_pred = model_2(batch) assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
def test_cpu_slurm_save_load(tmpdir): """Verify model save/load/checkpoint on CPU.""" model = BoringModel() # logger file to get meta logger = tutils.get_default_logger(tmpdir) version = logger.version # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=logger, limit_train_batches=0.2, limit_val_batches=0.2, callbacks=[ModelCheckpoint(dirpath=tmpdir)], ) real_global_step = trainer.global_step # traning complete assert trainer.state.finished, "cpu model failed to complete" # predict with trained model before saving # make a prediction dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: for batch in dataloader: break model.eval() pred_before_saving = model(batch) # test HPC saving # simulate snapshot on slurm # save logger to make sure we get all the metrics if logger: logger.finalize("finished") hpc_save_path = trainer.checkpoint_connector.hpc_save_path( trainer.weights_save_path) trainer.save_checkpoint(hpc_save_path) assert os.path.exists(hpc_save_path) # new logger file to get meta logger = tutils.get_default_logger(tmpdir, version=version) model = BoringModel() class _StartCallback(Callback): # set the epoch start hook so we can predict before the model does the full training def on_train_epoch_start(self, trainer, model): assert trainer.global_step == real_global_step and trainer.global_step > 0 # predict with loaded model to make sure answers are the same mode = model.eval() new_pred = model(batch) assert torch.eq(pred_before_saving, new_pred).all() model.train(mode) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=logger, callbacks=[_StartCallback(), ModelCheckpoint(dirpath=tmpdir)], ) # by calling fit again, we trigger training, loading weights from the cluster # and our hook to predict using current model before any more weight updates