def main(): logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="energy", freq=16) logger.add_hook( lambda logger, data: logger.plot(data["energy"], "free_energy"), feature="energy", freq=100) task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.sobel_edges, tasks.depth_zbuffer, tasks.reshading, tasks.edge_occlusion, tasks.keypoints3d, tasks.keypoints2d, ] reality = RealityTask('ood', dataset=ImagePairDataset(data_dir=OOD_DIR, resize=(256, 256)), tasks=[tasks.rgb, tasks.rgb], batch_size=28) # reality = RealityTask('almena', # dataset=TaskDataset( # buildings=['almena'], # tasks=task_list, # ), # tasks=task_list, # batch_size=8 # ) graph = TaskGraph(tasks=[reality, *task_list], batch_size=28) task = tasks.rgb images = [reality.task_data[task]] sources = [task.name] for _, edge in sorted( ((edge.dest_task.name, edge) for edge in graph.adj[task])): if isinstance(edge.src_task, RealityTask): continue reality.task_data[edge.src_task] x = edge(reality.task_data[edge.src_task]) if edge.dest_task != tasks.normal: edge2 = graph.edge_map[(edge.dest_task.name, tasks.normal.name)] x = edge2(x) images.append(x.clamp(min=0, max=1)) sources.append(edge.dest_task.name) logger.images_grouped(images, ", ".join(sources), resize=256)
def main(): logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="energy", freq=16) logger.add_hook(lambda logger, data: logger.plot(data["energy"], "free_energy"), feature="energy", freq=100) task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.sobel_edges, tasks.depth_zbuffer, tasks.reshading, tasks.edge_occlusion, tasks.keypoints3d, tasks.keypoints2d, ] reality = RealityTask('almena', dataset=TaskDataset( buildings=['almena'], tasks=task_list, ), tasks=task_list, batch_size=8 ) graph = TaskGraph( tasks=[reality, *task_list], batch_size=8 ) for task in graph.tasks: if isinstance(task, RealityTask): continue images = [reality.task_data[task].clamp(min=0, max=1)] sources = [task.name] for _, edge in sorted(((edge.src_task.name, edge) for edge in graph.in_adj[task])): if isinstance(edge.src_task, RealityTask): continue x = edge(reality.task_data[edge.src_task]) images.append(x.clamp(min=0, max=1)) sources.append(edge.src_task.name) logger.images_grouped(images, ", ".join(sources), resize=192)
def main(src_task, dest_task, fast=False): # src_task, dest_task = get_task(src_task), get_task(dest_task) model = DataParallelModel(get_model(src_task, dest_task).cuda()) model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=25) logger.add_hook(partial(jointplot, loss_type="mse_loss"), feature="val_mse_loss", freq=1) logger.add_hook(lambda logger, data: model.save( f"{RESULTS_DIR}/{src_task.name}2{dest_task.name}.pth"), feature="loss", freq=400) # DATA LOADING train_loader, val_loader, train_step, val_step = load_train_val( [src_task, dest_task], batch_size=48, train_buildings=["almena"], val_buildings=["almena"]) train_step, val_step = 5, 5 test_set, test_images = load_test(src_task, dest_task) src_task.plot_func(test_images, "images", logger, resize=128) for epochs in range(0, 800): logger.update("epoch", epochs) plot_images(model, logger, test_set, dest_task, show_masks=True) train_set = itertools.islice(train_loader, train_step) val_set = itertools.islice(val_loader, val_step) (train_mse_data, ) = model.fit_with_metrics(train_set, loss_fn=dest_task.norm, logger=logger) logger.update("train_mse_loss", train_mse_data) (val_mse_data, ) = model.predict_with_metrics(val_set, loss_fn=dest_task.norm, logger=logger) logger.update("val_mse_loss", val_mse_data)
def main(loss_config="gt_mse", mode="standard", pretrained=False, batch_size=64, **kwargs): # MODEL # model = DataParallelModel.load(UNet().cuda(), "standardval_rgb2normal_baseline.pth") model = functional_transfers.n.load_model() if pretrained else DataParallelModel(UNet()) model.compile(torch.optim.Adam, lr=(3e-5 if pretrained else 3e-4), weight_decay=2e-6, amsgrad=True) scheduler = MultiStepLR(model.optimizer, milestones=[5*i + 1 for i in range(0, 80)], gamma=0.95) # FUNCTIONAL LOSS functional = get_functional_loss(config=loss_config, mode=mode, model=model, **kwargs) print (functional) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda logger, data: model.save(f"{RESULTS_DIR}/model.pth"), feature="loss", freq=400) logger.add_hook(lambda logger, data: scheduler.step(), feature="epoch", freq=1) functional.logger_hooks(loggers) # DATA LOADING ood_images = load_ood(ood_path=f'{BASE_DIR}/data/ood_images/') train_loader, val_loader, train_step, val_step = load_train_val([tasks.rgb, tasks.normal], batch_size=batch_size) # train_buildings=["almena"], val_buildings=["almena"]) test_set, test_images = load_test("rgb", "normal") logger.images(test_images, "images", resize=128) logger.images(torch.cat(ood_images, dim=0), "ood_images", resize=128) # TRAINING for epochs in range(0, 800): preds_name = "start_preds" if epochs == 0 and pretrained else "preds" ood_name = "start_ood" if epochs == 0 and pretrained else "ood" plot_images(model, logger, test_set, dest_task="normal", ood_images=ood_images, loss_models=functional.plot_losses, preds_name=preds_name, ood_name=ood_name ) logger.update("epoch", epochs) logger.step() train_set = itertools.islice(train_loader, train_step) val_set = itertools.islice(val_loader, val_step) val_metrics = model.predict_with_metrics(val_set, loss_fn=functional, logger=logger) train_metrics = model.fit_with_metrics(train_set, loss_fn=functional, logger=logger) functional.logger_update(logger, train_metrics, val_metrics)
def main(): src_task, dest_task = tasks.rgb, tasks.normal model = functional_transfers.n.load_model() model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=25) logger.add_hook(partial(jointplot, loss_type="mse_loss"), feature="val_mse_loss", freq=1) logger.add_hook( lambda logger, data: model.save(f"{RESULTS_DIR}/model.pth"), feature="epoch", freq=1) # DATA LOADING train_loader, val_loader, train_step, val_step, test_set, test_images = load_sintel_train_val_test( src_task, dest_task, batch_size=8) src_task.plot_func(test_images, "images", logger, resize=128) for epochs in range(0, 800): logger.update("epoch", epochs) plot_images(model, logger, test_set, dest_task, show_masks=True) train_set = itertools.islice(train_loader, train_step) val_set = itertools.islice(val_loader, val_step) (train_mse_data, ) = model.fit_with_metrics(train_set, loss_fn=dest_task.norm, logger=logger) logger.update("train_mse_loss", torch.mean(torch.tensor(train_mse_data))) (val_mse_data, ) = model.predict_with_metrics(val_set, loss_fn=dest_task.norm, logger=logger) logger.update("val_mse_loss", torch.mean(torch.tensor(train_mse_data)))
def main( loss_config="baseline_normal", mode="standard", visualize=False, fast=False, batch_size=None, learning_rate=5e-4, subset_size=None, max_epochs=5000, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({ "loss_config": loss_config, "batch_size": batch_size, "data_aug": dataaug, "lr": learning_rate }) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=learning_rate, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) if (epochs % 100 == 0) or (epochs % 10 == 0 and epochs < 30): path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) data = logger.step() del data['loss'] del data['epoch'] data = {k: v[0] for k, v in data.items()} wandb.log(data, step=epochs) # save model and opt state every 10 epochs if epochs % 10 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth") # lower lr after 1500 epochs if epochs == 1500: graph.optimizer.param_groups[0]['lr'] = 3e-5 graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")
def main(): task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.sobel_edges, tasks.depth_zbuffer, tasks.reshading, tasks.edge_occlusion, tasks.keypoints3d, tasks.keypoints2d, ] reality = RealityTask('almena', dataset=TaskDataset(buildings=['almena'], tasks=[ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer ]), tasks=[ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer ], batch_size=4) graph = TaskGraph( tasks=[reality, *task_list], anchored_tasks=[reality, tasks.rgb], reality=reality, batch_size=4, edges_exclude=[ ('almena', 'normal'), ('almena', 'principal_curvature'), ('almena', 'depth_zbuffer'), # ('rgb', 'keypoints3d'), # ('rgb', 'edge_occlusion'), ], initialize_first_order=True, ) graph.p.compile(torch.optim.Adam, lr=4e-2) graph.estimates.compile(torch.optim.Adam, lr=1e-2) logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="energy", freq=16) logger.add_hook( lambda logger, data: logger.plot(data["energy"], "free_energy"), feature="energy", freq=100) logger.add_hook(lambda logger, data: graph.plot_estimates(logger), feature="epoch", freq=32) logger.add_hook(lambda logger, data: graph.update_paths(logger), feature="epoch", freq=32) graph.plot_estimates(logger) graph.plot_paths(logger, dest_tasks=[ tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer ], show_images=False) for epochs in range(0, 4000): logger.update("epoch", epochs) with torch.no_grad(): free_energy = graph.free_energy(sample=16) graph.averaging_step() logger.update("energy", free_energy)
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{BASE_DIR}/shared/results_LBP_multipercep_lat_winrate_8/graph.pth", cont_gan=None, pre_gan=None, max_epochs=800, use_baseline=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), energy_loss.get_tasks("val"), batch_size=batch_size, fast=fast, ) train_subset_dataset, _, _, _ = load_train_val( energy_loss.get_tasks("train_subset"), batch_size=batch_size, fast=fast, subset_size=subset_size) if not fast: train_step, val_step = train_step // (16 * 4), val_step // (16) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) # ood = RealityTask.from_static("ood", ood_set, [energy_loss.get_tasks("ood")]) # GRAPH realities = [ train, val, test, ] + [train_subset] #[ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned, freeze_list=energy_loss.freeze_list) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph.load_weights(cont) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) # logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') energy_losses = [] mse_losses = [] pearsonr_vals = [] percep_losses = defaultdict(list) pearson_percep = defaultdict(list) # # TRAINING # for epochs in range(0, max_epochs): # logger.update("epoch", epochs) # if epochs == 0: # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") # # if visualize: return # graph.eval() # for _ in range(0, val_step): # with torch.no_grad(): # losses = energy_loss(graph, realities=[val]) # all_perceps = [losses[loss_name] for loss_name in losses if 'percep' in loss_name ] # energy_avg = sum(all_perceps) / len(all_perceps) # for loss_name in losses: # if 'percep' not in loss_name: continue # percep_losses[loss_name] += [losses[loss_name].data.cpu().numpy()] # mse = losses['mse'] # energy_losses.append(energy_avg.data.cpu().numpy()) # mse_losses.append(mse.data.cpu().numpy()) # val.step() # mse_arr = np.array(mse_losses) # energy_arr = np.array(energy_losses) # # logger.scatter(mse_arr - mse_arr.mean() / np.std(mse_arr), \ # # energy_arr - energy_arr.mean() / np.std(energy_arr), \ # # 'unit_normal_all', opts={'xlabel':'mse','ylabel':'energy'}) # logger.scatter(mse_arr, energy_arr, \ # 'mse_energy_all', opts={'xlabel':'mse','ylabel':'energy'}) # pearsonr, p = scipy.stats.pearsonr(mse_arr, energy_arr) # logger.text(f'pearsonr = {pearsonr}, p = {p}') # pearsonr_vals.append(pearsonr) # logger.plot(pearsonr_vals, 'pearsonr_all') # for percep_name in percep_losses: # percep_loss_arr = np.array(percep_losses[percep_name]) # logger.scatter(mse_arr, percep_loss_arr, f'mse_energy_{percep_name}', \ # opts={'xlabel':'mse','ylabel':'energy'}) # pearsonr, p = scipy.stats.pearsonr(mse_arr, percep_loss_arr) # pearson_percep[percep_name] += [pearsonr] # logger.plot(pearson_percep[percep_name], f'pearson_{percep_name}') # energy_loss.logger_update(logger) # if logger.data['val_mse : n(~x) -> y^'][-1] < best_ood_val_loss: # best_ood_val_loss = logger.data['val_mse : n(~x) -> y^'][-1] # energy_loss.plot_paths(graph, logger, realities, prefix="best") energy_mean_by_blur = [] energy_std_by_blur = [] mse_mean_by_blur = [] mse_std_by_blur = [] for blur_size in np.arange(0, 10, 0.5): tasks.rgb.blur_radius = blur_size if blur_size > 0 else None train_subset.step() # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") energy_losses = [] mse_losses = [] for epochs in range(subset_size // batch_size): with torch.no_grad(): flosses = energy_loss(graph, realities=[train_subset], reduce=False) losses = energy_loss(graph, realities=[train_subset], reduce=False) all_perceps = np.stack([ losses[loss_name].data.cpu().numpy() for loss_name in losses if 'percep' in loss_name ]) energy_losses += list(all_perceps.mean(0)) mse_losses += list(losses['mse'].data.cpu().numpy()) train_subset.step() mse_losses = np.array(mse_losses) energy_losses = np.array(energy_losses) logger.text( f'blur_radius = {blur_size}, mse = {mse_losses.mean()}, energy = {energy_losses.mean()}' ) logger.scatter(mse_losses, energy_losses, \ f'mse_energy, blur = {blur_size}', opts={'xlabel':'mse','ylabel':'energy'}) energy_mean_by_blur += [energy_losses.mean()] energy_std_by_blur += [np.std(energy_losses)] mse_mean_by_blur += [mse_losses.mean()] mse_std_by_blur += [np.std(mse_losses)] logger.plot(energy_mean_by_blur, f'energy_mean_by_blur') logger.plot(energy_std_by_blur, f'energy_std_by_blur') logger.plot(mse_mean_by_blur, f'mse_mean_by_blur') logger.plot(mse_std_by_blur, f'mse_std_by_blur')
def main( loss_config="multiperceptual", mode="standard", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=5000, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({"loss_config":loss_config,"batch_size":batch_size,"data_aug":dataaug,"lr":"3e-5", "n_gauss":1,"distribution":"laplace"}) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, val_noaug_dataset, train_step, val_step = load_train_val_merging( energy_loss.get_tasks("train_c"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood"), ood_path='./assets/ood_natural/') ood_syn_aug_set = load_ood(energy_loss.get_tasks("ood_syn_aug"), ood_path='./assets/st_syn_distortions/') ood_syn_set = load_ood(energy_loss.get_tasks("ood_syn"), ood_path='./assets/ood_syn_distortions/', sample=35) train = RealityTask("train_c", train_dataset, batch_size=batch_size, shuffle=True) # distorted and undistorted val = RealityTask("val_c", val_dataset, batch_size=batch_size, shuffle=True) # distorted and undistorted val_noaug = RealityTask("val", val_noaug_dataset, batch_size=batch_size, shuffle=True) # no augmentation test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) ## standard ood set - natural ood_syn_aug = RealityTask.from_static("ood_syn_aug", ood_syn_aug_set, [tasks.rgb,]) ## synthetic distortion images used for sig training ood_syn = RealityTask.from_static("ood_syn", ood_syn_set, [tasks.rgb,]) ## unseen syn distortions # GRAPH realities = [train, val, val_noaug, test, ood, ood_syn_aug, ood_syn] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) graph.eval() path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0) with torch.no_grad(): for reality in [val,val_noaug]: for _ in range(0, val_step): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=0) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for reality in [val,val_noaug]: for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] del data['epoch'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=epochs+1) if epochs % 10 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth") if epochs % 25 == 0: path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs+1) graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")
def main( loss_config="baseline", mode="standard", visualize=False, fast=False, batch_size=None, path=None, subset_size=None, early_stopping=float('inf'), max_epochs=800, **kwargs ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) print('tasks', energy_loss.get_tasks("ood")) ood_set = load_ood(energy_loss.get_tasks("ood")) print (train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=False, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge(tasks.rgb, tasks.normal).path = path graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) graph.save(weights_dir=f"{RESULTS_DIR}") # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # return # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss: print ("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][-1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(weights_dir=f"{RESULTS_DIR}") if stop_idx >= early_stopping: print ("Stopping training now") return
def main( pretrained=True, finetuned=False, fast=False, batch_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", max_epochs=800, **kwargs, ): task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer, # tasks.sobel_edges, ] # CONFIG batch_size = batch_size or (4 if fast else 64) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( task_list, batch_size=batch_size, fast=fast, ) test_set = load_test(task_list) train_step, val_step = train_step // 4, val_step // 4 print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, task_list) # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # GRAPH realities = [train, val, test] graph = TaskGraph(tasks=task_list + realities, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) logger.add_hook(partial(jointplot, loss_type=f"energy"), feature=f"val_energy", freq=1) def path_name(path): if len(path) == 1: print(path) return str(path[0]) if str((path[-2].name, path[-1].name)) not in graph.edge_map: return None sub_name = path_name(path[:-1]) if sub_name is None: return None return f'{get_transfer_name(Transfer(path[-2], path[-1]))}({sub_name})' for i in range(0, 20): gt_paths = { path_name(path): list(path) for path in itertools.permutations(task_list, 1) if path_name(path) is not None } baseline_paths = { path_name(path): list(path) for path in itertools.permutations(task_list, 2) if path_name(path) is not None } paths = { path_name(path): list(path) for path in itertools.permutations(task_list, 3) if path_name(path) is not None } selected_paths = dict(random.sample(paths.items(), k=3)) print("Chosen paths: ", selected_paths) loss_config = { "paths": { **gt_paths, **baseline_paths, **paths }, "losses": { "baseline_mse": { ( "train", "val", ): [(path_name, str(path[-1])) for path_name, path in baseline_paths.items()] }, "mse": { ( "train", "val", ): [(path_name, str(path[-1])) for path_name, path in selected_paths.items()] }, "eval_mse": { ("val", ): [(path_name, str(path[-1])) for path_name, path in paths.items()] } }, "plots": { "ID": dict(size=256, realities=("test", ), paths=[ path_name for path_name, path in selected_paths.items() ] + [ str(path[-1]) for path_name, path in selected_paths.items() ]), }, } energy_loss = EnergyLoss(**loss_config) energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, 5): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], loss_types=["mse", "baseline_mse"]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val], loss_types=["mse", "baseline_mse"]) val_loss = sum( [val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) val_loss = energy_loss(graph, realities=[val], loss_types=["eval_mse"]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("train_energy", val_loss) logger.update("val_energy", val_loss) energy_loss.logger_update(logger) logger.step()
from sklearn.metrics import accuracy_score, roc_auc_score from lifelines.utils import concordance_index from scipy.stats import pearsonr from utils import * from models import TrainableModel from modules import Highway from generators import AbstractPatientGenerator from logger import Logger, VisdomLogger from data import fetch import IPython # LOGGING logger = VisdomLogger("train", server='35.230.67.129', port=7000, env="cancer") logger.add_hook(lambda x: logger.step(), feature='loss', freq=80) def jointplot(data): data = np.stack([ logger.data["cox_loss"], logger.data["sim_loss"], logger.data["train_c_index"], logger.data["val_c_index"], ], axis=1) np.savez_compressed( "data_multimodal.npz", cox_loss=logger.data["cox_loss"], sim_loss=logger.data["sim_loss"],
def main(): # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda x: logger.step(), feature="loss", freq=25) resize = 256 ood_images = load_ood()[0] tasks = [ get_task(name) for name in [ 'rgb', 'normal', 'principal_curvature', 'depth_zbuffer', 'sobel_edges', 'reshading', 'keypoints3d', 'keypoints2d' ] ] test_loader = torch.utils.data.DataLoader(TaskDataset(['almena'], tasks), batch_size=64, num_workers=12, shuffle=False, pin_memory=True) imgs = list(itertools.islice(test_loader, 1))[0] gt = {tasks[i].name: batch.cuda() for i, batch in enumerate(imgs)} num_plot = 4 logger.images(ood_images, f"x", nrow=2, resize=resize) edges = finetuned_transfers def get_nbrs(task, edges): res = [] for e in edges: if task == e.src_task: res.append(e) return res max_depth = 10 mse_dict = defaultdict(list) def search_small(x, task, prefix, visited, depth, endpoint): if task.name == 'normal': interleave = torch.stack([ val for pair in zip(x[:num_plot], gt[task.name][:num_plot]) for val in pair ]) logger.images(interleave.clamp(max=1, min=0), prefix, nrow=2, resize=resize) mse, _ = task.loss_func(x, gt[task.name]) mse_dict[task.name].append( (mse.detach().data.cpu().numpy(), prefix)) for transfer in get_nbrs(task, edges): preds = transfer(x) next_prefix = f'{transfer.name}({prefix})' print(f"{transfer.src_task.name}2{transfer.dest_task.name}", next_prefix) if transfer.dest_task.name not in visited: visited.add(transfer.dest_task.name) res = search_small(preds, transfer.dest_task, next_prefix, visited, depth + 1, endpoint) visited.remove(transfer.dest_task.name) return endpoint == task def search_full(x, task, prefix, visited, depth, endpoint): for transfer in get_nbrs(task, edges): preds = transfer(x) next_prefix = f'{transfer.name}({prefix})' print(f"{transfer.src_task.name}2{transfer.dest_task.name}", next_prefix) if transfer.dest_task.name == 'normal': interleave = torch.stack([ val for pair in zip(preds[:num_plot], gt[ transfer.dest_task.name][:num_plot]) for val in pair ]) logger.images(interleave.clamp(max=1, min=0), next_prefix, nrow=2, resize=resize) mse, _ = task.loss_func(preds, gt[transfer.dest_task.name]) mse_dict[transfer.dest_task.name].append( (mse.detach().data.cpu().numpy(), next_prefix)) if transfer.dest_task.name not in visited: visited.add(transfer.dest_task.name) res = search_full(preds, transfer.dest_task, next_prefix, visited, depth + 1, endpoint) visited.remove(transfer.dest_task.name) return endpoint == task def search(x, task, prefix, visited, depth): for transfer in get_nbrs(task, edges): preds = transfer(x) next_prefix = f'{transfer.name}({prefix})' print(f"{transfer.src_task.name}2{transfer.dest_task.name}", next_prefix) if transfer.dest_task.name == 'normal': logger.images(preds.clamp(max=1, min=0), next_prefix, nrow=2, resize=resize) if transfer.dest_task.name not in visited: visited.add(transfer.dest_task.name) res = search(preds, transfer.dest_task, next_prefix, visited, depth + 1) visited.remove(transfer.dest_task.name) with torch.no_grad(): # search_full(gt['rgb'], TASK_MAP['rgb'], 'x', set('rgb'), 1, TASK_MAP['normal']) search(ood_images, get_task('rgb'), 'x', set('rgb'), 1) for name, mse_list in mse_dict.items(): mse_list.sort() print(name) print(mse_list) if len(mse_list) == 1: mse_list.append((0, '-')) rownames = [pair[1] for pair in mse_list] data = [pair[0] for pair in mse_list] print(data, rownames) logger.bar(data, f'{name}_path_mse', opts={'rownames': rownames})
def main( job_config="jobinfo.txt", models_dir="models", fast=False, batch_size=None, subset_size=None, max_epochs=500, dataaug=False, **kwargs, ): loss_config, loss_mode, model_class = None, None, None experiment, base_dir = None, None current_dir = os.path.dirname(__file__) job_config = os.path.normpath( os.path.join(os.path.join(current_dir, "config"), job_config)) if os.path.isfile(job_config): with open(job_config) as config_file: out = config_file.read().strip().split(',\n') loss_config, loss_mode, model_class, experiment, base_dir = out loss_config = loss_config or LOSS_CONFIG loss_mode = loss_mode or LOSS_MODE model_class = model_class or MODEL_CLASS base_dir = base_dir or BASE_DIR base_dir = os.path.normpath(os.path.join(current_dir, base_dir)) experiment = experiment or EXPERIMENT job = "_".join(experiment.split("_")[0:-1]) models_dir = os.path.join(base_dir, models_dir) results_dir = f"{base_dir}/results/results_{experiment}" results_dir_models = f"{base_dir}/results/results_{experiment}/models" # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, loss_mode=loss_mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, dataaug=dataaug, ) if fast: train_dataset = val_dataset train_step, val_step = 2, 2 train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville', 'espanola']) ood_set = load_ood(energy_loss.get_tasks("ood")) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) realities = [train, val, test, ood] # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, tasks_in=energy_loss.tasks_in, tasks_out=energy_loss.tasks_out, pretrained=True, models_dir=models_dir, freeze_list=energy_loss.freeze_list, direct_edges=energy_loss.direct_edges, model_class=model_class) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING os.makedirs(results_dir, exist_ok=True) os.makedirs(results_dir_models, exist_ok=True) logger = VisdomLogger("train", env=job, port=PORT, server=SERVER) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{results_dir}/graph.pth", results_dir_models), feature="epoch", freq=1) energy_loss.logger_hooks(logger) energy_loss.plot_paths(graph, logger, realities, prefix="start") # BASELINE graph.eval() with torch.no_grad(): for _ in range(0, val_step * 4): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step * 4): train_loss, _ = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="finish") graph.train() for _ in range(0, train_step): train_loss, grad_mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True) graph.step(train_loss, losses=energy_loss.losses, paths=energy_loss.paths) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) del train_loss print(grad_mse_coeff) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", cont_gan=None, pre_gan=None, use_patches=False, patch_size=64, use_baseline=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned, freeze_list=energy_loss.freeze_list) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not use_baseline and not USE_RAID: graph.load_weights(cont) pre_gan = pre_gan or 1 discriminator = Discriminator(energy_loss.losses['gan'], size=(patch_size if use_patches else 224), use_patches=use_patches) # if cont_gan is not None: discriminator.load_weights(cont_gan) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) logger.add_hook( lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, 80): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() discriminator.train() for _ in range(0, train_step): if epochs > pre_gan: energy_loss.train_iter += 1 train_loss = energy_loss(graph, discriminator=discriminator, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) for i in range(5 if epochs <= pre_gan else 1): train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train]) discriminator.step(train_loss2) train.step() graph.eval() discriminator.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, discriminator=discriminator, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from utils import * import transforms from encoding import encode_binary from models import BaseModel, DecodingModel, DataParallelModel from logger import Logger, VisdomLogger import IPython # LOGGING logger = VisdomLogger("test", server="35.230.67.129", port=8000, env=JOB) logger.add_hook(lambda x: logger.step(), feature="orig", freq=1) def sweep(images, targets, model, transform, name, samples=10): min_val, max_val = transform.plot_range results = [] for val in tqdm.tqdm(np.linspace(min_val, max_val, samples), ncols=30): transformed = transform(images, val) predictions = model(transformed).mean(dim=1).cpu().data.numpy() mse_loss = np.mean( [binary.mse_dist(x, y) for x, y in zip(predictions, targets)]) binary_loss = np.mean( [binary.distance(x, y) for x, y in zip(predictions, targets)])
targets = [binary.random(n=TARGET_SIZE) for i in range(len(images))] torch.save((perturbation.data, images.data, targets), f"{output_path}/{k}.pth") if __name__ == "__main__": model = DataParallelModel( DecodingModel(n=DIST_SIZE, distribution=transforms.training)) params = itertools.chain(model.module.classifier.parameters(), model.module.features[-1].parameters()) optimizer = torch.optim.Adam(params, lr=2.5e-3) init_data("data/amnesia") logger = VisdomLogger("train", server="35.230.67.129", port=8000, env=JOB) logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20) logger.add_hook(lambda data: logger.plot(data, "train_loss"), feature="loss", freq=50) logger.add_hook(lambda data: logger.plot(data, "train_bits"), feature="bits", freq=50) logger.add_hook( lambda x: model.save("output/train_test.pth", verbose=True), feature="epoch", freq=100) model.save("output/train_test.pth", verbose=True) files = glob.glob(f"data/amnesia/*.pth") for i, save_file in enumerate( random.choice(files) for i in range(0, 2701)):
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=64, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) ood_batch_size = ood_batch_size or batch_size energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( [tasks.rgb, tasks.normal, tasks.principal_curvature], return_dataset=True, batch_size=batch_size, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=256, ) ood_consistency_dataset, _, _, _ = load_train_val( [ tasks.rgb, ], return_dataset=True, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=512, ) train_subset_dataset, _, _, _ = load_train_val( [ tasks.rgb, tasks.normal, ], return_dataset=True, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=512, subset_size=subset_size, ) train_step, val_step = train_step // 4, val_step // 4 if fast: train_step, val_step = 20, 20 test_set = load_test([tasks.rgb, tasks.normal, tasks.principal_curvature]) ood_images = load_ood() ood_images_large = load_ood(resize=512, sample=8) ood_consistency_test = load_test([ tasks.rgb, ], resize=512) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) ood_consistency = RealityTask("ood_consistency", ood_consistency_dataset, batch_size=ood_batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, tasks=[tasks.rgb, tasks.normal], batch_size=ood_batch_size, shuffle=False) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static( "test", test_set, [tasks.rgb, tasks.normal, tasks.principal_curvature]) ood_test = RealityTask.from_static("ood_test", (ood_images, ), [ tasks.rgb, ]) ood_test_large = RealityTask.from_static("ood_test_large", (ood_images_large, ), [ tasks.rgb, ]) ood_consistency_test = RealityTask.from_static("ood_consistency_test", ood_consistency_test, [ tasks.rgb, ]) realities = [ train, val, train_subset, ood_consistency, test, ood_test, ood_test_large, ood_consistency_test ] energy_loss.load_realities(realities) # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # graph.load_weights(f"{MODELS_DIR}/conservative/conservative.pth") graph.load_weights( f"{SHARED_DIR}/results_2FF_train_subset_512_true_baseline_3/graph_baseline.pth" ) print(graph) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook( lambda _, __: graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth"), feature="epoch", freq=1, ) graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth") energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, 800): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): train.step() train_loss = energy_loss(graph, reality=train) graph.step(train_loss) logger.update("loss", train_loss) train_subset.step() train_subset_loss = energy_loss(graph, reality=train_subset) graph.step(train_subset_loss) ood_consistency.step() ood_consistency_loss = energy_loss(graph, reality=ood_consistency) if ood_consistency_loss is not None: graph.step(ood_consistency_loss) graph.eval() for _ in range(0, val_step): val.step() with torch.no_grad(): val_loss = energy_loss(graph, reality=val) logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) print (kwargs) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size ) train_step, val_step = 24, 12 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood([tasks.rgb,]) print (train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=1e-6, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() # logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() # logger.update("loss", train_loss) energy_loss.logger_update(logger) for param_group in graph.optimizer.param_groups: param_group['lr'] *= 1.2 print ("LR: ", param_group['lr']) logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) train_step, val_step = train_step // 4, val_step // 4 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) print("Train step: ", train_step, "Val step: ", val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=True, freeze_list=[functional_transfers.a, functional_transfers.RC], ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) activated_triangles = set() triangle_energy = { "triangle1_mse": float('inf'), "triangle2_mse": float('inf') } logger.add_hook(partial(jointplot, loss_type=f"energy"), feature=f"val_energy", freq=1) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): # loss_type = random.choice(["triangle1_mse", "triangle2_mse"]) loss_type = max(triangle_energy, key=triangle_energy.get) activated_triangles.add(loss_type) train_loss = energy_loss(graph, realities=[train], loss_types=[loss_type]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() if loss_type == "triangle1_mse": consistency_tr1 = energy_loss.metrics["train"][ "triangle1_mse : F(RC(x)) -> n(x)"][-1] error_tr1 = energy_loss.metrics["train"][ "triangle1_mse : n(x) -> y^"][-1] triangle_energy["triangle1_mse"] = float(consistency_tr1 / error_tr1) elif loss_type == "triangle2_mse": consistency_tr2 = energy_loss.metrics["train"][ "triangle2_mse : S(a(x)) -> n(x)"][-1] error_tr2 = energy_loss.metrics["train"][ "triangle2_mse : n(x) -> y^"][-1] triangle_energy["triangle2_mse"] = float(consistency_tr2 / error_tr2) print("Triangle energy: ", triangle_energy) logger.update("loss", train_loss) energy = sum(triangle_energy.values()) if (energy < float('inf')): logger.update("train_energy", energy) logger.update("val_energy", energy) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val], loss_types=list(activated_triangles)) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) activated_triangles = set() energy_loss.logger_update(logger) logger.step()
def main( loss_config="geonet", mode="geonet", visualize=False, fast=False, batch_size=None, subset_size=None, early_stopping=float('inf'), max_epochs=800, **kwargs, ): print(kwargs) # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) print (train_step, val_step) # GRAPH print(energy_loss.tasks) print('train tasks', energy_loss.get_tasks("train")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) print('val tasks', energy_loss.get_tasks("val")) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) print('test tasks', energy_loss.get_tasks("test")) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) print('ood tasks', energy_loss.get_tasks("ood")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) print('done') # GRAPH realities = [train, val, test, ood] # graph = GeoNetTaskGraph(tasks=energy_loss.tasks, realities=realities, pretrained=False) graph = GeoNetTaskGraph(tasks=energy_loss.tasks, realities=realities, pretrained=True) # n(x)/norm(n(x)) # (f(n(x)) / RC(x)) #graph.compile(torch.optim.Adam, grad_clip=2.0, lr=1e-5, weight_decay=0e-6, amsgrad=True) graph.compile(torch.optim.Adam, grad_clip=5.0, lr=4e-5, weight_decay=2e-6, amsgrad=True) #graph.compile(torch.optim.Adam, grad_clip=5.0, lr=1e-6, weight_decay=2e-6, amsgrad=True) #graph.compile(torch.optim.Adam, grad_clip=5.0, lr=1e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) try: energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") except: pass if visualize: return graph.train() print('training for', train_step, 'steps') for _ in range(0, train_step): try: train_loss = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) except NotImplementedError: pass graph.eval() for _ in range(0, val_step): try: with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) except NotImplementedError: pass energy_loss.logger_update(logger) logger.step() stop_idx += 1 try: curr_val_loss = (logger.data["val_mse : N(rgb) -> normal"][-1] + logger.data["val_mse : D(rgb) -> depth"][-1]) if curr_val_loss < best_val_loss: print ("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = curr_val_loss, 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(f"{RESULTS_DIR}/graph.pth") except NotImplementedError: pass if stop_idx >= early_stopping: print ("Stopping training now") return
def main( loss_config="multiperceptual", mode="standard", visualize=False, fast=False, batch_size=None, resume=False, learning_rate=3e-5, subset_size=None, max_epochs=2000, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({ "loss_config": loss_config, "batch_size": batch_size, "data_aug": dataaug, "lr": learning_rate }) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if resume: graph.load_weights('/workspace/shared/results_test_1/graph.pth') graph.optimizer.load_state_dict( torch.load('/workspace/shared/results_test_1/opt.pth')) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) ######## baseline computation if not resume: graph.eval() with torch.no_grad(): for _ in range(0, val_step): val_loss, _, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step): train_loss, _, _ = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data = logger.step() del data['loss'] data = {k: v[0] for k, v in data.items()} wandb.log(data, step=0) path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0) ########### # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss, coeffs, avg_grads = energy_loss( graph, realities=[train], compute_grad_ratio=True) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data = logger.step() del data['loss'] del data['epoch'] data = {k: v[0] for k, v in data.items()} wandb.log(data, step=epochs) wandb.log(coeffs, step=epochs) wandb.log(avg_grads, step=epochs) if epochs % 5 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth") if epochs % 10 == 0: path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs + 1) graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train_step, val_step = 4, 4 print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, freeze_list=energy_loss.freeze_list, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not USE_RAID: graph.load_weights(cont) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix=f"epoch_{epochs}") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.select_losses(val) if epochs != 0: energy_loss.logger_update(logger) else: energy_loss.metrics = {} logger.step() logger.text(f"Chosen losses: {energy_loss.chosen_losses}") logger.text(f"Percep winrate: {energy_loss.percep_winrate}") graph.train() for _ in range(0, train_step): train_loss2 = energy_loss(graph, realities=[train]) train_loss = sum(train_loss2.values()) graph.step(train_loss) train.step() logger.update("loss", train_loss)
import torch.optim as optim from torch.autograd import Variable from models import DecodingModel, DataParallelModel from torchvision import models from logger import Logger, VisdomLogger from utils import * import IPython import transforms # LOGGING logger = VisdomLogger("encoding", server="35.230.67.129", port=8000, env=JOB) logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20) logger.add_hook(lambda x: logger.plot(x, "Encoding Loss", opts=dict(ymin=0)), feature="loss", freq=50) """ Computes the changed images, given a a specified perturbation, standard deviation weighting, and epsilon. """ def compute_changed_images(images, perturbation, std_weights, epsilon=EPSILON): perturbation_w2 = perturbation * std_weights perturbation_zc = ( perturbation_w2 / perturbation_w2.view(perturbation_w2.shape[0], -1)
def __len__(self): return len(self.files) def __getitem__(self, idx): file = self.files[idx] res = [] seed = random.randint(0, 1e10) for task in self.tasks: image = task.file_loader(file, seed=seed) if image.shape[0] == 1: image = image.expand(3, -1, -1) res.append(image) return tuple(res) if __name__ == "__main__": logger = VisdomLogger("data", env=JOB) train_dataset, val_dataset, train_step, val_step = load_train_val( [tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.rgb(size=512)], batch_size=32, ) print ("created dataset") logger.add_hook(lambda logger, data: logger.step(), freq=32) for i, _ in enumerate(train_dataset): logger.update("epoch", i)
def main( fast=False, subset_size=None, early_stopping=float('inf'), mode='standard', max_epochs=800, **kwargs, ): early_stopping = 8 loss_config_percepnet = { "paths": { "y": [tasks.normal], "z^": [tasks.principal_curvature], "f(y)": [tasks.normal, tasks.principal_curvature], }, "losses": { "mse": { ("train", "val"): [ ("f(y)", "z^"), ], }, }, "plots": { "ID": dict(size=256, realities=("test", "ood"), paths=[ "y", "z^", "f(y)", ]), }, } # CONFIG batch_size = 64 energy_loss = EnergyLoss(**loss_config_percepnet) task_list = [tasks.rgb, tasks.normal, tasks.principal_curvature] # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( task_list, batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(task_list) ood_set = load_ood(task_list) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, task_list) ood = RealityTask.from_static("ood", ood_set, task_list) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities, pretrained=False, freeze_list=[functional_transfers.n], ) graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : f(y) -> z^"][-1] < best_val_loss: print("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : f(y) -> z^"][ -1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(weights_dir=f"{RESULTS_DIR}") if stop_idx >= early_stopping: print("Stopping training now") break early_stopping = 50 # CONFIG energy_loss = get_energy_loss(config="perceptual", mode=mode, **kwargs) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities, pretrained=False, freeze_list=[functional_transfers.f], ) graph.edge( tasks.normal, tasks.principal_curvature).model.load_weights(f"{RESULTS_DIR}/f.pth") graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss: print("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][ -1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(f"{RESULTS_DIR}/graph.pth") if stop_idx >= early_stopping: print("Stopping training now") break
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", cont_gan=None, pre_gan=None, max_epochs=800, use_baseline=False, use_patches=False, patch_frac=None, patch_size=64, patch_sigma=0, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) train_subset_dataset, _, _, _ = load_train_val( energy_loss.get_tasks("train_subset"), batch_size=batch_size, fast=fast, subset_size=subset_size) train_step, val_step = train_step // 16, val_step // 16 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, train_subset, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not USE_RAID and not use_baseline: graph.load_weights(cont) pre_gan = pre_gan or 1 discriminator = Discriminator(energy_loss.losses['gan'], frac=patch_frac, size=(patch_size if use_patches else 224), sigma=patch_sigma, use_patches=use_patches) if cont_gan is not None: discriminator.load_weights(cont_gan) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) logger.add_hook( lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') logger.add_hook(partial(jointplot, loss_type=f"gan_subset"), feature=f"val_gan_subset", freq=1) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() discriminator.train() for _ in range(0, train_step): if epochs > pre_gan: train_loss = energy_loss(graph, discriminator=discriminator, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse']) # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1]) # train.step() # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan']) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # train.step() # graph.step(train_loss1 + train_loss2) # logger.update("loss", train_loss1 + train_loss2) # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_id']) # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1]) # graph.step(train_loss1) # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_ood']) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # graph.step(train_loss2) # train_loss3 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan']) # train_loss3 = sum([train_loss3[loss_name] for loss_name in train_loss3]) # graph.step(train_loss3) # logger.update("loss", train_loss1 + train_loss2 + train_loss3) # train.step() # graph fooling loss # n(~x), and y^ (128 subset) # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train]) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # train_loss = train_loss1 + train_loss2 warmup = 5 # if epochs < pre_gan else 1 for i in range(warmup): # y_hat = graph.sample_path([tasks.normal(size=512)], reality=train_subset) # n_x = graph.sample_path([tasks.rgb(size=512), tasks.normal(size=512)], reality=train) y_hat = graph.sample_path([tasks.normal], reality=train_subset) n_x = graph.sample_path( [tasks.rgb(blur_radius=6), tasks.normal(blur_radius=6)], reality=train) def coeff_hook(coeff): def fun1(grad): return coeff * grad.clone() return fun1 logit_path1 = discriminator(y_hat.detach()) coeff = 0.1 path_value2 = n_x * 1.0 path_value2.register_hook(coeff_hook(coeff)) logit_path2 = discriminator(path_value2) binary_label = torch.Tensor( [1] * logit_path1.size(0) + [0] * logit_path2.size(0)).float().cuda() gan_loss = nn.BCEWithLogitsLoss(size_average=True)(torch.cat( (logit_path1, logit_path2), dim=0).view(-1), binary_label) discriminator.discriminator.step(gan_loss) logger.update("train_gan_subset", gan_loss) logger.update("val_gan_subset", gan_loss) # print ("Gan loss: ", (-gan_loss).data.cpu().numpy()) train.step() train_subset.step() graph.eval() discriminator.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, discriminator=discriminator, realities=[val, train_subset]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) if epochs > pre_gan: energy_loss.logger_update(logger) logger.step() if logger.data["train_subset_val_ood : y^ -> n(~x)"][ -1] < best_ood_val_loss: best_ood_val_loss = logger.data[ "train_subset_val_ood : y^ -> n(~x)"][-1] energy_loss.plot_paths(graph, logger, realities, prefix="best")
def main( loss_config="multiperceptual", mode="winrate", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=800, dataaug=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, dataaug=dataaug, ) if fast: train_dataset = val_dataset train_step, val_step = 2,2 train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) if fast: train_dataset = val_dataset train_step, val_step = 2,2 realities = [train, val] else: test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville']) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) realities = [train, val, test] # If you wanted to just do some qualitative predictions on inputs w/o labels, you could do: # ood_set = load_ood(energy_loss.get_tasks("ood")) # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # realities.append(ood) # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, initialize_from_transfer=False, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING os.makedirs(RESULTS_DIR, exist_ok=True) logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) energy_loss.plot_paths(graph, logger, realities, prefix="start") # BASELINE graph.eval() with torch.no_grad(): for _ in range(0, val_step*4): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step*4): train_loss, _ = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="") if visualize: return graph.train() for _ in range(0, train_step): train_loss, mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, use_optimizer=False, subset_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) print(kwargs) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size) train_step, val_step = train_step // 4, val_step // 4 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood([ tasks.rgb, ]) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge( tasks.rgb, tasks.normal ).path = "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/n.pth" graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if use_optimizer: optimizer = torch.load( "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/optimizer.pth" ) graph.optimizer.load_state_dict(optimizer.state_dict()) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # if logger.data["val_mse : y^ -> n(~x)"][-1] < best_ood_val_loss: # best_ood_val_loss = logger.data["val_mse : y^ -> n(~x)"][-1] # energy_loss.plot_paths(graph, logger, realities, prefix="best") logger.step()
def main( loss_config="trainsig_edgereshade", mode="standard", visualize=False, fast=False, batch_size=32, learning_rate=3e-5, resume=False, subset_size=None, max_epochs=800, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({"loss_config":loss_config,"batch_size":batch_size,"lr":learning_rate}) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_undist_dataset, train_dist_dataset, val_ooddist_dataset, val_dist_dataset, val_dataset, train_step, val_step = load_train_val_sig( energy_loss.get_tasks("val"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood"), ood_path='./assets/ood_natural/') ood_syn_aug_set = load_ood(energy_loss.get_tasks("ood_syn_aug"), ood_path='./assets/st_syn_distortions/') ood_syn_set = load_ood(energy_loss.get_tasks("ood_syn"), ood_path='./assets/ood_syn_distortions/', sample=35) train_undist = RealityTask("train_undist", train_undist_dataset, batch_size=batch_size, shuffle=True) train_dist = RealityTask("train_dist", train_dist_dataset, batch_size=batch_size, shuffle=True) val_ooddist = RealityTask("val_ooddist", val_ooddist_dataset, batch_size=batch_size, shuffle=True) val_dist = RealityTask("val_dist", val_dist_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) ## standard ood set - natural ood_syn_aug = RealityTask.from_static("ood_syn_aug", ood_syn_aug_set, [tasks.rgb,]) ## synthetic distortion images used for sig training ood_syn = RealityTask.from_static("ood_syn", ood_syn_set, [tasks.rgb,]) ## unseen syn distortions # GRAPH realities = [train_undist, train_dist, val_ooddist, val_dist, val, test, ood, ood_syn_aug, ood_syn] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if resume: graph.load_weights('/workspace/shared/results_test_1/graph.pth') graph.optimizer.load_state_dict(torch.load('/workspace/shared/results_test_1/opt.pth')) # else: # folder_name='/workspace/shared/results_wavelet2normal_depthreshadecurvimgnetl1perceps_0.1nll/' # # pdb.set_trace() # in_domain='wav' # out_domain='normal' # graph.load_weights(folder_name+'graph.pth', [str((in_domain, out_domain))]) # create_t0_graph(folder_name,in_domain,out_domain) # graph.load_weights(folder_name+'graph_t0.pth', [str((in_domain, f'{out_domain}_t0'))]) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) # BASELINE if not resume: graph.eval() with torch.no_grad(): for reality in [val_ooddist,val_dist,val]: for _ in range(0, val_step): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) for reality in [train_undist,train_dist]: for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[reality]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) reality.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=0) path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) graph.train() for _ in range(0, train_step): train_loss_nll = energy_loss(graph, realities=[train_undist]) train_loss_nll = sum([train_loss_nll[loss_name] for loss_name in train_loss_nll]) train_loss_lwfsig = energy_loss(graph, realities=[train_dist]) train_loss_lwfsig = sum([train_loss_lwfsig[loss_name] for loss_name in train_loss_lwfsig]) train_loss = train_loss_nll+train_loss_lwfsig graph.step(train_loss) train_undist.step() train_dist.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val_dist]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val_dist.step() logger.update("loss", val_loss) if epochs % 20 == 0: for reality in [val,val_ooddist]: for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=epochs+1) if epochs % 10 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth") if (epochs % 100 == 0) or (epochs % 15 == 0 and epochs <= 30): path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs+1) graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")