コード例 #1
0
def blurImage(img):
    # Img: tensor of shape (X, Y)
    config = common.get_config()
    KERNEL_SIZE = config.getint('DEFAULT', 'blur_size')

    if config['DEFAULT']['blur'] == 'box':
        kernel = (1/(KERNEL_SIZE**2))*np.ones((1,1,KERNEL_SIZE,KERNEL_SIZE))
    elif config['DEFAULT']['blur'] == 'gauss':
        kernel = np.expand_dims(np.expand_dims(gkern(KERNEL_SIZE), axis=0), axis=0)
    elif config['DEFAULT']['blur'] == 'none':
        Interval = astropy.visualization.MinMaxInterval()
        return Image.fromarray(Interval(np.array(img)), mode='F')
    else:    
        print("Unrecognized blur")
        return

    m = nn.Conv2d(1, 1, KERNEL_SIZE, stride=1, padding=(int((KERNEL_SIZE-1)/2),int((KERNEL_SIZE-1)/2)), padding_mode='reflect')
    kernel = torch.Tensor(kernel)
    kernel = torch.nn.Parameter( kernel ) # calling this turns tensor into "weight" parameter
    m.weight = kernel

    with torch.no_grad():
        output = m(TF.to_tensor(img).unsqueeze(0))
    Interval = astropy.visualization.MinMaxInterval()
    return Image.fromarray(Interval(output.squeeze(0).numpy()[0]), mode='F')
コード例 #2
0
def downsample(HR_torch):
    """Downsample an image using average pooling.
        HR_torch: tensor image with shape (1,H,W), in [0..1].
    """
    config = common.get_config()
    factor = config.getint('DEFAULT', 'factor')

    m = nn.AvgPool2d(factor)
    LR_torch = m(HR_torch)
    return LR_torch
コード例 #3
0
def make_baseline_figure(HR, bicubic_HR, LR, name):
    # Make baseline comparision figure.
    # HR, Bicubic, LR (nearest-neighbor upsampled).
    config = common.get_config()
    m = nn.Upsample(scale_factor=config.getint("DEFAULT", "factor"), mode='nearest')
    up_LR = m(LR.unsqueeze(0)).squeeze(0)

    ncols = 3
    nrows = 1
    fig, axes = plt.subplots(nrows, ncols, figsize=(14, 5), dpi=250)
    fig.suptitle('{}'.format(name), fontsize=16)

    n_channels = config.getint("DEFAULT", "n_channels")

    if n_channels == 1:
        HR = HR.permute(1,2,0)[:,:,0]
        bicubic_HR = bicubic_HR.permute(1,2,0)[:,:,0]
        up_LR = up_LR.permute(1,2,0)[:,:,0]
        HR_size = HR.size()
        bicubic_HR_size = bicubic_HR.size()
        up_LR_size = up_LR.size()
    else:
        HR = HR.permute(1,2,0)
        bicubic_HR = bicubic_HR.permute(1,2,0)
        up_LR = up_LR.permute(1,2,0)
        HR_size = (HR.size()[1], HR.size()[2])
        bicubic_HR_size = (bicubic_HR.size()[0], bicubic_HR.size()[1])
        up_LR_size = (up_LR.size()[0], up_LR.size()[1])


    axes[0].imshow(torch.clamp(torch.flip(HR, dims=(0,)), 0, 1), cmap='gray')
    axes[0].set_title('HR')
    axes[0].set_xlabel('({}x{}x{})'.format(HR_size[0], HR_size[1], n_channels))
    
    axes[1].imshow(torch.clamp(torch.flip(bicubic_HR, dims=(0,)), 0, 1), cmap='gray')
    axes[1].set_title('HR_bicubic')
    axes[1].set_xlabel('({}x{}x{})'.format(bicubic_HR_size[0], bicubic_HR_size[1], n_channels))
    
    axes[2].imshow(torch.clamp(torch.flip(up_LR, dims=(0,)), 0, 1), cmap='gray')
    axes[2].set_title('LR (NN-Upsampled)')
    axes[2].set_xlabel('({}x{}x{})'.format(up_LR_size[0], up_LR_size[1], n_channels))

    for ax in axes:
        ax.set_xticks([])
        ax.set_yticks([])
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.savefig("output/{}.png".format(name))
    plt.close()
コード例 #4
0
def preload_LR_HR(LR_path, HR_path):
    config = common.get_config()

    augmented_history = config.has_option('LOADING', 'augmented_history') and \
                        config.getboolean('LOADING', 'augmented_history')

    if augmented_history:
        blurred_path = config['LOADING']['path_to_blurred']
        downsampled_path = config['LOADING']['path_to_downsampled']
        HRs_blurred = sorted(glob.glob(blurred_path+"*"))
        LRs_downsampled = sorted(glob.glob(downsampled_path+"*"))

    LRs = sorted(glob.glob(LR_path+"*"))
    HRs = sorted(glob.glob(HR_path+"*"))

    for i in range(len(LRs)):
        orig_torch = common.get_image(HRs[i])
        HR_torch = orig_torch.clone()
        LR_torch = common.get_image(LRs[i])
        
        input_depth = config.getint('DEFAULT', 'input_depth')
        imsize_x = config.getint('DEFAULT', 'imsize_x')
        imsize_y = config.getint('DEFAULT', 'imsize_y')
        net_input = common.get_noise(input_depth, 'noise', (imsize_y, imsize_x)).type(state.dtype)

        m = nn.Upsample(scale_factor=config.getint("DEFAULT", "factor"), mode='bicubic')
        HR_torch_bicubic = m(LR_torch.unsqueeze(0)).squeeze(0)
        
        out =   {
            'orig_torch': orig_torch,
            #'orig_pil_blurred': orig_pil_blurred,
            'HR_torch': HR_torch,
            #'HR_pil_blurred': HR_pil_blurred,
            'LR_torch': LR_torch,
            #'LR_pil_blurred': LR_pil_blurred,
            'net_input': net_input,
            'HR_torch_bicubic': HR_torch_bicubic,
            #'HR_bicubic_blurred': HR_bicubic_blurred
            'history_low': state.HistoryTracker(augmented_history=augmented_history),
            'history_high': state.HistoryTracker(augmented_history=augmented_history),
        }
        if augmented_history:
            HR_torch_blurred = common.get_image(HRs_blurred[i])
            LR_torch_downsampled = common.get_image(LRs_downsampled[i])
            out['HR_torch_blurred'] = HR_torch_blurred
            out['LR_torch_downsampled'] = LR_torch_downsampled
        
        state.imgs.append(out)
コード例 #5
0
def load_LR_HR_imgs_sr(fname):
    '''Loads an image, resizes it, center crops and downscales.

    '''
    config = common.get_config()

    # Load fits file to [0,1] normalized torch tensor with shape (1,H,W)
    # Load png file to [0,1] normalized torch tensor with shape (3,H,W)
    orig_torch = common.get_image(fname)
    #orig_pil_blurred = blurImage(orig_pil)

    HR_torch = crop(orig_torch)
    #HR_pil_blurred = crop(orig_pil_blurred)

    # Create low resolution
    LR_torch = downsample(HR_torch)
    #LR_pil_blurred = downsample(HR_pil_blurred)

    print('HR and LR resolutions: %s, %s' % (str(HR_torch.size()), str(LR_torch.size())))

    input_depth = config.getint('DEFAULT', 'input_depth')
    imsize_x = config.getint('DEFAULT', 'imsize_x')
    imsize_y = config.getint('DEFAULT', 'imsize_y')
    net_input = common.get_noise(input_depth, 'noise', (imsize_y, imsize_x)).type(state.dtype)

    # Create bicubic upsampled versions of LR images for reference
    m = nn.Upsample(scale_factor=config.getint("DEFAULT", "factor"), mode='bicubic', align_corners=False)
    HR_torch_bicubic = m(LR_torch.unsqueeze(0)).squeeze(0)
    #HR_bicubic_blurred = LR_pil_blurred.resize(HR_pil_blurred.size, Image.BICUBIC)

    out =   {
            'orig_torch': orig_torch,
            #'orig_pil_blurred': orig_pil_blurred,
            'HR_torch': HR_torch,
            #'HR_pil_blurred': HR_pil_blurred,
            'LR_torch': LR_torch,
            #'LR_pil_blurred': LR_pil_blurred,
            'net_input': net_input,
            'HR_torch_bicubic': HR_torch_bicubic,
            #'HR_bicubic_blurred': HR_bicubic_blurred
            'history_low': state.HistoryTracker(),
            'history_high': state.HistoryTracker(),
        }

    return out
コード例 #6
0
def save_results():
    """
    state.net.eval()
    # Save output data as fits files.

    for j in range(len(state.imgs)):
        hdu = fits.PrimaryHDU(np.array(state.imgs[j]['LR_pil']))
        hdu.writeto('output/LR_np_{}.fits'.format(j))
        common.saveFigure('output/LR_Ground_Truth_{}.png'.format(j), hdu.data)

        hdu = fits.PrimaryHDU(np.array(state.imgs[j]['HR_pil']))
        hdu.writeto('output/HR_np_{}.fits'.format(j))
        common.saveFigure('output/HR_Ground_Truth_{}.png'.format(j), hdu.data)

        hdu = fits.PrimaryHDU(np.array(state.imgs[j]['HR_bicubic']))
        hdu.writeto('output/HR_bicubic_{}.fits'.format(j))
        common.saveFigure('output/HR_bicubic_{}.png'.format(j), hdu.data)

        bicubic_residual = np.array(state.imgs[j]['HR_bicubic']) - np.array(state.imgs[j]['HR_pil'])
        common.saveFigure('output/HR_bicubic_residual_{}.png'.format(j), bicubic_residual)

        with torch.no_grad():
            state.net.eval()
            data = state.net(state.imgs[j]['net_input']).cpu()
        hdu = fits.PrimaryHDU(data)
        hdu.writeto('output/network_output_{}.fits'.format(j))
        common.saveFigure('output/HR_Output_{}.png'.format(j), hdu.data[0,0])
        output_residual = hdu.data[0,0] - np.array(state.imgs[j]['HR_pil'])
        common.saveFigure('output/Output_Residual_{}.png'.format(j), output_residual)
    """
    config = common.get_config()
    for i in range(len(state.imgs)):
        experiment_name = config['DEFAULT']['experiment_name'] + "_frame_{}".format(i)
        make_summary_figure(
            state.imgs[i]['history_low'].iteration, state.imgs[i]['history_high'].iteration,
            state.imgs[i]['history_low'].psnr_HR, state.imgs[i]['history_high'].psnr_HR,
            state.imgs[i]['history_low'].psnr_LR, state.imgs[i]['history_high'].psnr_LR,
            state.imgs[i]['history_low'].target_loss, state.imgs[i]['history_high'].target_loss,
            state.imgs[i]['history_low'].training_loss, state.imgs[i]['history_high'].training_loss,
            experiment_name,
            psnr_blurred_low=state.imgs[i]['history_low'].psnr_blurred,
            psnr_blurred_high=state.imgs[i]['history_high'].psnr_blurred,
            psnr_downsampled_low=state.imgs[i]['history_low'].psnr_downsampled,
            psnr_downsampled_high=state.imgs[i]['history_high'].psnr_downsampled
        )
コード例 #7
0
def crop(orig_torch):
    """Crops a torch tensor in [0..1.], (1,H,W).
        crop_x: left x of cropping
        crop_y: lower y of cropping
    """
    config = common.get_config()
    crops = {
        "crop_x": config.getint('DEFAULT', 'crop_x'),
        "crop_y":  config.getint('DEFAULT', 'crop_y'),
    }

    imsize_x = config.getint('DEFAULT', 'imsize_x')
    imsize_y = config.getint('DEFAULT', 'imsize_y')
    # Crop the image
    HR_torch = orig_torch[:, 
        crops['crop_y']:crops['crop_y']+imsize_y,
        crops['crop_x']:crops['crop_x']+imsize_x
    ]
    return HR_torch
コード例 #8
0
#python example to train doc2vec model (with or without pre-trained word embeddings)

import gensim.models as g
import logging
import common_utils

config = common_utils.get_config()

#pretrained word embeddings
pretrained_emb = "toy_data/pretrained_word_embeddings.txt"  #None if use without pretrained embeddings

#enable logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
                    level=logging.INFO)

#train doc2vec model
model = g.Doc2Vec.load(config.enwiki_dbow)
print(model)
コード例 #9
0
def skip(
        num_input_channels=2, num_output_channels=1, 
        num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 
        filter_size_down=3, filter_size_up=3, filter_skip_size=1,
        need_sigmoid=True, need_bias=True, 
        pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU', 
        need1x1_up=True):
    """Assembles encoder-decoder with skip connections.

    Arguments:
        act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU)
        pad (string): zero|reflection (default: 'zero')
        upsample_mode (string): 'nearest|bilinear' (default: 'nearest')
        downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride')

    """
    assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip)

    n_scales = len(num_channels_down) 

    if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) :
        upsample_mode   = [upsample_mode]*n_scales

    if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)):
        downsample_mode   = [downsample_mode]*n_scales
    
    if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) :
        filter_size_down   = [filter_size_down]*n_scales

    if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) :
        filter_size_up   = [filter_size_up]*n_scales

    last_scale = n_scales - 1 

    cur_depth = None

    model = nn.Sequential()
    model_tmp = model

    input_depth = num_input_channels
    for i in range(len(num_channels_down)):

        deeper = nn.Sequential()
        skip = nn.Sequential()

        if num_channels_skip[i] != 0:
            model_tmp.add(Concat(1, skip, deeper))
        else:
            model_tmp.add(deeper)
        
        model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i])))

        if num_channels_skip[i] != 0:
            skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad))
            skip.add(bn(num_channels_skip[i]))
            skip.add(act(act_fun))
            
        # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part))

        deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i]))
        deeper.add(bn(num_channels_down[i]))
        deeper.add(act(act_fun))

        deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad))
        deeper.add(bn(num_channels_down[i]))
        deeper.add(act(act_fun))

        deeper_main = nn.Sequential()

        if i == len(num_channels_down) - 1:
            # The deepest
            k = num_channels_down[i]
        else:
            deeper.add(deeper_main)
            k = num_channels_up[i + 1]

        deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i]))

        model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad))
        model_tmp.add(bn(num_channels_up[i]))
        model_tmp.add(act(act_fun))


        if need1x1_up:
            model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad))
            model_tmp.add(bn(num_channels_up[i]))
            model_tmp.add(act(act_fun))

        input_depth = num_channels_down[i]
        model_tmp = deeper_main

    config = common_utils.get_config()
    num_output_channels = config.getint("DEFAULT", "n_channels")
    model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad))
    if need_sigmoid:
        model.add(nn.Sigmoid())

    return model
コード例 #10
0
import sr_utils
from build_closure import build_closure
from build_network import build_network

# Setup space to save checkpoints and inputs
if os.path.exists('output'):
    shutil.rmtree("output")
os.mkdir("output")
os.mkdir("output/checkpoints")
os.mkdir("output/inputs")

# Setup tensorboard summary writer
writer = SummaryWriter(log_dir='output/runs/')

# Read config file
config = common.get_config()

# Load images. Crop to produce HR, downsample to produce LR,
# create bicubic reference frames.
state.imgs = []
if config.has_section('LOADING') and config.getboolean('LOADING', 'load_precreated'):
    sr_utils.preload_LR_HR(
        config['LOADING']['path_to_LR'],
        config['LOADING']['path_to_HR']
    )
else: 
    for im_path in glob.glob(config['DEFAULT']['path_to_images']):
        state.imgs.append(sr_utils.load_LR_HR_imgs_sr(im_path))

# Get baselines, such as psnr and target loss of bicubic.
#sr_utils.get_baselines(state.imgs)
コード例 #11
0
def make_progress_figure(HR, HR_bicubic, HR_out, LR, LR_out, HR_name, LR_name):
    """Make two figures; one tracks HR progress, one tracks LR progress.
    HR, ..., LR_out: tensors of shape (1,H,W) or (1,H/factor,W/factor)"""
    
    # Make HR comparison figure
    config = common.get_config()

    ncols = 5
    nrows = 1
    fig, axes = plt.subplots(nrows, ncols, figsize=(14, 5), dpi=250)

    n_channels = config.getint("DEFAULT", "n_channels")

    if n_channels == 1:
        HR = HR.permute(1,2,0)[:,:,0]
        HR_out = HR_out.permute(1,2,0)[:,:,0]
        HR_bicubic = HR_bicubic.permute(1,2,0)[:,:,0]
        LR = LR.permute(1,2,0)[:,:,0]
        LR_out = LR_out.permute(1,2,0)[:,:,0]
        HR_size = HR.size()
        LR_size = LR.size()
    else:
        HR = HR.permute(1,2,0)
        HR_out = HR_out.permute(1,2,0)
        HR_bicubic = HR_bicubic.permute(1,2,0)
        LR = LR.permute(1,2,0)
        LR_out = LR_out.permute(1,2,0)
        HR_size = (HR.size()[0], HR.size()[1])
        LR_size = (LR.size()[0], LR.size()[1])

    fig.suptitle('{}\n({}x{}x{})'.format(HR_name, HR_size[0], HR_size[1], n_channels), fontsize=16)

    axes[0].imshow(torch.clamp(torch.flip(HR, dims=(0,)), 0, 1), cmap='gray')
    axes[0].set_title('HR')
    # Make HR output plot; include PSNR
    axes[1].imshow(torch.clamp(torch.flip(HR_out, dims=(0,)), 0, 1), cmap='gray')
    axes[1].set_title('HR Output')
    psnr_HR = compare_psnr(HR.numpy(), HR_out.numpy())
    axes[1].set_xlabel('PSNR HR: {:.2f}'.format(psnr_HR))

    axes[2].imshow(torch.clamp(torch.flip(HR_bicubic, dims=(0,)), 0, 1), cmap='gray')
    axes[2].set_title('HR_bicubic')
    psnr_bicubic = compare_psnr(HR.numpy(), HR_bicubic.numpy())
    axes[2].set_xlabel('PSNR Bicubic: {:.2f}'.format(psnr_bicubic))

    axes[3].imshow(torch.clamp(torch.flip((HR_bicubic-HR), dims=(0,)), 0, 1), cmap='gray')
    axes[3].set_title('Residual: \nBicubic - HR')
    bicubic_loss = compare_mse(HR.numpy(), HR_bicubic.numpy())
    axes[3].set_xlabel('MSE HR / Bicubic: {:.2e}'.format(bicubic_loss))

    axes[4].imshow(torch.clamp(torch.flip((HR_out-HR), dims=(0,)), 0, 1), cmap='gray')
    axes[4].set_title('Residual: \nHR Output - HR')
    target_loss = compare_mse(HR.numpy(), HR_out.numpy())
    axes[4].set_xlabel('MSE HR / HR Output: {:.2e}'.format(target_loss))

    for ax in axes:
        ax.set_xticks([])
        ax.set_yticks([])
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.savefig("output/{}.png".format(HR_name))
    plt.close()

    # Make LR comparison figure
    ncols = 3
    nrows = 1
    fig, axes = plt.subplots(nrows, ncols, figsize=(14, 5), dpi=250)
    fig.suptitle('{}\n({}x{}x{})'.format(LR_name, LR_size[0], LR_size[1], n_channels), fontsize=16)

    axes[0].imshow(torch.clamp(torch.flip(LR, dims=(0,)), 0, 1), cmap='gray')
    axes[0].set_title('LR')
    
    axes[1].imshow(torch.clamp(torch.flip(LR_out, dims=(0,)), 0, 1), cmap='gray')
    axes[1].set_title('LR Output')
    psnr_LR = compare_psnr(LR.numpy(), LR_out.numpy())
    axes[1].set_xlabel('PSNR LR: {:.2f}'.format(psnr_LR))

    axes[2].imshow(torch.clamp(torch.flip((LR_out-LR), dims=(0,)), 0, 1), cmap='gray')
    axes[2].set_title('Residual: \nLR Output - LR')
    training_loss = compare_mse(LR.numpy(), LR_out.numpy())
    axes[2].set_xlabel('MSE LR / LR Output: {:.2e}'.format(training_loss))

    for ax in axes:
        ax.set_xticks([])
        ax.set_yticks([])
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.savefig("output/{}.png".format(LR_name))
    plt.close()