def eval_epoch( epoch: int, model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, visualizer: Visualizer, early_break: bool, ): model.eval() tracker.reset("val") visualizer.reset(epoch, "val") loader = dataset.val_dataloader() with Ctq(loader) as tq_val_loader: for data in tq_val_loader: data = data.to(device) with torch.no_grad(): model.set_input(data) model.forward() tracker.track(model) tq_val_loader.set_postfix(**tracker.get_metrics(), color=COLORS.VAL_COLOR) if visualizer.is_active: visualizer.save_visuals(model.get_current_visuals()) if early_break: break metrics = tracker.publish(epoch) tracker.print_summary() checkpoint.save_best_models_under_current_metrics(model, metrics)
def run_epoch(model: BaseModel, loader, device: str, num_batches: int): model.eval() with Ctq(loader) as tq_loader: for batch_idx, data in enumerate(tq_loader): if batch_idx < num_batches: process(model, data, device) else: break
def test_epoch(model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, log): model.eval() tracker.reset("test") loader = dataset.test_dataloader() with Ctq(loader) as tq_test_loader: for data in tq_test_loader: data = data.to(device) with torch.no_grad(): model.set_input(data) model.forward() tracker.track(model) tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR) metrics = tracker.publish() tracker.print_summary() checkpoint.save_best_models_under_current_metrics(model, metrics)
def test_epoch( epoch: int, model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, visualizer: Visualizer, early_break: bool, ): model.eval() loaders = dataset.test_dataloaders() for idx, loader in enumerate(loaders): stage_name = dataset.get_test_dataset_name(idx) tracker.reset(stage_name) visualizer.reset(epoch, stage_name) with Ctq(loader) as tq_test_loader: for data in tq_test_loader: data = data.to(device) with torch.no_grad(): model.set_input(data) model.forward() tracker.track(model) tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR) if visualizer.is_active: visualizer.save_visuals(model.get_current_visuals()) if early_break: break metrics = tracker.publish(epoch) tracker.print_summary() checkpoint.save_best_models_under_current_metrics(model, metrics)