Exemplo n.º 1
0
def validate(cfgs):
    utils.set_device(cfgs.get("val", "device"),
                     cfgs.getint("val", "device_id"))
    loader = get_imagelist_dataloader(cfgs, "val")
    model_str = cfgs.get("val", "params")
    if model_str:
      model_list = glob(model_str)
    else:
      path = os.path.join(cfgs.get("train", "snapshot_prefix"),
                          cfgs.get("model", "net"))
      model_list = [os.path.join(path, p) for p in os.listdir(path) if ".pt" in p]
      model_list = sorted(sorted(model_list), key=len)

    for model_path in model_list:
      val_list = cfgs.get('val','imagelist').split('/')[-1].split('.')[0]
      prefix = model_path.replace('.pt', '_%s'%val_list)
      res_file = glob(prefix+'*.npy')
      if len(res_file) != 0:
        logging.info("Test Result %s already existed!"%res_file[0])
        continue
      
      logging.info("Loading Model %s"%model_path)
      net = get_net(cfgs.get("model", "net"),
                    cfgs.getint("model", "classes"),
                    model_path)
      val_net(net, loader, model_path, cfgs)
Exemplo n.º 2
0
def train(cfgs):
    utils.set_device(cfgs.get("train", "device"),
                     cfgs.getint("train", "device_id"))
    trainloader = get_imagelist_dataloader(cfgs, "train")

    net = get_net(cfgs.get("model", "net"), cfgs.getint("model", "classes"),
                  cfgs.get("train", "params"))
    criterion = nn.CrossEntropyLoss()
    #criterion = LSoftmaxLinear()
    optimizer = get_optimizer(net, cfgs)
    scheduler = get_scheduler(optimizer, cfgs)
    train_net(net, criterion, trainloader, optimizer, scheduler, cfgs)
Exemplo n.º 3
0
    def __init__(self,
                 init_lr,
                 net_params,
                 n_epochs,
                 loss,
                 optimizer=None,
                 scheduler=None):
        super().__init__()

        self.model = get_net(**net_params)
        self.init_lr = init_lr
        self.n_epochs = n_epochs
        self.loss = loss
        if args.snap_mix:
            self.snapmix_criterion = SnapMixLoss()
Exemplo n.º 4
0
def test(cfgs):
    utils.set_device(cfgs.get("test", "device"),
                     cfgs.getint("test", "device_id"))
    model_str = cfgs.get("test", "params")
    if model_str:
        model_list = glob(model_str)
    else:
        logging.info("Please specify the model!")
        sys.exit(1)

    for model_path in model_list:
        logging.info("Loading model %s" % model_path)
        net = get_net(cfgs.get("model", "net"),
                      cfgs.getint("model", "classes"), model_path)
        test_net(net, model_path, cfgs)
Exemplo n.º 5
0
Arquivo: train.py Projeto: ioalzx/AdaS
def main(args: APNamespace):
    root_path = Path(args.root).expanduser()
    config_path = root_path / Path(args.config).expanduser()
    data_path = root_path / Path(args.data).expanduser()
    output_path = root_path / Path(args.output).expanduser()
    global checkpoint_path
    checkpoint_path = root_path / Path(args.checkpoint).expanduser()

    if not config_path.exists():
        # logging.critical(f"AdaS: Config path {config_path} does not exist")
        print(f"AdaS: Config path {config_path} does not exist")
        raise ValueError
    if not data_path.exists():
        print(f"AdaS: Data dir {data_path} does not exists, building")
        data_path.mkdir(exist_ok=True, parents=True)
    if not output_path.exists():
        print(f"AdaS: Output dir {output_path} does not exists, building")
        output_path.mkdir(exist_ok=True, parents=True)
    if not checkpoint_path.exists():
        if args.resume:
            print(f"AdaS: Cannot resume from checkpoint without specifying " +
                  "checkpoint dir")
            raise ValueError
        if checkpoint_path.is_dir():
            print(f"AdaS: Checkpoint dir {checkpoint_path} does not exists, " +
                  "building")
            checkpoint_path.mkdir(exist_ok=True, parents=True)
        else:
            print(f"AdaS: Checkpoint path {checkpoint_path} doesn't exist " +
                  "building directory to store checkpoints: .adas-checkpoint")
            checkpoint_path.cwd().mkdir(exist_ok=True, parents=True)

    with config_path.open() as f:
        config = yaml.load(f)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    global best_acc
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    print("Adas: Argument Parser Options")
    print("-"*45)
    print(f"    {'config':<20}: {args.config:<20}")
    print(f"    {'data':<20}: {args.data:<20}")
    print(f"    {'output':<20}: {args.output:<20}")
    print(f"    {'checkpoint':<20}: {args.checkpoint:<20}")
    print(f"    {'resume':<20}: {args.resume:<20}")
    print("\nAdas: Train: Config")
    print(f"    {'Key':<20} {'Value':<20}")
    print("-"*45)
    for k, v in config.items():
        print(f"    {k:<20} {v:<20}")

    for trial in range(config['n_trials']):
        device
        # Data
        # logging.info("Adas: Preparing Data")
        train_loader, test_loader = get_data(
            root=data_path,
            dataset=config['dataset'],
            mini_batch_size=config['mini_batch_size'])
        global performance_statistics, net, metrics, adas
        performance_statistics = {}

        # logging.info("AdaS: Building Model")
        net = get_net(config['network'], num_classes=10 if config['dataset'] ==
                      'CIFAR10' else 100 if config['dataset'] == 'CIFAR100'
                      else 1000 if config['dataset'] == 'ImageNet' else 10)
        metrics = Metrics(list(net.parameters()),
                          p=config['p'])
        if config['lr_scheduler'] == 'AdaS':
            adas = AdaS(parameters=list(net.parameters()),
                        beta=config['beta'],
                        zeta=config['zeta'],
                        init_lr=float(config['init_lr']),
                        min_lr=float(config['min_lr']),
                        p=config['p'])

        net = net.to(device)

        global criterion
        criterion = get_loss(config['loss'])

        # TODO config
        optimizer, scheduler = get_optimizer_scheduler(
            init_lr=float(config['init_lr']),
            optim_method=config['optim_method'],
            lr_scheduler=config['lr_scheduler'])

        if device == 'cuda':
            net = torch.nn.DataParallel(net)
            cudnn.benchmark = True

        if args.resume:
            # Load checkpoint.
            print("Adas: Resuming from checkpoint...")
            if checkpoint_path.is_dir():
                checkpoint = torch.load(str(checkpoint_path / 'ckpt.pth'))
            else:
                checkpoint = torch.load(str(checkpoint_path))
            net.load_state_dict(checkpoint['net'])
            best_acc = checkpoint['acc']
            start_epoch = checkpoint['epoch']
            if adas is not None:
                adas.historical_io_metrics = \
                    checkpoint['historical_io_metrics']

        # model_parameters = filter(lambda p: p.requires_grad,
        #                           net.parameters())
        # params = sum([np.prod(p.size()) for p in model_parameters])
        # print(params)
        epochs = range(start_epoch, start_epoch + config['max_epoch'])
        for epoch in epochs:
            start_time = time.time()
            print(f"AdaS: Epoch {epoch} Started.")
            train_loss, train_accuracy = epoch_iteration(
                train_loader, epoch, device, optimizer)
            end_time = time.time()
            if config['lr_scheduler'] == 'StepLR':
                scheduler.step()
            test_loss, test_accuracy = test_main(test_loader, epoch, device)
            total_time = time.time()
            print(
                f"AdaS: Epoch {epoch}/{epochs[-1]} Ended | " +
                "Total Time: {:.3f}s | ".format(total_time - start_time) +
                "Epoch Time: {:.3f}s | ".format(end_time - start_time) +
                "Est. Time Remaining: {:.3f}s | ".format(
                    (total_time - start_time) * (epochs[-1] - epoch)),
                "Train Loss: {:.4f}% | Train Acc. {:.4f}% | ".format(
                    train_loss,
                    train_accuracy) +
                "Test Loss: {:.4f}% | Test Acc. {:.4f}%".format(test_loss,
                                                                test_accuracy))
            df = pd.DataFrame(data=performance_statistics)
            if config['lr_scheduler'] == 'AdaS':
                xlsx_name = \
                    f"config['optim_method']_AdaS_trial={trial}_" +\
                    f"beta={config['beta']}_initlr=config['init_lr']_" +\
                    f"net={config['network']}_dataset={config['dataset']}.xlsx"
            else:
                xlsx_name = \
                    f"config['optim_method']_config['lr_scheduler']_" +\
                    f"trial={trial}_initlr=config['init_lr']" +\
                    f"net={config['network']}_dataset={config['dataset']}.xlsx"

            df.to_excel(str(output_path / xlsx_name))
Exemplo n.º 6
0
def main():

    batch_size = config.batch_size
    img_size = (config.img_size, config.img_size)
    num_classes = config.num_classes
    if config.png:
        img_type = '.png'
    else:
        img_type = '.jpg'

    dataset = HumanDataset(
        config.data,
        target_file='/content/gdrive/My Drive/atlas/sample_submission.csv',
        train=False,
        multi_label=config.multi_label,
        tags_type='all',
        img_type=img_type,
        img_size=img_size,
        test_aug=config.tta,
    )

    tags = get_tags()
    output_col = ['Id'] + tags
    submission_col = ['Id', 'Predicted']

    loader = data.DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=config.num_processes)

    model = get_net(config.model,
                    num_classes=config.num_classes,
                    drop_rate=0.5,
                    channels=config.channels)

    if not config.no_cuda:
        if config.num_gpu > 1:
            model = torch.nn.DataParallel(model,
                                          device_ids=list(range(
                                              config.num_gpu))).cuda()
        else:
            model.cuda()

    if config.resume is not None:
        assert os.path.isfile(config.resume), '%s not found' % config.resume
        checkpoint = torch.load(config.resume)
        print('Restoring model with %s architecture...' % checkpoint['arch'])

        model.load_state_dict(checkpoint['state_dict'])

        if 'threshold' in checkpoint:
            threshold = checkpoint['threshold']
            threshold = torch.FloatTensor(threshold)
            print('Using thresholds:', threshold)
            if not config.no_cuda:
                threshold = threshold.cuda()
        else:
            threshold = 0.2
        threshold = 0.2
        csplit = os.path.normpath(config.resume).split(sep=os.path.sep)
        if len(csplit) > 1:
            exp_name = csplit[-2] + '-' + csplit[-1].split('.')[0]
        else:
            exp_name = ''
        print('Model restored from file: %s' % config.resume)
    else:
        assert False and "No checkpoint specified"

    if config.output:
        output_base = config.output
    else:
        output_base = os.path.join('/content/gdrive/My Drive/', 'output')
    if not exp_name:
        exp_name = '-'.join([
            config.model,
            str(config.img_size), 'f' + str(config.fold),
            'png' if config.png else 'jpg'
        ])
    output_dir = get_outdir(output_base, 'predictions', exp_name)

    model.eval()

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    results_raw = []
    results_thr = []
    results_sub = []
    try:
        end = time.time()
        for batch_idx, (input, target, index) in enumerate(loader):
            data_time_m.update(time.time() - end)
            if not config.no_cuda:
                input = input.cuda()
            input_var = autograd.Variable(input, volatile=True)
            output = model(input_var)

            # augmentation reduction
            reduce_factor = loader.dataset.get_aug_factor()
            if reduce_factor > 1:
                output.data = output.data.unfold(
                    0, reduce_factor, reduce_factor).mean(dim=2).squeeze(dim=2)
                index = index[0:index.size(0):reduce_factor]

            # output non-linearity and thresholding
            output = torch.sigmoid(output)
            if isinstance(threshold, torch.FloatTensor) or isinstance(
                    threshold, torch.cuda.FloatTensor):
                threshold_m = torch.unsqueeze(threshold,
                                              0).expand_as(output.data)
                output_thr = (output.data > threshold_m).byte()
            else:
                output_thr = (output.data > threshold).byte()

            # move data to CPU and collect
            output = output.cpu().data.numpy()
            output_thr = output_thr.cpu().numpy()
            index = index.cpu().numpy().flatten()
            for i, o, ot in zip(index, output, output_thr):
                #print(dataset.inputs[i], o, ot)
                image_name = os.path.splitext(
                    os.path.basename(dataset.inputs[i]))[0]
                results_raw.append([image_name] + list(o))
                results_thr.append([image_name] + list(ot))
                results_sub.append([image_name] + [vector_to_tags(ot, tags)])
                # end iterating through batch

            batch_time_m.update(time.time() - end)
            if batch_idx % config.log_interval == 0:
                print('Inference: [{}/{} ({:.0f}%)]  '
                      'Time: {batch_time.val:.3f}s, {rate:.3f}/s  '
                      '({batch_time.avg:.3f}s, {rate_avg:.3f}/s)  '
                      'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                          batch_idx * len(input),
                          len(loader.sampler),
                          100. * batch_idx / len(loader),
                          batch_time=batch_time_m,
                          rate=input_var.size(0) / batch_time_m.val,
                          rate_avg=input_var.size(0) / batch_time_m.avg,
                          data_time=data_time_m))

            end = time.time()
            #end iterating through dataset
    except KeyboardInterrupt:
        pass
    results_raw_df = pd.DataFrame(results_raw, columns=output_col)
    results_raw_df.to_csv(os.path.join(output_dir, 'results_raw.csv'),
                          index=False)
    results_thr_df = pd.DataFrame(results_thr, columns=output_col)
    results_thr_df.to_csv(os.path.join(output_dir, 'results_thr.csv'),
                          index=False)
    results_sub_df = pd.DataFrame(results_sub, columns=submission_col)
    results_sub_df.to_csv(os.path.join(output_dir, 'submission.csv'),
                          index=False)
def main(img: int = 0,
         num_iter: int = 40000,
         lr: float = 3e-4,
         gpu: int = 0,
         seed: int = 42,
         save: bool = True):
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    dtype = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.FloatTensor

    global i, out_avg, psnr_noisy_last, last_net, net_input, losses, psnrs, ssims, average_dropout_rate, no_layers, \
           img_mean, sample_count, recons, uncerts, uncerts_ale, loss_last, roll_back

    imsize = (256, 256)
    PLOT = True

    timestamp = int(time.time())
    save_path = '/media/fastdata/laves/unsure'
    os.mkdir(f'{save_path}/{timestamp}')

    # denoising
    if img == 0:
        fname = '../NORMAL-4951060-8.jpeg'
        imsize = (256, 256)
    elif img == 1:
        fname = '../BACTERIA-1351146-0006.png'
        imsize = (256, 256)
    elif img == 2:
        fname = '../081_HC.png'
        imsize = (256, 256)
    elif img == 3:
        fname = '../CNV-9997680-30.png'
        imsize = (256, 256)
    else:
        assert False

    if fname == '../NORMAL-4951060-8.jpeg':

        # Add Gaussian noise to simulate speckle
        img_pil = crop_image(get_image(fname, imsize)[0], d=32)
        img_np = pil_to_np(img_pil)
        print(img_np.shape)
        p_sigma = 0.1
        img_noisy_pil, img_noisy_np = get_noisy_image_gaussian(img_np, p_sigma)

    elif fname == '../BACTERIA-1351146-0006.png':

        # Add Poisson noise to simulate low dose X-ray
        img_pil = crop_image(get_image(fname, imsize)[0], d=32)
        img_np = pil_to_np(img_pil)
        print(img_np.shape)
        #p_lambda = 50.0
        #img_noisy_pil, img_noisy_np = get_noisy_image_poisson(img_np, p_lambda)
        # for lam > 20, poisson can be approximated with Gaussian
        p_sigma = 0.1
        img_noisy_pil, img_noisy_np = get_noisy_image_gaussian(img_np, p_sigma)

    elif fname == '../081_HC.png':

        # Add Gaussian noise to simulate speckle
        img_pil = crop_image(get_image(fname, imsize)[0], d=32)
        img_np = pil_to_np(img_pil)
        print(img_np.shape)
        p_sigma = 0.1
        img_noisy_pil, img_noisy_np = get_noisy_image_gaussian(img_np, p_sigma)

    elif fname == '../CNV-9997680-30.png':

        # Add Gaussian noise to simulate speckle
        img_pil = crop_image(get_image(fname, imsize)[0], d=32)
        img_np = pil_to_np(img_pil)
        print(img_np.shape)
        p_sigma = 0.1
        img_noisy_pil, img_noisy_np = get_noisy_image_gaussian(img_np, p_sigma)

    else:
        assert False

    if PLOT:
        q = plot_image_grid([img_np, img_noisy_np], 4, 6)
        out_pil = np_to_pil(q)
        out_pil.save(f'{save_path}/{timestamp}/input.png', 'PNG')

    INPUT = 'noise'
    pad = 'reflection'
    OPT_OVER = 'net'  # 'net,input'

    reg_noise_std = 1. / 10.
    LR = lr
    roll_back = False  # To solve the oscillation of model training
    input_depth = 32

    show_every = 100
    exp_weight = 0.99

    mse = torch.nn.MSELoss()

    img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

    LOSSES = {}
    RECONS = {}
    UNCERTS = {}
    UNCERTS_ALE = {}
    PSNRS = {}
    SSIMS = {}

    # # SGD

    OPTIMIZER = 'adamw'
    weight_decay = 0
    LOSS = 'mse'
    figsize = 4

    NET_TYPE = 'skip'

    skip_n33d = 128
    skip_n33u = 128
    skip_n11 = 4
    num_scales = 5
    upsample_mode = 'bilinear'

    dropout_mode_down = 'None'
    dropout_p_down = 0.0
    dropout_mode_up = 'None'
    dropout_p_up = dropout_p_down
    dropout_mode_skip = 'None'
    dropout_p_skip = dropout_p_down
    dropout_mode_output = 'None'
    dropout_p_output = dropout_p_down

    net_input = get_noise(
        input_depth, INPUT,
        (img_pil.size[1], img_pil.size[0])).type(dtype).detach()

    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()

    out_avg = None
    last_net = None
    mc_iter = 1

    def closure_dip():

        global i, out_avg, psnr_noisy_last, last_net, net_input, losses, psnrs, ssims, average_dropout_rate, no_layers,\
               img_mean, sample_count, recons, uncerts, loss_last

        if reg_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() * reg_noise_std)

        out = net(net_input)
        out[:, :1] = out[:, :1].sigmoid()

        _loss = mse(out[:, :1], img_noisy_torch)
        _loss.backward()

        # Smoothing
        if out_avg is None:
            out_avg = out.detach()
        else:
            out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

        losses.append(mse(out_avg[:, :1], img_noisy_torch).item())

        _out = out.detach().cpu().numpy()[0, :1]
        _out_avg = out_avg.detach().cpu().numpy()[0, :1]

        psnr_noisy = compare_psnr(img_noisy_np, _out)
        psnr_gt = compare_psnr(img_np, _out)
        psnr_gt_sm = compare_psnr(img_np, _out_avg)

        ssim_noisy = compare_ssim(img_noisy_np[0], _out[0])
        ssim_gt = compare_ssim(img_np[0], _out[0])
        ssim_gt_sm = compare_ssim(img_np[0], _out_avg[0])

        psnrs.append([psnr_noisy, psnr_gt, psnr_gt_sm])
        ssims.append([ssim_noisy, ssim_gt, ssim_gt_sm])

        if PLOT and i % show_every == 0:
            print(
                f'Iteration: {i} Loss: {_loss.item():.4f} PSNR_noisy: {psnr_noisy:.4f} PSRN_gt: {psnr_gt:.4f} PSNR_gt_sm: {psnr_gt_sm:.4f}'
            )

            out_np = _out

            psnr_noisy = compare_psnr(img_noisy_np, out_np)
            psnr_gt = compare_psnr(img_np, out_np)

            if sample_count != 0:
                psnr_mean = compare_psnr(img_np, img_mean / sample_count)
            else:
                psnr_mean = 0

            print('###################')

            recons.append(out_np)

        i += 1

        return _loss

    if '../NORMAL-4951060-8.jpeg':
        net = get_net(input_depth,
                      NET_TYPE,
                      pad,
                      skip_n33d=skip_n33d,
                      skip_n33u=skip_n33u,
                      skip_n11=skip_n11,
                      num_scales=num_scales,
                      n_channels=1,
                      upsample_mode=upsample_mode,
                      dropout_mode_down=dropout_mode_down,
                      dropout_p_down=dropout_p_down,
                      dropout_mode_up=dropout_mode_up,
                      dropout_p_up=dropout_p_up,
                      dropout_mode_skip=dropout_mode_skip,
                      dropout_p_skip=dropout_p_skip,
                      dropout_mode_output=dropout_mode_output,
                      dropout_p_output=dropout_p_output).type(dtype)
    else:
        assert False

    net.apply(init_normal)

    losses = []
    recons = []
    uncerts = []
    uncerts_ale = []
    psnrs = []
    ssims = []

    img_mean = 0
    sample_count = 0
    i = 0
    psnr_noisy_last = 0
    loss_last = 1e16

    parameters = get_params(OPT_OVER, net, net_input)
    out_avg = None
    optimizer = torch.optim.AdamW(parameters, lr=LR, weight_decay=weight_decay)
    optimize(optimizer, closure_dip, num_iter)

    LOSSES['dip'] = losses
    RECONS['dip'] = recons
    UNCERTS['dip'] = uncerts
    UNCERTS_ALE['dip'] = uncerts_ale
    PSNRS['dip'] = psnrs
    SSIMS['dip'] = ssims

    to_plot = [img_np] + [np.clip(img, 0, 1) for img in RECONS['dip']]
    q = plot_image_grid(to_plot, factor=13)

    out_pil = np_to_pil(q)
    out_pil.save(f'{save_path}/{timestamp}/dip_recons.png', 'PNG')

    ## SGLD

    weight_decay = 1e-4
    LOSS = 'mse'
    input_depth = 32
    param_noise_sigma = 2

    NET_TYPE = 'skip'

    skip_n33d = 128
    skip_n33u = 128
    skip_n11 = 4
    num_scales = 5
    upsample_mode = 'bilinear'

    dropout_mode_down = 'None'
    dropout_p_down = 0.0
    dropout_mode_up = 'None'
    dropout_p_up = dropout_p_down
    dropout_mode_skip = 'None'
    dropout_p_skip = dropout_p_down
    dropout_mode_output = 'None'
    dropout_p_output = dropout_p_down

    net_input = get_noise(
        input_depth, INPUT,
        (img_pil.size[1], img_pil.size[0])).type(dtype).detach()

    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()

    mc_iter = 25

    def add_noise(model):
        for n in [x for x in model.parameters() if len(x.size()) == 4]:
            noise = torch.randn(n.size()) * param_noise_sigma * LR
            noise = noise.type(dtype)
            n.data = n.data + noise

    def closure_sgld():

        global i, out_avg, psnr_noisy_last, last_net, net_input, losses, psnrs, ssims, average_dropout_rate, no_layers, img_mean, sample_count, recons, uncerts, loss_last

        add_noise(net)

        if reg_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() * reg_noise_std)

        out = net(net_input)
        out[:, :1] = out[:, :1].sigmoid()

        _loss = mse(out[:, :1], img_noisy_torch)
        _loss.backward()

        # Smoothing
        if out_avg is None:
            out_avg = out.detach()
        else:
            out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

        losses.append(mse(out_avg[:, :1], img_noisy_torch).item())

        _out = out.detach().cpu().numpy()[0, :1]
        _out_avg = out_avg.detach().cpu().numpy()[0, :1]

        psnr_noisy = compare_psnr(img_noisy_np, _out)
        psnr_gt = compare_psnr(img_np, _out)
        psnr_gt_sm = compare_psnr(img_np, _out_avg)

        ssim_noisy = compare_ssim(img_noisy_np[0], _out[0])
        ssim_gt = compare_ssim(img_np[0], _out[0])
        ssim_gt_sm = compare_ssim(img_np[0], _out_avg[0])

        psnrs.append([psnr_noisy, psnr_gt, psnr_gt_sm])
        ssims.append([ssim_noisy, ssim_gt, ssim_gt_sm])

        if PLOT and i % show_every == 0:
            print(
                f'Iteration: {i} Loss: {_loss.item():.4f} PSNR_noisy: {psnr_noisy:.4f} PSRN_gt: {psnr_gt:.4f} PSNR_gt_sm: {psnr_gt_sm:.4f}'
            )

            out_np = _out
            recons.append(out_np)

            out_np_var = np.var(np.array(recons[-mc_iter:]), axis=0)[:1]

            print('mean epi', out_np_var.mean())
            print('###################')

            uncerts.append(out_np_var)

        i += 1

        return _loss

    if '../NORMAL-4951060-8.jpeg':
        net = get_net(input_depth,
                      NET_TYPE,
                      pad,
                      skip_n33d=skip_n33d,
                      skip_n33u=skip_n33u,
                      skip_n11=skip_n11,
                      num_scales=num_scales,
                      n_channels=1,
                      upsample_mode=upsample_mode,
                      dropout_mode_down=dropout_mode_down,
                      dropout_p_down=dropout_p_down,
                      dropout_mode_up=dropout_mode_up,
                      dropout_p_up=dropout_p_up,
                      dropout_mode_skip=dropout_mode_skip,
                      dropout_p_skip=dropout_p_skip,
                      dropout_mode_output=dropout_mode_output,
                      dropout_p_output=dropout_p_output).type(dtype)
    else:
        assert False

    net.apply(init_normal)

    losses = []
    recons = []
    uncerts = []
    uncerts_ale = []
    psnrs = []
    ssims = []

    img_mean = 0
    sample_count = 0
    i = 0
    psnr_noisy_last = 0
    loss_last = 1e10
    out_avg = None
    last_net = None

    parameters = get_params(OPT_OVER, net, net_input)
    optimizer = torch.optim.AdamW(parameters, lr=LR, weight_decay=weight_decay)
    optimize(optimizer, closure_sgld, num_iter)

    LOSSES['sgld'] = losses
    RECONS['sgld'] = recons
    UNCERTS['sgld'] = uncerts
    UNCERTS_ALE['sgld'] = uncerts_ale
    PSNRS['sgld'] = psnrs
    SSIMS['sgld'] = ssims

    to_plot = [img_np] + [np.clip(img, 0, 1) for img in RECONS['sgld']]
    q = plot_image_grid(to_plot, factor=13)

    out_pil = np_to_pil(q)
    out_pil.save(f'{save_path}/{timestamp}/sgld_recons.png', 'PNG')

    errs = img_noisy_torch.cpu() - torch.tensor(RECONS['sgld'][-1])
    uncerts_epi = torch.tensor(UNCERTS['sgld'][-1]).unsqueeze(0)
    uncerts = uncerts_epi
    uce, err, uncert, freq = uceloss(errs**2, uncerts, n_bins=21)
    fig, ax = plot_uncert(err, uncert, freq, outlier_freq=0.001)
    ax.set_title(
        f'U = {uncerts.mean().sqrt().item():.4f}, UCE = {uce.item()*100:.3f}')
    plt.tight_layout()
    fig.savefig(f'{save_path}/{timestamp}/sgld_calib.png')

    ## SGLD + NLL

    LOSS = 'nll'

    net_input = get_noise(
        input_depth, INPUT,
        (img_pil.size[1], img_pil.size[0])).type(dtype).detach()

    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()

    def closure_sgldnll():

        global i, out_avg, psnr_noisy_last, last_net, net_input, losses, psnrs, ssims, average_dropout_rate, no_layers,\
               img_mean, sample_count, recons, uncerts, uncerts_ale, loss_last

        add_noise(net)

        if reg_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() * reg_noise_std)

        out = net(net_input)
        out[:, :1] = out[:, :1].sigmoid()

        _loss = gaussian_nll(out[:, :1], out[:, 1:], img_noisy_torch)
        _loss.backward()

        out[:, 1:] = torch.exp(-out[:, 1:])  # aleatoric uncertainty

        # Smoothing
        if out_avg is None:
            out_avg = out.detach()
        else:
            out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

        with torch.no_grad():
            mse_loss = mse(out_avg[:, :1], img_noisy_torch).item()

        losses.append(mse_loss)

        _out = out.detach().cpu().numpy()[0, :1]
        _out_avg = out_avg.detach().cpu().numpy()[0, :1]

        psnr_noisy = compare_psnr(img_noisy_np, _out)
        psnr_gt = compare_psnr(img_np, _out)
        psnr_gt_sm = compare_psnr(img_np, _out_avg)

        ssim_noisy = compare_ssim(img_noisy_np[0], _out[0])
        ssim_gt = compare_ssim(img_np[0], _out[0])
        ssim_gt_sm = compare_ssim(img_np[0], _out_avg[0])

        psnrs.append([psnr_noisy, psnr_gt, psnr_gt_sm])
        ssims.append([ssim_noisy, ssim_gt, ssim_gt_sm])

        if PLOT and i % show_every == 0:
            print(
                f'Iteration: {i} Loss: {_loss.item():.4f} PSNR_noisy: {psnr_noisy:.4f} PSRN_gt: {psnr_gt:.4f} PSNR_gt_sm: {psnr_gt_sm:.4f}'
            )

            out_np = _out
            recons.append(out_np)
            out_np_ale = out.detach().cpu().numpy()[0, 1:]
            out_np_var = np.var(np.array(recons[-mc_iter:]), axis=0)[:1]

            print('mean epi', out_np_var.mean())
            print('mean ale', out_np_ale.mean())
            print('###################')

            uncerts.append(out_np_var)
            uncerts_ale.append(out_np_ale)

        i += 1

        return _loss

    if '../NORMAL-4951060-8.jpeg':
        net = get_net(input_depth,
                      NET_TYPE,
                      pad,
                      skip_n33d=skip_n33d,
                      skip_n33u=skip_n33u,
                      skip_n11=skip_n11,
                      num_scales=num_scales,
                      n_channels=2,
                      upsample_mode=upsample_mode,
                      dropout_mode_down=dropout_mode_down,
                      dropout_p_down=dropout_p_down,
                      dropout_mode_up=dropout_mode_up,
                      dropout_p_up=dropout_p_up,
                      dropout_mode_skip=dropout_mode_skip,
                      dropout_p_skip=dropout_p_skip,
                      dropout_mode_output=dropout_mode_output,
                      dropout_p_output=dropout_p_output).type(dtype)
    else:
        assert False

    net.apply(init_normal)

    losses = []
    recons = []
    uncerts = []
    uncerts_ale = []
    psnrs = []
    ssims = []

    img_mean = 0
    sample_count = 0
    i = 0
    psnr_noisy_last = 0
    loss_last = 1e6
    out_avg = None
    last_net = None

    parameters = get_params(OPT_OVER, net, net_input)
    optimizer = torch.optim.AdamW(parameters, lr=LR, weight_decay=weight_decay)
    optimize(optimizer, closure_sgldnll, num_iter)

    LOSSES['sgldnll'] = losses
    RECONS['sgldnll'] = recons
    UNCERTS['sgldnll'] = uncerts
    UNCERTS_ALE['sgldnll'] = uncerts_ale
    PSNRS['sgldnll'] = psnrs
    SSIMS['sgldnll'] = ssims

    to_plot = [img_np] + [np.clip(img, 0, 1) for img in RECONS['sgldnll']]
    q = plot_image_grid(to_plot, factor=13)

    out_pil = np_to_pil(q)
    out_pil.save(f'{save_path}/{timestamp}/sgldnll_recons.png', 'PNG')

    errs = img_noisy_torch.cpu() - torch.tensor(RECONS['sgldnll'][-1])
    uncerts_epi = torch.tensor(UNCERTS['sgldnll'][-1]).unsqueeze(0)
    uncerts_ale = torch.tensor(UNCERTS_ALE['sgldnll'][-1]).unsqueeze(0)
    uncerts = uncerts_epi + uncerts_ale
    uce, err, uncert, freq = uceloss(errs**2, uncerts, n_bins=21)
    fig, ax = plot_uncert(err, uncert, freq, outlier_freq=0.001)
    ax.set_title(
        f'U = {uncerts.mean().sqrt().item():.4f}, UCE = {uce.item()*100:.3f}')
    plt.tight_layout()
    fig.savefig(f'{save_path}/{timestamp}/sgldnll_calib.png')

    errs = torch.tensor(img_np).unsqueeze(0) - torch.tensor(
        RECONS['sgldnll'][-1])
    uncerts_epi = torch.tensor(UNCERTS['sgldnll'][-1]).unsqueeze(0)
    uncerts_ale = torch.tensor(UNCERTS_ALE['sgldnll'][-1]).unsqueeze(0)
    uncerts = uncerts_epi + uncerts_ale
    uce, err, uncert, freq = uceloss(errs**2, uncerts, n_bins=21)
    fig, ax = plot_uncert(err, uncert, freq, outlier_freq=0.001)
    ax.set_title(
        f'U = {uncerts.mean().sqrt().item():.4f}, UCE = {uce.item()*100:.3f}')
    plt.tight_layout()
    fig.savefig(f'{save_path}/{timestamp}/sgldnll_calib2.png')

    ## MCDIP

    OPTIMIZER = 'adamw'
    weight_decay = 1e-4
    LOSS = 'nll'
    input_depth = 32
    figsize = 4

    NET_TYPE = 'skip'

    skip_n33d = 128
    skip_n33u = 128
    skip_n11 = 4
    num_scales = 5
    upsample_mode = 'bilinear'

    dropout_mode_down = '2d'
    dropout_p_down = 0.3
    dropout_mode_up = '2d'
    dropout_p_up = dropout_p_down
    dropout_mode_skip = 'None'
    dropout_p_skip = dropout_p_down
    dropout_mode_output = 'None'
    dropout_p_output = dropout_p_down

    net_input = get_noise(
        input_depth, INPUT,
        (img_pil.size[1], img_pil.size[0])).type(dtype).detach()

    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()

    mc_iter = 25

    def closure_mcdip():

        global i, out_avg, psnr_noisy_last, last_net, net_input, losses, psnrs, ssims, average_dropout_rate, no_layers,\
               img_mean, sample_count, recons, uncerts, uncerts_ale, loss_last

        if reg_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() * reg_noise_std)

        out = net(net_input)
        out[:, :1] = out[:, :1].sigmoid()

        _loss = gaussian_nll(out[:, :1], out[:, 1:], img_noisy_torch)
        _loss.backward()

        out[:, 1:] = torch.exp(-out[:, 1:])  # aleatoric uncertainty

        # Smoothing
        if out_avg is None:
            out_avg = out.detach()
        else:
            out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

        losses.append(mse(out_avg[:, :1], img_noisy_torch).item())

        _out = out.detach().cpu().numpy()[0, :1]
        _out_avg = out_avg.detach().cpu().numpy()[0, :1]

        psnr_noisy = compare_psnr(img_noisy_np, _out)
        psnr_gt = compare_psnr(img_np, _out)
        psnr_gt_sm = compare_psnr(img_np, _out_avg)

        ssim_noisy = compare_ssim(img_noisy_np[0], _out[0])
        ssim_gt = compare_ssim(img_np[0], _out[0])
        ssim_gt_sm = compare_ssim(img_np[0], _out_avg[0])

        psnrs.append([psnr_noisy, psnr_gt, psnr_gt_sm])
        ssims.append([ssim_noisy, ssim_gt, ssim_gt_sm])

        if PLOT and i % show_every == 0:
            print(
                f'Iteration: {i} Loss: {_loss.item():.4f} PSNR_noisy: {psnr_noisy:.4f} PSRN_gt: {psnr_gt:.4f} PSNR_gt_sm: {psnr_gt_sm:.4f}'
            )

            img_list = []
            aleatoric_list = []

            with torch.no_grad():
                net_input = net_input_saved + (noise.normal_() * reg_noise_std)

                for _ in range(mc_iter):
                    img = net(net_input)
                    img[:, :1] = img[:, :1].sigmoid()
                    img[:, 1:] = torch.exp(-img[:, 1:])
                    img_list.append(torch_to_np(img[:1]))
                    aleatoric_list.append(torch_to_np(img[:, 1:]))

            img_list_np = np.array(img_list)
            out_np = np.mean(img_list_np, axis=0)[:1]
            out_np_ale = np.mean(aleatoric_list, axis=0)[:1]
            out_np_var = np.var(img_list_np, axis=0)[:1]

            psnr_noisy = compare_psnr(img_noisy_np, out_np)
            psnr_gt = compare_psnr(img_np, out_np)

            print('mean epi', out_np_var.mean())
            print('mean ale', out_np_ale.mean())
            print('###################')

            recons.append(out_np)
            uncerts.append(out_np_var)
            uncerts_ale.append(out_np_ale)

        i += 1

        return _loss

    if '../NORMAL-4951060-8.jpeg':
        net = get_net(input_depth,
                      NET_TYPE,
                      pad,
                      skip_n33d=skip_n33d,
                      skip_n33u=skip_n33u,
                      skip_n11=skip_n11,
                      num_scales=num_scales,
                      n_channels=2,
                      upsample_mode=upsample_mode,
                      dropout_mode_down=dropout_mode_down,
                      dropout_p_down=dropout_p_down,
                      dropout_mode_up=dropout_mode_up,
                      dropout_p_up=dropout_p_up,
                      dropout_mode_skip=dropout_mode_skip,
                      dropout_p_skip=dropout_p_skip,
                      dropout_mode_output=dropout_mode_output,
                      dropout_p_output=dropout_p_output).type(dtype)
    else:
        assert False

    net.apply(init_normal)

    losses = []
    recons = []
    uncerts = []
    uncerts_ale = []
    psnrs = []
    ssims = []

    img_mean = 0
    sample_count = 0
    i = 0
    psnr_noisy_last = 0
    loss_last = 1e16
    out_avg = None
    last_net = None

    parameters = get_params(OPT_OVER, net, net_input)
    optimizer = torch.optim.AdamW(parameters, lr=LR, weight_decay=weight_decay)
    optimize(optimizer, closure_mcdip, num_iter)

    LOSSES['mcdip'] = losses
    RECONS['mcdip'] = recons
    UNCERTS['mcdip'] = uncerts
    UNCERTS_ALE['mcdip'] = uncerts_ale
    PSNRS['mcdip'] = psnrs
    SSIMS['mcdip'] = ssims

    # In[75]:

    to_plot = [img_np] + [np.clip(img, 0, 1) for img in RECONS['mcdip']]
    q = plot_image_grid(to_plot, factor=13)

    out_pil = np_to_pil(q)
    out_pil.save(f'{save_path}/{timestamp}/mcdip_recons.png', 'PNG')

    # In[85]:

    errs = img_noisy_torch.cpu() - torch.tensor(RECONS['mcdip'][-1])
    uncerts_epi = torch.tensor(UNCERTS['mcdip'][-1]).unsqueeze(0)
    uncerts_ale = torch.tensor(UNCERTS_ALE['mcdip'][-1]).unsqueeze(0)
    uncerts = uncerts_epi + uncerts_ale
    uce, err, uncert, freq = uceloss(errs**2, uncerts, n_bins=21)
    fig, ax = plot_uncert(err, uncert, freq, outlier_freq=0.001)
    ax.set_title(
        f'U = {uncerts.mean().sqrt().item():.4f}, UCE = {uce.item()*100:.3f}')
    plt.tight_layout()
    fig.savefig(f'{save_path}/{timestamp}/mcdip_calib.png')

    # In[86]:

    errs = torch.tensor(img_np).unsqueeze(0) - torch.tensor(
        RECONS['mcdip'][-1])
    uncerts_epi = torch.tensor(UNCERTS['mcdip'][-1]).unsqueeze(0)
    uncerts_ale = torch.tensor(UNCERTS_ALE['mcdip'][-1]).unsqueeze(0)
    uncerts = uncerts_epi + uncerts_ale
    uce, err, uncert, freq = uceloss(errs**2, uncerts, n_bins=21)
    fig, ax = plot_uncert(err, uncert, freq, outlier_freq=0.001)
    ax.set_title(
        f'U = {uncerts.mean().sqrt().item():.4f}, UCE = {uce.item()*100:.3f}')
    plt.tight_layout()
    fig.savefig(f'{save_path}/{timestamp}/mcdip_calib2.png')

    fig, ax0 = plt.subplots(1, 1)

    for key, loss in LOSSES.items():
        ax0.plot(range(len(loss)), loss, label=key)
        ax0.set_title('MSE')
        ax0.set_xlabel('iteration')
        ax0.set_ylabel('mse loss')
        ax0.set_ylim(0, 0.03)
        ax0.grid(True)
        ax0.legend()

    plt.tight_layout()
    plt.savefig(f'{save_path}/{timestamp}/losses.png')
    plt.show()

    fig, axs = plt.subplots(1, 3, constrained_layout=True)
    labels = ["psnr_noisy", "psnr_gt", "psnr_gt_sm"]

    for key, psnr in PSNRS.items():
        psnr = np.array(psnr)
        for i in range(psnr.shape[1]):
            axs[i].plot(range(psnr.shape[0]), psnr[:, i], label=key)
            axs[i].set_title(labels[i])
            axs[i].set_xlabel('iteration')
            axs[i].set_ylabel('psnr')
            axs[i].legend()

    plt.savefig(f'{save_path}/{timestamp}/psnrs.png')
    plt.show()

    fig, axs = plt.subplots(1, 3, constrained_layout=True)
    labels = ["ssim_noisy", "ssim_gt", "ssim_gt_sm"]

    for key, ssim in SSIMS.items():
        ssim = np.array(ssim)
        for i in range(ssim.shape[1]):
            axs[i].plot(range(ssim.shape[0]), ssim[:, i], label=key)
            axs[i].set_title(labels[i])
            axs[i].set_xlabel('iteration')
            axs[i].legend()
            axs[i].set_ylabel('ssim')

    plt.savefig(f'{save_path}/{timestamp}/ssims.png')
    plt.show()

    # save stuff for plotting
    if save:
        np.savez(f"{save_path}/{timestamp}/save.npz",
                 noisy_img=img_noisy_np,
                 losses=LOSSES,
                 recons=RECONS,
                 uncerts=UNCERTS,
                 uncerts_ale=UNCERTS_ALE,
                 psnrs=PSNRS,
                 ssims=SSIMS)
Exemplo n.º 8
0
def main(args: APNamespace):
    root_path = Path(args.root).expanduser()
    config_path = Path(args.config).expanduser()
    data_path = root_path / Path(args.data).expanduser()
    output_path = root_path / Path(args.output).expanduser()
    global checkpoint_path, config
    checkpoint_path = root_path / Path(args.checkpoint).expanduser()

    if not config_path.exists():
        # logging.critical(f"AdaS: Config path {config_path} does not exist")
        print(f"AdaS: Config path {config_path} does not exist")
        raise ValueError
    if not data_path.exists():
        print(f"AdaS: Data dir {data_path} does not exist, building")
        data_path.mkdir(exist_ok=True, parents=True)
    if not output_path.exists():
        print(f"AdaS: Output dir {output_path} does not exist, building")
        output_path.mkdir(exist_ok=True, parents=True)
    if not checkpoint_path.exists():
        if args.resume:
            print(f"AdaS: Cannot resume from checkpoint without specifying " +
                  "checkpoint dir")
            raise ValueError
        checkpoint_path.mkdir(exist_ok=True, parents=True)
    with config_path.open() as f:
        config = yaml.load(f)
    print("Adas: Argument Parser Options")
    print("-" * 45)
    print(f"    {'config':<20}: {args.config:<40}")
    print(f"    {'data':<20}: {str(Path(args.root) / args.data):<40}")
    print(f"    {'output':<20}: {str(Path(args.root) / args.output):<40}")
    print(f"    {'checkpoint':<20}: " +
          f"{str(Path(args.root) / args.checkpoint):<40}")
    print(f"    {'root':<20}: {args.root:<40}")
    print(f"    {'resume':<20}: {'True' if args.resume else 'False':<20}")
    print("\nAdas: Train: Config")
    print(f"    {'Key':<20} {'Value':<20}")
    print("-" * 45)
    for k, v in config.items():
        print(f"    {k:<20} {v:<20}")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"AdaS: Pytorch device is set to {device}")
    global best_acc
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
    if np.less(float(config['early_stop_threshold']), 0):
        print("AdaS: Notice: early stop will not be used as it was set to " +
              f"{config['early_stop_threshold']}, training till completion.")

    for trial in range(config['n_trials']):
        if config['lr_scheduler'] == 'AdaS':
            filename = \
                f"stats_{config['optim_method']}_AdaS_trial={trial}_" +\
                f"beta={config['beta']}_initlr={config['init_lr']}_" +\
                f"net={config['network']}_dataset={config['dataset']}.csv"
        else:
            filename = \
                f"stats_{config['optim_method']}_{config['lr_scheduler']}_" +\
                f"trial={trial}_initlr={config['init_lr']}" +\
                f"net={config['network']}_dataset={config['dataset']}.csv"
        Profiler.filename = output_path / filename
        device
        # Data
        # logging.info("Adas: Preparing Data")
        train_loader, test_loader = get_data(
            root=data_path,
            dataset=config['dataset'],
            mini_batch_size=config['mini_batch_size'])
        global performance_statistics, net, metrics, adas
        performance_statistics = {}

        # logging.info("AdaS: Building Model")
        net = get_net(config['network'],
                      num_classes=10 if config['dataset'] == 'CIFAR10' else
                      100 if config['dataset'] == 'CIFAR100' else
                      1000 if config['dataset'] == 'ImageNet' else 10)
        metrics = Metrics(list(net.parameters()), p=config['p'])
        if config['lr_scheduler'] == 'AdaS':
            adas = AdaS(parameters=list(net.parameters()),
                        beta=config['beta'],
                        zeta=config['zeta'],
                        init_lr=float(config['init_lr']),
                        min_lr=float(config['min_lr']),
                        p=config['p'])

        net = net.to(device)

        global criterion
        criterion = get_loss(config['loss'])

        optimizer, scheduler = get_optimizer_scheduler(
            net_parameters=net.parameters(),
            init_lr=float(config['init_lr']),
            optim_method=config['optim_method'],
            lr_scheduler=config['lr_scheduler'],
            train_loader_len=len(train_loader),
            max_epochs=int(config['max_epoch']))
        early_stop = EarlyStop(patience=int(config['early_stop_patience']),
                               threshold=float(config['early_stop_threshold']))

        if device == 'cuda':
            net = torch.nn.DataParallel(net)
            cudnn.benchmark = True

        if args.resume:
            # Load checkpoint.
            print("Adas: Resuming from checkpoint...")
            checkpoint = torch.load(str(checkpoint_path / 'ckpt.pth'))
            # if checkpoint_path.is_dir():
            #     checkpoint = torch.load(str(checkpoint_path / 'ckpt.pth'))
            # else:
            #     checkpoint = torch.load(str(checkpoint_path))
            net.load_state_dict(checkpoint['net'])
            best_acc = checkpoint['acc']
            start_epoch = checkpoint['epoch']
            if adas is not None:
                metrics.historical_metrics = \
                    checkpoint['historical_io_metrics']

        # model_parameters = filter(lambda p: p.requires_grad,
        #                           net.parameters())
        # params = sum([np.prod(p.size()) for p in model_parameters])
        # print(params)
        epochs = range(start_epoch, start_epoch + config['max_epoch'])
        for epoch in epochs:
            start_time = time.time()
            # print(f"AdaS: Epoch {epoch}/{epochs[-1]} Started.")
            train_loss, train_accuracy, test_loss, test_accuracy = epoch_iteration(
                train_loader, test_loader, epoch, device, optimizer, scheduler)
            end_time = time.time()
            if config['lr_scheduler'] == 'StepLR':
                scheduler.step()
            total_time = time.time()
            print(
                f"AdaS: Trial {trial}/{config['n_trials'] - 1} | " +
                f"Epoch {epoch}/{epochs[-1]} Ended | " +
                "Total Time: {:.3f}s | ".format(total_time - start_time) +
                "Epoch Time: {:.3f}s | ".format(end_time - start_time) +
                "~Time Left: {:.3f}s | ".format(
                    (total_time - start_time) * (epochs[-1] - epoch)),
                "Train Loss: {:.4f}% | Train Acc. {:.4f}% | ".format(
                    train_loss, train_accuracy) +
                "Test Loss: {:.4f}% | Test Acc. {:.4f}%".format(
                    test_loss, test_accuracy))
            df = pd.DataFrame(data=performance_statistics)
            if config['lr_scheduler'] == 'AdaS':
                xlsx_name = \
                    f"{config['optim_method']}_AdaS_trial={trial}_" +\
                    f"beta={config['beta']}_initlr={config['init_lr']}_" +\
                    f"net={config['network']}_dataset={config['dataset']}.xlsx"
            else:
                xlsx_name = \
                    f"{config['optim_method']}_{config['lr_scheduler']}_" +\
                    f"trial={trial}_initlr={config['init_lr']}" +\
                    f"net={config['network']}_dataset={config['dataset']}.xlsx"

            df.to_excel(str(output_path / xlsx_name))
            if early_stop(train_loss):
                print("AdaS: Early stop activated.")
                break
    return
Exemplo n.º 9
0
def main():
    config = DefaultConfigs()
    train_input_root = os.path.join(config.data)
    train_labels_file = 'labels.csv'

    if config.output:
        if not os.path.exists(config.output):
            os.makedirs(config.output)
        output_base = config.output
    else:
        if not os.path.exists(config.output):
            os.makedirs(config.output)
        output_base = config.output

    exp_name = '-'.join([
        datetime.now().strftime("%Y%m%d-%H%M%S"), config.model,
        str(config.img_size), 'f' + str(config.fold)
    ])
    mask_exp_name = '-'.join(
        [config.model,
         str(config.img_size), 'f' + str(config.fold)])
    mask_exp_name = glob.glob(
        os.path.join(output_base, 'train', '*' + mask_exp_name))
    if config.resume and mask_exp_name:
        output_dir = mask_exp_name
    else:
        output_dir = get_outdir(output_base, 'train', exp_name)

    batch_size = config.batch_size
    test_batch_size = config.test_batch_size
    num_epochs = config.epochs
    img_type = config.image_type
    img_size = (config.img_size, config.img_size)
    num_classes = get_tags_size(config.labels)

    torch.manual_seed(config.seed)

    dataset_train = HumanDataset(
        train_input_root,
        train_labels_file,
        train=True,
        multi_label=config.multi_label,
        img_type=img_type,
        img_size=img_size,
        fold=config.fold,
    )

    #sampler = WeightedRandomOverSampler(dataset_train.get_sample_weights())

    loader_train = data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=True,
        #sampler=sampler,
        num_workers=config.num_processes)

    dataset_eval = HumanDataset(
        train_input_root,
        train_labels_file,
        train=False,
        multi_label=config.multi_label,
        img_type=img_type,
        img_size=img_size,
        test_aug=config.tta,
        fold=config.fold,
    )

    loader_eval = data.DataLoader(dataset_eval,
                                  batch_size=test_batch_size,
                                  shuffle=False,
                                  num_workers=config.num_processes)

    #    model = model_factory.create_model(
    #        config.model,
    #        pretrained=True,
    #        num_classes=num_classes,
    #        drop_rate=config.drop,
    #        global_pool=config.gp)

    model = get_net(config.model, num_classes, config.drop, config.channels)

    if not config.no_cuda:
        if config.num_gpu > 1:
            model = torch.nn.DataParallel(model,
                                          device_ids=list(range(
                                              config.num_gpu))).cuda()
        else:
            model.cuda()

    if config.opt.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=config.lr,
                              momentum=config.momentum,
                              weight_decay=config.weight_decay)
    elif config.opt.lower() == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=config.lr,
                               weight_decay=config.weight_decay)
    elif config.opt.lower() == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=config.lr,
                                   weight_decay=config.weight_decay)
    elif config.opt.lower() == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  lr=config.lr,
                                  alpha=0.9,
                                  momentum=config.momentum,
                                  weight_decay=config.weight_decay)
    elif config.opt.lower() == 'yellowfin':
        optimizer = YFOptimizer(model.parameters(),
                                lr=config.lr,
                                weight_decay=config.weight_decay,
                                clip_thresh=2)
    else:
        assert False and "Invalid optimizer"

    if not config.decay_epochs:
        lr_scheduler = ReduceLROnPlateau(optimizer, patience=8)
    else:
        lr_scheduler = None

    if config.class_weights:
        class_weights = torch.from_numpy(
            dataset_train.get_class_weights()).float()
        class_weights_norm = class_weights / class_weights.sum()
        if not config.no_cuda:
            class_weights = class_weights.cuda()
            class_weights_norm = class_weights_norm.cuda()
    else:
        class_weights = None
        class_weights_norm = None

    if config.loss.lower() == 'nll':
        #assert not args.multi_label and 'Cannot use crossentropy with multi-label target.'
        loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
    elif config.loss.lower() == 'mlsm':
        assert config.multi_label
        loss_fn = torch.nn.MultiLabelSoftMarginLoss(weight=class_weights)
    else:
        assert config and "Invalid loss function"

    if not config.no_cuda:
        loss_fn = loss_fn.cuda()

    # optionally resume from a checkpoint
    start_epoch = 1
    if config.resume:
        if os.path.isfile(config.resume):
            print("=> loading checkpoint '{}'".format(config.resume))
            checkpoint = torch.load(config.resume)
            config.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                config.resume, checkpoint['epoch']))
            start_epoch = checkpoint['epoch']
        else:
            print("=> no checkpoint found at '{}'".format(config.resume))
            exit(-1)

    use_tensorboard = not config.no_tb and CrayonClient is not None
    if use_tensorboard:
        hostname = '127.0.0.1'
        port = 8889
        host_port = config.tbh.split(':')[:2]
        if len(host_port) == 1:
            hostname = host_port[0]
        elif len(host_port) >= 2:
            hostname, port = host_port[:2]
        try:
            cc = CrayonClient(hostname=hostname, port=port)
            try:
                cc.remove_experiment(exp_name)
            except ValueError:
                pass
            exp = cc.create_experiment(exp_name)
        except Exception as e:
            exp = None
            print(
                "Error (%s) connecting to Tensoboard/Crayon server. Giving up..."
                % str(e))
    else:
        exp = None

    # Optional fine-tune of only the final classifier weights for specified number of epochs (or part of)
    if not config.resume and config.ft_epochs > 0.:
        if config.opt.lower() == 'adam':
            finetune_optimizer = optim.Adam(model.get_fc().parameters(),
                                            lr=config.ft_lr,
                                            weight_decay=config.weight_decay)
        else:
            finetune_optimizer = optim.SGD(model.get_fc().parameters(),
                                           lr=config.ft_lr,
                                           momentum=config.momentum,
                                           weight_decay=config.weight_decay)

        finetune_epochs_int = int(np.ceil(config.ft_epochs))
        finetune_final_batches = int(
            np.ceil((1 - (finetune_epochs_int - config.ft_epochs)) *
                    len(loader_train)))
        print(finetune_epochs_int, finetune_final_batches)
        for fepoch in range(1, finetune_epochs_int + 1):
            if fepoch == finetune_epochs_int and finetune_final_batches:
                batch_limit = finetune_final_batches
            else:
                batch_limit = 0
            train_epoch(fepoch,
                        model,
                        loader_train,
                        finetune_optimizer,
                        loss_fn,
                        config,
                        class_weights_norm,
                        output_dir,
                        batch_limit=batch_limit)
            step = fepoch * len(loader_train)
            score, _ = validate(step, model, loader_eval, loss_fn, config, 0.3,
                                output_dir)

    score_metric = 'f2'
    best_loss = None
    best_f2 = None
    threshold = 0.2
    try:
        for epoch in range(start_epoch, num_epochs + 1):
            if config.decay_epochs:
                adjust_learning_rate(optimizer,
                                     epoch,
                                     initial_lr=config.lr,
                                     decay_epochs=config.decay_epochs)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        loss_fn,
                                        config,
                                        class_weights_norm,
                                        output_dir,
                                        exp=exp)

            step = epoch * len(loader_train)
            eval_metrics, latest_threshold = validate(step,
                                                      model,
                                                      loader_eval,
                                                      loss_fn,
                                                      config,
                                                      threshold,
                                                      output_dir,
                                                      exp=exp)

            if lr_scheduler is not None:
                lr_scheduler.step(eval_metrics['eval_loss'])

            rowd = OrderedDict(epoch=epoch)
            rowd.update(train_metrics)
            rowd.update(eval_metrics)
            with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf:
                dw = csv.DictWriter(cf, fieldnames=rowd.keys())
                if best_loss is None:  # first iteration (epoch == 1 can't be used)
                    dw.writeheader()
                dw.writerow(rowd)

            best = False
            if best_loss is None or eval_metrics['eval_loss'] < best_loss[1]:
                best_loss = (epoch, eval_metrics['eval_loss'])
                if score_metric == 'loss':
                    best = True
            if best_f2 is None or eval_metrics['eval_f2'] > best_f2[1]:
                best_f2 = (epoch, eval_metrics['eval_f2'])
                if score_metric == 'f2':
                    best = True

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': config.model,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'threshold': latest_threshold,
                    'config': config
                },
                is_best=best,
                filename=os.path.join(config.checkpoint_path,
                                      'checkpoint-%d.pth.tar' % epoch),
                output_dir=output_dir)

    except KeyboardInterrupt:
        pass
    print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0]))
    print('*** Best f2: {0} (epoch {1})'.format(best_f2[1], best_f2[0]))
Exemplo n.º 10
0
    print("\nload model relu for second run")
    train_result_model_batch = train_and_evaluate_model(
        model_batch_norm, batch_size2, epochs3, BaseX_train, BaseY_train,
        X_test, Y_test, BaseX_val, BaseY_val)
    # plot_train_stat(train_result_model_batch)
    """Part 2"""
    input_shape = (32, 32, 1)
    learn_rate = 1e-5
    decay = 1e-03
    batch_size = 64
    epochs = 25
    drop = True
    dropRate = 0.3
    reg = 1e-2
    filters1 = [64, 128, 128, 256, 256]
    NNet1 = get_net(filters1, input_shape, drop, dropRate, reg)

    # Defining the optimizar parameters:
    AdamOpt = Adam(lr=learn_rate, decay=decay)

    # Compile the network:
    NNet1.compile(optimizer=AdamOpt,
                  metrics=['acc'],
                  loss='categorical_crossentropy')

    # Saving checkpoints during training:
    # Checkpath = os.getcwd()

    h1 = NNet1.fit(x=BaseX_train,
                   y=BaseY_train,
                   batch_size=batch_size,
Exemplo n.º 11
0
          "raccoons. baseline accuracy", print_statement)

print("x_test: ", len(x_test))
print("y_test: ", y_test.shape[0])
unique, counts = np.unique(y_test, return_counts=True)
if counts[0] >= counts[1]:
    baseline_acc = counts[0] / (counts[0] + counts[1])
    print_statement = str(round(baseline_acc, 2)) + " (always coyote)."
else:
    baseline_acc = counts[0] / (counts[0] + counts[1])
    print_statement = str(round(baseline_acc, 2)) + " (always coyote)."
print("test class distribution: ", counts[0], "coyotes and", counts[1],
      "raccoons. baseline accuracy", print_statement)

# get model and data handler
net = get_net(dataset_name)
handler = get_handler(dataset_name)

# GPU enabled?
print("Using GPU - {}".format(torch.cuda.is_available()))

final_train_accs = []
final_test_accs = []
train_process = Train(envs, x_test, y_test, net, handler, args)

print()
if args['optimizer_args']['penalty_weight'] > 1.0:
    model_name = model_name + "IRM"
    print(
        "========================================IRM========================================"
    )
Exemplo n.º 12
0
def denoise(fname, plot=False):
    """Add AWGN with sigma=25 to the given image and denoise it.

    Args:
        fname: Path to the image.
        mode: Stopping mode to use. either "AMNS", "SMNS", or "static".

    Returns:
        A tuple with the denoised image in numpy format as the first element,
        and a history of the PSNR in the second element.

    """
    dtype = torch.cuda.FloatTensor

    sigma = 25
    sigma_ = sigma/255.
    imsize = -1

    np.random.seed(7)

    img_pil = crop_image(get_image(fname, imsize)[0], d=32)
    img_np = pil_to_np(img_pil)
    img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)

    if plot:
        plot_image_grid([img_np, img_noisy_np], 4, 6)

    INPUT = 'noise'  # 'meshgrid'
    pad = 'reflection'
    OPT_OVER = 'net'  # 'net,input'

    reg_noise_std = 1./30.  # set to 1./20. for sigma=50
    LR = 0.01
    exp_weight = 0.99  # Exponential averaging coefficient

    OPTIMIZER = 'adam'  # 'LBFGS'
    show_every = 500

    num_iter = 1800
    input_depth = 32
    figsize = 4

    net = get_net(input_depth, 'skip', pad,
                  skip_n33d=128,
                  skip_n33u=128,
                  skip_n11=4,
                  num_scales=5,
                  upsample_mode='bilinear').type(dtype)

    net_input = get_noise(input_depth, INPUT, (img_np.shape[1], img_np.shape[2])).type(dtype).detach()

    # Compute number of parameters
    s = sum([np.prod(list(p.size())) for p in net.parameters()])
    print('Number of params: %d' % s)

    # Loss
    mse = torch.nn.MSELoss().type(dtype)

    img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)
    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()
    out_avg = None
    last_net = None
    psrn_noisy_last = 0

    i = 0
    psnr_history = []

    def closure():
        nonlocal i, out_avg, psrn_noisy_last, last_net, psnr_history

        if reg_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() * reg_noise_std)

        out = net(net_input)

        # Smoothing
        if exp_weight is not None:
            if out_avg is None:
                out_avg = out.detach()
            else:
                out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

        total_loss = mse(out, img_noisy_torch)
        total_loss.backward()

        psrn_noisy = compare_psnr(img_noisy_np, out.detach().cpu().numpy()[0])
        psrn_gt = compare_psnr(img_np, out.detach().cpu().numpy()[0])
        psrn_gt_sm = compare_psnr(img_np, out_avg.detach().cpu().numpy()[0])
        psnr_history.append(psrn_gt_sm)

        print('Iteration %05d    Loss %f   PSNR_noisy: %f   PSRN_gt: %f PSNR_gt_sm: %f' % (i, total_loss.item(), psrn_noisy, psrn_gt, psrn_gt_sm), '\r', end='')
        if plot and i % show_every == 0:
            out_np = torch_to_np(out)
            plot_image_grid([np.clip(out_np, 0, 1), np.clip(torch_to_np(out_avg), 0, 1)], factor=figsize, nrow=1)

        # Backtracking
        if i % show_every:
            if psrn_noisy - psrn_noisy_last < -5:
                print('Falling back to previous checkpoint.')

                for new_param, net_param in zip(last_net, net.parameters()):
                    net_param.data.copy_(new_param.cuda())

                return total_loss*0
            else:
                last_net = [x.data.cpu() for x in net.parameters()]
                psrn_noisy_last = psrn_noisy

        i += 1
        return total_loss

    p = get_params(OPT_OVER, net, net_input)
    try:
        optimize(OPTIMIZER, p, closure, LR, num_iter)
    except StopIteration:
        pass

    return out_avg, psnr_history