示例#1
0
def run_eval_suite(name,
                   dest_task=tasks.normal,
                   graph_file=None,
                   model_file=None,
                   logger=None,
                   sample=800,
                   show_images=False,
                   old=False):

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif model_file is not None:
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.normal,
                         dest_task=dest_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    dataset = ValidationMetrics("almena", dest_task=dest_task)
    result = dataset.evaluate(model, sample=800)
    logger.text(name + ": " + str(result))
示例#2
0
 def load_model(self):
     if self.model is None:
         print("Loading pretrained model from:", self.path)
         if self.path is not None:
             self.model = DataParallelModel.load(
                 self.model_type().to(DEVICE), self.path)
         else:
             self.model = self.model_type()
             if isinstance(self.model, nn.Module):
                 self.model = DataParallelModel(self.model)
     return self.model
示例#3
0
 def load_model(self):
     if self.model is None:
         if self.path is not None:
             self.model = DataParallelModel.load(
                 self.model_type().to(DEVICE), self.path)
             # if optimizer:
             #     self.model.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
         else:
             self.model = self.model_type().to(DEVICE)
             if isinstance(self.model, nn.Module):
                 self.model = DataParallelModel(self.model)
     return self.model
示例#4
0
def run_viz_suite(name,
                  data,
                  dest_task=tasks.depth_zbuffer,
                  graph_file=None,
                  model_file=None,
                  logger=None,
                  old=False,
                  multitask=False,
                  percep_mode=None):

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        print('here')
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    # DATA LOADING 1
    results = model.predict(data)[:, -3:].clamp(min=0, max=1)
    if results.shape[1] == 1:
        results = torch.cat([results] * 3, dim=1)

    if percep_mode:
        percep_model = Transfer(src_task=dest_task,
                                dest_task=tasks.normal).load_model()
        percep_model.eval()
        eval_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(results),
            batch_size=16,
            num_workers=16,
            shuffle=False,
            pin_memory=True)
        final_preds = []
        for preds, in eval_loader:
            print('preds shape', preds.shape)
            final_preds += [percep_model.forward(preds[:, -3:])]
        results = torch.cat(final_preds, dim=0)

    return results
示例#5
0
def run_perceptual_eval_suite(name,
                              intermediate_task=tasks.normal,
                              dest_task=tasks.normal,
                              graph_file=None,
                              model_file=None,
                              logger=None,
                              sample=800,
                              show_images=False,
                              old=False,
                              perceptual_transfer=None,
                              multitask=False):

    if perceptual_transfer is None:
        percep_model = Transfer(src_task=intermediate_task,
                                dest_task=dest_task).load_model()

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, intermediate_task],
                          pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, intermediate_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        print('running multitask')
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.rgb,
                         dest_task=intermediate_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    dataset = ValidationMetrics("almena", dest_task=dest_task)
    result = dataset.evaluate_with_percep(model,
                                          sample=800,
                                          percep_model=percep_model)
    logger.text(name + ": " + str(result))
示例#6
0
    def load_model(self, parents=[]):

        model_path = get_finetuned_model_path(parents + [self])

        if model_path not in self.cached_models:
            if not os.path.exists(model_path):
                self.cached_models[model_path] = super().load_model()
            else:
                self.cached_models[model_path] = DataParallelModel.load(
                    self.model_type().cuda(), model_path)

        self.model = self.cached_models[model_path]
        return self.model
示例#7
0
    def load_model(self, parents=[]):

        model_path = get_finetuned_model_path(parents + [self])

        if model_path not in self.cached_models:
            if not os.path.exists(model_path):
                print(f"{model_path} not found, loading pretrained")
                self.cached_models[model_path] = super().load_model()
            else:
                print(f"{model_path} found, loading finetuned")
                self.cached_models[model_path] = DataParallelModel.load(
                    self.model_type().cuda(), model_path)
                print(f"")
        self.model = self.cached_models[model_path]
        return self.model
示例#8
0
def main(loss_config="gt_mse", mode="standard", pretrained=False, batch_size=64, **kwargs):

    # MODEL
    # model = DataParallelModel.load(UNet().cuda(), "standardval_rgb2normal_baseline.pth")
    model = functional_transfers.n.load_model() if pretrained else DataParallelModel(UNet())
    model.compile(torch.optim.Adam, lr=(3e-5 if pretrained else 3e-4), weight_decay=2e-6, amsgrad=True)
    scheduler = MultiStepLR(model.optimizer, milestones=[5*i + 1 for i in range(0, 80)], gamma=0.95)

    # FUNCTIONAL LOSS
    functional = get_functional_loss(config=loss_config, mode=mode, model=model, **kwargs)
    print (functional)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
    logger.add_hook(lambda logger, data: model.save(f"{RESULTS_DIR}/model.pth"), feature="loss", freq=400)
    logger.add_hook(lambda logger, data: scheduler.step(), feature="epoch", freq=1)
    functional.logger_hooks(loggers)

    # DATA LOADING
    ood_images = load_ood(ood_path=f'{BASE_DIR}/data/ood_images/')
    train_loader, val_loader, train_step, val_step = load_train_val([tasks.rgb, tasks.normal], batch_size=batch_size)
        # train_buildings=["almena"], val_buildings=["almena"])
    test_set, test_images = load_test("rgb", "normal")
    logger.images(test_images, "images", resize=128)
    logger.images(torch.cat(ood_images, dim=0), "ood_images", resize=128)

    # TRAINING
    for epochs in range(0, 800):
        preds_name = "start_preds" if epochs == 0 and pretrained else "preds"
        ood_name = "start_ood" if epochs == 0 and pretrained else "ood"
        plot_images(model, logger, test_set, dest_task="normal", ood_images=ood_images, 
            loss_models=functional.plot_losses, preds_name=preds_name, ood_name=ood_name
        )
        logger.update("epoch", epochs)
        logger.step()

        train_set = itertools.islice(train_loader, train_step)
        val_set = itertools.islice(val_loader, val_step)
        
        val_metrics = model.predict_with_metrics(val_set, loss_fn=functional, logger=logger)
        train_metrics = model.fit_with_metrics(train_set, loss_fn=functional, logger=logger)
        functional.logger_update(logger, train_metrics, val_metrics)
示例#9
0
def main(src_task, dest_task, fast=False):

    # src_task, dest_task = get_task(src_task), get_task(dest_task)
    model = DataParallelModel(get_model(src_task, dest_task).cuda())
    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)
    logger.add_hook(lambda logger, data: logger.step(),
                    feature="loss",
                    freq=25)
    logger.add_hook(partial(jointplot, loss_type="mse_loss"),
                    feature="val_mse_loss",
                    freq=1)
    logger.add_hook(lambda logger, data: model.save(
        f"{RESULTS_DIR}/{src_task.name}2{dest_task.name}.pth"),
                    feature="loss",
                    freq=400)

    # DATA LOADING
    train_loader, val_loader, train_step, val_step = load_train_val(
        [src_task, dest_task],
        batch_size=48,
        train_buildings=["almena"],
        val_buildings=["almena"])
    train_step, val_step = 5, 5
    test_set, test_images = load_test(src_task, dest_task)
    src_task.plot_func(test_images, "images", logger, resize=128)

    for epochs in range(0, 800):

        logger.update("epoch", epochs)
        plot_images(model, logger, test_set, dest_task, show_masks=True)

        train_set = itertools.islice(train_loader, train_step)
        val_set = itertools.islice(val_loader, val_step)

        (train_mse_data, ) = model.fit_with_metrics(train_set,
                                                    loss_fn=dest_task.norm,
                                                    logger=logger)
        logger.update("train_mse_loss", train_mse_data)
        (val_mse_data, ) = model.predict_with_metrics(val_set,
                                                      loss_fn=dest_task.norm,
                                                      logger=logger)
        logger.update("val_mse_loss", val_mse_data)
示例#10
0
def encode(
    image,
    out,
    target=binary.str(binary.random(TARGET_SIZE)),
    n=96,
    model=None,
    max_iter=500,
    use_weighting=True,
    perturbation_out=None,
):

    if not isinstance(model, DecodingModel):
        model = DataParallelModel(DecodingModel.load(distribution=transforms.encoding, n=n, weights_file=model))
    image = im.torch(im.load(image)).unsqueeze(0)
    print("Target: ", target)
    target = binary.parse(str(target))
    encoded = encode_binary(image, [target], model, n=n, verbose=True, max_iter=max_iter, use_weighting=use_weighting)
    im.save(im.numpy(encoded.squeeze()), file=out)
    if perturbation_out != None:
        im.save(im.numpy((image - encoded).squeeze()), file=perturbation_out)
示例#11
0
def evaluate(model, image, target, test_transforms=False):

    if not isinstance(model, BaseModel):
        model = DataParallelModel(
            DecodingModel.load(distribution=transforms.identity,
                               n=1,
                               weights_file=model))

    image = im.torch(im.load(image)).unsqueeze(0)
    target = binary.parse(str(target))
    prediction = model(image).mean(dim=1).squeeze().cpu().data.numpy()
    prediction = binary.get(prediction)

    # print (f"Target: {binary.str(target)}, Prediction: {binary.str(prediction)}, \
    #         Distance: {binary.distance(target, prediction)}")

    if test_transforms:
        sweep(image, [target],
              model,
              transform=transforms.rotate,
              name="eval",
              samples=60)
示例#12
0
def run_viz_suite(name,
                  data_loader,
                  dest_task=tasks.depth_zbuffer,
                  graph_file=None,
                  model_file=None,
                  old=False,
                  multitask=False,
                  percep_mode=None,
                  downsample=6,
                  out_channels=3,
                  final_task=tasks.normal,
                  oldpercep=False):

    extra_task = [final_task] if percep_mode else []

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task] + extra_task,
                          pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        # downsample = 5 or 6
        print('loading main model')
        #model = DataParallelModel.load(UNetReshade(downsample=downsample,  out_channels=out_channels).cuda(), model_file)
        model = DataParallelModel.load(
            UNet(downsample=downsample, out_channels=out_channels).cuda(),
            model_file)
        #model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = DummyModel(
            Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model())

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    # DATA LOADING 1
    results = []
    final_preds = []

    if percep_mode:
        print('Loading percep model...')
        if graph_file is not None and not oldpercep:
            percep_model = graph.edge(dest_task, final_task).load_model()
            percep_model.compile(torch.optim.Adam,
                                 lr=3e-4,
                                 weight_decay=2e-6,
                                 amsgrad=True)
        else:
            percep_model = Transfer(src_task=dest_task,
                                    dest_task=final_task).load_model()
        percep_model.eval()

    print("Converting...")
    for data, in data_loader:
        preds = model.predict_on_batch(data)[:, -3:].clamp(min=0, max=1)
        results.append(preds.detach().cpu())
        if percep_mode:
            try:
                final_preds += [
                    percep_model.forward(preds[:, -3:]).detach().cpu()
                ]
            except RuntimeError:
                preds = torch.cat([preds] * 3, dim=1)
                final_preds += [
                    percep_model.forward(preds[:, -3:]).detach().cpu()
                ]
        #break

    if percep_mode:
        results = torch.cat(final_preds, dim=0)
    else:
        results = torch.cat(results, dim=0)

    return results
示例#13
0
def save_outputs(img_path, output_file_name):

    img = Image.open(img_path)
    img_tensor = trans_totensor(img)[:3].unsqueeze(0).to(DEVICE)

    if distortion is not None:
        trans_topil(img_tensor[0].clamp(
            min=0,
            max=1).cpu()).save(args.output_path + '/' + 'distorted_input.png')

    # compute baseline and consistency output
    #for type in ['baseline','consistency']:
    #    path = root_dir + 'rgb2'+args.task+'_'+type+'.pth'
    #    model_state_dict = torch.load(path, map_location=map_location)
    #    model.load_state_dict(model_state_dict)
    #    baseline_output = model(img_tensor).clamp(min=0, max=1)
    #    trans_topil(baseline_output[0]).save(args.output_path+'/'+output_file_name+'_'+args.task+'_'+type+'.png')

    # compute all 8 path outputs
    #pdb.set_trace()
    all_models_state_dict = torch.load(path, map_location=map_location)

    direct_model = WrapperModel(DataParallelModel(models[0].to(DEVICE)))
    #pdb.set_trace()
    direct_model.load_state_dict(all_models_state_dict["('rgb', '" +
                                                       target_task + "')"])
    direct_output = direct_model(img_tensor)  #.clamp(min=0, max=1)

    emboss_model = WrapperModel(DataParallelModel(models[1].to(DEVICE)))
    emboss_model.load_state_dict(all_models_state_dict["('emboss4d', '" +
                                                       target_task + "')"])
    emboss_output = emboss_model(
        emboss4d_kernel(img_tensor))  #.clamp(min=0, max=1)

    grey_model = WrapperModel(DataParallelModel(models[2].to(DEVICE)))
    grey_model.load_state_dict(all_models_state_dict["('grey', '" +
                                                     target_task + "')"])
    grey_output = grey_model(greyscale(img_tensor))  #.clamp(min=0, max=1)

    laplace_model = WrapperModel(DataParallelModel(models[3].to(DEVICE)))
    laplace_model.load_state_dict(all_models_state_dict["('laplace_edges', '" +
                                                        target_task + "')"])
    laplace_output = laplace_model(
        laplace_kernel(img_tensor))  #.clamp(min=0, max=1)

    gauss_model = WrapperModel(DataParallelModel(models[4].to(DEVICE)))
    gauss_model.load_state_dict(all_models_state_dict["('gauss', '" +
                                                      target_task + "')"])
    gauss_output = gauss_model(gauss_kernel(img_tensor))  #.clamp(min=0, max=1)

    sobel_model = WrapperModel(DataParallelModel(models[5].to(DEVICE)))
    sobel_model.load_state_dict(all_models_state_dict["('sobel_edges', '" +
                                                      target_task + "')"])
    sobel_output = sobel_model(sobel_kernel(img_tensor))  #.clamp(min=0, max=1)

    wav_model = WrapperModel(DataParallelModel(models[6].to(DEVICE)))
    wav_model.load_state_dict(all_models_state_dict["('wav', '" + target_task +
                                                    "')"])
    wav_output = wav_model(wav_kernel(img_tensor))  #.clamp(min=0, max=1)

    sharp_model = WrapperModel(DataParallelModel(models[7].to(DEVICE)))
    sharp_model.load_state_dict(all_models_state_dict["('sharp', '" +
                                                      target_task + "')"])
    sharp_output = sharp_model(sharp_kernel(img_tensor))  #.clamp(min=0, max=1)

    #merged_outputs = torch.Tensor().cuda()
    merged_outputs = torch.cat(
        (direct_output, emboss_output, grey_output, laplace_output,
         gauss_output, sobel_output, wav_output, sharp_output),
        dim=1)

    npaths = 8
    di_ind = np.diag_indices(npaths)
    nchannels = int(merged_outputs.size(1) // (npaths * 2))
    inds = np.arange(npaths) * 2 + 1  # indices of channel0 sigmas
    SQRT2 = math.sqrt(2)
    for i in range(nchannels):
        inds_ = nchannels * inds + i
        merged_outputs[:, inds_] = merged_outputs[:, inds_].exp(
        ) * SQRT2  # convert to sigma from log(b)

    ######## get sig avg weights
    if nchannels == 1:
        muind = inds - 1
        sigind = inds
    else:
        muind = np.array([
            0, 1, 2, 6, 7, 8, 12, 13, 14, 18, 19, 20, 24, 25, 26, 30, 31, 32,
            36, 37, 38, 42, 43, 44
        ])  #  8 paths
        sigind = muind + 3
    sig_avg_weights = torch.cuda.FloatTensor(
        merged_outputs[:, :npaths].size()).fill_(0.0)
    total_inv_sig = (1. / merged_outputs[:, sigind].pow(2)).sum(1)
    for i in range(npaths):
        sig_avg_weights[:, i] = (
            1. /
            merged_outputs[:, 2 * i * nchannels + nchannels:2 *
                           (i + 1) * nchannels].pow(2)).sum(1) / total_inv_sig

    weights = sig_avg_weights
    merged_mu = torch.cuda.FloatTensor(
        merged_outputs[:, :nchannels].size()).fill_(0.0)
    merged_sig = torch.cuda.FloatTensor(
        merged_outputs[:, :nchannels].size()).fill_(0.0)

    for i in range(nchannels):
        inds_ = i + nchannels * inds
        ## compute correl mat
        cov_mat = torch.cuda.FloatTensor(merged_mu.size(0), merged_mu.size(-1),
                                         merged_mu.size(-1), int(npaths),
                                         int(npaths)).fill_(0.0)
        cov_mat[:, :, :, di_ind[0],
                di_ind[1]] = merged_outputs[:,
                                            inds_].pow(2).permute(0, 2, 3, 1)

        ## merge
        merged_mu[:, i] = (merged_outputs[:, inds_ - nchannels] *
                           weights).sum(dim=1)
        weights = weights.permute(0, 2, 3, 1)
        merged_sig[:, i] = (
            weights.unsqueeze(-2) @ cov_mat
            @ weights.unsqueeze(-1)).squeeze(-1).squeeze(-1).sqrt()
        weights = weights.permute(0, 3, 1, 2)

    if nchannels is 1:
        var_epistemic = merged_outputs[:, [0, 2, 4, 6, 8, 10, 12, 14]].var(
            1, keepdim=True)
    else:
        var_epistemic_r = merged_outputs[:,
                                         [0, 6, 12, 18, 24, 30, 36, 42]].var(
                                             1, keepdim=True)
        var_epistemic_g = merged_outputs[:,
                                         [1, 7, 13, 19, 25, 31, 37, 43]].var(
                                             1, keepdim=True)
        var_epistemic_b = merged_outputs[:,
                                         [2, 8, 14, 20, 26, 32, 38, 44]].var(
                                             1, keepdim=True)
        var_epistemic = torch.cat(
            (var_epistemic_r, var_epistemic_g, var_epistemic_b), dim=1)

    baseline_mu = direct_output[:, :nchannels]
    baseline_sig = direct_output[:, nchannels:]

    trans_topil(merged_mu[0].clamp(
        min=0, max=1).cpu()).save(args.output_path + '/' + output_file_name +
                                  '_' + args.task + '_' + 'ours_mean.png')
    #trans_topil((merged_sig[0].exp()*SQRT2).cpu()).save(args.output_path+'/'+output_file_name+'_'+args.task+'_'+'ours_var.png')
    trans_topil((var_epistemic[0].sqrt() * SQRT2).clamp(
        min=0, max=1).cpu()).save(args.output_path + '/' + output_file_name +
                                  '_' + args.task + '_' + 'ours_sig.png')

    trans_topil(baseline_mu[0].clamp(
        min=0, max=1).cpu()).save(args.output_path + '/' + output_file_name +
                                  '_' + args.task + '_' + 'baseline_mean.png')
    trans_topil((baseline_sig[0].exp() *
                 SQRT2).cpu()).save(args.output_path + '/' + output_file_name +
                                    '_' + args.task + '_' + 'baseline_sig.png')

    ### GET DEEP ENS RESULTS ###

    all_models_state_dict = torch.load(path_deepens, map_location=map_location)

    direct_model = WrapperModel(DataParallelModel(models_deepens.to(DEVICE)))

    direct_model.load_state_dict(
        all_models_state_dict["('rgb', '" + target_task + "1_ens')"])
    direct_output1 = direct_model(img_tensor)  #.clamp(min=0, max=1)

    direct_model.load_state_dict(
        all_models_state_dict["('rgb', '" + target_task + "2_ens')"])
    direct_output2 = direct_model(img_tensor)  #.clamp(min=0, max=1)

    direct_model.load_state_dict(
        all_models_state_dict["('rgb', '" + target_task + "3_ens')"])
    direct_output3 = direct_model(img_tensor)  #.clamp(min=0, max=1)

    direct_model.load_state_dict(
        all_models_state_dict["('rgb', '" + target_task + "4_ens')"])
    direct_output4 = direct_model(img_tensor)  #.clamp(min=0, max=1)

    direct_model.load_state_dict(
        all_models_state_dict["('rgb', '" + target_task + "5_ens')"])
    direct_output5 = direct_model(img_tensor)  #.clamp(min=0, max=1)

    direct_model.load_state_dict(
        all_models_state_dict["('rgb', '" + target_task + "6_ens')"])
    direct_output6 = direct_model(img_tensor)  #.clamp(min=0, max=1)

    direct_model.load_state_dict(
        all_models_state_dict["('rgb', '" + target_task + "7_ens')"])
    direct_output7 = direct_model(img_tensor)  #.clamp(min=0, max=1)

    direct_model.load_state_dict(
        all_models_state_dict["('rgb', '" + target_task + "8_ens')"])
    direct_output8 = direct_model(img_tensor)  #.clamp(min=0, max=1)

    mu_ens = 0.125 * (
        direct_output1[:, :nchannels] + direct_output2[:, :nchannels] +
        direct_output3[:, :nchannels] + direct_output4[:, :nchannels] +
        direct_output5[:, :nchannels] + direct_output6[:, :nchannels] +
        direct_output7[:, :nchannels] + direct_output8[:, :nchannels])
    merged_outputs_ens = torch.cat(
        (direct_output1, direct_output2, direct_output3, direct_output4,
         direct_output5, direct_output6, direct_output7, direct_output8),
        dim=1)

    if nchannels is 1:
        var_epistemic_ens = merged_outputs_ens[:, [0, 2, 4, 6, 8, 10, 12, 14
                                                   ]].var(1, keepdim=True)
    else:
        var_epistemic_ens_r = merged_outputs_ens[:, [
            0, 6, 12, 18, 24, 30, 36, 42
        ]].var(1, keepdim=True)
        var_epistemic_ens_g = merged_outputs_ens[:, [
            1, 7, 13, 19, 25, 31, 37, 43
        ]].var(1, keepdim=True)
        var_epistemic_ens_b = merged_outputs_ens[:, [
            2, 8, 14, 20, 26, 32, 38, 44
        ]].var(1, keepdim=True)
        var_epistemic_ens = torch.cat(
            (var_epistemic_ens_r, var_epistemic_ens_g, var_epistemic_ens_b),
            dim=1)

    trans_topil(mu_ens[0].clamp(
        min=0, max=1).cpu()).save(args.output_path + '/' + output_file_name +
                                  '_' + args.task + '_' + 'deepens_mean.png')
    trans_topil((var_epistemic_ens[0].sqrt() * SQRT2).clamp(
        min=0, max=1).cpu()).save(args.output_path + '/' + output_file_name +
                                  '_' + args.task + '_' + 'deepens_sig.png')
示例#14
0
    for k, files in tqdm.tqdm(list(
            enumerate(batch(image_files, batch_size=BATCH_SIZE))),
                              ncols=50):

        images = im.stack([im.load(img_file) for img_file in files]).detach()
        perturbation = nn.Parameter(0.03 *
                                    torch.randn(images.size()).to(DEVICE) +
                                    0.0)
        targets = [binary.random(n=TARGET_SIZE) for i in range(len(images))]
        torch.save((perturbation.data, images.data, targets),
                   f"{output_path}/{k}.pth")


if __name__ == "__main__":

    model = DataParallelModel(
        DecodingModel(n=DIST_SIZE, distribution=transforms.training))
    params = itertools.chain(model.module.classifier.parameters(),
                             model.module.features[-1].parameters())
    optimizer = torch.optim.Adam(params, lr=2.5e-3)
    init_data("data/amnesia")

    logger = VisdomLogger("train", server="35.230.67.129", port=8000, env=JOB)
    logger.add_hook(lambda x: logger.step(), feature="epoch", freq=20)
    logger.add_hook(lambda data: logger.plot(data, "train_loss"),
                    feature="loss",
                    freq=50)
    logger.add_hook(lambda data: logger.plot(data, "train_bits"),
                    feature="bits",
                    freq=50)
    logger.add_hook(
        lambda x: model.save("output/train_test.pth", verbose=True),
示例#15
0
def test_transforms(model=None,
                    image_files=VAL_FILES,
                    name="test",
                    max_iter=250):

    if not isinstance(model, BaseModel):
        print(f"Loading model from {model}")
        model = DataParallelModel(
            DecodingModel.load(distribution=transforms.new_dist,
                               n=ENCODING_DIST_SIZE,
                               weights_file=model))

    images = [im.load(image) for image in image_files]
    images = im.stack(images)
    targets = [binary.random(n=TARGET_SIZE) for _ in range(0, len(images))]
    model.eval()

    encoded_images = encode_binary(images,
                                   targets,
                                   model,
                                   n=ENCODING_DIST_SIZE,
                                   verbose=True,
                                   max_iter=max_iter,
                                   use_weighting=True)

    logger.images(images, "original_images", resize=196)
    logger.images(encoded_images, "encoded_images", resize=196)
    for img, encoded_im, filename, target in zip(images, encoded_images,
                                                 image_files, targets):
        im.save(
            im.numpy(img),
            file=
            f"output/_{binary.str(target)}_original_{filename.split('/')[-1]}")
        im.save(
            im.numpy(encoded_im),
            file=
            f"output/_{binary.str(target)}_encoded_{filename.split('/')[-1]}")

    model.set_distribution(transforms.identity, n=1)
    predictions = model(encoded_images).mean(dim=1).cpu().data.numpy()
    binary_loss = np.mean(
        [binary.distance(x, y) for x, y in zip(predictions, targets)])

    for transform in [
            transforms.pixilate,
            transforms.blur,
            transforms.rotate,
            transforms.scale,
            transforms.translate,
            transforms.noise,
            transforms.crop,
            transforms.gauss,
            transforms.whiteout,
            transforms.resize_rect,
            transforms.color_jitter,
            transforms.jpeg_transform,
            transforms.elastic,
            transforms.brightness,
            transforms.contrast,
            transforms.flip,
    ]:
        sweep(encoded_images,
              targets,
              model,
              transform=transform,
              name=name,
              samples=60)

    # sweep(
    #     encoded_images,
    #     targets,
    #     model,
    #     transform=lambda x, val: transforms.rotate(x, rand_val=False, theta=val),
    #     name=name,
    #     transform_name="rotate",
    #     min_val=-0.6,
    #     max_val=0.6,
    #     samples=80,
    # )

    # sweep(
    #     encoded_images,
    #     targets,
    #     model,
    #     transform=lambda x, val: transforms.scale(x, rand_val=False, scale_val=val),
    #     name=name,
    #     transform_name="scale",
    #     min_val=0.6,
    #     max_val=1.4,
    #     samples=50,
    # )

    # sweep(
    #     encoded_images,
    #     targets,
    #     model,
    #     transform=lambda x, val: transforms.translate(x, rand_val=False, radius=val),
    #     name=name,
    #     transform_name="translate",
    #     min_val=0.0,
    #     max_val=1.0,
    #     samples=50,
    # )

    # sweep(
    #     encoded_images,
    #     targets,
    #     model,
    #     transform=lambda x, val: transforms.noise(x, intensity=val),
    #     name=name,
    #     transform_name="noise",
    #     min_val=0.0,
    #     max_val=0.1,
    #     samples=30,
    # )

    # sweep(
    #     encoded_images,
    #     targets,
    #     model,
    #     transform=lambda x, val: transforms.crop(x, p=val),
    #     name=name,
    #     transform_name="crop",
    #     min_val=0.1,
    #     max_val=1.0,
    #     samples=50,
    # )

    # sweep(
    #     encoded_images,
    #     targets,
    #     model,
    #     transform=lambda x, val: transforms.gauss(x, sigma=val, rand_val=False),
    #     name=name,
    #     transform_name="gauss",
    #     min_val=0.3,
    #     max_val=4,
    #     samples=50,
    # )

    # sweep(
    #     encoded_images,
    #     targets,
    #     model,
    #     transform=lambda x, val: transforms.whiteout(x, scale=val, rand_val=False),
    #     name=name,
    #     transform_name="whiteout",
    #     min_val=0.02,
    #     max_val=0.2,
    #     samples=50,
    # )

    # sweep(
    #     encoded_images,
    #     targets,
    #     model,
    #     transform=lambda x, val: transforms.resize_rect(x, ratio=val, rand_val=False),
    #     name=name,
    #     transform_name="resize_rect",
    #     min_val=0.5,
    #     max_val=1.5,
    #     samples=50,
    # )

    # sweep(
    #     encoded_images,
    #     targets,
    #     model,
    #     transform=lambda x, val: transforms.color_jitter(x, jitter=val),
    #     name=name,
    #     transform_name="jitter",
    #     min_val=0,
    #     max_val=0.2,
    #     samples=50,
    # )

    # sweep(
    #     encoded_images,
    #     targets,
    #     model,
    #     transform=lambda x, val: transforms.convertToJpeg(x, q=val),
    #     name=name,
    #     transform_name="jpg",
    #     min_val=10,
    #     max_val=100,
    #     samples=50,
    # )

    logger.update("orig", binary_loss)
    model.set_distribution(transforms.training, n=DIST_SIZE)
    model.train()
示例#16
0
def test_transfer(model=None, image_files=VAL_FILES, max_iter=250):
    if not isinstance(model, BaseModel):
        print(f"Loading model from {model}")
        model = DataParallelModel(
            DecodingModel.load(distribution=transforms.encoding,
                               n=ENCODING_DIST_SIZE,
                               weights_file=model))

    images = [im.load(image) for image in image_files]
    images = im.stack(images)
    targets = [binary.random(n=TARGET_SIZE) for _ in range(0, len(images))]
    model.eval()

    transform_list = [
        transforms.rotate,
        transforms.translate,
        transforms.scale,
        transforms.resize_rect,
        transforms.crop,
        transforms.whiteout,
        transforms.elastic,
        transforms.motion_blur,
        transforms.brightness,
        transforms.contrast,
        transforms.pixilate,
        transforms.blur,
        transforms.color_jitter,
        transforms.gauss,
        transforms.noise,
        transforms.impulse_noise,
        transforms.flip,
    ]

    labels = [t.__name__ for t in transform_list]
    score_matrix = np.zeros((len(transform_list), len(transform_list)))

    for i, t1 in enumerate(transform_list):

        model.set_distribution(lambda x: t1.random(x), n=ENCODING_DIST_SIZE)
        encoded_images = encode_binary(images,
                                       targets,
                                       model,
                                       n=ENCODING_DIST_SIZE,
                                       verbose=True,
                                       max_iter=max_iter,
                                       use_weighting=True)

        model.set_distribution(transforms.identity, n=1)
        t1_error = sweep(encoded_images,
                         targets,
                         model,
                         transform=t1,
                         name=f"{t1.__name__}",
                         samples=60)

        for j, t2 in enumerate(transform_list):
            if t1.__name__ == t2.__name__:
                score_matrix[i, j] = t1_error
                continue
            t2_error = sweep(encoded_images,
                             targets,
                             model,
                             transform=t2,
                             name=f'{t1.__name__}->{t2.__name__}',
                             samples=60)
            score_matrix[i, j] = t2_error

            print(f'{t1.__name__} --> {t2.__name__}: {t2_error}')
    np.save('labels', labels)
    np.save('score_matrix', score_matrix)
    create_heatmap(score_matrix, labels)