def run_eval_suite(name, dest_task=tasks.normal, graph_file=None, model_file=None, logger=None, sample=800, show_images=False, old=False): if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, dest_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif model_file is not None: #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file) model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = Transfer(src_task=tasks.normal, dest_task=dest_task).load_model() model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) dataset = ValidationMetrics("almena", dest_task=dest_task) result = dataset.evaluate(model, sample=800) logger.text(name + ": " + str(result))
def load_model(self): if self.model is None: print("Loading pretrained model from:", self.path) if self.path is not None: self.model = DataParallelModel.load( self.model_type().to(DEVICE), self.path) else: self.model = self.model_type() if isinstance(self.model, nn.Module): self.model = DataParallelModel(self.model) return self.model
def load_model(self): if self.model is None: if self.path is not None: self.model = DataParallelModel.load( self.model_type().to(DEVICE), self.path) # if optimizer: # self.model.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True) else: self.model = self.model_type().to(DEVICE) if isinstance(self.model, nn.Module): self.model = DataParallelModel(self.model) return self.model
def run_viz_suite(name, data, dest_task=tasks.depth_zbuffer, graph_file=None, model_file=None, logger=None, old=False, multitask=False, percep_mode=None): if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, dest_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif multitask: model = DataParallelModel.load( UNet(downsample=5, out_channels=6).cuda(), model_file) elif model_file is not None: print('here') #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file) model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model() model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) # DATA LOADING 1 results = model.predict(data)[:, -3:].clamp(min=0, max=1) if results.shape[1] == 1: results = torch.cat([results] * 3, dim=1) if percep_mode: percep_model = Transfer(src_task=dest_task, dest_task=tasks.normal).load_model() percep_model.eval() eval_loader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(results), batch_size=16, num_workers=16, shuffle=False, pin_memory=True) final_preds = [] for preds, in eval_loader: print('preds shape', preds.shape) final_preds += [percep_model.forward(preds[:, -3:])] results = torch.cat(final_preds, dim=0) return results
def run_perceptual_eval_suite(name, intermediate_task=tasks.normal, dest_task=tasks.normal, graph_file=None, model_file=None, logger=None, sample=800, show_images=False, old=False, perceptual_transfer=None, multitask=False): if perceptual_transfer is None: percep_model = Transfer(src_task=intermediate_task, dest_task=dest_task).load_model() if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, intermediate_task], pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, intermediate_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif multitask: print('running multitask') model = DataParallelModel.load( UNet(downsample=5, out_channels=6).cuda(), model_file) elif model_file is not None: #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file) model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = Transfer(src_task=tasks.rgb, dest_task=intermediate_task).load_model() model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) dataset = ValidationMetrics("almena", dest_task=dest_task) result = dataset.evaluate_with_percep(model, sample=800, percep_model=percep_model) logger.text(name + ": " + str(result))
def load_model(self, parents=[]): model_path = get_finetuned_model_path(parents + [self]) if model_path not in self.cached_models: if not os.path.exists(model_path): self.cached_models[model_path] = super().load_model() else: self.cached_models[model_path] = DataParallelModel.load( self.model_type().cuda(), model_path) self.model = self.cached_models[model_path] return self.model
def load_model(self, parents=[]): model_path = get_finetuned_model_path(parents + [self]) if model_path not in self.cached_models: if not os.path.exists(model_path): print(f"{model_path} not found, loading pretrained") self.cached_models[model_path] = super().load_model() else: print(f"{model_path} found, loading finetuned") self.cached_models[model_path] = DataParallelModel.load( self.model_type().cuda(), model_path) print(f"") self.model = self.cached_models[model_path] return self.model
def main(loss_config="gt_mse", mode="standard", pretrained=False, batch_size=64, **kwargs): # MODEL # model = DataParallelModel.load(UNet().cuda(), "standardval_rgb2normal_baseline.pth") model = functional_transfers.n.load_model() if pretrained else DataParallelModel(UNet()) model.compile(torch.optim.Adam, lr=(3e-5 if pretrained else 3e-4), weight_decay=2e-6, amsgrad=True) scheduler = MultiStepLR(model.optimizer, milestones=[5*i + 1 for i in range(0, 80)], gamma=0.95) # FUNCTIONAL LOSS functional = get_functional_loss(config=loss_config, mode=mode, model=model, **kwargs) print (functional) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20) logger.add_hook(lambda logger, data: model.save(f"{RESULTS_DIR}/model.pth"), feature="loss", freq=400) logger.add_hook(lambda logger, data: scheduler.step(), feature="epoch", freq=1) functional.logger_hooks(loggers) # DATA LOADING ood_images = load_ood(ood_path=f'{BASE_DIR}/data/ood_images/') train_loader, val_loader, train_step, val_step = load_train_val([tasks.rgb, tasks.normal], batch_size=batch_size) # train_buildings=["almena"], val_buildings=["almena"]) test_set, test_images = load_test("rgb", "normal") logger.images(test_images, "images", resize=128) logger.images(torch.cat(ood_images, dim=0), "ood_images", resize=128) # TRAINING for epochs in range(0, 800): preds_name = "start_preds" if epochs == 0 and pretrained else "preds" ood_name = "start_ood" if epochs == 0 and pretrained else "ood" plot_images(model, logger, test_set, dest_task="normal", ood_images=ood_images, loss_models=functional.plot_losses, preds_name=preds_name, ood_name=ood_name ) logger.update("epoch", epochs) logger.step() train_set = itertools.islice(train_loader, train_step) val_set = itertools.islice(val_loader, val_step) val_metrics = model.predict_with_metrics(val_set, loss_fn=functional, logger=logger) train_metrics = model.fit_with_metrics(train_set, loss_fn=functional, logger=logger) functional.logger_update(logger, train_metrics, val_metrics)
def main(src_task, dest_task, fast=False): # src_task, dest_task = get_task(src_task), get_task(dest_task) model = DataParallelModel(get_model(src_task, dest_task).cuda()) model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) # LOGGING logger = VisdomLogger("train", env=JOB) logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=25) logger.add_hook(partial(jointplot, loss_type="mse_loss"), feature="val_mse_loss", freq=1) logger.add_hook(lambda logger, data: model.save( f"{RESULTS_DIR}/{src_task.name}2{dest_task.name}.pth"), feature="loss", freq=400) # DATA LOADING train_loader, val_loader, train_step, val_step = load_train_val( [src_task, dest_task], batch_size=48, train_buildings=["almena"], val_buildings=["almena"]) train_step, val_step = 5, 5 test_set, test_images = load_test(src_task, dest_task) src_task.plot_func(test_images, "images", logger, resize=128) for epochs in range(0, 800): logger.update("epoch", epochs) plot_images(model, logger, test_set, dest_task, show_masks=True) train_set = itertools.islice(train_loader, train_step) val_set = itertools.islice(val_loader, val_step) (train_mse_data, ) = model.fit_with_metrics(train_set, loss_fn=dest_task.norm, logger=logger) logger.update("train_mse_loss", train_mse_data) (val_mse_data, ) = model.predict_with_metrics(val_set, loss_fn=dest_task.norm, logger=logger) logger.update("val_mse_loss", val_mse_data)
def encode( image, out, target=binary.str(binary.random(TARGET_SIZE)), n=96, model=None, max_iter=500, use_weighting=True, perturbation_out=None, ): if not isinstance(model, DecodingModel): model = DataParallelModel(DecodingModel.load(distribution=transforms.encoding, n=n, weights_file=model)) image = im.torch(im.load(image)).unsqueeze(0) print("Target: ", target) target = binary.parse(str(target)) encoded = encode_binary(image, [target], model, n=n, verbose=True, max_iter=max_iter, use_weighting=use_weighting) im.save(im.numpy(encoded.squeeze()), file=out) if perturbation_out != None: im.save(im.numpy((image - encoded).squeeze()), file=perturbation_out)
def evaluate(model, image, target, test_transforms=False): if not isinstance(model, BaseModel): model = DataParallelModel( DecodingModel.load(distribution=transforms.identity, n=1, weights_file=model)) image = im.torch(im.load(image)).unsqueeze(0) target = binary.parse(str(target)) prediction = model(image).mean(dim=1).squeeze().cpu().data.numpy() prediction = binary.get(prediction) # print (f"Target: {binary.str(target)}, Prediction: {binary.str(prediction)}, \ # Distance: {binary.distance(target, prediction)}") if test_transforms: sweep(image, [target], model, transform=transforms.rotate, name="eval", samples=60)
def run_viz_suite(name, data_loader, dest_task=tasks.depth_zbuffer, graph_file=None, model_file=None, old=False, multitask=False, percep_mode=None, downsample=6, out_channels=3, final_task=tasks.normal, oldpercep=False): extra_task = [final_task] if percep_mode else [] if graph_file is not None: graph = TaskGraph(tasks=[tasks.rgb, dest_task] + extra_task, pretrained=False) graph.load_weights(graph_file) model = graph.edge(tasks.rgb, dest_task).load_model() elif old: model = DataParallelModel.load(UNetOld().cuda(), model_file) elif multitask: model = DataParallelModel.load( UNet(downsample=5, out_channels=6).cuda(), model_file) elif model_file is not None: # downsample = 5 or 6 print('loading main model') #model = DataParallelModel.load(UNetReshade(downsample=downsample, out_channels=out_channels).cuda(), model_file) model = DataParallelModel.load( UNet(downsample=downsample, out_channels=out_channels).cuda(), model_file) #model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file) else: model = DummyModel( Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model()) model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) # DATA LOADING 1 results = [] final_preds = [] if percep_mode: print('Loading percep model...') if graph_file is not None and not oldpercep: percep_model = graph.edge(dest_task, final_task).load_model() percep_model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True) else: percep_model = Transfer(src_task=dest_task, dest_task=final_task).load_model() percep_model.eval() print("Converting...") for data, in data_loader: preds = model.predict_on_batch(data)[:, -3:].clamp(min=0, max=1) results.append(preds.detach().cpu()) if percep_mode: try: final_preds += [ percep_model.forward(preds[:, -3:]).detach().cpu() ] except RuntimeError: preds = torch.cat([preds] * 3, dim=1) final_preds += [ percep_model.forward(preds[:, -3:]).detach().cpu() ] #break if percep_mode: results = torch.cat(final_preds, dim=0) else: results = torch.cat(results, dim=0) return results
def save_outputs(img_path, output_file_name): img = Image.open(img_path) img_tensor = trans_totensor(img)[:3].unsqueeze(0).to(DEVICE) if distortion is not None: trans_topil(img_tensor[0].clamp( min=0, max=1).cpu()).save(args.output_path + '/' + 'distorted_input.png') # compute baseline and consistency output #for type in ['baseline','consistency']: # path = root_dir + 'rgb2'+args.task+'_'+type+'.pth' # model_state_dict = torch.load(path, map_location=map_location) # model.load_state_dict(model_state_dict) # baseline_output = model(img_tensor).clamp(min=0, max=1) # trans_topil(baseline_output[0]).save(args.output_path+'/'+output_file_name+'_'+args.task+'_'+type+'.png') # compute all 8 path outputs #pdb.set_trace() all_models_state_dict = torch.load(path, map_location=map_location) direct_model = WrapperModel(DataParallelModel(models[0].to(DEVICE))) #pdb.set_trace() direct_model.load_state_dict(all_models_state_dict["('rgb', '" + target_task + "')"]) direct_output = direct_model(img_tensor) #.clamp(min=0, max=1) emboss_model = WrapperModel(DataParallelModel(models[1].to(DEVICE))) emboss_model.load_state_dict(all_models_state_dict["('emboss4d', '" + target_task + "')"]) emboss_output = emboss_model( emboss4d_kernel(img_tensor)) #.clamp(min=0, max=1) grey_model = WrapperModel(DataParallelModel(models[2].to(DEVICE))) grey_model.load_state_dict(all_models_state_dict["('grey', '" + target_task + "')"]) grey_output = grey_model(greyscale(img_tensor)) #.clamp(min=0, max=1) laplace_model = WrapperModel(DataParallelModel(models[3].to(DEVICE))) laplace_model.load_state_dict(all_models_state_dict["('laplace_edges', '" + target_task + "')"]) laplace_output = laplace_model( laplace_kernel(img_tensor)) #.clamp(min=0, max=1) gauss_model = WrapperModel(DataParallelModel(models[4].to(DEVICE))) gauss_model.load_state_dict(all_models_state_dict["('gauss', '" + target_task + "')"]) gauss_output = gauss_model(gauss_kernel(img_tensor)) #.clamp(min=0, max=1) sobel_model = WrapperModel(DataParallelModel(models[5].to(DEVICE))) sobel_model.load_state_dict(all_models_state_dict["('sobel_edges', '" + target_task + "')"]) sobel_output = sobel_model(sobel_kernel(img_tensor)) #.clamp(min=0, max=1) wav_model = WrapperModel(DataParallelModel(models[6].to(DEVICE))) wav_model.load_state_dict(all_models_state_dict["('wav', '" + target_task + "')"]) wav_output = wav_model(wav_kernel(img_tensor)) #.clamp(min=0, max=1) sharp_model = WrapperModel(DataParallelModel(models[7].to(DEVICE))) sharp_model.load_state_dict(all_models_state_dict["('sharp', '" + target_task + "')"]) sharp_output = sharp_model(sharp_kernel(img_tensor)) #.clamp(min=0, max=1) #merged_outputs = torch.Tensor().cuda() merged_outputs = torch.cat( (direct_output, emboss_output, grey_output, laplace_output, gauss_output, sobel_output, wav_output, sharp_output), dim=1) npaths = 8 di_ind = np.diag_indices(npaths) nchannels = int(merged_outputs.size(1) // (npaths * 2)) inds = np.arange(npaths) * 2 + 1 # indices of channel0 sigmas SQRT2 = math.sqrt(2) for i in range(nchannels): inds_ = nchannels * inds + i merged_outputs[:, inds_] = merged_outputs[:, inds_].exp( ) * SQRT2 # convert to sigma from log(b) ######## get sig avg weights if nchannels == 1: muind = inds - 1 sigind = inds else: muind = np.array([ 0, 1, 2, 6, 7, 8, 12, 13, 14, 18, 19, 20, 24, 25, 26, 30, 31, 32, 36, 37, 38, 42, 43, 44 ]) # 8 paths sigind = muind + 3 sig_avg_weights = torch.cuda.FloatTensor( merged_outputs[:, :npaths].size()).fill_(0.0) total_inv_sig = (1. / merged_outputs[:, sigind].pow(2)).sum(1) for i in range(npaths): sig_avg_weights[:, i] = ( 1. / merged_outputs[:, 2 * i * nchannels + nchannels:2 * (i + 1) * nchannels].pow(2)).sum(1) / total_inv_sig weights = sig_avg_weights merged_mu = torch.cuda.FloatTensor( merged_outputs[:, :nchannels].size()).fill_(0.0) merged_sig = torch.cuda.FloatTensor( merged_outputs[:, :nchannels].size()).fill_(0.0) for i in range(nchannels): inds_ = i + nchannels * inds ## compute correl mat cov_mat = torch.cuda.FloatTensor(merged_mu.size(0), merged_mu.size(-1), merged_mu.size(-1), int(npaths), int(npaths)).fill_(0.0) cov_mat[:, :, :, di_ind[0], di_ind[1]] = merged_outputs[:, inds_].pow(2).permute(0, 2, 3, 1) ## merge merged_mu[:, i] = (merged_outputs[:, inds_ - nchannels] * weights).sum(dim=1) weights = weights.permute(0, 2, 3, 1) merged_sig[:, i] = ( weights.unsqueeze(-2) @ cov_mat @ weights.unsqueeze(-1)).squeeze(-1).squeeze(-1).sqrt() weights = weights.permute(0, 3, 1, 2) if nchannels is 1: var_epistemic = merged_outputs[:, [0, 2, 4, 6, 8, 10, 12, 14]].var( 1, keepdim=True) else: var_epistemic_r = merged_outputs[:, [0, 6, 12, 18, 24, 30, 36, 42]].var( 1, keepdim=True) var_epistemic_g = merged_outputs[:, [1, 7, 13, 19, 25, 31, 37, 43]].var( 1, keepdim=True) var_epistemic_b = merged_outputs[:, [2, 8, 14, 20, 26, 32, 38, 44]].var( 1, keepdim=True) var_epistemic = torch.cat( (var_epistemic_r, var_epistemic_g, var_epistemic_b), dim=1) baseline_mu = direct_output[:, :nchannels] baseline_sig = direct_output[:, nchannels:] trans_topil(merged_mu[0].clamp( min=0, max=1).cpu()).save(args.output_path + '/' + output_file_name + '_' + args.task + '_' + 'ours_mean.png') #trans_topil((merged_sig[0].exp()*SQRT2).cpu()).save(args.output_path+'/'+output_file_name+'_'+args.task+'_'+'ours_var.png') trans_topil((var_epistemic[0].sqrt() * SQRT2).clamp( min=0, max=1).cpu()).save(args.output_path + '/' + output_file_name + '_' + args.task + '_' + 'ours_sig.png') trans_topil(baseline_mu[0].clamp( min=0, max=1).cpu()).save(args.output_path + '/' + output_file_name + '_' + args.task + '_' + 'baseline_mean.png') trans_topil((baseline_sig[0].exp() * SQRT2).cpu()).save(args.output_path + '/' + output_file_name + '_' + args.task + '_' + 'baseline_sig.png') ### GET DEEP ENS RESULTS ### all_models_state_dict = torch.load(path_deepens, map_location=map_location) direct_model = WrapperModel(DataParallelModel(models_deepens.to(DEVICE))) direct_model.load_state_dict( all_models_state_dict["('rgb', '" + target_task + "1_ens')"]) direct_output1 = direct_model(img_tensor) #.clamp(min=0, max=1) direct_model.load_state_dict( all_models_state_dict["('rgb', '" + target_task + "2_ens')"]) direct_output2 = direct_model(img_tensor) #.clamp(min=0, max=1) direct_model.load_state_dict( all_models_state_dict["('rgb', '" + target_task + "3_ens')"]) direct_output3 = direct_model(img_tensor) #.clamp(min=0, max=1) direct_model.load_state_dict( all_models_state_dict["('rgb', '" + target_task + "4_ens')"]) direct_output4 = direct_model(img_tensor) #.clamp(min=0, max=1) direct_model.load_state_dict( all_models_state_dict["('rgb', '" + target_task + "5_ens')"]) direct_output5 = direct_model(img_tensor) #.clamp(min=0, max=1) direct_model.load_state_dict( all_models_state_dict["('rgb', '" + target_task + "6_ens')"]) direct_output6 = direct_model(img_tensor) #.clamp(min=0, max=1) direct_model.load_state_dict( all_models_state_dict["('rgb', '" + target_task + "7_ens')"]) direct_output7 = direct_model(img_tensor) #.clamp(min=0, max=1) direct_model.load_state_dict( all_models_state_dict["('rgb', '" + target_task + "8_ens')"]) direct_output8 = direct_model(img_tensor) #.clamp(min=0, max=1) mu_ens = 0.125 * ( direct_output1[:, :nchannels] + direct_output2[:, :nchannels] + direct_output3[:, :nchannels] + direct_output4[:, :nchannels] + direct_output5[:, :nchannels] + direct_output6[:, :nchannels] + direct_output7[:, :nchannels] + direct_output8[:, :nchannels]) merged_outputs_ens = torch.cat( (direct_output1, direct_output2, direct_output3, direct_output4, direct_output5, direct_output6, direct_output7, direct_output8), dim=1) if nchannels is 1: var_epistemic_ens = merged_outputs_ens[:, [0, 2, 4, 6, 8, 10, 12, 14 ]].var(1, keepdim=True) else: var_epistemic_ens_r = merged_outputs_ens[:, [ 0, 6, 12, 18, 24, 30, 36, 42 ]].var(1, keepdim=True) var_epistemic_ens_g = merged_outputs_ens[:, [ 1, 7, 13, 19, 25, 31, 37, 43 ]].var(1, keepdim=True) var_epistemic_ens_b = merged_outputs_ens[:, [ 2, 8, 14, 20, 26, 32, 38, 44 ]].var(1, keepdim=True) var_epistemic_ens = torch.cat( (var_epistemic_ens_r, var_epistemic_ens_g, var_epistemic_ens_b), dim=1) trans_topil(mu_ens[0].clamp( min=0, max=1).cpu()).save(args.output_path + '/' + output_file_name + '_' + args.task + '_' + 'deepens_mean.png') trans_topil((var_epistemic_ens[0].sqrt() * SQRT2).clamp( min=0, max=1).cpu()).save(args.output_path + '/' + output_file_name + '_' + args.task + '_' + 'deepens_sig.png')
for k, files in tqdm.tqdm(list( enumerate(batch(image_files, batch_size=BATCH_SIZE))), ncols=50): images = im.stack([im.load(img_file) for img_file in files]).detach() perturbation = nn.Parameter(0.03 * torch.randn(images.size()).to(DEVICE) + 0.0) targets = [binary.random(n=TARGET_SIZE) for i in range(len(images))] torch.save((perturbation.data, images.data, targets), f"{output_path}/{k}.pth") if __name__ == "__main__": model = DataParallelModel( DecodingModel(n=DIST_SIZE, distribution=transforms.training)) params = itertools.chain(model.module.classifier.parameters(), model.module.features[-1].parameters()) optimizer = torch.optim.Adam(params, lr=2.5e-3) init_data("data/amnesia") logger = VisdomLogger("train", server="35.230.67.129", port=8000, env=JOB) logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20) logger.add_hook(lambda data: logger.plot(data, "train_loss"), feature="loss", freq=50) logger.add_hook(lambda data: logger.plot(data, "train_bits"), feature="bits", freq=50) logger.add_hook( lambda x: model.save("output/train_test.pth", verbose=True),
def test_transforms(model=None, image_files=VAL_FILES, name="test", max_iter=250): if not isinstance(model, BaseModel): print(f"Loading model from {model}") model = DataParallelModel( DecodingModel.load(distribution=transforms.new_dist, n=ENCODING_DIST_SIZE, weights_file=model)) images = [im.load(image) for image in image_files] images = im.stack(images) targets = [binary.random(n=TARGET_SIZE) for _ in range(0, len(images))] model.eval() encoded_images = encode_binary(images, targets, model, n=ENCODING_DIST_SIZE, verbose=True, max_iter=max_iter, use_weighting=True) logger.images(images, "original_images", resize=196) logger.images(encoded_images, "encoded_images", resize=196) for img, encoded_im, filename, target in zip(images, encoded_images, image_files, targets): im.save( im.numpy(img), file= f"output/_{binary.str(target)}_original_{filename.split('/')[-1]}") im.save( im.numpy(encoded_im), file= f"output/_{binary.str(target)}_encoded_{filename.split('/')[-1]}") model.set_distribution(transforms.identity, n=1) predictions = model(encoded_images).mean(dim=1).cpu().data.numpy() binary_loss = np.mean( [binary.distance(x, y) for x, y in zip(predictions, targets)]) for transform in [ transforms.pixilate, transforms.blur, transforms.rotate, transforms.scale, transforms.translate, transforms.noise, transforms.crop, transforms.gauss, transforms.whiteout, transforms.resize_rect, transforms.color_jitter, transforms.jpeg_transform, transforms.elastic, transforms.brightness, transforms.contrast, transforms.flip, ]: sweep(encoded_images, targets, model, transform=transform, name=name, samples=60) # sweep( # encoded_images, # targets, # model, # transform=lambda x, val: transforms.rotate(x, rand_val=False, theta=val), # name=name, # transform_name="rotate", # min_val=-0.6, # max_val=0.6, # samples=80, # ) # sweep( # encoded_images, # targets, # model, # transform=lambda x, val: transforms.scale(x, rand_val=False, scale_val=val), # name=name, # transform_name="scale", # min_val=0.6, # max_val=1.4, # samples=50, # ) # sweep( # encoded_images, # targets, # model, # transform=lambda x, val: transforms.translate(x, rand_val=False, radius=val), # name=name, # transform_name="translate", # min_val=0.0, # max_val=1.0, # samples=50, # ) # sweep( # encoded_images, # targets, # model, # transform=lambda x, val: transforms.noise(x, intensity=val), # name=name, # transform_name="noise", # min_val=0.0, # max_val=0.1, # samples=30, # ) # sweep( # encoded_images, # targets, # model, # transform=lambda x, val: transforms.crop(x, p=val), # name=name, # transform_name="crop", # min_val=0.1, # max_val=1.0, # samples=50, # ) # sweep( # encoded_images, # targets, # model, # transform=lambda x, val: transforms.gauss(x, sigma=val, rand_val=False), # name=name, # transform_name="gauss", # min_val=0.3, # max_val=4, # samples=50, # ) # sweep( # encoded_images, # targets, # model, # transform=lambda x, val: transforms.whiteout(x, scale=val, rand_val=False), # name=name, # transform_name="whiteout", # min_val=0.02, # max_val=0.2, # samples=50, # ) # sweep( # encoded_images, # targets, # model, # transform=lambda x, val: transforms.resize_rect(x, ratio=val, rand_val=False), # name=name, # transform_name="resize_rect", # min_val=0.5, # max_val=1.5, # samples=50, # ) # sweep( # encoded_images, # targets, # model, # transform=lambda x, val: transforms.color_jitter(x, jitter=val), # name=name, # transform_name="jitter", # min_val=0, # max_val=0.2, # samples=50, # ) # sweep( # encoded_images, # targets, # model, # transform=lambda x, val: transforms.convertToJpeg(x, q=val), # name=name, # transform_name="jpg", # min_val=10, # max_val=100, # samples=50, # ) logger.update("orig", binary_loss) model.set_distribution(transforms.training, n=DIST_SIZE) model.train()
def test_transfer(model=None, image_files=VAL_FILES, max_iter=250): if not isinstance(model, BaseModel): print(f"Loading model from {model}") model = DataParallelModel( DecodingModel.load(distribution=transforms.encoding, n=ENCODING_DIST_SIZE, weights_file=model)) images = [im.load(image) for image in image_files] images = im.stack(images) targets = [binary.random(n=TARGET_SIZE) for _ in range(0, len(images))] model.eval() transform_list = [ transforms.rotate, transforms.translate, transforms.scale, transforms.resize_rect, transforms.crop, transforms.whiteout, transforms.elastic, transforms.motion_blur, transforms.brightness, transforms.contrast, transforms.pixilate, transforms.blur, transforms.color_jitter, transforms.gauss, transforms.noise, transforms.impulse_noise, transforms.flip, ] labels = [t.__name__ for t in transform_list] score_matrix = np.zeros((len(transform_list), len(transform_list))) for i, t1 in enumerate(transform_list): model.set_distribution(lambda x: t1.random(x), n=ENCODING_DIST_SIZE) encoded_images = encode_binary(images, targets, model, n=ENCODING_DIST_SIZE, verbose=True, max_iter=max_iter, use_weighting=True) model.set_distribution(transforms.identity, n=1) t1_error = sweep(encoded_images, targets, model, transform=t1, name=f"{t1.__name__}", samples=60) for j, t2 in enumerate(transform_list): if t1.__name__ == t2.__name__: score_matrix[i, j] = t1_error continue t2_error = sweep(encoded_images, targets, model, transform=t2, name=f'{t1.__name__}->{t2.__name__}', samples=60) score_matrix[i, j] = t2_error print(f'{t1.__name__} --> {t2.__name__}: {t2_error}') np.save('labels', labels) np.save('score_matrix', score_matrix) create_heatmap(score_matrix, labels)