Esempio n. 1
0
class RecKoopmanModel(BaseModel):
    def __init__(self,
                 in_channels,
                 feat_dim,
                 nf_particle,
                 nf_effect,
                 g_dim,
                 u_dim,
                 n_objects,
                 I_factor=10,
                 n_blocks=1,
                 psteps=1,
                 n_timesteps=1,
                 ngf=8,
                 image_size=[64, 64]):
        super().__init__()
        out_channels = 1
        n_layers = int(np.log2(image_size[0])) - 1

        self.u_dim = u_dim

        # Set state dim with config, depending on how many time-steps we want to take into account
        self.image_size = image_size
        self.n_timesteps = n_timesteps
        self.state_dim = feat_dim
        self.I_factor = I_factor
        self.psteps = psteps
        self.g_dim = g_dim

        feat_dyn_dim = feat_dim // 8
        feat_dyn_dim = 4
        self.feat_dyn_dim = feat_dyn_dim
        self.feat_cte_dim = feat_dim - feat_dyn_dim

        self.with_u = False
        self.n_iters = 1
        self.ini_alpha = 1
        # Note:
        #  - I leave it to 0 now. If it increases too fast, the gradients might be affected
        self.incr_alpha = 0.1

        self.cte_resolution = (32, 32)
        self.ori_resolution = (128, 128)
        self.att_resolution = (16, 16)
        self.obj_resolution = (4, 4)
        # self.n_objects = reduce((lambda x, y: x * y), self.obj_resolution)
        self.n_objects = n_objects

        self.spatial_tf = SpatialTransformation(self.cte_resolution,
                                                self.ori_resolution)

        self.linear_f_cte_post = nn.Linear(2 * self.feat_cte_dim,
                                           2 * self.feat_cte_dim)
        self.linear_f_dyn_post = nn.Linear(2 * self.feat_dyn_dim,
                                           2 * self.feat_dyn_dim)
        #
        self.bc_decoder = False  #2 *
        if self.bc_decoder:
            self.image_decoder = ImageBroadcastDecoder(
                self.feat_cte_dim, out_channels,
                resolution=(16, 16))  # resolution=self.att_resolution
        else:
            self.image_decoder = ImageDecoder(self.feat_cte_dim,
                                              out_channels,
                                              dyn_dim=self.feat_dyn_dim)
        self.image_encoder = ImageEncoder(
            in_channels, 2 * self.feat_cte_dim, self.feat_dyn_dim,
            self.att_resolution, self.n_objects, ngf,
            n_layers)  # feat_dim * 2 if sample here
        self.koopman = KoopmanOperators(feat_dyn_dim, nf_particle, nf_effect,
                                        g_dim, u_dim, n_timesteps, n_blocks)

    def _get_full_state(self, x, T):

        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1
        x = x.reshape(-1, T, *x.shape[1:])
        new_x = []
        for t in range(new_T):
            new_x.append(
                torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)],
                          dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1)
        return new_x.reshape(-1, new_x.shape[-1]), new_T

    def _get_full_state_hankel(self, x, T):
        '''
        :param x: features or observations
        :param T: number of time-steps before concatenation
        :return: Columns of a hankel matrix with self.n_timesteps rows.
        '''
        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1

        x = x.reshape(-1, T, *x.shape[2:])
        new_x = []
        for t in range(new_T):
            new_x.append(
                torch.stack([x[:, t + idx] for idx in range(self.n_timesteps)],
                            dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1)

        return new_x.reshape(-1, new_T,
                             new_x.shape[-2] * new_x.shape[-1]), new_T

    def forward(self, input, epoch=1):
        # Note: Add annealing in SPACE
        bs, T, ch, h, w = input.shape
        # Percentage of output

        free_pred = T // 4
        returned_post = []

        input = input.cuda()
        # Backbone deterministic features
        f_bb = self.image_encoder(input, block='backbone')

        # Dynamic features
        T_inp = T
        f_dyn, shape, f_cte, confi = self.image_encoder(f_bb[:, :T_inp],
                                                        block='dyn_track')
        f_dyn = f_dyn.reshape(-1, f_dyn.shape[-1])

        # Sample dynamic features or reshape
        # Option 1: Don't sample
        f_dyn = f_dyn.reshape(bs * self.n_objects, T_inp, -1)
        # Option 2: Sample
        # f_mu_dyn, f_logvar_dyn = self.linear_f_dyn_post(f_dyn).reshape(bs, self.n_objects, T_inp, -1).chunk(2, -1)
        # f_dyn_post = Normal(f_mu_dyn, F.softplus(f_logvar_dyn))
        # f_dyn = f_dyn_post.rsample()
        # f_dyn = f_dyn.reshape(bs * self.n_objects, T_inp, -1)
        # returned_post.append(f_dyn_post)

        # Get delayed dynamic features
        f_dyn_state, T_inp = self._get_full_state_hankel(f_dyn, T_inp)

        # Get inputs from delayed dynamic features
        u, u_dist = self.koopman.to_u(f_dyn_state,
                                      temp=self.ini_alpha +
                                      epoch * self.incr_alpha,
                                      ignore=True)
        if not self.with_u:
            u = torch.zeros_like(u)
        # Get observations from delayed dynamic features
        g = self.koopman.to_g(
            f_dyn_state.reshape(bs * self.n_objects * T_inp, -1), self.psteps)
        g = g.reshape(bs * self.n_objects, T_inp, *g.shape[1:])

        # Get shifted observations for sys ID
        randperm = torch.arange(g.shape[0])  # No permutation
        # randperm = torch.randperm(g.shape[0]) # Random permutation
        if free_pred > 0:
            G_tilde = g[randperm, :-1 - free_pred, None]
            H_tilde = g[randperm, 1:-free_pred, None]
        else:
            G_tilde = g[randperm, :-1, None]
            H_tilde = g[randperm, 1:, None]

        # Sys ID
        A, B, A_inv, fit_err = self.koopman.system_identify(
            G=G_tilde,
            H=H_tilde,
            U=u[randperm, :T_inp - free_pred - 1],
            I_factor=self.I_factor
        )  # Try not permuting U when inp is permutted

        # Rollout from start_step onwards.
        start_step = 2  # g and u must be aligned!!
        G_for_pred = self.koopman.simulate(T=T_inp - start_step - 1,
                                           g=g[:, start_step],
                                           u=u[:, start_step:],
                                           A=A,
                                           B=B)
        g_for_koop = G_for_pred

        assert f_bb[:, self.n_timesteps - 1:self.n_timesteps - 1 +
                    T_inp].shape[1] == f_bb[:, self.n_timesteps - 1:].shape[1]
        rec = {
            "obs": g,
            "confi": confi[:, :,
                           self.n_timesteps - 1:self.n_timesteps - 1 + T_inp],
            "shape": shape[:, :,
                           self.n_timesteps - 1:self.n_timesteps - 1 + T_inp],
            "f_cte": f_cte[:, :,
                           self.n_timesteps - 1:self.n_timesteps - 1 + T_inp],
            "T": T_inp,
            "name": "rec"
        }
        pred = {
            "obs": G_for_pred,
            "confi": confi[:, :, -G_for_pred.shape[1]:],
            "shape": shape[:, :, -G_for_pred.shape[1]:],
            "f_cte": f_cte[:, :, -G_for_pred.shape[1]:],
            "T": G_for_pred.shape[1],
            "name": "pred"
        }
        outs = {}
        BBs = {}

        # TODO: Check if the indices for f_bb and supervision are correct.
        # Recover partial shape with decoded dynamical features. Iterate with new estimates of the appearance.
        # Note: This process could be iterative.
        for idx, case in enumerate([rec, pred]):

            case_name = case["name"]
            # get back dynamic features
            # if case_name == "rec":
            #     f_dyn_state = f_dyn[:, -T_inp:].reshape(-1, *f_dyn.shape[2:])
            # else:
            # TODO: Try with Initial f_dyn (no koop)
            f_dyn_state = self.koopman.to_s(gcodes=_get_flat(case["obs"]),
                                            psteps=self.psteps)
            # TODO: From f_dyn extract where, scale, etc. with only Y = WX + b. Review.

            # Initial noisy (low variance) constant vector.
            # Note:
            #  - Different realization for each time-step.
            #  - Is it a Variable?
            # f_cte_ini = torch.randn(bs * self.n_objects * case["T"], self.feat_cte_dim).to(f_dyn_state.device) #* self.f_cte_ini_std

            # Get full feature vector
            # f = torch.cat([ f_dyn_state,
            #                 f_cte_ini], dim=-1)

            # Get coarse features from which obtain queries and/or decode
            # f_coarse = self.image_decoder(f, block = 'coarse')
            # f_coarse = f_coarse.reshape(bs, self.n_objects, case["T"], *f_coarse.shape[1:])

            for _ in range(self.n_iters):

                # Get constant feature vector through attention
                # f_cte = self.image_encoder(case["backbone_features"], f_coarse, block='cte')

                # Encode with raw dynamic features
                # pose = self.pose_encoder(f_dyn_state)
                pose = f_dyn_state
                # M_center, M_relocate = get_affine_params_from_pose(pose, self.pose_limits, self.pose_bias)

                # Option 2: Previous impl
                # bb_feat = case["backbone_features"].unsqueeze(1).repeat_interleave(self.n_objects, dim=1)\
                #     .reshape(-1, *case["backbone_features"].shape[-4:]) # Repeat for number of objects
                # warped_bb_feat = (bb_feat, M_center, *self.cte_resolution) # Check all resolutions are coordinate [32, 32]
                # f_cte = self.image_encoder(warped_bb_feat, f_dyn_state, block='cte')

                # Sample cte features
                f_cte = case["f_cte"].reshape(bs * self.n_objects * case["T"],
                                              -1)
                confi = case["confi"].reshape(bs * self.n_objects * case["T"],
                                              -1)
                #

                f_mu_cte, f_logvar_cte = self.linear_f_cte_post(f_cte).reshape(
                    bs, self.n_objects, case["T"], -1).chunk(2, -1)
                f_cte_post = Normal(f_mu_cte, F.softplus(f_logvar_cte))
                f_cte = f_cte_post.rsample().reshape(
                    bs * self.n_objects * case["T"], -1)
                # Option 2: Previous impl
                # f_mu_cte, f_logvar_cte = self.linear_f_cte_post(f_cte).reshape(bs, self.n_objects, 1, -1).chunk(2, -1)
                # f_cte_post = Normal(f_mu_cte, F.softplus(f_logvar_cte))
                # f_cte = f_cte_post.rsample()
                # f_cte = f_cte.repeat_interleave(case["T"], dim=2)
                # f_cte = f_cte.reshape(bs * self.n_objects * case["T"], self.feat_cte_dim)

                # Register statistics
                returned_post.append(f_cte_post)

                # Get full feature vector
                # f = torch.cat([ f_dyn_state,
                #                 f_cte], dim=-1) # Note: Do if f_dyn_state is used in the appearance
                f = f_cte

            # Get output. Spatial broadcast decoder

            dec_obj = self.image_decoder(f, block='to_x')
            if not self.bc_decoder:
                grid, area = self.spatial_tf(confi, pose)
                outs[case_name], out_shape = self.spatial_tf.warp_and_render(
                    dec_obj, case["shape"], confi, grid)
            else:
                outs[case_name] = dec_obj * confi[..., None, None]
            # outs[case_name] = warp_affine(dec_obj, M_relocate, h=h, w=w)

            BBs[case_name] = dec_obj

        out_rec = outs["rec"]
        out_pred = outs["pred"]
        bb_rec = BBs["rec"]
        # bb_rec = BBs["pred"]

        # Test disentanglement - TO REVIEW
        # if random.random() < 0.1 or self.content is None:
        #     self.content = f_cte
        # f_cte = self.content
        ''' Returned variables '''
        returned_g = torch.cat([g, G_for_pred], dim=1)  # Observations

        o_touple = (out_rec, out_pred,
                    returned_g.reshape(-1, returned_g.size(-1)))
        o = [
            item.reshape(
                torch.Size([bs * self.n_objects, -1]) + item.size()[1:])
            for item in o_touple
        ]

        # Option 1: one object mapped to 0
        # shape_o0 = o[0].shape
        # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:])
        # o[0][:,0] = o[0][:,0]*0
        # o[0] = o[0].reshape(*shape_o0)

        # Sum across objects. Note: Add Clamp or Sigmoid.
        o[:2] = [
            torch.clamp(torch.sum(item.reshape(bs, self.n_objects,
                                               *item.shape[1:]),
                                  dim=1),
                        min=0,
                        max=1) for item in o[:2]
        ]
        # o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]]

        # TODO: output = {}
        #  output['rec'] = o[0]
        #  output['pred'] = o[1]

        o.append(returned_post)
        o.append(A)  # State transition matrix
        o.append(B)  # Input matrix
        o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1]))  # Inputs
        o.append(u_dist.reshape(bs * self.n_objects, -1,
                                u_dist.shape[-1]))  # Input distribution
        o.append(
            g_for_koop.reshape(
                bs * self.n_objects, -1,
                g_for_koop.shape[-1]))  # Observation propagated only with A
        o.append(fit_err)  # Fit error g to G_for_pred

        # bb_rec = bb_rec.reshape(torch.Size([bs * self.n_objects, -1]) + bb_rec.size()[1:])
        # bb_rec =torch.sum(bb_rec, dim=2, keepdim=True)
        # o.append(bb_rec.reshape(bs, self.n_objects, *bb_rec.shape[1:])) # Motion field reconstruction

        # TODO: return in dictionary form
        return o
class RecKoopmanModel(BaseModel):
    def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, u_dim,
                 n_objects, I_factor=10, n_blocks=1, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]):
        super().__init__()
        out_channels = 1
        n_layers = int(np.log2(image_size[0])) - 1

        self.u_dim = u_dim

        # Temporal encoding buffers
        if n_timesteps > 1:
            t = torch.linspace(-1, 1, n_timesteps)
            # Add as constant, with extra dims for N and C
            self.register_buffer('t_grid', t)

        # Set state dim with config, depending on how many time-steps we want to take into account
        self.image_size = image_size
        self.n_timesteps = n_timesteps
        self.state_dim = feat_dim
        self.I_factor = I_factor
        self.psteps = psteps
        self.g_dim = g_dim
        self.n_objects = n_objects

        # self.softmax = nn.Softmax(dim=-1)
        # self.sigmoid = nn.Sigmoid()

        # self.linear_g = nn.Linear(g_dim, g_dim)

        # self.initial_conditions = nn.Sequential(nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps),
        #                                         nn.ReLU(),
        #                                         nn.Linear(feat_dim * n_timesteps, g_dim * 2))

        feat_dyn_dim = feat_dim // 6
        self.feat_dyn_dim = feat_dyn_dim
        self.content = None
        # self.reverse = True
        self.reverse = False
        self.ini_alpha = 1
        self.incr_alpha = 0.25

        # self.rnn_f_cte = nn.LSTM(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim, 1, bias=False, batch_first=True)
        # self.rnn_f_cte = nn.GRU(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim, 1, bias=False, batch_first=True)

        # self.linear_f_mu = nn.Linear(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim)
        # self.linear_f_logvar = nn.Linear(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim)
        # self.linear_f_dyn_mu = nn.Linear(feat_dyn_dim, feat_dyn_dim)
        # self.linear_f_dyn_logvar = nn.Linear(feat_dyn_dim, feat_dyn_dim)
        self.linear_f_mu = nn.Linear(feat_dim, feat_dim)
        self.linear_f_logvar = nn.Linear(feat_dim, feat_dim)

        self.image_encoder = ImageEncoder(in_channels, feat_dim, n_objects, ngf, n_layers)  # feat_dim * 2 if sample here
        self.image_decoder = ImageDecoder(feat_dim, out_channels, ngf, n_layers)
        self.koopman = KoopmanOperators(feat_dyn_dim, nf_particle, nf_effect, g_dim, u_dim, n_timesteps, n_blocks)

    def _get_full_state(self, x, T):

        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1
        x = x.reshape(-1, T, *x.shape[1:])
        new_x = []
        for t in range(new_T):
            new_x.append(torch.cat([x[:, t + idx]
                         for idx in range(self.n_timesteps)], dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1)
        return new_x.reshape(-1, new_x.shape[-1]), new_T

    def _get_full_state_hankel(self, x, T):
        '''
        :param x: features or observations
        :param T: number of time-steps before concatenation
        :return: Columns of a hankel matrix with self.n_timesteps rows.
        '''
        if self.n_timesteps < 2:
            return x, T
        new_T = T - self.n_timesteps + 1

        x = x.reshape(-1, T, *x.shape[2:])
        new_x = []
        for t in range(new_T):
            new_x.append(torch.stack([x[:, t + idx]
                                    for idx in range(self.n_timesteps)], dim=-1))
        # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1)
        new_x = torch.stack(new_x, dim=1)

        return new_x.reshape(-1, new_T, new_x.shape[-2] * new_x.shape[-1]), new_T


    def forward(self, input, epoch = 1):
        bs, T, ch, h, w = input.shape
        free_pred = T//4
        f = self.image_encoder(input) # (bs * n_objects * T, feat_dim)

        # Option 1: Mean before sampling
        # f_cte = f.reshape(bs * self.n_objects, T, *f.shape[1:])\
        #         [..., self.feat_dyn_dim:]
        # f_cte = f_cte.mean(1)
        # # LSTM to obtain cte
        # # h0 = torch.zeros_like(f_cte[:, 0])[None]
        # # c0 = torch.zeros_like(f_cte[:, 0])[None]
        # # f_cte, _ = self.rnn_f_cte(f_cte[:, :-free_pred], (h0, c0))
        # # f_cte = f_cte[:, -1]
        # f_mu_cte, f_logvar_cte = self.linear_f_mu(f_cte)[:, None].repeat(1, T, 1).reshape(bs * self.n_objects * T, -1), \
        #                          self.linear_f_logvar(f_cte)[:, None].repeat(1, T, 1).reshape(bs * self.n_objects * T, -1)
        # f_cte =  _sample_latent_simple(f_mu_cte, f_logvar_cte)
        # #
        # f_dyn = f[..., :self.feat_dyn_dim]
        # f_mu_dyn, f_logvar_dyn = self.linear_f_dyn_mu(f_dyn), \
        #                          self.linear_f_dyn_logvar(f_dyn)
        # f_dyn = _sample_latent_simple(f_mu_dyn, f_logvar_dyn)
        # #
        # f_mu = torch.cat([f_mu_dyn, f_mu_cte], dim=-1)
        # f_logvar = torch.cat([f_logvar_dyn, f_logvar_cte], dim=-1)

        # Option 2: Sampling all together
        f_mu = self.linear_f_mu(f)
        f_logvar = self.linear_f_logvar(f)
        f = _sample_latent_simple(f_mu, f_logvar)
        # Note: Split features (cte, dyn)
        f_dyn, f_cte = f[..., :self.feat_dyn_dim], f[..., self.feat_dyn_dim:]
        f_cte = f_cte.reshape(bs * self.n_objects, T, *f_cte.shape[1:])

        # Test disentanglement
        # if random.random() < 0.1 or self.content is None:
        #     self.content = f_cte
        # f_cte = self.content

        f_dyn = f_dyn.reshape(bs * self.n_objects, T, *f_dyn.shape[1:])
        f_dyn, T = self._get_full_state_hankel(f_dyn, T)

        # TODO: U might depend also in features from the scene. Residual features (slot n+1 / Background)
        u, u_dist = self.koopman.to_u(f_dyn, temp=self.ini_alpha + epoch * self.incr_alpha)
        g = self.koopman.to_g(f_dyn.reshape(bs * self.n_objects * T, -1), self.psteps)
        g = g.reshape(bs * self.n_objects, T, *g.shape[1:])

        # TODO:
        #  0 Ho foto tot a varying y a tomar
        #  ...
        #  0: state as velocity acceleration, pose
        #  0: g fitting error.
        #  0: Hankel view. F**k koopman for now, it has too many degrees of freedom. Restar input para minimizar rango de H.
        #  1: Invert or Sample randomly u and maximize the error in reconstruction.
        #  2: Treat n_timesteps with conv_nn? Or does it make sense to mix f1(t) and f2(t-1)?
        #  3: We don't impose A symmetric, but we impose that g in both directions use the same A and B.
        #       Explode symmetry. If there's input in the a sequence, there will be the same input in the reverse sequence.
        #       B is valid for both. Then we cant cross, unless we recalculate B. It might not be necessary because we use
        #       the same A for both dirs. INVERT U's temporally --> that's a great idea. Dismiss (n_ini_timesteps - 1 samples) in each extreme Is this correct?
        #       Is it the same linear input intervention if I do it forward and backward?
        #  5: Eigenvalues and Eigenvector. Fix one get the other.
        #  6: Observation selector using Gumbel (should apply mask emissor and receptor
        #       and it should be the same across time)
        #  7: Attention mechanism to select - koopman (sub-koopman) / Identity. It will be useful for the future.
        #  8: When works, formalize code. Hyperparameters through config.yaml
        #  9: The nonlinear projection should also separate the objects. We can separate by the svd? Is it possible?
        # 1: Prev_sample not necessary? Start from g(0)
        # 2: Sampling from features
        # 3: Bottleneck in f. Smaller than G. Constant part is big and routed directly.
        #       Variant part is very small and expanded to G.
        # 4: In evaluation, sample features several times but only keep one of the cte. Check if it captures the appearance
        # 5: Fix B with A learned. If the first doesn't require input, B will be mapped to 0.

        randperm = torch.arange(g.shape[0])  if self.reverse or True\
            else torch.randperm(g.shape[0])

        #Note: Inverting u(t) in the time axis
        # TODO: still here but im tired.


        # if free_pred > 0:
        #     G_tilde = g[randperm, self.n_timesteps -1:-1-free_pred, None]
        #     H_tilde = g[randperm, self.n_timesteps:-free_pred, None]
        # else:
        #     G_tilde = g[randperm, self.n_timesteps -1:-1, None]
        #     H_tilde = g[randperm, self.n_timesteps:, None]

        if free_pred > 0:
            G_tilde = g[randperm, :-1-free_pred, None]
            H_tilde = g[randperm, 1:-free_pred, None]
        else:
            G_tilde = g[randperm, :-1, None]
            H_tilde = g[randperm, 1:, None]

        # TODO: If we identify with half of the timesteps, but use all inputs for rollout,
        #  we might get something. We can also predict only the future.
        A, B, A_inv, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, U=u[randperm, :-1-free_pred], I_factor=self.I_factor) #Note: Not permutting input, but permutting g.
        # A, B, A_inv, fit_err = self.koopman.fit_with_A(G=G_tilde, H=H_tilde, U=u[randperm, :-1], I_factor=self.I_factor)
        # A, B, A_inv, fit_err = self.koopman.fit_with_B(G=G_tilde, H=H_tilde, U=u[randperm, :-1-free_pred], I_factor=self.I_factor)
        # A, B = self.koopman.fit_with_AB(G_tilde.shape[0])
        # A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A(G=G_tilde, H=H_tilde, I_factor=self.I_factor)
        # B = None

        # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B)
        # g_for_koop = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None)
        # G_for_pred = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1],
        #                         self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=u[:, self.n_timesteps -1:], A=A, B=B)], dim=1)
        # g_for_koop = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1],
        #                         self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=None, A=A, B=None)], dim=1)
        # G_for_pred = self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=u[:, self.n_timesteps -1:], A=A, B=B)
        # g_for_koop = self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=None, A=A, B=None)

        # G_for_pred = self.koopman.simulate(T=T - 4, g=g[:, 3], u=u, A=A, B=B)
        # g_for_koop = self.koopman.simulate(T=T - 4, g=g[:, 3], u=None, A=A, B=None)
        g_start, T_start = g[:, 2], T-3
        G_for_pred = self.koopman.simulate(T=T_start, g=g_start, u=u[!], A=A, B=B) # TODO: ATTENTION. U MIGHT NOT BE ALIGNED TEMPORALLY!
        # g_for_koop = self.koopman.simulate(T=T - 2, g=g[:, 1], u=None, A=A, B=None)
        g_for_koop= G_for_pred


        # g_for_koop = G_for_pred

        # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None)
        '''Simple version of the reverse rollout'''
        # G_for_pred_rev = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A_pinv, B=None)
        # G_for_pred = torch.flip(G_for_pred_rev, dims=[1])
        # G_for_pred = torch.cat([g[:,0:1],self.koopman.simulate(T=T - 2, g=g[:, 1], u=u, A=A, B=B)], dim=1)


        s_for_rec = self.koopman.to_s(gcodes=_get_flat(g),
                                      psteps=self.psteps)
        s_for_pred = self.koopman.to_s(gcodes=_get_flat(G_for_pred),
                                       psteps=self.psteps)

        # Note: Split features (cte, dyn). In case of reverse, f_cte averages for both directions.
        # Note 2: TODO: Reconstruction could be in reversed g's or both!
        # Option 1: Temporally permute appearance
        # T_randperm =  torch.randperm(g.shape[1]) # torch.arange(g.shape[1])
        # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1),
        #                        f_cte[:, -T:]], dim=-1)
        # # f_cte = f_cte[:, T_randperm]
        # f_cte = f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1)
        # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1),
        #                         f_cte[:, -G_for_pred.shape[1]:]], dim=-1)

        # T_randperm =  torch.randperm(g.shape[1]) # torch.arange(g.shape[1])
        # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1),
        #                        f_cte[:, -T:]], dim=-1)
        # # f_cte = f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1)
        # f_cte = f_cte[:, T_randperm]
        #
        # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1),
        #                         f_cte[:, -G_for_pred.shape[1]:]], dim=-1)
        # Option 2: Average appearance
        s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1),
                              f_cte[:, None].mean(2).repeat(1, T, 1)], dim=-1)
        s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1),
                              f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1)], dim=-1)
        # Option 3: Copy the first appearance (or a random one)
        # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1),
        #                        f_cte[:, 0:1].repeat(1, T, 1)], dim=-1)
        # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1),
        #                         f_cte[:, -G_for_pred.shape[1]:-G_for_pred.shape[1]+1].repeat(1, G_for_pred.shape[1], 1)], dim=-1)

        # Option 4: If f_cte has been computed before.
        # f_cte = f_cte.reshape(bs * self.n_objects, -1, f_cte.shape[-1])
        # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1),
        #                        f_cte[:, -T:]], dim=-1)
        # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1),
        #                         f_cte[:, -G_for_pred.shape[1]:]], dim=-1)
        #
        # s_for_rec = _get_flat(s_for_rec)
        # s_for_pred = _get_flat(s_for_pred)

        # NOTE: Sample after
        # f_mu = self.linear_f_mu(s_for_rec)
        # f_logvar = self.linear_f_logvar(s_for_rec)
        # s_for_rec = _sample_latent_simple(f_mu, f_logvar)
        # #
        # f_mu_pred = self.linear_f_mu(s_for_pred)
        # f_logvar_pred = self.linear_f_logvar(s_for_pred)
        # s_for_pred = _sample_latent_simple(f_mu_pred, f_logvar_pred)
        # f_mu, f_mu_pred = f_mu.reshape(bs * self.n_objects, T, -1), f_mu_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1)
        # f_logvar, f_logvar_pred = f_logvar.reshape(bs * self.n_objects, T, -1), f_logvar_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1)


        # Convolutional decoder. Normally Spatial Broadcasting decoder
        out_rec = self.image_decoder(s_for_rec)
        out_pred = self.image_decoder(s_for_pred)

        returned_g = torch.cat([g, G_for_pred], dim=1)

        returned_mus = torch.cat([f_mu], dim=1)
        returned_logvars = torch.cat([f_logvar], dim=1)
        # returned_mus = torch.cat([f_mu, f_mu_pred], dim=1)
        # returned_logvars = torch.cat([f_logvar, f_logvar_pred], dim=1)

        o_touple = (out_rec, out_pred, returned_g.reshape(-1, returned_g.size(-1)),
                    returned_mus.reshape(-1, returned_mus.size(-1)),
                    returned_logvars.reshape(-1, returned_logvars.size(-1)))
                    # f_mu.reshape(-1, f_mu.size(-1)),
                    # f_logvar.reshape(-1, f_logvar.size(-1)))
        o = [item.reshape(torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple]
        # Option 1: one object mapped to 0
        # shape_o0 = o[0].shape
        # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:])
        # o[0][:,0] = o[0][:,0]*0
        # o[0] = o[0].reshape(*shape_o0)

        # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2]]

        # Test object decomposition
        # o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:])[:,0:1], dim=1) for item in o[:2]]
        o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]]

        o[3:5] = [item.reshape(bs, self.n_objects, *item.shape[1:]) for item in o[3:5]]

        o.append(A)
        o.append(B)

        # u = u.reshape(*u.shape[:2], self.u_dim, self.n_timesteps)[..., -1] #TODO:HANKEL This is for hankel view
        o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1]))
        o.append(u_dist.reshape(bs * self.n_objects, -1, *u_dist.shape[-1:])) #Append udist for categorical
        # o.append(u_dist.reshape(bs * self.n_objects, -1, *u_dist.shape[-2:]))
        o.append(g_for_koop.reshape(bs * self.n_objects, -1, g_for_koop.shape[-1]))
        o.append(fit_err)

        return o