コード例 #1
0
def test_if_inference_output_is_valid(tmpdir):
    """Test that the output inferred from ONNX model is same as from PyTorch"""
    model = EvalModelTemplate()
    trainer = Trainer(max_epochs=5)
    trainer.fit(model)

    model.eval()
    with torch.no_grad():
        torch_out = model(model.example_input_array)

    file_path = os.path.join(tmpdir, "model.onnx")
    model.to_onnx(file_path, model.example_input_array, export_params=True)

    ort_session = onnxruntime.InferenceSession(file_path)

    def to_numpy(tensor):
        return tensor.detach().cpu().numpy(
        ) if tensor.requires_grad else tensor.cpu().numpy()

    # compute ONNX Runtime output prediction
    ort_inputs = {
        ort_session.get_inputs()[0].name: to_numpy(model.example_input_array)
    }
    ort_outs = ort_session.run(None, ort_inputs)

    # compare ONNX Runtime and PyTorch results
    assert np.allclose(to_numpy(torch_out),
                       ort_outs[0],
                       rtol=1e-03,
                       atol=1e-05)
コード例 #2
0
def test_model_saves_with_example_output(tmpdir):
    """Test that ONNX model saves when provided with example output"""
    model = EvalModelTemplate()
    trainer = Trainer(max_epochs=1)
    trainer.fit(model)

    file_path = os.path.join(tmpdir, "model.onxx")
    input_sample = torch.randn((1, 28 * 28))
    model.eval()
    example_outputs = model.forward(input_sample)
    model.to_onnx(file_path, input_sample, example_outputs=example_outputs)
    assert os.path.exists(file_path) is True
コード例 #3
0
def test_model_saving_loading(tmpdir):
    """Tests use case where trainer saves the model, and user loads it from tags independently."""
    model = EvalModelTemplate()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    trainer_options = dict(
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(tmpdir)
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    # traning complete
    assert result == 1, 'amp + ddp model failed to complete'

    # make a prediction
    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        for batch in dataloader:
            break

    x, y = batch
    x = x.view(x.size(0), -1)

    # generate preds before saving model
    model.eval()
    pred_before_saving = model(x)

    # save model
    new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
    trainer.save_checkpoint(new_weights_path)

    # load new model
    hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
    hparams_path = os.path.join(hparams_path, 'hparams.yaml')
    model_2 = EvalModelTemplate.load_from_checkpoint(
        checkpoint_path=new_weights_path,
        hparams_file=hparams_path
    )
    model_2.eval()

    # make prediction
    # assert that both predictions are the same
    new_pred = model_2(x)
    assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
コード例 #4
0
def test_cpu_slurm_save_load(tmpdir):
    """Verify model save/load/checkpoint on CPU."""
    hparams = EvalModelTemplate.get_default_hparams()
    model = EvalModelTemplate(hparams)

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)
    version = logger.version

    # fit model
    trainer = Trainer(max_epochs=1,
                      logger=logger,
                      checkpoint_callback=ModelCheckpoint(tmpdir))
    result = trainer.fit(model)
    real_global_step = trainer.global_step

    # traning complete
    assert result == 1, 'cpu model failed to complete'

    # predict with trained model before saving
    # make a prediction
    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        for batch in dataloader:
            break

    x, y = batch
    x = x.view(x.size(0), -1)

    model.eval()
    pred_before_saving = model(x)

    # test HPC saving
    # simulate snapshot on slurm
    saved_filepath = trainer.hpc_save(tmpdir, logger)
    assert os.path.exists(saved_filepath)

    # new logger file to get meta
    logger = tutils.get_default_logger(tmpdir, version=version)

    trainer = Trainer(
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(tmpdir),
    )
    model = EvalModelTemplate(hparams)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_pred_same():
        assert trainer.global_step == real_global_step and trainer.global_step > 0

        # predict with loaded model to make sure answers are the same
        trainer.model.eval()
        new_pred = trainer.model(x)
        assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1

    model.on_epoch_start = assert_pred_same

    # by calling fit again, we trigger training, loading weights from the cluster
    # and our hook to predict using current model before any more weight updates
    trainer.fit(model)