Пример #1
0
    def test_init_metrics() -> None:
        args, _, _ = init_pipeline("--no-save", "--no-visualize",
                                   "--num-examples=100", "--epochs=1")

        metrics = MetricTracker(args, {})

        assert metrics
Пример #2
0
def viz() -> None:
    args, device, checkpoint = init_pipeline()
    dataset_loader = get_dataset_initializer(args.dataset)
    train_loader, _, init_params = dataset_loader.load_train_data(args, device)
    init_params = checkpoint.get("model_init", init_params)
    model = get_model_initializer(args.model)(*init_params).to(device)
    util.load_state_dict(checkpoint, model)

    sample_loader = util.get_sample_loader(train_loader)
    visualize(args, model, sample_loader)
    visualize_trained(args, model, sample_loader)
Пример #3
0
def test(*arg_list: str) -> None:
    args, device, checkpoint = init_pipeline(*arg_list)
    criterion = get_loss_initializer(args.loss)()
    test_loader = get_dataset_initializer(args.dataset).load_test_data(
        args, device)
    init_params = checkpoint.get("model_init", [])
    model = get_model_initializer(args.model)(*init_params).to(device)
    util.load_state_dict(checkpoint, model)
    sample_loader = util.get_sample_loader(test_loader)
    model_summary(args, model, sample_loader)

    test_model(args, model, test_loader, criterion)
Пример #4
0
    def test_epoch_update(capsys: pytest.CaptureFixture[str],
                          example_batch: SimpleNamespace) -> None:
        args, _, _ = init_pipeline("--no-save", "--no-visualize", "--epochs=1",
                                   "--log-interval=3")
        metrics = MetricTracker(args, {})
        num_batches = 4
        for i in range(num_batches):
            _ = metrics.batch_update(example_batch, i, num_batches, Mode.TRAIN)

        metrics.epoch_update(Mode.TRAIN)

        captured = capsys.readouterr().out
        assert captured == "Mode.TRAIN Loss: 0.2100 Accuracy: 66.67% \n"
Пример #5
0
    def test_one_batch_update(example_batch: SimpleNamespace) -> None:
        args, _, _ = init_pipeline("--no-save", "--no-visualize", "--epochs=1")
        metrics = MetricTracker(args, {})

        tqdm_dict = metrics.batch_update(example_batch, 0, 1, Mode.TRAIN)

        for key, value in tqdm_dict.items():
            tqdm_dict[key] = round(value, 2)
        result = [
            round(metric.epoch_avg, 2)
            for metric in metrics.metric_data.values()
        ]
        assert tqdm_dict == {"Loss": 0.21, "Accuracy": 0.67}
        assert result == [0.63, 2]
Пример #6
0
def train(*arg_list: str) -> MetricTracker:
    args, device, checkpoint = init_pipeline(*arg_list)
    dataset_loader = get_dataset_initializer(args.dataset)
    train_loader, val_loader, init_params = dataset_loader.load_train_data(
        args, device)
    sample_loader = util.get_sample_loader(train_loader)
    model, criterion, optimizer, scheduler = load_model(
        args, device, init_params, sample_loader)
    util.load_state_dict(checkpoint, model, optimizer, scheduler)
    metrics = MetricTracker(args, checkpoint, dataset_loader.CLASS_LABELS)
    visualize(args, model, sample_loader, metrics)

    util.set_rng_state(checkpoint)
    for _ in range(args.epochs):
        metrics.next_epoch()
        train_and_validate(args, model, train_loader, optimizer, criterion,
                           metrics, Mode.TRAIN)
        train_and_validate(args, model, val_loader, None, criterion, metrics,
                           Mode.VAL)
        if scheduler is not None:
            scheduler.step()

        if not args.no_save:
            checkpoint_dict = {
                "model_init":
                init_params,
                "model_state_dict":
                model.state_dict(),
                "optimizer_state_dict":
                optimizer.state_dict(),
                "scheduler_state_dict":
                (None if scheduler is None else scheduler.state_dict()),
                "rng_state":
                random.getstate(),
                "np_rng_state":
                np.random.get_state(),
                "torch_rng_state":
                torch.get_rng_state(),
                "run_name":
                metrics.run_name,
                "metric_obj":
                metrics.json_repr(),
            }
            util.save_checkpoint(checkpoint_dict, metrics.is_best)

    torch.set_grad_enabled(True)
    visualize_trained(args, model, sample_loader, metrics)
    return metrics
Пример #7
0
    def test_many_batch_update(example_batch: SimpleNamespace) -> None:
        args, _, _ = init_pipeline("--no-save", "--no-visualize", "--epochs=1",
                                   "--log-interval=3")
        metrics = MetricTracker(args, {})
        num_batches = 4
        tqdm_dict = {}
        for i in range(num_batches):
            tqdm_dict = metrics.batch_update(example_batch, i, num_batches,
                                             Mode.TRAIN)

        for key, value in tqdm_dict.items():
            tqdm_dict[key] = round(value, 2)
        result = [
            round(metric.epoch_avg, 2)
            for metric in metrics.metric_data.values()
        ]
        assert tqdm_dict == {"Loss": 0.21, "Accuracy": 0.67}
        assert all(metric.running_avg == 0
                   for metric in metrics.metric_data.values())
        assert result == [2.52, 8]