Example #1
0
def image_quality(gen, batch_gen, noise_shapes, num_instances=1,
    N_batches=100):
    N = batch_gen.N
    #N_batches = N//batch_gen.batch_size
    img_shape = batch_gen.img_shape
    noise_gen = noise.NoiseGenerator(noise_shapes(img_shape),
        batch_size=batch_gen.batch_size, random_seed=1234)

    batch_gen.reset(random_seed=1234)
    rmse_all = []
    ssim_all = []
    lsd_all = []

    for k in range(N_batches):
        (img_real, img_ds) = next(batch_gen)
        for i in range(num_instances):
            n = noise_gen()
            img_gen = gen.predict([img_ds]+n)
            rmse = np.sqrt(((img_real-img_gen)**2).mean(axis=(2,3,4)))
            ssim = msssim.MultiScaleSSIM(img_real, img_gen, 1.0)
            lsd = log_spectral_distance_batch(img_real, img_gen)
            rmse_all.append(rmse.flatten())
            ssim_all.append(ssim.flatten())
            lsd_all.append(lsd.flatten())

    rmse_all = np.concatenate(rmse_all)
    ssim_all = np.concatenate(ssim_all)
    lsd_all = np.concatenate(lsd_all)

    return (rmse_all, ssim_all, lsd_all)
Example #2
0
def plot_all(data_fn, gen_weights_fn, application="mch"):
    num_channels = {"mch": 1, "goes": 3}[application]
    (gen_styled, gen, styling,
     noise_shapes) = models.generator_styled(num_channels=num_channels)
    gen_styled.load_weights(gen_weights_fn)

    (wgan, batch_gen, noise_shapes,
     steps_per_epoch) = train.setup_gan(data_fn,
                                        n_samples=128,
                                        sample_random=True,
                                        application=application,
                                        random_seed=321459)

    noise_gen = noise.NoiseGenerator(noise_shapes(),
                                     batch_size=batch_gen.batch_size,
                                     random_seed=34)
    plot_samples(gen_styled,
                 batch_gen,
                 noise_gen,
                 out_fn="../figures/{}_samples.pdf".format(application))

    noise_gen_1 = noise.NoiseGenerator(noise_shapes(),
                                       batch_size=batch_gen.batch_size,
                                       random_seed=221)
    noise_gen_2 = noise.NoiseGenerator(noise_shapes(),
                                       batch_size=batch_gen.batch_size,
                                       random_seed=70)
    noise_gen_3 = noise.NoiseGenerator(noise_shapes(),
                                       batch_size=batch_gen.batch_size,
                                       random_seed=39)
    plot_styles(gen_styled,
                batch_gen, [noise_gen_1, noise_gen_2, noise_gen_3],
                out_fn="../figures/{}_styles.pdf".format(application))

    noise_gen = noise.NoiseGenerator(noise_shapes(),
                                     batch_size=batch_gen.batch_size,
                                     random_seed=241)
    plot_transition(gen,
                    styling,
                    batch_gen,
                    noise_gen,
                    out_fn="../figures/{}_transition.pdf".format(application))

    gc.collect()
Example #3
0
def train_gan(wgan,
              batch_gen,
              noise_shapes,
              steps_per_epoch,
              num_epochs,
              application="mch"):
    img_shape = batch_gen.img_shape
    noise_gen = noise.NoiseGenerator(noise_shapes(img_shape),
                                     batch_size=batch_gen.batch_size)

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch + 1, num_epochs))
        wgan.train(batch_gen, noise_gen, steps_per_epoch, training_ratio=5)
        plots.plot_samples(
            wgan.gen,
            batch_gen,
            noise_gen,
            application=application,
            out_fn="../figures/progress_{}.pdf".format(application))

    return wgan
Example #4
0
def rank_metrics_by_noise(application, run_id, data_file,
    weights_fn):
    (wgan, batch_gen_train, batch_gen_valid, _,
        noise_shapes, steps_per_epoch) = train.setup_gan(data_file,
        application=application)
    gen = wgan.gen
    noise_gen = noise.NoiseGenerator(noise_shapes(),
        batch_size=batch_gen_valid.batch_size)

    for m in list(range(0.5,2.51,0.1))+[3.0,3.5]:
        N_samples = int(fn.split("-")[-1].split(".")[0])
        gen.load_weights(weights_dir+"/"+fn)
        (ranks, crps_scores) = ensemble_ranks(gen, batch_gen_valid,
            noise_gen, num_batches=32, noise_mul=m)
        
        KS = rank_KS(ranks)
        CvM = rank_CvM(ranks) 
        DKL = rank_DKL(ranks)
        CRPS = crps_scores.mean()
        mean = ranks.mean()
        std = ranks.std()

        print(N_samples, KS, CvM, DKL, CRPS, mean, std)
Example #5
0
def rank_metrics_table(application, data_file, weights_fn, method="gan"):

    if method=="gan":
        (wgan, batch_gen_train, batch_gen_valid, batch_gen_test,
            noise_shapes, steps_per_epoch) = train.setup_gan(data_file,
            test_data_file=data_file, application=application, batch_size=64)
        gen = wgan.gen
        gen.load_weights(weights_fn)
    elif method=="rainfarm":
        (gen_det, batch_gen_train, batch_gen_valid, 
            batch_gen_test, steps_per_epoch) = train.setup_deterministic(data_file,
            test_data_file=data_file, sample_random=True, n_samples=1, batch_size=64,
            application=application, loss='mse')
        gen = GeneratorRainFARM(16, batch_gen_test.decoder)
        noise_shapes = lambda: []

    noise_gen = noise.NoiseGenerator(noise_shapes(),
        batch_size=batch_gen_valid.batch_size)

    (ranks, crps_scores) = ensemble_ranks(gen, batch_gen_test,
        noise_gen, num_batches=16)
    
    KS = rank_KS(ranks)
    CvM = rank_CvM(ranks) 
    DKL = rank_DKL(ranks)
    OP = rank_OP(ranks)
    CRPS = crps_scores.mean() 
    mean = ranks.mean()
    std = ranks.std()

    print("KS: {:.3f}".format(KS))
    print("CvM: {:.3f}".format(CvM))
    print("DKL: {:.3f}".format(DKL))
    print("OP: {:.3f}".format(OP))
    print("CRPS: {:.3f}".format(CRPS))
    print("mean: {:.3f}".format(mean))
    print("std: {:.3f}".format(std))
Example #6
0
def rank_metrics_by_time(application, data_file, out_fn,
    weights_dir, check_every=1, N_range=None):
    (wgan, batch_gen_train, batch_gen_valid, batch_gen_test,
        noise_shapes, steps_per_epoch) = train.setup_gan(data_file,
        application=application, batch_size=64)
    gen = wgan.gen
    noise_gen = noise.NoiseGenerator(noise_shapes(),
        batch_size=batch_gen_valid.batch_size)

    files = os.listdir(weights_dir)
    def get_id(fn):
        return fn.split("-")[1]
    files = sorted(fn for fn in files if get_id(fn)==application)

    def log_line(line):
        with open(out_fn, 'a') as f:
            print(line, file=f)
    log_line("N KS CvM DKL OP CRPS mean std")

    for fn in files[::check_every]:
        N_samples = int(fn.split("-")[-1].split(".")[0])
        if (N_range is not None) and not (N_range[0] <= N_samples < N_range[1]):
            continue
        gen.load_weights(weights_dir+"/"+fn)
        (ranks, crps_scores) = ensemble_ranks(gen, batch_gen_valid,
            noise_gen, num_batches=8)
        
        KS = rank_KS(ranks)
        CvM = rank_CvM(ranks) 
        DKL = rank_DKL(ranks)
        OP = rank_OP(ranks)
        CRPS = crps_scores.mean() 
        mean = ranks.mean()
        std = ranks.std()

        log_line("{} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f}".format(
            N_samples, KS, CvM, DKL, OP, CRPS, mean, std))
Example #7
0
def reconstruct_time_series_partial(images_fn, gen, noise_shapes,
    init_model, out_fn,
    time_range, h=None, last_t=None, application="mchrzc", ds_factor=16, n_ensemble=4,
    scaling_fn=path+"/../data/scale_rzc.npy", relax_lam=0.0):

    if application == "mchrzc":
        dec = data.RainRateDecoder(scaling_fn, below_val=np.log10(0.025))
    else:
        raise ValueError("Unknown application.")
    downsampler = data.LogDownsampler(min_val=dec.below_val,
        threshold_val=dec.value_range[0])

    with netCDF4.Dataset(images_fn) as ds_img:
        time = np.array(ds_img["time"][:], copy=False)
        time_dt = [datetime(1970,1,1)+timedelta(seconds=t) for t in time]
        t0 = bisect_left(time_dt, time_range[0])
        t1 = bisect_left(time_dt, time_range[1])
        images = np.array(ds_img["images"][t0:t1,...], copy=False)
        time = time[t0:t1]

    img_shape = images.shape[1:3]
    img_shape = (
        img_shape[0] - img_shape[0]%ds_factor,
        img_shape[1] - img_shape[1]%ds_factor,
    )
    noise_gen = noise.NoiseGenerator(noise_shapes(img_shape),
        batch_size=n_ensemble)

    images_ds = np.zeros(
        (images.shape[0],img_shape[0]//ds_factor,img_shape[1]//ds_factor,1),
        dtype=np.uint8
    )
    images_gen = np.zeros(
        (images.shape[0],)+img_shape+(1,n_ensemble),
        dtype=np.uint8
    )

    # this finds the nearest index in the R encoding
    def encoder():
        lR = dec.logR
        ind = np.arange(len(lR))
        ip = interp1d(lR,ind)
        def f(x):
            y = np.zeros(x.shape, dtype=np.uint8)
            valid = (x >= dec.value_range[0])
            y[valid] = ip(x[valid]).round().astype(np.uint8)
            return y
        return f
    encode = encoder()

    for k in range(images.shape[0]):
        print("{}/{}".format(k+1,images.shape[0]))
        img_real = images[k:k+1,:img_shape[0],:img_shape[1],:]
        img_real = dec(img_real)
        img_real = img_real.reshape(
            (1,1)+img_real.shape[1:])
        img_real[np.isnan(img_real)] = dec.below_val
        img_ds = downsampler(img_real)
        img_ds = dec.normalize(img_ds)
        img_ds_denorm = dec.denormalize(img_ds)
        img_ds = np.tile(img_ds, (n_ensemble,1,1,1,1))

        (n_init, n_update) = noise_gen()
            
        if (h is None) or (time[k]-last_t != 600):
            h = init_model.predict([img_ds[:,0,...], n_init])
            
        (img_gen,h) = gen.predict([img_ds, h, n_update])
        if relax_lam > 0.0:
            # nudge h towards null
            h_null = init_model.predict([
                np.zeros_like(img_ds[:,0,...]), n_init
            ])
            h = h_null + (1.0-relax_lam)*(h-h_null)
        img_gen = dec.denormalize(img_gen)
        img_gen = img_gen.transpose((1,2,3,4,0))

        images_ds[k,...] = encode(img_ds_denorm[0,...])
        images_gen[k,...] = encode(img_gen[0,...])
        last_t = time[k]

    with netCDF4.Dataset(out_fn, 'w') as ds:
        dim_height = ds.createDimension("dim_height", img_shape[0])
        dim_width = ds.createDimension("dim_width", img_shape[1])
        dim_height_ds = ds.createDimension("dim_height_ds",
            img_shape[0]/ds_factor)
        dim_width_ds = ds.createDimension("dim_width_ds",
            img_shape[1]/ds_factor)
        dim_samples = ds.createDimension("dim_samples", images.shape[0])
        dim_ensemble = ds.createDimension("dim_ensemble", n_ensemble)
        dim_channels = ds.createDimension("dim_channels", 1)

        var_params = {"zlib": True, "complevel": 9}

        def create_var(name, dims, **params):
            dtype = params.pop("dtype", np.float32)
            var = ds.createVariable(name, dtype, dims, **params)
            return var

        var_img = create_var("images",
            ("dim_samples","dim_height","dim_width","dim_channels",
                "dim_ensemble"),
            chunksizes=(1,64,64,1,1), dtype=np.uint8, **var_params)
        var_img.units = "Encoded R"
        var_img_ds = create_var("images_ds",
            ("dim_samples","dim_height_ds","dim_width_ds","dim_channels"),
            dtype=np.uint8, **var_params)
        var_img_ds.units = "Encoded R"
        var_time = create_var("time", ("dim_samples",), 
            chunksizes=(1,), dtype=np.float64, **var_params)
        var_time.units = "Seconds since 1970-01-01 00:00"

        var_img_ds[:] = images_ds
        var_img[:] = images_gen
        var_time[:] = time

    return (h, last_t)
Example #8
0
def plot_sequences_horiz(gen,
                         noise_shapes,
                         batch_gen,
                         samples=[0, 1, 2],
                         num_instances=3,
                         out_fn=None,
                         plot_stride=2,
                         random_seed=1234,
                         application="mchrzc"):

    num_samples = len(samples)
    old_batch_size = batch_gen.batch_size
    old_augment = batch_gen.augment
    old_zeros_frac = batch_gen.zeros_frac
    img_shape = batch_gen.sequences.shape[2:4]
    noise_gen = noise.NoiseGenerator(noise_shapes(img_shape),
                                     batch_size=num_samples,
                                     random_seed=random_seed)
    # force the batch generator to return the selected samples
    batch_gen.next_ind = np.array(samples)
    try:
        batch_gen.batch_size = num_samples
        batch_gen.augment = False
        batch_gen.zeros_frac = 0.0
        (seq_real, cond) = next(batch_gen)
        seq_gen = []
        for i in range(num_instances):
            seq_gen.append(gen.predict([cond] + noise_gen()))
    finally:
        batch_gen.batch_size = old_batch_size
        batch_gen.augment = old_augment
        batch_gen.zeros_frac = old_zeros_frac

    seq_real = batch_gen.decoder.denormalize(seq_real)
    cond = batch_gen.decoder.denormalize(cond)
    seq_gen = [batch_gen.decoder.denormalize(seq) for seq in seq_gen]

    num_frames = batch_gen.num_frames
    if plot_stride > 1:
        seq_real = seq_real[:, ::plot_stride, ...]
        cond = cond[:, ::plot_stride, ...]
        for i in range(len(seq_gen)):
            seq_gen[i] = seq_gen[i][:, ::plot_stride, ...]
        num_frames = seq_real.shape[1]

    num_rows = num_samples
    num_cols = num_frames
    num_rows_s = 2 + num_instances

    figsize = (num_cols * 1.5, num_rows * num_rows_s * 1.60)
    fig = plt.figure(figsize=figsize)

    gs = gridspec.GridSpec(num_rows + 1,
                           1,
                           hspace=0.05,
                           height_ratios=[1] * num_rows + [0.035])

    value_range = batch_gen.decoder.value_range

    for s in range(num_samples):
        gs_s = gridspec.GridSpecFromSubplotSpec(num_rows_s,
                                                num_cols,
                                                subplot_spec=gs[s, 0],
                                                wspace=0.05,
                                                hspace=0.05)
        for t in range(num_frames):
            plt.subplot(gs_s[0, t])
            plot_img(seq_real[s, t, :, :, 0], value_range=value_range)
            if t == 0:
                plt.ylabel("Real", fontsize=16)
                plt.text(0.01,
                         0.97,
                         "({})".format(ascii_lowercase[s]),
                         horizontalalignment='left',
                         verticalalignment='top',
                         transform=plt.gca().transAxes,
                         fontsize=16)
                if s == 0:
                    plt.title("Time \u2192", fontsize=16)
            plt.subplot(gs_s[1, t])
            plot_img(cond[s, t, :, :, 0], value_range=value_range)
            if t == 0:
                plt.ylabel("Downs.", fontsize=16)
            for k in range(num_instances):
                j = 2 + k
                plt.subplot(gs_s[j, t])
                plot_img(seq_gen[k][s, t, :, :, 0], value_range=value_range)
                if t == 0:
                    plt.ylabel("Gen. #{}".format(k + 1), fontsize=16)

    if application == 'mchrzc':
        units = "Rain rate [mm h$^{-1}$]"
        cb_tick_loc = np.array([-1, 0, 1, 2])
        cb_tick_labels = [0.1, 1, 10, 100]
    elif application == 'goescod':
        units = "Cloud optical thickness"
        cb_tick_loc = np.log([2, 10, 50, 150])
        cb_tick_labels = np.exp(cb_tick_loc).round().astype(int)

    cax = plt.subplot(gs[-1, 0]).axes
    cb = colorbar.ColorbarBase(cax,
                               norm=colors.Normalize(*value_range),
                               orientation='horizontal')
    cb.set_ticks(cb_tick_loc)
    cb.set_ticklabels(cb_tick_labels)
    cax.tick_params(labelsize=16)
    cb.set_label(units, size=16)

    if out_fn is not None:
        plt.savefig(out_fn, bbox_inches='tight')
        plt.close()