Exemplo n.º 1
0
def print_sort_model():
    config_manager = ConfigManager(config)
    num_runs = config_manager.get_num_left_config()

    # get original metric
    ori_metrics, ori_tc, ori_other_metrics = get_metrics()

    # get cross results
    metric_values, metric_names, cross_metrics, cross_metrics_average = \
        get_cross_metrics()

    sort_list = []
    for i in range(num_runs):
        array = (list(cross_metrics_average[i, :i]) +
                 list(cross_metrics_average[i, i + 1:]))
        sort_list.append(np.mean(array))

    sort_order = np.argsort(sort_list)[::-1]

    for i in range(num_runs):
        other_metric_str = ""
        for metric in METRICS:
            for sub_metric in metric[1]:
                other_metric_str += "{}={}, ".format(
                    sub_metric,
                    ori_other_metrics[sort_order[i]][metric[0]][sub_metric])
        print("{:0.3f} ({}, tc={}, {})".format(sort_list[sort_order[i]],
                                               ori_metrics[sort_order[i]],
                                               ori_tc[sort_order[i]],
                                               other_metric_str))
Exemplo n.º 2
0
def get_cross_metrics():
    metric_values = []
    metric_names = []

    # get cross results
    config_manager = ConfigManager(config)
    num_runs = config_manager.get_num_left_config()
    cross_metrics = np.zeros((num_runs, num_runs))
    all_configs = []
    i = 0
    while config_manager.get_num_left_config() > 0:
        cur_config = config_manager.get_next_config()
        _work_dir = cur_config["work_dir"]
        file_path = os.path.join(_work_dir, "cross_evaluation.pkl")
        with open(file_path, "rb") as f:
            data = pickle.load(f)
        results = data["results"]
        configs = data["configs"]
        all_configs.append(configs)

        for j in range(num_runs):
            cross_metrics[i, j] = results[j]["factorVAE_metric"]

        i += 1

    # double check the configs
    for i in range(len(all_configs)):
        for j in range(len(all_configs[0])):
            for k in all_configs[0][0]:
                assert all_configs[i][j][k] == all_configs[0][j][k]

    cross_metrics_average = (cross_metrics + cross_metrics.T) / 2

    return metric_values, metric_names, cross_metrics, cross_metrics_average
Exemplo n.º 3
0
def get_metrics():
    ori_metrics = []
    ori_cr = []
    ori_info = []
    ori_other_metrics = []
    config_manager = ConfigManager(config)
    num_runs = config_manager.get_num_left_config()
    while config_manager.get_num_left_config() > 0:
        cur_config = config_manager.get_next_config()
        _config = cur_config["config"]
        ori_cr.append(_config["cr_coe_increase"])
        ori_info.append(_config["info_coe_de"])
        _work_dir = cur_config["work_dir"]
        metric_path = os.path.join(_work_dir, "metric.csv")
        with open(metric_path, "r") as f:
            reader = csv.DictReader(f)
            for row in reader:
                if row["epoch_id"] == "27" and row["batch_id"] == "-1":
                    ori_metrics.append(float(row["factorVAE_metric"]))
        with open(os.path.join(_work_dir, "final_metrics.pkl"), "rb") as f:
            other_metrics = pickle.load(f)
        ori_other_metrics.append(other_metrics)
    assert len(ori_metrics) == num_runs

    return ori_metrics, ori_cr, ori_info, ori_other_metrics
    def main(self):
        import os
        import tensorflow as tf
        import pickle
        from gan.load_data import load_dSprites
        from gan.latent import GaussianLatent, JointLatent
        from gan.network import VAEDecoder, VAEEncoder, TCDiscriminator, \
            MetricRegresser
        from gan.factorVAE import FactorVAE
        from gan.metric import FactorVAEMetric, DSpritesInceptionScore, \
            DHSICMetric

        from gpu_task_scheduler.config_manager import ConfigManager
        from config_mc import config

        data, metric_data, latent_values, metadata = \
            load_dSprites("data/dSprites")
        _, height, width, depth = data.shape

        latent_list = []

        for i in range(self._config["gaussian_dim"]):
            latent_list.append(
                GaussianLatent(in_dim=1,
                               out_dim=1,
                               loc=0.0,
                               scale=1.0,
                               q_std=1.0,
                               apply_reg=True))
        latent = JointLatent(latent_list=latent_list)

        decoder = VAEDecoder(output_width=width,
                             output_height=height,
                             output_depth=depth)
        encoder = VAEEncoder(output_length=latent.reg_in_dim)
        tcDiscriminator = TCDiscriminator()

        shape_network = MetricRegresser(
            output_length=3, scope_name="dSpritesSampleQualityMetric_shape")

        checkpoint_dir = os.path.join(self._work_dir, "checkpoint")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        sample_dir = os.path.join(self._work_dir, "sample")
        if not os.path.exists(sample_dir):
            os.makedirs(sample_dir)
        time_path = os.path.join(self._work_dir, "time.txt")
        metric_path = os.path.join(self._work_dir, "metric.csv")

        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        with tf.Session(config=run_config) as sess:
            factorVAEMetric = FactorVAEMetric(metric_data, sess=sess)
            dSpritesInceptionScore = DSpritesInceptionScore(
                sess=sess,
                do_training=False,
                data=data,
                metadata=metadata,
                latent_values=latent_values,
                network_path="metric_model/DSprites",
                shape_network=shape_network,
                sample_dir=sample_dir)
            dHSICMetric = DHSICMetric(sess=sess, data=data)
            metric_callbacks = [
                factorVAEMetric, dSpritesInceptionScore, dHSICMetric
            ]
            vae = FactorVAE(sess=sess,
                            checkpoint_dir=checkpoint_dir,
                            sample_dir=sample_dir,
                            time_path=time_path,
                            epoch=self._config["epoch"],
                            batch_size=self._config["batch_size"],
                            data=data,
                            vis_freq=self._config["vis_freq"],
                            vis_num_sample=self._config["vis_num_sample"],
                            vis_num_rep=self._config["vis_num_rep"],
                            latent=latent,
                            decoder=decoder,
                            encoder=encoder,
                            tcDiscriminator=tcDiscriminator,
                            tc_coe=self._config["tc_coe"],
                            metric_callbacks=metric_callbacks,
                            metric_freq=self._config["metric_freq"],
                            metric_path=metric_path,
                            output_reverse=self._config["output_reverse"])
            vae.build()
            vae.load()

            results = []
            configs = []

            sub_config_manager = ConfigManager(config)
            while sub_config_manager.get_num_left_config() > 0:
                sub_config = sub_config_manager.get_next_config()

                metric_data_path = os.path.join(sub_config["work_dir"],
                                                "metric_data.pkl")
                with open(metric_data_path, "rb") as f:
                    sub_metric_data = pickle.load(f)

                sub_factorVAEMetric = FactorVAEMetric(sub_metric_data,
                                                      sess=sess)
                sub_factorVAEMetric.set_model(vae)
                sub_result = sub_factorVAEMetric.evaluate(-1, -1, -1)

                results.append(sub_result)
                configs.append(sub_config)

            with open(os.path.join(self._work_dir, "cross_evaluation.pkl"),
                      "wb") as f:
                pickle.dump({
                    "results": results,
                    "configs": configs
                },
                            f,
                            protocol=2)
    def main(self):
        import os
        import tensorflow as tf
        import pickle
        from gan.load_data import load_dSprites
        from gan.latent import UniformLatent, JointLatent
        from gan.network import Decoder, InfoGANDiscriminator, \
            CrDiscriminator, MetricRegresser
        from gan.infogan_cr import INFOGAN_CR
        from gan.metric import FactorVAEMetric, DSpritesInceptionScore, \
            DHSICMetric

        from gpu_task_scheduler.config_manager import ConfigManager
        from config_mc import config

        data, metric_data, latent_values, metadata = \
            load_dSprites("data/dSprites")
        _, height, width, depth = data.shape

        latent_list = []

        for i in range(self._config["uniform_reg_dim"]):
            latent_list.append(
                UniformLatent(in_dim=1,
                              out_dim=1,
                              low=-1.0,
                              high=1.0,
                              q_std=1.0,
                              apply_reg=True))
        if self._config["uniform_not_reg_dim"] > 0:
            latent_list.append(
                UniformLatent(in_dim=self._config["uniform_not_reg_dim"],
                              out_dim=self._config["uniform_not_reg_dim"],
                              low=-1.0,
                              high=1.0,
                              q_std=1.0,
                              apply_reg=False))
        latent = JointLatent(latent_list=latent_list)

        decoder = Decoder(output_width=width,
                          output_height=height,
                          output_depth=depth)
        infoGANDiscriminator = \
            InfoGANDiscriminator(
                output_length=latent.reg_out_dim,
                q_l_dim=self._config["q_l_dim"])
        crDiscriminator = CrDiscriminator(output_length=latent.num_reg_latent)

        shape_network = MetricRegresser(
            output_length=3, scope_name="dSpritesSampleQualityMetric_shape")

        checkpoint_dir = os.path.join(self._work_dir, "checkpoint")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        sample_dir = os.path.join(self._work_dir, "sample")
        if not os.path.exists(sample_dir):
            os.makedirs(sample_dir)
        time_path = os.path.join(self._work_dir, "time.txt")
        metric_path = os.path.join(self._work_dir, "metric.csv")

        run_config = tf.ConfigProto()
        with tf.Session(config=run_config) as sess:
            factorVAEMetric = FactorVAEMetric(metric_data, sess=sess)
            dSpritesInceptionScore = DSpritesInceptionScore(
                sess=sess,
                do_training=False,
                data=data,
                metadata=metadata,
                latent_values=latent_values,
                network_path="metric_model/DSprites",
                shape_network=shape_network,
                sample_dir=sample_dir)
            dHSICMetric = DHSICMetric(sess=sess, data=data)
            metric_callbacks = [
                factorVAEMetric, dSpritesInceptionScore, dHSICMetric
            ]
            gan = INFOGAN_CR(
                sess=sess,
                checkpoint_dir=checkpoint_dir,
                sample_dir=sample_dir,
                time_path=time_path,
                epoch=self._config["epoch"],
                batch_size=self._config["batch_size"],
                data=data,
                vis_freq=self._config["vis_freq"],
                vis_num_sample=self._config["vis_num_sample"],
                vis_num_rep=self._config["vis_num_rep"],
                latent=latent,
                decoder=decoder,
                infoGANDiscriminator=infoGANDiscriminator,
                crDiscriminator=crDiscriminator,
                gap_start=self._config["gap_start"],
                gap_decrease_times=self._config["gap_decrease_times"],
                gap_decrease=self._config["gap_decrease"],
                gap_decrease_batch=self._config["gap_decrease_batch"],
                cr_coe_start=self._config["cr_coe_start"],
                cr_coe_increase_times=self._config["cr_coe_increase_times"],
                cr_coe_increase=self._config["cr_coe_increase"],
                cr_coe_increase_batch=self._config["cr_coe_increase_batch"],
                info_coe_de=self._config["info_coe_de"],
                info_coe_infod=self._config["info_coe_infod"],
                metric_callbacks=metric_callbacks,
                metric_freq=self._config["metric_freq"],
                metric_path=metric_path,
                output_reverse=self._config["output_reverse"],
                de_lr=self._config["de_lr"],
                infod_lr=self._config["infod_lr"],
                crd_lr=self._config["crd_lr"],
                summary_freq=self._config["summary_freq"])
            gan.build()
            gan.load()

            results = []
            configs = []

            sub_config_manager = ConfigManager(config)
            while sub_config_manager.get_num_left_config() > 0:
                sub_config = sub_config_manager.get_next_config()

                metric_data_path = os.path.join(sub_config["work_dir"],
                                                "metric_data.pkl")
                with open(metric_data_path, "rb") as f:
                    sub_metric_data = pickle.load(f)

                sub_factorVAEMetric = FactorVAEMetric(sub_metric_data,
                                                      sess=sess)
                sub_factorVAEMetric.set_model(gan)
                sub_result = sub_factorVAEMetric.evaluate(-1, -1, -1)

                results.append(sub_result)
                configs.append(sub_config)

            with open(os.path.join(self._work_dir, "cross_evaluation.pkl"),
                      "wb") as f:
                pickle.dump({
                    "results": results,
                    "configs": configs
                },
                            f,
                            protocol=2)