def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, use_optimizer=False, subset_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) print(kwargs) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size) train_step, val_step = train_step // 4, val_step // 4 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood([ tasks.rgb, ]) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge( tasks.rgb, tasks.normal ).path = "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/n.pth" graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if use_optimizer: optimizer = torch.load( "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/optimizer.pth" ) graph.optimizer.load_state_dict(optimizer.state_dict()) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # if logger.data["val_mse : y^ -> n(~x)"][-1] < best_ood_val_loss: # best_ood_val_loss = logger.data["val_mse : y^ -> n(~x)"][-1] # energy_loss.plot_paths(graph, logger, realities, prefix="best") logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", cont_gan=None, pre_gan=None, max_epochs=800, use_baseline=False, use_patches=False, patch_frac=None, patch_size=64, patch_sigma=0, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) train_subset_dataset, _, _, _ = load_train_val( energy_loss.get_tasks("train_subset"), batch_size=batch_size, fast=fast, subset_size=subset_size) train_step, val_step = train_step // 16, val_step // 16 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, train_subset, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not USE_RAID and not use_baseline: graph.load_weights(cont) pre_gan = pre_gan or 1 discriminator = Discriminator(energy_loss.losses['gan'], frac=patch_frac, size=(patch_size if use_patches else 224), sigma=patch_sigma, use_patches=use_patches) if cont_gan is not None: discriminator.load_weights(cont_gan) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) logger.add_hook( lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') logger.add_hook(partial(jointplot, loss_type=f"gan_subset"), feature=f"val_gan_subset", freq=1) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() discriminator.train() for _ in range(0, train_step): if epochs > pre_gan: train_loss = energy_loss(graph, discriminator=discriminator, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse']) # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1]) # train.step() # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan']) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # train.step() # graph.step(train_loss1 + train_loss2) # logger.update("loss", train_loss1 + train_loss2) # train_loss1 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_id']) # train_loss1 = sum([train_loss1[loss_name] for loss_name in train_loss1]) # graph.step(train_loss1) # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['mse_ood']) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # graph.step(train_loss2) # train_loss3 = energy_loss(graph, discriminator=discriminator, realities=[train], loss_types=['gan']) # train_loss3 = sum([train_loss3[loss_name] for loss_name in train_loss3]) # graph.step(train_loss3) # logger.update("loss", train_loss1 + train_loss2 + train_loss3) # train.step() # graph fooling loss # n(~x), and y^ (128 subset) # train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train]) # train_loss2 = sum([train_loss2[loss_name] for loss_name in train_loss2]) # train_loss = train_loss1 + train_loss2 warmup = 5 # if epochs < pre_gan else 1 for i in range(warmup): # y_hat = graph.sample_path([tasks.normal(size=512)], reality=train_subset) # n_x = graph.sample_path([tasks.rgb(size=512), tasks.normal(size=512)], reality=train) y_hat = graph.sample_path([tasks.normal], reality=train_subset) n_x = graph.sample_path( [tasks.rgb(blur_radius=6), tasks.normal(blur_radius=6)], reality=train) def coeff_hook(coeff): def fun1(grad): return coeff * grad.clone() return fun1 logit_path1 = discriminator(y_hat.detach()) coeff = 0.1 path_value2 = n_x * 1.0 path_value2.register_hook(coeff_hook(coeff)) logit_path2 = discriminator(path_value2) binary_label = torch.Tensor( [1] * logit_path1.size(0) + [0] * logit_path2.size(0)).float().cuda() gan_loss = nn.BCEWithLogitsLoss(size_average=True)(torch.cat( (logit_path1, logit_path2), dim=0).view(-1), binary_label) discriminator.discriminator.step(gan_loss) logger.update("train_gan_subset", gan_loss) logger.update("val_gan_subset", gan_loss) # print ("Gan loss: ", (-gan_loss).data.cpu().numpy()) train.step() train_subset.step() graph.eval() discriminator.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, discriminator=discriminator, realities=[val, train_subset]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) if epochs > pre_gan: energy_loss.logger_update(logger) logger.step() if logger.data["train_subset_val_ood : y^ -> n(~x)"][ -1] < best_ood_val_loss: best_ood_val_loss = logger.data[ "train_subset_val_ood : y^ -> n(~x)"][-1] energy_loss.plot_paths(graph, logger, realities, prefix="best")
def main( loss_config="baseline_normal", mode="standard", visualize=False, fast=False, batch_size=None, learning_rate=5e-4, subset_size=None, max_epochs=5000, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({ "loss_config": loss_config, "batch_size": batch_size, "data_aug": dataaug, "lr": learning_rate }) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=learning_rate, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) if (epochs % 100 == 0) or (epochs % 10 == 0 and epochs < 30): path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) data = logger.step() del data['loss'] del data['epoch'] data = {k: v[0] for k, v in data.items()} wandb.log(data, step=epochs) # save model and opt state every 10 epochs if epochs % 10 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth") # lower lr after 1500 epochs if epochs == 1500: graph.optimizer.param_groups[0]['lr'] = 3e-5 graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) print (kwargs) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size ) train_step, val_step = 24, 12 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood([tasks.rgb,]) print (train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=1e-6, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() # logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() # logger.update("loss", train_loss) energy_loss.logger_update(logger) for param_group in graph.optimizer.param_groups: param_group['lr'] *= 1.2 print ("LR: ", param_group['lr']) logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train_step, val_step = 4, 4 print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, freeze_list=energy_loss.freeze_list, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not USE_RAID: graph.load_weights(cont) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix=f"epoch_{epochs}") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.select_losses(val) if epochs != 0: energy_loss.logger_update(logger) else: energy_loss.metrics = {} logger.step() logger.text(f"Chosen losses: {energy_loss.chosen_losses}") logger.text(f"Percep winrate: {energy_loss.percep_winrate}") graph.train() for _ in range(0, train_step): train_loss2 = energy_loss(graph, realities=[train]) train_loss = sum(train_loss2.values()) graph.step(train_loss) train.step() logger.update("loss", train_loss)
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", cont_gan=None, pre_gan=None, use_patches=False, patch_size=64, use_baseline=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned, freeze_list=energy_loss.freeze_list) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if not use_baseline and not USE_RAID: graph.load_weights(cont) pre_gan = pre_gan or 1 discriminator = Discriminator(energy_loss.losses['gan'], size=(patch_size if use_patches else 224), use_patches=use_patches) # if cont_gan is not None: discriminator.load_weights(cont_gan) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) logger.add_hook( lambda _, __: discriminator.save(f"{RESULTS_DIR}/discriminator.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, 80): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() discriminator.train() for _ in range(0, train_step): if epochs > pre_gan: energy_loss.train_iter += 1 train_loss = energy_loss(graph, discriminator=discriminator, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) for i in range(5 if epochs <= pre_gan else 1): train_loss2 = energy_loss(graph, discriminator=discriminator, realities=[train]) discriminator.step(train_loss2) train.step() graph.eval() discriminator.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, discriminator=discriminator, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="multiperceptual", mode="standard", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=5000, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({"loss_config":loss_config,"batch_size":batch_size,"data_aug":dataaug,"lr":"3e-5", "n_gauss":1,"distribution":"laplace"}) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, val_noaug_dataset, train_step, val_step = load_train_val_merging( energy_loss.get_tasks("train_c"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood"), ood_path='./assets/ood_natural/') ood_syn_aug_set = load_ood(energy_loss.get_tasks("ood_syn_aug"), ood_path='./assets/st_syn_distortions/') ood_syn_set = load_ood(energy_loss.get_tasks("ood_syn"), ood_path='./assets/ood_syn_distortions/', sample=35) train = RealityTask("train_c", train_dataset, batch_size=batch_size, shuffle=True) # distorted and undistorted val = RealityTask("val_c", val_dataset, batch_size=batch_size, shuffle=True) # distorted and undistorted val_noaug = RealityTask("val", val_noaug_dataset, batch_size=batch_size, shuffle=True) # no augmentation test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) ## standard ood set - natural ood_syn_aug = RealityTask.from_static("ood_syn_aug", ood_syn_aug_set, [tasks.rgb,]) ## synthetic distortion images used for sig training ood_syn = RealityTask.from_static("ood_syn", ood_syn_set, [tasks.rgb,]) ## unseen syn distortions # GRAPH realities = [train, val, val_noaug, test, ood, ood_syn_aug, ood_syn] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) graph.eval() path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0) with torch.no_grad(): for reality in [val,val_noaug]: for _ in range(0, val_step): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=0) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for reality in [val,val_noaug]: for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] del data['epoch'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=epochs+1) if epochs % 10 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth") if epochs % 25 == 0: path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs+1) graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")
def main( fast=False, subset_size=None, early_stopping=float('inf'), mode='standard', max_epochs=800, **kwargs, ): early_stopping = 8 loss_config_percepnet = { "paths": { "y": [tasks.normal], "z^": [tasks.principal_curvature], "f(y)": [tasks.normal, tasks.principal_curvature], }, "losses": { "mse": { ("train", "val"): [ ("f(y)", "z^"), ], }, }, "plots": { "ID": dict(size=256, realities=("test", "ood"), paths=[ "y", "z^", "f(y)", ]), }, } # CONFIG batch_size = 64 energy_loss = EnergyLoss(**loss_config_percepnet) task_list = [tasks.rgb, tasks.normal, tasks.principal_curvature] # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( task_list, batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(task_list) ood_set = load_ood(task_list) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, task_list) ood = RealityTask.from_static("ood", ood_set, task_list) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities, pretrained=False, freeze_list=[functional_transfers.n], ) graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : f(y) -> z^"][-1] < best_val_loss: print("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : f(y) -> z^"][ -1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(weights_dir=f"{RESULTS_DIR}") if stop_idx >= early_stopping: print("Stopping training now") break early_stopping = 50 # CONFIG energy_loss = get_energy_loss(config="perceptual", mode=mode, **kwargs) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities, pretrained=False, freeze_list=[functional_transfers.f], ) graph.edge( tasks.normal, tasks.principal_curvature).model.load_weights(f"{RESULTS_DIR}/f.pth") graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss: print("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][ -1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(f"{RESULTS_DIR}/graph.pth") if stop_idx >= early_stopping: print("Stopping training now") break
def main( loss_config="multiperceptual", mode="winrate", visualize=False, fast=False, batch_size=None, subset_size=None, max_epochs=800, dataaug=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, dataaug=dataaug, ) if fast: train_dataset = val_dataset train_step, val_step = 2,2 train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) if fast: train_dataset = val_dataset train_step, val_step = 2,2 realities = [train, val] else: test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville']) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) realities = [train, val, test] # If you wanted to just do some qualitative predictions on inputs w/o labels, you could do: # ood_set = load_ood(energy_loss.get_tasks("ood")) # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # realities.append(ood) # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, initialize_from_transfer=False, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING os.makedirs(RESULTS_DIR, exist_ok=True) logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) energy_loss.plot_paths(graph, logger, realities, prefix="start") # BASELINE graph.eval() with torch.no_grad(): for _ in range(0, val_step*4): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step*4): train_loss, _ = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="") if visualize: return graph.train() for _ in range(0, train_step): train_loss, mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) train_step, val_step = train_step // 4, val_step // 4 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) print("Train step: ", train_step, "Val step: ", val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=True, freeze_list=[functional_transfers.a, functional_transfers.RC], ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) activated_triangles = set() triangle_energy = { "triangle1_mse": float('inf'), "triangle2_mse": float('inf') } logger.add_hook(partial(jointplot, loss_type=f"energy"), feature=f"val_energy", freq=1) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): # loss_type = random.choice(["triangle1_mse", "triangle2_mse"]) loss_type = max(triangle_energy, key=triangle_energy.get) activated_triangles.add(loss_type) train_loss = energy_loss(graph, realities=[train], loss_types=[loss_type]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() if loss_type == "triangle1_mse": consistency_tr1 = energy_loss.metrics["train"][ "triangle1_mse : F(RC(x)) -> n(x)"][-1] error_tr1 = energy_loss.metrics["train"][ "triangle1_mse : n(x) -> y^"][-1] triangle_energy["triangle1_mse"] = float(consistency_tr1 / error_tr1) elif loss_type == "triangle2_mse": consistency_tr2 = energy_loss.metrics["train"][ "triangle2_mse : S(a(x)) -> n(x)"][-1] error_tr2 = energy_loss.metrics["train"][ "triangle2_mse : n(x) -> y^"][-1] triangle_energy["triangle2_mse"] = float(consistency_tr2 / error_tr2) print("Triangle energy: ", triangle_energy) logger.update("loss", train_loss) energy = sum(triangle_energy.values()) if (energy < float('inf')): logger.update("train_energy", energy) logger.update("val_energy", energy) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val], loss_types=list(activated_triangles)) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) activated_triangles = set() energy_loss.logger_update(logger) logger.step()
def main( loss_config="multiperceptual", mode="standard", visualize=False, fast=False, batch_size=None, resume=False, learning_rate=3e-5, subset_size=None, max_epochs=2000, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({ "loss_config": loss_config, "batch_size": batch_size, "data_aug": dataaug, "lr": learning_rate }) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if resume: graph.load_weights('/workspace/shared/results_test_1/graph.pth') graph.optimizer.load_state_dict( torch.load('/workspace/shared/results_test_1/opt.pth')) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) ######## baseline computation if not resume: graph.eval() with torch.no_grad(): for _ in range(0, val_step): val_loss, _, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step): train_loss, _, _ = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data = logger.step() del data['loss'] data = {k: v[0] for k, v in data.items()} wandb.log(data, step=0) path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0) ########### # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss, coeffs, avg_grads = energy_loss( graph, realities=[train], compute_grad_ratio=True) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data = logger.step() del data['loss'] del data['epoch'] data = {k: v[0] for k, v in data.items()} wandb.log(data, step=epochs) wandb.log(coeffs, step=epochs) wandb.log(avg_grads, step=epochs) if epochs % 5 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth") if epochs % 10 == 0: path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs + 1) graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(), f"{RESULTS_DIR}/opt.pth")
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=64, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) ood_batch_size = ood_batch_size or batch_size energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( [tasks.rgb, tasks.normal, tasks.principal_curvature], return_dataset=True, batch_size=batch_size, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=256, ) ood_consistency_dataset, _, _, _ = load_train_val( [ tasks.rgb, ], return_dataset=True, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=512, ) train_subset_dataset, _, _, _ = load_train_val( [ tasks.rgb, tasks.normal, ], return_dataset=True, train_buildings=["almena"] if fast else None, val_buildings=["almena"] if fast else None, resize=512, subset_size=subset_size, ) train_step, val_step = train_step // 4, val_step // 4 if fast: train_step, val_step = 20, 20 test_set = load_test([tasks.rgb, tasks.normal, tasks.principal_curvature]) ood_images = load_ood() ood_images_large = load_ood(resize=512, sample=8) ood_consistency_test = load_test([ tasks.rgb, ], resize=512) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) ood_consistency = RealityTask("ood_consistency", ood_consistency_dataset, batch_size=ood_batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, tasks=[tasks.rgb, tasks.normal], batch_size=ood_batch_size, shuffle=False) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static( "test", test_set, [tasks.rgb, tasks.normal, tasks.principal_curvature]) ood_test = RealityTask.from_static("ood_test", (ood_images, ), [ tasks.rgb, ]) ood_test_large = RealityTask.from_static("ood_test_large", (ood_images_large, ), [ tasks.rgb, ]) ood_consistency_test = RealityTask.from_static("ood_consistency_test", ood_consistency_test, [ tasks.rgb, ]) realities = [ train, val, train_subset, ood_consistency, test, ood_test, ood_test_large, ood_consistency_test ] energy_loss.load_realities(realities) # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # graph.load_weights(f"{MODELS_DIR}/conservative/conservative.pth") graph.load_weights( f"{SHARED_DIR}/results_2FF_train_subset_512_true_baseline_3/graph_baseline.pth" ) print(graph) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook( lambda _, __: graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth"), feature="epoch", freq=1, ) graph.save(f"{RESULTS_DIR}/graph_{loss_config}.pth") energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, 800): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): train.step() train_loss = energy_loss(graph, reality=train) graph.step(train_loss) logger.update("loss", train_loss) train_subset.step() train_subset_loss = energy_loss(graph, reality=train_subset) graph.step(train_subset_loss) ood_consistency.step() ood_consistency_loss = energy_loss(graph, reality=ood_consistency) if ood_consistency_loss is not None: graph.step(ood_consistency_loss) graph.eval() for _ in range(0, val_step): val.step() with torch.no_grad(): val_loss = energy_loss(graph, reality=val) logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="baseline", mode="standard", visualize=False, fast=False, batch_size=None, path=None, subset_size=None, early_stopping=float('inf'), max_epochs=800, **kwargs ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) print('tasks', energy_loss.get_tasks("ood")) ood_set = load_ood(energy_loss.get_tasks("ood")) print (train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=False, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge(tasks.rgb, tasks.normal).path = path graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) graph.save(weights_dir=f"{RESULTS_DIR}") # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # return # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss: print ("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][-1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(weights_dir=f"{RESULTS_DIR}") if stop_idx >= early_stopping: print ("Stopping training now") return
def main( pretrained=True, finetuned=False, fast=False, batch_size=None, cont=f"{MODELS_DIR}/conservative/conservative.pth", max_epochs=800, **kwargs, ): task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer, # tasks.sobel_edges, ] # CONFIG batch_size = batch_size or (4 if fast else 64) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( task_list, batch_size=batch_size, fast=fast, ) test_set = load_test(task_list) train_step, val_step = train_step // 4, val_step // 4 print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, task_list) # ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) # GRAPH realities = [train, val, test] graph = TaskGraph(tasks=task_list + realities, finetuned=finetuned) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) logger.add_hook(partial(jointplot, loss_type=f"energy"), feature=f"val_energy", freq=1) def path_name(path): if len(path) == 1: print(path) return str(path[0]) if str((path[-2].name, path[-1].name)) not in graph.edge_map: return None sub_name = path_name(path[:-1]) if sub_name is None: return None return f'{get_transfer_name(Transfer(path[-2], path[-1]))}({sub_name})' for i in range(0, 20): gt_paths = { path_name(path): list(path) for path in itertools.permutations(task_list, 1) if path_name(path) is not None } baseline_paths = { path_name(path): list(path) for path in itertools.permutations(task_list, 2) if path_name(path) is not None } paths = { path_name(path): list(path) for path in itertools.permutations(task_list, 3) if path_name(path) is not None } selected_paths = dict(random.sample(paths.items(), k=3)) print("Chosen paths: ", selected_paths) loss_config = { "paths": { **gt_paths, **baseline_paths, **paths }, "losses": { "baseline_mse": { ( "train", "val", ): [(path_name, str(path[-1])) for path_name, path in baseline_paths.items()] }, "mse": { ( "train", "val", ): [(path_name, str(path[-1])) for path_name, path in selected_paths.items()] }, "eval_mse": { ("val", ): [(path_name, str(path[-1])) for path_name, path in paths.items()] } }, "plots": { "ID": dict(size=256, realities=("test", ), paths=[ path_name for path_name, path in selected_paths.items() ] + [ str(path[-1]) for path_name, path in selected_paths.items() ]), }, } energy_loss = EnergyLoss(**loss_config) energy_loss.logger_hooks(logger) # TRAINING for epochs in range(0, 5): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train], loss_types=["mse", "baseline_mse"]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val], loss_types=["mse", "baseline_mse"]) val_loss = sum( [val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) val_loss = energy_loss(graph, realities=[val], loss_types=["eval_mse"]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("train_energy", val_loss) logger.update("val_energy", val_loss) energy_loss.logger_update(logger) logger.step()
def main( job_config="jobinfo.txt", models_dir="models", fast=False, batch_size=None, subset_size=None, max_epochs=500, dataaug=False, **kwargs, ): loss_config, loss_mode, model_class = None, None, None experiment, base_dir = None, None current_dir = os.path.dirname(__file__) job_config = os.path.normpath( os.path.join(os.path.join(current_dir, "config"), job_config)) if os.path.isfile(job_config): with open(job_config) as config_file: out = config_file.read().strip().split(',\n') loss_config, loss_mode, model_class, experiment, base_dir = out loss_config = loss_config or LOSS_CONFIG loss_mode = loss_mode or LOSS_MODE model_class = model_class or MODEL_CLASS base_dir = base_dir or BASE_DIR base_dir = os.path.normpath(os.path.join(current_dir, base_dir)) experiment = experiment or EXPERIMENT job = "_".join(experiment.split("_")[0:-1]) models_dir = os.path.join(base_dir, models_dir) results_dir = f"{base_dir}/results/results_{experiment}" results_dir_models = f"{base_dir}/results/results_{experiment}/models" # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, loss_mode=loss_mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, dataaug=dataaug, ) if fast: train_dataset = val_dataset train_step, val_step = 2, 2 train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test_set = load_test(energy_loss.get_tasks("test"), buildings=['almena', 'albertville', 'espanola']) ood_set = load_ood(energy_loss.get_tasks("ood")) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) realities = [train, val, test, ood] # GRAPH graph = TaskGraph(tasks=energy_loss.tasks + realities, tasks_in=energy_loss.tasks_in, tasks_out=energy_loss.tasks_out, pretrained=True, models_dir=models_dir, freeze_list=energy_loss.freeze_list, direct_edges=energy_loss.direct_edges, model_class=model_class) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) # LOGGING os.makedirs(results_dir, exist_ok=True) os.makedirs(results_dir_models, exist_ok=True) logger = VisdomLogger("train", env=job, port=PORT, server=SERVER) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{results_dir}/graph.pth", results_dir_models), feature="epoch", freq=1) energy_loss.logger_hooks(logger) energy_loss.plot_paths(graph, logger, realities, prefix="start") # BASELINE graph.eval() with torch.no_grad(): for _ in range(0, val_step * 4): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) for _ in range(0, train_step * 4): train_loss, _ = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="finish") graph.train() for _ in range(0, train_step): train_loss, grad_mse_coeff = energy_loss(graph, realities=[train], compute_grad_ratio=True) graph.step(train_loss, losses=energy_loss.losses, paths=energy_loss.paths) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) train.step() logger.update("loss", train_loss) del train_loss print(grad_mse_coeff) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss, _ = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step()
def main( loss_config="trainsig_edgereshade", mode="standard", visualize=False, fast=False, batch_size=32, learning_rate=3e-5, resume=False, subset_size=None, max_epochs=800, dataaug=False, **kwargs, ): # CONFIG wandb.config.update({"loss_config":loss_config,"batch_size":batch_size,"lr":learning_rate}) batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_undist_dataset, train_dist_dataset, val_ooddist_dataset, val_dist_dataset, val_dataset, train_step, val_step = load_train_val_sig( energy_loss.get_tasks("val"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood"), ood_path='./assets/ood_natural/') ood_syn_aug_set = load_ood(energy_loss.get_tasks("ood_syn_aug"), ood_path='./assets/st_syn_distortions/') ood_syn_set = load_ood(energy_loss.get_tasks("ood_syn"), ood_path='./assets/ood_syn_distortions/', sample=35) train_undist = RealityTask("train_undist", train_undist_dataset, batch_size=batch_size, shuffle=True) train_dist = RealityTask("train_dist", train_dist_dataset, batch_size=batch_size, shuffle=True) val_ooddist = RealityTask("val_ooddist", val_ooddist_dataset, batch_size=batch_size, shuffle=True) val_dist = RealityTask("val_dist", val_dist_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,]) ## standard ood set - natural ood_syn_aug = RealityTask.from_static("ood_syn_aug", ood_syn_aug_set, [tasks.rgb,]) ## synthetic distortion images used for sig training ood_syn = RealityTask.from_static("ood_syn", ood_syn_set, [tasks.rgb,]) ## unseen syn distortions # GRAPH realities = [train_undist, train_dist, val_ooddist, val_dist, val, test, ood, ood_syn_aug, ood_syn] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if resume: graph.load_weights('/workspace/shared/results_test_1/graph.pth') graph.optimizer.load_state_dict(torch.load('/workspace/shared/results_test_1/opt.pth')) # else: # folder_name='/workspace/shared/results_wavelet2normal_depthreshadecurvimgnetl1perceps_0.1nll/' # # pdb.set_trace() # in_domain='wav' # out_domain='normal' # graph.load_weights(folder_name+'graph.pth', [str((in_domain, out_domain))]) # create_t0_graph(folder_name,in_domain,out_domain) # graph.load_weights(folder_name+'graph_t0.pth', [str((in_domain, f'{out_domain}_t0'))]) # LOGGING logger = VisdomLogger("train", env=JOB) # fake visdom logger logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) # BASELINE if not resume: graph.eval() with torch.no_grad(): for reality in [val_ooddist,val_dist,val]: for _ in range(0, val_step): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) for reality in [train_undist,train_dist]: for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[reality]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) reality.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=0) path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0) # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) graph.train() for _ in range(0, train_step): train_loss_nll = energy_loss(graph, realities=[train_undist]) train_loss_nll = sum([train_loss_nll[loss_name] for loss_name in train_loss_nll]) train_loss_lwfsig = energy_loss(graph, realities=[train_dist]) train_loss_lwfsig = sum([train_loss_lwfsig[loss_name] for loss_name in train_loss_lwfsig]) train_loss = train_loss_nll+train_loss_lwfsig graph.step(train_loss) train_undist.step() train_dist.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val_dist]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val_dist.step() logger.update("loss", val_loss) if epochs % 20 == 0: for reality in [val,val_ooddist]: for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[reality]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) reality.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) data=logger.step() del data['loss'] data = {k:v[0] for k,v in data.items()} wandb.log(data, step=epochs+1) if epochs % 10 == 0: graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth") if (epochs % 100 == 0) or (epochs % 15 == 0 and epochs <= 30): path_values = energy_loss.plot_paths(graph, logger, realities, prefix="") for reality_paths, reality_images in path_values.items(): wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs+1) graph.save(f"{RESULTS_DIR}/graph.pth") torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")