Exemple #1
0
            def pred_step(self, x, initializing, states):
                if initializing:
                    states = elegy.States(net_states=0)
                else:
                    states = elegy.States(net_states=states.net_states + 1)

                return elegy.PredStep(x + 1.0, states)
Exemple #2
0
            def pred_step(self, states):
                nonlocal N
                N = N + 1

                return elegy.PredStep(
                    y_pred=None,
                    states=states.update(net_params=1, net_states=2),
                )
Exemple #3
0
            def pred_step(self, x, states, initializing):
                if initializing:
                    states = elegy.States(net_states=0)
                else:
                    states = elegy.States(net_states=states.net_states + 1)

                return elegy.PredStep(
                    y_pred=1,
                    states=states,
                )
Exemple #4
0
    def pred_step(self, x, rng, net_states, net_params, states, initializing):

        if initializing:
            (z, mean, stddev), enc_variables = self.encoder.init_with_output(
                rng.next(), x, rng)
            logits, dec_variables = self.decoder.init_with_output(
                rng.next(), z)
        else:
            (enc_states, dec_states) = net_states
            (enc_params, dec_params) = net_params
            enc_variables = dict(params=enc_params, **enc_states)
            (z, mean, stddev), enc_variables = self.encoder.apply(
                enc_variables,
                x,
                rng,
                rngs={"params": rng.next()},
                mutable=True)
            dec_variables = dict(params=dec_params, **dec_states)
            logits, dec_variables = self.decoder.apply(
                dec_variables, z, rngs={"params": rng.next()}, mutable=True)

        aux_losses = dict(kl_divergence_loss=2e-1 *
                          kl_divergence(mean, stddev))

        enc_states, enc_params = enc_variables.pop("params")
        dec_states, dec_params = dec_variables.pop("params")
        net_params = (enc_params, dec_params)
        nets_states = (enc_states, dec_states)

        return elegy.PredStep(
            logits,
            states.update(
                net_params=net_params,
                net_states=nets_states,
                rng=rng,
            ),
            aux_losses=aux_losses,
            aux_metrics={},
            summaries=[],
        )
Exemple #5
0
    def pred_step(self, x, states, initializing):
        rng: elegy.RNGSeq = states.rng

        if initializing:
            (z, mean, stddev), enc_variables = self.encoder.init_with_output(
                rng.next(), x, rng)
            logits, dec_variables = self.decoder.init_with_output(
                rng.next(), z)
        else:
            (enc_states, dec_states) = states.net_states
            (enc_params, dec_params) = states.net_params
            enc_variables = dict(params=enc_params, **enc_states)
            (z, mean, stddev), enc_variables = self.encoder.apply(
                enc_variables,
                x,
                rng,
                rngs={"params": rng.next()},
                mutable=True)
            dec_variables = dict(params=dec_params, **dec_states)
            logits, dec_variables = self.decoder.apply(
                dec_variables, z, rngs={"params": rng.next()}, mutable=True)

        elegy.hooks.add_loss("kl_divergence_loss",
                             2e-1 * kl_divergence(mean, stddev))

        enc_states, enc_params = enc_variables.pop("params")
        dec_states, dec_params = dec_variables.pop("params")
        net_params = (enc_params, dec_params)
        nets_states = (enc_states, dec_states)

        return elegy.PredStep(
            logits,
            states.update(
                net_params=net_params,
                net_states=nets_states,
                rng=rng,
            ),
        )