def test(model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, log): model.eval() tracker.reset("test") loader = dataset.test_dataloader() with Ctq(loader) as tq_test_loader: for data in tq_test_loader: data = data.to(device) with torch.no_grad(): model.set_input(data) model.forward() tracker.track(model) tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR) metrics = tracker.publish() tracker.print_summary() checkpoint.save_best_models_under_current_metrics(model, metrics)
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() epoch_iter = 0 iter_start_time = 0 for i, data in enumerate(dataset): total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.forward() if total_steps % opt.display_freq == 0: save_result = total_steps % opt.update_html_freq == 0 visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) model.optimize_parameters() if total_steps % opt.print_freq == 0: errors = model.get_current_errors() t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t) if opt.display_id > 0: visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)