def test_epoch( model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, voting_runs=1, tracker_options={}, ): loaders = dataset.test_dataloaders for loader in loaders: stage_name = loader.dataset.name tracker.reset(stage_name) for i in range(voting_runs): with Ctq(loader) as tq_test_loader: for data in tq_test_loader: with torch.no_grad(): model.set_input(data, device) model.forward() tracker.track(model, **tracker_options) tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR) tracker.finalise(**tracker_options) tracker.print_summary()
def train_epoch( epoch: int, model: BaseModel, dataset, device: str, tracker: BaseTracker, checkpoint: ModelCheckpoint, visualizer: Visualizer, debugging, ): early_break = getattr(debugging, "early_break", False) profiling = getattr(debugging, "profiling", False) model.train() tracker.reset("train") visualizer.reset(epoch, "train") train_loader = dataset.train_dataloader iter_data_time = time.time() with Ctq(train_loader) as tq_train_loader: for i, data in enumerate(tq_train_loader): model.set_input(data, device) t_data = time.time() - iter_data_time iter_start_time = time.time() model.optimize_parameters(epoch, dataset.batch_size) if i % 10 == 0: tracker.track(model) tq_train_loader.set_postfix(**tracker.get_metrics(), data_loading=float(t_data), iteration=float(time.time() - iter_start_time), color=COLORS.TRAIN_COLOR) if visualizer.is_active: visualizer.save_visuals(model.get_current_visuals()) iter_data_time = time.time() if early_break: break if profiling: if i > getattr(debugging, "num_batches", 50): return 0 metrics = tracker.publish(epoch) checkpoint.save_best_models_under_current_metrics(model, metrics, tracker.metric_func) log.info("Learning rate = %f" % model.learning_rate)
def test_epoch( epoch: int, model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, visualizer: Visualizer, debugging, ): early_break = getattr(debugging, "early_break", False) model.eval() loaders = dataset.test_dataloaders for loader in loaders: stage_name = loader.dataset.name tracker.reset(stage_name) visualizer.reset(epoch, stage_name) with Ctq(loader) as tq_test_loader: for data in tq_test_loader: with torch.no_grad(): model.set_input(data, device) model.forward() tracker.track(model) tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR) if visualizer.is_active: visualizer.save_visuals(model.get_current_visuals()) if early_break: break tracker.finalise() metrics = tracker.publish(epoch) tracker.print_summary() checkpoint.save_best_models_under_current_metrics( model, metrics, tracker.metric_func)
def eval_epoch( epoch: int, model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, visualizer: Visualizer, debugging, ): early_break = getattr(debugging, "early_break", False) model.eval() tracker.reset("val") visualizer.reset(epoch, "val") loader = dataset.val_dataloader with Ctq(loader) as tq_val_loader: for data in tq_val_loader: with torch.no_grad(): model.set_input(data, device) model.forward() tracker.track(model) tq_val_loader.set_postfix(**tracker.get_metrics(), color=COLORS.VAL_COLOR) if visualizer.is_active: visualizer.save_visuals(model.get_current_visuals()) if early_break: break metrics = tracker.publish(epoch) tracker.print_summary() checkpoint.save_best_models_under_current_metrics(model, metrics, tracker.metric_func)
def eval_epoch( model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, voting_runs=1, tracker_options={}, ): tracker.reset("val") loader = dataset.val_dataloader for i in range(voting_runs): with Ctq(loader) as tq_val_loader: for data in tq_val_loader: with torch.no_grad(): model.set_input(data, device) model.forward() tracker.track(model, **tracker_options) tq_val_loader.set_postfix(**tracker.get_metrics(), color=COLORS.VAL_COLOR) tracker.finalise(**tracker_options) tracker.print_summary()