Beispiel #1
0
def run(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # dataset
    train_set = NoisyBSDSDataset(args.root_dir,
                                 image_size=args.image_size,
                                 sigma=args.sigma)
    test_set = NoisyBSDSDataset(args.root_dir,
                                mode='test',
                                image_size=args.test_image_size,
                                sigma=args.sigma)

    # model
    if args.model == 'dncnn':
        net = DnCNN(args.D, C=args.C).to(device)
    elif args.model == 'udncnn':
        net = UDnCNN(args.D, C=args.C).to(device)
    elif args.model == 'dudncnn':
        net = DUDnCNN(args.D, C=args.C).to(device)
    else:
        raise NameError('Please enter: dncnn, udncnn, or dudncnn')

    # optimizer
    adam = torch.optim.Adam(net.parameters(), lr=args.lr)

    # stats manager
    stats_manager = DenoisingStatsManager()

    # experiment
    exp = nt.Experiment(net,
                        train_set,
                        test_set,
                        adam,
                        stats_manager,
                        batch_size=args.batch_size,
                        output_dir=args.output_dir,
                        perform_validation_during_training=True)

    # run
    if args.plot:
        fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(9, 7))
        exp.run(num_epochs=args.num_epochs,
                plot=lambda exp: plot(
                    exp, fig=fig, axes=axes, noisy=test_set[73][0]))
    else:
        exp.run(num_epochs=args.num_epochs)
Beispiel #2
0
    if opt.downsampling_factor is not None:
        opt.downsampling_factor = list(
            map(lambda x: int(x), opt.downsampling_factor.split(',')))

    if opt.jpeg_quality is not None:
        opt.jpeg_quality = list(
            map(lambda x: int(x), opt.jpeg_quality.split(',')))

    if not os.path.exists(opt.outputs_dir):
        os.makedirs(opt.outputs_dir)

    torch.manual_seed(opt.seed)

    if opt.arch == 'DnCNN-S':
        model = DnCNN(num_layers=17)
    elif opt.arch == 'DnCNN-B':
        model = DnCNN(num_layers=20)
    elif opt.arch == 'DnCNN-3':
        model = DnCNN(num_layers=20)

    model = model.to(device)
    criterion = nn.MSELoss(reduction='sum')

    optimizer = optim.Adam(model.parameters(), lr=opt.lr)

    dataset = Dataset(opt.images_dir, opt.patch_size, opt.gaussian_noise_level,
                      opt.downsampling_factor, opt.jpeg_quality,
                      opt.use_fast_loader)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=opt.batch_size,
def main(_):
    if not os.path.exists(args.ckpt_dir):
        os.makedirs(args.ckpt_dir)
    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)
    if not os.path.exists(args.test_dir):
        os.makedirs(args.test_dir)

    if args.use_gpu:

        # added to controll the gpu memory
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
        with tf.Session(config=tf.ConfigProto(
                gpu_options=gpu_options)) as sess:
            model = DnCNN(sess,
                          sigma=args.sigma,
                          lr=args.lr,
                          dataset=args.trainset)
            if args.phase == 'train':
                #                visualize()
                model.train()
            else:
                model.test()

    else:
        with tf.Session() as sess:
            model = DnCNN(sess,
                          sigma=args.sigma,
                          lr=args.lr,
                          dataset=args.trainset)
            if args.phase == 'train':
                model.train()
            else:
                model.test()
Beispiel #4
0
        (torch.tensor(np.real(np.fft.ifft2(Low_freq))).unsqueeze(1).float(),
         torch.tensor(np.imag(np.fft.ifft2(Low_freq))).unsqueeze(1).float()),
        dim=1)

    return High_output, Low_output


if not os.path.exists(save_dir):
    os.mkdir(save_dir)

if __name__ == '__main__':
    epoch_min = 100
    # model selection
    print('===> Building model')
    model = UNetDense(input_channels=2, image_channels=1)
    pre_model = DnCNN(image_channels=1)

    pre_model = torch.load(os.path.join(args.load_model_dir, 'model.pth'))

    initial_epoch = findLastCheckpoint(
        save_dir=save_dir)  # load the last model in matconvnet style
    # initial_epoch = 150
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        model.load_state_dict(
            torch.load(os.path.join(save_dir,
                                    'model_%03d.pth' % initial_epoch)))
        # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))

    model.train()
    pre_model.eval()
        dim=1)
    Low_output = torch.cat(
        (torch.tensor(np.real(np.fft.ifft2(Low_freq))).unsqueeze(1).float(),
         torch.tensor(np.imag(np.fft.ifft2(Low_freq))).unsqueeze(1).float()),
        dim=1)

    return High_output, Low_output


if not os.path.exists(save_dir):
    os.mkdir(save_dir)

if __name__ == '__main__':
    # model selection
    print('===> Building model')
    decompose_model = DnCNN(image_channels=1)
    # compose_model = ComposeNet(n_block=32)

    initial_epoch = findLastCheckpoint(
        save_dir=save_dir)  # load the last model in matconvnet style
    initial_epoch = 11
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        decompose_model.load_state_dict(
            torch.load(os.path.join(save_dir,
                                    'model_%03d.pth' % initial_epoch)))
        # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))

    # criterion = nn.MSELoss(reduction = 'sum')  # PyTorch 0.4.1
    # criterion = sum_squared_error()
    criterion = nn.MSELoss()
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


if __name__ == '__main__':

    args = parse_args()

    model1 = DnCNN()
    model2 = DecomposeNet(image_channels=5, n_block=32)
    model3

    # model = torch.load(os.path.join(args.model_dir, args.model_name))
    model.load_state_dict(
        torch.load(os.path.join(args.model_dir, args.model_name)))
    decompose_model.load_state_dict(
        torch.load(os.path.join(args.decom_model_dir, args.decom_model_name)))
    # model = torch.load(os.path.join(args.model_dir, args.model_name))
    log('load trained model')

    #    params = model.state_dict()
    #    print(params.values())
    #    print(params.keys())
    #
Beispiel #7
0
                        type=str,
                        default='DnCNN-S',
                        help='DnCNN-S, DnCNN-B, DnCNN-3')
    parser.add_argument('--weights_path', type=str, required=True)
    parser.add_argument('--image_path', type=str, required=True)
    parser.add_argument('--outputs_dir', type=str, required=True)
    parser.add_argument('--gaussian_noise_level', type=int)
    parser.add_argument('--jpeg_quality', type=int)
    parser.add_argument('--downsampling_factor', type=int)
    opt = parser.parse_args()

    if not os.path.exists(opt.outputs_dir):
        os.makedirs(opt.outputs_dir)

    if opt.arch == 'DnCNN-S':
        model = DnCNN(num_layers=17)
    elif opt.arch == 'DnCNN-B':
        model = DnCNN(num_layers=20)
    elif opt.arch == 'DnCNN-3':
        model = DnCNN(num_layers=20)

    state_dict = model.state_dict()
    for n, p in torch.load(opt.weights_path,
                           map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model = model.to(device)
    model.eval()
Beispiel #8
0
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


if __name__ == '__main__':

    args = parse_args()

    model_dncnn = DnCNN()
    model_lowfreq = DecomposeNetLowFreq(image_channels=3, n_block=24)
    # model_lowfreq = DecomposeNet(image_channels=5, n_block=32)

    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):
        model_dncnn = torch.load(os.path.join(args.model_dir, 'model.pth'))
        model_lowfreq.load_state_dict(
            torch.load(os.path.join(args.low_model_dir, args.low_model_name)))

        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        model_dncnn.load_state_dict(
            torch.load(os.path.join(args.model_dir, args.model_name)))
        model_lowfreq.load_state_dict(
            torch.load(os.path.join(args.low_model_dir, args.low_model_name)))
Beispiel #9
0
        dim=1)
    Low_output = torch.cat(
        (torch.tensor(np.real(np.fft.ifft2(Low_freq))).unsqueeze(1).float(),
         torch.tensor(np.imag(np.fft.ifft2(Low_freq))).unsqueeze(1).float()),
        dim=1)

    return High_output, Low_output


if not os.path.exists(save_dir):
    os.mkdir(save_dir)

if __name__ == '__main__':
    # model selection
    print('===> Building model')
    model = DnCNN(image_channels=2)
    u_model = UNet(input_channels=1, image_channels=1)
    model.load_state_dict(
        torch.load(os.path.join(args.load_model_dir, args.load_model_name)))

    initial_epoch = findLastCheckpoint(
        save_dir=save_dir)  # load the last model in matconvnet style
    # initial_epoch = 150
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        u_model.load_state_dict(
            torch.load(os.path.join(save_dir,
                                    'model_%03d.pth' % initial_epoch)))

        # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))
Beispiel #10
0
def train_model(config):
    # Define hyper-parameters.
    depth = int(config["DnCNN"]["depth"])
    n_channels = int(config["DnCNN"]["n_channels"])
    img_channel = int(config["DnCNN"]["img_channel"])
    kernel_size = int(config["DnCNN"]["kernel_size"])
    use_bnorm = config.getboolean("DnCNN", "use_bnorm")
    epochs = int(config["DnCNN"]["epoch"])
    batch_size = int(config["DnCNN"]["batch_size"])
    train_data_dir = config["DnCNN"]["train_data_dir"]
    test_data_dir = config["DnCNN"]["test_data_dir"]
    eta_min = float(config["DnCNN"]["eta_min"])
    eta_max = float(config["DnCNN"]["eta_max"])
    dose = float(config["DnCNN"]["dose"])
    model_save_dir = config["DnCNN"]["model_save_dir"]

    # Save logs to txt file.
    log_dir = config["DnCNN"]["log_dir"]
    log_dir = Path(log_dir) / "dose{}".format(str(int(dose * 100)))
    log_file = log_dir / "train_result.txt"

    # Define device.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Initiate a DnCNN instance.
    # Load the model to device and set the model to training.
    model = DnCNN(depth=depth, n_channels=n_channels,
                  img_channel=img_channel,
                  use_bnorm=use_bnorm,
                  kernel_size=kernel_size)

    model = model.to(device)
    model.train()

    # Define loss criterion and optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)
    criterion = LossFunc(reduction="mean")

    # Get a validation test set and corrupt with noise for validation performance.
    # For every epoch, use this pre-determined noisy images.
    test_file_list = glob.glob(test_data_dir + "/*.png")
    xs_test = []
    # Can't directly convert the xs_test from list to ndarray because some images are 512*512
    # while the rest are 256*256.
    for i in range(len(test_file_list)):
        img = cv2.imread(test_file_list[i], 0)
        img = np.array(img, dtype="float32") / 255.0
        img = np.expand_dims(img, axis=0)
        img_noisy, _ = nm(img, eta_min, eta_max, dose, t=100)
        xs_test.append((img_noisy, img))

    # Train the model.
    loss_store = []
    epoch_loss_store = []
    psnr_store = []
    ssim_store = []

    psnr_tr_store = []
    ssim_tr_store = []
    
    loss_mse = torch.nn.MSELoss()

    dtype = torch.cuda.FloatTensor
    # load vgg network
    vgg = Vgg16().type(dtype)
    
    
    for epoch in range(epochs):
        # For each epoch, generate clean augmented patches from the training directory.
        # Convert the data from uint8 to float32 then scale them to make it in [0, 1].
        # Then make the patches to be of shape [N, C, H, W],
        # where N is the batch size, C is the number of color channels.
        # H and W are height and width of image patches.
        xs = dg.datagenerator(data_dir=train_data_dir)
        xs = xs.astype("float32") / 255.0
        xs = torch.from_numpy(xs.transpose((0, 3, 1, 2)))

        train_set = dg.DenoisingDatatset(xs, eta_min, eta_max, dose)
        train_loader = DataLoader(dataset=train_set, num_workers=4,
                                  drop_last=True, batch_size=batch_size,
                                  shuffle=True)  # TODO: if drop_last=True, the dropping in the
                                                 # TODO: data_generator is not necessary?

        # train_loader_test = next(iter(train_loader))

        t_start = timer()
        epoch_loss = 0
        for idx, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            img_batch_read = len(inputs)

            optimizer.zero_grad()

            outputs = model(inputs)
            
            # We can use labels for both style and content image
            
                # style image
#             style_transform = transforms.Compose([
#             normalize_tensor_transform()      # normalize with ImageNet values
#             ])
            
#             labels_t = style_transform(labels)
                        
            labels_t = labels.repeat(1, 3, 1, 1)
            outputs_t = outputs.repeat(1, 3, 1, 1)            
            
            y_c_features = vgg(labels_t)
            style_gram = [gram(fmap) for fmap in y_c_features]
            
            y_hat_features = vgg(outputs_t)
            y_hat_gram = [gram(fmap) for fmap in y_hat_features]            
            
            # calculate style loss
            style_loss = 0.0
            for j in range(4):
                style_loss += loss_mse(y_hat_gram[j], style_gram[j][:img_batch_read])
            style_loss = STYLE_WEIGHT*style_loss
            aggregate_style_loss = style_loss

            # calculate content loss (h_relu_2_2)
            recon = y_c_features[1]      
            recon_hat = y_hat_features[1]
            content_loss = CONTENT_WEIGHT*loss_mse(recon_hat, recon)
            aggregate_content_loss = content_loss
            
            loss = aggregate_content_loss + aggregate_style_loss
#             loss = criterion(outputs, labels)
            
            loss_store.append(loss.item())
            epoch_loss += loss.item()

            loss.backward()

            optimizer.step()

            if idx % 100 == 0:
                print("Epoch [{} / {}], step [{} / {}], loss = {:.5f}, lr = {:.6f}, elapsed time = {:.2f}s".format(
                    epoch + 1, epochs, idx, len(train_loader), loss.item(), *scheduler.get_last_lr(), timer()-t_start))

        epoch_loss_store.append(epoch_loss / len(train_loader))

        # At each epoch validate the result.
        model = model.eval()

        # # Firstly validate on training sets. This takes a long time so I commented.
        # tr_psnr = []
        # tr_ssim = []
        # # t_start = timer()
        # with torch.no_grad():
        #     for idx, train_data in enumerate(train_loader):
        #         inputs, labels = train_data
        #         # print(inputs.shape)
        #         # inputs = np.expand_dims(inputs, axis=0)
        #         # inputs = torch.from_numpy(inputs).to(device)
        #         inputs = inputs.to(device)
        #         labels = labels.squeeze().numpy()
        #
        #         outputs = model(inputs)
        #         outputs = outputs.squeeze().cpu().detach().numpy()
        #
        #         tr_psnr.append(peak_signal_noise_ratio(labels, outputs))
        #         tr_ssim.append(structural_similarity(outputs, labels))
        # psnr_tr_store.append(sum(tr_psnr) / len(tr_psnr))
        # ssim_tr_store.append(sum(tr_ssim) / len(tr_ssim))
        # # print("Elapsed time = {}".format(timer() - t_start))
        #
        # print("Validation on train set: epoch [{} / {}], aver PSNR = {:.2f}, aver SSIM = {:.4f}".format(
        #     epoch + 1, epochs, psnr_tr_store[-1], ssim_tr_store[-1]))

        # Validate on test set
        val_psnr = []
        val_ssim = []
        with torch.no_grad():
            for idx, test_data in enumerate(xs_test):
                inputs, labels = test_data
                inputs = np.expand_dims(inputs, axis=0)
                inputs = torch.from_numpy(inputs).to(device)
                labels = labels.squeeze()

                outputs = model(inputs)
                outputs = outputs.squeeze().cpu().detach().numpy()

                val_psnr.append(peak_signal_noise_ratio(labels, outputs))
                val_ssim.append(structural_similarity(outputs, labels))

        psnr_store.append(sum(val_psnr) / len(val_psnr))
        ssim_store.append(sum(val_ssim) / len(val_ssim))

        print("Validation on test set: epoch [{} / {}], aver PSNR = {:.2f}, aver SSIM = {:.4f}".format(
            epoch + 1, epochs, psnr_store[-1], ssim_store[-1]))

        # Set model to train mode again.
        model = model.train()

        scheduler.step()

        # Save model
        save_model(model, model_save_dir, epoch, dose * 100)

        # Save the loss and validation PSNR, SSIM.

        if not log_dir.exists():
            Path.mkdir(log_dir)
        with open(log_file, "a+") as fh:
            # fh.write("{} Epoch [{} / {}], loss = {:.6f}, train PSNR = {:.2f}dB, train SSIM = {:.4f}, "
            #          "validation PSNR = {:.2f}dB, validation SSIM = {:.4f}".format(
            #          datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),
            #          epoch + 1, epochs, epoch_loss_store[-1],
            #          psnr_tr_store[-1], ssim_tr_store[-1],
            #          psnr_store[-1], ssim_store[-1]))
            fh.write("{} Epoch [{} / {}], loss = {:.6f}, "
                     "validation PSNR = {:.2f}dB, validation SSIM = {:.4f}\n".format(
                     datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),
                     epoch + 1, epochs, epoch_loss_store[-1],
                     psnr_store[-1], ssim_store[-1]))

        # np.savetxt(log_file, np.hstack((epoch + 1, epoch_loss_store[-1], psnr_store[-1], ssim_store[-1])), fmt="%.6f", delimiter=",  ")

        fig, ax = plt.subplots()
        ax.plot(loss_store[-len(train_loader):])
        ax.set_title("Last 1862 losses")
        ax.set_xlabel("iteration")
        fig.show()
Beispiel #11
0
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


if __name__ == '__main__':

    args = parse_args()

    model = DnCNN()

    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):
        model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        # model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
        model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

#    params = model.state_dict()
#    print(params.values())
#    print(params.keys())
#
#    for key, value in params.items():
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


if __name__ == '__main__':

    args = parse_args()

    pre_model = DnCNN()
    unet_model = UNet(input_channels=2, image_channels=1)

    pre_model = torch.load(os.path.join(args.two_model_dir, 'model.pth'))
    unet_model.load_state_dict(
        torch.load(os.path.join(args.three_model_dir, args.three_model_name)))

    # model = torch.load(os.path.join(args.model_dir, args.model_name))
    log('load trained model')

    #    params = model.state_dict()
    #    print(params.values())
    #    print(params.keys())
    #
    #    for key, value in params.items():
    #        print(key)    # parameter name
Beispiel #13
0
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


if __name__ == '__main__':

    args = parse_args()

    high_model = DnCNN()
    low_model = DnCNN()

    if not os.path.exists(
            os.path.join(args.high_model_dir, args.high_model_name)):
        high_model = torch.load(os.path.join(args.high_model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        high_model.load_state_dict(
            torch.load(os.path.join(args.high_model_dir,
                                    args.high_model_name)))
        low_model.load_state_dict(
            torch.load(os.path.join(args.low_model_dir, args.low_model_name)))
        # model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


if __name__ == '__main__':

    args = parse_args()

    low_model = DnCNN()
    res_model = DnCNN()
    model = DnCNN()

    if not os.path.exists(os.path.join(args.low_model_dir,
                                       args.low_model_name)):
        low_model = torch.load(os.path.join(args.low_model_dir, 'model.pth'))
        res_model = torch.load(os.path.join(args.res_model_dir, 'model.pth'))
        model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        low_model.load_state_dict(
            torch.load(os.path.join(args.low_model_dir, args.low_model_name)))
        res_model.load_state_dict(
            torch.load(os.path.join(args.res_model_dir, args.res_model_name)))
Beispiel #15
0
                        help="evaluation interval")
    parser.add_argument("--save_interval",
                        default=1,
                        type=int,
                        help="number of epochs")

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = get_args()
    # device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = DnCNN().to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr)
    criterion = nn.MSELoss()
    training_data_loader, testing_data_loader = dataloader(
        args.train_dir, args.test_dir, args.crop_size, args.batch_size)

    mean = args.noise_mean
    stddev = args.noise_std
    num_epochs = args.num_epochs
    for epoch in range(1, num_epochs + 1):
        train(epoch, model, optimizer, training_data_loader, mean, stddev,
              criterion)

        if epoch % args.eval_interval == 0:
            validate(model, testing_data_loader, mean, stddev, criterion)
Beispiel #16
0
def main():
    start = time.time()

    parser = argparse. ArgumentParser(description='Gamma-Spectra Denoising Trainer')
    parser.add_argument('--dettype', type=str, default='HPGe', help='detector type to train {HPGe, NaI, CZT}')
    parser.add_argument('--test_set', type=str, default='data/training.h5', help='h5 file with training vectors')
    parser.add_argument('--all', default=False, help='denoise all examples in test_set file', action='store_true')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size for denoising')
    parser.add_argument('--seed', type=int, help='random seed')
    parser.add_argument('--model', type=str, default='models/best_model.pt', help='location of model to use')
    parser.add_argument('--outdir', type=str, help='location to save output plots')
    parser.add_argument('--outfile', type=str, help='location to save output data', default='denoised.h5')
    parser.add_argument('--savefigs', help='saves plots of results', default=False, action='store_true')
    args = parser.parse_args()

    # if output directory is not provided, save plots to model directory
    if not args.outdir:
        args.outdir = os.path.dirname(args.model)
    else:
        # make sure output dirs exists
        os.makedirs(args.outdir, exist_ok=True)
       
    # make sure data files exist
    assert os.path.exists(args.test_set), f'Cannot find testset vectors file {args.test_set}'


    # detect gpus and setup environment variables
    device_ids = setup_gpus()
    print(f'Cuda devices found: {[torch.cuda.get_device_name(i) for i in device_ids]}')

    print('Loading datasets')
    test_data = load_data(args.test_set, args.dettype.upper())
    noisy_spectra = test_data['noisy_spectrum']
    clean_spectra = test_data['spectrum']
    spectra_keV = test_data['keV']

    noisy_spectra = np.expand_dims(noisy_spectra, axis=1)
    clean_spectra = np.expand_dims(clean_spectra, axis=1)

    assert noisy_spectra.shape == clean_spectra.shape, 'Mismatch between shapes of training and target data'

    # load parameters for model
    params = pickle.load(open(args.model.replace('.pt','.npy'),'rb'))['model']

    train_mean = params['train_mean'] 
    train_std = params['train_std'] 

    if not args.seed:
        args.seed = params['train_seed']

    # applying random seed for reproducability
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # create dataset for denoising, if not 'all' use training seed to recreate validation set
    if not args.all:
        _, x_val, _, y_val = train_test_split(noisy_spectra, clean_spectra, test_size = 0.1, random_state=args.seed)
        val_dataset = TensorDataset(torch.Tensor(x_val), torch.Tensor(y_val))
    else:
        val_dataset = TensorDataset(torch.Tensor(noisy_spectra), torch.Tensor(clean_spectra))

    print(f'Number of examples to denoise: {len(val_dataset)}')

    # create batched data loaders for model
    val_loader = DataLoader(dataset=val_dataset, num_workers=os.cpu_count(), batch_size=args.batch_size, shuffle=False)
    print(f'Number of batches {len(val_loader)}')

    # create and load model
    if params['model_name'] == 'DnCNN':
        model = DnCNN(num_channels=params['num_channels'], num_layers=params['num_layers'], \
                      kernel_size=params['kernel_size'], stride=params['stride'], num_filters=params['num_filters']) 
    elif params['model_name'] == 'DnCNN-res':
        model = DnCNN_Res(num_channels=params['num_channels'], num_layers=params['num_layers'], \
                      kernel_size=params['kernel_size'], stride=params['stride'], num_filters=params['num_filters']) 
    else:
        print(f'Model name {params["model_name"]} is not supported.')
        return 1

    # prepare model for data parallelism (use multiple GPUs)
    model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()

    # loaded saved model
    print(f'Loading weights for {params["model_name"]} model from {args.model} for {params["model_type"]}')
    model.load_state_dict(torch.load(args.model))

    # Main training loop

    print(f'Denoising spectra')
    model.eval() 
    total_psnr_noisy = 0
    total_psnr_denoised = 0

    denoised = []
    with torch.no_grad():
        for num, (noisy_spectra, clean_spectra) in enumerate(val_loader, start=1):

            # move batch to GPU
            noisy_spectra = Variable(noisy_spectra.cuda())
            clean_spectra = Variable(clean_spectra.cuda())

            # make predictions
            preds = model((noisy_spectra-train_mean)/train_std)

            # calculate PSNR 
            clean_spectra = clean_spectra.cpu().numpy().astype(np.float32)
            noisy_spectra = noisy_spectra.cpu().numpy().astype(np.float32)
            preds = preds.cpu().numpy().astype(np.float32)
            psnr_noisy = psnr_of_batch(clean_spectra, noisy_spectra)

            # save denoised spectrum
            if params['model_type'] == 'Gen-spectrum':
                denoised_spectrum = preds
            else:
                denoised_spectrum = noisy_spectra-preds 

            # add batch of denoised spectra to list of denoised spectra
            denoised.extend(denoised_spectrum.tolist()) 

            psnr_denoised = psnr_of_batch(clean_spectra, denoised_spectrum)
            total_psnr_noisy += psnr_noisy
            total_psnr_denoised += psnr_denoised
            print(f'[{num}/{len(val_loader)}] PSNR {psnr_noisy} --> {psnr_denoised}, increase of {psnr_denoised-psnr_noisy}')
            if args.savefigs:
                compare_results(spectra_keV, clean_spectra[0,0,:], noisy_spectra[0,0,:], preds[0,0,:], args.outdir, str(num))

    # save denoised data to file, currently only supports entire dataset
    if args.all:
        assert len(test_data['noisy_spectrum']) == len(denoised), f'{len(test_data["noisy_spectrum"])} examples yet {len(denoised)} denoised' 
        denoised = np.squeeze(np.array(denoised))
        test_data['noisy_spectrum'] = denoised 
        outfile = os.path.join(args.outdir, args.outfile)
        print(f'Saving denoised spectrum to {outfile}')
        save_dataset(args.dettype.upper(), test_data, outfile)

    avg_psnr_noisy = total_psnr_noisy/len(val_loader)
    avg_psnr_denoised = total_psnr_denoised/len(val_loader)

    print(f'Average PSNR: {avg_psnr_denoised}, average increase of {avg_psnr_denoised-avg_psnr_noisy}')

    print(f'Script completed in {time.time()-start:.2f} secs')
    return 0
Beispiel #17
0
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()



if __name__ == '__main__':

    args = parse_args()

    model = DnCNN()
    decompose_model = DecomposeNet(image_channels=5, n_block=32)

    if not os.path.exists(os.path.join(args.decom_model_dir, args.decom_model_name)):
        decompose_model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        # model = torch.load(os.path.join(args.model_dir, args.model_name))
        model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
        decompose_model.load_state_dict(torch.load(os.path.join(args.decom_model_dir, args.decom_model_name)))
        # model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

#    params = model.state_dict()
#    print(params.values())
n_epoch = args.epoch
sigma = args.sigma
save_dir = args.save_dir

toTensor = transforms.ToTensor()
toPILImage = transforms.ToPILImage()

# save_dir = os.path.join('models', args.model+'_' + 'sigma' + str(sigma))

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

if __name__ == '__main__':
    # model selection
    print('===> Building model')
    model = DnCNN()
    low_model = DnCNN()
    low_model.load_state_dict(torch.load(os.path.join(args.low_model_dir, args.low_model_name)))

    initial_epoch = findLastCheckpoint(save_dir=save_dir)  # load the last model in matconvnet style
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        # model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)))
        # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))
    model.train()
    low_model.eval()
    # criterion = nn.MSELoss(reduction = 'sum')  # PyTorch 0.4.1
    # criterion = sum_squared_error()
    criterion = nn.MSELoss()
    Edge_enhance = torch.FloatTensor(args.batch_size, 1, 40, 40)
    if cuda:
Beispiel #19
0
    parser.add_argument("--crop_size", default=224, type=int,
                        help="size to resize image to")
    parser.add_argument("--model_path",
                        default='trained_model/model_epoch_10.pth',
                        help="path to saved model")

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = get_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = DnCNN().to(device)
    model = torch.load(args.model_path,  map_location=device)
    model = model['arch']

    loader = transforms.Compose([
        transforms.Resize((args.crop_size, args.crop_size)),
        transforms.ToTensor()
    ])

    img_path = args.image_path

    # pass the image into the image_loader function
    image = image_loader(img_path, loader, device)

    # get prediction
    predict(model, image, args.save_image_path)
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


if __name__ == '__main__':

    args = parse_args()

    decompose_model = DnCNN(image_channels=5)

    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):
        decompose_model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        decompose_model.load_state_dict(
            torch.load(os.path.join(args.model_dir, args.model_name)))
        # model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

#    params = model.state_dict()
#    print(params.values())
#    print(params.keys())
#
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


if __name__ == '__main__':

    args = parse_args()

    model = DnCNN()
    decompose_model = DecomposeNet(image_channels=5, n_block=32)

    if not os.path.exists(
            os.path.join(args.decom_model_dir, args.decom_model_name)):
        decompose_model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        model = torch.load(os.path.join(args.model_dir, args.model_name))
        # model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
        decompose_model.load_state_dict(
            torch.load(
                os.path.join(args.decom_model_dir, args.decom_model_name)))
        # model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')
Beispiel #22
0
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()



if __name__ == '__main__':

    args = parse_args()



    model = DnCNN()



    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):
        model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
        # model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

#    params = model.state_dict()
#    print(params.values())
#    print(params.keys())
        dim=1)
    Low_output = torch.cat(
        (torch.tensor(np.real(np.fft.ifft2(Low_freq))).unsqueeze(1).float(),
         torch.tensor(np.imag(np.fft.ifft2(Low_freq))).unsqueeze(1).float()),
        dim=1)

    return High_output, Low_output


if not os.path.exists(save_dir):
    os.mkdir(save_dir)

if __name__ == '__main__':
    # model selection
    print('===> Building model')
    model = DnCNN(image_channels=2)

    initial_epoch = findLastCheckpoint(
        save_dir=save_dir)  # load the last model in matconvnet style
    # initial_epoch = 150
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        # model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)))
        # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))
    model.train()
    criterion = nn.MSELoss()

    if cuda:
        model = model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


if __name__ == '__main__':

    args = parse_args()

    model = DnCNN(image_channels=2)

    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):
        model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        model.load_state_dict(
            torch.load(os.path.join(args.model_dir, args.model_name)))
        # model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

#    params = model.state_dict()
#    print(params.values())
#    print(params.keys())
#
sigma = args.sigma
save_dir = args.save_dir

toTensor = transforms.ToTensor()
toPILImage = transforms.ToPILImage()

# save_dir = os.path.join('models', args.model+'_' + 'sigma' + str(sigma))

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

if __name__ == '__main__':
    # model selection
    print('===> Building model')
    model = Model()
    low_model = DnCNN()
    low_model.load_state_dict(torch.load(os.path.join(args.low_model_dir, args.low_model_name)))

    initial_epoch = findLastCheckpoint(save_dir=save_dir)  # load the last model in matconvnet style
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        # model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)))
        # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))
    model.train()
    low_model.eval()
    # criterion = nn.MSELoss(reduction = 'sum')  # PyTorch 0.4.1
    # criterion = sum_squared_error()
    criterion = nn.MSELoss()
    Edge_enhance = torch.FloatTensor(args.batch_size, 1, 40, 40)
    if cuda:
        model = model.cuda()
Beispiel #26
0
def main(_):
    if not os.path.exists(args.ckpt_dir):
        os.makedirs(args.ckpt_dir)
    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)
    if not os.path.exists(args.test_dir):
        os.makedirs(args.test_dir)

    if args.use_gpu:
        # added to control the gpu memory
        print("GPU\n")
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
        with tf.Session(config=tf.ConfigProto(
                gpu_options=gpu_options)) as sess:
            model = DnCNN(sess,
                          sigma=args.sigma,
                          lr=args.lr,
                          epoch=args.epoch,
                          batch_size=args.batch_size)
            if args.phase == 'train':
                model.train()
            else:
                model.test()

    else:
        print("CPU\n")
        with tf.Session() as sess:
            model = DnCNN(sess,
                          sigma=args.sigma,
                          lr=args.lr,
                          epoch=args.epoch,
                          batch_size=args.batch_size)
            if args.phase == 'train':
                model.train()
            else:
                model.test()
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='jet')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


if __name__ == '__main__':

    args = parse_args()

    model = DnCNN()
    decompose_model = DecomposeNet(image_channels=5, n_block=32)

    if not os.path.exists(
            os.path.join(args.decom_model_dir, args.decom_model_name)):
        decompose_model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        decompose_model.load_state_dict(
            torch.load(
                os.path.join(args.decom_model_dir, args.decom_model_name)))
        # model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')