Пример #1
0
    def run(self, inputs, use_learned_prior=True):
        """Policy interface for model. Runs decoder if action plan is empty, otherwise returns next action from action plan.
        :arg inputs: dict with 'states', 'actions', 'images' keys from environment
        :arg use_learned_prior: if True, uses learned prior otherwise samples latent from uniform prior
        """
        if not self._action_plan:
            inputs = map2torch(inputs, device=self.device)

            # sample latent variable from prior
            z = self.compute_learned_prior(self._learned_prior_input(inputs), first_only=True).sample() \
                if use_learned_prior else Gaussian(torch.zeros((1, self._hp.nz_vae*2), device=self.device)).sample()

            # decode into action plan
            z = z.repeat(
                self._hp.batch_size, 1
            )  # this is a HACK flat LSTM decoder can only take batch_size inputs
            input_obs = self._learned_prior_input(inputs).repeat(
                self._hp.batch_size, 1)
            actions = self.decode(z,
                                  cond_inputs=input_obs,
                                  steps=self._hp.n_rollout_steps)[0]
            self._action_plan = deque(split_along_axis(map2np(actions),
                                                       axis=0))

        return AttrDict(action=self._action_plan.popleft()[None])
Пример #2
0
    def loss(self, model_output, inputs):
        """Loss computation of the SPIRL model.
        :arg model_output: output of SPIRL model forward pass
        :arg inputs: dict with 'states', 'actions', 'images' keys from data loader
        """
        losses = AttrDict()

        # reconstruction loss, assume unit variance model output Gaussian
        losses.rec_mse = NLL(self._hp.reconstruction_mse_weight) \
            (Gaussian(model_output.reconstruction, torch.zeros_like(model_output.reconstruction)),
             self._regression_targets(inputs))

        # KL loss
        losses.kl_loss = KLDivLoss(self._hp.kl_div_weight)(model_output.q, model_output.p)

        # learned skill prior net loss
        losses.q_hat_loss = self._compute_learned_prior_loss(model_output)

        losses.total = self._compute_total_loss(losses)
        return losses
Пример #3
0
    def forward(self, e_l, e_r):
        g1 = self.p1(e_l, e_r)
        z1 = Gaussian(g1).sample()
        g2 = self.q_p_shared(z1, e_l, e_r)  # make sure its the same order of arguments as in usage above!!

        return SequentialGaussian_SharedPQ(g1, z1, g2)
Пример #4
0
 def forward(self, e_l, e_r, e_tilde):
     g1 = self.q1(e_l, e_r, e_tilde)
     z1 = Gaussian(g1).sample()
     g2 = self.q2(z1, e_l, e_r)
     return SequentialGaussian_SharedPQ(g1, z1, g2)
Пример #5
0
 def forward(self, *inputs):
     return Gaussian(super().forward(*inputs)).tensor()
Пример #6
0
 def compute(self, estimates, targets):
     if not isinstance(estimates, Gaussian): estimates = Gaussian(estimates)
     if not isinstance(targets, Gaussian): targets = Gaussian(targets)
     kl_divergence = estimates.kl_divergence(targets)
     return kl_divergence