def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train_step, val_step = 4, 4 print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, freeze_list=energy_loss.freeze_list, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not USE_RAID: graph.load_weights(cont) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix=f"epoch_{epochs}") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.select_losses(val) if epochs != 0: energy_loss.logger_update(logger) else: energy_loss.metrics = {} logger.step() logger.text(f"Chosen losses: {energy_loss.chosen_losses}") logger.text(f"Percep winrate: {energy_loss.percep_winrate}") graph.train() for _ in range(0, train_step): train_loss2 = energy_loss(graph, realities=[train]) train_loss = sum(train_loss2.values()) graph.step(train_loss) train.step() logger.update("loss", train_loss)
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", cont_gan=None, pre_gan=None, max_epochs=800, use_baseline=False, use_patches=False, patch_frac=None, patch_size=64, patch_sigma=0, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) train_subset_dataset, _, _, _ = load_train_val( energy_loss.get_tasks("train_subset"), batch_size=batch_size, fast=fast, subset_size=subset_size) train_step, val_step = train_step // 16, val_step // 16 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, train_subset, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not USE_RAID and not use_baseline: graph.load_weights(cont) pre_gan = pre_gan or 1 discriminator = Discriminator(energy_loss.losses['gan'], frac=patch_frac, size=(patch_size if use_patches else 224), sigma=patch_sigma, use_patches=use_patches) if cont_gan is not None: discriminator.load_weights(cont_gan) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) logger.add_hook( lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') logger.add_hook(partial(jointplot, loss_type=f"gan_subset"), feature=f"val_gan_subset", freq=1) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() discriminator.train() for _ in range(0, train_step): if epochs > pre_gan: train_loss = energy_loss(graph, discriminator=discriminator, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse']) # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1]) # train.step() # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan']) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # train.step() # graph.step(train_loss1 + train_loss2) # logger.update("loss", train_loss1 + train_loss2) # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_id']) # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1]) # graph.step(train_loss1) # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_ood']) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # graph.step(train_loss2) # train_loss3 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan']) # train_loss3 = sum([train_loss3[loss_name] for loss_name in train_loss3]) # graph.step(train_loss3) # logger.update("loss", train_loss1 + train_loss2 + train_loss3) # train.step() # graph fooling loss # n(~x), and y^ (128 subset) # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train]) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # train_loss = train_loss1 + train_loss2 warmup = 5 # if epochs < pre_gan else 1 for i in range(warmup): # y_hat = graph.sample_path([tasks.normal(size=512)], reality=train_subset) # n_x = graph.sample_path([tasks.rgb(size=512), tasks.normal(size=512)], reality=train) y_hat = graph.sample_path([tasks.normal], reality=train_subset) n_x = graph.sample_path( [tasks.rgb(blur_radius=6), tasks.normal(blur_radius=6)], reality=train) def coeff_hook(coeff): def fun1(grad): return coeff * grad.clone() return fun1 logit_path1 = discriminator(y_hat.detach()) coeff = 0.1 path_value2 = n_x * 1.0 path_value2.register_hook(coeff_hook(coeff)) logit_path2 = discriminator(path_value2) binary_label = torch.Tensor( [1] * logit_path1.size(0) + [0] * logit_path2.size(0)).float().cuda() gan_loss = nn.BCEWithLogitsLoss(size_average=True)(torch.cat( (logit_path1, logit_path2), dim=0).view(-1), binary_label) discriminator.discriminator.step(gan_loss) logger.update("train_gan_subset", gan_loss) logger.update("val_gan_subset", gan_loss) # print ("Gan loss: ", (-gan_loss).data.cpu().numpy()) train.step() train_subset.step() graph.eval() discriminator.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, discriminator=discriminator, realities=[val, train_subset]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) if epochs > pre_gan: energy_loss.logger_update(logger) logger.step() if logger.data["train_subset_val_ood : y^ -> n(~x)"][ -1] < best_ood_val_loss: best_ood_val_loss = logger.data[ "train_subset_val_ood : y^ -> n(~x)"][-1] energy_loss.plot_paths(graph, logger, realities, prefix="best")
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{BASE_DIR}/shared/results_LBP_multipercep_lat_winrate_8/graph.pth", cont_gan=None, pre_gan=None, max_epochs=800, use_baseline=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), energy_loss.get_tasks("val"), batch_size=batch_size, fast=fast, ) train_subset_dataset, _, _, _ = load_train_val( energy_loss.get_tasks("train_subset"), batch_size=batch_size, fast=fast, subset_size=subset_size) if not fast: train_step, val_step = train_step // (16 * 4), val_step // (16) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) # ood = RealityTask.from_static("ood", ood_set, [energy_loss.get_tasks("ood")]) # GRAPH realities = [ train, val, test, ] + [train_subset] #[ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned, freeze_list=energy_loss.freeze_list) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph.load_weights(cont) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) # logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') energy_losses = [] mse_losses = [] pearsonr_vals = [] percep_losses = defaultdict(list) pearson_percep = defaultdict(list) # # TRAINING # for epochs in range(0, max_epochs): # logger.update("epoch", epochs) # if epochs == 0: # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") # # if visualize: return # graph.eval() # for _ in range(0, val_step): # with torch.no_grad(): # losses = energy_loss(graph, realities=[val]) # all_perceps = [losses[loss_name] for loss_name in losses if 'percep' in loss_name ] # energy_avg = sum(all_perceps) / len(all_perceps) # for loss_name in losses: # if 'percep' not in loss_name: continue # percep_losses[loss_name] += [losses[loss_name].data.cpu().numpy()] # mse = losses['mse'] # energy_losses.append(energy_avg.data.cpu().numpy()) # mse_losses.append(mse.data.cpu().numpy()) # val.step() # mse_arr = np.array(mse_losses) # energy_arr = np.array(energy_losses) # # logger.scatter(mse_arr - mse_arr.mean() / np.std(mse_arr), \ # # energy_arr - energy_arr.mean() / np.std(energy_arr), \ # # 'unit_normal_all', opts={'xlabel':'mse','ylabel':'energy'}) # logger.scatter(mse_arr, energy_arr, \ # 'mse_energy_all', opts={'xlabel':'mse','ylabel':'energy'}) # pearsonr, p = scipy.stats.pearsonr(mse_arr, energy_arr) # logger.text(f'pearsonr = {pearsonr}, p = {p}') # pearsonr_vals.append(pearsonr) # logger.plot(pearsonr_vals, 'pearsonr_all') # for percep_name in percep_losses: # percep_loss_arr = np.array(percep_losses[percep_name]) # logger.scatter(mse_arr, percep_loss_arr, f'mse_energy_{percep_name}', \ # opts={'xlabel':'mse','ylabel':'energy'}) # pearsonr, p = scipy.stats.pearsonr(mse_arr, percep_loss_arr) # pearson_percep[percep_name] += [pearsonr] # logger.plot(pearson_percep[percep_name], f'pearson_{percep_name}') # energy_loss.logger_update(logger) # if logger.data['val_mse : n(~x) -> y^'][-1] < best_ood_val_loss: # best_ood_val_loss = logger.data['val_mse : n(~x) -> y^'][-1] # energy_loss.plot_paths(graph, logger, realities, prefix="best") energy_mean_by_blur = [] energy_std_by_blur = [] mse_mean_by_blur = [] mse_std_by_blur = [] for blur_size in np.arange(0, 10, 0.5): tasks.rgb.blur_radius = blur_size if blur_size > 0 else None train_subset.step() # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") energy_losses = [] mse_losses = [] for epochs in range(subset_size // batch_size): with torch.no_grad(): flosses = energy_loss(graph, realities=[train_subset], reduce=False) losses = energy_loss(graph, realities=[train_subset], reduce=False) all_perceps = np.stack([ losses[loss_name].data.cpu().numpy() for loss_name in losses if 'percep' in loss_name ]) energy_losses += list(all_perceps.mean(0)) mse_losses += list(losses['mse'].data.cpu().numpy()) train_subset.step() mse_losses = np.array(mse_losses) energy_losses = np.array(energy_losses) logger.text( f'blur_radius = {blur_size}, mse = {mse_losses.mean()}, energy = {energy_losses.mean()}' ) logger.scatter(mse_losses, energy_losses, \ f'mse_energy, blur = {blur_size}', opts={'xlabel':'mse','ylabel':'energy'}) energy_mean_by_blur += [energy_losses.mean()] energy_std_by_blur += [np.std(energy_losses)] mse_mean_by_blur += [mse_losses.mean()] mse_std_by_blur += [np.std(mse_losses)] logger.plot(energy_mean_by_blur, f'energy_mean_by_blur') logger.plot(energy_std_by_blur, f'energy_std_by_blur') logger.plot(mse_mean_by_blur, f'mse_mean_by_blur') logger.plot(mse_std_by_blur, f'mse_std_by_blur')
def main( fast=False, batch_size=None, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 32) energy_loss = get_energy_loss(config="consistency_two_path", mode="standard", **kwargs) # LOGGING logger = VisdomLogger("train", env=JOB) # DATA LOADING video_dataset = ImageDataset( files=sorted( glob.glob(f"mount/taskonomy_house_tour/original/image*.png"), key=lambda x: int(os.path.basename(x)[5:-4])), return_tuple=True, resize=720, ) video = RealityTask("video", video_dataset, [ tasks.rgb, ], batch_size=batch_size, shuffle=False) # GRAPHS graph_baseline = TaskGraph(tasks=energy_loss.tasks + [video], finetuned=False) graph_baseline.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph_finetuned = TaskGraph(tasks=energy_loss.tasks + [video], finetuned=True) graph_finetuned.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph_conservative = TaskGraph(tasks=energy_loss.tasks + [video], finetuned=True) graph_conservative.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph_conservative.load_weights( f"{MODELS_DIR}/conservative/conservative.pth") graph_ood_conservative = TaskGraph(tasks=energy_loss.tasks + [video], finetuned=True) graph_ood_conservative.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph_ood_conservative.load_weights( f"{SHARED_DIR}/results_2F_grounded_1percent_gt_twopath_512_256_crop_7/graph_grounded_1percent_gt_twopath.pth" ) graphs = { "baseline": graph_baseline, "finetuned": graph_finetuned, "conservative": graph_conservative, "ood_conservative": graph_ood_conservative, } inv_transform = transforms.ToPILImage() data = {key: {"losses": [], "zooms": []} for key in graphs} size = 256 for batch in range(0, 700): if batch * batch_size > len(video_dataset.files): break frac = (batch * batch_size * 1.0) / len(video_dataset.files) if frac < 0.3: size = int(256.0 - 128 * frac / 0.3) elif frac < 0.5: size = int(128.0 + 128 * (frac - 0.3) / 0.2) else: size = int(256.0 + (720 - 256) * (frac - 0.5) / 0.5) print(size) # video.reload() size = (size // 32) * 32 print(size) video.step() video.task_data[tasks.rgb] = resize( video.task_data[tasks.rgb].to(DEVICE), size).data print(video.task_data[tasks.rgb].shape) with torch.no_grad(): for i, img in enumerate(video.task_data[tasks.rgb]): inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save( f"mount/taskonomy_house_tour/distorted/image{batch*batch_size + i}.png" ) for name, graph in graphs.items(): normals = graph.sample_path([tasks.rgb, tasks.normal], reality=video) normals2 = graph.sample_path( [tasks.rgb, tasks.principal_curvature, tasks.normal], reality=video) for i, img in enumerate(normals): energy, _ = tasks.normal.norm(normals[i:(i + 1)], normals2[i:(i + 1)]) data[name]["losses"] += [energy.data.cpu().numpy().mean()] data[name]["zooms"] += [size] inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save( f"mount/taskonomy_house_tour/normals_{name}/image{batch*batch_size + i}.png" ) for i, img in enumerate(normals2): inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save( f"mount/taskonomy_house_tour/path2_{name}/image{batch*batch_size + i}.png" ) pickle.dump(data, open(f"mount/taskonomy_house_tour/data.pkl", 'wb')) os.system("bash ~/scaling/scripts/create_vids.sh")
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) print (kwargs) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size ) train_step, val_step = 24, 12 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood([tasks.rgb,]) print (train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=1e-6, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() # logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() # logger.update("loss", train_loss) energy_loss.logger_update(logger) for param_group in graph.optimizer.param_groups: param_group['lr'] *= 1.2 print ("LR: ", param_group['lr']) logger.step()
def main( loss_config="geonet", mode="geonet", visualize=False, fast=False, batch_size=None, subset_size=None, early_stopping=float('inf'), max_epochs=800, **kwargs, ): print(kwargs) # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) print (train_step, val_step) # GRAPH print(energy_loss.tasks) print('train tasks', energy_loss.get_tasks("train")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) print('val tasks', energy_loss.get_tasks("val")) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) print('test tasks', energy_loss.get_tasks("test")) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) print('ood tasks', energy_loss.get_tasks("ood")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) print('done') # GRAPH realities = [train, val, test, ood] # graph = GeoNetTaskGraph(tasks=energy_loss.tasks, realities=realities, pretrained=False) graph = GeoNetTaskGraph(tasks=energy_loss.tasks, realities=realities, pretrained=True) # n(x)/norm(n(x)) # (f(n(x)) / RC(x)) #graph.compile(torch.optim.Adam, grad_clip=2.0, lr=1e-5, weight_decay=0e-6, amsgrad=True) graph.compile(torch.optim.Adam, grad_clip=5.0, lr=4e-5, weight_decay=2e-6, amsgrad=True) #graph.compile(torch.optim.Adam, grad_clip=5.0, lr=1e-6, weight_decay=2e-6, amsgrad=True) #graph.compile(torch.optim.Adam, grad_clip=5.0, lr=1e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) try: energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") except: pass if visualize: return graph.train() print('training for', train_step, 'steps') for _ in range(0, train_step): try: train_loss = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) except NotImplementedError: pass graph.eval() for _ in range(0, val_step): try: with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) except NotImplementedError: pass energy_loss.logger_update(logger) logger.step() stop_idx += 1 try: curr_val_loss = (logger.data["val_mse : N(rgb) -> normal"][-1] + logger.data["val_mse : D(rgb) -> depth"][-1]) if curr_val_loss < best_val_loss: print ("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = curr_val_loss, 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(f"{RESULTS_DIR}/graph.pth") except NotImplementedError: pass if stop_idx >= early_stopping: print ("Stopping training now") return
def main( loss_config="multiperceptual", mode="winrate", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=800, dataaug=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, dataaug=dataaug, ) if fast: train_dataset = val_dataset train_step, val_step = 2,2 train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) if fast: train_dataset = val_dataset train_step, val_step = 2,2 realities = [train, val] else: test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville']) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) realities = [train, val, test] # If you wanted to just do some qualitative predictions on inputs w/o labels, you could do: # ood_set = load_ood(energy_loss.get_tasks("ood")) # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # realities.append(ood) # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, initialize_from_transfer=False, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING os.makedirs(RESULTS_DIR, exist_ok=True) logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) energy_loss.plot_paths(graph, logger, realities, prefix="start") # BASELINE graph.eval() with torch.no_grad(): for _ in range(0, val_step*4): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step*4): train_loss, _ = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="") if visualize: return graph.train() for _ in range(0, train_step): train_loss, mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( job_config="jobinfo.txt", models_dir="models", fast=False, batch_size=None, subset_size=None, max_epochs=500, dataaug=False, **kwargs, ): loss_config, loss_mode, model_class = None, None, None experiment, base_dir = None, None current_dir = os.path.dirname(__file__) job_config = os.path.normpath( os.path.join(os.path.join(current_dir, "config"), job_config)) if os.path.isfile(job_config): with open(job_config) as config_file: out = config_file.read().strip().split(',\n') loss_config, loss_mode, model_class, experiment, base_dir = out loss_config = loss_config or LOSS_CONFIG loss_mode = loss_mode or LOSS_MODE model_class = model_class or MODEL_CLASS base_dir = base_dir or BASE_DIR base_dir = os.path.normpath(os.path.join(current_dir, base_dir)) experiment = experiment or EXPERIMENT job = "_".join(experiment.split("_")[0:-1]) models_dir = os.path.join(base_dir, models_dir) results_dir = f"{base_dir}/results/results_{experiment}" results_dir_models = f"{base_dir}/results/results_{experiment}/models" # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, loss_mode=loss_mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, dataaug=dataaug, ) if fast: train_dataset = val_dataset train_step, val_step = 2, 2 train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville', 'espanola']) ood_set = load_ood(energy_loss.get_tasks("ood")) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) realities = [train, val, test, ood] # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, tasks_in=energy_loss.tasks_in, tasks_out=energy_loss.tasks_out, pretrained=True, models_dir=models_dir, freeze_list=energy_loss.freeze_list, direct_edges=energy_loss.direct_edges, model_class=model_class) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING os.makedirs(results_dir, exist_ok=True) os.makedirs(results_dir_models, exist_ok=True) logger = VisdomLogger("train", env=job, port=PORT, server=SERVER) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{results_dir}/graph.pth", results_dir_models), feature="epoch", freq=1) energy_loss.logger_hooks(logger) energy_loss.plot_paths(graph, logger, realities, prefix="start") # BASELINE graph.eval() with torch.no_grad(): for _ in range(0, val_step * 4): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step * 4): train_loss, _ = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="finish") graph.train() for _ in range(0, train_step): train_loss, grad_mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True) graph.step(train_loss, losses=energy_loss.losses, paths=energy_loss.paths) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) del train_loss print(grad_mse_coeff) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) train_step, val_step = train_step // 4, val_step // 4 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) print("Train step: ", train_step, "Val step: ", val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=True, freeze_list=[functional_transfers.a, functional_transfers.RC], ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) activated_triangles = set() triangle_energy = { "triangle1_mse": float('inf'), "triangle2_mse": float('inf') } logger.add_hook(partial(jointplot, loss_type=f"energy"), feature=f"val_energy", freq=1) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): # loss_type = random.choice(["triangle1_mse", "triangle2_mse"]) loss_type = max(triangle_energy, key=triangle_energy.get) activated_triangles.add(loss_type) train_loss = energy_loss(graph, realities=[train], loss_types=[loss_type]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() if loss_type == "triangle1_mse": consistency_tr1 = energy_loss.metrics["train"][ "triangle1_mse : F(RC(x)) -> n(x)"][-1] error_tr1 = energy_loss.metrics["train"][ "triangle1_mse : n(x) -> y^"][-1] triangle_energy["triangle1_mse"] = float(consistency_tr1 / error_tr1) elif loss_type == "triangle2_mse": consistency_tr2 = energy_loss.metrics["train"][ "triangle2_mse : S(a(x)) -> n(x)"][-1] error_tr2 = energy_loss.metrics["train"][ "triangle2_mse : n(x) -> y^"][-1] triangle_energy["triangle2_mse"] = float(consistency_tr2 / error_tr2) print("Triangle energy: ", triangle_energy) logger.update("loss", train_loss) energy = sum(triangle_energy.values()) if (energy < float('inf')): logger.update("train_energy", energy) logger.update("val_energy", energy) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val], loss_types=list(activated_triangles)) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) activated_triangles = set() energy_loss.logger_update(logger) logger.step()
def main( fast=False, subset_size=None, early_stopping=float('inf'), mode='standard', max_epochs=800, **kwargs, ): early_stopping = 8 loss_config_percepnet = { "paths": { "y": [tasks.normal], "z^": [tasks.principal_curvature], "f(y)": [tasks.normal, tasks.principal_curvature], }, "losses": { "mse": { ("train", "val"): [ ("f(y)", "z^"), ], }, }, "plots": { "ID": dict(size=256, realities=("test", "ood"), paths=[ "y", "z^", "f(y)", ]), }, } # CONFIG batch_size = 64 energy_loss = EnergyLoss(**loss_config_percepnet) task_list = [tasks.rgb, tasks.normal, tasks.principal_curvature] # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( task_list, batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(task_list) ood_set = load_ood(task_list) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, task_list) ood = RealityTask.from_static("ood", ood_set, task_list) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities, pretrained=False, freeze_list=[functional_transfers.n], ) graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : f(y) -> z^"][-1] < best_val_loss: print("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : f(y) -> z^"][ -1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(weights_dir=f"{RESULTS_DIR}") if stop_idx >= early_stopping: print("Stopping training now") break early_stopping = 50 # CONFIG energy_loss = get_energy_loss(config="perceptual", mode=mode, **kwargs) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities, pretrained=False, freeze_list=[functional_transfers.f], ) graph.edge( tasks.normal, tasks.principal_curvature).model.load_weights(f"{RESULTS_DIR}/f.pth") graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss: print("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][ -1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(f"{RESULTS_DIR}/graph.pth") if stop_idx >= early_stopping: print("Stopping training now") break
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=64, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) ood_batch_size = ood_batch_size or batch_size energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( [tasks.rgb, tasks.normal, tasks.principal_curvature], return_dataset=True, batch_size=batch_size, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=256, ) ood_consistency_dataset, _, _, _ = load_train_val( [ tasks.rgb, ], return_dataset=True, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=512, ) train_subset_dataset, _, _, _ = load_train_val( [ tasks.rgb, tasks.normal, ], return_dataset=True, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=512, subset_size=subset_size, ) train_step, val_step = train_step // 4, val_step // 4 if fast: train_step, val_step = 20, 20 test_set = load_test([tasks.rgb, tasks.normal, tasks.principal_curvature]) ood_images = load_ood() ood_images_large = load_ood(resize=512, sample=8) ood_consistency_test = load_test([ tasks.rgb, ], resize=512) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) ood_consistency = RealityTask("ood_consistency", ood_consistency_dataset, batch_size=ood_batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, tasks=[tasks.rgb, tasks.normal], batch_size=ood_batch_size, shuffle=False) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static( "test", test_set, [tasks.rgb, tasks.normal, tasks.principal_curvature]) ood_test = RealityTask.from_static("ood_test", (ood_images, ), [ tasks.rgb, ]) ood_test_large = RealityTask.from_static("ood_test_large", (ood_images_large, ), [ tasks.rgb, ]) ood_consistency_test = RealityTask.from_static("ood_consistency_test", ood_consistency_test, [ tasks.rgb, ]) realities = [ train, val, train_subset, ood_consistency, test, ood_test, ood_test_large, ood_consistency_test ] energy_loss.load_realities(realities) # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # graph.load_weights(f"{MODELS_DIR}/conservative/conservative.pth") graph.load_weights( f"{SHARED_DIR}/results_2FF_train_subset_512_true_baseline_3/graph_baseline.pth" ) print(graph) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook( lambda _, __: graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth"), feature="epoch", freq=1, ) graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth") energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, 800): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): train.step() train_loss = energy_loss(graph, reality=train) graph.step(train_loss) logger.update("loss", train_loss) train_subset.step() train_subset_loss = energy_loss(graph, reality=train_subset) graph.step(train_subset_loss) ood_consistency.step() ood_consistency_loss = energy_loss(graph, reality=ood_consistency) if ood_consistency_loss is not None: graph.step(ood_consistency_loss) graph.eval() for _ in range(0, val_step): val.step() with torch.no_grad(): val_loss = energy_loss(graph, reality=val) logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="conservative_full", mode="standard", pretrained=True, finetuned=False, batch_size=16, ood_batch_size=None, subset_size=None, cont=None, use_l1=True, num_workers=32, data_dir=None, save_dir='mount/shared/', **kwargs, ): # CONFIG energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) if data_dir is None: buildings = ["almena", "albertville"] train_subset_dataset = TaskDataset(buildings, tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature]) else: train_subset_dataset = ImageDataset(data_dir=data_dir) data_dir = 'CUSTOM' train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=False) if subset_size is None: subset_size = len(train_subset_dataset) subset_size = min(subset_size, len(train_subset_dataset)) # GRAPH realities = [train_subset] edges = [] for t in energy_loss.tasks: if t != tasks.rgb: edges.append((tasks.rgb, t)) edges.append((tasks.rgb, tasks.normal)) graph = TaskGraph(tasks=energy_loss.tasks + [train_subset], finetuned=finetuned, freeze_list=energy_loss.freeze_list, lazy=True, initialize_from_transfer=True, ) # print('file', cont) #graph.load_weights(cont) graph.compile(optimizer=None) # Add consistency links for target in ['reshading', 'depth_zbuffer', 'normal']: graph.edge_map[str(('rgb', target))].path = None graph.edge_map[str(('rgb', target))].load_model() graph.edge_map[str(('rgb', 'reshading'))].model.load_weights('./models/rgb2reshading_consistency.pth',backward_compatible=True) graph.edge_map[str(('rgb', 'depth_zbuffer'))].model.load_weights('./models/rgb2depth_consistency.pth',backward_compatible=True) graph.edge_map[str(('rgb', 'normal'))].model.load_weights('./models/rgb2normal_consistency.pth',backward_compatible=True) energy_losses, mse_losses = [], [] percep_losses = defaultdict(list) energy_mean_by_blur, energy_std_by_blur = [], [] error_mean_by_blur, error_std_by_blur = [], [] energy_losses, error_losses = [], [] energy_losses_all, energy_losses_headings = [], [] fnames = [] train_subset.reload() # Compute energies for epochs in tqdm(range(subset_size // batch_size)): with torch.no_grad(): losses = energy_loss(graph, realities=[train_subset], reduce=False, use_l1=use_l1) if len(energy_losses_headings) == 0: energy_losses_headings = sorted([loss_name for loss_name in losses if 'percep' in loss_name]) all_perceps = [losses[loss_name].cpu().numpy() for loss_name in energy_losses_headings] direct_losses = [losses[loss_name].cpu().numpy() for loss_name in losses if 'direct' in loss_name] if len(all_perceps) > 0: energy_losses_all += [all_perceps] all_perceps = np.stack(all_perceps) energy_losses += list(all_perceps.mean(0)) if len(direct_losses) > 0: direct_losses = np.stack(direct_losses) error_losses += list(direct_losses.mean(0)) if False: fnames += train_subset.task_data[tasks.filename] train_subset.step() # log losses if len(energy_losses) > 0: energy_losses = np.array(energy_losses) print(f'energy = {energy_losses.mean()}') energy_mean_by_blur += [energy_losses.mean()] energy_std_by_blur += [np.std(energy_losses)] if len(error_losses) > 0: error_losses = np.array(error_losses) print(f'error = {error_losses.mean()}') error_mean_by_blur += [error_losses.mean()] error_std_by_blur += [np.std(error_losses)] # save to csv save_error_losses = error_losses if len(error_losses) > 0 else [0] * subset_size save_energy_losses = energy_losses if len(energy_losses) > 0 else [0] * subset_size z_score = lambda x: (x - x.mean()) / x.std() def get_standardized_energy(df, use_std=False, compare_to_in_domain=False): percepts = [c for c in df.columns if 'percep' in c] stdize = lambda x: (x - x.mean()).abs().mean() means = {k: df[k].mean() for k in percepts} stds = {k: stdize(df[k]) for k in percepts} stdized = {k: (df[k] - means[k])/stds[k] for k in percepts} energies = np.stack([v for k, v in stdized.items() if k[-1] == '_' or '__' in k]).mean(0) return energies os.makedirs(save_dir, exist_ok=True) if data_dir is 'CUSTOM': eng_curr = np.array(energy_losses).mean() df = pd.read_csv(os.path.join(save_dir, 'data.csv')) else: percep_losses = { k: v for k, v in zip(energy_losses_headings, np.concatenate(energy_losses_all, axis=-1))} df = pd.DataFrame(both( {'energy': save_energy_losses, 'error': save_error_losses }, percep_losses )) # compuate correlation df['normalized_energy'] = get_standardized_energy(df, use_std=False) df['normalized_error'] = z_score(df['error']) print(scipy.stats.spearmanr(z_score(df['error']), df['normalized_energy'])) print("Pearson r:", scipy.stats.pearsonr(df['error'], df['normalized_energy'])) if data_dir is not 'CUSTOM': df.to_csv(f"{save_dir}/data.csv", mode='w', header=True) # plot correlation plt.figure(figsize=(4,4)) g = sns.regplot(df['normalized_error'], df['normalized_energy'],robust=False) if data_dir is 'CUSTOM': ax1 = g.axes ax1.axhline(eng_curr, ls='--', color='red') ax1.text(0.5, 25, "Query Image Energy Line") plt.xlabel('Error (z-score)') plt.ylabel('Energy (z-score)') plt.title('') plt.savefig(f'{save_dir}/energy.pdf')
def main( loss_config="baseline", mode="standard", visualize=False, fast=False, batch_size=None, path=None, subset_size=None, early_stopping=float('inf'), max_epochs=800, **kwargs ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) print('tasks', energy_loss.get_tasks("ood")) ood_set = load_ood(energy_loss.get_tasks("ood")) print (train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=False, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge(tasks.rgb, tasks.normal).path = path graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) graph.save(weights_dir=f"{RESULTS_DIR}") # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # return # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss: print ("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][-1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(weights_dir=f"{RESULTS_DIR}") if stop_idx >= early_stopping: print ("Stopping training now") return
def main( mode="standard", visualize=False, pretrained=True, finetuned=False, batch_size=None, **kwargs, ): configs = { "VISUALS3_rgb2normals2x_multipercep8_winrate_standardized_upd": dict( loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], cont="mount/shared/results_LBP_multipercep8_winrate_standardized_upd_3/graph.pth", test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2reshade2x_latwinrate_reshadetarget": dict( loss_configs=["baseline_reshade_size256", "baseline_reshade_size320", "baseline_reshade_size384", "baseline_reshade_size448", "baseline_reshade_size512"], cont="mount/shared/results_LBP_multipercep_latwinrate_reshadingtarget_6/graph.pth", test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2reshade2x_reshadebaseline": dict( loss_configs=["baseline_reshade_size256", "baseline_reshade_size320", "baseline_reshade_size384", "baseline_reshade_size448", "baseline_reshade_size512"], test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2reshade2x_latwinrate_depthtarget": dict( loss_configs=["baseline_depth_size256", "baseline_depth_size320", "baseline_depth_size384", "baseline_depth_size448", "baseline_depth_size512"], cont="mount/shared/results_LBP_multipercep_latwinrate_reshadingtarget_6/graph.pth", test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2reshade2x_depthbaseline": dict( loss_configs=["baseline_depth_size256", "baseline_depth_size320", "baseline_depth_size384", "baseline_depth_size448", "baseline_depth_size512"], test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2normals2x_baseline": dict( loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], test=True, ood=True, oodfull=False, ), "VISUALS3_rgb2normals2x_multipercep": dict( loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], test=True, ood=True, oodfull=False, cont="mount/shared/results_LBP_multipercep_32/graph.pth", ), "VISUALS3_rgb2x2normals_baseline": dict( loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"], finetuned=False, test=True, ood=True, ood_full=False, ), "VISUALS3_rgb2x2normals_finetuned": dict( loss_configs=["rgb2x2normals_plots", "rgb2x_plots2normals_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"], finetuned=True, test=True, ood=True, ood_full=False, ), "VISUALS3_rgb2x_baseline": dict( loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"], finetuned=False, test=True, ood=True, ood_full=False, ), "VISUALS3_rgb2x_finetuned": dict( loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"], finetuned=True, test=True, ood=True, ood_full=False, ), } # configs = { # "VISUALS_rgb2normals2x_latv2": dict( # loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], # cont="mount/shared/results_LBP_multipercep_latv2_10/graph.pth", # ), # "VISUALS_rgb2normals2x_lat_winrate": dict( # loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], # cont="mount/shared/results_LBP_multipercep_lat_winrate_8/graph.pth", # ), # "VISUALS_rgb2normals2x_multipercep": dict( # loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], # cont="mount/shared/results_LBP_multipercep_32/graph.pth", # ), # "VISUALS_rgb2normals2x_rndv2": dict( # loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], # cont="mount/shared/results_LBP_multipercep_rnd_11/graph.pth", # ), # "VISUALS_rgb2normals2x_baseline": dict( # loss_configs=["baseline_size256", "baseline_size320", "baseline_size384", "baseline_size448", "baseline_size512"], # cont=None, # ), # "VISUALS_rgb2x2normals_baseline": dict( # loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"], # finetuned=False, # ), # "VISUALS_rgb2x2normals_finetuned": dict( # loss_configs=["rgb2x2normals_plots", "rgb2x2normals_plots_size320", "rgb2x2normals_plots_size384", "rgb2x2normals_plots_size448", "rgb2x2normals_plots_size512"], # finetuned=True, # ), # "VISUALS_y2normals_baseline": dict( # loss_configs=["y2normals_plots", "y2normals_plots_size320", "y2normals_plots_size384", "y2normals_plots_size448", "y2normals_plots_size512"], # finetuned=False, # ), # "VISUALS_y2normals_finetuned": dict( # loss_configs=["y2normals_plots", "y2normals_plots_size320", "y2normals_plots_size384", "y2normals_plots_size448", "y2normals_plots_size512"], # finetuned=True, # ), # "VISUALS_rgb2x_baseline": dict( # loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"], # finetuned=False, # ), # "VISUALS_rgb2x_finetuned": dict( # loss_configs=["rgb2x_plots", "rgb2x_plots_size320", "rgb2x_plots_size384", "rgb2x_plots_size448", "rgb2x_plots_size512"], # finetuned=True, # ), # } for i in range(0, 5): config = configs[list(configs.keys())[0]] finetuned = config.get("finetuned", False) loss_configs = config["loss_configs"] loss_config = loss_configs[i] batch_size = batch_size or 32 energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING 1 test_set = load_test(energy_loss.get_tasks("test"), sample=8) ood_tasks = [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb'] ood_set = load_ood(ood_tasks, sample=4) print (ood_tasks) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, ood_tasks) # DATA LOADING 2 ood_tasks = list(set([tasks.rgb] + [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb'])) test_set = load_test(ood_tasks, sample=2) ood_set = load_ood(ood_tasks) test2 = RealityTask.from_static("test", test_set, ood_tasks) ood2 = RealityTask.from_static("ood", ood_set, ood_tasks) # DATA LOADING 3 test_set = load_test(energy_loss.get_tasks("test"), sample=8) ood_tasks = [task for task in energy_loss.get_tasks("ood") if task.kind == 'rgb'] ood_loader = torch.utils.data.DataLoader( ImageDataset(tasks=ood_tasks, data_dir=f"{SHARED_DIR}/ood_images"), batch_size=32, num_workers=32, shuffle=False, pin_memory=True ) data = list(itertools.islice(ood_loader, 2)) test_set = data[0] ood_set = data[1] test3 = RealityTask.from_static("test", test_set, ood_tasks) ood3 = RealityTask.from_static("ood", ood_set, ood_tasks) for name, config in configs.items(): finetuned = config.get("finetuned", False) loss_configs = config["loss_configs"] cont = config.get("cont", None) logger = VisdomLogger("train", env=name, delete=True if i == 0 else False) if config.get("test", False): # GRAPH realities = [test, ood] print ("Finetuned: ", finetuned) graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True) if cont is not None: graph.load_weights(cont) # LOGGING energy_loss.plot_paths_errors(graph, logger, realities, prefix=loss_config) logger = VisdomLogger("train", env=name + "_ood", delete=True if i == 0 else False) if config.get("ood", False): # GRAPH realities = [test2, ood2] print ("Finetuned: ", finetuned) graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True) if cont is not None: graph.load_weights(cont) energy_loss.plot_paths(graph, logger, realities, prefix=loss_config) logger = VisdomLogger("train", env=name + "_oodfull", delete=True if i == 0 else False) if config.get("oodfull", False): # GRAPH realities = [test3, ood3] print ("Finetuned: ", finetuned) graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=finetuned, lazy=True) if cont is not None: graph.load_weights(cont) energy_loss.plot_paths(graph, logger, realities, prefix=loss_config)
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", cont_gan=None, pre_gan=None, use_patches=False, patch_size=64, use_baseline=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned, freeze_list=energy_loss.freeze_list) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not use_baseline and not USE_RAID: graph.load_weights(cont) pre_gan = pre_gan or 1 discriminator = Discriminator(energy_loss.losses['gan'], size=(patch_size if use_patches else 224), use_patches=use_patches) # if cont_gan is not None: discriminator.load_weights(cont_gan) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) logger.add_hook( lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, 80): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() discriminator.train() for _ in range(0, train_step): if epochs > pre_gan: energy_loss.train_iter += 1 train_loss = energy_loss(graph, discriminator=discriminator, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) for i in range(5 if epochs <= pre_gan else 1): train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train]) discriminator.step(train_loss2) train.step() graph.eval() discriminator.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, discriminator=discriminator, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, use_optimizer=False, subset_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) print(kwargs) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size) train_step, val_step = train_step // 4, val_step // 4 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood([ tasks.rgb, ]) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge( tasks.rgb, tasks.normal ).path = "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/n.pth" graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if use_optimizer: optimizer = torch.load( "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/optimizer.pth" ) graph.optimizer.load_state_dict(optimizer.state_dict()) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # if logger.data["val_mse : y^ -> n(~x)"][-1] < best_ood_val_loss: # best_ood_val_loss = logger.data["val_mse : y^ -> n(~x)"][-1] # energy_loss.plot_paths(graph, logger, realities, prefix="best") logger.step()
def main( loss_config="baseline", mode="standard", visualize=False, fast=False, batch_size=None, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) train_step, val_step = 4 * train_step, 4 * val_step test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge(tasks.rgb, tasks.normal ).path = f"{SHARED_DIR}/results_SAMPLEFF_consistency1m_25/n.pth" graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) # TRAINING graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() # print ("Train mse: ", logger.data["train_mse : n(x) -> y^"]) print("Val mse: ", logger.data["val_mse : n(x) -> y^"])