def visualize(self, x, t, n, actions, rewards, done):
        # pre-process image x
        im_x = x.view(-1, self.x_size[0], self.x_size[1], self.x_size[2])
        processed_x = self.process_x(im_x)  # max x length is max(t2) + 1
        processed_x = processed_x.view(x.shape[0], x.shape[1], -1)
        if actions is not None:
            rewards = (rewards[..., None] / 10.0).clamp(-1.0, 1.0)
            action_embs = self.action_embedding(actions)
            processed_x = torch.cat([processed_x, action_embs, rewards], -1)
        else:
            action_embs = None

        # aggregate the belief b
        b = self.b_rnn(processed_x, done)[:, t]  # size: bs, time, layers, dim
        t_encodings = self.time_encoding(
            b.new_ones(b.size(0), dtype=torch.long) *
            ((self.t_diff_max - self.t_diff_min) //
             2)) * self.time_encoding_scale

        # compute z from b
        p_z_bs = []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                p_z_b_mu, p_z_b_logvar = self.z_b[layer](b[:, layer])
            else:
                p_z_b_mu, p_z_b_logvar = self.z_b[layer](torch.cat(
                    [b[:, layer], p_z_b], dim=1))
            p_z_b = ops.reparameterize_gaussian(p_z_b_mu, p_z_b_logvar, True)
            p_z_bs.insert(0, p_z_b_mu)

        z = torch.cat(p_z_bs, dim=1)
        rollout_x = [self.x_z(z)]
        for i in range(n - 1):
            next_z = []
            for layer in range(self.layers - 1, -1, -1):
                if layer == self.layers - 1:
                    if action_embs is not None:
                        inputs = torch.cat(
                            [z, t_encodings, action_embs[:, t + i + 1]], dim=1)
                    else:
                        inputs = torch.cat([z, t_encodings], dim=1)
                    pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](inputs)
                else:
                    if action_embs is not None:
                        inputs = torch.cat([
                            z, pt_z2_z1, t_encodings, action_embs[:, t + i + 1]
                        ],
                                           dim=1)
                    else:
                        inputs = torch.cat([z, pt_z2_z1, t_encodings], dim=1)
                    pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](inputs)
                pt_z2_z1 = ops.reparameterize_gaussian(pt_z2_z1_mu,
                                                       pt_z2_z1_logvar, True)
                next_z.insert(0, pt_z2_z1_mu)

            z = torch.cat(next_z, dim=1)
            rollout_x.append(self.x_z(z))

        return torch.stack(rollout_x, dim=1)
    def option_reconstruction(self, b, actions, t1, t2):
        b = b
        qb_z2_b2_mus, qb_z2_b2_logvars, qb_z2_b2s = [], [], []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](b[:, :, layer])
            else:
                qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](torch.cat(
                    [b[:, :, layer], qb_z2_b2], dim=-1))
            qb_z2_b2_mus.insert(0, qb_z2_b2_mu)
            qb_z2_b2_logvars.insert(0, qb_z2_b2_logvar)

            qb_z2_b2 = ops.reparameterize_gaussian(qb_z2_b2_mu,
                                                   qb_z2_b2_logvar,
                                                   self.training)
            qb_z2_b2s.insert(0, qb_z2_b2)

        zs = torch.cat(qb_z2_b2s, -1)

        lengths = (t2 - t1).squeeze(0)
        sorted_lengths, argsort = lengths.sort(descending=True)
        maxlen = torch.max(lengths, 0)[0].item()
        indices = t1[0, :, None] + torch.arange(
            0, maxlen, device=b.device)[None, :].expand(zs.shape[0], -1)
        indices = indices[:, :, None]
        indices = indices.clamp(0, b.shape[1] - 1)
        b_indices = indices.expand(-1, -1, zs.shape[-1])
        states = torch.gather(zs, 1, b_indices)
        actions = torch.gather(actions, 1, indices.squeeze(-1))
        sorted_states = states[argsort]
        sorted_actions = actions[argsort]

        mean, logvar, sizes = self.action_reconstruction(
            sorted_states, sorted_actions, sorted_lengths)
        mean = mean.contiguous()
        logvar = logvar.contiguous()
        parameters = ops.reparameterize_gaussian(mean, logvar, self.training)
        inferred_actions = apply_option(states, parameters, sizes)

        reconstruction_loss = F.cross_entropy(inferred_actions.transpose(
            -1, -2),
                                              actions,
                                              reduction="none")
        reconstruction_loss = torch.gather(
            torch.cumsum(reconstruction_loss, 1), 1, (lengths - 1)[:, None])

        _, unsort = argsort.sort()
        parameters = parameters[unsort]

        return parameters.detach(), reconstruction_loss, mean, logvar
    def compute_q(self, x, actions, rewards, done):
        # pre-process image x
        im_x = x.view(-1, self.x_size[0], self.x_size[1], self.x_size[2])
        processed_x = self.process_x(im_x)  # max x length is max(t2) + 1
        processed_x = processed_x.view(x.shape[0], x.shape[1], -1)
        if actions is not None:
            rewards = (rewards[..., None] / 10.0).clamp(-1.0, 1.0)
            action_embs = self.action_embedding(actions)
            processed_x = torch.cat([processed_x, action_embs, rewards], -1)

        # aggregate the belief b
        b = self.b_rnn(processed_x, done)  # size: bs, time, layers, dim
        b = b[:, -1]  # size: bs, layers, dim

        zs = []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                z_mu, z_logvar = self.z_b[layer](b[:, layer])
            else:
                z_mu, z_logvar = self.z_b[layer](torch.cat([b[:, layer], z],
                                                           dim=1))

            z = ops.reparameterize_gaussian(z_mu, z_logvar, self.training)
            zs.insert(0, z)

        z = torch.cat(zs, dim=1)
        return self.q_z(z)
Beispiel #4
0
 def forward(self, input_):
     output = self.main(input_) + self.output_bias
     if self.reparameterization:
         mean = output[:, :self.latent_size]
         logvar = output[:, self.latent_size:]
         output = ops.reparameterize_gaussian(mean, logvar, self.training)
     return output
Beispiel #5
0
    def forward(self, x, actions):
        # pre-process image x
        im_x = x.view(-1, self.x_size[0], self.x_size[1], self.x_size[2])
        processed_x = self.process_x(im_x)  # max x length is max(t2) + 1
        processed_x = processed_x.view(x.shape[0], x.shape[1], -1)

        # q_B(z2 | b2)
        qb_z2_b2_mus, qb_z2_b2_logvars, qb_z2_b2s = [], [], []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](processed_x)
            else:
                qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](torch.cat(
                    [processed_x, qb_z2_b2], dim=-1))
            qb_z2_b2_mus.insert(0, qb_z2_b2_mu)
            qb_z2_b2_logvars.insert(0, qb_z2_b2_logvar)

            qb_z2_b2 = ops.reparameterize_gaussian(qb_z2_b2_mu,
                                                   qb_z2_b2_logvar,
                                                   self.training)
            qb_z2_b2s.insert(0, qb_z2_b2)

        qb_z2_b2_mu = torch.cat(qb_z2_b2_mus, dim=1)
        qb_z2_b2_logvar = torch.cat(qb_z2_b2_logvars, dim=1)
        qb_z2_b2 = torch.cat(qb_z2_b2s, dim=1)

        # p_D(x2 | z2)
        pd_x2_z2 = self.x_z(qb_z2_b2.view(im_x.shape[0], -1))
        return (x, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2, pd_x2_z2)
Beispiel #6
0
    def visualize(self, x, t, n):
        # pre-process image x
        processed_x = self.process_x(x)  # x length is t + 1

        # aggregate the belief b
        b = self.b_rnn(processed_x)[:, t]  # size: bs, time, layers, dim

        # compute z from b
        p_z_bs = []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                p_z_b_mu, p_z_b_logvar = self.z_b[layer](b[:, layer])
            else:
                p_z_b_mu, p_z_b_logvar = self.z_b[layer](torch.cat(
                    [b[:, layer], p_z_b], dim=1))
            p_z_b = ops.reparameterize_gaussian(p_z_b_mu, p_z_b_logvar, True)
            p_z_bs.insert(0, p_z_b)

        z = torch.cat(p_z_bs, dim=1)
        rollout_x = []

        for _ in range(n):
            next_z = []
            for layer in range(self.layers - 1, -1,
                               -1):  # TODO optionally condition n
                if layer == self.layers - 1:
                    pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](z)
                else:
                    pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](torch.cat(
                        [z, pt_z2_z1], dim=1))
                pt_z2_z1 = ops.reparameterize_gaussian(pt_z2_z1_mu,
                                                       pt_z2_z1_logvar, True)
                next_z.insert(0, pt_z2_z1)

            z = torch.cat(next_z, dim=1)
            rollout_x.append(self.x_z(z))

        return torch.stack(rollout_x, dim=1)
    def predict_forward(self, qs_z1_b1, option, t_encodings):
        pt_z2_z1_mus, pt_z2_z1_logvars, pt_z2_z1s = [], [], []
        option = self.option_embedding(option)
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](torch.cat(
                    [qs_z1_b1, t_encodings, option], dim=-1))
            else:
                pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](torch.cat(
                    [qs_z1_b1, pt_z2_z1s[0], t_encodings, option], dim=-1))
            pt_z2_z1_mus.insert(0, pt_z2_z1_mu)
            pt_z2_z1_logvars.insert(0, pt_z2_z1_logvar)
            pt_z2_z1s.insert(
                0,
                ops.reparameterize_gaussian(pt_z2_z1_mu, pt_z2_z1_logvar,
                                            self.training))

        pt_z2_z1 = torch.cat(pt_z2_z1s, dim=-1)
        pd_g2_z2_mu = self.g_z(
            torch.cat([qs_z1_b1, pt_z2_z1, option, t_encodings], dim=-1))
        value = self.actor_critic.get_value(pt_z2_z1)

        return pt_z2_z1, pd_g2_z2_mu, value
Beispiel #8
0
    def visualize(self, x, t, n, actions):
        # pre-process image x
        im_x = x.view(-1, self.x_size[0], self.x_size[1], self.x_size[2])
        processed_x = self.process_x(im_x)  # max x length is max(t2) + 1
        processed_x = processed_x.view(x.shape[0], x.shape[1], -1)
        # aggregate the belief b
        # compute z from b
        p_z_bs = []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                p_z_b_mu, p_z_b_logvar = self.z_b[layer](processed_x)
            else:
                p_z_b_mu, p_z_b_logvar = self.z_b[layer](torch.cat(
                    [processed_x, p_z_b_mu], dim=-1))
            p_z_bs.insert(
                0,
                ops.reparameterize_gaussian(p_z_b_mu,
                                            p_z_b_logvar,
                                            sample=False))

        z = torch.cat(p_z_bs, dim=1)
        rollout_x = [self.x_z(z.view(im_x.shape[0], -1))]
        return torch.stack(rollout_x, dim=1)
Beispiel #9
0
    def forward(self, x):
        # sample t1 and t2
        t1 = torch.randint(0,
                           x.size(1) - self.t_diff_max,
                           (self.samples_per_seq, x.size(0)),
                           device=x.device)
        t2 = t1 + torch.randint(self.t_diff_min,
                                self.t_diff_max + 1,
                                (self.samples_per_seq, x.size(0)),
                                device=x.device)
        # x = x[:, :t2.max() + 1]  # usually not required with big enough batch size

        # pre-process image x
        processed_x = self.process_x(x)  # max x length is max(t2) + 1

        # aggregate the belief b
        b = self.b_rnn(processed_x)  # size: bs, time, layers, dim

        # replicate b multiple times
        b = b[None, ...].expand(self.samples_per_seq, -1, -1, -1,
                                -1)  # size: copy, bs, time, layers, dim

        # Element-wise indexing. sizes: bs, layers, dim
        b1 = torch.gather(
            b, 2, t1[..., None, None,
                     None].expand(-1, -1, -1, b.size(3),
                                  b.size(4))).view(-1, b.size(3), b.size(4))
        b2 = torch.gather(
            b, 2, t2[..., None, None,
                     None].expand(-1, -1, -1, b.size(3),
                                  b.size(4))).view(-1, b.size(3), b.size(4))

        # q_B(z2 | b2)
        qb_z2_b2_mus, qb_z2_b2_logvars, qb_z2_b2s = [], [], []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](b2[:, layer])
            else:
                qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](torch.cat(
                    [b2[:, layer], qb_z2_b2], dim=1))
            qb_z2_b2_mus.insert(0, qb_z2_b2_mu)
            qb_z2_b2_logvars.insert(0, qb_z2_b2_logvar)

            qb_z2_b2 = ops.reparameterize_gaussian(qb_z2_b2_mu,
                                                   qb_z2_b2_logvar,
                                                   self.training)
            qb_z2_b2s.insert(0, qb_z2_b2)

        qb_z2_b2_mu = torch.cat(qb_z2_b2_mus, dim=1)
        qb_z2_b2_logvar = torch.cat(qb_z2_b2_logvars, dim=1)
        qb_z2_b2 = torch.cat(qb_z2_b2s, dim=1)

        # q_S(z1 | z2, b1, b2) ~= q_S(z1 | z2, b1)
        qs_z1_z2_b1_mus, qs_z1_z2_b1_logvars, qs_z1_z2_b1s = [], [], []
        for layer in range(self.layers - 1, -1,
                           -1):  # TODO optionally condition t2 - t1
            if layer == self.layers - 1:
                qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar = self.z1_z2_b1[layer](
                    torch.cat([qb_z2_b2, b1[:, layer]], dim=1))
            else:
                qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar = self.z1_z2_b1[layer](
                    torch.cat([qb_z2_b2, b1[:, layer], qs_z1_z2_b1], dim=1))
            qs_z1_z2_b1_mus.insert(0, qs_z1_z2_b1_mu)
            qs_z1_z2_b1_logvars.insert(0, qs_z1_z2_b1_logvar)

            qs_z1_z2_b1 = ops.reparameterize_gaussian(qs_z1_z2_b1_mu,
                                                      qs_z1_z2_b1_logvar,
                                                      self.training)
            qs_z1_z2_b1s.insert(0, qs_z1_z2_b1)

        qs_z1_z2_b1_mu = torch.cat(qs_z1_z2_b1_mus, dim=1)
        qs_z1_z2_b1_logvar = torch.cat(qs_z1_z2_b1_logvars, dim=1)
        qs_z1_z2_b1 = torch.cat(qs_z1_z2_b1s, dim=1)

        # p_T(z2 | z1), also conditions on q_B(z2) from higher layer
        pt_z2_z1_mus, pt_z2_z1_logvars = [], []
        for layer in range(self.layers - 1, -1,
                           -1):  # TODO optionally condition t2 - t1
            if layer == self.layers - 1:
                pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](qs_z1_z2_b1)
            else:
                pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](torch.cat(
                    [qs_z1_z2_b1, qb_z2_b2s[layer + 1]], dim=1))
            pt_z2_z1_mus.insert(0, pt_z2_z1_mu)
            pt_z2_z1_logvars.insert(0, pt_z2_z1_logvar)

        pt_z2_z1_mu = torch.cat(pt_z2_z1_mus, dim=1)
        pt_z2_z1_logvar = torch.cat(pt_z2_z1_logvars, dim=1)

        # p_B(z1 | b1)
        pb_z1_b1_mus, pb_z1_b1_logvars = [], []
        for layer in range(self.layers - 1, -1,
                           -1):  # TODO optionally condition t2 - t1
            if layer == self.layers - 1:
                pb_z1_b1_mu, pb_z1_b1_logvar = self.z_b[layer](b1[:, layer])
            else:
                pb_z1_b1_mu, pb_z1_b1_logvar = self.z_b[layer](torch.cat(
                    [b1[:, layer], qs_z1_z2_b1s[layer + 1]], dim=1))
            pb_z1_b1_mus.insert(0, pb_z1_b1_mu)
            pb_z1_b1_logvars.insert(0, pb_z1_b1_logvar)

        pb_z1_b1_mu = torch.cat(pb_z1_b1_mus, dim=1)
        pb_z1_b1_logvar = torch.cat(pb_z1_b1_logvars, dim=1)

        # p_D(x2 | z2)
        pd_x2_z2 = self.x_z(qb_z2_b2)

        return (x, t2, qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu,
                pb_z1_b1_logvar, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2,
                pt_z2_z1_mu, pt_z2_z1_logvar, pd_x2_z2)
    def forward(self, x, actions, rewards, done, t1, t2):
        if t1 is None:
            t1 = torch.randint(0,
                               x.size(1) - int(self.rl) - self.t_diff_max,
                               (self.samples_per_seq, x.size(0)),
                               device=x.device)
        else:
            t1 = t1[None, :]
        if t2 is None:
            t2 = t1 + torch.randint(self.t_diff_min,
                                    self.t_diff_max + 1,
                                    (self.samples_per_seq, x.size(0)),
                                    device=x.device)
        else:
            t2 = t2[None, :]

        q1, q2, action_embs, b1, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2s, qb_z2_b2, pb_z1_b1_mu, pb_z1_b1_logvar = \
            self.q_and_z_b(x, actions, rewards, done, t1, t2)

        t_encodings = self.time_encoding(
            t2 - t1 - self.t_diff_min) * self.time_encoding_scale
        t_encodings = t_encodings.view(-1, t_encodings.size(-1))
        if action_embs is not None:
            t1_next = t1 + 1
            action_embs = action_embs[None, ...].expand(
                self.samples_per_seq, -1, -1, -1)  # size: copy, bs, time, dim
            a1_next = torch.gather(
                action_embs, 2, t1_next[..., None, None].expand(
                    -1, -1, -1,
                    action_embs.shape[-1])).view(-1, action_embs.shape[-1])

        # q_S(z1 | z2, b1, b2) ~= q_S(z1 | z2, b1)
        qs_z1_z2_b1_mus, qs_z1_z2_b1_logvars, qs_z1_z2_b1s = [], [], []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar = self.z1_z2_b[layer](
                    torch.cat([qb_z2_b2, b1[:, layer], t_encodings], dim=1))
            else:
                qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar = self.z1_z2_b[layer](
                    torch.cat(
                        [qb_z2_b2, b1[:, layer], qs_z1_z2_b1, t_encodings],
                        dim=1))
            qs_z1_z2_b1_mus.insert(0, qs_z1_z2_b1_mu)
            qs_z1_z2_b1_logvars.insert(0, qs_z1_z2_b1_logvar)

            qs_z1_z2_b1 = ops.reparameterize_gaussian(qs_z1_z2_b1_mu,
                                                      qs_z1_z2_b1_logvar,
                                                      self.training)
            qs_z1_z2_b1s.insert(0, qs_z1_z2_b1)

        qs_z1_z2_b1_mu = torch.cat(qs_z1_z2_b1_mus, dim=1)
        qs_z1_z2_b1_logvar = torch.cat(qs_z1_z2_b1_logvars, dim=1)
        qs_z1_z2_b1 = torch.cat(qs_z1_z2_b1s, dim=1)

        # p_T(z2 | z1), also conditions on q_B(z2) from higher layer
        pt_z2_z1_mus, pt_z2_z1_logvars = [], []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](torch.cat(
                    [qs_z1_z2_b1, t_encodings, a1_next], dim=1))
            else:
                pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](torch.cat(
                    [qs_z1_z2_b1, qb_z2_b2s[layer + 1], t_encodings, a1_next],
                    dim=1))
            pt_z2_z1_mus.insert(0, pt_z2_z1_mu)
            pt_z2_z1_logvars.insert(0, pt_z2_z1_logvar)

        pt_z2_z1_mu = torch.cat(pt_z2_z1_mus, dim=1)
        pt_z2_z1_logvar = torch.cat(pt_z2_z1_logvars, dim=1)

        # p_D(x2 | z2)
        pd_x2_z2 = self.x_z(qb_z2_b2)
        # p_D(g2 | z1, z2, a1', t2-t1)
        if self.rl:
            pd_g2_z2_mu = self.g_z(
                torch.cat([qs_z1_z2_b1, qb_z2_b2, a1_next, t_encodings],
                          dim=1))
        else:
            pd_g2_z2_mu = None

        return (x, actions, rewards, done, t1, t2, qs_z1_z2_b1_mu,
                qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar, qb_z2_b2_mu,
                qb_z2_b2_logvar, qb_z2_b2, pt_z2_z1_mu, pt_z2_z1_logvar,
                pd_x2_z2, pd_g2_z2_mu, q1, q2)
    def q_and_z_b(self, x, actions, rewards, done, t1, t2):
        # pre-process image x
        im_x = x.view(-1, self.x_size[0], self.x_size[1], self.x_size[2])
        processed_x = self.process_x(im_x)  # max x length is max(t2) + 1
        processed_x = processed_x.view(x.shape[0], x.shape[1], -1)
        if actions is not None:
            rewards = (rewards[..., None] / 10.0).clamp(-1.0, 1.0)
            action_embs = self.action_embedding(actions)
            processed_x = torch.cat([processed_x, action_embs, rewards], -1)
        else:
            action_embs = None

        # aggregate the belief b
        b = self.b_rnn(processed_x, done)  # size: bs, time, layers, dim

        # replicate b multiple times
        b = b[None, ...].expand(self.samples_per_seq, -1, -1, -1,
                                -1)  # size: copy, bs, time, layers, dim

        # Element-wise indexing. sizes: bs, layers, dim
        b1 = torch.gather(
            b, 2, t1[..., None, None,
                     None].expand(-1, -1, -1, b.size(3),
                                  b.size(4))).view(-1, b.size(3), b.size(4))
        b2 = torch.gather(
            b, 2, t2[..., None, None,
                     None].expand(-1, -1, -1, b.size(3),
                                  b.size(4))).view(-1, b.size(3), b.size(4))

        # q_B(z2 | b2)
        qb_z2_b2_mus, qb_z2_b2_logvars, qb_z2_b2s = [], [], []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](b2[:, layer])
            else:
                qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](torch.cat(
                    [b2[:, layer], qb_z2_b2], dim=1))
            qb_z2_b2_mus.insert(0, qb_z2_b2_mu)
            qb_z2_b2_logvars.insert(0, qb_z2_b2_logvar)

            qb_z2_b2 = ops.reparameterize_gaussian(qb_z2_b2_mu,
                                                   qb_z2_b2_logvar,
                                                   self.training)
            qb_z2_b2s.insert(0, qb_z2_b2)

        qb_z2_b2_mu = torch.cat(qb_z2_b2_mus, dim=1)
        qb_z2_b2_logvar = torch.cat(qb_z2_b2_logvars, dim=1)
        qb_z2_b2 = torch.cat(qb_z2_b2s, dim=1)

        # p_B(z1 | b1)
        pb_z1_b1_mus, pb_z1_b1_logvars = [], []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                pb_z1_b1_mu, pb_z1_b1_logvar = self.z_b[layer](b1[:, layer])
            else:
                pb_z1_b1_mu, pb_z1_b1_logvar = self.z_b[layer](torch.cat(
                    [b1[:, layer], pb_z1_b1], dim=1))
            pb_z1_b1_mus.insert(0, pb_z1_b1_mu)
            pb_z1_b1_logvars.insert(0, pb_z1_b1_logvar)
            pb_z1_b1 = ops.reparameterize_gaussian(pb_z1_b1_mu,
                                                   pb_z1_b1_logvar,
                                                   self.training)

        pb_z1_b1_mu = torch.cat(pb_z1_b1_mus, dim=1)
        pb_z1_b1_logvar = torch.cat(pb_z1_b1_logvars, dim=1)

        if self.rl:
            # Q values
            q1 = self.q_z(pb_z1_b1_mu)
            q2 = self.q_z(qb_z2_b2_mu)
        else:
            q1, q2 = 0, 0

        return q1, q2, action_embs, b1, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2s, qb_z2_b2, pb_z1_b1_mu, pb_z1_b1_logvar
    def predictive_control(self,
                           x,
                           actions,
                           done,
                           rewards,
                           num_rollouts=100,
                           rollout_length=1,
                           option=None,
                           jump_length=10,
                           gamma=0.99,
                           boltzmann=True):
        with torch.no_grad():
            # pre-process image x
            im_x = x.view(-1, self.x_size[0], self.x_size[1], self.x_size[2])
            processed_x = self.process_x(im_x)  # max x length is max(t2) + 1
            processed_x = processed_x.view(x.shape[0], x.shape[1], -1)
            if actions is not None:
                rewards = (rewards[..., None] / 10.0).clamp(0.0, 2.0)
                action_embs = self.action_embedding(actions)
                processed_x = torch.cat([processed_x, action_embs, rewards],
                                        -1)
            else:
                action_embs = None

            # aggregate the belief b
            # size: bs, rollout, layers, dim
            b = self.b_rnn(processed_x,
                           done)[:, -1][:,
                                        None].expand(-1, num_rollouts, -1, -1)

            # q_B(z2 | b2)
            qb_z2_b2_mus, qb_z2_b2_logvars, qb_z2_b2s = [], [], []
            for layer in range(self.layers - 1, -1, -1):
                if layer == self.layers - 1:
                    qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](b[:, :,
                                                                     layer])
                else:
                    qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](torch.cat(
                        [b[:, :, layer], qb_z2_b2], dim=-1))
                qb_z2_b2_mus.insert(0, qb_z2_b2_mu)
                qb_z2_b2_logvars.insert(0, qb_z2_b2_logvar)

                qb_z2_b2 = ops.reparameterize_gaussian(qb_z2_b2_mu,
                                                       qb_z2_b2_logvar,
                                                       self.training)
                qb_z2_b2s.insert(0, qb_z2_b2)

            initial = torch.cat(qb_z2_b2s, dim=-1)[:, -1].unsqueeze(1).expand(
                -1, num_rollouts, -1)

            distributions = Normal(0, 1)
            sizes = self.actor_critic.base.total_sizes
            parameters = distributions.sample(
                (x.shape[0], num_rollouts, int(np.sum(sizes)))).to(b.device)

            if option is None:
                current = initial
                running = 0
                jump_encoding = torch.tensor([jump_length],
                                             device=b.device,
                                             dtype=torch.long)
                jump_encoding = jump_encoding[None, None, :].expand(
                    current.shape[0], current.shape[1], -1)
                jump_encoding = self.time_encoding(jump_encoding).squeeze(-2)
                for i in range(rollout_length):
                    current, pd_g2, value = self.predict_forward(
                        current, parameters, jump_encoding)

                    pd_g2 = pd_g2 * (gamma**(i * jump_length))
                    running = pd_g2 + running

                running = value * (gamma**((i + 1) * jump_length)) + running

                print(running[0].mean().item(), running[0].var().item(),
                      running[0].max().item(), running[0].min().item())
                best = torch.max(running, 1)[1]
                indices = best[:, None, ].expand(-1, -1, parameters.shape[-1])
                option = torch.gather(parameters, 1, indices).squeeze(-2)
            action = apply_option(initial[:, 0], option, sizes)
            if boltzmann:
                print(F.softmax(action[0].flatten()))
                dist = Categorical(logits=action)
                action = dist.sample()
            else:
                action = torch.max(action, -1)[1]

        return action.cpu().numpy(), option
    def visualize(self, x, t, n, actions, rewards, done):
        # pre-process image x
        im_x = x.view(-1, self.x_size[0], self.x_size[1], self.x_size[2])
        processed_x = self.process_x(im_x)  # max x length is max(t2) + 1
        processed_x = processed_x.view(x.shape[0], x.shape[1], -1)

        if actions is not None:
            rewards = (rewards[..., None] / 10.0).clamp(0.0, 2.0)
            action_embs = self.action_embedding(actions)
            processed_x = torch.cat([processed_x, action_embs, rewards], -1)
        else:
            action_embs = None

        # aggregate the belief b
        full_b = self.b_rnn(processed_x, done)  # size: bs, time, layers, dim
        b = full_b[:, t]  # Just pick out relevant time
        t_encodings = self.time_encoding(
            b.new_zeros(b.size(0),
                        dtype=torch.long)) * self.time_encoding_scale
        t1 = torch.zeros(x.shape[0], device=x.device)
        t2 = torch.zeros(x.shape[0], device=x.device) + x.shape[1] - 1
        option, _, _, _ = self.option_reconstruction(full_b, actions,
                                                     t1.unsqueeze(0).long(),
                                                     t2.unsqueeze(0).long())
        processed_option = self.option_embedding(option)
        # compute z from b
        p_z_bs = []
        for layer in range(self.layers - 1, -1, -1):
            if layer == self.layers - 1:
                p_z_b_mu, p_z_b_logvar = self.z_b[layer](b[:, layer])
            else:
                p_z_b_mu, p_z_b_logvar = self.z_b[layer](torch.cat(
                    [b[:, layer], p_z_b], dim=1))
            p_z_b = ops.reparameterize_gaussian(p_z_b_mu, p_z_b_logvar, True)
            p_z_bs.insert(0, p_z_b_mu)

        z = torch.cat(p_z_bs, dim=1)
        rollout_x = [self.x_z(z)]
        for i in range(n - 1):
            next_z = []
            for layer in range(self.layers - 1, -1, -1):
                if layer == self.layers - 1:
                    if actions is not None:
                        inputs = torch.cat([z, t_encodings, processed_option],
                                           dim=1)
                    else:
                        inputs = torch.cat([z, t_encodings], dim=1)
                    pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](inputs)
                else:
                    if actions is not None:
                        inputs = torch.cat(
                            [z, pt_z2_z1, t_encodings, processed_option],
                            dim=1)
                    else:
                        inputs = torch.cat([z, pt_z2_z1, t_encodings], dim=1)
                    pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](inputs)
                pt_z2_z1 = ops.reparameterize_gaussian(pt_z2_z1_mu,
                                                       pt_z2_z1_logvar, True)
                next_z.insert(0, pt_z2_z1_mu)

            z = torch.cat(next_z, dim=1)
            rollout_x.append(self.x_z(z))

        return torch.stack(rollout_x, dim=1)