예제 #1
0
def quality_metrics_table(application, data_fn, weights_fn, method="gan"):
    if method == "gan":
        (wgan, batch_gen_train, batch_gen_valid, batch_gen_test,
            noise_shapes, steps_per_epoch) = train.setup_gan(data_fn,
                test_data_file=data_fn, application=application, batch_size=32)
        gen = wgan.gen
        gen.load_weights(weights_fn)
    elif method == "gen_det":
        (gen_det, batch_gen_train, batch_gen_valid, 
            batch_gen_test, steps_per_epoch) = train.setup_deterministic(data_fn,
            test_data_file=data_fn, sample_random=True, n_samples=1, batch_size=32,
            application=application, loss='mse')
        gen_det.load_weights(weights_fn)
        gen = GeneratorDeterministicPlaceholder(gen_det)
        noise_shapes = lambda s: []
    elif method == "lanczos":
        (gen_det, batch_gen_train, batch_gen_valid, 
            batch_gen_test, steps_per_epoch) = train.setup_deterministic(data_fn,
            test_data_file=data_fn, sample_random=True, n_samples=1, batch_size=32,
            application=application, loss='mse')
        gen = GeneratorLanczos((128,128))
        noise_shapes = lambda s: []
    elif method == "rainfarm":
        (gen_det, batch_gen_train, batch_gen_valid, 
            batch_gen_test, steps_per_epoch) = train.setup_deterministic(data_fn,
            test_data_file=data_fn, sample_random=True, n_samples=1, batch_size=32,
            application=application, loss='mse')
        gen = GeneratorRainFARM(16, batch_gen_test.decoder)
        noise_shapes = lambda s: []

    (rmse, ssim, lsd) = image_quality(gen, batch_gen_test, noise_shapes)

    print("RMSE: {:.3f}".format(rmse.mean()))
    print("MSSSIM: {:.3f}".format(ssim.mean()))
    print("LSD: {:.3f}".format(np.nanmean(lsd)))
예제 #2
0
def quality_metrics_by_time(application, data_fn, out_fn,
    weights_dir, check_every=1):
    (wgan, batch_gen_train, batch_gen_valid, _,
        noise_shapes, steps_per_epoch) = train.setup_gan(data_fn,
            application=application, batch_size=32)
    gen = wgan.gen

    files = os.listdir(weights_dir)
    def get_app(fn):
        return fn.split("-")[1]
    files = sorted(fn for fn in files if get_app(fn)==application)

    def log_line(line):
        with open(out_fn, 'a') as f:
            print(line, file=f)
    log_line("N RMSE MSSSIM LSD")

    for fn in files[::check_every]:
        N_samples = int(fn.split("-")[-1].split(".")[0])
        print(N_samples)
        gen.load_weights(weights_dir+"/"+fn)

        (rmse, ssim, lsd) = image_quality(gen, batch_gen_valid, noise_shapes)
        log_line("{} {:.6f} {:.6f} {:.6f}".format(
            N_samples, rmse.mean(), ssim.mean(), np.nanmean(lsd)))
def latents(model_name="../models/masc_infogan_combined",
            latents_fn="../data/masc_latents.nc",
            latent_dist_fn="../data/masc_latent_dist.nc"):
    '''
    Computes the latent codes for the entire dataset for the trained GAN.
    This will take quite a while for the latent-code distributions so it
    should probably be run on a GPU.
    '''

    (infogan, batch_gen, latent_gen, noise_gen, steps_per_epoch) = \
        train.setup_gan(["../data/masc_davos.nc","../data/masc_apres3.nc"],
        batch_size=64)
    infogan.load(infogan.filenames_from_root(model_name))

    eval.latents_all(infogan.disc, batch_gen, out_file=latents_fn)
    eval.latent_dist_all(infogan.disc, batch_gen, out_file=latent_dist_fn)
예제 #4
0
def plot_examples_goescod(data_fn, weights_fn, plot_fn):
    (wgan, batch_gen_train, batch_gen_valid, batch_gen_test, noise_shapes,
     steps_per_epoch) = train.setup_gan(data_fn,
                                        test_data_file=data_fn,
                                        sample_random=True,
                                        n_samples=1,
                                        application='goescod',
                                        random_seed=1234)
    gen = wgan.gen
    gen.load_weights(weights_fn)
    plot_sequences_horiz(gen,
                         noise_shapes,
                         batch_gen_test,
                         samples=[0, 1, 2],
                         application='goescod',
                         plot_stride=1)
    plt.savefig(plot_fn, bbox_inches='tight')
    plt.close()
예제 #5
0
def plot_all(data_fn, gen_weights_fn, application="mch"):
    num_channels = {"mch": 1, "goes": 3}[application]
    (gen_styled, gen, styling,
     noise_shapes) = models.generator_styled(num_channels=num_channels)
    gen_styled.load_weights(gen_weights_fn)

    (wgan, batch_gen, noise_shapes,
     steps_per_epoch) = train.setup_gan(data_fn,
                                        n_samples=128,
                                        sample_random=True,
                                        application=application,
                                        random_seed=321459)

    noise_gen = noise.NoiseGenerator(noise_shapes(),
                                     batch_size=batch_gen.batch_size,
                                     random_seed=34)
    plot_samples(gen_styled,
                 batch_gen,
                 noise_gen,
                 out_fn="../figures/{}_samples.pdf".format(application))

    noise_gen_1 = noise.NoiseGenerator(noise_shapes(),
                                       batch_size=batch_gen.batch_size,
                                       random_seed=221)
    noise_gen_2 = noise.NoiseGenerator(noise_shapes(),
                                       batch_size=batch_gen.batch_size,
                                       random_seed=70)
    noise_gen_3 = noise.NoiseGenerator(noise_shapes(),
                                       batch_size=batch_gen.batch_size,
                                       random_seed=39)
    plot_styles(gen_styled,
                batch_gen, [noise_gen_1, noise_gen_2, noise_gen_3],
                out_fn="../figures/{}_styles.pdf".format(application))

    noise_gen = noise.NoiseGenerator(noise_shapes(),
                                     batch_size=batch_gen.batch_size,
                                     random_seed=241)
    plot_transition(gen,
                    styling,
                    batch_gen,
                    noise_gen,
                    out_fn="../figures/{}_transition.pdf".format(application))

    gc.collect()
def training(model_load_name="",
             model_save_name="../models/masc_infogan_combined",
             batch_size=64,
             batches_per_epoch=500):
    '''
    Given enough time, this should replicate the GAN training.
    The train_gan function will create a figure called progress.pdf
    in the figures directory, this figure can be used to monitor the
    progress of the training.

    The default model_target_name is set to masc_infogan so you don't
    accidentally overwrite the pre-trained models.
    '''
    (infogan, batch_gen, latent_gen, noise_gen, steps_per_epoch) = \
        train.setup_gan(["../data/masc_davos.nc","../data/masc_apres3.nc"],
        batch_size=batch_size)
    if model_load_name:
        infogan.load(infogan.filenames_from_root(model_load_name))
    while True:
        train.train_gan(infogan, batch_gen, latent_gen, noise_gen,
                        batches_per_epoch, 1)
        infogan.save(model_save_name)
예제 #7
0
def rank_metrics_by_noise(application, run_id, data_file,
    weights_fn):
    (wgan, batch_gen_train, batch_gen_valid, _,
        noise_shapes, steps_per_epoch) = train.setup_gan(data_file,
        application=application)
    gen = wgan.gen
    noise_gen = noise.NoiseGenerator(noise_shapes(),
        batch_size=batch_gen_valid.batch_size)

    for m in list(range(0.5,2.51,0.1))+[3.0,3.5]:
        N_samples = int(fn.split("-")[-1].split(".")[0])
        gen.load_weights(weights_dir+"/"+fn)
        (ranks, crps_scores) = ensemble_ranks(gen, batch_gen_valid,
            noise_gen, num_batches=32, noise_mul=m)
        
        KS = rank_KS(ranks)
        CvM = rank_CvM(ranks) 
        DKL = rank_DKL(ranks)
        CRPS = crps_scores.mean()
        mean = ranks.mean()
        std = ranks.std()

        print(N_samples, KS, CvM, DKL, CRPS, mean, std)
예제 #8
0
def rank_metrics_table(application, data_file, weights_fn, method="gan"):

    if method=="gan":
        (wgan, batch_gen_train, batch_gen_valid, batch_gen_test,
            noise_shapes, steps_per_epoch) = train.setup_gan(data_file,
            test_data_file=data_file, application=application, batch_size=64)
        gen = wgan.gen
        gen.load_weights(weights_fn)
    elif method=="rainfarm":
        (gen_det, batch_gen_train, batch_gen_valid, 
            batch_gen_test, steps_per_epoch) = train.setup_deterministic(data_file,
            test_data_file=data_file, sample_random=True, n_samples=1, batch_size=64,
            application=application, loss='mse')
        gen = GeneratorRainFARM(16, batch_gen_test.decoder)
        noise_shapes = lambda: []

    noise_gen = noise.NoiseGenerator(noise_shapes(),
        batch_size=batch_gen_valid.batch_size)

    (ranks, crps_scores) = ensemble_ranks(gen, batch_gen_test,
        noise_gen, num_batches=16)
    
    KS = rank_KS(ranks)
    CvM = rank_CvM(ranks) 
    DKL = rank_DKL(ranks)
    OP = rank_OP(ranks)
    CRPS = crps_scores.mean() 
    mean = ranks.mean()
    std = ranks.std()

    print("KS: {:.3f}".format(KS))
    print("CvM: {:.3f}".format(CvM))
    print("DKL: {:.3f}".format(DKL))
    print("OP: {:.3f}".format(OP))
    print("CRPS: {:.3f}".format(CRPS))
    print("mean: {:.3f}".format(mean))
    print("std: {:.3f}".format(std))
예제 #9
0
def rank_metrics_by_time(application, data_file, out_fn,
    weights_dir, check_every=1, N_range=None):
    (wgan, batch_gen_train, batch_gen_valid, batch_gen_test,
        noise_shapes, steps_per_epoch) = train.setup_gan(data_file,
        application=application, batch_size=64)
    gen = wgan.gen
    noise_gen = noise.NoiseGenerator(noise_shapes(),
        batch_size=batch_gen_valid.batch_size)

    files = os.listdir(weights_dir)
    def get_id(fn):
        return fn.split("-")[1]
    files = sorted(fn for fn in files if get_id(fn)==application)

    def log_line(line):
        with open(out_fn, 'a') as f:
            print(line, file=f)
    log_line("N KS CvM DKL OP CRPS mean std")

    for fn in files[::check_every]:
        N_samples = int(fn.split("-")[-1].split(".")[0])
        if (N_range is not None) and not (N_range[0] <= N_samples < N_range[1]):
            continue
        gen.load_weights(weights_dir+"/"+fn)
        (ranks, crps_scores) = ensemble_ranks(gen, batch_gen_valid,
            noise_gen, num_batches=8)
        
        KS = rank_KS(ranks)
        CvM = rank_CvM(ranks) 
        DKL = rank_DKL(ranks)
        OP = rank_OP(ranks)
        CRPS = crps_scores.mean() 
        mean = ranks.mean()
        std = ranks.std()

        log_line("{} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f}".format(
            N_samples, KS, CvM, DKL, OP, CRPS, mean, std))
예제 #10
0
def plot_examples_goescod_random(data_fn,
                                 weights_fn,
                                 plot_dir,
                                 num_examples=16):
    (wgan, batch_gen_train, batch_gen_valid, batch_gen_test, noise_shapes,
     steps_per_epoch) = train.setup_gan(data_fn,
                                        test_data_file=data_fn,
                                        sample_random=True,
                                        n_samples=1,
                                        application='goescod',
                                        random_seed=2345)
    gen = wgan.gen
    gen.load_weights(weights_fn)
    for k in range(num_examples):
        plot_fn = plot_dir + "/examples-goescod-random-{:02d}.pdf".format(k)
        plot_sequences_horiz(gen,
                             noise_shapes,
                             batch_gen_test,
                             samples=[k],
                             application='goescod',
                             plot_stride=1,
                             num_instances=12)
        plt.savefig(plot_fn, bbox_inches='tight')
        plt.close()
예제 #11
0
        data_fn = args.data_file
        application = args.application
        load_weights_root = args.load_weights_root
        save_weights_root = args.save_weights_root
        log_path = args.log_path
        steps_per_epoch = args.steps_per_epoch
        batch_size = args.batch_size
        num_samples = args.num_samples
        opt_switch_point = args.opt_switch_point

        if not save_weights_root:
            save_weights_root = path + "../models/downscaleseqgan"

        # initialize GAN
        (wgan, batch_gen_train, batch_gen_valid, _, noise_shapes, _) = \
            train.setup_gan(data_fn,
                batch_size=batch_size, application=application)

        if load_weights_root:  # load weights and run status
            wgan.load(wgan.filenames_from_root(load_weights_root))
            with open(load_weights_root + "-run_status.json", 'r') as f:
                run_status = json.load(f)
            training_samples = run_status["training_samples"]

            if log_path:
                log_file = "{}/log-{}.txt".format(log_path, application)
                log = pd.read_csv(log_file)

        else:  # initialize run status
            chars = string.ascii_lowercase + string.digits
            training_samples = 0
def experiments(model_name="../models/masc_infogan_combined",
                latents_fn="../data/masc_latent_dist_combined.nc"):
    (latents, ind) = cluster.sample_latents(latents_fn, random_seed=1000)
    '''
    Creates the plots and computes the key evaluation numbers reported in 
    the paper. This should be runnable on a CPU given a pre-trained model,
    although it will take some time.
    '''

    # number of clusters
    (K, cost, min_medoids) = cluster.cluster_cost(latents, ind, K_max=20)
    np.save("../data/cluster_losses_combined.npy", (K, cost, min_medoids))

    kmed_16 = cluster.KMedoids(latents,
                               metric=cluster.distribution_distance,
                               num_medoids=16)
    kmed_16.medoid_ind = min_medoids[15]
    kmed_16.rearrange_medoids()
    kmed_16.medoid_ind = kmed_16.medoid_ind[[
        4,
        0,
        1,
        3,
        2,
        15,
        13,
        14,
        11,
        12,
        9,
        10,
        5,
        7,
        6,
        8,
    ]]  # cosmetic rearrangement
    (K_16, costs_16) = cluster.hierarchy_cost(kmed_16)
    (branches_16, joins_16) = cluster.cluster_hierarchy(kmed_16)

    kmed_6 = cluster.KMedoids(latents,
                              metric=cluster.distribution_distance,
                              num_medoids=6)
    kmed_6.medoid_ind = min_medoids[5]
    kmed_6.rearrange_medoids()
    (K_6, costs_6) = cluster.hierarchy_cost(kmed_6)
    (branches_6, joins_6) = cluster.cluster_hierarchy(kmed_6)

    plots.cluster_number(K, cost, costs_16, costs_6)
    plt.savefig("../figures/Kmedoids_loss.pdf", bbox_inches='tight')
    plt.close()

    # cluster_samples
    plots.cluster_samples(kmed_16, ind, random_seed=1001)
    plt.savefig("../figures/class_samples_16.pdf", bbox_inches='tight')
    plt.close()

    plots.cluster_samples(kmed_6, ind, random_seed=1002)
    plt.savefig("../figures/class_samples_6.pdf", bbox_inches='tight')
    plt.close()

    # cluster distance matrix
    plots.cluster_distance_matrix(kmed_16)
    plt.savefig("../figures/class_distance_matrix.pdf", bbox_inches='tight')
    plt.close()

    # class membership matrix
    plots.class_membership_matrix(kmed_16)
    plt.savefig("../figures/membership_matrix.pdf", bbox_inches='tight')
    plt.close()

    # class statistics matrix
    plots.class_statistics_matrix(kmed_16)
    plt.savefig("../figures/statistics_matrix.pdf", bbox_inches='tight')
    plt.close()

    # GAN samples
    (infogan, batch_gen, latent_gen, noise_gen, steps_per_epoch) = \
        train.setup_gan(["../data/masc_davos.nc","../data/masc_apres3.nc"],
            batch_size=8, latent_seed=1003, noise_seed=1004, batch_seed=1005)
    infogan.load(infogan.filenames_from_root(model_name))
    plots.sample_images(infogan.disc,
                        infogan.gen,
                        batch_gen,
                        latent_gen,
                        noise_gen,
                        num_samples=8)
    plt.savefig("../figures/sample_images.pdf", bbox_inches='tight')
    plt.close()

    latent_gen = data.NoiseGenerator([(8, )],
                                     batch_size=batch_gen.batch_size,
                                     random_seed=1040)
    noise_gen = data.NoiseGenerator(noise_gen.noise_shapes,
                                    batch_size=batch_gen.batch_size,
                                    random_seed=1007)
    plots.latent_variation(infogan.gen, noise_gen, latent_gen)
    plt.savefig("../figures/latent_variation.pdf", bbox_inches='tight')
    plt.close()

    gc.collect()

    # Distribution distance / SED comparison
    cam_index = eval.group_images_by_cam(
        ["../data/masc_davos.nc", "../data/masc_apres3.nc"])
    (d, dm, dr, d_dist, dm_dist,
     dr_dist) = eval.evaluate_distance("../data/masc_latent_combined.nc",
                                       "../data/masc_latent_dist_combined.nc",
                                       cam_index,
                                       random_seed=1008)
    print("Median SED (pairs) = {:.3f}".format(np.median(d)))
    print("Median SED (all) = {:.3f}".format(np.median(dm)))
    print("Median distance rank for SED = {:.3f}%".format(np.median(dr) * 100))
    print("Median Bhattacharyya distance (pairs) = {:.3f}".format(
        np.median(d_dist)))
    print("Median Bhattacharyya distance (all) = {:.3f}".format(
        np.median(dm_dist)))
    print("Median distance rank for Bhattacharyya distance = {:.3f}%".format(
        np.median(dr_dist) * 100))

    gc.collect()

    # SSIM
    (median_ssim, mean_ssim, std_ssim) = eval.average_ssim(infogan.gen,
                                                           infogan.disc,
                                                           batch_gen,
                                                           noise_gen,
                                                           num_batches=128,
                                                           verbose=True)
    print("Median SSIM = {:.3f}".format(median_ssim))
    print("Mean SSIM = {:.3f}".format(mean_ssim))
    print("SSIM st. dev. = {:.3f}".format(std_ssim))

    gc.collect()
예제 #13
0
                        help="Network weights file root")
    parser.add_argument('--application',
                        type=str,
                        default="mch",
                        help="Application (mch or goes)")

    args = parser.parse_args()
    mode = args.mode
    data_fn = args.data_file
    weights_root = args.weights_root
    application = args.application

    if mode == "train":
        (wgan, batch_gen, noise_shapes,
         steps_per_epoch) = train.setup_gan(data_fn,
                                            batch_size=64,
                                            application=application)
        if weights_root:
            wgan.load(wgan.filenames_from_root(weights_root))
        while True:
            train.train_gan(wgan,
                            batch_gen,
                            noise_shapes,
                            100,
                            1,
                            application=application)
            wgan.save(weights_root)
    elif mode == "plot":
        gen_weights_file = weights_root + "-gen_weights.h5"
        plots.plot_all(data_fn, gen_weights_file, application=application)