コード例 #1
0
def load(dataset, method, datafolder=None):
    print(dataset, method)

    method, L = get_methodName_and_L(method)

    if method == "slang":
        all_exps = slang.slang_base.get_variant("slang_convergence_final").variants

        def is_relevant(exp):
            return (DATASET_VILIB_NAME[dataset] in exp and "L_" + str(L) + "_" in exp)
        relevant_exps = [exp for exp in all_exps if is_relevant(exp)]

        results = get_experiment_results(slang.slang_base, "slang_convergence_final", relevant_exps)
    else:
        all_exps = bbb.bbb_copy_slang.get_variant("bbb_convergence_final").variants

        def is_relevant(exp):
            return DATASET_VILIB_NAME[dataset] in exp
        relevant_exps = [exp for exp in all_exps if is_relevant(exp)]

        results = get_experiment_results(bbb.bbb_copy_slang, "bbb_convergence_final", relevant_exps)

    nlls = []
    nlZs = []
    accs = []
    for r in results:
        nlZs.append(r['metric_history']['elbo_neg_ave'])
        nlls.append(r['metric_history']['test_pred_logloss'])
        accs.append(r['metric_history']['test_pred_accuracy'])

    x = np.array(range(len(nlZs[0]))) * MINIBATCH_SIZES[dataset] / Ns[dataset]

    return x + 1, np.vstack(nlZs), np.vstack(nlls), np.vstack(accs)
コード例 #2
0
def plot_slang_convergence(data_set, L, metrics, prior_prec_indices):

    if not (data_set == 'australian_presplit'
            or data_set == 'breastcancer_presplit' or data_set == 'usps_3vs5'):
        raise ValueError('"' + data_set +
                         '" is not a valid value for the data_set argument.')

    # possible datasets are: 'australian_presplit', 'breastcancer_presplit', or 'usps_3vs5'.
    # valid values of L are: 1 and 10.
    # possible metrics are:  'test_pred_logloss', 'train_pred_logloss', 'elbo_neg_ave', 'accuracy' or 'all'.
    # possible prior precisions are:
    #       prior_precisions = [0.001, 0.01, 0.1, 1, 8, 32, 64, 128, 512]
    # prior_prec_indices is used to select these values and should be list containing integers between
    # 0 and 8.

    if metrics == 'all':
        metrics = ['test_pred_logloss', 'train_pred_logloss', 'elbo_neg_ave']

    prior_precisions = [0.001, 0.01, 0.1, 1, 8, 32, 64, 128, 512]
    seeds = [1, 2, 3]

    variants_to_plot = list(
        filter(lambda z: (data_set in z) and (('_' + str(L) + '_') in z),
               variants))

    experiment_base = slang.slang_cv
    prior_prec_variants = []

    for i in prior_prec_indices:
        prec = prior_precisions[i]
        prior_string = ('prior_prec_' + str(prec))
        v = list(filter(lambda z: prior_string in z, variants_to_plot))
        prior_prec_variants.append(v)

    experiment_results = {}
    for i, list_of_restarts in enumerate(prior_prec_variants):
        crashed = False
        for restart in list_of_restarts:
            record = slang.slang_cv.get_variant(experiment_name).get_variant(
                restart).get_latest_record()
            crashed = crashed or (not record.has_result())

        if not crashed:
            results = record_utils.get_experiment_results(
                slang.slang_cv, experiment_name, list_of_restarts)
            results = list(results[0].values())
            experiment_results[prior_precisions[
                i]] = record_utils.summarize_metric_histories(results)
        else:
            experiment_results[
                prior_precisions[i]] = None  # One of these restarts crashed.

    if not os.path.exists('plots/prior_selection/'):
        os.makedirs('plots/prior_selection/')

    if not os.path.exists('plots/prior_selection/' + data_set + '/'):
        os.makedirs('plots/prior_selection/' + data_set + '/')

    ######################
    #### Plot Results ####
    ######################

    plt.ioff()

    if 'test_pred_logloss' in metrics:
        # Plot test logloss
        plt.figure()
        labels = []
        for key, item in experiment_results.items():
            if not item is None:
                plt.plot(1 + np.arange(5001),
                         item['test_pred_logloss']['mean'])
                labels.append(str(key))

        plt.legend(labels)
        plt.grid()
        plt.xlabel("Epoch")
        plt.ylabel("Test logloss")
        plt.title("L = " + str(L) + ", Dataset = " + data_set)
        plt.savefig("plots/prior_selection/" + data_set + "/slang_L_" +
                    str(L) + "_Dataset_" + data_set + "_test_logloss.pdf")
        plt.close()

    if 'train_pred_logloss' in metrics:
        # Plot train logloss
        plt.figure()
        labels = []
        for key, item in experiment_results.items():
            if not item is None:
                plt.plot(1 + np.arange(5001),
                         item['train_pred_logloss']['mean'])
                labels.append(str(key))

        plt.legend(labels)
        plt.grid()
        plt.xlabel("Epoch")
        plt.ylabel("Train logloss")
        plt.title("L = " + str(L) + ", Dataset = " + data_set)
        plt.savefig("plots/prior_selection/" + data_set + "/slang_L_" +
                    str(L) + "_Dataset_" + data_set + "_train_logloss.pdf")
        plt.close()

    if 'elbo_neg_ave' in metrics:
        # Plot ELBO
        plt.figure()
        labels = []
        for key, item in experiment_results.items():
            if not item is None:
                plt.loglog(1 + np.arange(5001), item['elbo_neg_ave']['mean'])
                labels.append(str(key))

        plt.legend(labels)
        plt.grid()
        plt.xlabel("Epoch")
        plt.ylabel("Neg. Average ELBO")
        plt.title("L = " + str(L) + ", Dataset = " + data_set)
        plt.savefig("plots/prior_selection/" + data_set + "/slang_L_" +
                    str(L) + "_Dataset_" + data_set + "_elbo.pdf")
        plt.close()
コード例 #3
0
def plot_bbb_convergence(data_set, metrics, lr_indices):

    if not (data_set == 'australian_presplit'
            or data_set == 'breastcancer_presplit' or data_set == 'usps_3vs5'):
        raise ValueError('"' + data_set +
                         '" is not a valid value for the data_set argument.')

    # possible datasets are: 'australian_presplit', 'breastcancer_presplit', or 'usps_3vs5'.
    # possible metrics are:  'test_pred_logloss', 'train_pred_logloss', 'elbo_neg_ave', 'accuracy' or 'all'.
    # possible learning rates are:
    #       [0.0001, 0.00021544, 0.00046416, 0.001, 0.00215443, 0.00464159, 0.01, 0.02154435, 0.04641589, 0.1]
    # lr_indices is used to select these values and should be list containing integers between
    # 0 and 9.

    if metrics == 'all':
        metrics = [
            'test_pred_logloss', 'train_pred_logloss', 'elbo_neg_ave',
            'accuracy'
        ]

    seeds = [1, 2, 3]

    variants_to_plot = list(filter(lambda z: (data_set in z), variants))
    lrs = np.logspace(-4, -1, 10)

    lrs = lrs[lr_indices]

    lr_variants = []

    for lr in lrs:
        lr_string = ('lr_' + str(lr))
        v = list(filter(lambda z: lr_string in z, variants_to_plot))
        lr_variants.append(v)

    experiment_results = {}
    for i, list_of_restarts in enumerate(lr_variants):
        crashed = False
        for restart in list_of_restarts:
            record = bbb.bbb_cv.get_variant(experiment_name).get_variant(
                restart).get_latest_record()
            crashed = crashed or (not record.has_result())

        if not crashed:
            results = record_utils.get_experiment_results(
                bbb.bbb_cv, experiment_name, list_of_restarts)
            results = list(results[0].values())
            experiment_results[
                lrs[i]] = record_utils.summarize_metric_histories(results)
        else:
            experiment_results[lrs[i]] = None  # One of these restarts crashed.

    if not os.path.exists('plots/lr_selection/'):
        os.makedirs('plots/lr_selection/')

    if not os.path.exists('plots/lr_selection/' + data_set + '/'):
        os.makedirs('plots/lr_selection/' + data_set + '/')

    ######################
    #### Plot Results ####
    ######################

    plt.ioff()

    if 'test_pred_logloss' in metrics:
        # Plot test logloss
        plt.figure()
        labels = []
        for key, item in experiment_results.items():
            if not item is None:
                plt.plot(1 + np.arange(1001),
                         item['test_pred_logloss']['mean'])
                labels.append(str(key))

        plt.legend(labels)
        plt.grid()
        plt.xlabel("Epoch")
        plt.ylabel("Test logloss")
        plt.title("Dataset = " + data_set)
        plt.savefig("plots/lr_selection/" + data_set + "/bbb_" + "_Dataset_" +
                    data_set + "_test_logloss.pdf")
        plt.close()

    if 'train_pred_logloss' in metrics:
        # Plot train logloss
        plt.figure()
        labels = []
        for key, item in experiment_results.items():
            if not item is None:
                plt.plot(1 + np.arange(1001),
                         item['train_pred_logloss']['mean'])
                labels.append(str(key))

        plt.legend(labels)
        plt.grid()
        plt.xlabel("Epoch")
        plt.ylabel("Train logloss")
        plt.title("Dataset = " + data_set)
        plt.savefig("plots/lr_selection/" + data_set + "/bbb_" + "_Dataset_" +
                    data_set + "_train_logloss.pdf")
        plt.close()

    if 'elbo_neg_ave' in metrics:
        # Plot ELBO
        plt.figure()
        labels = []
        for key, item in experiment_results.items():
            if not item is None:
                plt.loglog(1 + np.arange(1001), item['elbo_neg_ave']['mean'])
                labels.append(str(key))

        plt.legend(labels)
        plt.grid()
        plt.xlabel("Epoch")
        plt.ylabel("Neg. Average ELBO")
        plt.title("Dataset = " + data_set)
        plt.savefig("plots/lr_selection/" + data_set + "/bbb_" + "_Dataset_" +
                    data_set + "_elbo.pdf")
        plt.close()
コード例 #4
0
def plot_convergence(data_set, metrics):

    if not (data_set == 'australian_presplit' or data_set == 'breastcancer_presplit' or data_set == 'usps_3vs5'):
        raise ValueError('"' + data_set + '" is not a valid value for the data_set argument.')

    # possible datasets are: 'australian_presplit', 'breastcancer_presplit', or 'usps_3vs5'.
    # possible metrics are:  'test_pred_logloss', 'train_pred_logloss', 'elbo_neg_ave', 'accuracy' or 'all'.

    if metrics == 'all':
        metrics = ['test_pred_logloss', 'train_pred_logloss', 'elbo_neg_ave', 'accuracy']


    random_seeds = np.arange(1,11)

    slang_variants_to_plot = list(filter(lambda z: (data_set in z), slang_final.variants))
    bbb_variants_to_plot = list(filter(lambda z: (data_set in z), bbb_final.variants))


    experiment_base = slang.slang_base
    L_variants = []
    Ls = [1,8,32,64]
    for l in Ls:
        l_string = ('L_' + str(l))
        v = list(filter(lambda z: l_string in z, slang_variants_to_plot))
        L_variants.append(v)


    experiment_results = {}
    for i, list_of_restarts in enumerate(L_variants):
        crashed = False
        for restart in list_of_restarts:
            record = slang.slang_base.get_variant(slang_final.experiment_name).get_variant(restart).get_latest_record()
            crashed = crashed or (not record.has_result())

        if not crashed:
            results = record_utils.get_experiment_results(slang.slang_base, slang_final.experiment_name, list_of_restarts)
            experiment_results[Ls[i]] = record_utils.summarize_metric_histories(results)
        else:
            experiment_results[Ls[i]] = None      # One of these restarts crashed.


    results = record_utils.get_experiment_results(bbb.bbb_copy_slang, bbb_final.experiment_name, bbb_variants_to_plot)
    experiment_results["BBB"] = record_utils.summarize_metric_histories(results)


    # create the necessary directories if they don't exist.
    if not os.path.exists('plots/final_plots/'):
        os.makedirs('plots/final_plots/')

    if not os.path.exists('plots/final_plots/' + data_set + '/'):
        os.makedirs('plots/final_plots/' + data_set + '/')


    if data_set == 'usps_3vs5':
        iters = 6501
    else:
        iters = 6001


    ######################
    #### Plot Results ####
    ######################

    plt.ioff()

    if 'test_pred_logloss' in metrics:
        # Plot test logloss
        plt.figure()
        labels = []
        for key, item in experiment_results.items():

            if not item is None:
                plt.plot(np.arange(iters), item['test_pred_logloss']['mean'])
                labels.append(str(key))

        plt.legend(labels)
        plt.grid()
        plt.xlabel("Epoch")
        plt.ylabel("Test logloss")
        plt.title("Dataset = " + data_set)
        plt.savefig("plots/final_plots/" + data_set +  "/test_logloss.pdf")
        plt.close()

    if 'accuracy' in metrics:
        # Plot test logloss
        plt.figure()
        labels = []
        for key, item in experiment_results.items():
            if not item is None:
                plt.plot(np.arange(iters), item['test_pred_accuracy']['mean'])
                labels.append(str(key))

        plt.legend(labels)
        plt.grid()
        plt.xlabel("Epoch")
        plt.ylabel("Test Accuracy")
        plt.title("Dataset = " + data_set)
        plt.savefig("plots/final_plots/" + data_set +  "/test_accuracy.pdf")
        plt.close()


    if 'train_pred_logloss' in metrics:
        # Plot train logloss
        plt.figure()
        labels = []
        for key, item in experiment_results.items():
            if not item is None:
                plt.plot(np.arange(iters), item['train_pred_logloss']['mean'])
                labels.append(str(key))

        plt.legend(labels)
        plt.grid()
        plt.xlabel("Epoch")
        plt.ylabel("Train logloss")
        plt.title("Dataset = " + data_set)
        plt.savefig("plots/final_plots/" + data_set + "/train_logloss.pdf")
        plt.close()


    if 'elbo_neg_ave' in metrics:
        # Plot ELBO
        plt.figure()
        labels = []
        for key, item in experiment_results.items():
            if not item is None:
                plt.loglog(np.arange(iters), item['elbo_neg_ave']['mean'])
                labels.append(str(key))

        plt.legend(labels)
        plt.grid()
        plt.xlabel("Epoch")
        plt.ylabel("Neg. Average ELBO")
        plt.title("Dataset = " + data_set)
        plt.savefig("plots/final_plots/" + data_set + "/elbo.pdf")
        plt.close()