Esempio n. 1
0
def eval_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    early_break: bool,
):
    model.eval()
    tracker.reset("val")
    visualizer.reset(epoch, "val")
    loader = dataset.val_dataloader()
    with Ctq(loader) as tq_val_loader:
        for data in tq_val_loader:
            data = data.to(device)
            with torch.no_grad():
                model.set_input(data)
                model.forward()

            tracker.track(model)
            tq_val_loader.set_postfix(**tracker.get_metrics(),
                                      color=COLORS.VAL_COLOR)

            if visualizer.is_active:
                visualizer.save_visuals(model.get_current_visuals())

            if early_break:
                break

    metrics = tracker.publish(epoch)
    tracker.print_summary()
    checkpoint.save_best_models_under_current_metrics(model, metrics)
Esempio n. 2
0
def run_epoch(model: BaseModel, loader, device: str, num_batches: int):
    model.eval()
    with Ctq(loader) as tq_loader:
        for batch_idx, data in enumerate(tq_loader):
            if batch_idx < num_batches:
                process(model, data, device)
            else:
                break
Esempio n. 3
0
def test_epoch(model: BaseModel, dataset, device, tracker: BaseTracker,
               checkpoint: ModelCheckpoint, log):
    model.eval()
    tracker.reset("test")
    loader = dataset.test_dataloader()
    with Ctq(loader) as tq_test_loader:
        for data in tq_test_loader:
            data = data.to(device)
            with torch.no_grad():
                model.set_input(data)
                model.forward()

            tracker.track(model)
            tq_test_loader.set_postfix(**tracker.get_metrics(),
                                       color=COLORS.TEST_COLOR)

    metrics = tracker.publish()
    tracker.print_summary()
    checkpoint.save_best_models_under_current_metrics(model, metrics)
Esempio n. 4
0
def test_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    early_break: bool,
):
    model.eval()

    loaders = dataset.test_dataloaders()

    for idx, loader in enumerate(loaders):
        stage_name = dataset.get_test_dataset_name(idx)
        tracker.reset(stage_name)
        visualizer.reset(epoch, stage_name)
        with Ctq(loader) as tq_test_loader:
            for data in tq_test_loader:
                data = data.to(device)
                with torch.no_grad():
                    model.set_input(data)
                    model.forward()

                tracker.track(model)
                tq_test_loader.set_postfix(**tracker.get_metrics(),
                                           color=COLORS.TEST_COLOR)

                if visualizer.is_active:
                    visualizer.save_visuals(model.get_current_visuals())

                if early_break:
                    break

        metrics = tracker.publish(epoch)
        tracker.print_summary()
        checkpoint.save_best_models_under_current_metrics(model, metrics)