def run(self, inputs, use_learned_prior=True): """Policy interface for model. Runs decoder if action plan is empty, otherwise returns next action from action plan. :arg inputs: dict with 'states', 'actions', 'images' keys from environment :arg use_learned_prior: if True, uses learned prior otherwise samples latent from uniform prior """ if not self._action_plan: inputs = map2torch(inputs, device=self.device) # sample latent variable from prior z = self.compute_learned_prior(self._learned_prior_input(inputs), first_only=True).sample() \ if use_learned_prior else Gaussian(torch.zeros((1, self._hp.nz_vae*2), device=self.device)).sample() # decode into action plan z = z.repeat( self._hp.batch_size, 1 ) # this is a HACK flat LSTM decoder can only take batch_size inputs input_obs = self._learned_prior_input(inputs).repeat( self._hp.batch_size, 1) actions = self.decode(z, cond_inputs=input_obs, steps=self._hp.n_rollout_steps)[0] self._action_plan = deque(split_along_axis(map2np(actions), axis=0)) return AttrDict(action=self._action_plan.popleft()[None])
def loss(self, model_output, inputs): """Loss computation of the SPIRL model. :arg model_output: output of SPIRL model forward pass :arg inputs: dict with 'states', 'actions', 'images' keys from data loader """ losses = AttrDict() # reconstruction loss, assume unit variance model output Gaussian losses.rec_mse = NLL(self._hp.reconstruction_mse_weight) \ (Gaussian(model_output.reconstruction, torch.zeros_like(model_output.reconstruction)), self._regression_targets(inputs)) # KL loss losses.kl_loss = KLDivLoss(self._hp.kl_div_weight)(model_output.q, model_output.p) # learned skill prior net loss losses.q_hat_loss = self._compute_learned_prior_loss(model_output) losses.total = self._compute_total_loss(losses) return losses
def forward(self, e_l, e_r): g1 = self.p1(e_l, e_r) z1 = Gaussian(g1).sample() g2 = self.q_p_shared(z1, e_l, e_r) # make sure its the same order of arguments as in usage above!! return SequentialGaussian_SharedPQ(g1, z1, g2)
def forward(self, e_l, e_r, e_tilde): g1 = self.q1(e_l, e_r, e_tilde) z1 = Gaussian(g1).sample() g2 = self.q2(z1, e_l, e_r) return SequentialGaussian_SharedPQ(g1, z1, g2)
def forward(self, *inputs): return Gaussian(super().forward(*inputs)).tensor()
def compute(self, estimates, targets): if not isinstance(estimates, Gaussian): estimates = Gaussian(estimates) if not isinstance(targets, Gaussian): targets = Gaussian(targets) kl_divergence = estimates.kl_divergence(targets) return kl_divergence