def pred_step(self, x, initializing, states): if initializing: states = elegy.States(net_states=0) else: states = elegy.States(net_states=states.net_states + 1) return elegy.PredStep(x + 1.0, states)
def pred_step(self, states): nonlocal N N = N + 1 return elegy.PredStep( y_pred=None, states=states.update(net_params=1, net_states=2), )
def pred_step(self, x, states, initializing): if initializing: states = elegy.States(net_states=0) else: states = elegy.States(net_states=states.net_states + 1) return elegy.PredStep( y_pred=1, states=states, )
def pred_step(self, x, rng, net_states, net_params, states, initializing): if initializing: (z, mean, stddev), enc_variables = self.encoder.init_with_output( rng.next(), x, rng) logits, dec_variables = self.decoder.init_with_output( rng.next(), z) else: (enc_states, dec_states) = net_states (enc_params, dec_params) = net_params enc_variables = dict(params=enc_params, **enc_states) (z, mean, stddev), enc_variables = self.encoder.apply( enc_variables, x, rng, rngs={"params": rng.next()}, mutable=True) dec_variables = dict(params=dec_params, **dec_states) logits, dec_variables = self.decoder.apply( dec_variables, z, rngs={"params": rng.next()}, mutable=True) aux_losses = dict(kl_divergence_loss=2e-1 * kl_divergence(mean, stddev)) enc_states, enc_params = enc_variables.pop("params") dec_states, dec_params = dec_variables.pop("params") net_params = (enc_params, dec_params) nets_states = (enc_states, dec_states) return elegy.PredStep( logits, states.update( net_params=net_params, net_states=nets_states, rng=rng, ), aux_losses=aux_losses, aux_metrics={}, summaries=[], )
def pred_step(self, x, states, initializing): rng: elegy.RNGSeq = states.rng if initializing: (z, mean, stddev), enc_variables = self.encoder.init_with_output( rng.next(), x, rng) logits, dec_variables = self.decoder.init_with_output( rng.next(), z) else: (enc_states, dec_states) = states.net_states (enc_params, dec_params) = states.net_params enc_variables = dict(params=enc_params, **enc_states) (z, mean, stddev), enc_variables = self.encoder.apply( enc_variables, x, rng, rngs={"params": rng.next()}, mutable=True) dec_variables = dict(params=dec_params, **dec_states) logits, dec_variables = self.decoder.apply( dec_variables, z, rngs={"params": rng.next()}, mutable=True) elegy.hooks.add_loss("kl_divergence_loss", 2e-1 * kl_divergence(mean, stddev)) enc_states, enc_params = enc_variables.pop("params") dec_states, dec_params = dec_variables.pop("params") net_params = (enc_params, dec_params) nets_states = (enc_states, dec_states) return elegy.PredStep( logits, states.update( net_params=net_params, net_states=nets_states, rng=rng, ), )