def run_eval_suite(name, dest_task=tasks.normal, graph_file=None, model_file=None, logger=None, sample=800, show_images=False, old=False): if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, dest_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif model_file is not None: #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file) model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = Transfer(src_task=tasks.normal, dest_task=dest_task).load_model() model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) dataset = ValidationMetrics("almena", dest_task=dest_task) result = dataset.evaluate(model, sample=800) logger.text(name + ": " + str(result))
def run_viz_suite(name, data, dest_task=tasks.depth_zbuffer, graph_file=None, model_file=None, logger=None, old=False, multitask=False, percep_mode=None): if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, dest_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif multitask: model = DataParallelModel.load( UNet(downsample=5, out_channels=6).cuda(), model_file) elif model_file is not None: print('here') #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file) model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model() model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) # DATA LOADING 1 results = model.predict(data)[:, -3:].clamp(min=0, max=1) if results.shape[1] == 1: results = torch.cat([results] * 3, dim=1) if percep_mode: percep_model = Transfer(src_task=dest_task, dest_task=tasks.normal).load_model() percep_model.eval() eval_loader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(results), batch_size=16, num_workers=16, shuffle=False, pin_memory=True) final_preds = [] for preds, in eval_loader: print('preds shape', preds.shape) final_preds += [percep_model.forward(preds[:, -3:])] results = torch.cat(final_preds, dim=0) return results
def main(): logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="energy", freq=16) logger.add_hook( lambda logger, data: logger.plot(data["energy"], "free_energy"), feature="energy", freq=100) task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.sobel_edges, tasks.depth_zbuffer, tasks.reshading, tasks.edge_occlusion, tasks.keypoints3d, tasks.keypoints2d, ] reality = RealityTask('ood', dataset=ImagePairDataset(data_dir=OOD_DIR, resize=(256, 256)), tasks=[tasks.rgb, tasks.rgb], batch_size=28) # reality = RealityTask('almena', # dataset=TaskDataset( # buildings=['almena'], # tasks=task_list, # ), # tasks=task_list, # batch_size=8 # ) graph = TaskGraph(tasks=[reality, *task_list], batch_size=28) task = tasks.rgb images = [reality.task_data[task]] sources = [task.name] for _, edge in sorted( ((edge.dest_task.name, edge) for edge in graph.adj[task])): if isinstance(edge.src_task, RealityTask): continue reality.task_data[edge.src_task] x = edge(reality.task_data[edge.src_task]) if edge.dest_task != tasks.normal: edge2 = graph.edge_map[(edge.dest_task.name, tasks.normal.name)] x = edge2(x) images.append(x.clamp(min=0, max=1)) sources.append(edge.dest_task.name) logger.images_grouped(images, ", ".join(sources), resize=256)
def run_perceptual_eval_suite(name, intermediate_task=tasks.normal, dest_task=tasks.normal, graph_file=None, model_file=None, logger=None, sample=800, show_images=False, old=False, perceptual_transfer=None, multitask=False): if perceptual_transfer is None: percep_model = Transfer(src_task=intermediate_task, dest_task=dest_task).load_model() if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, intermediate_task], pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, intermediate_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif multitask: print('running multitask') model = DataParallelModel.load( UNet(downsample=5, out_channels=6).cuda(), model_file) elif model_file is not None: #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file) model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = Transfer(src_task=tasks.rgb, dest_task=intermediate_task).load_model() model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) dataset = ValidationMetrics("almena", dest_task=dest_task) result = dataset.evaluate_with_percep(model, sample=800, percep_model=percep_model) logger.text(name + ": " + str(result))
def main(): logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="energy", freq=16) logger.add_hook(lambda logger, data: logger.plot(data["energy"], "free_energy"), feature="energy", freq=100) task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.sobel_edges, tasks.depth_zbuffer, tasks.reshading, tasks.edge_occlusion, tasks.keypoints3d, tasks.keypoints2d, ] reality = RealityTask('almena', dataset=TaskDataset( buildings=['almena'], tasks=task_list, ), tasks=task_list, batch_size=8 ) graph = TaskGraph( tasks=[reality, *task_list], batch_size=8 ) for task in graph.tasks: if isinstance(task, RealityTask): continue images = [reality.task_data[task].clamp(min=0, max=1)] sources = [task.name] for _, edge in sorted(((edge.src_task.name, edge) for edge in graph.in_adj[task])): if isinstance(edge.src_task, RealityTask): continue x = edge(reality.task_data[edge.src_task]) images.append(x.clamp(min=0, max=1)) sources.append(edge.src_task.name) logger.images_grouped(images, ", ".join(sources), resize=192)
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", cont_gan=None, pre_gan=None, use_patches=False, patch_size=64, use_baseline=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned, freeze_list=energy_loss.freeze_list) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not use_baseline and not USE_RAID: graph.load_weights(cont) pre_gan = pre_gan or 1 discriminator = Discriminator(energy_loss.losses['gan'], size=(patch_size if use_patches else 224), use_patches=use_patches) # if cont_gan is not None: discriminator.load_weights(cont_gan) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) logger.add_hook( lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, 80): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() discriminator.train() for _ in range(0, train_step): if epochs > pre_gan: energy_loss.train_iter += 1 train_loss = energy_loss(graph, discriminator=discriminator, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) for i in range(5 if epochs <= pre_gan else 1): train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train]) discriminator.step(train_loss2) train.step() graph.eval() discriminator.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, discriminator=discriminator, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="trainsig_edgereshade", mode="standard", visualize=False, fast=False, batch_size=32, learning_rate=3e-5, resume=False, subset_size=None, max_epochs=800, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({"loss_config":loss_config,"batch_size":batch_size,"lr":learning_rate}) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_undist_dataset, train_dist_dataset, val_ooddist_dataset, val_dist_dataset, val_dataset, train_step, val_step = load_train_val_sig( energy_loss.get_tasks("val"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood"), ood_path='./assets/ood_natural/') ood_syn_aug_set = load_ood(energy_loss.get_tasks("ood_syn_aug"), ood_path='./assets/st_syn_distortions/') ood_syn_set = load_ood(energy_loss.get_tasks("ood_syn"), ood_path='./assets/ood_syn_distortions/', sample=35) train_undist = RealityTask("train_undist", train_undist_dataset, batch_size=batch_size, shuffle=True) train_dist = RealityTask("train_dist", train_dist_dataset, batch_size=batch_size, shuffle=True) val_ooddist = RealityTask("val_ooddist", val_ooddist_dataset, batch_size=batch_size, shuffle=True) val_dist = RealityTask("val_dist", val_dist_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) ## standard ood set - natural ood_syn_aug = RealityTask.from_static("ood_syn_aug", ood_syn_aug_set, [tasks.rgb,]) ## synthetic distortion images used for sig training ood_syn = RealityTask.from_static("ood_syn", ood_syn_set, [tasks.rgb,]) ## unseen syn distortions # GRAPH realities = [train_undist, train_dist, val_ooddist, val_dist, val, test, ood, ood_syn_aug, ood_syn] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if resume: graph.load_weights('/workspace/shared/results_test_1/graph.pth') graph.optimizer.load_state_dict(torch.load('/workspace/shared/results_test_1/opt.pth')) # else: # folder_name='/workspace/shared/results_wavelet2normal_depthreshadecurvimgnetl1perceps_0.1nll/' # # pdb.set_trace() # in_domain='wav' # out_domain='normal' # graph.load_weights(folder_name+'graph.pth', [str((in_domain, out_domain))]) # create_t0_graph(folder_name,in_domain,out_domain) # graph.load_weights(folder_name+'graph_t0.pth', [str((in_domain, f'{out_domain}_t0'))]) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) # BASELINE if not resume: graph.eval() with torch.no_grad(): for reality in [val_ooddist,val_dist,val]: for _ in range(0, val_step): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) for reality in [train_undist,train_dist]: for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[reality]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) reality.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=0) path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) graph.train() for _ in range(0, train_step): train_loss_nll = energy_loss(graph, realities=[train_undist]) train_loss_nll = sum([train_loss_nll[loss_name] for loss_name in train_loss_nll]) train_loss_lwfsig = energy_loss(graph, realities=[train_dist]) train_loss_lwfsig = sum([train_loss_lwfsig[loss_name] for loss_name in train_loss_lwfsig]) train_loss = train_loss_nll+train_loss_lwfsig graph.step(train_loss) train_undist.step() train_dist.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val_dist]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val_dist.step() logger.update("loss", val_loss) if epochs % 20 == 0: for reality in [val,val_ooddist]: for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=epochs+1) if epochs % 10 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth") if (epochs % 100 == 0) or (epochs % 15 == 0 and epochs <= 30): path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs+1) graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, use_optimizer=False, subset_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) print(kwargs) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size) train_step, val_step = train_step // 4, val_step // 4 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood([ tasks.rgb, ]) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge( tasks.rgb, tasks.normal ).path = "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/n.pth" graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if use_optimizer: optimizer = torch.load( "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/optimizer.pth" ) graph.optimizer.load_state_dict(optimizer.state_dict()) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # if logger.data["val_mse : y^ -> n(~x)"][-1] < best_ood_val_loss: # best_ood_val_loss = logger.data["val_mse : y^ -> n(~x)"][-1] # energy_loss.plot_paths(graph, logger, realities, prefix="best") logger.step()
def main( mode="standard", visualize=False, pretrained=True, finetuned=False, batch_size=None, **kwargs, ): configs = { "VISUALS3_rgb2normals2x_multipercep8_winrate_standardized_upd": dict( loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], cont="mount/shared/results_LBP_multipercep8_winrate_standardized_upd_3/graph.pth", test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2reshade2x_latwinrate_reshadetarget": dict( loss_configs=["baseline_reshade_size256", "baseline_reshade_size320", "baseline_reshade_size384", "baseline_reshade_size448", "baseline_reshade_size512"], cont="mount/shared/results_LBP_multipercep_latwinrate_reshadingtarget_6/graph.pth", test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2reshade2x_reshadebaseline": dict( loss_configs=["baseline_reshade_size256", "baseline_reshade_size320", "baseline_reshade_size384", "baseline_reshade_size448", "baseline_reshade_size512"], test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2reshade2x_latwinrate_depthtarget": dict( loss_configs=["baseline_depth_size256", "baseline_depth_size320", "baseline_depth_size384", "baseline_depth_size448", "baseline_depth_size512"], cont="mount/shared/results_LBP_multipercep_latwinrate_reshadingtarget_6/graph.pth", test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2reshade2x_depthbaseline": dict( loss_configs=["baseline_depth_size256", "baseline_depth_size320", "baseline_depth_size384", "baseline_depth_size448", "baseline_depth_size512"], test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2normals2x_baseline": dict( loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2normals2x_multipercep": dict( loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], test=True, ood=True, oodfull=False, cont="mount/shared/results_LBP_multipercep_32/graph.pth", ), "VISUALS3_rgb2x2normals_baseline": dict( loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"], finetuned=False, test=True, ood=True, ood_full=False, ), "VISUALS3_rgb2x2normals_finetuned": dict( loss_configs=["rgb2x2normals_plots", "rgb2x_plots2normals_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"], finetuned=True, test=True, ood=True, ood_full=False, ), "VISUALS3_rgb2x_baseline": dict( loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"], finetuned=False, test=True, ood=True, ood_full=False, ), "VISUALS3_rgb2x_finetuned": dict( loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"], finetuned=True, test=True, ood=True, ood_full=False, ), } # configs = { # "VISUALS_rgb2normals2x_latv2": dict( # loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], # cont="mount/shared/results_LBP_multipercep_latv2_10/graph.pth", # ), # "VISUALS_rgb2normals2x_lat_winrate": dict( # loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], # cont="mount/shared/results_LBP_multipercep_lat_winrate_8/graph.pth", # ), # "VISUALS_rgb2normals2x_multipercep": dict( # loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], # cont="mount/shared/results_LBP_multipercep_32/graph.pth", # ), # "VISUALS_rgb2normals2x_rndv2": dict( # loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], # cont="mount/shared/results_LBP_multipercep_rnd_11/graph.pth", # ), # "VISUALS_rgb2normals2x_baseline": dict( # loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], # cont=None, # ), # "VISUALS_rgb2x2normals_baseline": dict( # loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"], # finetuned=False, # ), # "VISUALS_rgb2x2normals_finetuned": dict( # loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"], # finetuned=True, # ), # "VISUALS_y2normals_baseline": dict( # loss_configs=["y2normals_plots", "y2normals_plots_size320", "y2normals_plots_size384", "y2normals_plots_size448", "y2normals_plots_size512"], # finetuned=False, # ), # "VISUALS_y2normals_finetuned": dict( # loss_configs=["y2normals_plots", "y2normals_plots_size320", "y2normals_plots_size384", "y2normals_plots_size448", "y2normals_plots_size512"], # finetuned=True, # ), # "VISUALS_rgb2x_baseline": dict( # loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"], # finetuned=False, # ), # "VISUALS_rgb2x_finetuned": dict( # loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"], # finetuned=True, # ), # } for i in range(0, 5): config = configs[list(configs.keys())[0]] finetuned = config.get("finetuned", False) loss_configs = config["loss_configs"] loss_config = loss_configs[i] batch_size = batch_size or 32 energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING 1 test_set = load_test(energy_loss.get_tasks("test"), sample=8) ood_tasks = [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb'] ood_set = load_ood(ood_tasks, sample=4) print (ood_tasks) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, ood_tasks) # DATA LOADING 2 ood_tasks = list(set([tasks.rgb] + [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb'])) test_set = load_test(ood_tasks, sample=2) ood_set = load_ood(ood_tasks) test2 = RealityTask.from_static("test", test_set, ood_tasks) ood2 = RealityTask.from_static("ood", ood_set, ood_tasks) # DATA LOADING 3 test_set = load_test(energy_loss.get_tasks("test"), sample=8) ood_tasks = [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb'] ood_loader = torch.utils.data.DataLoader( ImageDataset(tasks=ood_tasks, data_dir=f"{SHARED_DIR}/ood_images"), batch_size=32, num_workers=32, shuffle=False, pin_memory=True ) data = list(itertools.islice(ood_loader, 2)) test_set = data[0] ood_set = data[1] test3 = RealityTask.from_static("test", test_set, ood_tasks) ood3 = RealityTask.from_static("ood", ood_set, ood_tasks) for name, config in configs.items(): finetuned = config.get("finetuned", False) loss_configs = config["loss_configs"] cont = config.get("cont", None) logger = VisdomLogger("train", env=name, delete=True if i == 0 else False) if config.get("test", False): # GRAPH realities = [test, ood] print ("Finetuned: ", finetuned) graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True) if cont is not None: graph.load_weights(cont) # LOGGING energy_loss.plot_paths_errors(graph, logger, realities, prefix=loss_config) logger = VisdomLogger("train", env=name + "_ood", delete=True if i == 0 else False) if config.get("ood", False): # GRAPH realities = [test2, ood2] print ("Finetuned: ", finetuned) graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True) if cont is not None: graph.load_weights(cont) energy_loss.plot_paths(graph, logger, realities, prefix=loss_config) logger = VisdomLogger("train", env=name + "_oodfull", delete=True if i == 0 else False) if config.get("oodfull", False): # GRAPH realities = [test3, ood3] print ("Finetuned: ", finetuned) graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True) if cont is not None: graph.load_weights(cont) energy_loss.plot_paths(graph, logger, realities, prefix=loss_config)
def main(): task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.sobel_edges, tasks.depth_zbuffer, tasks.reshading, tasks.edge_occlusion, tasks.keypoints3d, tasks.keypoints2d, ] reality = RealityTask('almena', dataset=TaskDataset(buildings=['almena'], tasks=[ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer ]), tasks=[ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer ], batch_size=4) graph = TaskGraph( tasks=[reality, *task_list], anchored_tasks=[reality, tasks.rgb], reality=reality, batch_size=4, edges_exclude=[ ('almena', 'normal'), ('almena', 'principal_curvature'), ('almena', 'depth_zbuffer'), # ('rgb', 'keypoints3d'), # ('rgb', 'edge_occlusion'), ], initialize_first_order=True, ) graph.p.compile(torch.optim.Adam, lr=4e-2) graph.estimates.compile(torch.optim.Adam, lr=1e-2) logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="energy", freq=16) logger.add_hook( lambda logger, data: logger.plot(data["energy"], "free_energy"), feature="energy", freq=100) logger.add_hook(lambda logger, data: graph.plot_estimates(logger), feature="epoch", freq=32) logger.add_hook(lambda logger, data: graph.update_paths(logger), feature="epoch", freq=32) graph.plot_estimates(logger) graph.plot_paths(logger, dest_tasks=[ tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer ], show_images=False) for epochs in range(0, 4000): logger.update("epoch", epochs) with torch.no_grad(): free_energy = graph.free_energy(sample=16) graph.averaging_step() logger.update("energy", free_energy)
def main( fast=False, batch_size=None, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 32) energy_loss = get_energy_loss(config="consistency_two_path", mode="standard", **kwargs) # LOGGING logger = VisdomLogger("train", env=JOB) # DATA LOADING video_dataset = ImageDataset( files=sorted( glob.glob(f"mount/taskonomy_house_tour/original/image*.png"), key=lambda x: int(os.path.basename(x)[5:-4])), return_tuple=True, resize=720, ) video = RealityTask("video", video_dataset, [ tasks.rgb, ], batch_size=batch_size, shuffle=False) # GRAPHS graph_baseline = TaskGraph(tasks=energy_loss.tasks + [video], finetuned=False) graph_baseline.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph_finetuned = TaskGraph(tasks=energy_loss.tasks + [video], finetuned=True) graph_finetuned.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph_conservative = TaskGraph(tasks=energy_loss.tasks + [video], finetuned=True) graph_conservative.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph_conservative.load_weights( f"{MODELS_DIR}/conservative/conservative.pth") graph_ood_conservative = TaskGraph(tasks=energy_loss.tasks + [video], finetuned=True) graph_ood_conservative.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph_ood_conservative.load_weights( f"{SHARED_DIR}/results_2F_grounded_1percent_gt_twopath_512_256_crop_7/graph_grounded_1percent_gt_twopath.pth" ) graphs = { "baseline": graph_baseline, "finetuned": graph_finetuned, "conservative": graph_conservative, "ood_conservative": graph_ood_conservative, } inv_transform = transforms.ToPILImage() data = {key: {"losses": [], "zooms": []} for key in graphs} size = 256 for batch in range(0, 700): if batch * batch_size > len(video_dataset.files): break frac = (batch * batch_size * 1.0) / len(video_dataset.files) if frac < 0.3: size = int(256.0 - 128 * frac / 0.3) elif frac < 0.5: size = int(128.0 + 128 * (frac - 0.3) / 0.2) else: size = int(256.0 + (720 - 256) * (frac - 0.5) / 0.5) print(size) # video.reload() size = (size // 32) * 32 print(size) video.step() video.task_data[tasks.rgb] = resize( video.task_data[tasks.rgb].to(DEVICE), size).data print(video.task_data[tasks.rgb].shape) with torch.no_grad(): for i, img in enumerate(video.task_data[tasks.rgb]): inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save( f"mount/taskonomy_house_tour/distorted/image{batch*batch_size + i}.png" ) for name, graph in graphs.items(): normals = graph.sample_path([tasks.rgb, tasks.normal], reality=video) normals2 = graph.sample_path( [tasks.rgb, tasks.principal_curvature, tasks.normal], reality=video) for i, img in enumerate(normals): energy, _ = tasks.normal.norm(normals[i:(i + 1)], normals2[i:(i + 1)]) data[name]["losses"] += [energy.data.cpu().numpy().mean()] data[name]["zooms"] += [size] inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save( f"mount/taskonomy_house_tour/normals_{name}/image{batch*batch_size + i}.png" ) for i, img in enumerate(normals2): inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save( f"mount/taskonomy_house_tour/path2_{name}/image{batch*batch_size + i}.png" ) pickle.dump(data, open(f"mount/taskonomy_house_tour/data.pkl", 'wb')) os.system("bash ~/scaling/scripts/create_vids.sh")
def main( loss_config="baseline", mode="standard", visualize=False, fast=False, batch_size=None, path=None, subset_size=None, early_stopping=float('inf'), max_epochs=800, **kwargs ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) print('tasks', energy_loss.get_tasks("ood")) ood_set = load_ood(energy_loss.get_tasks("ood")) print (train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=False, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge(tasks.rgb, tasks.normal).path = path graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) graph.save(weights_dir=f"{RESULTS_DIR}") # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # return # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss: print ("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][-1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(weights_dir=f"{RESULTS_DIR}") if stop_idx >= early_stopping: print ("Stopping training now") return
def main( pretrained=True, finetuned=False, fast=False, batch_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", max_epochs=800, **kwargs, ): task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer, # tasks.sobel_edges, ] # CONFIG batch_size = batch_size or (4 if fast else 64) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( task_list, batch_size=batch_size, fast=fast, ) test_set = load_test(task_list) train_step, val_step = train_step // 4, val_step // 4 print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, task_list) # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # GRAPH realities = [train, val, test] graph = TaskGraph(tasks=task_list + realities, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) logger.add_hook(partial(jointplot, loss_type=f"energy"), feature=f"val_energy", freq=1) def path_name(path): if len(path) == 1: print(path) return str(path[0]) if str((path[-2].name, path[-1].name)) not in graph.edge_map: return None sub_name = path_name(path[:-1]) if sub_name is None: return None return f'{get_transfer_name(Transfer(path[-2], path[-1]))}({sub_name})' for i in range(0, 20): gt_paths = { path_name(path): list(path) for path in itertools.permutations(task_list, 1) if path_name(path) is not None } baseline_paths = { path_name(path): list(path) for path in itertools.permutations(task_list, 2) if path_name(path) is not None } paths = { path_name(path): list(path) for path in itertools.permutations(task_list, 3) if path_name(path) is not None } selected_paths = dict(random.sample(paths.items(), k=3)) print("Chosen paths: ", selected_paths) loss_config = { "paths": { **gt_paths, **baseline_paths, **paths }, "losses": { "baseline_mse": { ( "train", "val", ): [(path_name, str(path[-1])) for path_name, path in baseline_paths.items()] }, "mse": { ( "train", "val", ): [(path_name, str(path[-1])) for path_name, path in selected_paths.items()] }, "eval_mse": { ("val", ): [(path_name, str(path[-1])) for path_name, path in paths.items()] } }, "plots": { "ID": dict(size=256, realities=("test", ), paths=[ path_name for path_name, path in selected_paths.items() ] + [ str(path[-1]) for path_name, path in selected_paths.items() ]), }, } energy_loss = EnergyLoss(**loss_config) energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, 5): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], loss_types=["mse", "baseline_mse"]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val], loss_types=["mse", "baseline_mse"]) val_loss = sum( [val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) val_loss = energy_loss(graph, realities=[val], loss_types=["eval_mse"]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("train_energy", val_loss) logger.update("val_energy", val_loss) energy_loss.logger_update(logger) logger.step()
def main( job_config="jobinfo.txt", models_dir="models", fast=False, batch_size=None, subset_size=None, max_epochs=500, dataaug=False, **kwargs, ): loss_config, loss_mode, model_class = None, None, None experiment, base_dir = None, None current_dir = os.path.dirname(__file__) job_config = os.path.normpath( os.path.join(os.path.join(current_dir, "config"), job_config)) if os.path.isfile(job_config): with open(job_config) as config_file: out = config_file.read().strip().split(',\n') loss_config, loss_mode, model_class, experiment, base_dir = out loss_config = loss_config or LOSS_CONFIG loss_mode = loss_mode or LOSS_MODE model_class = model_class or MODEL_CLASS base_dir = base_dir or BASE_DIR base_dir = os.path.normpath(os.path.join(current_dir, base_dir)) experiment = experiment or EXPERIMENT job = "_".join(experiment.split("_")[0:-1]) models_dir = os.path.join(base_dir, models_dir) results_dir = f"{base_dir}/results/results_{experiment}" results_dir_models = f"{base_dir}/results/results_{experiment}/models" # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, loss_mode=loss_mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, dataaug=dataaug, ) if fast: train_dataset = val_dataset train_step, val_step = 2, 2 train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville', 'espanola']) ood_set = load_ood(energy_loss.get_tasks("ood")) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) realities = [train, val, test, ood] # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, tasks_in=energy_loss.tasks_in, tasks_out=energy_loss.tasks_out, pretrained=True, models_dir=models_dir, freeze_list=energy_loss.freeze_list, direct_edges=energy_loss.direct_edges, model_class=model_class) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING os.makedirs(results_dir, exist_ok=True) os.makedirs(results_dir_models, exist_ok=True) logger = VisdomLogger("train", env=job, port=PORT, server=SERVER) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{results_dir}/graph.pth", results_dir_models), feature="epoch", freq=1) energy_loss.logger_hooks(logger) energy_loss.plot_paths(graph, logger, realities, prefix="start") # BASELINE graph.eval() with torch.no_grad(): for _ in range(0, val_step * 4): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step * 4): train_loss, _ = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="finish") graph.train() for _ in range(0, train_step): train_loss, grad_mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True) graph.step(train_loss, losses=energy_loss.losses, paths=energy_loss.paths) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) del train_loss print(grad_mse_coeff) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="multiperceptual", mode="standard", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=5000, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({"loss_config":loss_config,"batch_size":batch_size,"data_aug":dataaug,"lr":"3e-5", "n_gauss":1,"distribution":"laplace"}) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, val_noaug_dataset, train_step, val_step = load_train_val_merging( energy_loss.get_tasks("train_c"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood"), ood_path='./assets/ood_natural/') ood_syn_aug_set = load_ood(energy_loss.get_tasks("ood_syn_aug"), ood_path='./assets/st_syn_distortions/') ood_syn_set = load_ood(energy_loss.get_tasks("ood_syn"), ood_path='./assets/ood_syn_distortions/', sample=35) train = RealityTask("train_c", train_dataset, batch_size=batch_size, shuffle=True) # distorted and undistorted val = RealityTask("val_c", val_dataset, batch_size=batch_size, shuffle=True) # distorted and undistorted val_noaug = RealityTask("val", val_noaug_dataset, batch_size=batch_size, shuffle=True) # no augmentation test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) ## standard ood set - natural ood_syn_aug = RealityTask.from_static("ood_syn_aug", ood_syn_aug_set, [tasks.rgb,]) ## synthetic distortion images used for sig training ood_syn = RealityTask.from_static("ood_syn", ood_syn_set, [tasks.rgb,]) ## unseen syn distortions # GRAPH realities = [train, val, val_noaug, test, ood, ood_syn_aug, ood_syn] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) graph.eval() path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0) with torch.no_grad(): for reality in [val,val_noaug]: for _ in range(0, val_step): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=0) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for reality in [val,val_noaug]: for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] del data['epoch'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=epochs+1) if epochs % 10 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth") if epochs % 25 == 0: path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs+1) graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")
def main( loss_config="conservative_full", mode="standard", pretrained=True, finetuned=False, batch_size=16, ood_batch_size=None, subset_size=None, cont=None, use_l1=True, num_workers=32, data_dir=None, save_dir='mount/shared/', **kwargs, ): # CONFIG energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) if data_dir is None: buildings = ["almena", "albertville"] train_subset_dataset = TaskDataset(buildings, tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature]) else: train_subset_dataset = ImageDataset(data_dir=data_dir) data_dir = 'CUSTOM' train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=False) if subset_size is None: subset_size = len(train_subset_dataset) subset_size = min(subset_size, len(train_subset_dataset)) # GRAPH realities = [train_subset] edges = [] for t in energy_loss.tasks: if t != tasks.rgb: edges.append((tasks.rgb, t)) edges.append((tasks.rgb, tasks.normal)) graph = TaskGraph(tasks=energy_loss.tasks + [train_subset], finetuned=finetuned, freeze_list=energy_loss.freeze_list, lazy=True, initialize_from_transfer=True, ) # print('file', cont) #graph.load_weights(cont) graph.compile(optimizer=None) # Add consistency links for target in ['reshading', 'depth_zbuffer', 'normal']: graph.edge_map[str(('rgb', target))].path = None graph.edge_map[str(('rgb', target))].load_model() graph.edge_map[str(('rgb', 'reshading'))].model.load_weights('./models/rgb2reshading_consistency.pth',backward_compatible=True) graph.edge_map[str(('rgb', 'depth_zbuffer'))].model.load_weights('./models/rgb2depth_consistency.pth',backward_compatible=True) graph.edge_map[str(('rgb', 'normal'))].model.load_weights('./models/rgb2normal_consistency.pth',backward_compatible=True) energy_losses, mse_losses = [], [] percep_losses = defaultdict(list) energy_mean_by_blur, energy_std_by_blur = [], [] error_mean_by_blur, error_std_by_blur = [], [] energy_losses, error_losses = [], [] energy_losses_all, energy_losses_headings = [], [] fnames = [] train_subset.reload() # Compute energies for epochs in tqdm(range(subset_size // batch_size)): with torch.no_grad(): losses = energy_loss(graph, realities=[train_subset], reduce=False, use_l1=use_l1) if len(energy_losses_headings) == 0: energy_losses_headings = sorted([loss_name for loss_name in losses if 'percep' in loss_name]) all_perceps = [losses[loss_name].cpu().numpy() for loss_name in energy_losses_headings] direct_losses = [losses[loss_name].cpu().numpy() for loss_name in losses if 'direct' in loss_name] if len(all_perceps) > 0: energy_losses_all += [all_perceps] all_perceps = np.stack(all_perceps) energy_losses += list(all_perceps.mean(0)) if len(direct_losses) > 0: direct_losses = np.stack(direct_losses) error_losses += list(direct_losses.mean(0)) if False: fnames += train_subset.task_data[tasks.filename] train_subset.step() # log losses if len(energy_losses) > 0: energy_losses = np.array(energy_losses) print(f'energy = {energy_losses.mean()}') energy_mean_by_blur += [energy_losses.mean()] energy_std_by_blur += [np.std(energy_losses)] if len(error_losses) > 0: error_losses = np.array(error_losses) print(f'error = {error_losses.mean()}') error_mean_by_blur += [error_losses.mean()] error_std_by_blur += [np.std(error_losses)] # save to csv save_error_losses = error_losses if len(error_losses) > 0 else [0] * subset_size save_energy_losses = energy_losses if len(energy_losses) > 0 else [0] * subset_size z_score = lambda x: (x - x.mean()) / x.std() def get_standardized_energy(df, use_std=False, compare_to_in_domain=False): percepts = [c for c in df.columns if 'percep' in c] stdize = lambda x: (x - x.mean()).abs().mean() means = {k: df[k].mean() for k in percepts} stds = {k: stdize(df[k]) for k in percepts} stdized = {k: (df[k] - means[k])/stds[k] for k in percepts} energies = np.stack([v for k, v in stdized.items() if k[-1] == '_' or '__' in k]).mean(0) return energies os.makedirs(save_dir, exist_ok=True) if data_dir is 'CUSTOM': eng_curr = np.array(energy_losses).mean() df = pd.read_csv(os.path.join(save_dir, 'data.csv')) else: percep_losses = { k: v for k, v in zip(energy_losses_headings, np.concatenate(energy_losses_all, axis=-1))} df = pd.DataFrame(both( {'energy': save_energy_losses, 'error': save_error_losses }, percep_losses )) # compuate correlation df['normalized_energy'] = get_standardized_energy(df, use_std=False) df['normalized_error'] = z_score(df['error']) print(scipy.stats.spearmanr(z_score(df['error']), df['normalized_energy'])) print("Pearson r:", scipy.stats.pearsonr(df['error'], df['normalized_energy'])) if data_dir is not 'CUSTOM': df.to_csv(f"{save_dir}/data.csv", mode='w', header=True) # plot correlation plt.figure(figsize=(4,4)) g = sns.regplot(df['normalized_error'], df['normalized_energy'],robust=False) if data_dir is 'CUSTOM': ax1 = g.axes ax1.axhline(eng_curr, ls='--', color='red') ax1.text(0.5, 25, "Query Image Energy Line") plt.xlabel('Error (z-score)') plt.ylabel('Energy (z-score)') plt.title('') plt.savefig(f'{save_dir}/energy.pdf')
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) print (kwargs) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size ) train_step, val_step = 24, 12 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood([tasks.rgb,]) print (train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=1e-6, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() # logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() # logger.update("loss", train_loss) energy_loss.logger_update(logger) for param_group in graph.optimizer.param_groups: param_group['lr'] *= 1.2 print ("LR: ", param_group['lr']) logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=64, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) ood_batch_size = ood_batch_size or batch_size energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( [tasks.rgb, tasks.normal, tasks.principal_curvature], return_dataset=True, batch_size=batch_size, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=256, ) ood_consistency_dataset, _, _, _ = load_train_val( [ tasks.rgb, ], return_dataset=True, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=512, ) train_subset_dataset, _, _, _ = load_train_val( [ tasks.rgb, tasks.normal, ], return_dataset=True, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=512, subset_size=subset_size, ) train_step, val_step = train_step // 4, val_step // 4 if fast: train_step, val_step = 20, 20 test_set = load_test([tasks.rgb, tasks.normal, tasks.principal_curvature]) ood_images = load_ood() ood_images_large = load_ood(resize=512, sample=8) ood_consistency_test = load_test([ tasks.rgb, ], resize=512) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) ood_consistency = RealityTask("ood_consistency", ood_consistency_dataset, batch_size=ood_batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, tasks=[tasks.rgb, tasks.normal], batch_size=ood_batch_size, shuffle=False) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static( "test", test_set, [tasks.rgb, tasks.normal, tasks.principal_curvature]) ood_test = RealityTask.from_static("ood_test", (ood_images, ), [ tasks.rgb, ]) ood_test_large = RealityTask.from_static("ood_test_large", (ood_images_large, ), [ tasks.rgb, ]) ood_consistency_test = RealityTask.from_static("ood_consistency_test", ood_consistency_test, [ tasks.rgb, ]) realities = [ train, val, train_subset, ood_consistency, test, ood_test, ood_test_large, ood_consistency_test ] energy_loss.load_realities(realities) # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # graph.load_weights(f"{MODELS_DIR}/conservative/conservative.pth") graph.load_weights( f"{SHARED_DIR}/results_2FF_train_subset_512_true_baseline_3/graph_baseline.pth" ) print(graph) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook( lambda _, __: graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth"), feature="epoch", freq=1, ) graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth") energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, 800): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): train.step() train_loss = energy_loss(graph, reality=train) graph.step(train_loss) logger.update("loss", train_loss) train_subset.step() train_subset_loss = energy_loss(graph, reality=train_subset) graph.step(train_subset_loss) ood_consistency.step() ood_consistency_loss = energy_loss(graph, reality=ood_consistency) if ood_consistency_loss is not None: graph.step(ood_consistency_loss) graph.eval() for _ in range(0, val_step): val.step() with torch.no_grad(): val_loss = energy_loss(graph, reality=val) logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{BASE_DIR}/shared/results_LBP_multipercep_lat_winrate_8/graph.pth", cont_gan=None, pre_gan=None, max_epochs=800, use_baseline=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), energy_loss.get_tasks("val"), batch_size=batch_size, fast=fast, ) train_subset_dataset, _, _, _ = load_train_val( energy_loss.get_tasks("train_subset"), batch_size=batch_size, fast=fast, subset_size=subset_size) if not fast: train_step, val_step = train_step // (16 * 4), val_step // (16) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) # ood = RealityTask.from_static("ood", ood_set, [energy_loss.get_tasks("ood")]) # GRAPH realities = [ train, val, test, ] + [train_subset] #[ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned, freeze_list=energy_loss.freeze_list) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph.load_weights(cont) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) # logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') energy_losses = [] mse_losses = [] pearsonr_vals = [] percep_losses = defaultdict(list) pearson_percep = defaultdict(list) # # TRAINING # for epochs in range(0, max_epochs): # logger.update("epoch", epochs) # if epochs == 0: # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") # # if visualize: return # graph.eval() # for _ in range(0, val_step): # with torch.no_grad(): # losses = energy_loss(graph, realities=[val]) # all_perceps = [losses[loss_name] for loss_name in losses if 'percep' in loss_name ] # energy_avg = sum(all_perceps) / len(all_perceps) # for loss_name in losses: # if 'percep' not in loss_name: continue # percep_losses[loss_name] += [losses[loss_name].data.cpu().numpy()] # mse = losses['mse'] # energy_losses.append(energy_avg.data.cpu().numpy()) # mse_losses.append(mse.data.cpu().numpy()) # val.step() # mse_arr = np.array(mse_losses) # energy_arr = np.array(energy_losses) # # logger.scatter(mse_arr - mse_arr.mean() / np.std(mse_arr), \ # # energy_arr - energy_arr.mean() / np.std(energy_arr), \ # # 'unit_normal_all', opts={'xlabel':'mse','ylabel':'energy'}) # logger.scatter(mse_arr, energy_arr, \ # 'mse_energy_all', opts={'xlabel':'mse','ylabel':'energy'}) # pearsonr, p = scipy.stats.pearsonr(mse_arr, energy_arr) # logger.text(f'pearsonr = {pearsonr}, p = {p}') # pearsonr_vals.append(pearsonr) # logger.plot(pearsonr_vals, 'pearsonr_all') # for percep_name in percep_losses: # percep_loss_arr = np.array(percep_losses[percep_name]) # logger.scatter(mse_arr, percep_loss_arr, f'mse_energy_{percep_name}', \ # opts={'xlabel':'mse','ylabel':'energy'}) # pearsonr, p = scipy.stats.pearsonr(mse_arr, percep_loss_arr) # pearson_percep[percep_name] += [pearsonr] # logger.plot(pearson_percep[percep_name], f'pearson_{percep_name}') # energy_loss.logger_update(logger) # if logger.data['val_mse : n(~x) -> y^'][-1] < best_ood_val_loss: # best_ood_val_loss = logger.data['val_mse : n(~x) -> y^'][-1] # energy_loss.plot_paths(graph, logger, realities, prefix="best") energy_mean_by_blur = [] energy_std_by_blur = [] mse_mean_by_blur = [] mse_std_by_blur = [] for blur_size in np.arange(0, 10, 0.5): tasks.rgb.blur_radius = blur_size if blur_size > 0 else None train_subset.step() # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") energy_losses = [] mse_losses = [] for epochs in range(subset_size // batch_size): with torch.no_grad(): flosses = energy_loss(graph, realities=[train_subset], reduce=False) losses = energy_loss(graph, realities=[train_subset], reduce=False) all_perceps = np.stack([ losses[loss_name].data.cpu().numpy() for loss_name in losses if 'percep' in loss_name ]) energy_losses += list(all_perceps.mean(0)) mse_losses += list(losses['mse'].data.cpu().numpy()) train_subset.step() mse_losses = np.array(mse_losses) energy_losses = np.array(energy_losses) logger.text( f'blur_radius = {blur_size}, mse = {mse_losses.mean()}, energy = {energy_losses.mean()}' ) logger.scatter(mse_losses, energy_losses, \ f'mse_energy, blur = {blur_size}', opts={'xlabel':'mse','ylabel':'energy'}) energy_mean_by_blur += [energy_losses.mean()] energy_std_by_blur += [np.std(energy_losses)] mse_mean_by_blur += [mse_losses.mean()] mse_std_by_blur += [np.std(mse_losses)] logger.plot(energy_mean_by_blur, f'energy_mean_by_blur') logger.plot(energy_std_by_blur, f'energy_std_by_blur') logger.plot(mse_mean_by_blur, f'mse_mean_by_blur') logger.plot(mse_std_by_blur, f'mse_std_by_blur')
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) train_step, val_step = train_step // 4, val_step // 4 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) print("Train step: ", train_step, "Val step: ", val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=True, freeze_list=[functional_transfers.a, functional_transfers.RC], ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) activated_triangles = set() triangle_energy = { "triangle1_mse": float('inf'), "triangle2_mse": float('inf') } logger.add_hook(partial(jointplot, loss_type=f"energy"), feature=f"val_energy", freq=1) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): # loss_type = random.choice(["triangle1_mse", "triangle2_mse"]) loss_type = max(triangle_energy, key=triangle_energy.get) activated_triangles.add(loss_type) train_loss = energy_loss(graph, realities=[train], loss_types=[loss_type]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() if loss_type == "triangle1_mse": consistency_tr1 = energy_loss.metrics["train"][ "triangle1_mse : F(RC(x)) -> n(x)"][-1] error_tr1 = energy_loss.metrics["train"][ "triangle1_mse : n(x) -> y^"][-1] triangle_energy["triangle1_mse"] = float(consistency_tr1 / error_tr1) elif loss_type == "triangle2_mse": consistency_tr2 = energy_loss.metrics["train"][ "triangle2_mse : S(a(x)) -> n(x)"][-1] error_tr2 = energy_loss.metrics["train"][ "triangle2_mse : n(x) -> y^"][-1] triangle_energy["triangle2_mse"] = float(consistency_tr2 / error_tr2) print("Triangle energy: ", triangle_energy) logger.update("loss", train_loss) energy = sum(triangle_energy.values()) if (energy < float('inf')): logger.update("train_energy", energy) logger.update("val_energy", energy) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val], loss_types=list(activated_triangles)) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) activated_triangles = set() energy_loss.logger_update(logger) logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train_step, val_step = 4, 4 print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, freeze_list=energy_loss.freeze_list, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not USE_RAID: graph.load_weights(cont) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix=f"epoch_{epochs}") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.select_losses(val) if epochs != 0: energy_loss.logger_update(logger) else: energy_loss.metrics = {} logger.step() logger.text(f"Chosen losses: {energy_loss.chosen_losses}") logger.text(f"Percep winrate: {energy_loss.percep_winrate}") graph.train() for _ in range(0, train_step): train_loss2 = energy_loss(graph, realities=[train]) train_loss = sum(train_loss2.values()) graph.step(train_loss) train.step() logger.update("loss", train_loss)
def main( loss_config="multiperceptual", mode="standard", visualize=False, fast=False, batch_size=None, resume=False, learning_rate=3e-5, subset_size=None, max_epochs=2000, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({ "loss_config": loss_config, "batch_size": batch_size, "data_aug": dataaug, "lr": learning_rate }) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if resume: graph.load_weights('/workspace/shared/results_test_1/graph.pth') graph.optimizer.load_state_dict( torch.load('/workspace/shared/results_test_1/opt.pth')) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) ######## baseline computation if not resume: graph.eval() with torch.no_grad(): for _ in range(0, val_step): val_loss, _, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step): train_loss, _, _ = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data = logger.step() del data['loss'] data = {k: v[0] for k, v in data.items()} wandb.log(data, step=0) path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0) ########### # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss, coeffs, avg_grads = energy_loss( graph, realities=[train], compute_grad_ratio=True) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data = logger.step() del data['loss'] del data['epoch'] data = {k: v[0] for k, v in data.items()} wandb.log(data, step=epochs) wandb.log(coeffs, step=epochs) wandb.log(avg_grads, step=epochs) if epochs % 5 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth") if epochs % 10 == 0: path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs + 1) graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", cont_gan=None, pre_gan=None, max_epochs=800, use_baseline=False, use_patches=False, patch_frac=None, patch_size=64, patch_sigma=0, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) train_subset_dataset, _, _, _ = load_train_val( energy_loss.get_tasks("train_subset"), batch_size=batch_size, fast=fast, subset_size=subset_size) train_step, val_step = train_step // 16, val_step // 16 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, train_subset, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not USE_RAID and not use_baseline: graph.load_weights(cont) pre_gan = pre_gan or 1 discriminator = Discriminator(energy_loss.losses['gan'], frac=patch_frac, size=(patch_size if use_patches else 224), sigma=patch_sigma, use_patches=use_patches) if cont_gan is not None: discriminator.load_weights(cont_gan) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) logger.add_hook( lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') logger.add_hook(partial(jointplot, loss_type=f"gan_subset"), feature=f"val_gan_subset", freq=1) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() discriminator.train() for _ in range(0, train_step): if epochs > pre_gan: train_loss = energy_loss(graph, discriminator=discriminator, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse']) # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1]) # train.step() # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan']) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # train.step() # graph.step(train_loss1 + train_loss2) # logger.update("loss", train_loss1 + train_loss2) # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_id']) # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1]) # graph.step(train_loss1) # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_ood']) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # graph.step(train_loss2) # train_loss3 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan']) # train_loss3 = sum([train_loss3[loss_name] for loss_name in train_loss3]) # graph.step(train_loss3) # logger.update("loss", train_loss1 + train_loss2 + train_loss3) # train.step() # graph fooling loss # n(~x), and y^ (128 subset) # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train]) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # train_loss = train_loss1 + train_loss2 warmup = 5 # if epochs < pre_gan else 1 for i in range(warmup): # y_hat = graph.sample_path([tasks.normal(size=512)], reality=train_subset) # n_x = graph.sample_path([tasks.rgb(size=512), tasks.normal(size=512)], reality=train) y_hat = graph.sample_path([tasks.normal], reality=train_subset) n_x = graph.sample_path( [tasks.rgb(blur_radius=6), tasks.normal(blur_radius=6)], reality=train) def coeff_hook(coeff): def fun1(grad): return coeff * grad.clone() return fun1 logit_path1 = discriminator(y_hat.detach()) coeff = 0.1 path_value2 = n_x * 1.0 path_value2.register_hook(coeff_hook(coeff)) logit_path2 = discriminator(path_value2) binary_label = torch.Tensor( [1] * logit_path1.size(0) + [0] * logit_path2.size(0)).float().cuda() gan_loss = nn.BCEWithLogitsLoss(size_average=True)(torch.cat( (logit_path1, logit_path2), dim=0).view(-1), binary_label) discriminator.discriminator.step(gan_loss) logger.update("train_gan_subset", gan_loss) logger.update("val_gan_subset", gan_loss) # print ("Gan loss: ", (-gan_loss).data.cpu().numpy()) train.step() train_subset.step() graph.eval() discriminator.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, discriminator=discriminator, realities=[val, train_subset]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) if epochs > pre_gan: energy_loss.logger_update(logger) logger.step() if logger.data["train_subset_val_ood : y^ -> n(~x)"][ -1] < best_ood_val_loss: best_ood_val_loss = logger.data[ "train_subset_val_ood : y^ -> n(~x)"][-1] energy_loss.plot_paths(graph, logger, realities, prefix="best")
def main( fast=False, subset_size=None, early_stopping=float('inf'), mode='standard', max_epochs=800, **kwargs, ): early_stopping = 8 loss_config_percepnet = { "paths": { "y": [tasks.normal], "z^": [tasks.principal_curvature], "f(y)": [tasks.normal, tasks.principal_curvature], }, "losses": { "mse": { ("train", "val"): [ ("f(y)", "z^"), ], }, }, "plots": { "ID": dict(size=256, realities=("test", "ood"), paths=[ "y", "z^", "f(y)", ]), }, } # CONFIG batch_size = 64 energy_loss = EnergyLoss(**loss_config_percepnet) task_list = [tasks.rgb, tasks.normal, tasks.principal_curvature] # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( task_list, batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(task_list) ood_set = load_ood(task_list) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, task_list) ood = RealityTask.from_static("ood", ood_set, task_list) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities, pretrained=False, freeze_list=[functional_transfers.n], ) graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : f(y) -> z^"][-1] < best_val_loss: print("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : f(y) -> z^"][ -1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(weights_dir=f"{RESULTS_DIR}") if stop_idx >= early_stopping: print("Stopping training now") break early_stopping = 50 # CONFIG energy_loss = get_energy_loss(config="perceptual", mode=mode, **kwargs) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities, pretrained=False, freeze_list=[functional_transfers.f], ) graph.edge( tasks.normal, tasks.principal_curvature).model.load_weights(f"{RESULTS_DIR}/f.pth") graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss: print("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][ -1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(f"{RESULTS_DIR}/graph.pth") if stop_idx >= early_stopping: print("Stopping training now") break
def main( loss_config="baseline_normal", mode="standard", visualize=False, fast=False, batch_size=None, learning_rate=5e-4, subset_size=None, max_epochs=5000, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({ "loss_config": loss_config, "batch_size": batch_size, "data_aug": dataaug, "lr": learning_rate }) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=learning_rate, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) if (epochs % 100 == 0) or (epochs % 10 == 0 and epochs < 30): path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) data = logger.step() del data['loss'] del data['epoch'] data = {k: v[0] for k, v in data.items()} wandb.log(data, step=epochs) # save model and opt state every 10 epochs if epochs % 10 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth") # lower lr after 1500 epochs if epochs == 1500: graph.optimizer.param_groups[0]['lr'] = 3e-5 graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")
def main( loss_config="multiperceptual", mode="winrate", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=800, dataaug=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, dataaug=dataaug, ) if fast: train_dataset = val_dataset train_step, val_step = 2,2 train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) if fast: train_dataset = val_dataset train_step, val_step = 2,2 realities = [train, val] else: test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville']) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) realities = [train, val, test] # If you wanted to just do some qualitative predictions on inputs w/o labels, you could do: # ood_set = load_ood(energy_loss.get_tasks("ood")) # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # realities.append(ood) # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, initialize_from_transfer=False, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING os.makedirs(RESULTS_DIR, exist_ok=True) logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) energy_loss.plot_paths(graph, logger, realities, prefix="start") # BASELINE graph.eval() with torch.no_grad(): for _ in range(0, val_step*4): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step*4): train_loss, _ = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="") if visualize: return graph.train() for _ in range(0, train_step): train_loss, mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def run_viz_suite(name, data_loader, dest_task=tasks.depth_zbuffer, graph_file=None, model_file=None, old=False, multitask=False, percep_mode=None, downsample=6, out_channels=3, final_task=tasks.normal, oldpercep=False): extra_task = [final_task] if percep_mode else [] if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, dest_task] + extra_task, pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, dest_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif multitask: model = DataParallelModel.load( UNet(downsample=5, out_channels=6).cuda(), model_file) elif model_file is not None: # downsample = 5 or 6 print('loading main model') #model = DataParallelModel.load(UNetReshade(downsample=downsample, out_channels=out_channels).cuda(), model_file) model = DataParallelModel.load( UNet(downsample=downsample, out_channels=out_channels).cuda(), model_file) #model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = DummyModel( Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model()) model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) # DATA LOADING 1 results = [] final_preds = [] if percep_mode: print('Loading percep model...') if graph_file is not None and not oldpercep: percep_model = graph.edge(dest_task, final_task).load_model() percep_model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) else: percep_model = Transfer(src_task=dest_task, dest_task=final_task).load_model() percep_model.eval() print("Converting...") for data, in data_loader: preds = model.predict_on_batch(data)[:, -3:].clamp(min=0, max=1) results.append(preds.detach().cpu()) if percep_mode: try: final_preds += [ percep_model.forward(preds[:, -3:]).detach().cpu() ] except RuntimeError: preds = torch.cat([preds] * 3, dim=1) final_preds += [ percep_model.forward(preds[:, -3:]).detach().cpu() ] #break if percep_mode: results = torch.cat(final_preds, dim=0) else: results = torch.cat(results, dim=0) return results
def main( loss_config="baseline", mode="standard", visualize=False, fast=False, batch_size=None, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) train_step, val_step = 4 * train_step, 4 * val_step test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge(tasks.rgb, tasks.normal ).path = f"{SHARED_DIR}/results_SAMPLEFF_consistency1m_25/n.pth" graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) # TRAINING graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() # print ("Train mse: ", logger.data["train_mse : n(x) -> y^"]) print("Val mse: ", logger.data["val_mse : n(x) -> y^"])