Пример #1
0
def test_defaults():
    """
    Test on defaults for SupervisedExperiment class, which is child class of
    BaseExperiment.  That's why we check only default callbacks functionality
    here
    """
    model = torch.nn.Module()
    dataset = torch.utils.data.Dataset()
    dataloader = torch.utils.data.DataLoader(dataset)
    loaders = OrderedDict()
    loaders["train"] = dataloader

    test_callbacks = OrderedDict(
        [
            ("_timer", TimerCallback),
            ("_metrics", MetricManagerCallback),
            ("_validation", ValidationManagerCallback),
            ("_saver", CheckpointCallback),
            ("_console", ConsoleLogger),
            ("_tensorboard", TensorboardLogger),
            ("_exception", ExceptionCallback),
        ]
    )

    exp = SupervisedExperiment(model=model, loaders=loaders)
    _test_callbacks(test_callbacks, exp)
Пример #2
0
def test_scheduler():
    test_callbacks = OrderedDict(
        [
            ("_timer", TimerCallback),
            ("_metrics", MetricManagerCallback),
            ("_validation", ValidationManagerCallback),
            ("_saver", CheckpointCallback),
            ("_console", ConsoleLogger),
            ("_tensorboard", TensorboardLogger),
            ("_exception", ExceptionCallback),
            ("_optimizer", OptimizerCallback),
            ("_scheduler", SchedulerCallback),
        ]
    )

    model = torch.nn.Linear(10, 10)
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10)
    dataset = torch.utils.data.Dataset()
    dataloader = torch.utils.data.DataLoader(dataset)
    loaders = OrderedDict()
    loaders["train"] = dataloader

    exp = SupervisedExperiment(
        model=model,
        loaders=loaders,
        optimizer=optimizer,
        scheduler=scheduler,
    )
    _test_callbacks(test_callbacks, exp)
Пример #3
0
def test_infer_all():
    test_callbacks = OrderedDict(
        [
            ("_verbose", VerboseLogger),
            ("_check", CheckRunCallback),
            ("_exception", ExceptionCallback),
        ]
    )

    model = torch.nn.Linear(10, 10)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10)
    dataset = torch.utils.data.Dataset()
    dataloader = torch.utils.data.DataLoader(dataset)
    loaders = OrderedDict()
    loaders["train"] = dataloader

    exp = SupervisedExperiment(
        model=model,
        loaders=loaders,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        verbose=True,
        check_run=True,
    )
    _test_callbacks(test_callbacks, exp, "infer")
Пример #4
0
def test_criterion():
    test_callbacks = OrderedDict(
        [
            ("_timer", TimerCallback),
            ("_metrics", MetricManagerCallback),
            ("_validation", ValidationManagerCallback),
            ("_saver", CheckpointCallback),
            ("_console", ConsoleLogger),
            ("_tensorboard", TensorboardLogger),
            ("_exception", ExceptionCallback),
            ("_criterion", CriterionCallback),
        ]
    )

    model = torch.nn.Linear(10, 10)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = None
    scheduler = None
    dataset = torch.utils.data.Dataset()
    dataloader = torch.utils.data.DataLoader(dataset)
    loaders = OrderedDict()
    loaders["train"] = dataloader

    exp = SupervisedExperiment(
        model=model,
        loaders=loaders,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
    )
    _test_callbacks(test_callbacks, exp)
Пример #5
0
def test_defaults_check():
    test_callbacks = OrderedDict(
        [
            ("_check", CheckRunCallback),
            ("_timer", TimerCallback),
            ("_metrics", MetricManagerCallback),
            ("_validation", ValidationManagerCallback),
            ("_saver", CheckpointCallback),
            ("_console", ConsoleLogger),
            ("_tensorboard", TensorboardLogger),
            ("_exception", ExceptionCallback),
        ]
    )

    model = torch.nn.Module()
    dataset = torch.utils.data.Dataset()
    dataloader = torch.utils.data.DataLoader(dataset)
    loaders = OrderedDict()
    loaders["train"] = dataloader

    exp = SupervisedExperiment(model=model, loaders=loaders, check_run=True)
    _test_callbacks(test_callbacks, exp)