def run_viz_suite(name, data, dest_task=tasks.depth_zbuffer, graph_file=None, model_file=None, logger=None, old=False, multitask=False, percep_mode=None): if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, dest_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif multitask: model = DataParallelModel.load( UNet(downsample=5, out_channels=6).cuda(), model_file) elif model_file is not None: print('here') #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file) model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model() model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) # DATA LOADING 1 results = model.predict(data)[:, -3:].clamp(min=0, max=1) if results.shape[1] == 1: results = torch.cat([results] * 3, dim=1) if percep_mode: percep_model = Transfer(src_task=dest_task, dest_task=tasks.normal).load_model() percep_model.eval() eval_loader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(results), batch_size=16, num_workers=16, shuffle=False, pin_memory=True) final_preds = [] for preds, in eval_loader: print('preds shape', preds.shape) final_preds += [percep_model.forward(preds[:, -3:])] results = torch.cat(final_preds, dim=0) return results
def run_viz_suite(name, data_loader, dest_task=tasks.depth_zbuffer, graph_file=None, model_file=None, old=False, multitask=False, percep_mode=None, downsample=6, out_channels=3, final_task=tasks.normal, oldpercep=False): extra_task = [final_task] if percep_mode else [] if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, dest_task] + extra_task, pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, dest_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif multitask: model = DataParallelModel.load( UNet(downsample=5, out_channels=6).cuda(), model_file) elif model_file is not None: # downsample = 5 or 6 print('loading main model') #model = DataParallelModel.load(UNetReshade(downsample=downsample, out_channels=out_channels).cuda(), model_file) model = DataParallelModel.load( UNet(downsample=downsample, out_channels=out_channels).cuda(), model_file) #model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = DummyModel( Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model()) model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) # DATA LOADING 1 results = [] final_preds = [] if percep_mode: print('Loading percep model...') if graph_file is not None and not oldpercep: percep_model = graph.edge(dest_task, final_task).load_model() percep_model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) else: percep_model = Transfer(src_task=dest_task, dest_task=final_task).load_model() percep_model.eval() print("Converting...") for data, in data_loader: preds = model.predict_on_batch(data)[:, -3:].clamp(min=0, max=1) results.append(preds.detach().cpu()) if percep_mode: try: final_preds += [ percep_model.forward(preds[:, -3:]).detach().cpu() ] except RuntimeError: preds = torch.cat([preds] * 3, dim=1) final_preds += [ percep_model.forward(preds[:, -3:]).detach().cpu() ] #break if percep_mode: results = torch.cat(final_preds, dim=0) else: results = torch.cat(results, dim=0) return results