コード例 #1
0
def test_tensorboard_log_graph(tmpdir, example_input_array):
    """test that log graph works with both model.example_input_array and if array is passed externally."""
    model = BoringModel()
    if example_input_array is not None:
        model.example_input_array = None

    logger = TensorBoardLogger(tmpdir, log_graph=True)
    logger.log_graph(model, example_input_array)
コード例 #2
0
def test_tensorboard_log_graph(tmpdir, example_input_array):
    """ test that log graph works with both model.example_input_array and
        if array is passed externaly
    """
    model = EvalModelTemplate()
    if example_input_array is None:
        model.example_input_array = None
    logger = TensorBoardLogger(tmpdir)
    logger.log_graph(model, example_input_array)
コード例 #3
0
def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
    """ test that log graph throws warning if model.example_input_array is None """
    model = BoringModel()
    model.example_input_array = None
    logger = TensorBoardLogger(tmpdir, log_graph=True)
    with pytest.warns(
            UserWarning,
            match=
            'Could not log computational graph since the `model.example_input_array`'
            ' attribute is not set or `input_array` was not given'):
        logger.log_graph(model)
コード例 #4
0
def test_tensorboard_graph_log(dataloaders_with_covariates, model, tmp_path):
    d = next(iter(dataloaders_with_covariates["train"]))
    logger = TensorBoardLogger("test", str(tmp_path), log_graph=True)
    logger.log_graph(model, d[0])