def main(): logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="energy", freq=16) logger.add_hook( lambda logger, data: logger.plot(data["energy"], "free_energy"), feature="energy", freq=100) task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.sobel_edges, tasks.depth_zbuffer, tasks.reshading, tasks.edge_occlusion, tasks.keypoints3d, tasks.keypoints2d, ] reality = RealityTask('ood', dataset=ImagePairDataset(data_dir=OOD_DIR, resize=(256, 256)), tasks=[tasks.rgb, tasks.rgb], batch_size=28) # reality = RealityTask('almena', # dataset=TaskDataset( # buildings=['almena'], # tasks=task_list, # ), # tasks=task_list, # batch_size=8 # ) graph = TaskGraph(tasks=[reality, *task_list], batch_size=28) task = tasks.rgb images = [reality.task_data[task]] sources = [task.name] for _, edge in sorted( ((edge.dest_task.name, edge) for edge in graph.adj[task])): if isinstance(edge.src_task, RealityTask): continue reality.task_data[edge.src_task] x = edge(reality.task_data[edge.src_task]) if edge.dest_task != tasks.normal: edge2 = graph.edge_map[(edge.dest_task.name, tasks.normal.name)] x = edge2(x) images.append(x.clamp(min=0, max=1)) sources.append(edge.dest_task.name) logger.images_grouped(images, ", ".join(sources), resize=256)
def main(): logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="energy", freq=16) logger.add_hook(lambda logger, data: logger.plot(data["energy"], "free_energy"), feature="energy", freq=100) task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.sobel_edges, tasks.depth_zbuffer, tasks.reshading, tasks.edge_occlusion, tasks.keypoints3d, tasks.keypoints2d, ] reality = RealityTask('almena', dataset=TaskDataset( buildings=['almena'], tasks=task_list, ), tasks=task_list, batch_size=8 ) graph = TaskGraph( tasks=[reality, *task_list], batch_size=8 ) for task in graph.tasks: if isinstance(task, RealityTask): continue images = [reality.task_data[task].clamp(min=0, max=1)] sources = [task.name] for _, edge in sorted(((edge.src_task.name, edge) for edge in graph.in_adj[task])): if isinstance(edge.src_task, RealityTask): continue x = edge(reality.task_data[edge.src_task]) images.append(x.clamp(min=0, max=1)) sources.append(edge.src_task.name) logger.images_grouped(images, ", ".join(sources), resize=192)
def main(): task_list = [ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.sobel_edges, tasks.depth_zbuffer, tasks.reshading, tasks.edge_occlusion, tasks.keypoints3d, tasks.keypoints2d, ] reality = RealityTask('almena', dataset=TaskDataset(buildings=['almena'], tasks=[ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer ]), tasks=[ tasks.rgb, tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer ], batch_size=4) graph = TaskGraph( tasks=[reality, *task_list], anchored_tasks=[reality, tasks.rgb], reality=reality, batch_size=4, edges_exclude=[ ('almena', 'normal'), ('almena', 'principal_curvature'), ('almena', 'depth_zbuffer'), # ('rgb', 'keypoints3d'), # ('rgb', 'edge_occlusion'), ], initialize_first_order=True, ) graph.p.compile(torch.optim.Adam, lr=4e-2) graph.estimates.compile(torch.optim.Adam, lr=1e-2) logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="energy", freq=16) logger.add_hook( lambda logger, data: logger.plot(data["energy"], "free_energy"), feature="energy", freq=100) logger.add_hook(lambda logger, data: graph.plot_estimates(logger), feature="epoch", freq=32) logger.add_hook(lambda logger, data: graph.update_paths(logger), feature="epoch", freq=32) graph.plot_estimates(logger) graph.plot_paths(logger, dest_tasks=[ tasks.normal, tasks.principal_curvature, tasks.depth_zbuffer ], show_images=False) for epochs in range(0, 4000): logger.update("epoch", epochs) with torch.no_grad(): free_energy = graph.free_energy(sample=16) graph.averaging_step() logger.update("energy", free_energy)
def main( loss_config="conservative_full", mode="standard", visualize=False, pretrained=True, finetuned=False, fast=False, batch_size=None, ood_batch_size=None, subset_size=None, cont=f"{BASE_DIR}/shared/results_LBP_multipercep_lat_winrate_8/graph.pth", cont_gan=None, pre_gan=None, max_epochs=800, use_baseline=False, **kwargs, ): # CONFIG batch_size = batch_size or (4 if fast else 64) energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs) # DATA LOADING train_dataset, val_dataset, train_step, val_step = load_train_val( energy_loss.get_tasks("train"), energy_loss.get_tasks("val"), batch_size=batch_size, fast=fast, ) train_subset_dataset, _, _, _ = load_train_val( energy_loss.get_tasks("train_subset"), batch_size=batch_size, fast=fast, subset_size=subset_size) if not fast: train_step, val_step = train_step // (16 * 4), val_step // (16) test_set = load_test(energy_loss.get_tasks("test")) ood_set = load_ood(energy_loss.get_tasks("ood")) train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True) train_subset = RealityTask("train_subset", train_subset_dataset, batch_size=batch_size, shuffle=True) val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True) test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test")) # ood = RealityTask.from_static("ood", ood_set, [energy_loss.get_tasks("ood")]) # GRAPH realities = [ train, val, test, ] + [train_subset] #[ood] graph = TaskGraph(tasks=energy_loss.tasks + realities, finetuned=finetuned, freeze_list=energy_loss.freeze_list) graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) graph.load_weights(cont) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) # logger.add_hook(lambda _, __: graph.save(f"{RESULTS_DIR}/graph.pth"), feature="epoch", freq=1) energy_loss.logger_hooks(logger) best_ood_val_loss = float('inf') energy_losses = [] mse_losses = [] pearsonr_vals = [] percep_losses = defaultdict(list) pearson_percep = defaultdict(list) # # TRAINING # for epochs in range(0, max_epochs): # logger.update("epoch", epochs) # if epochs == 0: # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") # # if visualize: return # graph.eval() # for _ in range(0, val_step): # with torch.no_grad(): # losses = energy_loss(graph, realities=[val]) # all_perceps = [losses[loss_name] for loss_name in losses if 'percep' in loss_name ] # energy_avg = sum(all_perceps) / len(all_perceps) # for loss_name in losses: # if 'percep' not in loss_name: continue # percep_losses[loss_name] += [losses[loss_name].data.cpu().numpy()] # mse = losses['mse'] # energy_losses.append(energy_avg.data.cpu().numpy()) # mse_losses.append(mse.data.cpu().numpy()) # val.step() # mse_arr = np.array(mse_losses) # energy_arr = np.array(energy_losses) # # logger.scatter(mse_arr - mse_arr.mean() / np.std(mse_arr), \ # # energy_arr - energy_arr.mean() / np.std(energy_arr), \ # # 'unit_normal_all', opts={'xlabel':'mse','ylabel':'energy'}) # logger.scatter(mse_arr, energy_arr, \ # 'mse_energy_all', opts={'xlabel':'mse','ylabel':'energy'}) # pearsonr, p = scipy.stats.pearsonr(mse_arr, energy_arr) # logger.text(f'pearsonr = {pearsonr}, p = {p}') # pearsonr_vals.append(pearsonr) # logger.plot(pearsonr_vals, 'pearsonr_all') # for percep_name in percep_losses: # percep_loss_arr = np.array(percep_losses[percep_name]) # logger.scatter(mse_arr, percep_loss_arr, f'mse_energy_{percep_name}', \ # opts={'xlabel':'mse','ylabel':'energy'}) # pearsonr, p = scipy.stats.pearsonr(mse_arr, percep_loss_arr) # pearson_percep[percep_name] += [pearsonr] # logger.plot(pearson_percep[percep_name], f'pearson_{percep_name}') # energy_loss.logger_update(logger) # if logger.data['val_mse : n(~x) -> y^'][-1] < best_ood_val_loss: # best_ood_val_loss = logger.data['val_mse : n(~x) -> y^'][-1] # energy_loss.plot_paths(graph, logger, realities, prefix="best") energy_mean_by_blur = [] energy_std_by_blur = [] mse_mean_by_blur = [] mse_std_by_blur = [] for blur_size in np.arange(0, 10, 0.5): tasks.rgb.blur_radius = blur_size if blur_size > 0 else None train_subset.step() # energy_loss.plot_paths(graph, logger, realities, prefix="start" if epochs == 0 else "") energy_losses = [] mse_losses = [] for epochs in range(subset_size // batch_size): with torch.no_grad(): flosses = energy_loss(graph, realities=[train_subset], reduce=False) losses = energy_loss(graph, realities=[train_subset], reduce=False) all_perceps = np.stack([ losses[loss_name].data.cpu().numpy() for loss_name in losses if 'percep' in loss_name ]) energy_losses += list(all_perceps.mean(0)) mse_losses += list(losses['mse'].data.cpu().numpy()) train_subset.step() mse_losses = np.array(mse_losses) energy_losses = np.array(energy_losses) logger.text( f'blur_radius = {blur_size}, mse = {mse_losses.mean()}, energy = {energy_losses.mean()}' ) logger.scatter(mse_losses, energy_losses, \ f'mse_energy, blur = {blur_size}', opts={'xlabel':'mse','ylabel':'energy'}) energy_mean_by_blur += [energy_losses.mean()] energy_std_by_blur += [np.std(energy_losses)] mse_mean_by_blur += [mse_losses.mean()] mse_std_by_blur += [np.std(mse_losses)] logger.plot(energy_mean_by_blur, f'energy_mean_by_blur') logger.plot(energy_std_by_blur, f'energy_std_by_blur') logger.plot(mse_mean_by_blur, f'mse_mean_by_blur') logger.plot(mse_std_by_blur, f'mse_std_by_blur')
torch.save((perturbation.data, images.data, targets), f"{output_path}/{k}.pth") if __name__ == "__main__": model = DataParallelModel( DecodingModel(n=DIST_SIZE, distribution=transforms.training)) params = itertools.chain(model.module.classifier.parameters(), model.module.features[-1].parameters()) optimizer = torch.optim.Adam(params, lr=2.5e-3) init_data("data/amnesia") logger = VisdomLogger("train", server="35.230.67.129", port=8000, env=JOB) logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20) logger.add_hook(lambda data: logger.plot(data, "train_loss"), feature="loss", freq=50) logger.add_hook(lambda data: logger.plot(data, "train_bits"), feature="bits", freq=50) logger.add_hook( lambda x: model.save("output/train_test.pth", verbose=True), feature="epoch", freq=100) model.save("output/train_test.pth", verbose=True) files = glob.glob(f"data/amnesia/*.pth") for i, save_file in enumerate( random.choice(files) for i in range(0, 2701)):
from torch.autograd import Variable from models import DecodingModel, DataParallelModel from torchvision import models from logger import Logger, VisdomLogger from utils import * import IPython import transforms # LOGGING logger = VisdomLogger("encoding", server="35.230.67.129", port=8000, env=JOB) logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20) logger.add_hook(lambda x: logger.plot(x, "Encoding Loss", opts=dict(ymin=0)), feature="loss", freq=50) """ Computes the changed images, given a a specified perturbation, standard deviation weighting, and epsilon. """ def compute_changed_images(images, perturbation, std_weights, epsilon=EPSILON): perturbation_w2 = perturbation * std_weights perturbation_zc = ( perturbation_w2 / perturbation_w2.view(perturbation_w2.shape[0], -1) .norm(2, dim=1, keepdim=True)