Exemplo n.º 1
0
class RecKoopmanModel(BaseModel):
    def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, u_dim,
                 n_objects, I_factor=10, n_blocks=1, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]):
        super().__init__()
        out_channels = 1
        n_layers = int(np.log2(image_size[0])) - 1

        self.u_dim = u_dim

        # Temporal encoding buffers
        if n_timesteps > 1:
            t = torch.linspace(-1, 1, n_timesteps)
            # Add as constant, with extra dims for N and C
            self.register_buffer('t_grid', t)

        # Set state dim with config, depending on how many time-steps we want to take into account
        self.image_size = image_size
        self.n_timesteps = n_timesteps
        self.state_dim = feat_dim
        self.I_factor = I_factor
        self.psteps = psteps
        self.g_dim = g_dim
        self.n_objects = n_objects

        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()

        self.linear_g = nn.Linear(g_dim, g_dim)
        self.linear_u = nn.Linear(g_dim, u_dim)
        self.initial_conditions = nn.Sequential(nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps),
                                                nn.ReLU(),
                                                nn.Linear(feat_dim * n_timesteps, g_dim * 2))

        self.image_encoder = ImageEncoder(in_channels, feat_dim * 2, n_objects, ngf, n_layers)  # feat_dim * 2 if sample here
        self.image_decoder = ImageDecoder(g_dim, out_channels, ngf, n_layers)
        self.koopman = KoopmanOperators(feat_dim * 2, nf_particle * 2, nf_effect * 2, g_dim, u_dim, n_timesteps, n_blocks)

        #Object attention:
        # if n_objects > 1:
        #     self.obj_attention = ObjectAttention(in_channels, feat_dim, n_objects)

    def _get_full_state(self, x, T):

        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1
        x = x.reshape(-1, T, *x.shape[1:])
        new_x = []
        for t in range(new_T):
            new_x.append(torch.cat([x[:, t + idx]
                         for idx in range(self.n_timesteps)], dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=2)

        return new_x.reshape(-1, *new_x.shape[3:]), new_T

    def forward(self, input):
        bs, T, ch, h, w = input.shape

        f = self.image_encoder(input)
        # Concatenates the features from current time and previous to create full state
        f, T = self._get_full_state(f, T)
        # TODO: Might be bug. I'm reshaping with objects before T.
        f = f.view(torch.Size([bs * self.n_objects, T]) + f.size()[1:])

        f_mu, f_logvar, a_mu, a_logvar = [], [], [], []
        gs, us = [], []
        for t in range(T):
            if t==0:
                prev_sample = torch.zeros_like(f[:, 0, :self.g_dim])
                f_t = torch.cat([f[:, t], prev_sample], dim=-1)
                g = self.koopman.to_g(f_t, self.psteps)
                # g = self.initial_conditions(f[:, t])
            else:
                f_t = torch.cat([f[:, t], prev_sample], dim=-1)
                g = self.koopman.to_g(f_t, self.psteps)
                # TODO: provar recurrent, suposo
                # g, prev_hidden, u = self.koopman.to_g(f_t, prev_hidden) #,psteps=self.psteps
            g_mu, g_logvar = torch.chunk(g, 2, dim=-1)
            g = _sample_latent_simple(g_mu, g_logvar)

            # if t < T-1:
            #     u_mu, u_logvar = torch.chunk(u, 2, dim=-1)
            #     u = _sample_latent_simple(u_mu, u_logvar)
            # else:
            #     u_mu, u_logvar = torch.chunk(torch.zeros_like(u), 2, dim=-1)
            #     u = u_mu
            # g, u = self.linear_g(g), self.sigmoid(self.linear_u(u))
            g, u = self.linear_g(g), self.sigmoid(self.linear_u(g))

            prev_sample = g

            gs.append(g)
            us.append(u)

            f_mu.append(g_mu)
            f_logvar.append(g_logvar)
            # a_mu.append(u_mu)
            # a_logvar.append(u_logvar)

        g = torch.stack(gs, dim=1)
        f_mu = torch.stack(f_mu, dim=1)
        f_logvar = torch.stack(f_logvar, dim=1)

        u = torch.stack(us, dim=1)
        # u_zeros = torch.zeros_likes_like(u)
        # u = torch.where(u > 0.8, u, u_zeros)
        # a_mu = torch.stack(a_mu, dim=1)
        # a_logvar = torch.stack(a_logvar, dim=1)

        free_pred = 0
        if free_pred > 0:
            G_tilde = g[:, 1:-1-free_pred, None]  # new axis corresponding to N number of objects
            H_tilde = g[:, 2:-free_pred, None]
        else:
            G_tilde = g[:, 1:-1, None]  # new axis corresponding to N number of objects
            H_tilde = g[:, 2:, None]

        A, B, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, U=u[:, 1:-1], I_factor=self.I_factor)
        # TODO: clamp A before invert. Threshold u over 0.9

        # A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A(G=G_tilde, H=H_tilde, I_factor=self.I_factor)
        # TODO: Try simulating backwards
        # B=None
        # Rollout. From observation in time 0, predict with Koopman operator
        # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B)
        # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None)

        '''Simple version of the reverse rollout'''
        # G_for_pred_rev = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A_pinv, B=None)
        # G_for_pred = torch.flip(G_for_pred_rev, dims=[1])

        G_for_pred = torch.cat([g[:,0:1],self.koopman.simulate(T=T - 2, g=g[:, 1], u=u, A=A, B=B)], dim=1)


        # Option 1: use the koopman object decoder
        # s_for_rec = self.koopman.to_s(gcodes=_get_flat(g),
        #                               pstep=self.psteps)
        # s_for_pred = self.koopman.to_s(gcodes=_get_flat(torch.cat([g[:, :1],G_for_pred], dim=1)),
        #                                pstep=self.psteps)
        # Option 2: we don't use the koopman object decoder
        s_for_rec = _get_flat(g)
        s_for_pred = _get_flat(G_for_pred)

        # Convolutional decoder. Normally Spatial Broadcasting decoder
        out_rec = self.image_decoder(s_for_rec)
        out_pred = self.image_decoder(s_for_pred)

        returned_g = torch.cat([g, G_for_pred], dim=1)
        returned_mus = torch.cat([f_mu], dim=-1)
        returned_logvars = torch.cat([f_logvar], dim=-1)
        # returned_mus = torch.cat([f_mu, a_mu], dim=-1)
        # returned_logvars = torch.cat([f_logvar, a_logvar], dim=-1)

        o_touple = (out_rec, out_pred, returned_g.reshape(-1, returned_g.size(-1)),
                    returned_mus.reshape(-1, returned_mus.size(-1)),
                    returned_logvars.reshape(-1, returned_logvars.size(-1)))
                    # f_mu.reshape(-1, f_mu.size(-1)),
                    # f_logvar.reshape(-1, f_logvar.size(-1)))
        o = [item.reshape(torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple]
        # Option 1: one object mapped to 0
        # shape_o0 = o[0].shape
        # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:])
        # o[0][:,0] = o[0][:,0]*0
        # o[0] = o[0].reshape(*shape_o0)

        # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2]]
        o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]]
        o[3:5] = [item.reshape(bs, self.n_objects, *item.shape[1:]) for item in o[3:5]]

        o.append(A)
        o.append(B)
        o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1]))
        # # Show images
        # plt.imshow(input[0, 0, :].reshape(16, 16).cpu().detach().numpy())
        # plt.savefig('test_attention.png')
        return o
Exemplo n.º 2
0
class cswm_KoopmanModel(BaseModel):
    def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim,
                 I_factor=10, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]):
        super().__init__()
        out_channels = 1
        n_layers = int(np.log2(image_size[0])) - 1
        num_objects = 1

        # Set state dim with config, depending on how many time-steps we want to take into account
        self.image_size = image_size
        self.n_timesteps = n_timesteps
        self.state_dim = feat_dim
        self.I_factor = I_factor
        self.psteps = psteps
        self.g_dim = g_dim

        self.image_encoder = EncoderCNNLarge(
            input_dim=in_channels,
            hidden_dim=feat_dim // 8,
            output_dim=feat_dim,
            num_objects=num_objects)

        self.obj_encoder = EncoderMLP(
            input_dim=np.prod(image_size) // (8 ** 2) * feat_dim,
            hidden_dim=feat_dim // 2,
            output_dim=feat_dim,
            num_objects=num_objects)

        self.state_encoder = EncoderMLP(
            input_dim=feat_dim * self.n_timesteps,
            hidden_dim=feat_dim,
            output_dim=g_dim,
            num_objects=num_objects)

        self.state_decoder = EncoderMLP(
            input_dim=g_dim,
            hidden_dim=feat_dim,
            output_dim=feat_dim,
            num_objects=num_objects)

        # self.transition_model = TransitionGNN(
        #     input_dim=embedding_dim,
        #     hidden_dim=hidden_dim,
        #     action_dim=action_dim,
        #     num_objects=num_objects,
        #     ignore_action=ignore_action,
        #     copy_action=copy_action)

        self.image_decoder = SBImageDecoder(feat_dim, out_channels, ngf, n_layers, image_size)
        self.koopman = KoopmanOperators(self.state_dim, nf_particle, nf_effect, g_dim, n_timesteps)

    def _add_positional_encoding(self, x):
        x = torch.cat((self.x_grid_enc.expand(*x.shape[:2], -1, -1, -1),
                       self.y_grid_enc.expand(*x.shape[:2], -1, -1, -1),
                       x), dim=-3)

        return x

    def _get_full_state(self, x, T):
        new_T = T - self.n_timesteps + 1
        x = x.reshape(-1, T, *x.shape[1:])
        new_x = []
        for t in range(new_T):
            new_x.append(torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1))

        new_x = torch.stack(new_x, dim=1)
        return new_x.reshape(-1, *new_x.shape[2:]), new_T

    def forward(self, input):
        bs, T, ch, h, w = input.shape

        cnnf = self.image_encoder(input.reshape(-1, ch, h, w))
        f = self.obj_encoder(cnnf)

        # Note: test image encoder with attention
        # f_mu, f_logvar = torch.chunk(self.att_image_encoder(input.reshape(-1, ch, h, w)), 2, dim=-1)
        # f = _sample(f_mu, f_logvar)

        # input = self._add_positional_encoding(input)

        f, T = self._get_full_state(f, T)
        # f = f.view(torch.Size([bs, T]) + f.size()[1:])
        # g, G_for_pred = f, f

        g = self.koopman.to_g(f, self.psteps)
        g = g.view(torch.Size([bs, T]) + g.size()[1:])

        G_tilde = g[:, :-1, None]  # new axis corresponding to N number of objects
        H_tilde = g[:, 1:, None]

        A, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, I_factor=self.I_factor)  # TODO: maybe

        ''' rollout: BT x D '''
        G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], A=A)  # Rollout from initial observation to the last

        s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), pstep=self.psteps)
        s_for_pred = self.koopman.to_s(gcodes=_get_flat(G_for_pred), pstep=self.psteps)

        out_rec = self.image_decoder(s_for_rec)
        out_pred = self.image_decoder(s_for_pred)

        o_touple = (out_rec, out_pred, g.reshape(-1, g.size(-1)), f_mu.reshape(-1, f_mu.size(-1)),
                    f_logvar.reshape(-1, f_logvar.size(-1)))
        o = [item.reshape(torch.Size([bs, -1]) + item.size()[1:]) for item in o_touple]

        return o
class KoopmanModel(BaseModel):
    def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim,
                 n_objects, I_factor=10, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]):
        super().__init__()
        out_channels = 1
        n_layers = int(np.log2(image_size[0])) - 1

        # Positional encoding buffers
        x = torch.linspace(-1, 1, image_size[0])
        y = torch.linspace(-1, 1, image_size[1])
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid_enc', x_grid.view((1, 1, 1) + x_grid.shape))
        self.register_buffer('y_grid_enc', y_grid.view((1, 1, 1) + y_grid.shape))

        # Temporal encoding buffers
        if n_timesteps > 1:
            t = torch.linspace(-1, 1, n_timesteps)
            # Add as constant, with extra dims for N and C
            self.register_buffer('t_grid', t)

        # Set state dim with config, depending on how many time-steps we want to take into account
        self.image_size = image_size
        self.n_timesteps = n_timesteps
        self.state_dim = feat_dim
        self.I_factor = I_factor
        self.psteps = psteps
        self.g_dim = g_dim
        self.n_objects = n_objects

        self.image_encoder = ImageEncoder(in_channels, feat_dim * 2, n_objects, ngf, n_layers)  # feat_dim * 2 if sample here
        self.image_decoder = ImageDecoder(g_dim, out_channels, ngf, n_layers)
        self.koopman = KoopmanOperators(feat_dim * 2, nf_particle * 2, nf_effect * 2, g_dim, n_timesteps)

        #Object attention:
        # if n_objects > 1:
        #     self.obj_attention = ObjectAttention(in_channels, feat_dim, n_objects)

    def _get_full_state(self, x, T):

        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1
        x = x.reshape(-1, T, *x.shape[1:])
        new_x = []
        for t in range(new_T):
            new_x.append(torch.cat([x[:, t + idx]
                         for idx in range(self.n_timesteps)], dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1).permute(0, 2, 1, 3)

        return new_x.reshape(-1, *new_x.shape[3:]), new_T

    def forward(self, input):
        bs, T, ch, h, w = input.shape

        # input = self._add_positional_encoding(input)
        f = self.image_encoder(input)

        # # Concatenates the features from current time and previous to create full state
        f, T = self._get_full_state(f, T)

        g = self.koopman.to_g(f, self.psteps)
        g = g.view(torch.Size([bs * self.n_objects, T]) + g.size()[1:])

        free_pred = 0 # TODO: Set it to >1 after training for several epochs.
        if free_pred > 0:
            G_tilde = g[:, :-1-free_pred, None]  # new axis corresponding to N number of objects
            H_tilde = g[:, 1:-free_pred, None]
        else:
            G_tilde = g[:, :-1, None]  # new axis corresponding to N number of objects
            H_tilde = g[:, 1:, None]

        # Find Koopman operator from current data (matrix a)
        # TODO: Predict longer from an A obtained for a shorter length.
        A, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, I_factor=self.I_factor)

        # Rollout. From observation in time 0, predict with Koopman operator
        G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], A=A)  # Rollout from initial observation to the last

        # # Note: Ignore - If we have split the representation: Merge constant and dynamic components
        # if g.size(-1) < self.g_dim:
        #     g = torch.cat([g, g_cte[:, 0:1].repeat(1, T, 1)], dim=-1)
        #     G_for_pred = torch.cat([G_for_pred, g_cte[:, 0:1].repeat(1, T - 1, 1)], dim=-1)


        g_mu, g_logvar = torch.chunk(g, 2, dim=-1)
        g_mu_pred, g_logvar_pred = torch.chunk(G_for_pred, 2, dim=-1)
        g = _sample_latent_simple(g_mu, g_logvar)
        G_for_pred = _sample_latent_simple(g_mu_pred, g_logvar_pred)
        f_mu, f_logvar = g_mu, g_logvar

        # Option 1: use the koopman object decoder
        # s_for_rec = self.koopman.to_s(gcodes=_get_flat(g),
        #                               pstep=self.psteps)
        # s_for_pred = self.koopman.to_s(gcodes=_get_flat(torch.cat([g[:, :1],G_for_pred], dim=1)),
        #                                pstep=self.psteps)
        # Option 2: we don't use the koopman object decoder
        s_for_rec = _get_flat(g)
        s_for_pred = _get_flat(G_for_pred)

        # Convolutional decoder. Normally Spatial Broadcasting decoder
        out_rec = self.image_decoder(s_for_rec)
        out_pred = self.image_decoder(s_for_pred)

        returned_mus = torch.cat([f_mu, g_mu_pred], dim=1)
        returned_logvars = torch.cat([f_logvar, g_logvar_pred], dim=1)
        o_touple = (out_rec, out_pred, returned_mus.reshape(-1, returned_mus.size(-1)),
                    returned_mus.reshape(-1, returned_mus.size(-1)),
                    returned_logvars.reshape(-1, returned_logvars.size(-1)))
                    # f_mu.reshape(-1, f_mu.size(-1)),
                    # f_logvar.reshape(-1, f_logvar.size(-1)))
        o = [item.reshape(torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple]
        # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1.5) for item in o[:2]]
        o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]]

        # A_to_show = self.block_diagonal(A)
        o.append(A[..., :self.g_dim, :self.g_dim])
        # if self.n_objects > 1:
        #     o.append(o_a)
        # Returns reconstruction, prediction, g_representation, mu and sigma from which we sample to compute KL divergence

        # # Show images
        # plt.imshow(input[0, 0, :].reshape(16, 16).cpu().detach().numpy())
        # plt.savefig('test_attention.png')
        return o
Exemplo n.º 4
0
class SingleObjKoopmanModel(BaseModel):
    def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim,
                 I_factor=10, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]):
        super().__init__()
        out_channels = 1
        n_layers = int(np.log2(image_size[0])) - 1

        # Positional encoding buffers
        x = torch.linspace(-1, 1, image_size[0])
        y = torch.linspace(-1, 1, image_size[1])
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid_enc', x_grid.view((1, 1, 1) + x_grid.shape))
        self.register_buffer('y_grid_enc', y_grid.view((1, 1, 1) + y_grid.shape))

        # Temporal encoding buffers
        if n_timesteps > 1:
            t = torch.linspace(-1, 1, n_timesteps)
            # Add as constant, with extra dims for N and C
            self.register_buffer('t_grid', t)

        # Set state dim with config, depending on how many time-steps we want to take into account
        self.image_size = image_size
        self.n_timesteps = n_timesteps
        self.state_dim = feat_dim
        self.I_factor = I_factor
        self.psteps = psteps
        self.g_dim = g_dim

        # self.att_image_encoder = AttImageEncoder(in_channels, feat_dim, ngf, n_layers)
        self.image_encoder = ImageEncoder(in_channels, feat_dim * 4, ngf, n_layers)  # feat_dim * 2 if sample here
        self.image_decoder = ImageDecoder(g_dim, out_channels, ngf, n_layers)
        # self.image_decoder = SimpleSBImageDecoder(feat_dim, out_channels, ngf, n_layers, image_size)
        self.koopman = KoopmanOperators(self.state_dim * 4, nf_particle, nf_effect, g_dim, n_timesteps)

    def _add_positional_encoding(self, x):

        x = torch.cat((self.x_grid_enc.expand(*x.shape[:2], -1, -1, -1),
                       self.y_grid_enc.expand(*x.shape[:2], -1, -1, -1),
                       x), dim=-3)

        return x

    def _get_full_state(self, x, T):

        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1
        x = x.reshape(-1, T, *x.shape[1:])
        new_x = []
        for t in range(new_T):
            new_x.append(torch.cat([x[:, t + idx]
                         for idx in range(self.n_timesteps)], dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1)
        return new_x.reshape(-1, *new_x.shape[2:]), new_T

    def forward(self, input):
        bs, T, ch, h, w = input.shape

        # Note: test image encoder with attention
        # f_mu, f_logvar = torch.chunk(self.att_image_encoder(input.reshape(-1, ch, h, w)), 2, dim=-1)
        # f = _sample(f_mu, f_logvar)

        # Note: Koopman applied to mu and logvar
        # f = self.att_image_encoder(input.reshape(-1, ch, h, w))

        # input = self._add_positional_encoding(input)
        f = self.image_encoder(input.reshape(-1, ch, h, w)) #Note: deterministic AE
        # f_mu, f_logvar = f, f
        # f_mu, f_logvar = torch.chunk(self.image_encoder(input.reshape(-1, ch+2, h, w)), 2, dim=-1)
        # f = _sample(f_mu, f_logvar)
        # f = f_mu

        # # Concatenates the features from current time and previous to create full state
        f, T = self._get_full_state(f, T)
        # f = f.view(torch.Size([bs, T]) + f.size()[1:])
        # g, G_for_pred = f, f

        g = self.koopman.to_g(f, self.psteps)
        g = g.view(torch.Size([bs, T]) + g.size()[1:])
        # g = f.view(torch.Size([bs, T]) + f.size()[1:])

        # # Note: Ignore - Split representation into constant and dynamic by hardcoding
        # g, g_cte = torch.chunk(g, 2, dim=-1) # Option 1
        # g, g_cte = g[..., :2], g[..., 2:] # Option 2

        free_pred = 3
        G_tilde = g[:, :-1-free_pred, None]  # new axis corresponding to N number of objects
        H_tilde = g[:, 1:-free_pred, None]

        # Find Koopman operator from current data (matrix a)
        # TODO: Predict longer from an A obtained for a shorter length.
        A, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, I_factor=self.I_factor)

        # Rollout. From observation in time 0, predict with Koopman operator
        G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], A=A)  # Rollout from initial observation to the last

        # # Note: Ignore - If we have split the representation: Merge constant and dynamic components
        # if g.size(-1) < self.g_dim:
        #     g = torch.cat([g, g_cte[:, 0:1].repeat(1, T, 1)], dim=-1)
        #     G_for_pred = torch.cat([G_for_pred, g_cte[:, 0:1].repeat(1, T - 1, 1)], dim=-1)


        g_mu, g_logvar = torch.chunk(g, 2, dim=-1)
        g_mu_pred, g_logvar_pred = torch.chunk(G_for_pred, 2, dim=-1)
        g = _sample_latent_simple(g_mu, g_logvar)
        G_for_pred = _sample_latent_simple(g_mu_pred, g_logvar_pred)
        f_mu, f_logvar = g_mu, g_logvar

        # Option 1: use the koopman object decoder
        # s_for_rec = self.koopman.to_s(gcodes=_get_flat(g),
        #                               pstep=self.psteps)
        # s_for_pred = self.koopman.to_s(gcodes=_get_flat(torch.cat([g[:, :1],G_for_pred], dim=1)),
        #                                pstep=self.psteps)
        # Option 2: we don't use the koopman object decoder
        s_for_rec = _get_flat(g)
        s_for_pred = _get_flat(G_for_pred)

        # Convolutional decoder. Normally Spatial Broadcasting decoder
        out_rec = self.image_decoder(s_for_rec)
        out_pred = self.image_decoder(s_for_pred)

        o_touple = (out_rec, out_pred, f_mu.reshape(-1, f_mu.size(-1)), f_mu.reshape(-1, f_mu.size(-1)),
                    f_logvar.reshape(-1, f_logvar.size(-1)))
        o = [item.reshape(torch.Size([bs, -1]) + item.size()[1:]) for item in o_touple]
        # Returns reconstruction, prediction, g_representation, mu and sigma from which we sample to compute KL divergence

        # # Show images
        # plt.imshow(input[0, 0, :].reshape(16, 16).cpu().detach().numpy())
        # plt.savefig('test_attention.png')
        return o
Exemplo n.º 5
0
class RecKoopmanModel(BaseModel):
    def __init__(self,
                 in_channels,
                 feat_dim,
                 nf_particle,
                 nf_effect,
                 g_dim,
                 u_dim,
                 n_objects,
                 I_factor=10,
                 psteps=1,
                 n_timesteps=1,
                 ngf=8,
                 image_size=[64, 64]):
        super().__init__()
        out_channels = 1
        n_layers = int(np.log2(image_size[0])) - 1

        self.u_dim = u_dim
        self.linear_g = nn.Linear(g_dim, g_dim)
        self.linear_u = nn.Linear(g_dim, u_dim)

        # Temporal encoding buffers
        if n_timesteps > 1:
            t = torch.linspace(-1, 1, n_timesteps)
            # Add as constant, with extra dims for N and C
            self.register_buffer('t_grid', t)

        # Set state dim with config, depending on how many time-steps we want to take into account
        self.image_size = image_size
        self.n_timesteps = n_timesteps
        self.state_dim = feat_dim
        self.I_factor = I_factor
        self.psteps = psteps
        self.g_dim = g_dim
        self.n_objects = n_objects
        self.n_directions = 2

        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()

        self.initial_conditions = nn.Sequential(
            nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps),
            nn.ReLU(), nn.Linear(feat_dim * n_timesteps, g_dim * 2))

        self.image_encoder = ImageEncoder(
            in_channels, feat_dim * 2, n_objects, ngf,
            n_layers)  # feat_dim * 2 if sample here
        self.image_decoder = ImageDecoder(g_dim, out_channels, ngf, n_layers)
        self.koopman = KoopmanOperators(feat_dim * 2, nf_particle * 2,
                                        nf_effect * 2, g_dim, u_dim,
                                        n_timesteps)

        #Object attention:
        # if n_objects > 1:
        #     self.obj_attention = ObjectAttention(in_channels, feat_dim, n_objects)

    def _get_full_state(self, x, T):

        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1
        x = x.reshape(-1, T, *x.shape[1:])
        new_x = []
        for t in range(new_T):
            new_x.append(
                torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)],
                          dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=2)
        return new_x.reshape(-1, *new_x.shape[3:]), new_T

    def forward(self, input):
        bs, T, ch, h, w = input.shape

        f = self.image_encoder(input, reverse=True)
        # Concatenates the features from current time and previous to create full state
        f, T = self._get_full_state(f, T)
        f = f.reshape(self.n_directions * bs * self.n_objects, T, f.shape[-1])

        f_mu, f_logvar, a_mu, a_logvar = [], [], [], []
        gs, us = [], []
        for t in range(T):
            if t == 0:
                # f_t = self.initial_conditions(f[:, t])
                prev_sample = torch.zeros_like(f[:, 0, :self.g_dim])
                # f_t = torch.cat([f_ini, prev_sample], dim=-1)
                prev_hidden = None

                f_t = torch.cat([f[:, t], prev_sample], dim=-1)
                g, u = self.koopman.to_g(f_t, psteps=self.psteps)
            else:
                f_t = torch.cat([f[:, t], prev_sample], dim=-1)
                # g, prev_hidden, u = self.koopman.to_g(f_t, prev_hidden) #,psteps=self.psteps
                g, u = self.koopman.to_g(f_t, psteps=self.psteps)
            g_mu, g_logvar = torch.chunk(g, 2, dim=-1)
            g = _sample_latent_simple(g_mu, g_logvar)
            # if t < T-1:
            #     u_mu, u_logvar = torch.chunk(u, 2, dim=-1)
            #     u = _sample_latent_simple(u_mu, u_logvar)
            # else:
            #     u_mu, u_logvar = torch.chunk(torch.zeros_like(u), 2, dim=-1)
            #     u = u_mu
            # g, u = self.linear_g(g), self.sigmoid(self.linear_u(u))
            g, u = self.linear_g(g), self.linear_u(g)

            prev_sample = g

            gs.append(g)
            us.append(u)

            f_mu.append(g_mu)
            f_logvar.append(g_logvar)
            # a_mu.append(u_mu)
            # a_logvar.append(u_logvar)

        g = torch.stack(gs, dim=1)
        f_mu = torch.stack(f_mu, dim=1)
        f_logvar = torch.stack(f_logvar, dim=1)

        u = torch.stack(us, dim=1)
        # a_mu = torch.stack(a_mu, dim=1)
        # a_logvar = torch.stack(a_logvar, dim=1)

        reverse = True
        if reverse:
            g_rsh = g.view(
                torch.Size([bs, self.n_directions, self.n_objects, T]) +
                g.size()[2:])
            g_rev = g_rsh[:, 0].reshape(bs * self.n_objects, T,
                                        g_rsh.shape[-1])
            g = g_rsh[:, 1].reshape(bs * self.n_objects, T, g_rsh.shape[-1])

        free_pred = 0
        if free_pred > 0:
            G_tilde = g[:, :-1 - free_pred,
                        None]  # new axis corresponding to N number of objects
            H_tilde = g[:, 1:-free_pred, None]
        else:
            G_tilde = g[:, :-1,
                        None]  # new axis corresponding to N number of objects
            H_tilde = g[:, 1:, None]

        # A, B, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, U=u[:, :-1], I_factor=self.I_factor)
        A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A(
            G=G_tilde, H=H_tilde, I_factor=self.I_factor)

        # Rollout. From observation in time 0, predict with Koopman operator
        # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B)
        G_for_pred = self.koopman.simulate(T=T - 1,
                                           g=g[:, 0],
                                           u=None,
                                           A=A,
                                           B=None)
        G_for_pred_rev = self.koopman.simulate(T=T - 1,
                                               g=g_rev[:, 0],
                                               u=None,
                                               A=A_pinv,
                                               B=None)
        B = None
        # G_for_pred = torch.cat([g[:,0:1],self.koopman.simulate(T=T - 2, g=g[:, 1], A=A)], dim=1)

        # Option 1: use the koopman object decoder
        # s_for_rec = self.koopman.to_s(gcodes=_get_flat(g),
        #                               pstep=self.psteps)
        # s_for_pred = self.koopman.to_s(gcodes=_get_flat(torch.cat([g[:, :1],G_for_pred], dim=1)),
        #                                pstep=self.psteps)
        # Option 2: we don't use the koopman object decoder
        g_all = torch.stack([g, g_rev], dim=1).reshape(
            self.n_directions * bs * self.n_objects, -1, g.shape[-1])
        G_pred_all = torch.stack([G_for_pred, G_for_pred_rev], dim=1).reshape(
            self.n_directions * bs * self.n_objects, -1, G_for_pred.shape[-1])
        s_for_rec = _get_flat(g_all)
        s_for_pred = _get_flat(G_pred_all)

        # Convolutional decoder. Normally Spatial Broadcasting decoder
        out_rec = self.image_decoder(s_for_rec)
        out_rec = out_rec.reshape(
            torch.Size([bs, self.n_objects, self.n_directions, -1]) +
            out_rec.size()[1:])
        out_pred = self.image_decoder(s_for_pred)
        out_pred = out_pred.reshape(
            torch.Size([bs, self.n_objects, self.n_directions, -1]) +
            out_pred.size()[1:])

        returned_g = torch.cat([g, G_for_pred], dim=1)
        returned_g_rev = torch.cat([g_rev, G_for_pred_rev], dim=1)

        # returned_mus = torch.cat([f_mu], dim=-1)
        returned_mus = f_mu.reshape(bs, self.n_directions, self.n_objects,
                                    *f_mu.shape[1:])
        returned_mus = returned_mus.reshape(bs, self.n_directions,
                                            self.n_objects, -1,
                                            returned_mus.shape[-1])

        # returned_logvars = torch.cat([f_logvar], dim=-1)
        returned_logvars = f_logvar.reshape(bs, self.n_directions,
                                            self.n_objects,
                                            *f_logvar.shape[1:])
        returned_logvars = returned_logvars.reshape(bs, self.n_directions,
                                                    self.n_objects, -1,
                                                    returned_logvars.shape[-1])

        # returned_mus = torch.cat([f_mu, a_mu], dim=-1)
        # returned_logvars = torch.cat([f_logvar, a_logvar], dim=-1)

        o = [out_rec, out_pred, returned_g, returned_mus, returned_logvars]

        # Option 1: one object mapped to 0
        # shape_o0 = o[0].shape
        # o[0] = o[0].reshape(bs, self.n_objects, 2, *o[0].shape[1:])
        # o[0][:,0] = o[0][:,0]*0
        # o[0] = o[0].reshape(*shape_o0)

        # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, self.n_directions, *item.shape[3:]), dim=1), min=0, max=1) for item in o[:2]]
        o[:2] = [
            torch.sum(item.reshape(bs, self.n_objects, self.n_directions,
                                   *item.shape[3:]),
                      dim=1) for item in o[:2]
        ]
        o[:2] = [
            item.reshape(bs * self.n_directions, *item.shape[2:])
            for item in o[:2]
        ]

        o.append(A)
        o.append(B)
        o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1]))
        o.append(
            returned_g_rev.reshape(bs * self.n_objects, -1, g_rev.shape[-1]))

        return o
Exemplo n.º 6
0
class RecKoopmanModel(BaseModel):
    def __init__(self,
                 in_channels,
                 feat_dim,
                 nf_particle,
                 nf_effect,
                 g_dim,
                 u_dim,
                 n_objects,
                 I_factor=10,
                 n_blocks=1,
                 psteps=1,
                 n_timesteps=1,
                 ngf=8,
                 image_size=[64, 64]):
        super().__init__()
        out_channels = 1
        n_layers = int(np.log2(image_size[0])) - 1

        self.u_dim = u_dim

        # Temporal encoding buffers
        if n_timesteps > 1:
            t = torch.linspace(-1, 1, n_timesteps)
            # Add as constant, with extra dims for N and C
            self.register_buffer('t_grid', t)

        # Set state dim with config, depending on how many time-steps we want to take into account
        self.image_size = image_size
        self.n_timesteps = n_timesteps
        self.state_dim = feat_dim
        self.I_factor = I_factor
        self.psteps = psteps
        self.g_dim = g_dim
        self.n_objects = n_objects

        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()

        # self.linear_g = nn.Linear(g_dim, g_dim)

        # self.initial_conditions = nn.Sequential(nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps),
        #                                         nn.ReLU(),
        #                                         nn.Linear(feat_dim * n_timesteps, g_dim * 2))

        feat_dyn_dim = feat_dim // 6
        self.feat_dyn_dim = feat_dyn_dim
        self.content = None
        self.reverse = False
        self.hankel = True
        self.ini_alpha = 1
        self.incr_alpha = 0.5

        # self.linear_u = nn.Linear(g_dim + u_dim, u_dim * 2)
        # self.linear_u_2_f = nn.Linear(feat_dyn_dim * n_timesteps, u_dim * 2)
        # self.linear_u_T_f = nn.Linear(feat_dyn_dim * n_timesteps, u_dim)

        if self.hankel:
            # self.linear_u_2_g = nn.Linear(g_dim * self.n_timesteps, u_dim * 2)
            # self.linear_u_2_g = nn.Sequential(nn.Linear(g_dim * self.n_timesteps, g_dim),
            #                                    nn.ReLU(),
            #                                    nn.Linear(g_dim, u_dim * 2))
            # self.linear_u_all_g = nn.Sequential(nn.Linear(g_dim * self.n_timesteps, g_dim),
            #                                   nn.ReLU(),
            #                                   nn.Linear(g_dim, u_dim + 1))
            # self.gru_u_all_g = nn.GRU(g_dim, u_dim + 1, num_layers = 2, batch_first=True)
            self.linear_u_1_g = nn.Sequential(
                nn.Linear(g_dim * self.n_timesteps, g_dim), nn.ReLU(),
                nn.Linear(g_dim, u_dim))
        else:
            self.linear_u_2_g = nn.Linear(g_dim, u_dim * 2)

        self.linear_f_mu = nn.Linear(feat_dim * 2, feat_dim)
        self.linear_f_logvar = nn.Linear(feat_dim * 2, feat_dim)

        self.image_encoder = ImageEncoder(
            in_channels, feat_dim * 2, n_objects, ngf,
            n_layers)  # feat_dim * 2 if sample here
        self.image_decoder = ImageDecoder(feat_dim, out_channels, ngf,
                                          n_layers)
        self.koopman = KoopmanOperators(feat_dyn_dim, nf_particle, nf_effect,
                                        g_dim, u_dim, n_timesteps, n_blocks)

    def _get_full_state(self, x, T):

        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1
        x = x.reshape(-1, T, *x.shape[1:])
        new_x = []
        for t in range(new_T):
            new_x.append(
                torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)],
                          dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1)
        return new_x.reshape(-1, new_x.shape[-1]), new_T

    def _get_full_state_hankel(self, x, T):
        '''
        :param x: features or observations
        :param T: number of time-steps before concatenation
        :return: Columns of a hankel matrix with self.n_timesteps rows.
        '''
        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1

        x = x.reshape(-1, T, *x.shape[2:])
        new_x = []
        for t in range(new_T):
            new_x.append(
                torch.stack([x[:, t + idx] for idx in range(self.n_timesteps)],
                            dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1)

        return new_x.reshape(-1, new_T,
                             new_x.shape[-2] * new_x.shape[-1]), new_T

    def forward(self, input, epoch=1):
        bs, T, ch, h, w = input.shape
        (n_dirs, bs) = (2, 2 * bs) if self.reverse else (1, bs)
        f = self.image_encoder(input, reverse=self.reverse)
        f_mu = self.linear_f_mu(f)
        f_logvar = self.linear_f_logvar(f)
        # Concatenates the features from current time and previous to create full state
        # f_mu, f_logvar = torch.chunk(f, 2, dim=-1)
        f = _sample_latent_simple(f_mu, f_logvar)

        # Note: Split features (cte, dyn)
        f_dyn, f_cte = f[..., :self.feat_dyn_dim], f[..., self.feat_dyn_dim:]
        f_cte = f_cte.reshape(bs * self.n_objects, T, *f_cte.shape[1:])

        # Test disentanglement
        # if random.random() < 0.1 or self.content is None:
        #     self.content = f_cte
        # f_cte = self.content

        f = f_dyn
        f = f.reshape(bs * self.n_objects * T, -1)

        if not self.hankel:
            f, T = self._get_full_state(f, T)

        g = self.koopman.to_g(f, self.psteps)
        g = g.reshape(bs * self.n_objects, T, *g.shape[1:])

        if self.hankel:
            g_flat = g
            g, T = self._get_full_state_hankel(g, T)

        if self.reverse:
            bs = bs // n_dirs
            g = g.reshape(bs, n_dirs, self.n_objects, T, *g.shape[2:])
            g_fw = g[:, 0].reshape(bs * self.n_objects, T, *g.shape[4:])
            g_bw = g[:, 1].reshape(bs * self.n_objects, T, *g.shape[4:])
            g = g_bw

        # Option 1: Sigmoid. All time-steps can be activated, binary activation. With dropout
        u_dist = self.linear_u_1_g(
            g.reshape(bs * self.n_objects * T, g.shape[-1]))
        u_dist = u_dist.reshape(u_dist.size(0), self.u_dim)
        u = self.sigmoid(
            (self.ini_alpha + epoch * self.incr_alpha) * u_dist).reshape(
                -1, T, self.u_dim)
        # u = torch.abs(u - torch.ones_like(u)*0.5)*2
        # do = nn.Dropout(p=0.2)
        # u = do(u)
        # Option 2: Gumbel/softmax. All time-steps can be activated. 1 action at a time, or None.
        # F.relu
        # u_dist, _ = (self.gru_u_all_g(g_flat
        #                           .reshape(bs * self.n_objects, T + self.n_timesteps -1, g_flat.shape[-1])))
        # u_dist = u_dist[:, self.n_timesteps-1:].reshape(bs *self.n_objects* T, - 1)
        # # u_dist = self.linear_u_all_g(g.reshape(bs * self.n_objects * T, g.shape[-1]))
        # u_dist = u_dist.reshape(u_dist.size(0), self.u_dim + 1)
        # u = F.gumbel_softmax(u_dist, tau=1, hard=True)
        # u = u[..., 1:].reshape(-1, T, self.u_dim)
        # # u_dist = u_dist.reshape(u_dist.size(0), self.u_dim + 1)
        # # temp = 0.01
        # # u = nn.Softmax(dim=-1)(u_dist/temp)
        # # u = u[..., 1:].reshape(-1, T, self.u_dim)
        # Option 3: Categorical. All time-steps can be activated. 1 action at a time. Non-diff
        # u_dist = F.sigmoid(self.linear_u_all_g(g.reshape(bs * self.n_objects * T, g.shape[-1])))
        # u_dist = u_dist.reshape(u_dist.size(0), self.u_dim + 1)
        # u_cat = Categorical(u_dist)
        # u = u_cat.sample()
        # u = F.one_hot(u).float()
        # u = u[..., 1:].reshape(-1, T, self.u_dim) # The first possible action is the inaction. Which should have the max prob.

        if self.hankel:
            zero_pad = torch.zeros_like(u[:, :self.n_timesteps - 1])
            u = torch.cat([zero_pad, u], dim=1)
            u, T_u = self._get_full_state_hankel(u, T + self.n_timesteps - 1)
            # u = u.reshape(*u.shape[:-1], u.shape[-1] // self.n_timesteps, self.n_timesteps)
            assert T_u == T

        # Option 2: One activation for all time-steps. Too strong of an assumption.
        # u_dist = F.relu(self.linear_u_T(f))
        # u_dist = u_dist.reshape(-1, T, self.u_dim).permute(0, 2, 1)
        # u = F.gumbel_softmax(u_dist, tau=1, hard=True)
        # u = u.permute(0, 2, 1)
        # Option 3: Actions are given by the observations g, obtained one at a time (Not necessary?)
        # us, u_dist = [], []
        # for t in range(T):
        #     if t==0:
        #         prev_u_sample = torch.zeros_like(g[:, 0, :self.u_dim])
        #     q_y = F.relu(self.linear_u(f))
        #     q_y = q_y.view(q_y.size(0),self.u_dim,2)
        #     u = F.gumbel_softmax(q_y, tau=1, hard=True)
        #     u = u[..., 0]
        #
        #     prev_u_sample = u
        #     us.append(u)
        #     u_dist.append(q_y)
        # u = torch.stack(us, dim=1)
        # u_dist = torch.stack(u_dist, dim=1)

        # TODO:
        #  0: state as velocity acceleration, pose
        #  0: g fitting error.
        #  0: Hankel view. F**k koopman for now, it has too many degrees of freedom. Restar input para minimizar rango de H.
        #  1: Invert or Sample randomly u and maximize the error in reconstruction.
        #  2: Treat n_timesteps with conv_nn? Or does it make sense to mix f1(t) and f2(t-1)?
        #  3: We don't impose A symmetric, but we impose that g in both directions use the same A and B.
        #       Explode symmetry. If there's input in the a sequence, there will be the same input in the reverse sequence.
        #       B is valid for both. Then we cant cross, unless we recalculate B. It might not be necessary because we use
        #       the same A for both dirs. INVERT U's temporally --> that's a great idea. Dismiss (n_ini_timesteps - 1 samples) in each extreme Is this correct?
        #       Is it the same linear input intervention if I do it forward and backward?
        #  5: Eigenvalues and Eigenvector. Fix one get the other.
        #  6: Observation selector using Gumbel (should apply mask emissor and receptor
        #       and it should be the same across time)
        #  7: Attention mechanism to select - koopman (sub-koopman) / Identity. It will be useful for the future.
        #  8: When works, formalize code. Hyperparameters through config.yaml
        #  9: The nonlinear projection should also separate the objects. We can separate by the svd? Is it possible?
        # 1: Prev_sample not necessary? Start from g(0)
        # 2: Sampling from features
        # 3: Bottleneck in f. Smaller than G. Constant part is big and routed directly.
        #       Variant part is very small and expanded to G.
        # 4: In evaluation, sample features several times but only keep one of the cte. Check if it captures the appearance
        # 5: Fix B with A learned. If the first doesn't require input, B will be mapped to 0.

        randperm = torch.arange(g.shape[0])  if self.reverse \
            else torch.randperm(g.shape[0])

        #Note: Inverting u(t) in the time axis
        if self.reverse:
            u_fw = u
            zeros = torch.zeros_like(u[:, :self.n_timesteps])
            u_bw = torch.flip(u, dims=[1])
            u_bw = torch.cat([u_bw[:, self.n_timesteps:], zeros], dim=1)
            u = u_bw
            free_pred = self.n_timesteps - 1 + 4
        else:
            free_pred = 0

        if free_pred > 0:
            G_tilde = g[randperm, self.n_timesteps - 1:-1 - free_pred, None]
            H_tilde = g[randperm, self.n_timesteps:-free_pred, None]
        else:
            G_tilde = g[randperm, self.n_timesteps - 1:-1, None]
            H_tilde = g[randperm, self.n_timesteps:, None]

        # if free_pred > 0:
        #     G_tilde = g[randperm, :-1-free_pred, None]
        #     H_tilde = g[randperm, 1:-free_pred, None]
        # else:
        #     G_tilde = g[randperm, :-1, None]
        #     H_tilde = g[randperm, 1:, None]

        A, B, A_inv, fit_err = self.koopman.system_identify(
            G=G_tilde,
            H=H_tilde,
            U=u[randperm, self.n_timesteps - 1:-1 - free_pred],
            I_factor=self.I_factor)
        # A, B, A_inv, fit_err = self.koopman.fit_with_A(G=G_tilde, H=H_tilde, U=u[randperm, :-1], I_factor=self.I_factor)
        # A, B, A_inv, fit_err = self.koopman.fit_with_B(G=G_tilde, H=H_tilde, U=u[randperm, :-1-free_pred], I_factor=self.I_factor)
        # A, B = self.koopman.fit_with_AB(G_tilde.shape[0])
        # A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A(G=G_tilde, H=H_tilde, I_factor=self.I_factor)
        # B = None

        if self.reverse:
            g = g_fw
            u = u_fw
            f_cte_shape = f_cte.shape
            f_cte = f_cte.reshape(bs, n_dirs * f_cte_shape[1],
                                  *f_cte_shape[2:])

        # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B)
        # g_for_koop = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None)
        # G_for_pred = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1],
        #                         self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=u[:, self.n_timesteps -1:], A=A, B=B)], dim=1)
        # g_for_koop = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1],
        #                         self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=None, A=A, B=None)], dim=1)
        G_for_pred = self.koopman.simulate(T=T - self.n_timesteps,
                                           g=g[:, self.n_timesteps - 1],
                                           u=u[:, self.n_timesteps - 1:],
                                           A=A,
                                           B=B)
        g_for_koop = self.koopman.simulate(T=T - self.n_timesteps,
                                           g=g[:, self.n_timesteps - 1],
                                           u=None,
                                           A=A,
                                           B=None)

        g_for_koop, T_prime = self._get_full_state_hankel(
            g_for_koop, g_for_koop.shape[1])

        # g_for_koop = G_for_pred

        # TODO: create recursive rollout. We obtain the input at each step from g
        # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None)
        '''Simple version of the reverse rollout'''
        # G_for_pred_rev = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A_pinv, B=None)
        # G_for_pred = torch.flip(G_for_pred_rev, dims=[1])
        # G_for_pred = torch.cat([g[:,0:1],self.koopman.simulate(T=T - 2, g=g[:, 1], u=u, A=A, B=B)], dim=1)

        if self.hankel:
            # G_for_pred = G_for_pred.reshape(*G_for_pred.shape[:-1],
            #                                 G_for_pred.shape[-1]// self.n_timesteps,
            #                                 self.n_timesteps)[..., self.n_timesteps - 1]
            # g = g.reshape(*g.shape[:-1],
            #               g.shape[-1]// self.n_timesteps,
            #               self.n_timesteps)[..., self.n_timesteps - 1]
            #
            # g_for_koop = g_for_koop.reshape(*g_for_koop.shape[:-1],
            #               g_for_koop.shape[-1]// self.n_timesteps,
            #               self.n_timesteps)
            #Option 2: with hankel structure.
            G_for_pred = G_for_pred.reshape(*G_for_pred.shape[:-1],
                                            G_for_pred.shape[-1])
            g = g.reshape(*g.shape[:-1], g.shape[-1] // self.n_timesteps,
                          self.n_timesteps)[..., self.n_timesteps - 1]

        # Option 1: use the koopman object decoder
        s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), psteps=self.psteps)
        s_for_pred = self.koopman.to_s(gcodes=_get_flat(G_for_pred),
                                       psteps=self.psteps)

        # Note: Split features (cte, dyn). In case of reverse, f_cte averages for both directions.
        # Note 2: TODO: Reconstruction could be in reversed g's or both!
        s_for_rec = torch.cat([
            s_for_rec.reshape(bs * self.n_objects, T, -1),
            f_cte[:, None].mean(2).repeat(1, T, 1)
        ],
                              dim=-1)
        s_for_pred = torch.cat([
            s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1),
            f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1)
        ],
                               dim=-1)
        s_for_rec = _get_flat(s_for_rec)
        s_for_pred = _get_flat(s_for_pred)

        # Option 2: we don't use the koopman object decoder
        # s_for_rec = _get_flat(g)
        # s_for_pred = _get_flat(G_for_pred)

        # Convolutional decoder. Normally Spatial Broadcasting decoder
        out_rec = self.image_decoder(s_for_rec)
        out_pred = self.image_decoder(s_for_pred)

        returned_g = torch.cat([g, G_for_pred], dim=1)
        returned_mus = torch.cat([f_mu], dim=-1)
        returned_logvars = torch.cat([f_logvar], dim=-1)
        # returned_mus = torch.cat([f_mu, a_mu], dim=-1)
        # returned_logvars = torch.cat([f_logvar, a_logvar], dim=-1)

        o_touple = (out_rec, out_pred,
                    returned_g.reshape(-1, returned_g.size(-1)),
                    returned_mus.reshape(-1, returned_mus.size(-1)),
                    returned_logvars.reshape(-1, returned_logvars.size(-1)))
        # f_mu.reshape(-1, f_mu.size(-1)),
        # f_logvar.reshape(-1, f_logvar.size(-1)))
        o = [
            item.reshape(
                torch.Size([bs * self.n_objects, -1]) + item.size()[1:])
            for item in o_touple
        ]
        # Option 1: one object mapped to 0
        # shape_o0 = o[0].shape
        # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:])
        # o[0][:,0] = o[0][:,0]*0
        # o[0] = o[0].reshape(*shape_o0)

        # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2]]
        o[:2] = [
            torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1)
            for item in o[:2]
        ]
        o[3:5] = [
            item.reshape(bs, self.n_objects, *item.shape[1:])
            for item in o[3:5]
        ]

        o.append(A)
        o.append(B)

        u = u.reshape(*u.shape[:2], self.u_dim, self.n_timesteps)[..., -1]
        o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1]))
        o.append(
            u_dist.reshape(bs * self.n_objects, -1,
                           *u_dist.shape[-1:]))  #Append udist for categorical
        # o.append(u_dist.reshape(bs * self.n_objects, -1, *u_dist.shape[-2:]))
        o.append(
            g_for_koop.reshape(bs * self.n_objects, -1,
                               g_for_koop.shape[-1] // self.n_timesteps,
                               self.n_timesteps))
        o.append(fit_err)

        return o
Exemplo n.º 7
0
class RecKoopmanModel(BaseModel):
    def __init__(self,
                 in_channels,
                 feat_dim,
                 nf_particle,
                 nf_effect,
                 g_dim,
                 u_dim,
                 n_objects,
                 I_factor=10,
                 n_blocks=1,
                 psteps=1,
                 n_timesteps=1,
                 ngf=8,
                 image_size=[64, 64]):
        super().__init__()
        out_channels = 1
        n_layers = int(np.log2(image_size[0])) - 1

        self.u_dim = u_dim

        # Set state dim with config, depending on how many time-steps we want to take into account
        self.image_size = image_size
        self.n_timesteps = n_timesteps
        self.state_dim = feat_dim
        self.I_factor = I_factor
        self.psteps = psteps
        self.g_dim = g_dim

        feat_dyn_dim = feat_dim // 8
        feat_dyn_dim = 4
        self.feat_dyn_dim = feat_dyn_dim
        self.feat_cte_dim = feat_dim - feat_dyn_dim

        self.with_u = False
        self.n_iters = 1
        self.ini_alpha = 1
        # Note:
        #  - I leave it to 0 now. If it increases too fast, the gradients might be affected
        self.incr_alpha = 0.1

        self.cte_resolution = (32, 32)
        self.ori_resolution = (128, 128)
        self.att_resolution = (16, 16)
        self.obj_resolution = (4, 4)
        # self.n_objects = reduce((lambda x, y: x * y), self.obj_resolution)
        self.n_objects = n_objects

        self.spatial_tf = SpatialTransformation(self.cte_resolution,
                                                self.ori_resolution)

        self.linear_f_cte_post = nn.Linear(2 * self.feat_cte_dim,
                                           2 * self.feat_cte_dim)
        self.linear_f_dyn_post = nn.Linear(2 * self.feat_dyn_dim,
                                           2 * self.feat_dyn_dim)
        #
        self.bc_decoder = False  #2 *
        if self.bc_decoder:
            self.image_decoder = ImageBroadcastDecoder(
                self.feat_cte_dim, out_channels,
                resolution=(16, 16))  # resolution=self.att_resolution
        else:
            self.image_decoder = ImageDecoder(self.feat_cte_dim,
                                              out_channels,
                                              dyn_dim=self.feat_dyn_dim)
        self.image_encoder = ImageEncoder(
            in_channels, 2 * self.feat_cte_dim, self.feat_dyn_dim,
            self.att_resolution, self.n_objects, ngf,
            n_layers)  # feat_dim * 2 if sample here
        self.koopman = KoopmanOperators(feat_dyn_dim, nf_particle, nf_effect,
                                        g_dim, u_dim, n_timesteps, n_blocks)

    def _get_full_state(self, x, T):

        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1
        x = x.reshape(-1, T, *x.shape[1:])
        new_x = []
        for t in range(new_T):
            new_x.append(
                torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)],
                          dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1)
        return new_x.reshape(-1, new_x.shape[-1]), new_T

    def _get_full_state_hankel(self, x, T):
        '''
        :param x: features or observations
        :param T: number of time-steps before concatenation
        :return: Columns of a hankel matrix with self.n_timesteps rows.
        '''
        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1

        x = x.reshape(-1, T, *x.shape[2:])
        new_x = []
        for t in range(new_T):
            new_x.append(
                torch.stack([x[:, t + idx] for idx in range(self.n_timesteps)],
                            dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1)

        return new_x.reshape(-1, new_T,
                             new_x.shape[-2] * new_x.shape[-1]), new_T

    def forward(self, input, epoch=1):
        # Note: Add annealing in SPACE
        bs, T, ch, h, w = input.shape
        # Percentage of output

        free_pred = T // 4
        returned_post = []

        input = input.cuda()
        # Backbone deterministic features
        f_bb = self.image_encoder(input, block='backbone')

        # Dynamic features
        T_inp = T
        f_dyn, shape, f_cte, confi = self.image_encoder(f_bb[:, :T_inp],
                                                        block='dyn_track')
        f_dyn = f_dyn.reshape(-1, f_dyn.shape[-1])

        # Sample dynamic features or reshape
        # Option 1: Don't sample
        f_dyn = f_dyn.reshape(bs * self.n_objects, T_inp, -1)
        # Option 2: Sample
        # f_mu_dyn, f_logvar_dyn = self.linear_f_dyn_post(f_dyn).reshape(bs, self.n_objects, T_inp, -1).chunk(2, -1)
        # f_dyn_post = Normal(f_mu_dyn, F.softplus(f_logvar_dyn))
        # f_dyn = f_dyn_post.rsample()
        # f_dyn = f_dyn.reshape(bs * self.n_objects, T_inp, -1)
        # returned_post.append(f_dyn_post)

        # Get delayed dynamic features
        f_dyn_state, T_inp = self._get_full_state_hankel(f_dyn, T_inp)

        # Get inputs from delayed dynamic features
        u, u_dist = self.koopman.to_u(f_dyn_state,
                                      temp=self.ini_alpha +
                                      epoch * self.incr_alpha,
                                      ignore=True)
        if not self.with_u:
            u = torch.zeros_like(u)
        # Get observations from delayed dynamic features
        g = self.koopman.to_g(
            f_dyn_state.reshape(bs * self.n_objects * T_inp, -1), self.psteps)
        g = g.reshape(bs * self.n_objects, T_inp, *g.shape[1:])

        # Get shifted observations for sys ID
        randperm = torch.arange(g.shape[0])  # No permutation
        # randperm = torch.randperm(g.shape[0]) # Random permutation
        if free_pred > 0:
            G_tilde = g[randperm, :-1 - free_pred, None]
            H_tilde = g[randperm, 1:-free_pred, None]
        else:
            G_tilde = g[randperm, :-1, None]
            H_tilde = g[randperm, 1:, None]

        # Sys ID
        A, B, A_inv, fit_err = self.koopman.system_identify(
            G=G_tilde,
            H=H_tilde,
            U=u[randperm, :T_inp - free_pred - 1],
            I_factor=self.I_factor
        )  # Try not permuting U when inp is permutted

        # Rollout from start_step onwards.
        start_step = 2  # g and u must be aligned!!
        G_for_pred = self.koopman.simulate(T=T_inp - start_step - 1,
                                           g=g[:, start_step],
                                           u=u[:, start_step:],
                                           A=A,
                                           B=B)
        g_for_koop = G_for_pred

        assert f_bb[:, self.n_timesteps - 1:self.n_timesteps - 1 +
                    T_inp].shape[1] == f_bb[:, self.n_timesteps - 1:].shape[1]
        rec = {
            "obs": g,
            "confi": confi[:, :,
                           self.n_timesteps - 1:self.n_timesteps - 1 + T_inp],
            "shape": shape[:, :,
                           self.n_timesteps - 1:self.n_timesteps - 1 + T_inp],
            "f_cte": f_cte[:, :,
                           self.n_timesteps - 1:self.n_timesteps - 1 + T_inp],
            "T": T_inp,
            "name": "rec"
        }
        pred = {
            "obs": G_for_pred,
            "confi": confi[:, :, -G_for_pred.shape[1]:],
            "shape": shape[:, :, -G_for_pred.shape[1]:],
            "f_cte": f_cte[:, :, -G_for_pred.shape[1]:],
            "T": G_for_pred.shape[1],
            "name": "pred"
        }
        outs = {}
        BBs = {}

        # TODO: Check if the indices for f_bb and supervision are correct.
        # Recover partial shape with decoded dynamical features. Iterate with new estimates of the appearance.
        # Note: This process could be iterative.
        for idx, case in enumerate([rec, pred]):

            case_name = case["name"]
            # get back dynamic features
            # if case_name == "rec":
            #     f_dyn_state = f_dyn[:, -T_inp:].reshape(-1, *f_dyn.shape[2:])
            # else:
            # TODO: Try with Initial f_dyn (no koop)
            f_dyn_state = self.koopman.to_s(gcodes=_get_flat(case["obs"]),
                                            psteps=self.psteps)
            # TODO: From f_dyn extract where, scale, etc. with only Y = WX + b. Review.

            # Initial noisy (low variance) constant vector.
            # Note:
            #  - Different realization for each time-step.
            #  - Is it a Variable?
            # f_cte_ini = torch.randn(bs * self.n_objects * case["T"], self.feat_cte_dim).to(f_dyn_state.device) #* self.f_cte_ini_std

            # Get full feature vector
            # f = torch.cat([ f_dyn_state,
            #                 f_cte_ini], dim=-1)

            # Get coarse features from which obtain queries and/or decode
            # f_coarse = self.image_decoder(f, block = 'coarse')
            # f_coarse = f_coarse.reshape(bs, self.n_objects, case["T"], *f_coarse.shape[1:])

            for _ in range(self.n_iters):

                # Get constant feature vector through attention
                # f_cte = self.image_encoder(case["backbone_features"], f_coarse, block='cte')

                # Encode with raw dynamic features
                # pose = self.pose_encoder(f_dyn_state)
                pose = f_dyn_state
                # M_center, M_relocate = get_affine_params_from_pose(pose, self.pose_limits, self.pose_bias)

                # Option 2: Previous impl
                # bb_feat = case["backbone_features"].unsqueeze(1).repeat_interleave(self.n_objects, dim=1)\
                #     .reshape(-1, *case["backbone_features"].shape[-4:]) # Repeat for number of objects
                # warped_bb_feat = (bb_feat, M_center, *self.cte_resolution) # Check all resolutions are coordinate [32, 32]
                # f_cte = self.image_encoder(warped_bb_feat, f_dyn_state, block='cte')

                # Sample cte features
                f_cte = case["f_cte"].reshape(bs * self.n_objects * case["T"],
                                              -1)
                confi = case["confi"].reshape(bs * self.n_objects * case["T"],
                                              -1)
                #

                f_mu_cte, f_logvar_cte = self.linear_f_cte_post(f_cte).reshape(
                    bs, self.n_objects, case["T"], -1).chunk(2, -1)
                f_cte_post = Normal(f_mu_cte, F.softplus(f_logvar_cte))
                f_cte = f_cte_post.rsample().reshape(
                    bs * self.n_objects * case["T"], -1)
                # Option 2: Previous impl
                # f_mu_cte, f_logvar_cte = self.linear_f_cte_post(f_cte).reshape(bs, self.n_objects, 1, -1).chunk(2, -1)
                # f_cte_post = Normal(f_mu_cte, F.softplus(f_logvar_cte))
                # f_cte = f_cte_post.rsample()
                # f_cte = f_cte.repeat_interleave(case["T"], dim=2)
                # f_cte = f_cte.reshape(bs * self.n_objects * case["T"], self.feat_cte_dim)

                # Register statistics
                returned_post.append(f_cte_post)

                # Get full feature vector
                # f = torch.cat([ f_dyn_state,
                #                 f_cte], dim=-1) # Note: Do if f_dyn_state is used in the appearance
                f = f_cte

            # Get output. Spatial broadcast decoder

            dec_obj = self.image_decoder(f, block='to_x')
            if not self.bc_decoder:
                grid, area = self.spatial_tf(confi, pose)
                outs[case_name], out_shape = self.spatial_tf.warp_and_render(
                    dec_obj, case["shape"], confi, grid)
            else:
                outs[case_name] = dec_obj * confi[..., None, None]
            # outs[case_name] = warp_affine(dec_obj, M_relocate, h=h, w=w)

            BBs[case_name] = dec_obj

        out_rec = outs["rec"]
        out_pred = outs["pred"]
        bb_rec = BBs["rec"]
        # bb_rec = BBs["pred"]

        # Test disentanglement - TO REVIEW
        # if random.random() < 0.1 or self.content is None:
        #     self.content = f_cte
        # f_cte = self.content
        ''' Returned variables '''
        returned_g = torch.cat([g, G_for_pred], dim=1)  # Observations

        o_touple = (out_rec, out_pred,
                    returned_g.reshape(-1, returned_g.size(-1)))
        o = [
            item.reshape(
                torch.Size([bs * self.n_objects, -1]) + item.size()[1:])
            for item in o_touple
        ]

        # Option 1: one object mapped to 0
        # shape_o0 = o[0].shape
        # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:])
        # o[0][:,0] = o[0][:,0]*0
        # o[0] = o[0].reshape(*shape_o0)

        # Sum across objects. Note: Add Clamp or Sigmoid.
        o[:2] = [
            torch.clamp(torch.sum(item.reshape(bs, self.n_objects,
                                               *item.shape[1:]),
                                  dim=1),
                        min=0,
                        max=1) for item in o[:2]
        ]
        # o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]]

        # TODO: output = {}
        #  output['rec'] = o[0]
        #  output['pred'] = o[1]

        o.append(returned_post)
        o.append(A)  # State transition matrix
        o.append(B)  # Input matrix
        o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1]))  # Inputs
        o.append(u_dist.reshape(bs * self.n_objects, -1,
                                u_dist.shape[-1]))  # Input distribution
        o.append(
            g_for_koop.reshape(
                bs * self.n_objects, -1,
                g_for_koop.shape[-1]))  # Observation propagated only with A
        o.append(fit_err)  # Fit error g to G_for_pred

        # bb_rec = bb_rec.reshape(torch.Size([bs * self.n_objects, -1]) + bb_rec.size()[1:])
        # bb_rec =torch.sum(bb_rec, dim=2, keepdim=True)
        # o.append(bb_rec.reshape(bs, self.n_objects, *bb_rec.shape[1:])) # Motion field reconstruction

        # TODO: return in dictionary form
        return o
Exemplo n.º 8
0
class RecKoopmanModel(BaseModel):
    def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, u_dim,
                 n_objects, I_factor=10, n_blocks=1, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]):
        super().__init__()
        out_channels = 1
        n_layers = int(np.log2(image_size[0])) - 1

        self.u_dim = u_dim

        # Temporal encoding buffers
        if n_timesteps > 1:
            t = torch.linspace(-1, 1, n_timesteps)
            # Add as constant, with extra dims for N and C
            self.register_buffer('t_grid', t)

        # Set state dim with config, depending on how many time-steps we want to take into account
        self.image_size = image_size
        self.n_timesteps = n_timesteps
        self.state_dim = feat_dim
        self.I_factor = I_factor
        self.psteps = psteps
        self.g_dim = g_dim
        self.n_objects = n_objects

        # self.softmax = nn.Softmax(dim=-1)
        # self.sigmoid = nn.Sigmoid()

        # self.linear_g = nn.Linear(g_dim, g_dim)

        # self.initial_conditions = nn.Sequential(nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps),
        #                                         nn.ReLU(),
        #                                         nn.Linear(feat_dim * n_timesteps, g_dim * 2))

        feat_dyn_dim = feat_dim // 6
        self.feat_dyn_dim = feat_dyn_dim
        self.content = None
        # self.reverse = True
        self.reverse = False
        self.ini_alpha = 1
        self.incr_alpha = 0.25

        # self.rnn_f_cte = nn.LSTM(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim, 1, bias=False, batch_first=True)
        # self.rnn_f_cte = nn.GRU(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim, 1, bias=False, batch_first=True)

        # self.linear_f_mu = nn.Linear(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim)
        # self.linear_f_logvar = nn.Linear(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim)
        # self.linear_f_dyn_mu = nn.Linear(feat_dyn_dim, feat_dyn_dim)
        # self.linear_f_dyn_logvar = nn.Linear(feat_dyn_dim, feat_dyn_dim)
        self.linear_f_mu = nn.Linear(feat_dim, feat_dim)
        self.linear_f_logvar = nn.Linear(feat_dim, feat_dim)

        self.image_encoder = ImageEncoder(in_channels, feat_dim, n_objects, ngf, n_layers)  # feat_dim * 2 if sample here
        self.image_decoder = ImageDecoder(feat_dim, out_channels, ngf, n_layers)
        self.koopman = KoopmanOperators(feat_dyn_dim, nf_particle, nf_effect, g_dim, u_dim, n_timesteps, n_blocks)

    def _get_full_state(self, x, T):

        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1
        x = x.reshape(-1, T, *x.shape[1:])
        new_x = []
        for t in range(new_T):
            new_x.append(torch.cat([x[:, t + idx]
                         for idx in range(self.n_timesteps)], dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1)
        return new_x.reshape(-1, new_x.shape[-1]), new_T

    def _get_full_state_hankel(self, x, T):
        '''
        :param x: features or observations
        :param T: number of time-steps before concatenation
        :return: Columns of a hankel matrix with self.n_timesteps rows.
        '''
        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1

        x = x.reshape(-1, T, *x.shape[2:])
        new_x = []
        for t in range(new_T):
            new_x.append(torch.stack([x[:, t + idx]
                                    for idx in range(self.n_timesteps)], dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1)

        return new_x.reshape(-1, new_T, new_x.shape[-2] * new_x.shape[-1]), new_T


    def forward(self, input, epoch = 1):
        bs, T, ch, h, w = input.shape
        free_pred = T//4
        f = self.image_encoder(input) # (bs * n_objects * T, feat_dim)

        # Option 1: Mean before sampling
        # f_cte = f.reshape(bs * self.n_objects, T, *f.shape[1:])\
        #         [..., self.feat_dyn_dim:]
        # f_cte = f_cte.mean(1)
        # # LSTM to obtain cte
        # # h0 = torch.zeros_like(f_cte[:, 0])[None]
        # # c0 = torch.zeros_like(f_cte[:, 0])[None]
        # # f_cte, _ = self.rnn_f_cte(f_cte[:, :-free_pred], (h0, c0))
        # # f_cte = f_cte[:, -1]
        # f_mu_cte, f_logvar_cte = self.linear_f_mu(f_cte)[:, None].repeat(1, T, 1).reshape(bs * self.n_objects * T, -1), \
        #                          self.linear_f_logvar(f_cte)[:, None].repeat(1, T, 1).reshape(bs * self.n_objects * T, -1)
        # f_cte =  _sample_latent_simple(f_mu_cte, f_logvar_cte)
        # #
        # f_dyn = f[..., :self.feat_dyn_dim]
        # f_mu_dyn, f_logvar_dyn = self.linear_f_dyn_mu(f_dyn), \
        #                          self.linear_f_dyn_logvar(f_dyn)
        # f_dyn = _sample_latent_simple(f_mu_dyn, f_logvar_dyn)
        # #
        # f_mu = torch.cat([f_mu_dyn, f_mu_cte], dim=-1)
        # f_logvar = torch.cat([f_logvar_dyn, f_logvar_cte], dim=-1)

        # Option 2: Sampling all together
        f_mu = self.linear_f_mu(f)
        f_logvar = self.linear_f_logvar(f)
        f = _sample_latent_simple(f_mu, f_logvar)
        # Note: Split features (cte, dyn)
        f_dyn, f_cte = f[..., :self.feat_dyn_dim], f[..., self.feat_dyn_dim:]
        f_cte = f_cte.reshape(bs * self.n_objects, T, *f_cte.shape[1:])

        # Test disentanglement
        # if random.random() < 0.1 or self.content is None:
        #     self.content = f_cte
        # f_cte = self.content

        f_dyn = f_dyn.reshape(bs * self.n_objects, T, *f_dyn.shape[1:])
        f_dyn, T = self._get_full_state_hankel(f_dyn, T)

        # TODO: U might depend also in features from the scene. Residual features (slot n+1 / Background)
        u, u_dist = self.koopman.to_u(f_dyn, temp=self.ini_alpha + epoch * self.incr_alpha)
        g = self.koopman.to_g(f_dyn.reshape(bs * self.n_objects * T, -1), self.psteps)
        g = g.reshape(bs * self.n_objects, T, *g.shape[1:])

        # TODO:
        #  0 Ho foto tot a varying y a tomar
        #  ...
        #  0: state as velocity acceleration, pose
        #  0: g fitting error.
        #  0: Hankel view. F**k koopman for now, it has too many degrees of freedom. Restar input para minimizar rango de H.
        #  1: Invert or Sample randomly u and maximize the error in reconstruction.
        #  2: Treat n_timesteps with conv_nn? Or does it make sense to mix f1(t) and f2(t-1)?
        #  3: We don't impose A symmetric, but we impose that g in both directions use the same A and B.
        #       Explode symmetry. If there's input in the a sequence, there will be the same input in the reverse sequence.
        #       B is valid for both. Then we cant cross, unless we recalculate B. It might not be necessary because we use
        #       the same A for both dirs. INVERT U's temporally --> that's a great idea. Dismiss (n_ini_timesteps - 1 samples) in each extreme Is this correct?
        #       Is it the same linear input intervention if I do it forward and backward?
        #  5: Eigenvalues and Eigenvector. Fix one get the other.
        #  6: Observation selector using Gumbel (should apply mask emissor and receptor
        #       and it should be the same across time)
        #  7: Attention mechanism to select - koopman (sub-koopman) / Identity. It will be useful for the future.
        #  8: When works, formalize code. Hyperparameters through config.yaml
        #  9: The nonlinear projection should also separate the objects. We can separate by the svd? Is it possible?
        # 1: Prev_sample not necessary? Start from g(0)
        # 2: Sampling from features
        # 3: Bottleneck in f. Smaller than G. Constant part is big and routed directly.
        #       Variant part is very small and expanded to G.
        # 4: In evaluation, sample features several times but only keep one of the cte. Check if it captures the appearance
        # 5: Fix B with A learned. If the first doesn't require input, B will be mapped to 0.

        randperm = torch.arange(g.shape[0])  if self.reverse or True\
            else torch.randperm(g.shape[0])

        #Note: Inverting u(t) in the time axis
        # TODO: still here but im tired.


        # if free_pred > 0:
        #     G_tilde = g[randperm, self.n_timesteps -1:-1-free_pred, None]
        #     H_tilde = g[randperm, self.n_timesteps:-free_pred, None]
        # else:
        #     G_tilde = g[randperm, self.n_timesteps -1:-1, None]
        #     H_tilde = g[randperm, self.n_timesteps:, None]

        if free_pred > 0:
            G_tilde = g[randperm, :-1-free_pred, None]
            H_tilde = g[randperm, 1:-free_pred, None]
        else:
            G_tilde = g[randperm, :-1, None]
            H_tilde = g[randperm, 1:, None]

        # TODO: If we identify with half of the timesteps, but use all inputs for rollout,
        #  we might get something. We can also predict only the future.
        A, B, A_inv, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, U=u[randperm, :-1-free_pred], I_factor=self.I_factor) #Note: Not permutting input, but permutting g.
        # A, B, A_inv, fit_err = self.koopman.fit_with_A(G=G_tilde, H=H_tilde, U=u[randperm, :-1], I_factor=self.I_factor)
        # A, B, A_inv, fit_err = self.koopman.fit_with_B(G=G_tilde, H=H_tilde, U=u[randperm, :-1-free_pred], I_factor=self.I_factor)
        # A, B = self.koopman.fit_with_AB(G_tilde.shape[0])
        # A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A(G=G_tilde, H=H_tilde, I_factor=self.I_factor)
        # B = None

        # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B)
        # g_for_koop = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None)
        # G_for_pred = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1],
        #                         self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=u[:, self.n_timesteps -1:], A=A, B=B)], dim=1)
        # g_for_koop = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1],
        #                         self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=None, A=A, B=None)], dim=1)
        # G_for_pred = self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=u[:, self.n_timesteps -1:], A=A, B=B)
        # g_for_koop = self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=None, A=A, B=None)

        # G_for_pred = self.koopman.simulate(T=T - 4, g=g[:, 3], u=u, A=A, B=B)
        # g_for_koop = self.koopman.simulate(T=T - 4, g=g[:, 3], u=None, A=A, B=None)
        g_start, T_start = g[:, 2], T-3
        G_for_pred = self.koopman.simulate(T=T_start, g=g_start, u=u[!], A=A, B=B) # TODO: ATTENTION. U MIGHT NOT BE ALIGNED TEMPORALLY!
        # g_for_koop = self.koopman.simulate(T=T - 2, g=g[:, 1], u=None, A=A, B=None)
        g_for_koop= G_for_pred


        # g_for_koop = G_for_pred

        # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None)
        '''Simple version of the reverse rollout'''
        # G_for_pred_rev = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A_pinv, B=None)
        # G_for_pred = torch.flip(G_for_pred_rev, dims=[1])
        # G_for_pred = torch.cat([g[:,0:1],self.koopman.simulate(T=T - 2, g=g[:, 1], u=u, A=A, B=B)], dim=1)


        s_for_rec = self.koopman.to_s(gcodes=_get_flat(g),
                                      psteps=self.psteps)
        s_for_pred = self.koopman.to_s(gcodes=_get_flat(G_for_pred),
                                       psteps=self.psteps)

        # Note: Split features (cte, dyn). In case of reverse, f_cte averages for both directions.
        # Note 2: TODO: Reconstruction could be in reversed g's or both!
        # Option 1: Temporally permute appearance
        # T_randperm =  torch.randperm(g.shape[1]) # torch.arange(g.shape[1])
        # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1),
        #                        f_cte[:, -T:]], dim=-1)
        # # f_cte = f_cte[:, T_randperm]
        # f_cte = f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1)
        # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1),
        #                         f_cte[:, -G_for_pred.shape[1]:]], dim=-1)

        # T_randperm =  torch.randperm(g.shape[1]) # torch.arange(g.shape[1])
        # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1),
        #                        f_cte[:, -T:]], dim=-1)
        # # f_cte = f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1)
        # f_cte = f_cte[:, T_randperm]
        #
        # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1),
        #                         f_cte[:, -G_for_pred.shape[1]:]], dim=-1)
        # Option 2: Average appearance
        s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1),
                              f_cte[:, None].mean(2).repeat(1, T, 1)], dim=-1)
        s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1),
                              f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1)], dim=-1)
        # Option 3: Copy the first appearance (or a random one)
        # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1),
        #                        f_cte[:, 0:1].repeat(1, T, 1)], dim=-1)
        # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1),
        #                         f_cte[:, -G_for_pred.shape[1]:-G_for_pred.shape[1]+1].repeat(1, G_for_pred.shape[1], 1)], dim=-1)

        # Option 4: If f_cte has been computed before.
        # f_cte = f_cte.reshape(bs * self.n_objects, -1, f_cte.shape[-1])
        # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1),
        #                        f_cte[:, -T:]], dim=-1)
        # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1),
        #                         f_cte[:, -G_for_pred.shape[1]:]], dim=-1)
        #
        # s_for_rec = _get_flat(s_for_rec)
        # s_for_pred = _get_flat(s_for_pred)

        # NOTE: Sample after
        # f_mu = self.linear_f_mu(s_for_rec)
        # f_logvar = self.linear_f_logvar(s_for_rec)
        # s_for_rec = _sample_latent_simple(f_mu, f_logvar)
        # #
        # f_mu_pred = self.linear_f_mu(s_for_pred)
        # f_logvar_pred = self.linear_f_logvar(s_for_pred)
        # s_for_pred = _sample_latent_simple(f_mu_pred, f_logvar_pred)
        # f_mu, f_mu_pred = f_mu.reshape(bs * self.n_objects, T, -1), f_mu_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1)
        # f_logvar, f_logvar_pred = f_logvar.reshape(bs * self.n_objects, T, -1), f_logvar_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1)


        # Convolutional decoder. Normally Spatial Broadcasting decoder
        out_rec = self.image_decoder(s_for_rec)
        out_pred = self.image_decoder(s_for_pred)

        returned_g = torch.cat([g, G_for_pred], dim=1)

        returned_mus = torch.cat([f_mu], dim=1)
        returned_logvars = torch.cat([f_logvar], dim=1)
        # returned_mus = torch.cat([f_mu, f_mu_pred], dim=1)
        # returned_logvars = torch.cat([f_logvar, f_logvar_pred], dim=1)

        o_touple = (out_rec, out_pred, returned_g.reshape(-1, returned_g.size(-1)),
                    returned_mus.reshape(-1, returned_mus.size(-1)),
                    returned_logvars.reshape(-1, returned_logvars.size(-1)))
                    # f_mu.reshape(-1, f_mu.size(-1)),
                    # f_logvar.reshape(-1, f_logvar.size(-1)))
        o = [item.reshape(torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple]
        # Option 1: one object mapped to 0
        # shape_o0 = o[0].shape
        # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:])
        # o[0][:,0] = o[0][:,0]*0
        # o[0] = o[0].reshape(*shape_o0)

        # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2]]

        # Test object decomposition
        # o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:])[:,0:1], dim=1) for item in o[:2]]
        o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]]

        o[3:5] = [item.reshape(bs, self.n_objects, *item.shape[1:]) for item in o[3:5]]

        o.append(A)
        o.append(B)

        # u = u.reshape(*u.shape[:2], self.u_dim, self.n_timesteps)[..., -1] #TODO:HANKEL This is for hankel view
        o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1]))
        o.append(u_dist.reshape(bs * self.n_objects, -1, *u_dist.shape[-1:])) #Append udist for categorical
        # o.append(u_dist.reshape(bs * self.n_objects, -1, *u_dist.shape[-2:]))
        o.append(g_for_koop.reshape(bs * self.n_objects, -1, g_for_koop.shape[-1]))
        o.append(fit_err)

        return o
Exemplo n.º 9
0
class RecKoopmanModel(BaseModel):
    def __init__(self,
                 in_channels,
                 feat_dim,
                 nf_particle,
                 nf_effect,
                 g_dim,
                 u_dim,
                 n_objects,
                 I_factor=10,
                 n_blocks=1,
                 psteps=1,
                 n_timesteps=1,
                 ngf=8,
                 image_size=[64, 64]):
        super().__init__()
        out_channels = 1
        n_layers = int(np.log2(image_size[0])) - 1

        self.u_dim = u_dim

        # Temporal encoding buffers
        if n_timesteps > 1:
            t = torch.linspace(-1, 1, n_timesteps)
            # Add as constant, with extra dims for N and C
            self.register_buffer('t_grid', t)

        # Set state dim with config, depending on how many time-steps we want to take into account
        self.image_size = image_size
        self.n_timesteps = n_timesteps
        self.state_dim = feat_dim
        self.I_factor = I_factor
        self.psteps = psteps
        self.g_dim = g_dim
        self.n_objects = n_objects

        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()

        feat_dyn_dim = feat_dim  #// 8
        self.linear_g = nn.Linear(g_dim, g_dim)
        self.linear_u = nn.Linear(g_dim + u_dim, u_dim * 2)
        # self.initial_conditions = nn.Sequential(nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps),
        #                                         nn.ReLU(),
        #                                         nn.Linear(feat_dim * n_timesteps, g_dim * 2))

        self.image_encoder = ImageEncoder(
            in_channels, feat_dim, n_objects, ngf,
            n_layers)  # feat_dim * 2 if sample here
        self.image_decoder = ImageDecoder(feat_dyn_dim, out_channels, ngf,
                                          n_layers)
        self.koopman = KoopmanOperators(feat_dim, nf_particle, nf_effect,
                                        g_dim, u_dim, n_timesteps, n_blocks)

    def _get_full_state(self, x, T):

        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1
        x = x.reshape(-1, T, *x.shape[1:])
        new_x = []
        for t in range(new_T):
            new_x.append(
                torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)],
                          dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=2)

        return new_x.reshape(-1, *new_x.shape[3:]), new_T

    def forward(self, input):
        bs, T, ch, h, w = input.shape

        f = self.image_encoder(input)
        # Concatenates the features from current time and previous to create full state
        f, T = self._get_full_state(f, T)

        f = f.view(torch.Size([bs * self.n_objects, T]) + f.size()[1:])

        f_mu, f_logvar, u_dist = [], [], []
        gs, us = [], []
        for t in range(T):
            if t == 0:
                prev_sample = torch.zeros_like(f[:,
                                                 0, :1].repeat(1, self.g_dim))
                prev_u_sample = torch.zeros_like(f[:, 0, :self.u_dim])

                # TODO:
                #  1: Prev_sample not necessary? Start from g(0)
                #  2: Sampling from features
                #  3: Bottleneck in f. Smaller than G. Constant part is big and routed directly.
                #       Variant part is very small and expanded to G.
                #  4: Clip gradient, clip A_inv, norm of the matrix: Spectral normalization
                #  5: Eigenvalues and Eigenvector. Fix one get the other.
                #  6: Half of features skip the koopman embedding
                #  7: Observation selector using Gumbel (should apply mask emissor and receptor
                #       and it should be the same across time)
                #  8: Should I write the conversation with Sandesh?
                #  Finally, check compositional koopman implementation

                # f_t = self.initial_conditions(f[:, t])
                f_t = torch.cat([f[:, t], prev_sample], dim=-1)
                g = self.koopman.to_g(f_t, self.psteps)
            else:
                f_t = torch.cat([f[:, t], prev_sample], dim=-1)
                g = self.koopman.to_g(f_t, self.psteps)
                # g, prev_hidden, u = self.koopman.to_g(f_t, prev_hidden) #,psteps=self.psteps

            g_mu, g_logvar = torch.chunk(g, 2, dim=-1)
            g = _sample_latent_simple(g_mu, g_logvar)
            g, q_y = self.linear_g(g), F.relu(
                self.linear_u(torch.cat([g, prev_u_sample], dim=-1)))

            q_y = q_y.view(q_y.size(0), self.u_dim, 2)
            u = F.gumbel_softmax(q_y, tau=1, hard=True)
            u = u[..., 0]

            prev_sample, prev_u_sample = g, u
            gs.append(g)
            us.append(u)

            f_mu.append(g_mu)
            f_logvar.append(g_logvar)

            u_dist.append(q_y)

        g = torch.stack(gs, dim=1)
        f_mu = torch.stack(f_mu, dim=1)
        f_logvar = torch.stack(f_logvar, dim=1)

        u = torch.stack(us, dim=1)
        u_dist = torch.stack(u_dist, dim=1)

        free_pred = 0
        if free_pred > 0:
            G_tilde = g[:, 1:-1 - free_pred,
                        None]  # new axis corresponding to N number of objects
            H_tilde = g[:, 2:-free_pred, None]
        else:
            G_tilde = g[:, 1:-1,
                        None]  # new axis corresponding to N number of objects
            H_tilde = g[:, 2:, None]

        # A, B, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, U=u[:, 1:-1], I_factor=self.I_factor)
        A, B, fit_err = self.koopman.fit_with_A(G=G_tilde,
                                                H=H_tilde,
                                                U=u[:, 1:-1],
                                                I_factor=self.I_factor)

        # A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A(G=G_tilde, H=H_tilde, I_factor=self.I_factor)
        # TODO: Try simulating backwards
        # B=None
        # Rollout. From observation in time 0, predict with Koopman operator
        # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B)
        # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None)
        '''Simple version of the reverse rollout'''
        # G_for_pred_rev = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A_pinv, B=None)
        # G_for_pred = torch.flip(G_for_pred_rev, dims=[1])

        G_for_pred = torch.cat([
            g[:, 0:1],
            self.koopman.simulate(T=T - 2, g=g[:, 1], u=u, A=A, B=B)
        ],
                               dim=1)

        # Option 1: use the koopman object decoder
        s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), psteps=self.psteps)
        # torch.cat([g[:, :1],G_for_pred], dim=1)
        s_for_pred = self.koopman.to_s(gcodes=_get_flat(G_for_pred),
                                       psteps=self.psteps)
        # Option 2: we don't use the koopman object decoder
        # s_for_rec = _get_flat(g)
        # s_for_pred = _get_flat(G_for_pred)

        # Convolutional decoder. Normally Spatial Broadcasting decoder
        out_rec = self.image_decoder(s_for_rec)
        out_pred = self.image_decoder(s_for_pred)

        returned_g = torch.cat([g, G_for_pred], dim=1)
        returned_mus = torch.cat([f_mu], dim=-1)
        returned_logvars = torch.cat([f_logvar], dim=-1)
        # returned_mus = torch.cat([f_mu, a_mu], dim=-1)
        # returned_logvars = torch.cat([f_logvar, a_logvar], dim=-1)

        o_touple = (out_rec, out_pred,
                    returned_g.reshape(-1, returned_g.size(-1)),
                    returned_mus.reshape(-1, returned_mus.size(-1)),
                    returned_logvars.reshape(-1, returned_logvars.size(-1)))
        # f_mu.reshape(-1, f_mu.size(-1)),
        # f_logvar.reshape(-1, f_logvar.size(-1)))
        o = [
            item.reshape(
                torch.Size([bs * self.n_objects, -1]) + item.size()[1:])
            for item in o_touple
        ]
        # Option 1: one object mapped to 0
        # shape_o0 = o[0].shape
        # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:])
        # o[0][:,0] = o[0][:,0]*0
        # o[0] = o[0].reshape(*shape_o0)

        # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2]]
        o[:2] = [
            torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1)
            for item in o[:2]
        ]
        o[3:5] = [
            item.reshape(bs, self.n_objects, *item.shape[1:])
            for item in o[3:5]
        ]

        o.append(A)
        o.append(B)
        o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1]))
        o.append(u_dist)  #.reshape(bs * self.n_objects, -1, u_dist.shape[-1]))
        # # Show images
        # plt.imshow(input[0, 0, :].reshape(16, 16).cpu().detach().numpy())
        # plt.savefig('test_attention.png')
        return o