def train():

    data = load_train_data()
    data = data.reshape((data.shape[0], data.shape[1], data.shape[2], 1))
    data = data.astype('float32') / 255.0
    #pyr = Laplacian_pyramid(data.shape,6)
    #loss = pyr.distance

    # model selection
    if args.pretrain: model = load_model(args.pretrain, compile=False)
    else:
        if args.model == 'DnCNN': model = models.DnCNN()
    # compile the model
    #model.compile(optimizer=Adam(), loss=[loss])
    model.compile(optimizer=Adam(), loss=['mse'])

    # use call back functions
    ckpt = ModelCheckpoint(save_dir + '/model_{epoch:02d}.h5',
                           monitor='val_loss',
                           verbose=0,
                           period=args.save_every)
    csv_logger = CSVLogger(save_dir + '/log.csv', append=True, separator=',')
    lr = LearningRateScheduler(step_decay)
    # train
    history = model.fit_generator(train_datagen(data,
                                                batch_size=args.batch_size),
                                  steps_per_epoch=len(data) // args.batch_size,
                                  epochs=args.epoch,
                                  verbose=1,
                                  callbacks=[ckpt, csv_logger, lr])

    return model
Esempio n. 2
0
def createLearner(data,model_choice=2):
    if model_choice == 1:
        model = mymodels.DnCNN()
    else:
        model = mymodels.myDnCNN()
    learn = Learner(data,model,metrics=mse,loss_func=MSELossFlat())

    return learn
def main(config, resume, device):
    start_time = time.time()
    if config["dataset"] == "fastMRI":
        train_data = dataloader.KneeMRI(config["train_loader"]["target_dir"],
                                        config["train_loader"]["noise_dirs"])
        val_data = dataloader.KneeMRI(config["val_loader"]["target_dir"],
                                      config["val_loader"]["noise_dirs"])
    elif config["dataset"] == "BSD500":
        train_data = dataloader.BSD500(config["train_loader"]["target_dir"],
                                       config["sigma"])
        val_data = dataloader.BSD500(config["val_loader"]["target_dir"],
                                     config["sigma"])
    trainloader = DataLoader(train_data,
                             batch_size=config["train_loader"]["batch_size"],
                             shuffle=config["train_loader"]["shuffle"],
                             num_workers=config["train_loader"]["num_workers"])
    valloader = DataLoader(val_data,
                           batch_size=config["val_loader"]["batch_size"],
                           shuffle=config["val_loader"]["shuffle"],
                           num_workers=config["val_loader"]["num_workers"])
    experim_dir = os.path.join(config["trainer"]['save_dir'],
                               config['experim_name'])
    if not os.path.exists(experim_dir):
        os.makedirs(experim_dir)
    for seed in config["seeds"]:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        model = models.DnCNN(
            config,
            depth=config["model"]["depth"],
            n_channels=config["model"]["n_channels"],
            image_channels=config["model"]["image_channels"],
            kernel_size=config["model"]["kernel_size"],
            padding=config["model"]["padding"],
            architecture=config["model"]["architecture"],
            spectral_norm=config["model"]["spectral_norm"],
            shared_activation=config["model"]["shared_activation"],
            shared_channels=config["model"]["shared_channels"],
            device=args.device)
        if config["dataset"] == "BSD500":
            model.apply(weights_init_kaiming)
        train_logger = Logger()
        trainer = Trainer(config, trainloader, valloader, model, train_logger,
                          seed, resume, device)
        trainer.train()
if __name__ == '__main__':

    # ---- input arguments ----
    args = parse_arguments()
    # CONFIG -> assert if config is here
    assert args.config
    config = json.load(open(args.config))

    # ---- load the model ----
    model = models.DnCNN(
        config,
        depth=config["model"]["depth"],
        n_channels=config["model"]["n_channels"],
        image_channels=config["model"]["image_channels"],
        kernel_size=config["model"]["kernel_size"],
        padding=config["model"]["padding"],
        architecture=config["model"]["architecture"],
        spectral_norm=config["model"]["spectral_norm"],
        shared_activation=config["model"]["shared_activation"],
        shared_channels=config["model"]["shared_channels"],
        device=args.device)
    device = args.device
    checkpoint = torch.load(args.model, device)
    if device == 'cpu':
        for key in list(checkpoint['state_dict'].keys()):
            if 'module.' in key:
                checkpoint['state_dict'][key.replace(
                    'module.', '')] = checkpoint['state_dict'][key]
                del checkpoint['state_dict'][key]

    try:
def main():
    # get the argument from parser
    args = parse_arguments()

    # CONFIG -> assert if config is here
    assert args.config
    config = json.load(open(args.config))

    # DATA
    if config["dataset"] == "fastMRI":
        testdataset = dataloader.KneeMRI(config["test_loader"]["target_dir"],
                                         config["test_loader"]["noise_dirs"])
    elif config["dataset"] == "BSD500":
        testdataset = dataloader.BSD500(config["test_loader"]["target_dir"],
                                        config["sigma"])

    testloader = DataLoader(testdataset,
                            batch_size=config["test_loader"]["batch_size"],
                            shuffle=config["test_loader"]["shuffle"],
                            num_workers=config["test_loader"]["num_workers"])

    # MODEL
    model = models.DnCNN(
        config,
        depth=config["model"]["depth"],
        n_channels=config["model"]["n_channels"],
        image_channels=config["model"]["image_channels"],
        kernel_size=config["model"]["kernel_size"],
        padding=config["model"]["padding"],
        architecture=config["model"]["architecture"],
        spectral_norm=config["model"]["spectral_norm"],
        shared_activation=config["model"]["shared_activation"],
        shared_channels=config["model"]["shared_channels"],
        device=args.device)
    device = args.device
    checkpoint = torch.load(args.model, device)
    if config["dataset"] == "fastMRI":
        criterion = torch.nn.MSELoss(reduction="sum")
    elif config["dataset"] == "BSD500":
        criterion = torch.nn.MSELoss(size_average=False)

    if device == 'cpu':
        for key in list(checkpoint['state_dict'].keys()):
            if 'module.' in key:
                checkpoint['state_dict'][key.replace(
                    'module.', '')] = checkpoint['state_dict'][key]
                del checkpoint['state_dict'][key]

    try:
        model.load_state_dict(checkpoint['state_dict'], strict=True)
    except Exception as e:
        print(f'Some modules are missing: {e}')
        model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.float()
    model.eval()
    if args.device != 'cpu':
        model.to(device)

    check_directory(args.experiment)
    # LOOP OVER THE DATA
    tbar = tqdm(testloader, ncols=100)

    total_loss_val = AverageMeter()
    Signal = []
    Noise = []

    with torch.no_grad():
        for batch_idx, data in enumerate(tbar):
            if config["dataset"] == "fastMRI":
                cropp1, cropp2, cropp3, cropp4, target1, target2, target3, target4, image_id = data
                cropp = torch.cat([cropp1, cropp2, cropp3, cropp4], dim=0)
                target = torch.cat([target1, target2, target3, target4], dim=0)
            elif config["dataset"] == "BSD500":
                cropp, target = data
            if args.device != 'cpu':
                cropp, target = cropp.to(non_blocking=True), target.cuda(
                    non_blocking=True)
            batch_size = cropp.shape[0]
            output = model(cropp)

            # LOSS
            if config["dataset"] == "fastMRI":
                loss = criterion(output, target) / batch_size
            elif config["dataset"] == "BSD500":
                loss = criterion(output, target) / (output.size()[0] * 2)
            total_loss_val.update(loss.cpu())

            # PRINT INFO
            tbar.set_description('EVAL | MSELoss: {:.5f} |'.format(
                total_loss_val.average))

            # save the images
            output = output.numpy()
            target = target.numpy()
            cropp = cropp.numpy()
            output = np.squeeze(output, axis=1)
            target = np.squeeze(target, axis=1)
            cropp = np.squeeze(cropp, axis=1)
            for idx in range(batch_size):
                signal = np.linalg.norm(target[idx].flatten())
                noise = np.linalg.norm(output[idx].flatten() -
                                       target[idx].flatten())
                Signal.append(signal)
                Noise.append(noise)
            output = batch_scale(output)
            target = batch_scale(target)
            cropp = batch_scale(cropp)
            for i in range(output.shape[0]):
                j = math.floor(i / 4)
                cv.imwrite(
                    f'{args.experiment}/test_result/{image_id[j][:-4]}_{i%4}_prediction.png',
                    output[i])
                cv.imwrite(
                    f'{args.experiment}/test_result/{image_id[j][:-4]}_{i%4}_target.png',
                    target[i])
                cv.imwrite(
                    f'{args.experiment}/test_result/{image_id[j][:-4]}_{i % 4}_input.png',
                    cropp[i])

        Signal = np.asarray(Signal)
        Noise = np.asarray(Noise)
        SNR = 20 * np.log10(Signal.mean() / Noise.mean())
        print("The mean SNR over the test set is : {}".format(SNR))
        # save the metric
        metrics = {"MSE_Loss": np.round(total_loss_val.average, 8), "SNR": SNR}

        with open(f'{args.experiment}/test_result/test.txt', 'w') as f:
            for k, v in list(metrics.items()):
                f.write("%s\n" % (k + ':' + f'{v}'))

        lipschitz_cte = SingularValues(model)
        lipschitz_cte.compute_layer_sv()
        merged_list = list(
            map(lambda x, y: (x, y), lipschitz_cte.names,
                lipschitz_cte.sigmas))
        metrics = {"Layer": merged_list}
        activation = config["model"]["activation_type"]
        QP = config["model"]["QP"]
        if (activation != "relu") and (activation != "leaky_relu") and (
                activation != "prelu"):
            C, slope = model.lipschtiz_exact()
            spline = {}
            splineAll = {}
            for i in range(len(C)):
                spline[f"activation_{i}"] = C[i]
                splineAll[f"activation_{i}"] = slope[i]

            with open(
                    f'{args.experiment}/test_result/{QP}_ActivationSlopes.txt',
                    'w') as f:
                for k, v in list(spline.items()):
                    f.write("%s\n" % (k + ':' + f'{v}'))
                    torch.save(v, f"{args.experiment}/test_result/{QP}_{k}.pt")

            with open(
                    f'{args.experiment}/test_result/{QP}_ActivationSlopesAll.txt',
                    'w') as f:
                for k, v in list(splineAll.items()):
                    f.write("%s\n" % (k + ':' + f'{v}'))
                    torch.save(
                        v, f"{args.experiment}/test_result/{QP}_{k}_All.pt")

        with open(f'{args.experiment}/test_result/sv.txt', 'w') as f:
            for k, v in list(metrics.items()):
                for t in v:
                    f.write("%s\n" % (k + ':' + f'{t[0]}' + ':' + f'{t[1]}'))
def main():
    # get the argument from parser
    args = parse_arguments()

    # CONFIG -> assert if config is here
    assert args.config
    config = json.load(open(args.config))

    # MODEL
    model = models.DnCNN(config,
                         depth=config["model"]["depth"],
                         n_channels=config["model"]["n_channels"],
                         image_channels=config["model"]["image_channels"],
                         kernel_size=config["model"]["kernel_size"],
                         padding=config["model"]["padding"],
                         architecture=config["model"]["architecture"],
                         spectral_norm=config["model"]["spectral_norm"],
                         device=args.device)

    if args.activation:
        config["model"]["activation_type"] = args.activation
    if args.QP:
        config["model"]["QP"] = args.QP

    model_QP = models.DnCNN(config,
                            depth=config["model"]["depth"],
                            n_channels=config["model"]["n_channels"],
                            image_channels=config["model"]["image_channels"],
                            kernel_size=config["model"]["kernel_size"],
                            padding=config["model"]["padding"],
                            architecture=config["model"]["architecture"],
                            spectral_norm=config["model"]["spectral_norm"],
                            device=args.device)
    device = args.device
    checkpoint = torch.load(args.model, device)

    if device == 'cpu':
        for key in list(checkpoint['state_dict'].keys()):
            if 'module.' in key:
                checkpoint['state_dict'][key.replace(
                    'module.', '')] = checkpoint['state_dict'][key]
                del checkpoint['state_dict'][key]

    try:
        model.load_state_dict(checkpoint['state_dict'], strict=True)
    except Exception as e:
        print(f'Some modules are missing: {e}')
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    for name, param in model.dncnn.named_parameters():
        for name_QP, param_QP in model_QP.dncnn.named_parameters():
            if name == name_QP:
                param_QP.data = param.data

    model.float()
    model.eval()
    if args.device != 'cpu':
        model.to(device)

    model_QP.float()
    model_QP.eval()
    if args.device != 'cpu':
        model_QP.to(device)

    check_directory(args.experiment)

    if config["model"]["activation_type"] != "deepBspline":
        start_time = time.time()
        for i, module in enumerate(model_QP.modules_deepspline()):
            module.do_lipschitz_projection()
        t = time.time() - start_time
        print("--- %s seconds ---" % t)
        with open(f'{args.experiment}/QP_result/{args.QP}_time.txt', 'w') as f:
            f.write("%s\n" % ('Time :' + f'{t}'))

    lipschitz_cte = SingularValues(model_QP)
    lipschitz_cte.compute_layer_sv()
    merged_list = list(
        map(lambda x, y: (x, y), lipschitz_cte.names, lipschitz_cte.sigmas))
    metrics = {"Layer": merged_list}
    activation = config["model"]["activation_type"]
    if (activation != "relu") or (activation != "leaky_relu") or (activation !=
                                                                  "prelu"):
        C, slope = model_QP.lipschtiz_exact()
        spline = {}
        splineAll = {}
        for i in range(len(C)):
            spline[f"activation_{i}"] = C[i]
            splineAll[f"activation_{i}"] = slope[i]

        with open(
                f'{args.experiment}/QP_result/{args.QP}_ActivationSlopes.txt',
                'w') as f:
            for k, v in list(spline.items()):
                f.write("%s\n" % (k + ':' + f'{v}'))
                torch.save(v, f"{args.experiment}/QP_result/{args.QP}_{k}.pt")

        with open(
                f'{args.experiment}/QP_result/{args.QP}_ActivationSlopesAll.txt',
                'w') as f:
            for k, v in list(splineAll.items()):
                f.write("%s\n" % (k + ':' + f'{v}'))
                torch.save(
                    v, f"{args.experiment}/QP_result/{args.QP}_{k}_All.pt")

    with open(f'{args.experiment}/QP_result/{args.QP}_sv.txt', 'w') as f:
        for k, v in list(metrics.items()):
            for t in v:
                f.write("%s\n" % (k + ':' + f'{t[0]}' + ':' + f'{t[1]}'))