Esempio n. 1
0
    def main(self):
        import os
        import tensorflow as tf
        from gan.load_data import load_dSprites
        from gan.latent import UniformLatent, JointLatent
        from gan.network import Decoder, InfoGANDiscriminator, \
            CrDiscriminator, MetricRegresser
        from gan.infogan_cr import INFOGAN_CR
        from gan.metric import FactorVAEMetric, DSpritesInceptionScore, \
            DHSICMetric

        data, metric_data, latent_values, metadata = \
            load_dSprites("data/dSprites")
        _, height, width, depth = data.shape

        latent_list = []

        for i in range(self._config["uniform_reg_dim"]):
            latent_list.append(
                UniformLatent(in_dim=1,
                              out_dim=1,
                              low=-1.0,
                              high=1.0,
                              q_std=1.0,
                              apply_reg=True))
        if self._config["uniform_not_reg_dim"] > 0:
            latent_list.append(
                UniformLatent(in_dim=self._config["uniform_not_reg_dim"],
                              out_dim=self._config["uniform_not_reg_dim"],
                              low=-1.0,
                              high=1.0,
                              q_std=1.0,
                              apply_reg=False))
        latent = JointLatent(latent_list=latent_list)

        decoder = Decoder(output_width=width,
                          output_height=height,
                          output_depth=depth)
        infoGANDiscriminator = \
            InfoGANDiscriminator(
                output_length=latent.reg_out_dim,
                q_l_dim=self._config["q_l_dim"])
        crDiscriminator = CrDiscriminator(output_length=latent.num_reg_latent)

        shape_network = MetricRegresser(
            output_length=3, scope_name="dSpritesSampleQualityMetric_shape")

        checkpoint_dir = os.path.join(self._work_dir, "checkpoint")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        sample_dir = os.path.join(self._work_dir, "sample")
        if not os.path.exists(sample_dir):
            os.makedirs(sample_dir)
        time_path = os.path.join(self._work_dir, "time.txt")
        metric_path = os.path.join(self._work_dir, "metric.csv")

        run_config = tf.ConfigProto()
        with tf.Session(config=run_config) as sess:
            factorVAEMetric = FactorVAEMetric(metric_data, sess=sess)
            dSpritesInceptionScore = DSpritesInceptionScore(
                sess=sess,
                do_training=False,
                data=data,
                metadata=metadata,
                latent_values=latent_values,
                network_path="metric_model/DSprites",
                shape_network=shape_network,
                sample_dir=sample_dir)
            dHSICMetric = DHSICMetric(sess=sess, data=data)
            metric_callbacks = [
                factorVAEMetric, dSpritesInceptionScore, dHSICMetric
            ]
            gan = INFOGAN_CR(
                sess=sess,
                checkpoint_dir=checkpoint_dir,
                sample_dir=sample_dir,
                time_path=time_path,
                epoch=self._config["epoch"],
                batch_size=self._config["batch_size"],
                data=data,
                vis_freq=self._config["vis_freq"],
                vis_num_sample=self._config["vis_num_sample"],
                vis_num_rep=self._config["vis_num_rep"],
                latent=latent,
                decoder=decoder,
                infoGANDiscriminator=infoGANDiscriminator,
                crDiscriminator=crDiscriminator,
                gap_start=self._config["gap_start"],
                gap_decrease_times=self._config["gap_decrease_times"],
                gap_decrease=self._config["gap_decrease"],
                gap_decrease_batch=self._config["gap_decrease_batch"],
                cr_coe_start=self._config["cr_coe_start"],
                cr_coe_increase_times=self._config["cr_coe_increase_times"],
                cr_coe_increase=self._config["cr_coe_increase"],
                cr_coe_increase_batch=self._config["cr_coe_increase_batch"],
                info_coe_de=self._config["info_coe_de"],
                info_coe_infod=self._config["info_coe_infod"],
                metric_callbacks=metric_callbacks,
                metric_freq=self._config["metric_freq"],
                metric_path=metric_path,
                output_reverse=self._config["output_reverse"],
                de_lr=self._config["de_lr"],
                infod_lr=self._config["infod_lr"],
                crd_lr=self._config["crd_lr"],
                summary_freq=self._config["summary_freq"])
            gan.build()
            gan.train()
Esempio n. 2
0
    def main(self):
        import os
        import tensorflow as tf
        from gan.load_data import load_dSprites
        from gan.latent import GaussianLatent, JointLatent
        from gan.network import VAEDecoder, VAEEncoder, TCDiscriminator, \
            MetricRegresser
        from gan.factorVAE import FactorVAE
        from gan.metric import FactorVAEMetric, DSpritesInceptionScore, \
            DHSICMetric

        data, metric_data, latent_values, metadata = \
            load_dSprites("data/dSprites")
        _, height, width, depth = data.shape

        latent_list = []

        for i in range(self._config["gaussian_dim"]):
            latent_list.append(GaussianLatent(
                in_dim=1, out_dim=1, loc=0.0, scale=1.0, q_std=1.0,
                apply_reg=True))
        latent = JointLatent(latent_list=latent_list)

        decoder = VAEDecoder(
            output_width=width, output_height=height, output_depth=depth)
        encoder = VAEEncoder(output_length=latent.reg_in_dim)
        tcDiscriminator = TCDiscriminator()

        shape_network = MetricRegresser(
            output_length=3,
            scope_name="dSpritesSampleQualityMetric_shape")

        checkpoint_dir = os.path.join(self._work_dir, "checkpoint")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        sample_dir = os.path.join(self._work_dir, "sample")
        if not os.path.exists(sample_dir):
            os.makedirs(sample_dir)
        time_path = os.path.join(self._work_dir, "time.txt")
        metric_path = os.path.join(self._work_dir, "metric.csv")

        run_config = tf.ConfigProto()
        with tf.Session(config=run_config) as sess:
            factorVAEMetric = FactorVAEMetric(metric_data, sess=sess)
            dSpritesInceptionScore = DSpritesInceptionScore(
                sess=sess,
                do_training=False,
                data=data,
                metadata=metadata,
                latent_values=latent_values,
                network_path="metric_model/DSprites",
                shape_network=shape_network,
                sample_dir=sample_dir)
            dHSICMetric = DHSICMetric(
                sess=sess,
                data=data)
            metric_callbacks = [factorVAEMetric,
                                dSpritesInceptionScore,
                                dHSICMetric]
            vae = FactorVAE(
                sess=sess,
                checkpoint_dir=checkpoint_dir,
                sample_dir=sample_dir,
                time_path=time_path,
                epoch=self._config["epoch"],
                batch_size=self._config["batch_size"],
                data=data,
                vis_freq=self._config["vis_freq"],
                vis_num_sample=self._config["vis_num_sample"],
                vis_num_rep=self._config["vis_num_rep"],
                latent=latent,
                decoder=decoder,
                encoder=encoder,
                tcDiscriminator=tcDiscriminator,
                tc_coe=self._config["tc_coe"],
                metric_callbacks=metric_callbacks,
                metric_freq=self._config["metric_freq"],
                metric_path=metric_path,
                output_reverse=self._config["output_reverse"])
            vae.build()
            vae.train()
    def main(self):
        import os
        import tensorflow as tf
        import pickle
        from gan.load_data import load_dSprites
        from gan.latent import GaussianLatent, JointLatent
        from gan.network import VAEDecoder, VAEEncoder, TCDiscriminator, \
            MetricRegresser
        from gan.factorVAE import FactorVAE
        from gan.metric import FactorVAEMetric, DSpritesInceptionScore, \
            DHSICMetric

        data, metric_data, latent_values, metadata = \
            load_dSprites("data/dSprites")
        _, height, width, depth = data.shape

        latent_list = []

        for i in range(self._config["gaussian_dim"]):
            latent_list.append(GaussianLatent(
                in_dim=1, out_dim=1, loc=0.0, scale=1.0, q_std=1.0,
                apply_reg=True))
        latent = JointLatent(latent_list=latent_list)

        decoder = VAEDecoder(
            output_width=width, output_height=height, output_depth=depth)
        encoder = VAEEncoder(output_length=latent.reg_in_dim)
        tcDiscriminator = TCDiscriminator()

        shape_network = MetricRegresser(
            output_length=3,
            scope_name="dSpritesSampleQualityMetric_shape")

        checkpoint_dir = os.path.join(self._work_dir, "checkpoint")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        sample_dir = os.path.join(self._work_dir, "sample")
        if not os.path.exists(sample_dir):
            os.makedirs(sample_dir)
        time_path = os.path.join(self._work_dir, "time.txt")
        metric_path = os.path.join(self._work_dir, "metric.csv")

        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        with tf.Session(config=run_config) as sess:
            factorVAEMetric = FactorVAEMetric(metric_data, sess=sess)
            dSpritesInceptionScore = DSpritesInceptionScore(
                sess=sess,
                do_training=False,
                data=data,
                metadata=metadata,
                latent_values=latent_values,
                network_path="metric_model/DSprites",
                shape_network=shape_network,
                sample_dir=sample_dir)
            dHSICMetric = DHSICMetric(
                sess=sess,
                data=data)
            metric_callbacks = [factorVAEMetric,
                                dSpritesInceptionScore,
                                dHSICMetric]
            vae = FactorVAE(
                sess=sess,
                checkpoint_dir=checkpoint_dir,
                sample_dir=sample_dir,
                time_path=time_path,
                epoch=self._config["epoch"],
                batch_size=self._config["batch_size"],
                data=data,
                vis_freq=self._config["vis_freq"],
                vis_num_sample=self._config["vis_num_sample"],
                vis_num_rep=self._config["vis_num_rep"],
                latent=latent,
                decoder=decoder,
                encoder=encoder,
                tcDiscriminator=tcDiscriminator,
                tc_coe=self._config["tc_coe"],
                metric_callbacks=metric_callbacks,
                metric_freq=self._config["metric_freq"],
                metric_path=metric_path,
                output_reverse=self._config["output_reverse"])
            vae.build()
            vae.load()

            metric_data_groups = []
            L = 100
            M = 1000

            for i in range(M):
                fixed_latent_id = i % 10
                latents_sampled = vae.latent.sample(L)
                latents_sampled[:, fixed_latent_id] = \
                    latents_sampled[0, fixed_latent_id]
                imgs_sampled = vae.sample_from(latents_sampled)
                metric_data_groups.append(
                    {"img": imgs_sampled,
                     "label": fixed_latent_id})

            latents_sampled = vae.latent.sample(data.shape[0] / 10)
            metric_data_eval_std = vae.sample_from(latents_sampled)

            metric_data = {
                "groups": metric_data_groups,
                "img_eval_std": metric_data_eval_std}

            metric_data_path = os.path.join(self._work_dir, "metric_data.pkl")
            with open(metric_data_path, "wb") as f:
                pickle.dump(metric_data, f, protocol=2)
    def main(self):
        import os
        import tensorflow as tf
        import pickle
        from gan.load_data import load_dSprites
        from gan.latent import GaussianLatent, JointLatent
        from gan.network import VAEDecoder, VAEEncoder, TCDiscriminator, \
            MetricRegresser
        from gan.factorVAE import FactorVAE
        from gan.metric import FactorVAEMetric, DSpritesInceptionScore, \
            DHSICMetric

        from gpu_task_scheduler.config_manager import ConfigManager
        from config_mc import config

        data, metric_data, latent_values, metadata = \
            load_dSprites("data/dSprites")
        _, height, width, depth = data.shape

        latent_list = []

        for i in range(self._config["gaussian_dim"]):
            latent_list.append(
                GaussianLatent(in_dim=1,
                               out_dim=1,
                               loc=0.0,
                               scale=1.0,
                               q_std=1.0,
                               apply_reg=True))
        latent = JointLatent(latent_list=latent_list)

        decoder = VAEDecoder(output_width=width,
                             output_height=height,
                             output_depth=depth)
        encoder = VAEEncoder(output_length=latent.reg_in_dim)
        tcDiscriminator = TCDiscriminator()

        shape_network = MetricRegresser(
            output_length=3, scope_name="dSpritesSampleQualityMetric_shape")

        checkpoint_dir = os.path.join(self._work_dir, "checkpoint")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        sample_dir = os.path.join(self._work_dir, "sample")
        if not os.path.exists(sample_dir):
            os.makedirs(sample_dir)
        time_path = os.path.join(self._work_dir, "time.txt")
        metric_path = os.path.join(self._work_dir, "metric.csv")

        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        with tf.Session(config=run_config) as sess:
            factorVAEMetric = FactorVAEMetric(metric_data, sess=sess)
            dSpritesInceptionScore = DSpritesInceptionScore(
                sess=sess,
                do_training=False,
                data=data,
                metadata=metadata,
                latent_values=latent_values,
                network_path="metric_model/DSprites",
                shape_network=shape_network,
                sample_dir=sample_dir)
            dHSICMetric = DHSICMetric(sess=sess, data=data)
            metric_callbacks = [
                factorVAEMetric, dSpritesInceptionScore, dHSICMetric
            ]
            vae = FactorVAE(sess=sess,
                            checkpoint_dir=checkpoint_dir,
                            sample_dir=sample_dir,
                            time_path=time_path,
                            epoch=self._config["epoch"],
                            batch_size=self._config["batch_size"],
                            data=data,
                            vis_freq=self._config["vis_freq"],
                            vis_num_sample=self._config["vis_num_sample"],
                            vis_num_rep=self._config["vis_num_rep"],
                            latent=latent,
                            decoder=decoder,
                            encoder=encoder,
                            tcDiscriminator=tcDiscriminator,
                            tc_coe=self._config["tc_coe"],
                            metric_callbacks=metric_callbacks,
                            metric_freq=self._config["metric_freq"],
                            metric_path=metric_path,
                            output_reverse=self._config["output_reverse"])
            vae.build()
            vae.load()

            results = []
            configs = []

            sub_config_manager = ConfigManager(config)
            while sub_config_manager.get_num_left_config() > 0:
                sub_config = sub_config_manager.get_next_config()

                metric_data_path = os.path.join(sub_config["work_dir"],
                                                "metric_data.pkl")
                with open(metric_data_path, "rb") as f:
                    sub_metric_data = pickle.load(f)

                sub_factorVAEMetric = FactorVAEMetric(sub_metric_data,
                                                      sess=sess)
                sub_factorVAEMetric.set_model(vae)
                sub_result = sub_factorVAEMetric.evaluate(-1, -1, -1)

                results.append(sub_result)
                configs.append(sub_config)

            with open(os.path.join(self._work_dir, "cross_evaluation.pkl"),
                      "wb") as f:
                pickle.dump({
                    "results": results,
                    "configs": configs
                },
                            f,
                            protocol=2)
Esempio n. 5
0
    def main(self):
        import os
        import tensorflow as tf
        from gan.load_data import load_dSprites
        from gan.latent import UniformLatent, JointLatent
        from gan.network import Decoder, InfoGANDiscriminator, \
            CrDiscriminator, MetricRegresser
        from gan.infogan_cr import INFOGAN_CR
        from gan.metric import FactorVAEMetric, DSpritesInceptionScore, \
            DHSICMetric, \
            BetaVAEMetric, SAPMetric, FStatMetric, MIGMetric, DCIMetric
        import pickle

        data, metric_data, latent_values, metadata = \
            load_dSprites("data/dSprites")
        _, height, width, depth = data.shape

        latent_list = []

        for i in range(self._config["uniform_reg_dim"]):
            latent_list.append(UniformLatent(
                in_dim=1, out_dim=1, low=-1.0, high=1.0, q_std=1.0,
                apply_reg=True))
        if self._config["uniform_not_reg_dim"] > 0:
            latent_list.append(UniformLatent(
                in_dim=self._config["uniform_not_reg_dim"],
                out_dim=self._config["uniform_not_reg_dim"],
                low=-1.0, high=1.0, q_std=1.0,
                apply_reg=False))
        latent = JointLatent(latent_list=latent_list)

        decoder = Decoder(
            output_width=width, output_height=height, output_depth=depth)
        infoGANDiscriminator = \
            InfoGANDiscriminator(
                output_length=latent.reg_out_dim,
                q_l_dim=self._config["q_l_dim"])
        crDiscriminator = CrDiscriminator(output_length=latent.num_reg_latent)

        shape_network = MetricRegresser(
            output_length=3,
            scope_name="dSpritesSampleQualityMetric_shape")

        checkpoint_dir = os.path.join(self._work_dir, "checkpoint")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        sample_dir = os.path.join(self._work_dir, "sample")
        if not os.path.exists(sample_dir):
            os.makedirs(sample_dir)
        time_path = os.path.join(self._work_dir, "time.txt")
        metric_path = os.path.join(self._work_dir, "metric.csv")

        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        with tf.Session(config=run_config) as sess:
            factorVAEMetric = FactorVAEMetric(metric_data, sess=sess)
            dSpritesInceptionScore = DSpritesInceptionScore(
                sess=sess,
                do_training=False,
                data=data,
                metadata=metadata,
                latent_values=latent_values,
                network_path="metric_model/DSprites",
                shape_network=shape_network,
                sample_dir=sample_dir)
            dHSICMetric = DHSICMetric(
                sess=sess,
                data=data)
            metric_callbacks = [factorVAEMetric,
                                dSpritesInceptionScore,
                                dHSICMetric]
            gan = INFOGAN_CR(
                sess=sess,
                checkpoint_dir=checkpoint_dir,
                sample_dir=sample_dir,
                time_path=time_path,
                epoch=self._config["epoch"],
                batch_size=self._config["batch_size"],
                data=data,
                vis_freq=self._config["vis_freq"],
                vis_num_sample=self._config["vis_num_sample"],
                vis_num_rep=self._config["vis_num_rep"],
                latent=latent,
                decoder=decoder,
                infoGANDiscriminator=infoGANDiscriminator,
                crDiscriminator=crDiscriminator,
                gap_start=self._config["gap_start"],
                gap_decrease_times=self._config["gap_decrease_times"],
                gap_decrease=self._config["gap_decrease"],
                gap_decrease_batch=self._config["gap_decrease_batch"],
                cr_coe_start=self._config["cr_coe_start"],
                cr_coe_increase_times=self._config["cr_coe_increase_times"],
                cr_coe_increase=self._config["cr_coe_increase"],
                cr_coe_increase_batch=self._config["cr_coe_increase_batch"],
                info_coe_de=self._config["info_coe_de"],
                info_coe_infod=self._config["info_coe_infod"],
                metric_callbacks=metric_callbacks,
                metric_freq=self._config["metric_freq"],
                metric_path=metric_path,
                output_reverse=self._config["output_reverse"],
                de_lr=self._config["de_lr"],
                infod_lr=self._config["infod_lr"],
                crd_lr=self._config["crd_lr"],
                summary_freq=self._config["summary_freq"])
            gan.build()
            gan.load()

            results = {}

            factorVAEMetric_f = FactorVAEMetric(metric_data, sess=sess)
            factorVAEMetric_f.set_model(gan)
            results["FactorVAE"] = factorVAEMetric_f.evaluate(-1, -1, -1)

            betaVAEMetric_f = BetaVAEMetric(metric_data, sess=sess)
            betaVAEMetric_f.set_model(gan)
            results["betaVAE"] = betaVAEMetric_f.evaluate(-1, -1, -1)
            
            sapMetric_f = SAPMetric(metric_data, sess=sess)
            sapMetric_f.set_model(gan)
            results["SAP"] = sapMetric_f.evaluate(-1, -1, -1)

            fStatMetric_f = FStatMetric(metric_data, sess=sess)
            fStatMetric_f.set_model(gan)
            results["FStat"] = fStatMetric_f.evaluate(-1, -1, -1)

            migMetric_f = MIGMetric(metric_data, sess=sess)
            migMetric_f.set_model(gan)
            results["MIG"] = migMetric_f.evaluate(-1, -1, -1)

            for regressor in ["Lasso", "LassoCV", "RandomForest", "RandomForestIBGAN", "RandomForestCV"]:
                dciVAEMetric_f = DCIMetric(metric_data, sess=sess, regressor=regressor)
                dciVAEMetric_f.set_model(gan)
                results["DCI_{}".format(regressor)] = dciVAEMetric_f.evaluate(-1, -1, -1)

            with open(os.path.join(self._work_dir, "final_metrics.pkl"), "wb") as f:
                pickle.dump(results, f)
Esempio n. 6
0
    def main(self):
        import os
        import tensorflow as tf
        from gan.load_data import load_dSprites
        from gan.latent import GaussianLatent, JointLatent
        from gan.network import VAEDecoder, VAEEncoder, TCDiscriminator, \
            MetricRegresser
        from gan.factorVAE import FactorVAE
        from gan.metric import FactorVAEMetric, DSpritesInceptionScore, \
            DHSICMetric, \
            BetaVAEMetric, SAPMetric, FStatMetric, MIGMetric, DCIMetric
        import pickle

        data, metric_data, latent_values, metadata = \
            load_dSprites("data/dSprites")
        _, height, width, depth = data.shape

        latent_list = []

        for i in range(self._config["gaussian_dim"]):
            latent_list.append(
                GaussianLatent(in_dim=1,
                               out_dim=1,
                               loc=0.0,
                               scale=1.0,
                               q_std=1.0,
                               apply_reg=True))
        latent = JointLatent(latent_list=latent_list)

        decoder = VAEDecoder(output_width=width,
                             output_height=height,
                             output_depth=depth)
        encoder = VAEEncoder(output_length=latent.reg_in_dim)
        tcDiscriminator = TCDiscriminator()

        shape_network = MetricRegresser(
            output_length=3, scope_name="dSpritesSampleQualityMetric_shape")

        checkpoint_dir = os.path.join(self._work_dir, "checkpoint")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        sample_dir = os.path.join(self._work_dir, "sample")
        if not os.path.exists(sample_dir):
            os.makedirs(sample_dir)
        time_path = os.path.join(self._work_dir, "time.txt")
        metric_path = os.path.join(self._work_dir, "metric.csv")

        run_config = tf.ConfigProto()
        with tf.Session(config=run_config) as sess:
            factorVAEMetric = FactorVAEMetric(metric_data, sess=sess)
            dSpritesInceptionScore = DSpritesInceptionScore(
                sess=sess,
                do_training=False,
                data=data,
                metadata=metadata,
                latent_values=latent_values,
                network_path="metric_model/DSprites",
                shape_network=shape_network,
                sample_dir=sample_dir)
            dHSICMetric = DHSICMetric(sess=sess, data=data)
            metric_callbacks = [
                factorVAEMetric, dSpritesInceptionScore, dHSICMetric
            ]
            vae = FactorVAE(sess=sess,
                            checkpoint_dir=checkpoint_dir,
                            sample_dir=sample_dir,
                            time_path=time_path,
                            epoch=self._config["epoch"],
                            batch_size=self._config["batch_size"],
                            data=data,
                            vis_freq=self._config["vis_freq"],
                            vis_num_sample=self._config["vis_num_sample"],
                            vis_num_rep=self._config["vis_num_rep"],
                            latent=latent,
                            decoder=decoder,
                            encoder=encoder,
                            tcDiscriminator=tcDiscriminator,
                            tc_coe=self._config["tc_coe"],
                            metric_callbacks=metric_callbacks,
                            metric_freq=self._config["metric_freq"],
                            metric_path=metric_path,
                            output_reverse=self._config["output_reverse"])
            vae.build()
            vae.load()

            results = {}

            factorVAEMetric_f = FactorVAEMetric(metric_data, sess=sess)
            factorVAEMetric_f.set_model(vae)
            results["FactorVAE"] = factorVAEMetric_f.evaluate(-1, -1, -1)

            betaVAEMetric_f = BetaVAEMetric(metric_data, sess=sess)
            betaVAEMetric_f.set_model(vae)
            results["betaVAE"] = betaVAEMetric_f.evaluate(-1, -1, -1)

            sapMetric_f = SAPMetric(metric_data, sess=sess)
            sapMetric_f.set_model(vae)
            results["SAP"] = sapMetric_f.evaluate(-1, -1, -1)

            fStatMetric_f = FStatMetric(metric_data, sess=sess)
            fStatMetric_f.set_model(vae)
            results["FStat"] = fStatMetric_f.evaluate(-1, -1, -1)

            migMetric_f = MIGMetric(metric_data, sess=sess)
            migMetric_f.set_model(vae)
            results["MIG"] = migMetric_f.evaluate(-1, -1, -1)

            for regressor in [
                    "Lasso", "LassoCV", "RandomForest", "RandomForestIBGAN",
                    "RandomForestCV"
            ]:
                dciVAEMetric_f = DCIMetric(metric_data,
                                           sess=sess,
                                           regressor=regressor)
                dciVAEMetric_f.set_model(vae)
                results["DCI_{}".format(regressor)] = dciVAEMetric_f.evaluate(
                    -1, -1, -1)

            with open(os.path.join(self._work_dir, "final_metrics.pkl"),
                      "wb") as f:
                pickle.dump(results, f)