class RecKoopmanModel(BaseModel): def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, u_dim, n_objects, I_factor=10, n_blocks=1, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]): super().__init__() out_channels = 1 n_layers = int(np.log2(image_size[0])) - 1 self.u_dim = u_dim # Temporal encoding buffers if n_timesteps > 1: t = torch.linspace(-1, 1, n_timesteps) # Add as constant, with extra dims for N and C self.register_buffer('t_grid', t) # Set state dim with config, depending on how many time-steps we want to take into account self.image_size = image_size self.n_timesteps = n_timesteps self.state_dim = feat_dim self.I_factor = I_factor self.psteps = psteps self.g_dim = g_dim self.n_objects = n_objects self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid() self.linear_g = nn.Linear(g_dim, g_dim) self.linear_u = nn.Linear(g_dim, u_dim) self.initial_conditions = nn.Sequential(nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps), nn.ReLU(), nn.Linear(feat_dim * n_timesteps, g_dim * 2)) self.image_encoder = ImageEncoder(in_channels, feat_dim * 2, n_objects, ngf, n_layers) # feat_dim * 2 if sample here self.image_decoder = ImageDecoder(g_dim, out_channels, ngf, n_layers) self.koopman = KoopmanOperators(feat_dim * 2, nf_particle * 2, nf_effect * 2, g_dim, u_dim, n_timesteps, n_blocks) #Object attention: # if n_objects > 1: # self.obj_attention = ObjectAttention(in_channels, feat_dim, n_objects) def _get_full_state(self, x, T): if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[1:]) new_x = [] for t in range(new_T): new_x.append(torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=2) return new_x.reshape(-1, *new_x.shape[3:]), new_T def forward(self, input): bs, T, ch, h, w = input.shape f = self.image_encoder(input) # Concatenates the features from current time and previous to create full state f, T = self._get_full_state(f, T) # TODO: Might be bug. I'm reshaping with objects before T. f = f.view(torch.Size([bs * self.n_objects, T]) + f.size()[1:]) f_mu, f_logvar, a_mu, a_logvar = [], [], [], [] gs, us = [], [] for t in range(T): if t==0: prev_sample = torch.zeros_like(f[:, 0, :self.g_dim]) f_t = torch.cat([f[:, t], prev_sample], dim=-1) g = self.koopman.to_g(f_t, self.psteps) # g = self.initial_conditions(f[:, t]) else: f_t = torch.cat([f[:, t], prev_sample], dim=-1) g = self.koopman.to_g(f_t, self.psteps) # TODO: provar recurrent, suposo # g, prev_hidden, u = self.koopman.to_g(f_t, prev_hidden) #,psteps=self.psteps g_mu, g_logvar = torch.chunk(g, 2, dim=-1) g = _sample_latent_simple(g_mu, g_logvar) # if t < T-1: # u_mu, u_logvar = torch.chunk(u, 2, dim=-1) # u = _sample_latent_simple(u_mu, u_logvar) # else: # u_mu, u_logvar = torch.chunk(torch.zeros_like(u), 2, dim=-1) # u = u_mu # g, u = self.linear_g(g), self.sigmoid(self.linear_u(u)) g, u = self.linear_g(g), self.sigmoid(self.linear_u(g)) prev_sample = g gs.append(g) us.append(u) f_mu.append(g_mu) f_logvar.append(g_logvar) # a_mu.append(u_mu) # a_logvar.append(u_logvar) g = torch.stack(gs, dim=1) f_mu = torch.stack(f_mu, dim=1) f_logvar = torch.stack(f_logvar, dim=1) u = torch.stack(us, dim=1) # u_zeros = torch.zeros_likes_like(u) # u = torch.where(u > 0.8, u, u_zeros) # a_mu = torch.stack(a_mu, dim=1) # a_logvar = torch.stack(a_logvar, dim=1) free_pred = 0 if free_pred > 0: G_tilde = g[:, 1:-1-free_pred, None] # new axis corresponding to N number of objects H_tilde = g[:, 2:-free_pred, None] else: G_tilde = g[:, 1:-1, None] # new axis corresponding to N number of objects H_tilde = g[:, 2:, None] A, B, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, U=u[:, 1:-1], I_factor=self.I_factor) # TODO: clamp A before invert. Threshold u over 0.9 # A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A(G=G_tilde, H=H_tilde, I_factor=self.I_factor) # TODO: Try simulating backwards # B=None # Rollout. From observation in time 0, predict with Koopman operator # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B) # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None) '''Simple version of the reverse rollout''' # G_for_pred_rev = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A_pinv, B=None) # G_for_pred = torch.flip(G_for_pred_rev, dims=[1]) G_for_pred = torch.cat([g[:,0:1],self.koopman.simulate(T=T - 2, g=g[:, 1], u=u, A=A, B=B)], dim=1) # Option 1: use the koopman object decoder # s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), # pstep=self.psteps) # s_for_pred = self.koopman.to_s(gcodes=_get_flat(torch.cat([g[:, :1],G_for_pred], dim=1)), # pstep=self.psteps) # Option 2: we don't use the koopman object decoder s_for_rec = _get_flat(g) s_for_pred = _get_flat(G_for_pred) # Convolutional decoder. Normally Spatial Broadcasting decoder out_rec = self.image_decoder(s_for_rec) out_pred = self.image_decoder(s_for_pred) returned_g = torch.cat([g, G_for_pred], dim=1) returned_mus = torch.cat([f_mu], dim=-1) returned_logvars = torch.cat([f_logvar], dim=-1) # returned_mus = torch.cat([f_mu, a_mu], dim=-1) # returned_logvars = torch.cat([f_logvar, a_logvar], dim=-1) o_touple = (out_rec, out_pred, returned_g.reshape(-1, returned_g.size(-1)), returned_mus.reshape(-1, returned_mus.size(-1)), returned_logvars.reshape(-1, returned_logvars.size(-1))) # f_mu.reshape(-1, f_mu.size(-1)), # f_logvar.reshape(-1, f_logvar.size(-1))) o = [item.reshape(torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple] # Option 1: one object mapped to 0 # shape_o0 = o[0].shape # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:]) # o[0][:,0] = o[0][:,0]*0 # o[0] = o[0].reshape(*shape_o0) # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2]] o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]] o[3:5] = [item.reshape(bs, self.n_objects, *item.shape[1:]) for item in o[3:5]] o.append(A) o.append(B) o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1])) # # Show images # plt.imshow(input[0, 0, :].reshape(16, 16).cpu().detach().numpy()) # plt.savefig('test_attention.png') return o
class cswm_KoopmanModel(BaseModel): def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, I_factor=10, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]): super().__init__() out_channels = 1 n_layers = int(np.log2(image_size[0])) - 1 num_objects = 1 # Set state dim with config, depending on how many time-steps we want to take into account self.image_size = image_size self.n_timesteps = n_timesteps self.state_dim = feat_dim self.I_factor = I_factor self.psteps = psteps self.g_dim = g_dim self.image_encoder = EncoderCNNLarge( input_dim=in_channels, hidden_dim=feat_dim // 8, output_dim=feat_dim, num_objects=num_objects) self.obj_encoder = EncoderMLP( input_dim=np.prod(image_size) // (8 ** 2) * feat_dim, hidden_dim=feat_dim // 2, output_dim=feat_dim, num_objects=num_objects) self.state_encoder = EncoderMLP( input_dim=feat_dim * self.n_timesteps, hidden_dim=feat_dim, output_dim=g_dim, num_objects=num_objects) self.state_decoder = EncoderMLP( input_dim=g_dim, hidden_dim=feat_dim, output_dim=feat_dim, num_objects=num_objects) # self.transition_model = TransitionGNN( # input_dim=embedding_dim, # hidden_dim=hidden_dim, # action_dim=action_dim, # num_objects=num_objects, # ignore_action=ignore_action, # copy_action=copy_action) self.image_decoder = SBImageDecoder(feat_dim, out_channels, ngf, n_layers, image_size) self.koopman = KoopmanOperators(self.state_dim, nf_particle, nf_effect, g_dim, n_timesteps) def _add_positional_encoding(self, x): x = torch.cat((self.x_grid_enc.expand(*x.shape[:2], -1, -1, -1), self.y_grid_enc.expand(*x.shape[:2], -1, -1, -1), x), dim=-3) return x def _get_full_state(self, x, T): new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[1:]) new_x = [] for t in range(new_T): new_x.append(torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, *new_x.shape[2:]), new_T def forward(self, input): bs, T, ch, h, w = input.shape cnnf = self.image_encoder(input.reshape(-1, ch, h, w)) f = self.obj_encoder(cnnf) # Note: test image encoder with attention # f_mu, f_logvar = torch.chunk(self.att_image_encoder(input.reshape(-1, ch, h, w)), 2, dim=-1) # f = _sample(f_mu, f_logvar) # input = self._add_positional_encoding(input) f, T = self._get_full_state(f, T) # f = f.view(torch.Size([bs, T]) + f.size()[1:]) # g, G_for_pred = f, f g = self.koopman.to_g(f, self.psteps) g = g.view(torch.Size([bs, T]) + g.size()[1:]) G_tilde = g[:, :-1, None] # new axis corresponding to N number of objects H_tilde = g[:, 1:, None] A, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, I_factor=self.I_factor) # TODO: maybe ''' rollout: BT x D ''' G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], A=A) # Rollout from initial observation to the last s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), pstep=self.psteps) s_for_pred = self.koopman.to_s(gcodes=_get_flat(G_for_pred), pstep=self.psteps) out_rec = self.image_decoder(s_for_rec) out_pred = self.image_decoder(s_for_pred) o_touple = (out_rec, out_pred, g.reshape(-1, g.size(-1)), f_mu.reshape(-1, f_mu.size(-1)), f_logvar.reshape(-1, f_logvar.size(-1))) o = [item.reshape(torch.Size([bs, -1]) + item.size()[1:]) for item in o_touple] return o
class KoopmanModel(BaseModel): def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, n_objects, I_factor=10, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]): super().__init__() out_channels = 1 n_layers = int(np.log2(image_size[0])) - 1 # Positional encoding buffers x = torch.linspace(-1, 1, image_size[0]) y = torch.linspace(-1, 1, image_size[1]) x_grid, y_grid = torch.meshgrid(x, y) # Add as constant, with extra dims for N and C self.register_buffer('x_grid_enc', x_grid.view((1, 1, 1) + x_grid.shape)) self.register_buffer('y_grid_enc', y_grid.view((1, 1, 1) + y_grid.shape)) # Temporal encoding buffers if n_timesteps > 1: t = torch.linspace(-1, 1, n_timesteps) # Add as constant, with extra dims for N and C self.register_buffer('t_grid', t) # Set state dim with config, depending on how many time-steps we want to take into account self.image_size = image_size self.n_timesteps = n_timesteps self.state_dim = feat_dim self.I_factor = I_factor self.psteps = psteps self.g_dim = g_dim self.n_objects = n_objects self.image_encoder = ImageEncoder(in_channels, feat_dim * 2, n_objects, ngf, n_layers) # feat_dim * 2 if sample here self.image_decoder = ImageDecoder(g_dim, out_channels, ngf, n_layers) self.koopman = KoopmanOperators(feat_dim * 2, nf_particle * 2, nf_effect * 2, g_dim, n_timesteps) #Object attention: # if n_objects > 1: # self.obj_attention = ObjectAttention(in_channels, feat_dim, n_objects) def _get_full_state(self, x, T): if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[1:]) new_x = [] for t in range(new_T): new_x.append(torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1).permute(0, 2, 1, 3) return new_x.reshape(-1, *new_x.shape[3:]), new_T def forward(self, input): bs, T, ch, h, w = input.shape # input = self._add_positional_encoding(input) f = self.image_encoder(input) # # Concatenates the features from current time and previous to create full state f, T = self._get_full_state(f, T) g = self.koopman.to_g(f, self.psteps) g = g.view(torch.Size([bs * self.n_objects, T]) + g.size()[1:]) free_pred = 0 # TODO: Set it to >1 after training for several epochs. if free_pred > 0: G_tilde = g[:, :-1-free_pred, None] # new axis corresponding to N number of objects H_tilde = g[:, 1:-free_pred, None] else: G_tilde = g[:, :-1, None] # new axis corresponding to N number of objects H_tilde = g[:, 1:, None] # Find Koopman operator from current data (matrix a) # TODO: Predict longer from an A obtained for a shorter length. A, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, I_factor=self.I_factor) # Rollout. From observation in time 0, predict with Koopman operator G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], A=A) # Rollout from initial observation to the last # # Note: Ignore - If we have split the representation: Merge constant and dynamic components # if g.size(-1) < self.g_dim: # g = torch.cat([g, g_cte[:, 0:1].repeat(1, T, 1)], dim=-1) # G_for_pred = torch.cat([G_for_pred, g_cte[:, 0:1].repeat(1, T - 1, 1)], dim=-1) g_mu, g_logvar = torch.chunk(g, 2, dim=-1) g_mu_pred, g_logvar_pred = torch.chunk(G_for_pred, 2, dim=-1) g = _sample_latent_simple(g_mu, g_logvar) G_for_pred = _sample_latent_simple(g_mu_pred, g_logvar_pred) f_mu, f_logvar = g_mu, g_logvar # Option 1: use the koopman object decoder # s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), # pstep=self.psteps) # s_for_pred = self.koopman.to_s(gcodes=_get_flat(torch.cat([g[:, :1],G_for_pred], dim=1)), # pstep=self.psteps) # Option 2: we don't use the koopman object decoder s_for_rec = _get_flat(g) s_for_pred = _get_flat(G_for_pred) # Convolutional decoder. Normally Spatial Broadcasting decoder out_rec = self.image_decoder(s_for_rec) out_pred = self.image_decoder(s_for_pred) returned_mus = torch.cat([f_mu, g_mu_pred], dim=1) returned_logvars = torch.cat([f_logvar, g_logvar_pred], dim=1) o_touple = (out_rec, out_pred, returned_mus.reshape(-1, returned_mus.size(-1)), returned_mus.reshape(-1, returned_mus.size(-1)), returned_logvars.reshape(-1, returned_logvars.size(-1))) # f_mu.reshape(-1, f_mu.size(-1)), # f_logvar.reshape(-1, f_logvar.size(-1))) o = [item.reshape(torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple] # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1.5) for item in o[:2]] o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]] # A_to_show = self.block_diagonal(A) o.append(A[..., :self.g_dim, :self.g_dim]) # if self.n_objects > 1: # o.append(o_a) # Returns reconstruction, prediction, g_representation, mu and sigma from which we sample to compute KL divergence # # Show images # plt.imshow(input[0, 0, :].reshape(16, 16).cpu().detach().numpy()) # plt.savefig('test_attention.png') return o
class SingleObjKoopmanModel(BaseModel): def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, I_factor=10, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]): super().__init__() out_channels = 1 n_layers = int(np.log2(image_size[0])) - 1 # Positional encoding buffers x = torch.linspace(-1, 1, image_size[0]) y = torch.linspace(-1, 1, image_size[1]) x_grid, y_grid = torch.meshgrid(x, y) # Add as constant, with extra dims for N and C self.register_buffer('x_grid_enc', x_grid.view((1, 1, 1) + x_grid.shape)) self.register_buffer('y_grid_enc', y_grid.view((1, 1, 1) + y_grid.shape)) # Temporal encoding buffers if n_timesteps > 1: t = torch.linspace(-1, 1, n_timesteps) # Add as constant, with extra dims for N and C self.register_buffer('t_grid', t) # Set state dim with config, depending on how many time-steps we want to take into account self.image_size = image_size self.n_timesteps = n_timesteps self.state_dim = feat_dim self.I_factor = I_factor self.psteps = psteps self.g_dim = g_dim # self.att_image_encoder = AttImageEncoder(in_channels, feat_dim, ngf, n_layers) self.image_encoder = ImageEncoder(in_channels, feat_dim * 4, ngf, n_layers) # feat_dim * 2 if sample here self.image_decoder = ImageDecoder(g_dim, out_channels, ngf, n_layers) # self.image_decoder = SimpleSBImageDecoder(feat_dim, out_channels, ngf, n_layers, image_size) self.koopman = KoopmanOperators(self.state_dim * 4, nf_particle, nf_effect, g_dim, n_timesteps) def _add_positional_encoding(self, x): x = torch.cat((self.x_grid_enc.expand(*x.shape[:2], -1, -1, -1), self.y_grid_enc.expand(*x.shape[:2], -1, -1, -1), x), dim=-3) return x def _get_full_state(self, x, T): if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[1:]) new_x = [] for t in range(new_T): new_x.append(torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, *new_x.shape[2:]), new_T def forward(self, input): bs, T, ch, h, w = input.shape # Note: test image encoder with attention # f_mu, f_logvar = torch.chunk(self.att_image_encoder(input.reshape(-1, ch, h, w)), 2, dim=-1) # f = _sample(f_mu, f_logvar) # Note: Koopman applied to mu and logvar # f = self.att_image_encoder(input.reshape(-1, ch, h, w)) # input = self._add_positional_encoding(input) f = self.image_encoder(input.reshape(-1, ch, h, w)) #Note: deterministic AE # f_mu, f_logvar = f, f # f_mu, f_logvar = torch.chunk(self.image_encoder(input.reshape(-1, ch+2, h, w)), 2, dim=-1) # f = _sample(f_mu, f_logvar) # f = f_mu # # Concatenates the features from current time and previous to create full state f, T = self._get_full_state(f, T) # f = f.view(torch.Size([bs, T]) + f.size()[1:]) # g, G_for_pred = f, f g = self.koopman.to_g(f, self.psteps) g = g.view(torch.Size([bs, T]) + g.size()[1:]) # g = f.view(torch.Size([bs, T]) + f.size()[1:]) # # Note: Ignore - Split representation into constant and dynamic by hardcoding # g, g_cte = torch.chunk(g, 2, dim=-1) # Option 1 # g, g_cte = g[..., :2], g[..., 2:] # Option 2 free_pred = 3 G_tilde = g[:, :-1-free_pred, None] # new axis corresponding to N number of objects H_tilde = g[:, 1:-free_pred, None] # Find Koopman operator from current data (matrix a) # TODO: Predict longer from an A obtained for a shorter length. A, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, I_factor=self.I_factor) # Rollout. From observation in time 0, predict with Koopman operator G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], A=A) # Rollout from initial observation to the last # # Note: Ignore - If we have split the representation: Merge constant and dynamic components # if g.size(-1) < self.g_dim: # g = torch.cat([g, g_cte[:, 0:1].repeat(1, T, 1)], dim=-1) # G_for_pred = torch.cat([G_for_pred, g_cte[:, 0:1].repeat(1, T - 1, 1)], dim=-1) g_mu, g_logvar = torch.chunk(g, 2, dim=-1) g_mu_pred, g_logvar_pred = torch.chunk(G_for_pred, 2, dim=-1) g = _sample_latent_simple(g_mu, g_logvar) G_for_pred = _sample_latent_simple(g_mu_pred, g_logvar_pred) f_mu, f_logvar = g_mu, g_logvar # Option 1: use the koopman object decoder # s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), # pstep=self.psteps) # s_for_pred = self.koopman.to_s(gcodes=_get_flat(torch.cat([g[:, :1],G_for_pred], dim=1)), # pstep=self.psteps) # Option 2: we don't use the koopman object decoder s_for_rec = _get_flat(g) s_for_pred = _get_flat(G_for_pred) # Convolutional decoder. Normally Spatial Broadcasting decoder out_rec = self.image_decoder(s_for_rec) out_pred = self.image_decoder(s_for_pred) o_touple = (out_rec, out_pred, f_mu.reshape(-1, f_mu.size(-1)), f_mu.reshape(-1, f_mu.size(-1)), f_logvar.reshape(-1, f_logvar.size(-1))) o = [item.reshape(torch.Size([bs, -1]) + item.size()[1:]) for item in o_touple] # Returns reconstruction, prediction, g_representation, mu and sigma from which we sample to compute KL divergence # # Show images # plt.imshow(input[0, 0, :].reshape(16, 16).cpu().detach().numpy()) # plt.savefig('test_attention.png') return o
class RecKoopmanModel(BaseModel): def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, u_dim, n_objects, I_factor=10, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]): super().__init__() out_channels = 1 n_layers = int(np.log2(image_size[0])) - 1 self.u_dim = u_dim self.linear_g = nn.Linear(g_dim, g_dim) self.linear_u = nn.Linear(g_dim, u_dim) # Temporal encoding buffers if n_timesteps > 1: t = torch.linspace(-1, 1, n_timesteps) # Add as constant, with extra dims for N and C self.register_buffer('t_grid', t) # Set state dim with config, depending on how many time-steps we want to take into account self.image_size = image_size self.n_timesteps = n_timesteps self.state_dim = feat_dim self.I_factor = I_factor self.psteps = psteps self.g_dim = g_dim self.n_objects = n_objects self.n_directions = 2 self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid() self.initial_conditions = nn.Sequential( nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps), nn.ReLU(), nn.Linear(feat_dim * n_timesteps, g_dim * 2)) self.image_encoder = ImageEncoder( in_channels, feat_dim * 2, n_objects, ngf, n_layers) # feat_dim * 2 if sample here self.image_decoder = ImageDecoder(g_dim, out_channels, ngf, n_layers) self.koopman = KoopmanOperators(feat_dim * 2, nf_particle * 2, nf_effect * 2, g_dim, u_dim, n_timesteps) #Object attention: # if n_objects > 1: # self.obj_attention = ObjectAttention(in_channels, feat_dim, n_objects) def _get_full_state(self, x, T): if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[1:]) new_x = [] for t in range(new_T): new_x.append( torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=2) return new_x.reshape(-1, *new_x.shape[3:]), new_T def forward(self, input): bs, T, ch, h, w = input.shape f = self.image_encoder(input, reverse=True) # Concatenates the features from current time and previous to create full state f, T = self._get_full_state(f, T) f = f.reshape(self.n_directions * bs * self.n_objects, T, f.shape[-1]) f_mu, f_logvar, a_mu, a_logvar = [], [], [], [] gs, us = [], [] for t in range(T): if t == 0: # f_t = self.initial_conditions(f[:, t]) prev_sample = torch.zeros_like(f[:, 0, :self.g_dim]) # f_t = torch.cat([f_ini, prev_sample], dim=-1) prev_hidden = None f_t = torch.cat([f[:, t], prev_sample], dim=-1) g, u = self.koopman.to_g(f_t, psteps=self.psteps) else: f_t = torch.cat([f[:, t], prev_sample], dim=-1) # g, prev_hidden, u = self.koopman.to_g(f_t, prev_hidden) #,psteps=self.psteps g, u = self.koopman.to_g(f_t, psteps=self.psteps) g_mu, g_logvar = torch.chunk(g, 2, dim=-1) g = _sample_latent_simple(g_mu, g_logvar) # if t < T-1: # u_mu, u_logvar = torch.chunk(u, 2, dim=-1) # u = _sample_latent_simple(u_mu, u_logvar) # else: # u_mu, u_logvar = torch.chunk(torch.zeros_like(u), 2, dim=-1) # u = u_mu # g, u = self.linear_g(g), self.sigmoid(self.linear_u(u)) g, u = self.linear_g(g), self.linear_u(g) prev_sample = g gs.append(g) us.append(u) f_mu.append(g_mu) f_logvar.append(g_logvar) # a_mu.append(u_mu) # a_logvar.append(u_logvar) g = torch.stack(gs, dim=1) f_mu = torch.stack(f_mu, dim=1) f_logvar = torch.stack(f_logvar, dim=1) u = torch.stack(us, dim=1) # a_mu = torch.stack(a_mu, dim=1) # a_logvar = torch.stack(a_logvar, dim=1) reverse = True if reverse: g_rsh = g.view( torch.Size([bs, self.n_directions, self.n_objects, T]) + g.size()[2:]) g_rev = g_rsh[:, 0].reshape(bs * self.n_objects, T, g_rsh.shape[-1]) g = g_rsh[:, 1].reshape(bs * self.n_objects, T, g_rsh.shape[-1]) free_pred = 0 if free_pred > 0: G_tilde = g[:, :-1 - free_pred, None] # new axis corresponding to N number of objects H_tilde = g[:, 1:-free_pred, None] else: G_tilde = g[:, :-1, None] # new axis corresponding to N number of objects H_tilde = g[:, 1:, None] # A, B, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, U=u[:, :-1], I_factor=self.I_factor) A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A( G=G_tilde, H=H_tilde, I_factor=self.I_factor) # Rollout. From observation in time 0, predict with Koopman operator # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B) G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None) G_for_pred_rev = self.koopman.simulate(T=T - 1, g=g_rev[:, 0], u=None, A=A_pinv, B=None) B = None # G_for_pred = torch.cat([g[:,0:1],self.koopman.simulate(T=T - 2, g=g[:, 1], A=A)], dim=1) # Option 1: use the koopman object decoder # s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), # pstep=self.psteps) # s_for_pred = self.koopman.to_s(gcodes=_get_flat(torch.cat([g[:, :1],G_for_pred], dim=1)), # pstep=self.psteps) # Option 2: we don't use the koopman object decoder g_all = torch.stack([g, g_rev], dim=1).reshape( self.n_directions * bs * self.n_objects, -1, g.shape[-1]) G_pred_all = torch.stack([G_for_pred, G_for_pred_rev], dim=1).reshape( self.n_directions * bs * self.n_objects, -1, G_for_pred.shape[-1]) s_for_rec = _get_flat(g_all) s_for_pred = _get_flat(G_pred_all) # Convolutional decoder. Normally Spatial Broadcasting decoder out_rec = self.image_decoder(s_for_rec) out_rec = out_rec.reshape( torch.Size([bs, self.n_objects, self.n_directions, -1]) + out_rec.size()[1:]) out_pred = self.image_decoder(s_for_pred) out_pred = out_pred.reshape( torch.Size([bs, self.n_objects, self.n_directions, -1]) + out_pred.size()[1:]) returned_g = torch.cat([g, G_for_pred], dim=1) returned_g_rev = torch.cat([g_rev, G_for_pred_rev], dim=1) # returned_mus = torch.cat([f_mu], dim=-1) returned_mus = f_mu.reshape(bs, self.n_directions, self.n_objects, *f_mu.shape[1:]) returned_mus = returned_mus.reshape(bs, self.n_directions, self.n_objects, -1, returned_mus.shape[-1]) # returned_logvars = torch.cat([f_logvar], dim=-1) returned_logvars = f_logvar.reshape(bs, self.n_directions, self.n_objects, *f_logvar.shape[1:]) returned_logvars = returned_logvars.reshape(bs, self.n_directions, self.n_objects, -1, returned_logvars.shape[-1]) # returned_mus = torch.cat([f_mu, a_mu], dim=-1) # returned_logvars = torch.cat([f_logvar, a_logvar], dim=-1) o = [out_rec, out_pred, returned_g, returned_mus, returned_logvars] # Option 1: one object mapped to 0 # shape_o0 = o[0].shape # o[0] = o[0].reshape(bs, self.n_objects, 2, *o[0].shape[1:]) # o[0][:,0] = o[0][:,0]*0 # o[0] = o[0].reshape(*shape_o0) # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, self.n_directions, *item.shape[3:]), dim=1), min=0, max=1) for item in o[:2]] o[:2] = [ torch.sum(item.reshape(bs, self.n_objects, self.n_directions, *item.shape[3:]), dim=1) for item in o[:2] ] o[:2] = [ item.reshape(bs * self.n_directions, *item.shape[2:]) for item in o[:2] ] o.append(A) o.append(B) o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1])) o.append( returned_g_rev.reshape(bs * self.n_objects, -1, g_rev.shape[-1])) return o
class RecKoopmanModel(BaseModel): def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, u_dim, n_objects, I_factor=10, n_blocks=1, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]): super().__init__() out_channels = 1 n_layers = int(np.log2(image_size[0])) - 1 self.u_dim = u_dim # Temporal encoding buffers if n_timesteps > 1: t = torch.linspace(-1, 1, n_timesteps) # Add as constant, with extra dims for N and C self.register_buffer('t_grid', t) # Set state dim with config, depending on how many time-steps we want to take into account self.image_size = image_size self.n_timesteps = n_timesteps self.state_dim = feat_dim self.I_factor = I_factor self.psteps = psteps self.g_dim = g_dim self.n_objects = n_objects self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid() # self.linear_g = nn.Linear(g_dim, g_dim) # self.initial_conditions = nn.Sequential(nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps), # nn.ReLU(), # nn.Linear(feat_dim * n_timesteps, g_dim * 2)) feat_dyn_dim = feat_dim // 6 self.feat_dyn_dim = feat_dyn_dim self.content = None self.reverse = False self.hankel = True self.ini_alpha = 1 self.incr_alpha = 0.5 # self.linear_u = nn.Linear(g_dim + u_dim, u_dim * 2) # self.linear_u_2_f = nn.Linear(feat_dyn_dim * n_timesteps, u_dim * 2) # self.linear_u_T_f = nn.Linear(feat_dyn_dim * n_timesteps, u_dim) if self.hankel: # self.linear_u_2_g = nn.Linear(g_dim * self.n_timesteps, u_dim * 2) # self.linear_u_2_g = nn.Sequential(nn.Linear(g_dim * self.n_timesteps, g_dim), # nn.ReLU(), # nn.Linear(g_dim, u_dim * 2)) # self.linear_u_all_g = nn.Sequential(nn.Linear(g_dim * self.n_timesteps, g_dim), # nn.ReLU(), # nn.Linear(g_dim, u_dim + 1)) # self.gru_u_all_g = nn.GRU(g_dim, u_dim + 1, num_layers = 2, batch_first=True) self.linear_u_1_g = nn.Sequential( nn.Linear(g_dim * self.n_timesteps, g_dim), nn.ReLU(), nn.Linear(g_dim, u_dim)) else: self.linear_u_2_g = nn.Linear(g_dim, u_dim * 2) self.linear_f_mu = nn.Linear(feat_dim * 2, feat_dim) self.linear_f_logvar = nn.Linear(feat_dim * 2, feat_dim) self.image_encoder = ImageEncoder( in_channels, feat_dim * 2, n_objects, ngf, n_layers) # feat_dim * 2 if sample here self.image_decoder = ImageDecoder(feat_dim, out_channels, ngf, n_layers) self.koopman = KoopmanOperators(feat_dyn_dim, nf_particle, nf_effect, g_dim, u_dim, n_timesteps, n_blocks) def _get_full_state(self, x, T): if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[1:]) new_x = [] for t in range(new_T): new_x.append( torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, new_x.shape[-1]), new_T def _get_full_state_hankel(self, x, T): ''' :param x: features or observations :param T: number of time-steps before concatenation :return: Columns of a hankel matrix with self.n_timesteps rows. ''' if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[2:]) new_x = [] for t in range(new_T): new_x.append( torch.stack([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, new_T, new_x.shape[-2] * new_x.shape[-1]), new_T def forward(self, input, epoch=1): bs, T, ch, h, w = input.shape (n_dirs, bs) = (2, 2 * bs) if self.reverse else (1, bs) f = self.image_encoder(input, reverse=self.reverse) f_mu = self.linear_f_mu(f) f_logvar = self.linear_f_logvar(f) # Concatenates the features from current time and previous to create full state # f_mu, f_logvar = torch.chunk(f, 2, dim=-1) f = _sample_latent_simple(f_mu, f_logvar) # Note: Split features (cte, dyn) f_dyn, f_cte = f[..., :self.feat_dyn_dim], f[..., self.feat_dyn_dim:] f_cte = f_cte.reshape(bs * self.n_objects, T, *f_cte.shape[1:]) # Test disentanglement # if random.random() < 0.1 or self.content is None: # self.content = f_cte # f_cte = self.content f = f_dyn f = f.reshape(bs * self.n_objects * T, -1) if not self.hankel: f, T = self._get_full_state(f, T) g = self.koopman.to_g(f, self.psteps) g = g.reshape(bs * self.n_objects, T, *g.shape[1:]) if self.hankel: g_flat = g g, T = self._get_full_state_hankel(g, T) if self.reverse: bs = bs // n_dirs g = g.reshape(bs, n_dirs, self.n_objects, T, *g.shape[2:]) g_fw = g[:, 0].reshape(bs * self.n_objects, T, *g.shape[4:]) g_bw = g[:, 1].reshape(bs * self.n_objects, T, *g.shape[4:]) g = g_bw # Option 1: Sigmoid. All time-steps can be activated, binary activation. With dropout u_dist = self.linear_u_1_g( g.reshape(bs * self.n_objects * T, g.shape[-1])) u_dist = u_dist.reshape(u_dist.size(0), self.u_dim) u = self.sigmoid( (self.ini_alpha + epoch * self.incr_alpha) * u_dist).reshape( -1, T, self.u_dim) # u = torch.abs(u - torch.ones_like(u)*0.5)*2 # do = nn.Dropout(p=0.2) # u = do(u) # Option 2: Gumbel/softmax. All time-steps can be activated. 1 action at a time, or None. # F.relu # u_dist, _ = (self.gru_u_all_g(g_flat # .reshape(bs * self.n_objects, T + self.n_timesteps -1, g_flat.shape[-1]))) # u_dist = u_dist[:, self.n_timesteps-1:].reshape(bs *self.n_objects* T, - 1) # # u_dist = self.linear_u_all_g(g.reshape(bs * self.n_objects * T, g.shape[-1])) # u_dist = u_dist.reshape(u_dist.size(0), self.u_dim + 1) # u = F.gumbel_softmax(u_dist, tau=1, hard=True) # u = u[..., 1:].reshape(-1, T, self.u_dim) # # u_dist = u_dist.reshape(u_dist.size(0), self.u_dim + 1) # # temp = 0.01 # # u = nn.Softmax(dim=-1)(u_dist/temp) # # u = u[..., 1:].reshape(-1, T, self.u_dim) # Option 3: Categorical. All time-steps can be activated. 1 action at a time. Non-diff # u_dist = F.sigmoid(self.linear_u_all_g(g.reshape(bs * self.n_objects * T, g.shape[-1]))) # u_dist = u_dist.reshape(u_dist.size(0), self.u_dim + 1) # u_cat = Categorical(u_dist) # u = u_cat.sample() # u = F.one_hot(u).float() # u = u[..., 1:].reshape(-1, T, self.u_dim) # The first possible action is the inaction. Which should have the max prob. if self.hankel: zero_pad = torch.zeros_like(u[:, :self.n_timesteps - 1]) u = torch.cat([zero_pad, u], dim=1) u, T_u = self._get_full_state_hankel(u, T + self.n_timesteps - 1) # u = u.reshape(*u.shape[:-1], u.shape[-1] // self.n_timesteps, self.n_timesteps) assert T_u == T # Option 2: One activation for all time-steps. Too strong of an assumption. # u_dist = F.relu(self.linear_u_T(f)) # u_dist = u_dist.reshape(-1, T, self.u_dim).permute(0, 2, 1) # u = F.gumbel_softmax(u_dist, tau=1, hard=True) # u = u.permute(0, 2, 1) # Option 3: Actions are given by the observations g, obtained one at a time (Not necessary?) # us, u_dist = [], [] # for t in range(T): # if t==0: # prev_u_sample = torch.zeros_like(g[:, 0, :self.u_dim]) # q_y = F.relu(self.linear_u(f)) # q_y = q_y.view(q_y.size(0),self.u_dim,2) # u = F.gumbel_softmax(q_y, tau=1, hard=True) # u = u[..., 0] # # prev_u_sample = u # us.append(u) # u_dist.append(q_y) # u = torch.stack(us, dim=1) # u_dist = torch.stack(u_dist, dim=1) # TODO: # 0: state as velocity acceleration, pose # 0: g fitting error. # 0: Hankel view. F**k koopman for now, it has too many degrees of freedom. Restar input para minimizar rango de H. # 1: Invert or Sample randomly u and maximize the error in reconstruction. # 2: Treat n_timesteps with conv_nn? Or does it make sense to mix f1(t) and f2(t-1)? # 3: We don't impose A symmetric, but we impose that g in both directions use the same A and B. # Explode symmetry. If there's input in the a sequence, there will be the same input in the reverse sequence. # B is valid for both. Then we cant cross, unless we recalculate B. It might not be necessary because we use # the same A for both dirs. INVERT U's temporally --> that's a great idea. Dismiss (n_ini_timesteps - 1 samples) in each extreme Is this correct? # Is it the same linear input intervention if I do it forward and backward? # 5: Eigenvalues and Eigenvector. Fix one get the other. # 6: Observation selector using Gumbel (should apply mask emissor and receptor # and it should be the same across time) # 7: Attention mechanism to select - koopman (sub-koopman) / Identity. It will be useful for the future. # 8: When works, formalize code. Hyperparameters through config.yaml # 9: The nonlinear projection should also separate the objects. We can separate by the svd? Is it possible? # 1: Prev_sample not necessary? Start from g(0) # 2: Sampling from features # 3: Bottleneck in f. Smaller than G. Constant part is big and routed directly. # Variant part is very small and expanded to G. # 4: In evaluation, sample features several times but only keep one of the cte. Check if it captures the appearance # 5: Fix B with A learned. If the first doesn't require input, B will be mapped to 0. randperm = torch.arange(g.shape[0]) if self.reverse \ else torch.randperm(g.shape[0]) #Note: Inverting u(t) in the time axis if self.reverse: u_fw = u zeros = torch.zeros_like(u[:, :self.n_timesteps]) u_bw = torch.flip(u, dims=[1]) u_bw = torch.cat([u_bw[:, self.n_timesteps:], zeros], dim=1) u = u_bw free_pred = self.n_timesteps - 1 + 4 else: free_pred = 0 if free_pred > 0: G_tilde = g[randperm, self.n_timesteps - 1:-1 - free_pred, None] H_tilde = g[randperm, self.n_timesteps:-free_pred, None] else: G_tilde = g[randperm, self.n_timesteps - 1:-1, None] H_tilde = g[randperm, self.n_timesteps:, None] # if free_pred > 0: # G_tilde = g[randperm, :-1-free_pred, None] # H_tilde = g[randperm, 1:-free_pred, None] # else: # G_tilde = g[randperm, :-1, None] # H_tilde = g[randperm, 1:, None] A, B, A_inv, fit_err = self.koopman.system_identify( G=G_tilde, H=H_tilde, U=u[randperm, self.n_timesteps - 1:-1 - free_pred], I_factor=self.I_factor) # A, B, A_inv, fit_err = self.koopman.fit_with_A(G=G_tilde, H=H_tilde, U=u[randperm, :-1], I_factor=self.I_factor) # A, B, A_inv, fit_err = self.koopman.fit_with_B(G=G_tilde, H=H_tilde, U=u[randperm, :-1-free_pred], I_factor=self.I_factor) # A, B = self.koopman.fit_with_AB(G_tilde.shape[0]) # A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A(G=G_tilde, H=H_tilde, I_factor=self.I_factor) # B = None if self.reverse: g = g_fw u = u_fw f_cte_shape = f_cte.shape f_cte = f_cte.reshape(bs, n_dirs * f_cte_shape[1], *f_cte_shape[2:]) # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B) # g_for_koop = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None) # G_for_pred = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1], # self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=u[:, self.n_timesteps -1:], A=A, B=B)], dim=1) # g_for_koop = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1], # self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=None, A=A, B=None)], dim=1) G_for_pred = self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps - 1], u=u[:, self.n_timesteps - 1:], A=A, B=B) g_for_koop = self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps - 1], u=None, A=A, B=None) g_for_koop, T_prime = self._get_full_state_hankel( g_for_koop, g_for_koop.shape[1]) # g_for_koop = G_for_pred # TODO: create recursive rollout. We obtain the input at each step from g # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None) '''Simple version of the reverse rollout''' # G_for_pred_rev = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A_pinv, B=None) # G_for_pred = torch.flip(G_for_pred_rev, dims=[1]) # G_for_pred = torch.cat([g[:,0:1],self.koopman.simulate(T=T - 2, g=g[:, 1], u=u, A=A, B=B)], dim=1) if self.hankel: # G_for_pred = G_for_pred.reshape(*G_for_pred.shape[:-1], # G_for_pred.shape[-1]// self.n_timesteps, # self.n_timesteps)[..., self.n_timesteps - 1] # g = g.reshape(*g.shape[:-1], # g.shape[-1]// self.n_timesteps, # self.n_timesteps)[..., self.n_timesteps - 1] # # g_for_koop = g_for_koop.reshape(*g_for_koop.shape[:-1], # g_for_koop.shape[-1]// self.n_timesteps, # self.n_timesteps) #Option 2: with hankel structure. G_for_pred = G_for_pred.reshape(*G_for_pred.shape[:-1], G_for_pred.shape[-1]) g = g.reshape(*g.shape[:-1], g.shape[-1] // self.n_timesteps, self.n_timesteps)[..., self.n_timesteps - 1] # Option 1: use the koopman object decoder s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), psteps=self.psteps) s_for_pred = self.koopman.to_s(gcodes=_get_flat(G_for_pred), psteps=self.psteps) # Note: Split features (cte, dyn). In case of reverse, f_cte averages for both directions. # Note 2: TODO: Reconstruction could be in reversed g's or both! s_for_rec = torch.cat([ s_for_rec.reshape(bs * self.n_objects, T, -1), f_cte[:, None].mean(2).repeat(1, T, 1) ], dim=-1) s_for_pred = torch.cat([ s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1), f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1) ], dim=-1) s_for_rec = _get_flat(s_for_rec) s_for_pred = _get_flat(s_for_pred) # Option 2: we don't use the koopman object decoder # s_for_rec = _get_flat(g) # s_for_pred = _get_flat(G_for_pred) # Convolutional decoder. Normally Spatial Broadcasting decoder out_rec = self.image_decoder(s_for_rec) out_pred = self.image_decoder(s_for_pred) returned_g = torch.cat([g, G_for_pred], dim=1) returned_mus = torch.cat([f_mu], dim=-1) returned_logvars = torch.cat([f_logvar], dim=-1) # returned_mus = torch.cat([f_mu, a_mu], dim=-1) # returned_logvars = torch.cat([f_logvar, a_logvar], dim=-1) o_touple = (out_rec, out_pred, returned_g.reshape(-1, returned_g.size(-1)), returned_mus.reshape(-1, returned_mus.size(-1)), returned_logvars.reshape(-1, returned_logvars.size(-1))) # f_mu.reshape(-1, f_mu.size(-1)), # f_logvar.reshape(-1, f_logvar.size(-1))) o = [ item.reshape( torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple ] # Option 1: one object mapped to 0 # shape_o0 = o[0].shape # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:]) # o[0][:,0] = o[0][:,0]*0 # o[0] = o[0].reshape(*shape_o0) # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2]] o[:2] = [ torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2] ] o[3:5] = [ item.reshape(bs, self.n_objects, *item.shape[1:]) for item in o[3:5] ] o.append(A) o.append(B) u = u.reshape(*u.shape[:2], self.u_dim, self.n_timesteps)[..., -1] o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1])) o.append( u_dist.reshape(bs * self.n_objects, -1, *u_dist.shape[-1:])) #Append udist for categorical # o.append(u_dist.reshape(bs * self.n_objects, -1, *u_dist.shape[-2:])) o.append( g_for_koop.reshape(bs * self.n_objects, -1, g_for_koop.shape[-1] // self.n_timesteps, self.n_timesteps)) o.append(fit_err) return o
class RecKoopmanModel(BaseModel): def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, u_dim, n_objects, I_factor=10, n_blocks=1, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]): super().__init__() out_channels = 1 n_layers = int(np.log2(image_size[0])) - 1 self.u_dim = u_dim # Set state dim with config, depending on how many time-steps we want to take into account self.image_size = image_size self.n_timesteps = n_timesteps self.state_dim = feat_dim self.I_factor = I_factor self.psteps = psteps self.g_dim = g_dim feat_dyn_dim = feat_dim // 8 feat_dyn_dim = 4 self.feat_dyn_dim = feat_dyn_dim self.feat_cte_dim = feat_dim - feat_dyn_dim self.with_u = False self.n_iters = 1 self.ini_alpha = 1 # Note: # - I leave it to 0 now. If it increases too fast, the gradients might be affected self.incr_alpha = 0.1 self.cte_resolution = (32, 32) self.ori_resolution = (128, 128) self.att_resolution = (16, 16) self.obj_resolution = (4, 4) # self.n_objects = reduce((lambda x, y: x * y), self.obj_resolution) self.n_objects = n_objects self.spatial_tf = SpatialTransformation(self.cte_resolution, self.ori_resolution) self.linear_f_cte_post = nn.Linear(2 * self.feat_cte_dim, 2 * self.feat_cte_dim) self.linear_f_dyn_post = nn.Linear(2 * self.feat_dyn_dim, 2 * self.feat_dyn_dim) # self.bc_decoder = False #2 * if self.bc_decoder: self.image_decoder = ImageBroadcastDecoder( self.feat_cte_dim, out_channels, resolution=(16, 16)) # resolution=self.att_resolution else: self.image_decoder = ImageDecoder(self.feat_cte_dim, out_channels, dyn_dim=self.feat_dyn_dim) self.image_encoder = ImageEncoder( in_channels, 2 * self.feat_cte_dim, self.feat_dyn_dim, self.att_resolution, self.n_objects, ngf, n_layers) # feat_dim * 2 if sample here self.koopman = KoopmanOperators(feat_dyn_dim, nf_particle, nf_effect, g_dim, u_dim, n_timesteps, n_blocks) def _get_full_state(self, x, T): if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[1:]) new_x = [] for t in range(new_T): new_x.append( torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, new_x.shape[-1]), new_T def _get_full_state_hankel(self, x, T): ''' :param x: features or observations :param T: number of time-steps before concatenation :return: Columns of a hankel matrix with self.n_timesteps rows. ''' if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[2:]) new_x = [] for t in range(new_T): new_x.append( torch.stack([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, new_T, new_x.shape[-2] * new_x.shape[-1]), new_T def forward(self, input, epoch=1): # Note: Add annealing in SPACE bs, T, ch, h, w = input.shape # Percentage of output free_pred = T // 4 returned_post = [] input = input.cuda() # Backbone deterministic features f_bb = self.image_encoder(input, block='backbone') # Dynamic features T_inp = T f_dyn, shape, f_cte, confi = self.image_encoder(f_bb[:, :T_inp], block='dyn_track') f_dyn = f_dyn.reshape(-1, f_dyn.shape[-1]) # Sample dynamic features or reshape # Option 1: Don't sample f_dyn = f_dyn.reshape(bs * self.n_objects, T_inp, -1) # Option 2: Sample # f_mu_dyn, f_logvar_dyn = self.linear_f_dyn_post(f_dyn).reshape(bs, self.n_objects, T_inp, -1).chunk(2, -1) # f_dyn_post = Normal(f_mu_dyn, F.softplus(f_logvar_dyn)) # f_dyn = f_dyn_post.rsample() # f_dyn = f_dyn.reshape(bs * self.n_objects, T_inp, -1) # returned_post.append(f_dyn_post) # Get delayed dynamic features f_dyn_state, T_inp = self._get_full_state_hankel(f_dyn, T_inp) # Get inputs from delayed dynamic features u, u_dist = self.koopman.to_u(f_dyn_state, temp=self.ini_alpha + epoch * self.incr_alpha, ignore=True) if not self.with_u: u = torch.zeros_like(u) # Get observations from delayed dynamic features g = self.koopman.to_g( f_dyn_state.reshape(bs * self.n_objects * T_inp, -1), self.psteps) g = g.reshape(bs * self.n_objects, T_inp, *g.shape[1:]) # Get shifted observations for sys ID randperm = torch.arange(g.shape[0]) # No permutation # randperm = torch.randperm(g.shape[0]) # Random permutation if free_pred > 0: G_tilde = g[randperm, :-1 - free_pred, None] H_tilde = g[randperm, 1:-free_pred, None] else: G_tilde = g[randperm, :-1, None] H_tilde = g[randperm, 1:, None] # Sys ID A, B, A_inv, fit_err = self.koopman.system_identify( G=G_tilde, H=H_tilde, U=u[randperm, :T_inp - free_pred - 1], I_factor=self.I_factor ) # Try not permuting U when inp is permutted # Rollout from start_step onwards. start_step = 2 # g and u must be aligned!! G_for_pred = self.koopman.simulate(T=T_inp - start_step - 1, g=g[:, start_step], u=u[:, start_step:], A=A, B=B) g_for_koop = G_for_pred assert f_bb[:, self.n_timesteps - 1:self.n_timesteps - 1 + T_inp].shape[1] == f_bb[:, self.n_timesteps - 1:].shape[1] rec = { "obs": g, "confi": confi[:, :, self.n_timesteps - 1:self.n_timesteps - 1 + T_inp], "shape": shape[:, :, self.n_timesteps - 1:self.n_timesteps - 1 + T_inp], "f_cte": f_cte[:, :, self.n_timesteps - 1:self.n_timesteps - 1 + T_inp], "T": T_inp, "name": "rec" } pred = { "obs": G_for_pred, "confi": confi[:, :, -G_for_pred.shape[1]:], "shape": shape[:, :, -G_for_pred.shape[1]:], "f_cte": f_cte[:, :, -G_for_pred.shape[1]:], "T": G_for_pred.shape[1], "name": "pred" } outs = {} BBs = {} # TODO: Check if the indices for f_bb and supervision are correct. # Recover partial shape with decoded dynamical features. Iterate with new estimates of the appearance. # Note: This process could be iterative. for idx, case in enumerate([rec, pred]): case_name = case["name"] # get back dynamic features # if case_name == "rec": # f_dyn_state = f_dyn[:, -T_inp:].reshape(-1, *f_dyn.shape[2:]) # else: # TODO: Try with Initial f_dyn (no koop) f_dyn_state = self.koopman.to_s(gcodes=_get_flat(case["obs"]), psteps=self.psteps) # TODO: From f_dyn extract where, scale, etc. with only Y = WX + b. Review. # Initial noisy (low variance) constant vector. # Note: # - Different realization for each time-step. # - Is it a Variable? # f_cte_ini = torch.randn(bs * self.n_objects * case["T"], self.feat_cte_dim).to(f_dyn_state.device) #* self.f_cte_ini_std # Get full feature vector # f = torch.cat([ f_dyn_state, # f_cte_ini], dim=-1) # Get coarse features from which obtain queries and/or decode # f_coarse = self.image_decoder(f, block = 'coarse') # f_coarse = f_coarse.reshape(bs, self.n_objects, case["T"], *f_coarse.shape[1:]) for _ in range(self.n_iters): # Get constant feature vector through attention # f_cte = self.image_encoder(case["backbone_features"], f_coarse, block='cte') # Encode with raw dynamic features # pose = self.pose_encoder(f_dyn_state) pose = f_dyn_state # M_center, M_relocate = get_affine_params_from_pose(pose, self.pose_limits, self.pose_bias) # Option 2: Previous impl # bb_feat = case["backbone_features"].unsqueeze(1).repeat_interleave(self.n_objects, dim=1)\ # .reshape(-1, *case["backbone_features"].shape[-4:]) # Repeat for number of objects # warped_bb_feat = (bb_feat, M_center, *self.cte_resolution) # Check all resolutions are coordinate [32, 32] # f_cte = self.image_encoder(warped_bb_feat, f_dyn_state, block='cte') # Sample cte features f_cte = case["f_cte"].reshape(bs * self.n_objects * case["T"], -1) confi = case["confi"].reshape(bs * self.n_objects * case["T"], -1) # f_mu_cte, f_logvar_cte = self.linear_f_cte_post(f_cte).reshape( bs, self.n_objects, case["T"], -1).chunk(2, -1) f_cte_post = Normal(f_mu_cte, F.softplus(f_logvar_cte)) f_cte = f_cte_post.rsample().reshape( bs * self.n_objects * case["T"], -1) # Option 2: Previous impl # f_mu_cte, f_logvar_cte = self.linear_f_cte_post(f_cte).reshape(bs, self.n_objects, 1, -1).chunk(2, -1) # f_cte_post = Normal(f_mu_cte, F.softplus(f_logvar_cte)) # f_cte = f_cte_post.rsample() # f_cte = f_cte.repeat_interleave(case["T"], dim=2) # f_cte = f_cte.reshape(bs * self.n_objects * case["T"], self.feat_cte_dim) # Register statistics returned_post.append(f_cte_post) # Get full feature vector # f = torch.cat([ f_dyn_state, # f_cte], dim=-1) # Note: Do if f_dyn_state is used in the appearance f = f_cte # Get output. Spatial broadcast decoder dec_obj = self.image_decoder(f, block='to_x') if not self.bc_decoder: grid, area = self.spatial_tf(confi, pose) outs[case_name], out_shape = self.spatial_tf.warp_and_render( dec_obj, case["shape"], confi, grid) else: outs[case_name] = dec_obj * confi[..., None, None] # outs[case_name] = warp_affine(dec_obj, M_relocate, h=h, w=w) BBs[case_name] = dec_obj out_rec = outs["rec"] out_pred = outs["pred"] bb_rec = BBs["rec"] # bb_rec = BBs["pred"] # Test disentanglement - TO REVIEW # if random.random() < 0.1 or self.content is None: # self.content = f_cte # f_cte = self.content ''' Returned variables ''' returned_g = torch.cat([g, G_for_pred], dim=1) # Observations o_touple = (out_rec, out_pred, returned_g.reshape(-1, returned_g.size(-1))) o = [ item.reshape( torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple ] # Option 1: one object mapped to 0 # shape_o0 = o[0].shape # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:]) # o[0][:,0] = o[0][:,0]*0 # o[0] = o[0].reshape(*shape_o0) # Sum across objects. Note: Add Clamp or Sigmoid. o[:2] = [ torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2] ] # o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]] # TODO: output = {} # output['rec'] = o[0] # output['pred'] = o[1] o.append(returned_post) o.append(A) # State transition matrix o.append(B) # Input matrix o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1])) # Inputs o.append(u_dist.reshape(bs * self.n_objects, -1, u_dist.shape[-1])) # Input distribution o.append( g_for_koop.reshape( bs * self.n_objects, -1, g_for_koop.shape[-1])) # Observation propagated only with A o.append(fit_err) # Fit error g to G_for_pred # bb_rec = bb_rec.reshape(torch.Size([bs * self.n_objects, -1]) + bb_rec.size()[1:]) # bb_rec =torch.sum(bb_rec, dim=2, keepdim=True) # o.append(bb_rec.reshape(bs, self.n_objects, *bb_rec.shape[1:])) # Motion field reconstruction # TODO: return in dictionary form return o
class RecKoopmanModel(BaseModel): def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, u_dim, n_objects, I_factor=10, n_blocks=1, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]): super().__init__() out_channels = 1 n_layers = int(np.log2(image_size[0])) - 1 self.u_dim = u_dim # Temporal encoding buffers if n_timesteps > 1: t = torch.linspace(-1, 1, n_timesteps) # Add as constant, with extra dims for N and C self.register_buffer('t_grid', t) # Set state dim with config, depending on how many time-steps we want to take into account self.image_size = image_size self.n_timesteps = n_timesteps self.state_dim = feat_dim self.I_factor = I_factor self.psteps = psteps self.g_dim = g_dim self.n_objects = n_objects # self.softmax = nn.Softmax(dim=-1) # self.sigmoid = nn.Sigmoid() # self.linear_g = nn.Linear(g_dim, g_dim) # self.initial_conditions = nn.Sequential(nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps), # nn.ReLU(), # nn.Linear(feat_dim * n_timesteps, g_dim * 2)) feat_dyn_dim = feat_dim // 6 self.feat_dyn_dim = feat_dyn_dim self.content = None # self.reverse = True self.reverse = False self.ini_alpha = 1 self.incr_alpha = 0.25 # self.rnn_f_cte = nn.LSTM(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim, 1, bias=False, batch_first=True) # self.rnn_f_cte = nn.GRU(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim, 1, bias=False, batch_first=True) # self.linear_f_mu = nn.Linear(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim) # self.linear_f_logvar = nn.Linear(feat_dim - feat_dyn_dim, feat_dim - feat_dyn_dim) # self.linear_f_dyn_mu = nn.Linear(feat_dyn_dim, feat_dyn_dim) # self.linear_f_dyn_logvar = nn.Linear(feat_dyn_dim, feat_dyn_dim) self.linear_f_mu = nn.Linear(feat_dim, feat_dim) self.linear_f_logvar = nn.Linear(feat_dim, feat_dim) self.image_encoder = ImageEncoder(in_channels, feat_dim, n_objects, ngf, n_layers) # feat_dim * 2 if sample here self.image_decoder = ImageDecoder(feat_dim, out_channels, ngf, n_layers) self.koopman = KoopmanOperators(feat_dyn_dim, nf_particle, nf_effect, g_dim, u_dim, n_timesteps, n_blocks) def _get_full_state(self, x, T): if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[1:]) new_x = [] for t in range(new_T): new_x.append(torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, new_x.shape[-1]), new_T def _get_full_state_hankel(self, x, T): ''' :param x: features or observations :param T: number of time-steps before concatenation :return: Columns of a hankel matrix with self.n_timesteps rows. ''' if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[2:]) new_x = [] for t in range(new_T): new_x.append(torch.stack([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=1) return new_x.reshape(-1, new_T, new_x.shape[-2] * new_x.shape[-1]), new_T def forward(self, input, epoch = 1): bs, T, ch, h, w = input.shape free_pred = T//4 f = self.image_encoder(input) # (bs * n_objects * T, feat_dim) # Option 1: Mean before sampling # f_cte = f.reshape(bs * self.n_objects, T, *f.shape[1:])\ # [..., self.feat_dyn_dim:] # f_cte = f_cte.mean(1) # # LSTM to obtain cte # # h0 = torch.zeros_like(f_cte[:, 0])[None] # # c0 = torch.zeros_like(f_cte[:, 0])[None] # # f_cte, _ = self.rnn_f_cte(f_cte[:, :-free_pred], (h0, c0)) # # f_cte = f_cte[:, -1] # f_mu_cte, f_logvar_cte = self.linear_f_mu(f_cte)[:, None].repeat(1, T, 1).reshape(bs * self.n_objects * T, -1), \ # self.linear_f_logvar(f_cte)[:, None].repeat(1, T, 1).reshape(bs * self.n_objects * T, -1) # f_cte = _sample_latent_simple(f_mu_cte, f_logvar_cte) # # # f_dyn = f[..., :self.feat_dyn_dim] # f_mu_dyn, f_logvar_dyn = self.linear_f_dyn_mu(f_dyn), \ # self.linear_f_dyn_logvar(f_dyn) # f_dyn = _sample_latent_simple(f_mu_dyn, f_logvar_dyn) # # # f_mu = torch.cat([f_mu_dyn, f_mu_cte], dim=-1) # f_logvar = torch.cat([f_logvar_dyn, f_logvar_cte], dim=-1) # Option 2: Sampling all together f_mu = self.linear_f_mu(f) f_logvar = self.linear_f_logvar(f) f = _sample_latent_simple(f_mu, f_logvar) # Note: Split features (cte, dyn) f_dyn, f_cte = f[..., :self.feat_dyn_dim], f[..., self.feat_dyn_dim:] f_cte = f_cte.reshape(bs * self.n_objects, T, *f_cte.shape[1:]) # Test disentanglement # if random.random() < 0.1 or self.content is None: # self.content = f_cte # f_cte = self.content f_dyn = f_dyn.reshape(bs * self.n_objects, T, *f_dyn.shape[1:]) f_dyn, T = self._get_full_state_hankel(f_dyn, T) # TODO: U might depend also in features from the scene. Residual features (slot n+1 / Background) u, u_dist = self.koopman.to_u(f_dyn, temp=self.ini_alpha + epoch * self.incr_alpha) g = self.koopman.to_g(f_dyn.reshape(bs * self.n_objects * T, -1), self.psteps) g = g.reshape(bs * self.n_objects, T, *g.shape[1:]) # TODO: # 0 Ho foto tot a varying y a tomar # ... # 0: state as velocity acceleration, pose # 0: g fitting error. # 0: Hankel view. F**k koopman for now, it has too many degrees of freedom. Restar input para minimizar rango de H. # 1: Invert or Sample randomly u and maximize the error in reconstruction. # 2: Treat n_timesteps with conv_nn? Or does it make sense to mix f1(t) and f2(t-1)? # 3: We don't impose A symmetric, but we impose that g in both directions use the same A and B. # Explode symmetry. If there's input in the a sequence, there will be the same input in the reverse sequence. # B is valid for both. Then we cant cross, unless we recalculate B. It might not be necessary because we use # the same A for both dirs. INVERT U's temporally --> that's a great idea. Dismiss (n_ini_timesteps - 1 samples) in each extreme Is this correct? # Is it the same linear input intervention if I do it forward and backward? # 5: Eigenvalues and Eigenvector. Fix one get the other. # 6: Observation selector using Gumbel (should apply mask emissor and receptor # and it should be the same across time) # 7: Attention mechanism to select - koopman (sub-koopman) / Identity. It will be useful for the future. # 8: When works, formalize code. Hyperparameters through config.yaml # 9: The nonlinear projection should also separate the objects. We can separate by the svd? Is it possible? # 1: Prev_sample not necessary? Start from g(0) # 2: Sampling from features # 3: Bottleneck in f. Smaller than G. Constant part is big and routed directly. # Variant part is very small and expanded to G. # 4: In evaluation, sample features several times but only keep one of the cte. Check if it captures the appearance # 5: Fix B with A learned. If the first doesn't require input, B will be mapped to 0. randperm = torch.arange(g.shape[0]) if self.reverse or True\ else torch.randperm(g.shape[0]) #Note: Inverting u(t) in the time axis # TODO: still here but im tired. # if free_pred > 0: # G_tilde = g[randperm, self.n_timesteps -1:-1-free_pred, None] # H_tilde = g[randperm, self.n_timesteps:-free_pred, None] # else: # G_tilde = g[randperm, self.n_timesteps -1:-1, None] # H_tilde = g[randperm, self.n_timesteps:, None] if free_pred > 0: G_tilde = g[randperm, :-1-free_pred, None] H_tilde = g[randperm, 1:-free_pred, None] else: G_tilde = g[randperm, :-1, None] H_tilde = g[randperm, 1:, None] # TODO: If we identify with half of the timesteps, but use all inputs for rollout, # we might get something. We can also predict only the future. A, B, A_inv, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, U=u[randperm, :-1-free_pred], I_factor=self.I_factor) #Note: Not permutting input, but permutting g. # A, B, A_inv, fit_err = self.koopman.fit_with_A(G=G_tilde, H=H_tilde, U=u[randperm, :-1], I_factor=self.I_factor) # A, B, A_inv, fit_err = self.koopman.fit_with_B(G=G_tilde, H=H_tilde, U=u[randperm, :-1-free_pred], I_factor=self.I_factor) # A, B = self.koopman.fit_with_AB(G_tilde.shape[0]) # A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A(G=G_tilde, H=H_tilde, I_factor=self.I_factor) # B = None # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B) # g_for_koop = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None) # G_for_pred = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1], # self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=u[:, self.n_timesteps -1:], A=A, B=B)], dim=1) # g_for_koop = torch.cat([g.reshape(*g.shape[:2], -1, self.n_timesteps)[:,1:self.n_timesteps, :, -1], # self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=None, A=A, B=None)], dim=1) # G_for_pred = self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=u[:, self.n_timesteps -1:], A=A, B=B) # g_for_koop = self.koopman.simulate(T=T - self.n_timesteps, g=g[:, self.n_timesteps -1], u=None, A=A, B=None) # G_for_pred = self.koopman.simulate(T=T - 4, g=g[:, 3], u=u, A=A, B=B) # g_for_koop = self.koopman.simulate(T=T - 4, g=g[:, 3], u=None, A=A, B=None) g_start, T_start = g[:, 2], T-3 G_for_pred = self.koopman.simulate(T=T_start, g=g_start, u=u[!], A=A, B=B) # TODO: ATTENTION. U MIGHT NOT BE ALIGNED TEMPORALLY! # g_for_koop = self.koopman.simulate(T=T - 2, g=g[:, 1], u=None, A=A, B=None) g_for_koop= G_for_pred # g_for_koop = G_for_pred # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None) '''Simple version of the reverse rollout''' # G_for_pred_rev = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A_pinv, B=None) # G_for_pred = torch.flip(G_for_pred_rev, dims=[1]) # G_for_pred = torch.cat([g[:,0:1],self.koopman.simulate(T=T - 2, g=g[:, 1], u=u, A=A, B=B)], dim=1) s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), psteps=self.psteps) s_for_pred = self.koopman.to_s(gcodes=_get_flat(G_for_pred), psteps=self.psteps) # Note: Split features (cte, dyn). In case of reverse, f_cte averages for both directions. # Note 2: TODO: Reconstruction could be in reversed g's or both! # Option 1: Temporally permute appearance # T_randperm = torch.randperm(g.shape[1]) # torch.arange(g.shape[1]) # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1), # f_cte[:, -T:]], dim=-1) # # f_cte = f_cte[:, T_randperm] # f_cte = f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1) # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1), # f_cte[:, -G_for_pred.shape[1]:]], dim=-1) # T_randperm = torch.randperm(g.shape[1]) # torch.arange(g.shape[1]) # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1), # f_cte[:, -T:]], dim=-1) # # f_cte = f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1) # f_cte = f_cte[:, T_randperm] # # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1), # f_cte[:, -G_for_pred.shape[1]:]], dim=-1) # Option 2: Average appearance s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1), f_cte[:, None].mean(2).repeat(1, T, 1)], dim=-1) s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1), f_cte[:, None].mean(2).repeat(1, G_for_pred.shape[1], 1)], dim=-1) # Option 3: Copy the first appearance (or a random one) # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1), # f_cte[:, 0:1].repeat(1, T, 1)], dim=-1) # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1), # f_cte[:, -G_for_pred.shape[1]:-G_for_pred.shape[1]+1].repeat(1, G_for_pred.shape[1], 1)], dim=-1) # Option 4: If f_cte has been computed before. # f_cte = f_cte.reshape(bs * self.n_objects, -1, f_cte.shape[-1]) # s_for_rec = torch.cat([s_for_rec.reshape(bs * self.n_objects, T, -1), # f_cte[:, -T:]], dim=-1) # s_for_pred = torch.cat([s_for_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1), # f_cte[:, -G_for_pred.shape[1]:]], dim=-1) # # s_for_rec = _get_flat(s_for_rec) # s_for_pred = _get_flat(s_for_pred) # NOTE: Sample after # f_mu = self.linear_f_mu(s_for_rec) # f_logvar = self.linear_f_logvar(s_for_rec) # s_for_rec = _sample_latent_simple(f_mu, f_logvar) # # # f_mu_pred = self.linear_f_mu(s_for_pred) # f_logvar_pred = self.linear_f_logvar(s_for_pred) # s_for_pred = _sample_latent_simple(f_mu_pred, f_logvar_pred) # f_mu, f_mu_pred = f_mu.reshape(bs * self.n_objects, T, -1), f_mu_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1) # f_logvar, f_logvar_pred = f_logvar.reshape(bs * self.n_objects, T, -1), f_logvar_pred.reshape(bs * self.n_objects, G_for_pred.shape[1], -1) # Convolutional decoder. Normally Spatial Broadcasting decoder out_rec = self.image_decoder(s_for_rec) out_pred = self.image_decoder(s_for_pred) returned_g = torch.cat([g, G_for_pred], dim=1) returned_mus = torch.cat([f_mu], dim=1) returned_logvars = torch.cat([f_logvar], dim=1) # returned_mus = torch.cat([f_mu, f_mu_pred], dim=1) # returned_logvars = torch.cat([f_logvar, f_logvar_pred], dim=1) o_touple = (out_rec, out_pred, returned_g.reshape(-1, returned_g.size(-1)), returned_mus.reshape(-1, returned_mus.size(-1)), returned_logvars.reshape(-1, returned_logvars.size(-1))) # f_mu.reshape(-1, f_mu.size(-1)), # f_logvar.reshape(-1, f_logvar.size(-1))) o = [item.reshape(torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple] # Option 1: one object mapped to 0 # shape_o0 = o[0].shape # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:]) # o[0][:,0] = o[0][:,0]*0 # o[0] = o[0].reshape(*shape_o0) # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2]] # Test object decomposition # o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:])[:,0:1], dim=1) for item in o[:2]] o[:2] = [torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2]] o[3:5] = [item.reshape(bs, self.n_objects, *item.shape[1:]) for item in o[3:5]] o.append(A) o.append(B) # u = u.reshape(*u.shape[:2], self.u_dim, self.n_timesteps)[..., -1] #TODO:HANKEL This is for hankel view o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1])) o.append(u_dist.reshape(bs * self.n_objects, -1, *u_dist.shape[-1:])) #Append udist for categorical # o.append(u_dist.reshape(bs * self.n_objects, -1, *u_dist.shape[-2:])) o.append(g_for_koop.reshape(bs * self.n_objects, -1, g_for_koop.shape[-1])) o.append(fit_err) return o
class RecKoopmanModel(BaseModel): def __init__(self, in_channels, feat_dim, nf_particle, nf_effect, g_dim, u_dim, n_objects, I_factor=10, n_blocks=1, psteps=1, n_timesteps=1, ngf=8, image_size=[64, 64]): super().__init__() out_channels = 1 n_layers = int(np.log2(image_size[0])) - 1 self.u_dim = u_dim # Temporal encoding buffers if n_timesteps > 1: t = torch.linspace(-1, 1, n_timesteps) # Add as constant, with extra dims for N and C self.register_buffer('t_grid', t) # Set state dim with config, depending on how many time-steps we want to take into account self.image_size = image_size self.n_timesteps = n_timesteps self.state_dim = feat_dim self.I_factor = I_factor self.psteps = psteps self.g_dim = g_dim self.n_objects = n_objects self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid() feat_dyn_dim = feat_dim #// 8 self.linear_g = nn.Linear(g_dim, g_dim) self.linear_u = nn.Linear(g_dim + u_dim, u_dim * 2) # self.initial_conditions = nn.Sequential(nn.Linear(feat_dim * n_timesteps * 2, feat_dim * n_timesteps), # nn.ReLU(), # nn.Linear(feat_dim * n_timesteps, g_dim * 2)) self.image_encoder = ImageEncoder( in_channels, feat_dim, n_objects, ngf, n_layers) # feat_dim * 2 if sample here self.image_decoder = ImageDecoder(feat_dyn_dim, out_channels, ngf, n_layers) self.koopman = KoopmanOperators(feat_dim, nf_particle, nf_effect, g_dim, u_dim, n_timesteps, n_blocks) def _get_full_state(self, x, T): if self.n_timesteps < 2: return x, T new_T = T - self.n_timesteps + 1 x = x.reshape(-1, T, *x.shape[1:]) new_x = [] for t in range(new_T): new_x.append( torch.cat([x[:, t + idx] for idx in range(self.n_timesteps)], dim=-1)) # torch.cat([ torch.zeros_like( , x[:,0,0:1]) + self.t_grid[idx]], dim=-1) new_x = torch.stack(new_x, dim=2) return new_x.reshape(-1, *new_x.shape[3:]), new_T def forward(self, input): bs, T, ch, h, w = input.shape f = self.image_encoder(input) # Concatenates the features from current time and previous to create full state f, T = self._get_full_state(f, T) f = f.view(torch.Size([bs * self.n_objects, T]) + f.size()[1:]) f_mu, f_logvar, u_dist = [], [], [] gs, us = [], [] for t in range(T): if t == 0: prev_sample = torch.zeros_like(f[:, 0, :1].repeat(1, self.g_dim)) prev_u_sample = torch.zeros_like(f[:, 0, :self.u_dim]) # TODO: # 1: Prev_sample not necessary? Start from g(0) # 2: Sampling from features # 3: Bottleneck in f. Smaller than G. Constant part is big and routed directly. # Variant part is very small and expanded to G. # 4: Clip gradient, clip A_inv, norm of the matrix: Spectral normalization # 5: Eigenvalues and Eigenvector. Fix one get the other. # 6: Half of features skip the koopman embedding # 7: Observation selector using Gumbel (should apply mask emissor and receptor # and it should be the same across time) # 8: Should I write the conversation with Sandesh? # Finally, check compositional koopman implementation # f_t = self.initial_conditions(f[:, t]) f_t = torch.cat([f[:, t], prev_sample], dim=-1) g = self.koopman.to_g(f_t, self.psteps) else: f_t = torch.cat([f[:, t], prev_sample], dim=-1) g = self.koopman.to_g(f_t, self.psteps) # g, prev_hidden, u = self.koopman.to_g(f_t, prev_hidden) #,psteps=self.psteps g_mu, g_logvar = torch.chunk(g, 2, dim=-1) g = _sample_latent_simple(g_mu, g_logvar) g, q_y = self.linear_g(g), F.relu( self.linear_u(torch.cat([g, prev_u_sample], dim=-1))) q_y = q_y.view(q_y.size(0), self.u_dim, 2) u = F.gumbel_softmax(q_y, tau=1, hard=True) u = u[..., 0] prev_sample, prev_u_sample = g, u gs.append(g) us.append(u) f_mu.append(g_mu) f_logvar.append(g_logvar) u_dist.append(q_y) g = torch.stack(gs, dim=1) f_mu = torch.stack(f_mu, dim=1) f_logvar = torch.stack(f_logvar, dim=1) u = torch.stack(us, dim=1) u_dist = torch.stack(u_dist, dim=1) free_pred = 0 if free_pred > 0: G_tilde = g[:, 1:-1 - free_pred, None] # new axis corresponding to N number of objects H_tilde = g[:, 2:-free_pred, None] else: G_tilde = g[:, 1:-1, None] # new axis corresponding to N number of objects H_tilde = g[:, 2:, None] # A, B, fit_err = self.koopman.system_identify(G=G_tilde, H=H_tilde, U=u[:, 1:-1], I_factor=self.I_factor) A, B, fit_err = self.koopman.fit_with_A(G=G_tilde, H=H_tilde, U=u[:, 1:-1], I_factor=self.I_factor) # A, A_pinv, fit_err = self.koopman.fit_block_diagonal_A(G=G_tilde, H=H_tilde, I_factor=self.I_factor) # TODO: Try simulating backwards # B=None # Rollout. From observation in time 0, predict with Koopman operator # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=u, A=A, B=B) # G_for_pred = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A, B=None) '''Simple version of the reverse rollout''' # G_for_pred_rev = self.koopman.simulate(T=T - 1, g=g[:, 0], u=None, A=A_pinv, B=None) # G_for_pred = torch.flip(G_for_pred_rev, dims=[1]) G_for_pred = torch.cat([ g[:, 0:1], self.koopman.simulate(T=T - 2, g=g[:, 1], u=u, A=A, B=B) ], dim=1) # Option 1: use the koopman object decoder s_for_rec = self.koopman.to_s(gcodes=_get_flat(g), psteps=self.psteps) # torch.cat([g[:, :1],G_for_pred], dim=1) s_for_pred = self.koopman.to_s(gcodes=_get_flat(G_for_pred), psteps=self.psteps) # Option 2: we don't use the koopman object decoder # s_for_rec = _get_flat(g) # s_for_pred = _get_flat(G_for_pred) # Convolutional decoder. Normally Spatial Broadcasting decoder out_rec = self.image_decoder(s_for_rec) out_pred = self.image_decoder(s_for_pred) returned_g = torch.cat([g, G_for_pred], dim=1) returned_mus = torch.cat([f_mu], dim=-1) returned_logvars = torch.cat([f_logvar], dim=-1) # returned_mus = torch.cat([f_mu, a_mu], dim=-1) # returned_logvars = torch.cat([f_logvar, a_logvar], dim=-1) o_touple = (out_rec, out_pred, returned_g.reshape(-1, returned_g.size(-1)), returned_mus.reshape(-1, returned_mus.size(-1)), returned_logvars.reshape(-1, returned_logvars.size(-1))) # f_mu.reshape(-1, f_mu.size(-1)), # f_logvar.reshape(-1, f_logvar.size(-1))) o = [ item.reshape( torch.Size([bs * self.n_objects, -1]) + item.size()[1:]) for item in o_touple ] # Option 1: one object mapped to 0 # shape_o0 = o[0].shape # o[0] = o[0].reshape(bs, self.n_objects, *o[0].shape[1:]) # o[0][:,0] = o[0][:,0]*0 # o[0] = o[0].reshape(*shape_o0) # o[:2] = [torch.clamp(torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1), min=0, max=1) for item in o[:2]] o[:2] = [ torch.sum(item.reshape(bs, self.n_objects, *item.shape[1:]), dim=1) for item in o[:2] ] o[3:5] = [ item.reshape(bs, self.n_objects, *item.shape[1:]) for item in o[3:5] ] o.append(A) o.append(B) o.append(u.reshape(bs * self.n_objects, -1, u.shape[-1])) o.append(u_dist) #.reshape(bs * self.n_objects, -1, u_dist.shape[-1])) # # Show images # plt.imshow(input[0, 0, :].reshape(16, 16).cpu().detach().numpy()) # plt.savefig('test_attention.png') return o