class RecKoopmanModel(BaseModel): def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, u_dim, n_objects, I_factor=10, n_blocks=1, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]): super().__init__() out_channels = 1 n_layers = int(np.log2(image_size[0])) - 1 self.u_dim = u_dim # Set state dim with config, depending on how many time-steps we want to take into account self.image_size = image_size self.n_timesteps = n_timesteps self.state_dim = feat_dim self.I_factor = I_factor self.psteps = psteps self.g_dim = g_dim feat_dyn_dim = feat_dim // 8 feat_dyn_dim = 4 self.feat_dyn_dim = feat_dyn_dim self.feat_cte_dim = feat_dim - feat_dyn_dim self.with_u = False self.n_iters = 1 self.ini_alpha = 1 # Note: # - I leave it to 0 now. If it increases too fast, the gradients might be affected self.incr_alpha = 0.1 self.cte_resolution = (32, 32) self.ori_resolution = (128, 128) self.att_resolution = (16, 16) self.obj_resolution = (4, 4) # self.n_objects = reduce((lambda x, y: x * y), self.obj_resolution) self.n_objects = n_objects self.spatial_tf = SpatialTransformation(self.cte_resolution, self.ori_resolution) self.linear_f_cte_post = nn.Linear(2 * self.feat_cte_dim, 2 * self.feat_cte_dim) self.linear_f_dyn_post = nn.Linear(2 * self.feat_dyn_dim, 2 * self.feat_dyn_dim) # self.bc_decoder = False #2 * if self.bc_decoder: self.image_decoder = ImageBroadcastDecoder( self.feat_cte_dim, out_channels, resolution=(16, 16)) # resolution=self.att_resolution else: self.image_decoder = ImageDecoder(self.feat_cte_dim, out_channels, dyn_dim=self.feat_dyn_dim) self.image_encoder = ImageEncoder( in_channels, 2 * self.feat_cte_dim, self.feat_dyn_dim, self.att_resolution, self.n_objects, ngf, n_layers) # feat_dim * 2 if sample here self.koopman = KoopmanOperators(feat_dyn_dim, nf_particle, nf_effect, g_dim, u_dim, n_timesteps, n_blocks) def _get_full_state(self, x, T): if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[1:]) new_x = [] for t in range(new_T): new_x.append( torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, new_x.shape[-1]), new_T def _get_full_state_hankel(self, x, T): ''' :param x: features or observations :param T: number of time-steps before concatenation :return: Columns of a hankel matrix with self.n_timesteps rows. ''' if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[2:]) new_x = [] for t in range(new_T): new_x.append( torch.stack([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, new_T, new_x.shape[-2] * new_x.shape[-1]), new_T def forward(self, input, epoch=1): # Note: Add annealing in SPACE bs, T, ch, h, w = input.shape # Percentage of output free_pred = T // 4 returned_post = [] input = input.cuda() # Backbone deterministic features f_bb = self.image_encoder(input, block='backbone') # Dynamic features T_inp = T f_dyn, shape, f_cte, confi = self.image_encoder(f_bb[:, :T_inp], block='dyn_track') f_dyn = f_dyn.reshape(-1, f_dyn.shape[-1]) # Sample dynamic features or reshape # Option 1: Don't sample f_dyn = f_dyn.reshape(bs * self.n_objects, T_inp, -1) # Option 2: Sample # f_mu_dyn, f_logvar_dyn = self.linear_f_dyn_post(f_dyn).reshape(bs, self.n_objects, T_inp, -1).chunk(2, -1) # f_dyn_post = Normal(f_mu_dyn, F.softplus(f_logvar_dyn)) # f_dyn = f_dyn_post.rsample() # f_dyn = f_dyn.reshape(bs * self.n_objects, T_inp, -1) # returned_post.append(f_dyn_post) # Get delayed dynamic features f_dyn_state, T_inp = self._get_full_state_hankel(f_dyn, T_inp) # Get inputs from delayed dynamic features u, u_dist = self.koopman.to_u(f_dyn_state, temp=self.ini_alpha + epoch * self.incr_alpha, ignore=True) if not self.with_u: u = torch.zeros_like(u) # Get observations from delayed dynamic features g = self.koopman.to_g( f_dyn_state.reshape(bs * self.n_objects * T_inp, -1), self.psteps) g = g.reshape(bs * self.n_objects, T_inp, *g.shape[1:]) # Get shifted observations for sys ID randperm = torch.arange(g.shape[0]) # No permutation # randperm = torch.randperm(g.shape[0]) # Random permutation if free_pred > 0: G_tilde = g[randperm, :-1 - free_pred, None] H_tilde = g[randperm, 1:-free_pred, None] else: G_tilde = g[randperm, :-1, None] H_tilde = g[randperm, 1:, None] # Sys ID A, B, A_inv, fit_err = self.koopman.system_identify( G=G_tilde, H=H_tilde, U=u[randperm, :T_inp - free_pred - 1], I_factor=self.I_factor ) # Try not permuting U when inp is permutted # Rollout from start_step onwards. start_step = 2 # g and u must be aligned!! G_for_pred = self.koopman.simulate(T=T_inp - start_step - 1, g=g[:, start_step], u=u[:, start_step:], A=A, B=B) g_for_koop = G_for_pred assert f_bb[:, self.n_timesteps - 1:self.n_timesteps - 1 + T_inp].shape[1] == f_bb[:, self.n_timesteps - 1:].shape[1] rec = { "obs": g, "confi": confi[:, :, self.n_timesteps - 1:self.n_timesteps - 1 + T_inp], "shape": shape[:, :, self.n_timesteps - 1:self.n_timesteps - 1 + T_inp], "f_cte": f_cte[:, :, self.n_timesteps - 1:self.n_timesteps - 1 + T_inp], "T": T_inp, "name": "rec" } pred = { "obs": G_for_pred, "confi": confi[:, :, -G_for_pred.shape[1]:], "shape": shape[:, :, -G_for_pred.shape[1]:], "f_cte": f_cte[:, :, -G_for_pred.shape[1]:], "T": G_for_pred.shape[1], "name": "pred" } outs = {} BBs = {} # TODO: Check if the indices for f_bb and supervision are correct. # Recover partial shape with decoded dynamical features. Iterate with new estimates of the appearance. # Note: This process could be iterative. for idx, case in enumerate([rec, pred]): case_name = case["name"] # get back dynamic features # if case_name == "rec": # f_dyn_state = f_dyn[:, -T_inp:].reshape(-1, *f_dyn.shape[2:]) # else: # TODO: Try with Initial f_dyn (no koop) f_dyn_state = self.koopman.to_s(gcodes=_get_flat(case["obs"]), psteps=self.psteps) # TODO: From f_dyn extract where, scale, etc. with only Y = WX + b. Review. # Initial noisy (low variance) constant vector. # Note: # - Different realization for each time-step. # - Is it a Variable? # f_cte_ini = torch.randn(bs * self.n_objects * case["T"], self.feat_cte_dim).to(f_dyn_state.device) #* self.f_cte_ini_std # Get full feature vector # f = torch.cat([ f_dyn_state, # f_cte_ini], dim=-1) # Get coarse features from which obtain queries and/or decode # f_coarse = self.image_decoder(f, block = 'coarse') # f_coarse = f_coarse.reshape(bs, self.n_objects, case["T"], *f_coarse.shape[1:]) for _ in range(self.n_iters): # Get constant feature vector through attention # f_cte = self.image_encoder(case["backbone_features"], f_coarse, block='cte') # Encode with raw dynamic features # pose = self.pose_encoder(f_dyn_state) pose = f_dyn_state # M_center, M_relocate = get_affine_params_from_pose(pose, self.pose_limits, self.pose_bias) # Option 2: Previous impl # bb_feat = case["backbone_features"].unsqueeze(1).repeat_interleave(self.n_objects, dim=1)\ # .reshape(-1, *case["backbone_features"].shape[-4:]) # Repeat for number of objects # warped_bb_feat = (bb_feat, M_center, *self.cte_resolution) # Check all resolutions are coordinate [32, 32] # f_cte = self.image_encoder(warped_bb_feat, f_dyn_state, block='cte') # Sample cte features f_cte = case["f_cte"].reshape(bs * self.n_objects * case["T"], -1) confi = case["confi"].reshape(bs * self.n_objects * case["T"], -1) # f_mu_cte, f_logvar_cte = self.linear_f_cte_post(f_cte).reshape( bs, self.n_objects, case["T"], -1).chunk(2, -1) f_cte_post = Normal(f_mu_cte, F.softplus(f_logvar_cte)) f_cte = f_cte_post.rsample().reshape( bs * self.n_objects * case["T"], -1) # Option 2: Previous impl # f_mu_cte, f_logvar_cte = self.linear_f_cte_post(f_cte).reshape(bs, self.n_objects, 1, -1).chunk(2, -1) # f_cte_post = Normal(f_mu_cte, F.softplus(f_logvar_cte)) # f_cte = f_cte_post.rsample() # f_cte = f_cte.repeat_interleave(case["T"], dim=2) # f_cte = f_cte.reshape(bs * self.n_objects * case["T"], self.feat_cte_dim) # Register statistics returned_post.append(f_cte_post) # Get full feature vector # f = torch.cat([ f_dyn_state, # f_cte], dim=-1) # Note: Do if f_dyn_state is used in the appearance f = f_cte # Get output. Spatial broadcast decoder dec_obj = self.image_decoder(f, block='to_x') if not self.bc_decoder: grid, area = self.spatial_tf(confi, pose) outs[case_name], out_shape = self.spatial_tf.warp_and_render( dec_obj, case["shape"], confi, grid) else: outs[case_name] = dec_obj * confi[..., None, None] # outs[case_name] = warp_affine(dec_obj, M_relocate, h=h, w=w) BBs[case_name] = dec_obj out_rec = outs["rec"] out_pred = outs["pred"] bb_rec = BBs["rec"] # bb_rec = BBs["pred"] # Test disentanglement - TO REVIEW # if random.random() < 0.1 or self.content is None: # self.content = f_cte # f_cte = self.content ''' Returned variables ''' returned_g = torch.cat([g, G_for_pred], dim=1) # Observations o_touple = (out_rec, out_pred, returned_g.reshape(-1, returned_g.size(-1))) o = [ item.reshape( torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple ] # Option 1: one object mapped to 0 # shape_o0 = o[0].shape # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:]) # o[0][:,0] = o[0][:,0]*0 # o[0] = o[0].reshape(*shape_o0) # Sum across objects. Note: Add Clamp or Sigmoid. o[:2] = [ torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2] ] # o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]] # TODO: output = {} # output['rec'] = o[0] # output['pred'] = o[1] o.append(returned_post) o.append(A) # State transition matrix o.append(B) # Input matrix o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1])) # Inputs o.append(u_dist.reshape(bs * self.n_objects, -1, u_dist.shape[-1])) # Input distribution o.append( g_for_koop.reshape( bs * self.n_objects, -1, g_for_koop.shape[-1])) # Observation propagated only with A o.append(fit_err) # Fit error g to G_for_pred # bb_rec = bb_rec.reshape(torch.Size([bs * self.n_objects, -1]) + bb_rec.size()[1:]) # bb_rec =torch.sum(bb_rec, dim=2, keepdim=True) # o.append(bb_rec.reshape(bs, self.n_objects, *bb_rec.shape[1:])) # Motion field reconstruction # TODO: return in dictionary form return o
class RecKoopmanModel(BaseModel): def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, u_dim, n_objects, I_factor=10, n_blocks=1, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]): super().__init__() out_channels = 1 n_layers = int(np.log2(image_size[0])) - 1 self.u_dim = u_dim # Temporal encoding buffers if n_timesteps > 1: t = torch.linspace(-1, 1, n_timesteps) # Add as constant, with extra dims for N and C self.register_buffer('t_grid', t) # Set state dim with config, depending on how many time-steps we want to take into account self.image_size = image_size self.n_timesteps = n_timesteps self.state_dim = feat_dim self.I_factor = I_factor self.psteps = psteps self.g_dim = g_dim self.n_objects = n_objects # self.softmax = nn.Softmax(dim=-1) # self.sigmoid = nn.Sigmoid() # self.linear_g = nn.Linear(g_dim, g_dim) # self.initial_conditions = nn.Sequential(nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps), # nn.ReLU(), # nn.Linear(feat_dim * n_timesteps, g_dim * 2)) feat_dyn_dim = feat_dim // 6 self.feat_dyn_dim = feat_dyn_dim self.content = None # self.reverse = True self.reverse = False self.ini_alpha = 1 self.incr_alpha = 0.25 # self.rnn_f_cte = nn.LSTM(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim, 1, bias=False, batch_first=True) # self.rnn_f_cte = nn.GRU(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim, 1, bias=False, batch_first=True) # self.linear_f_mu = nn.Linear(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim) # self.linear_f_logvar = nn.Linear(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim) # self.linear_f_dyn_mu = nn.Linear(feat_dyn_dim, feat_dyn_dim) # self.linear_f_dyn_logvar = nn.Linear(feat_dyn_dim, feat_dyn_dim) self.linear_f_mu = nn.Linear(feat_dim, feat_dim) self.linear_f_logvar = nn.Linear(feat_dim, feat_dim) self.image_encoder = ImageEncoder(in_channels, feat_dim, n_objects, ngf, n_layers) # feat_dim * 2 if sample here self.image_decoder = ImageDecoder(feat_dim, out_channels, ngf, n_layers) self.koopman = KoopmanOperators(feat_dyn_dim, nf_particle, nf_effect, g_dim, u_dim, n_timesteps, n_blocks) def _get_full_state(self, x, T): if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[1:]) new_x = [] for t in range(new_T): new_x.append(torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, new_x.shape[-1]), new_T def _get_full_state_hankel(self, x, T): ''' :param x: features or observations :param T: number of time-steps before concatenation :return: Columns of a hankel matrix with self.n_timesteps rows. ''' if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[2:]) new_x = [] for t in range(new_T): new_x.append(torch.stack([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, new_T, new_x.shape[-2] * new_x.shape[-1]), new_T def forward(self, input, epoch = 1): bs, T, ch, h, w = input.shape free_pred = T//4 f = self.image_encoder(input) # (bs * n_objects * T, feat_dim) # Option 1: Mean before sampling # f_cte = f.reshape(bs * self.n_objects, T, *f.shape[1:])\ # [..., self.feat_dyn_dim:] # f_cte = f_cte.mean(1) # # LSTM to obtain cte # # h0 = torch.zeros_like(f_cte[:, 0])[None] # # c0 = torch.zeros_like(f_cte[:, 0])[None] # # f_cte, _ = self.rnn_f_cte(f_cte[:, :-free_pred], (h0, c0)) # # f_cte = f_cte[:, -1] # f_mu_cte, f_logvar_cte = self.linear_f_mu(f_cte)[:, None].repeat(1, T, 1).reshape(bs * self.n_objects * T, -1), \ # self.linear_f_logvar(f_cte)[:, None].repeat(1, T, 1).reshape(bs * self.n_objects * T, -1) # f_cte = _sample_latent_simple(f_mu_cte, f_logvar_cte) # # # f_dyn = f[..., :self.feat_dyn_dim] # f_mu_dyn, f_logvar_dyn = self.linear_f_dyn_mu(f_dyn), \ # self.linear_f_dyn_logvar(f_dyn) # f_dyn = _sample_latent_simple(f_mu_dyn, f_logvar_dyn) # # # f_mu = torch.cat([f_mu_dyn, f_mu_cte], dim=-1) # f_logvar = torch.cat([f_logvar_dyn, f_logvar_cte], dim=-1) # Option 2: Sampling all together f_mu = self.linear_f_mu(f) f_logvar = self.linear_f_logvar(f) f = _sample_latent_simple(f_mu, f_logvar) # Note: Split features (cte, dyn) f_dyn, f_cte = f[..., :self.feat_dyn_dim], f[..., self.feat_dyn_dim:] f_cte = f_cte.reshape(bs * self.n_objects, T, *f_cte.shape[1:]) # Test disentanglement # if random.random() < 0.1 or self.content is None: # self.content = f_cte # f_cte = self.content f_dyn = f_dyn.reshape(bs * self.n_objects, T, *f_dyn.shape[1:]) f_dyn, T = self._get_full_state_hankel(f_dyn, T) # TODO: U might depend also in features from the scene. Residual features (slot n+1 / Background) u, u_dist = self.koopman.to_u(f_dyn, temp=self.ini_alpha + epoch * self.incr_alpha) g = self.koopman.to_g(f_dyn.reshape(bs * self.n_objects * T, -1), self.psteps) g = g.reshape(bs * self.n_objects, T, *g.shape[1:]) # TODO: # 0 Ho foto tot a varying y a tomar # ... # 0: state as velocity acceleration, pose # 0: g fitting error. # 0: Hankel view. F**k koopman for now, it has too many degrees of freedom. Restar input para minimizar rango de H. # 1: Invert or Sample randomly u and maximize the error in reconstruction. # 2: Treat n_timesteps with conv_nn? Or does it make sense to mix f1(t) and f2(t-1)? # 3: We don't impose A symmetric, but we impose that g in both directions use the same A and B. # Explode symmetry. If there's input in the a sequence, there will be the same input in the reverse sequence. # B is valid for both. Then we cant cross, unless we recalculate B. It might not be necessary because we use # the same A for both dirs. INVERT U's temporally --> that's a great idea. Dismiss (n_ini_timesteps - 1 samples) in each extreme Is this correct? # Is it the same linear input intervention if I do it forward and backward? # 5: Eigenvalues and Eigenvector. Fix one get the other. # 6: Observation selector using Gumbel (should apply mask emissor and receptor # and it should be the same across time) # 7: Attention mechanism to select - koopman (sub-koopman) / Identity. It will be useful for the future. # 8: When works, formalize code. Hyperparameters through config.yaml # 9: The nonlinear projection should also separate the objects. We can separate by the svd? Is it possible? # 1: Prev_sample not necessary? Start from g(0) # 2: Sampling from features # 3: Bottleneck in f. Smaller than G. Constant part is big and routed directly. # Variant part is very small and expanded to G. # 4: In evaluation, sample features several times but only keep one of the cte. Check if it captures the appearance # 5: Fix B with A learned. If the first doesn't require input, B will be mapped to 0. randperm = torch.arange(g.shape[0]) if self.reverse or True\ else torch.randperm(g.shape[0]) #Note: Inverting u(t) in the time axis # TODO: still here but im tired. # if free_pred > 0: # G_tilde = g[randperm, self.n_timesteps -1:-1-free_pred, None] # H_tilde = g[randperm, self.n_timesteps:-free_pred, None] # else: # G_tilde = g[randperm, self.n_timesteps -1:-1, None] # H_tilde = g[randperm, self.n_timesteps:, None] if free_pred > 0: G_tilde = g[randperm, :-1-free_pred, None] H_tilde = g[randperm, 1:-free_pred, None] else: G_tilde = g[randperm, :-1, None] H_tilde = g[randperm, 1:, None] # TODO: If we identify with half of the timesteps, but use all inputs for rollout, # we might get something. We can also predict only the future. A, B, A_inv, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, U=u[randperm, :-1-free_pred], I_factor=self.I_factor) #Note: Not permutting input, but permutting g. # A, B, A_inv, fit_err = self.koopman.fit_with_A(G=G_tilde, H=H_tilde, U=u[randperm, :-1], I_factor=self.I_factor) # A, B, A_inv, fit_err = self.koopman.fit_with_B(G=G_tilde, H=H_tilde, U=u[randperm, :-1-free_pred], I_factor=self.I_factor) # A, B = self.koopman.fit_with_AB(G_tilde.shape[0]) # A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A(G=G_tilde, H=H_tilde, I_factor=self.I_factor) # B = None # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B) # g_for_koop = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None) # G_for_pred = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1], # self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=u[:, self.n_timesteps -1:], A=A, B=B)], dim=1) # g_for_koop = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1], # self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=None, A=A, B=None)], dim=1) # G_for_pred = self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=u[:, self.n_timesteps -1:], A=A, B=B) # g_for_koop = self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=None, A=A, B=None) # G_for_pred = self.koopman.simulate(T=T - 4, g=g[:, 3], u=u, A=A, B=B) # g_for_koop = self.koopman.simulate(T=T - 4, g=g[:, 3], u=None, A=A, B=None) g_start, T_start = g[:, 2], T-3 G_for_pred = self.koopman.simulate(T=T_start, g=g_start, u=u[!], A=A, B=B) # TODO: ATTENTION. U MIGHT NOT BE ALIGNED TEMPORALLY! # g_for_koop = self.koopman.simulate(T=T - 2, g=g[:, 1], u=None, A=A, B=None) g_for_koop= G_for_pred # g_for_koop = G_for_pred # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None) '''Simple version of the reverse rollout''' # G_for_pred_rev = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A_pinv, B=None) # G_for_pred = torch.flip(G_for_pred_rev, dims=[1]) # G_for_pred = torch.cat([g[:,0:1],self.koopman.simulate(T=T - 2, g=g[:, 1], u=u, A=A, B=B)], dim=1) s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), psteps=self.psteps) s_for_pred = self.koopman.to_s(gcodes=_get_flat(G_for_pred), psteps=self.psteps) # Note: Split features (cte, dyn). In case of reverse, f_cte averages for both directions. # Note 2: TODO: Reconstruction could be in reversed g's or both! # Option 1: Temporally permute appearance # T_randperm = torch.randperm(g.shape[1]) # torch.arange(g.shape[1]) # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1), # f_cte[:, -T:]], dim=-1) # # f_cte = f_cte[:, T_randperm] # f_cte = f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1) # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1), # f_cte[:, -G_for_pred.shape[1]:]], dim=-1) # T_randperm = torch.randperm(g.shape[1]) # torch.arange(g.shape[1]) # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1), # f_cte[:, -T:]], dim=-1) # # f_cte = f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1) # f_cte = f_cte[:, T_randperm] # # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1), # f_cte[:, -G_for_pred.shape[1]:]], dim=-1) # Option 2: Average appearance s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1), f_cte[:, None].mean(2).repeat(1, T, 1)], dim=-1) s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1), f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1)], dim=-1) # Option 3: Copy the first appearance (or a random one) # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1), # f_cte[:, 0:1].repeat(1, T, 1)], dim=-1) # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1), # f_cte[:, -G_for_pred.shape[1]:-G_for_pred.shape[1]+1].repeat(1, G_for_pred.shape[1], 1)], dim=-1) # Option 4: If f_cte has been computed before. # f_cte = f_cte.reshape(bs * self.n_objects, -1, f_cte.shape[-1]) # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1), # f_cte[:, -T:]], dim=-1) # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1), # f_cte[:, -G_for_pred.shape[1]:]], dim=-1) # # s_for_rec = _get_flat(s_for_rec) # s_for_pred = _get_flat(s_for_pred) # NOTE: Sample after # f_mu = self.linear_f_mu(s_for_rec) # f_logvar = self.linear_f_logvar(s_for_rec) # s_for_rec = _sample_latent_simple(f_mu, f_logvar) # # # f_mu_pred = self.linear_f_mu(s_for_pred) # f_logvar_pred = self.linear_f_logvar(s_for_pred) # s_for_pred = _sample_latent_simple(f_mu_pred, f_logvar_pred) # f_mu, f_mu_pred = f_mu.reshape(bs * self.n_objects, T, -1), f_mu_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1) # f_logvar, f_logvar_pred = f_logvar.reshape(bs * self.n_objects, T, -1), f_logvar_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1) # Convolutional decoder. Normally Spatial Broadcasting decoder out_rec = self.image_decoder(s_for_rec) out_pred = self.image_decoder(s_for_pred) returned_g = torch.cat([g, G_for_pred], dim=1) returned_mus = torch.cat([f_mu], dim=1) returned_logvars = torch.cat([f_logvar], dim=1) # returned_mus = torch.cat([f_mu, f_mu_pred], dim=1) # returned_logvars = torch.cat([f_logvar, f_logvar_pred], dim=1) o_touple = (out_rec, out_pred, returned_g.reshape(-1, returned_g.size(-1)), returned_mus.reshape(-1, returned_mus.size(-1)), returned_logvars.reshape(-1, returned_logvars.size(-1))) # f_mu.reshape(-1, f_mu.size(-1)), # f_logvar.reshape(-1, f_logvar.size(-1))) o = [item.reshape(torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple] # Option 1: one object mapped to 0 # shape_o0 = o[0].shape # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:]) # o[0][:,0] = o[0][:,0]*0 # o[0] = o[0].reshape(*shape_o0) # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2]] # Test object decomposition # o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:])[:,0:1], dim=1) for item in o[:2]] o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]] o[3:5] = [item.reshape(bs, self.n_objects, *item.shape[1:]) for item in o[3:5]] o.append(A) o.append(B) # u = u.reshape(*u.shape[:2], self.u_dim, self.n_timesteps)[..., -1] #TODO:HANKEL This is for hankel view o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1])) o.append(u_dist.reshape(bs * self.n_objects, -1, *u_dist.shape[-1:])) #Append udist for categorical # o.append(u_dist.reshape(bs * self.n_objects, -1, *u_dist.shape[-2:])) o.append(g_for_koop.reshape(bs * self.n_objects, -1, g_for_koop.shape[-1])) o.append(fit_err) return o