def run_eval_suite(name, dest_task=tasks.normal, graph_file=None, model_file=None, logger=None, sample=800, show_images=False, old=False): if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, dest_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif model_file is not None: #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file) model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = Transfer(src_task=tasks.normal, dest_task=dest_task).load_model() model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) dataset = ValidationMetrics("almena", dest_task=dest_task) result = dataset.evaluate(model, sample=800) logger.text(name + ": " + str(result))
def run_viz_suite(name, data, dest_task=tasks.depth_zbuffer, graph_file=None, model_file=None, logger=None, old=False, multitask=False, percep_mode=None): if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, dest_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif multitask: model = DataParallelModel.load( UNet(downsample=5, out_channels=6).cuda(), model_file) elif model_file is not None: print('here') #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file) model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model() model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) # DATA LOADING 1 results = model.predict(data)[:, -3:].clamp(min=0, max=1) if results.shape[1] == 1: results = torch.cat([results] * 3, dim=1) if percep_mode: percep_model = Transfer(src_task=dest_task, dest_task=tasks.normal).load_model() percep_model.eval() eval_loader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(results), batch_size=16, num_workers=16, shuffle=False, pin_memory=True) final_preds = [] for preds, in eval_loader: print('preds shape', preds.shape) final_preds += [percep_model.forward(preds[:, -3:])] results = torch.cat(final_preds, dim=0) return results
def run_perceptual_eval_suite(name, intermediate_task=tasks.normal, dest_task=tasks.normal, graph_file=None, model_file=None, logger=None, sample=800, show_images=False, old=False, perceptual_transfer=None, multitask=False): if perceptual_transfer is None: percep_model = Transfer(src_task=intermediate_task, dest_task=dest_task).load_model() if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, intermediate_task], pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, intermediate_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif multitask: print('running multitask') model = DataParallelModel.load( UNet(downsample=5, out_channels=6).cuda(), model_file) elif model_file is not None: #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file) model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = Transfer(src_task=tasks.rgb, dest_task=intermediate_task).load_model() model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) dataset = ValidationMetrics("almena", dest_task=dest_task) result = dataset.evaluate_with_percep(model, sample=800, percep_model=percep_model) logger.text(name + ": " + str(result))
def run_viz_suite(name, data_loader, dest_task=tasks.depth_zbuffer, graph_file=None, model_file=None, old=False, multitask=False, percep_mode=None, downsample=6, out_channels=3, final_task=tasks.normal, oldpercep=False): extra_task = [final_task] if percep_mode else [] if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, dest_task] + extra_task, pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, dest_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif multitask: model = DataParallelModel.load( UNet(downsample=5, out_channels=6).cuda(), model_file) elif model_file is not None: # downsample = 5 or 6 print('loading main model') #model = DataParallelModel.load(UNetReshade(downsample=downsample, out_channels=out_channels).cuda(), model_file) model = DataParallelModel.load( UNet(downsample=downsample, out_channels=out_channels).cuda(), model_file) #model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = DummyModel( Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model()) model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) # DATA LOADING 1 results = [] final_preds = [] if percep_mode: print('Loading percep model...') if graph_file is not None and not oldpercep: percep_model = graph.edge(dest_task, final_task).load_model() percep_model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) else: percep_model = Transfer(src_task=dest_task, dest_task=final_task).load_model() percep_model.eval() print("Converting...") for data, in data_loader: preds = model.predict_on_batch(data)[:, -3:].clamp(min=0, max=1) results.append(preds.detach().cpu()) if percep_mode: try: final_preds += [ percep_model.forward(preds[:, -3:]).detach().cpu() ] except RuntimeError: preds = torch.cat([preds] * 3, dim=1) final_preds += [ percep_model.forward(preds[:, -3:]).detach().cpu() ] #break if percep_mode: results = torch.cat(final_preds, dim=0) else: results = torch.cat(results, dim=0) return results
def main( loss_config="conservative_full", mode="standard", visualize=False, fast=False, batch_size=None, use_optimizer=False, subset_size=None, max_epochs=800, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) print(kwargs) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size) train_step, val_step = train_step // 4, val_step // 4 test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood([ tasks.rgb, ]) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, [ tasks.rgb, ]) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge( tasks.rgb, tasks.normal ).path = "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/n.pth" graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) if use_optimizer: optimizer = torch.load( "mount/shared/results_SAMPLEFF_baseline_fulldata_opt_4/optimizer.pth" ) graph.optimizer.load_state_dict(optimizer.state_dict()) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) energy_loss.logger_update(logger) # if logger.data["val_mse : y^ -> n(~x)"][-1] < best_ood_val_loss: # best_ood_val_loss = logger.data["val_mse : y^ -> n(~x)"][-1] # energy_loss.plot_paths(graph, logger, realities, prefix="best") logger.step()
def main( fast=False, subset_size=None, early_stopping=float('inf'), mode='standard', max_epochs=800, **kwargs, ): early_stopping = 8 loss_config_percepnet = { "paths": { "y": [tasks.normal], "z^": [tasks.principal_curvature], "f(y)": [tasks.normal, tasks.principal_curvature], }, "losses": { "mse": { ("train", "val"): [ ("f(y)", "z^"), ], }, }, "plots": { "ID": dict(size=256, realities=("test", "ood"), paths=[ "y", "z^", "f(y)", ]), }, } # CONFIG batch_size = 64 energy_loss = EnergyLoss(**loss_config_percepnet) task_list = [tasks.rgb, tasks.normal, tasks.principal_curvature] # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( task_list, batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(task_list) ood_set = load_ood(task_list) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, task_list) ood = RealityTask.from_static("ood", ood_set, task_list) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities, pretrained=False, freeze_list=[functional_transfers.n], ) graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : f(y) -> z^"][-1] < best_val_loss: print("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : f(y) -> z^"][ -1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(weights_dir=f"{RESULTS_DIR}") if stop_idx >= early_stopping: print("Stopping training now") break early_stopping = 50 # CONFIG energy_loss = get_energy_loss(config="perceptual", mode=mode, **kwargs) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=[tasks.rgb, tasks.normal, tasks.principal_curvature] + realities, pretrained=False, freeze_list=[functional_transfers.f], ) graph.edge( tasks.normal, tasks.principal_curvature).model.load_weights(f"{RESULTS_DIR}/f.pth") graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum( [train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss: print("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][ -1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(f"{RESULTS_DIR}/graph.pth") if stop_idx >= early_stopping: print("Stopping training now") break
def main( loss_config="baseline", mode="standard", visualize=False, fast=False, batch_size=None, path=None, subset_size=None, early_stopping=float('inf'), max_epochs=800, **kwargs ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, subset_size=subset_size, ) test_set = load_test(energy_loss.get_tasks("test")) print('tasks', energy_loss.get_tasks("ood")) ood_set = load_ood(energy_loss.get_tasks("ood")) print (train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=False, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge(tasks.rgb, tasks.normal).path = path graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) graph.save(weights_dir=f"{RESULTS_DIR}") # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) energy_loss.logger_hooks(logger) best_val_loss, stop_idx = float('inf'), 0 # return # TRAINING for epochs in range(0, max_epochs): logger.update("epoch", epochs) energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") if visualize: return graph.train() for _ in range(0, train_step): train_loss = energy_loss(graph, realities=[train]) train_loss = sum([train_loss[loss_name] for loss_name in train_loss]) graph.step(train_loss) train.step() logger.update("loss", train_loss) graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() stop_idx += 1 if logger.data["val_mse : n(x) -> y^"][-1] < best_val_loss: print ("Better val loss, reset stop_idx: ", stop_idx) best_val_loss, stop_idx = logger.data["val_mse : n(x) -> y^"][-1], 0 energy_loss.plot_paths(graph, logger, realities, prefix="best") graph.save(weights_dir=f"{RESULTS_DIR}") if stop_idx >= early_stopping: print ("Stopping training now") return
def main( loss_config="baseline", mode="standard", visualize=False, fast=False, batch_size=None, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), batch_size=batch_size, fast=fast, ) train_step, val_step = 4 * train_step, 4 * val_step test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) print(train_step, val_step) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) ood = RealityTask.from_static("ood", ood_set, energy_loss.get_tasks("ood")) # GRAPH realities = [train, val, test, ood] graph = TaskGraph( tasks=energy_loss.tasks + realities, pretrained=True, freeze_list=energy_loss.freeze_list, ) graph.edge(tasks.rgb, tasks.normal).model = None graph.edge(tasks.rgb, tasks.normal ).path = f"{SHARED_DIR}/results_SAMPLEFF_consistency1m_25/n.pth" graph.edge(tasks.rgb, tasks.normal).load_model() graph.compile(torch.optim.Adam, lr=4e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) # TRAINING graph.eval() for _ in range(0, val_step): with torch.no_grad(): val_loss = energy_loss(graph, realities=[val]) val_loss = sum([val_loss[loss_name] for loss_name in val_loss]) val.step() logger.update("loss", val_loss) energy_loss.logger_update(logger) logger.step() # print ("Train mse: ", logger.data["train_mse : n(x) -> y^"]) print("Val mse: ", logger.data["val_mse : n(x) -> y^"])