def main(loss_config="gt_mse", mode="standard", pretrained=False, batch_size=64, **kwargs): # MODEL # model = DataParallelModel.load(UNet().cuda(), "standardval_rgb2normal_baseline.pth") model = functional_transfers.n.load_model() if pretrained else DataParallelModel(UNet()) model.compile(torch.optim.Adam, lr=(3e-5 if pretrained else 3e-4), weight_decay=2e-6, amsgrad=True) scheduler = MultiStepLR(model.optimizer, milestones=[5*i + 1 for i in range(0, 80)], gamma=0.95) # FUNCTIONAL LOSS functional = get_functional_loss(config=loss_config, mode=mode, model=model, **kwargs) print (functional) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda logger, data: model.save(f"{RESULTS_DIR}/model.pth"), feature="loss", freq=400) logger.add_hook(lambda logger, data: scheduler.step(), feature="epoch", freq=1) functional.logger_hooks(loggers) # DATA LOADING ood_images = load_ood(ood_path=f'{BASE_DIR}/data/ood_images/') train_loader, val_loader, train_step, val_step = load_train_val([tasks.rgb, tasks.normal], batch_size=batch_size) # train_buildings=["almena"], val_buildings=["almena"]) test_set, test_images = load_test("rgb", "normal") logger.images(test_images, "images", resize=128) logger.images(torch.cat(ood_images, dim=0), "ood_images", resize=128) # TRAINING for epochs in range(0, 800): preds_name = "start_preds" if epochs == 0 and pretrained else "preds" ood_name = "start_ood" if epochs == 0 and pretrained else "ood" plot_images(model, logger, test_set, dest_task="normal", ood_images=ood_images, loss_models=functional.plot_losses, preds_name=preds_name, ood_name=ood_name ) logger.update("epoch", epochs) logger.step() train_set = itertools.islice(train_loader, train_step) val_set = itertools.islice(val_loader, val_step) val_metrics = model.predict_with_metrics(val_set, loss_fn=functional, logger=logger) train_metrics = model.fit_with_metrics(train_set, loss_fn=functional, logger=logger) functional.logger_update(logger, train_metrics, val_metrics)
def main(): # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda x: logger.step(), feature="loss", freq=25) resize = 256 ood_images = load_ood()[0] tasks = [ get_task(name) for name in [ 'rgb', 'normal', 'principal_curvature', 'depth_zbuffer', 'sobel_edges', 'reshading', 'keypoints3d', 'keypoints2d' ] ] test_loader = torch.utils.data.DataLoader(TaskDataset(['almena'], tasks), batch_size=64, num_workers=12, shuffle=False, pin_memory=True) imgs = list(itertools.islice(test_loader, 1))[0] gt = {tasks[i].name: batch.cuda() for i, batch in enumerate(imgs)} num_plot = 4 logger.images(ood_images, f"x", nrow=2, resize=resize) edges = finetuned_transfers def get_nbrs(task, edges): res = [] for e in edges: if task == e.src_task: res.append(e) return res max_depth = 10 mse_dict = defaultdict(list) def search_small(x, task, prefix, visited, depth, endpoint): if task.name == 'normal': interleave = torch.stack([ val for pair in zip(x[:num_plot], gt[task.name][:num_plot]) for val in pair ]) logger.images(interleave.clamp(max=1, min=0), prefix, nrow=2, resize=resize) mse, _ = task.loss_func(x, gt[task.name]) mse_dict[task.name].append( (mse.detach().data.cpu().numpy(), prefix)) for transfer in get_nbrs(task, edges): preds = transfer(x) next_prefix = f'{transfer.name}({prefix})' print(f"{transfer.src_task.name}2{transfer.dest_task.name}", next_prefix) if transfer.dest_task.name not in visited: visited.add(transfer.dest_task.name) res = search_small(preds, transfer.dest_task, next_prefix, visited, depth + 1, endpoint) visited.remove(transfer.dest_task.name) return endpoint == task def search_full(x, task, prefix, visited, depth, endpoint): for transfer in get_nbrs(task, edges): preds = transfer(x) next_prefix = f'{transfer.name}({prefix})' print(f"{transfer.src_task.name}2{transfer.dest_task.name}", next_prefix) if transfer.dest_task.name == 'normal': interleave = torch.stack([ val for pair in zip(preds[:num_plot], gt[ transfer.dest_task.name][:num_plot]) for val in pair ]) logger.images(interleave.clamp(max=1, min=0), next_prefix, nrow=2, resize=resize) mse, _ = task.loss_func(preds, gt[transfer.dest_task.name]) mse_dict[transfer.dest_task.name].append( (mse.detach().data.cpu().numpy(), next_prefix)) if transfer.dest_task.name not in visited: visited.add(transfer.dest_task.name) res = search_full(preds, transfer.dest_task, next_prefix, visited, depth + 1, endpoint) visited.remove(transfer.dest_task.name) return endpoint == task def search(x, task, prefix, visited, depth): for transfer in get_nbrs(task, edges): preds = transfer(x) next_prefix = f'{transfer.name}({prefix})' print(f"{transfer.src_task.name}2{transfer.dest_task.name}", next_prefix) if transfer.dest_task.name == 'normal': logger.images(preds.clamp(max=1, min=0), next_prefix, nrow=2, resize=resize) if transfer.dest_task.name not in visited: visited.add(transfer.dest_task.name) res = search(preds, transfer.dest_task, next_prefix, visited, depth + 1) visited.remove(transfer.dest_task.name) with torch.no_grad(): # search_full(gt['rgb'], TASK_MAP['rgb'], 'x', set('rgb'), 1, TASK_MAP['normal']) search(ood_images, get_task('rgb'), 'x', set('rgb'), 1) for name, mse_list in mse_dict.items(): mse_list.sort() print(name) print(mse_list) if len(mse_list) == 1: mse_list.append((0, '-')) rownames = [pair[1] for pair in mse_list] data = [pair[0] for pair in mse_list] print(data, rownames) logger.bar(data, f'{name}_path_mse', opts={'rownames': rownames})