Exemple #1
0
    def test_loss(self):

        lossnet = VaeLossNet(prefix="loss")
        pred = self.network(self.data[0:1, :])
        y_true = self.network.graph_px_g_z.splitter(self.data[0:1, :])
        inputs = VaeLossNet.Input.from_output(
            y_true=y_true,
            model_output=pred,
            weights=VaeLossNet.InputWeight(1.0, 1.0, 1.0, 1.0, 1.0),
        )
        losses = lossnet(inputs)
Exemple #2
0
 def __init__(self, config: VAE.Config, **kwargs):
     tf.keras.Model.__init__(self, **kwargs)
     self.config = config
     self.network = VaeNet(config, **kwargs)
     self.lossnet = VaeLossNet(latent_eps=1e-6, prefix="loss", **kwargs)
     self.weight_getter = Vae.CoolingRegime(config, dtype=self.dtype)
     AutoencoderModelBaseMixin.__init__(
         self,
         self.weight_getter,
         self.network,
         self.config.get_latent_parser_type(),
         self.config.get_fake_output_getter(),
     )
Exemple #3
0
 def test_getNetworkFromConfig(self):
     net = config.get_network_type()(config)
     lossnet = config.get_lossnet_type()(prefix="loss")
     pred = net(self.data[0:1, :])
     y_true = net.graph_px_g_z.splitter(self.data[0:1, :])
     inputs = config.get_lossnet_type().Input.from_output(
         y_true=y_true,
         model_output=pred,
         weights=VaeLossNet.InputWeight(1.0, 1.0, 1.0, 1.0, 1.0),
     )
     losses = lossnet(inputs)
Exemple #4
0
    def test_step(self, data):
        data = data_adapter.expand_1d(data)
        x, y = data
        y_pred = self.network(x, training=False)
        losses = self.loss_fn(y,
                              y_pred,
                              VaeLossNet.InputWeight(),
                              training=False)
        loss = tf.reduce_mean(losses.loss)

        return {
            self._output_keys_renamed[k]: v
            for k, v in losses._asdict().items()
        }
Exemple #5
0
 def from_output(
     y_true: SplitCovariates,
     model_output: GumbleGmvaeNet.Output,
     weights: GumbleGmvaeNetLossNet.InputWeight,
 ) -> MarginalGmVaeLossNet.Input:
     return GumbleGmvaeNetLossNet.Input(
         model_output.py,
         model_output.qy_g_x,
         MarginalGmVaeLossNet.Input.from_MarginalGmVae_output(
             y_true,
             model_output.marginal,
             VaeLossNet.InputWeight(*weights[1:]),
         ),
         weights[0],
     )
Exemple #6
0
 def from_output(
     y_true: SplitCovariates,
     model_output: StackedGmvaeNet.Output,
     weights: StackedGmvaeLossNet.InputWeight,
 ) -> MarginalGmVaeLossNet.Input:
     return StackedGmvaeLossNet.Input(
         model_output.py,
         model_output.qy_g_x,
         [
             MarginalGmVaeLossNet.Input.from_MarginalGmVae_output(
                 y_true, marg, VaeLossNet.InputWeight(*weights[1:]))
             for marg in model_output.marginals
         ],
         weights[0],
     )
Exemple #7
0
 def call(self, step):
     cstep = tf.cast(step, self.dtype)
     kld_z_schedule = self.config.kld_z_schedule(cstep)
     recon_schedule = self.config.recon_schedule(cstep)
     recon_reg_schedule = self.config.recon_reg_schedule(cstep)
     recon_bin_schedule = self.config.recon_bin_schedule(cstep)
     recon_ord_schedule = self.config.recon_ord_schedule(cstep)
     recon_cat_schedule = self.config.recon_cat_schedule(cstep)
     return VaeLossNet.InputWeight(
         kld_z_schedule,
         recon_reg_schedule,
         recon_bin_schedule,
         recon_ord_schedule,
         recon_cat_schedule,
     )
Exemple #8
0
    def loss_fn(
        self,
        y_true,
        y_pred: VaeNet.VaeNetOutput,
        weight=VaeLossNet.InputWeight(),
        training=False,
    ) -> VaeLossNet.output:
        y_true = tf.cast(y_true, dtype=self.dtype)
        y_split = self.network.graph_px_g_z.splitter(y_true)

        loss = self.lossnet.Output(*[
            tf.reduce_mean(x) for x in self.lossnet(
                self.lossnet.Input.from_output(y_split, y_pred, weight),
                training,
            )
        ])
        return loss
Exemple #9
0
    def test_loss(self):

        lossnet = MarginalGmVaeLossNet()
        pred = self.network([self.data[0][0:1, :], self.data[1][0:1, :]])
        y_true = self.network.graph_px_g_z.splitter(self.data[0][0:1, :])
        inputs = MarginalGmVaeLossNet.Input.from_MarginalGmVae_output(
            y_true=y_true,
            model_output=pred,
            weights=VaeLossNet.InputWeight(1.0, 1.0, 1.0, 1.0, 1.0),
        )

        from pprint import pp

        # getShape = lambda x: [v.shape for v in x] if isinstance(x, list) else x.shape
        # pp({k:getShape(v) for k,v in inputs.y_true._asdict().items()}, depth=6, indent=4)
        # print("=====================")
        # pp({k:getShape(v) for k,v in inputs.y_pred._asdict().items()}, depth=6, indent=4)

        losses = lossnet(inputs, False)
Exemple #10
0
class Vae(tf.keras.Model, AutoencoderModelBaseMixin):
    class CoolingRegime(tf.keras.layers.Layer):
        class Config(BaseModel):
            kld_z_schedule: tf.keras.optimizers.schedules.LearningRateSchedule = (
                tfa.optimizers.CyclicalLearningRate(
                    1.0,
                    1.0,
                    step_size=1,
                    scale_fn=lambda x: 1.0,
                    scale_mode="cycle",
                ))
            recon_schedule: tf.keras.optimizers.schedules.LearningRateSchedule = (
                tfa.optimizers.CyclicalLearningRate(
                    1.0,
                    1.0,
                    step_size=1,
                    scale_fn=lambda x: 1.0,
                    scale_mode="cycle",
                ))
            recon_reg_schedule: tf.keras.optimizers.schedules.LearningRateSchedule = (
                tfa.optimizers.CyclicalLearningRate(
                    1.0,
                    1.0,
                    step_size=1,
                    scale_fn=lambda x: 1.0,
                    scale_mode="cycle",
                ))
            recon_bin_schedule: tf.keras.optimizers.schedules.LearningRateSchedule = (
                tfa.optimizers.CyclicalLearningRate(
                    1.0,
                    1.0,
                    step_size=1,
                    scale_fn=lambda x: 1.0,
                    scale_mode="cycle",
                ))
            recon_ord_schedule: tf.keras.optimizers.schedules.LearningRateSchedule = (
                tfa.optimizers.CyclicalLearningRate(
                    1.0,
                    1.0,
                    step_size=1,
                    scale_fn=lambda x: 1.0,
                    scale_mode="cycle",
                ))
            recon_cat_schedule: tf.keras.optimizers.schedules.LearningRateSchedule = (
                tfa.optimizers.CyclicalLearningRate(
                    1.0,
                    1.0,
                    step_size=1,
                    scale_fn=lambda x: 1.0,
                    scale_mode="cycle",
                ))

            class Config:
                arbitrary_types_allowed = True
                smart_union = True

        def __init__(self, config: CoolingRegime.Config, **kwargs):
            super().__init__(**kwargs)
            self.config = config

        def call(self, step):
            cstep = tf.cast(step, self.dtype)
            kld_z_schedule = self.config.kld_z_schedule(cstep)
            recon_schedule = self.config.recon_schedule(cstep)
            recon_reg_schedule = self.config.recon_reg_schedule(cstep)
            recon_bin_schedule = self.config.recon_bin_schedule(cstep)
            recon_ord_schedule = self.config.recon_ord_schedule(cstep)
            recon_cat_schedule = self.config.recon_cat_schedule(cstep)
            return VaeLossNet.InputWeight(
                kld_z_schedule,
                recon_reg_schedule,
                recon_bin_schedule,
                recon_ord_schedule,
                recon_cat_schedule,
            )

    class Config(VaeNet.Config, CoolingRegime.Config):
        pass

    def __init__(self, config: VAE.Config, **kwargs):
        tf.keras.Model.__init__(self, **kwargs)
        self.config = config
        self.network = VaeNet(config, **kwargs)
        self.lossnet = VaeLossNet(latent_eps=1e-6, prefix="loss", **kwargs)
        self.weight_getter = Vae.CoolingRegime(config, dtype=self.dtype)
        AutoencoderModelBaseMixin.__init__(
            self,
            self.weight_getter,
            self.network,
            self.config.get_latent_parser_type(),
            self.config.get_fake_output_getter(),
        )

    @tf.function
    def loss_fn(
        self,
        y_true,
        y_pred: VaeNet.VaeNetOutput,
        weight=VaeLossNet.InputWeight(),
        training=False,
    ) -> VaeLossNet.output:
        y_true = tf.cast(y_true, dtype=self.dtype)
        y_split = self.network.graph_px_g_z.splitter(y_true)

        loss = self.lossnet.Output(*[
            tf.reduce_mean(x) for x in self.lossnet(
                self.lossnet.Input.from_output(y_split, y_pred, weight),
                training,
            )
        ])
        return loss

    @tf.function
    def call(self, x, training=False):
        return self.network(x, training)

    @tf.function
    def latent_sample(self, inputs, y, training=False, samples=1):
        output = self.monte_carlo_estimate(samples,
                                           inputs,
                                           y,
                                           training=training)
        latent = outputs["px_g_z__sample"]
        return latent

    def train_step(self, data, training: bool = False):

        data = data_adapter.expand_1d(data)
        x, y = data
        weights = self.weight_getter(self.optimizer.iterations)

        with backprop.GradientTape() as tape:
            y_pred = self.network(x, training=True)
            losses = self.loss_fn(
                y,
                y_pred,
                weights,
                training=True,
            )
            loss = tf.reduce_mean(losses.loss)

        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
        return {
            self._output_keys_renamed[k]: v
            for k, v in {
                # **{v.name: v.result() for v in self.metrics}
                **losses._asdict(),
                "kld_z_schedule": weights.lambda_z,
            }.items()
        }

    def test_step(self, data):
        data = data_adapter.expand_1d(data)
        x, y = data
        y_pred = self.network(x, training=False)
        losses = self.loss_fn(y,
                              y_pred,
                              VaeLossNet.InputWeight(),
                              training=False)
        loss = tf.reduce_mean(losses.loss)

        return {
            self._output_keys_renamed[k]: v
            for k, v in losses._asdict().items()
        }

    _output_keys_renamed = {
        "kl_z": "losses/kl_z",
        "l_pxgz_reg": "reconstruction/l_pxgz_reg",
        "l_pxgz_bin": "reconstruction/l_pxgz_bin",
        "l_pxgz_ord": "reconstruction/l_pxgz_ord",
        "l_pxgz_cat": "reconstruction/l_pxgz_cat",
        "scaled_l_pxgz": "reconstruction/l_pxgz",
        "scaled_elbo": "losses/scaled_elbo",
        "recon_loss": "losses/recon_loss",
        "loss": "losses/loss",
        "lambda_z": "weight/lambda_z",
        "lambda_reg": "weight/lambda_reg",
        "lambda_bin": "weight/lambda_bin",
        "lambda_ord": "weight/lambda_ord",
        "lambda_cat": "weight/lambda_cat",
        "kld_z_schedule": "weight/lambda_z",
    }