def image_quality(gen, batch_gen, noise_shapes, num_instances=1, N_batches=100): N = batch_gen.N #N_batches = N//batch_gen.batch_size img_shape = batch_gen.img_shape noise_gen = noise.NoiseGenerator(noise_shapes(img_shape), batch_size=batch_gen.batch_size, random_seed=1234) batch_gen.reset(random_seed=1234) rmse_all = [] ssim_all = [] lsd_all = [] for k in range(N_batches): (img_real, img_ds) = next(batch_gen) for i in range(num_instances): n = noise_gen() img_gen = gen.predict([img_ds]+n) rmse = np.sqrt(((img_real-img_gen)**2).mean(axis=(2,3,4))) ssim = msssim.MultiScaleSSIM(img_real, img_gen, 1.0) lsd = log_spectral_distance_batch(img_real, img_gen) rmse_all.append(rmse.flatten()) ssim_all.append(ssim.flatten()) lsd_all.append(lsd.flatten()) rmse_all = np.concatenate(rmse_all) ssim_all = np.concatenate(ssim_all) lsd_all = np.concatenate(lsd_all) return (rmse_all, ssim_all, lsd_all)
def plot_all(data_fn, gen_weights_fn, application="mch"): num_channels = {"mch": 1, "goes": 3}[application] (gen_styled, gen, styling, noise_shapes) = models.generator_styled(num_channels=num_channels) gen_styled.load_weights(gen_weights_fn) (wgan, batch_gen, noise_shapes, steps_per_epoch) = train.setup_gan(data_fn, n_samples=128, sample_random=True, application=application, random_seed=321459) noise_gen = noise.NoiseGenerator(noise_shapes(), batch_size=batch_gen.batch_size, random_seed=34) plot_samples(gen_styled, batch_gen, noise_gen, out_fn="../figures/{}_samples.pdf".format(application)) noise_gen_1 = noise.NoiseGenerator(noise_shapes(), batch_size=batch_gen.batch_size, random_seed=221) noise_gen_2 = noise.NoiseGenerator(noise_shapes(), batch_size=batch_gen.batch_size, random_seed=70) noise_gen_3 = noise.NoiseGenerator(noise_shapes(), batch_size=batch_gen.batch_size, random_seed=39) plot_styles(gen_styled, batch_gen, [noise_gen_1, noise_gen_2, noise_gen_3], out_fn="../figures/{}_styles.pdf".format(application)) noise_gen = noise.NoiseGenerator(noise_shapes(), batch_size=batch_gen.batch_size, random_seed=241) plot_transition(gen, styling, batch_gen, noise_gen, out_fn="../figures/{}_transition.pdf".format(application)) gc.collect()
def train_gan(wgan, batch_gen, noise_shapes, steps_per_epoch, num_epochs, application="mch"): img_shape = batch_gen.img_shape noise_gen = noise.NoiseGenerator(noise_shapes(img_shape), batch_size=batch_gen.batch_size) for epoch in range(num_epochs): print("Epoch {}/{}".format(epoch + 1, num_epochs)) wgan.train(batch_gen, noise_gen, steps_per_epoch, training_ratio=5) plots.plot_samples( wgan.gen, batch_gen, noise_gen, application=application, out_fn="../figures/progress_{}.pdf".format(application)) return wgan
def rank_metrics_by_noise(application, run_id, data_file, weights_fn): (wgan, batch_gen_train, batch_gen_valid, _, noise_shapes, steps_per_epoch) = train.setup_gan(data_file, application=application) gen = wgan.gen noise_gen = noise.NoiseGenerator(noise_shapes(), batch_size=batch_gen_valid.batch_size) for m in list(range(0.5,2.51,0.1))+[3.0,3.5]: N_samples = int(fn.split("-")[-1].split(".")[0]) gen.load_weights(weights_dir+"/"+fn) (ranks, crps_scores) = ensemble_ranks(gen, batch_gen_valid, noise_gen, num_batches=32, noise_mul=m) KS = rank_KS(ranks) CvM = rank_CvM(ranks) DKL = rank_DKL(ranks) CRPS = crps_scores.mean() mean = ranks.mean() std = ranks.std() print(N_samples, KS, CvM, DKL, CRPS, mean, std)
def rank_metrics_table(application, data_file, weights_fn, method="gan"): if method=="gan": (wgan, batch_gen_train, batch_gen_valid, batch_gen_test, noise_shapes, steps_per_epoch) = train.setup_gan(data_file, test_data_file=data_file, application=application, batch_size=64) gen = wgan.gen gen.load_weights(weights_fn) elif method=="rainfarm": (gen_det, batch_gen_train, batch_gen_valid, batch_gen_test, steps_per_epoch) = train.setup_deterministic(data_file, test_data_file=data_file, sample_random=True, n_samples=1, batch_size=64, application=application, loss='mse') gen = GeneratorRainFARM(16, batch_gen_test.decoder) noise_shapes = lambda: [] noise_gen = noise.NoiseGenerator(noise_shapes(), batch_size=batch_gen_valid.batch_size) (ranks, crps_scores) = ensemble_ranks(gen, batch_gen_test, noise_gen, num_batches=16) KS = rank_KS(ranks) CvM = rank_CvM(ranks) DKL = rank_DKL(ranks) OP = rank_OP(ranks) CRPS = crps_scores.mean() mean = ranks.mean() std = ranks.std() print("KS: {:.3f}".format(KS)) print("CvM: {:.3f}".format(CvM)) print("DKL: {:.3f}".format(DKL)) print("OP: {:.3f}".format(OP)) print("CRPS: {:.3f}".format(CRPS)) print("mean: {:.3f}".format(mean)) print("std: {:.3f}".format(std))
def rank_metrics_by_time(application, data_file, out_fn, weights_dir, check_every=1, N_range=None): (wgan, batch_gen_train, batch_gen_valid, batch_gen_test, noise_shapes, steps_per_epoch) = train.setup_gan(data_file, application=application, batch_size=64) gen = wgan.gen noise_gen = noise.NoiseGenerator(noise_shapes(), batch_size=batch_gen_valid.batch_size) files = os.listdir(weights_dir) def get_id(fn): return fn.split("-")[1] files = sorted(fn for fn in files if get_id(fn)==application) def log_line(line): with open(out_fn, 'a') as f: print(line, file=f) log_line("N KS CvM DKL OP CRPS mean std") for fn in files[::check_every]: N_samples = int(fn.split("-")[-1].split(".")[0]) if (N_range is not None) and not (N_range[0] <= N_samples < N_range[1]): continue gen.load_weights(weights_dir+"/"+fn) (ranks, crps_scores) = ensemble_ranks(gen, batch_gen_valid, noise_gen, num_batches=8) KS = rank_KS(ranks) CvM = rank_CvM(ranks) DKL = rank_DKL(ranks) OP = rank_OP(ranks) CRPS = crps_scores.mean() mean = ranks.mean() std = ranks.std() log_line("{} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f}".format( N_samples, KS, CvM, DKL, OP, CRPS, mean, std))
def reconstruct_time_series_partial(images_fn, gen, noise_shapes, init_model, out_fn, time_range, h=None, last_t=None, application="mchrzc", ds_factor=16, n_ensemble=4, scaling_fn=path+"/../data/scale_rzc.npy", relax_lam=0.0): if application == "mchrzc": dec = data.RainRateDecoder(scaling_fn, below_val=np.log10(0.025)) else: raise ValueError("Unknown application.") downsampler = data.LogDownsampler(min_val=dec.below_val, threshold_val=dec.value_range[0]) with netCDF4.Dataset(images_fn) as ds_img: time = np.array(ds_img["time"][:], copy=False) time_dt = [datetime(1970,1,1)+timedelta(seconds=t) for t in time] t0 = bisect_left(time_dt, time_range[0]) t1 = bisect_left(time_dt, time_range[1]) images = np.array(ds_img["images"][t0:t1,...], copy=False) time = time[t0:t1] img_shape = images.shape[1:3] img_shape = ( img_shape[0] - img_shape[0]%ds_factor, img_shape[1] - img_shape[1]%ds_factor, ) noise_gen = noise.NoiseGenerator(noise_shapes(img_shape), batch_size=n_ensemble) images_ds = np.zeros( (images.shape[0],img_shape[0]//ds_factor,img_shape[1]//ds_factor,1), dtype=np.uint8 ) images_gen = np.zeros( (images.shape[0],)+img_shape+(1,n_ensemble), dtype=np.uint8 ) # this finds the nearest index in the R encoding def encoder(): lR = dec.logR ind = np.arange(len(lR)) ip = interp1d(lR,ind) def f(x): y = np.zeros(x.shape, dtype=np.uint8) valid = (x >= dec.value_range[0]) y[valid] = ip(x[valid]).round().astype(np.uint8) return y return f encode = encoder() for k in range(images.shape[0]): print("{}/{}".format(k+1,images.shape[0])) img_real = images[k:k+1,:img_shape[0],:img_shape[1],:] img_real = dec(img_real) img_real = img_real.reshape( (1,1)+img_real.shape[1:]) img_real[np.isnan(img_real)] = dec.below_val img_ds = downsampler(img_real) img_ds = dec.normalize(img_ds) img_ds_denorm = dec.denormalize(img_ds) img_ds = np.tile(img_ds, (n_ensemble,1,1,1,1)) (n_init, n_update) = noise_gen() if (h is None) or (time[k]-last_t != 600): h = init_model.predict([img_ds[:,0,...], n_init]) (img_gen,h) = gen.predict([img_ds, h, n_update]) if relax_lam > 0.0: # nudge h towards null h_null = init_model.predict([ np.zeros_like(img_ds[:,0,...]), n_init ]) h = h_null + (1.0-relax_lam)*(h-h_null) img_gen = dec.denormalize(img_gen) img_gen = img_gen.transpose((1,2,3,4,0)) images_ds[k,...] = encode(img_ds_denorm[0,...]) images_gen[k,...] = encode(img_gen[0,...]) last_t = time[k] with netCDF4.Dataset(out_fn, 'w') as ds: dim_height = ds.createDimension("dim_height", img_shape[0]) dim_width = ds.createDimension("dim_width", img_shape[1]) dim_height_ds = ds.createDimension("dim_height_ds", img_shape[0]/ds_factor) dim_width_ds = ds.createDimension("dim_width_ds", img_shape[1]/ds_factor) dim_samples = ds.createDimension("dim_samples", images.shape[0]) dim_ensemble = ds.createDimension("dim_ensemble", n_ensemble) dim_channels = ds.createDimension("dim_channels", 1) var_params = {"zlib": True, "complevel": 9} def create_var(name, dims, **params): dtype = params.pop("dtype", np.float32) var = ds.createVariable(name, dtype, dims, **params) return var var_img = create_var("images", ("dim_samples","dim_height","dim_width","dim_channels", "dim_ensemble"), chunksizes=(1,64,64,1,1), dtype=np.uint8, **var_params) var_img.units = "Encoded R" var_img_ds = create_var("images_ds", ("dim_samples","dim_height_ds","dim_width_ds","dim_channels"), dtype=np.uint8, **var_params) var_img_ds.units = "Encoded R" var_time = create_var("time", ("dim_samples",), chunksizes=(1,), dtype=np.float64, **var_params) var_time.units = "Seconds since 1970-01-01 00:00" var_img_ds[:] = images_ds var_img[:] = images_gen var_time[:] = time return (h, last_t)
def plot_sequences_horiz(gen, noise_shapes, batch_gen, samples=[0, 1, 2], num_instances=3, out_fn=None, plot_stride=2, random_seed=1234, application="mchrzc"): num_samples = len(samples) old_batch_size = batch_gen.batch_size old_augment = batch_gen.augment old_zeros_frac = batch_gen.zeros_frac img_shape = batch_gen.sequences.shape[2:4] noise_gen = noise.NoiseGenerator(noise_shapes(img_shape), batch_size=num_samples, random_seed=random_seed) # force the batch generator to return the selected samples batch_gen.next_ind = np.array(samples) try: batch_gen.batch_size = num_samples batch_gen.augment = False batch_gen.zeros_frac = 0.0 (seq_real, cond) = next(batch_gen) seq_gen = [] for i in range(num_instances): seq_gen.append(gen.predict([cond] + noise_gen())) finally: batch_gen.batch_size = old_batch_size batch_gen.augment = old_augment batch_gen.zeros_frac = old_zeros_frac seq_real = batch_gen.decoder.denormalize(seq_real) cond = batch_gen.decoder.denormalize(cond) seq_gen = [batch_gen.decoder.denormalize(seq) for seq in seq_gen] num_frames = batch_gen.num_frames if plot_stride > 1: seq_real = seq_real[:, ::plot_stride, ...] cond = cond[:, ::plot_stride, ...] for i in range(len(seq_gen)): seq_gen[i] = seq_gen[i][:, ::plot_stride, ...] num_frames = seq_real.shape[1] num_rows = num_samples num_cols = num_frames num_rows_s = 2 + num_instances figsize = (num_cols * 1.5, num_rows * num_rows_s * 1.60) fig = plt.figure(figsize=figsize) gs = gridspec.GridSpec(num_rows + 1, 1, hspace=0.05, height_ratios=[1] * num_rows + [0.035]) value_range = batch_gen.decoder.value_range for s in range(num_samples): gs_s = gridspec.GridSpecFromSubplotSpec(num_rows_s, num_cols, subplot_spec=gs[s, 0], wspace=0.05, hspace=0.05) for t in range(num_frames): plt.subplot(gs_s[0, t]) plot_img(seq_real[s, t, :, :, 0], value_range=value_range) if t == 0: plt.ylabel("Real", fontsize=16) plt.text(0.01, 0.97, "({})".format(ascii_lowercase[s]), horizontalalignment='left', verticalalignment='top', transform=plt.gca().transAxes, fontsize=16) if s == 0: plt.title("Time \u2192", fontsize=16) plt.subplot(gs_s[1, t]) plot_img(cond[s, t, :, :, 0], value_range=value_range) if t == 0: plt.ylabel("Downs.", fontsize=16) for k in range(num_instances): j = 2 + k plt.subplot(gs_s[j, t]) plot_img(seq_gen[k][s, t, :, :, 0], value_range=value_range) if t == 0: plt.ylabel("Gen. #{}".format(k + 1), fontsize=16) if application == 'mchrzc': units = "Rain rate [mm h$^{-1}$]" cb_tick_loc = np.array([-1, 0, 1, 2]) cb_tick_labels = [0.1, 1, 10, 100] elif application == 'goescod': units = "Cloud optical thickness" cb_tick_loc = np.log([2, 10, 50, 150]) cb_tick_labels = np.exp(cb_tick_loc).round().astype(int) cax = plt.subplot(gs[-1, 0]).axes cb = colorbar.ColorbarBase(cax, norm=colors.Normalize(*value_range), orientation='horizontal') cb.set_ticks(cb_tick_loc) cb.set_ticklabels(cb_tick_labels) cax.tick_params(labelsize=16) cb.set_label(units, size=16) if out_fn is not None: plt.savefig(out_fn, bbox_inches='tight') plt.close()