def forward_fw_warp(self, sta_obs_s0, dyn_obs_t1_s0, dyn_obs_t0_s0): # dyn_obs_t1_s0.register_hook(lambda grad: print(torch.norm(grad))) of = self.motion(dyn_obs_t1_s0 - dyn_obs_t0_s0) #Range: [-1, 1] # of.register_hook(lambda grad: print(torch.norm(grad))) of, residual_of = self.constrain_of(of) vis_dict_out = {'optical_flow': of} #TODO: optical flow with softmax on the differenced. # of = torch.zeros_like(of) + 0.01 # dyn_obs_s1, sta_obs_s1 = dyn_obs_s0, sta_obs_s0 # mode = 'nearest' # Note: uv can be rounded? Check Pytorch docs for nearest. # dyn_obs_s1 = F.grid_sample(dyn_obs_s0, of.permute(0, 2, 3, 1), mode=mode) # NTO * 1 * H * W # sta_obs_s1 = F.grid_sample(sta_obs_s0, of.permute(0, 2, 3, 1), mode=mode) # NTO * 1 * H * W # dyn_obs_s1 = softsplat.FunctionSoftsplat(tenInput=dyn_obs_s0, tenFlow=of, tenMetric=None, strType='summation') # sta_obs_s1 = softsplat.FunctionSoftsplat(tenInput=sta_obs_s0.contiguous(), tenFlow=of, tenMetric=None, strType='summation') # TODO: Test with of 0 obs_s0 = torch.cat([sta_obs_s0, dyn_obs_t1_s0], dim=-3) #+ self.soft_residual(residual_of) # obs_s0 = obs_s0 + self.soft_residual(residual_of) obs_s1 = self.soft_fw(obs_s0, of, temp=0.1, constrain_of=True) # obs_s1.register_hook(lambda grad: print(torch.norm(grad), 'after soft_fw')) tut.norm_grad(obs_s1, 5) #TODO: Sure? sta_obs_s1, dyn_obs_t1_s1 = obs_s1[..., :-self.input_dim, :, :], obs_s1[..., -self.input_dim:, :, :] # Option 2 # TODO: Unfold + SpaTX Grid_sample for each location. return sta_obs_s1, dyn_obs_t1_s1, vis_dict_out
def to_selector(self, g, temp=1): """ state encoder """ obs_shape = g.shape sel_logit = self.selector_fc(g.reshape(-1, obs_shape[-1])) sel_sm = self.softmax(sel_logit) tut.norm_grad(sel_sm, 2) selector = self.st_gumbel_softmax(sel_sm) return selector
def forward(self, h_o_prev, y_e_prev, C_o, T): """ h_o_prev: N * O * dim_h_o y_e_prev: N * O * dim_y_e C_o: N * C2_1 * C2_2 """ # o = self.o dims = self.dims conf_dim, layer_dim, pose_dim, shape_dim, app_dim = dims[ 'confidence'], dims['layer'], dims['pose'], dims['shape'], dims[ 'appearance'] '''Visualize''' # if self.att is None and self.mem is None: # self.att = torch.Tensor(T, self.n_objects, self.ntm_cell.ha, self.ntm_cell.wa).to(h_o_prev.device) # self.mem = torch.Tensor(T, self.n_objects, self.ntm_cell.ha, self.ntm_cell.wa).to(h_o_prev.device) '''TEM?''' # if "no_tem" in o.exp_config: # h_o_prev = torch.zeros_like(h_o_prev).cuda() # y_e_prev = torch.zeros_like(y_e_prev).cuda() # Sort h_o_prev and y_e_prev # Note: REP # delta = torch.arange(0, self.n_objects).float().cuda().unsqueeze(0) * 0.0001 # 1 * O # y_e_prev_mdf = y_e_prev.squeeze(2).round() - delta # perm_mat = self.permutation_matrix_calculator(y_e_prev_mdf) # N * O * O # h_o_prev = perm_mat.bmm(h_o_prev) # N * O * dim_h_o # y_e_prev = perm_mat.bmm(y_e_prev) # N * O * dim_y_e # TODO: mirar on es fa servir aixo. # Update h_o h_o_prev_split = torch.unbind(h_o_prev, 1) # N * dim_h_o h_o_split = {} k_split = {} r_split = {} for i in range(0, self.n_objects): self.ntm_cell.i = i # TODO: add h_o of previous object. h_o_split[i], C_o, k_split[i], r_split[i] = self.ntm_cell( h_o_prev_split[i], C_o) h_o = torch.stack(tuple(h_o_split.values()), dim=1) # N * O * dim_h_o k = torch.stack(tuple(k_split.values()), dim=1) # N * O * C2_2 r = torch.stack(tuple(r_split.values()), dim=1) # N * O * C2_2 # att = self.ntm_cell.att # mem = self.ntm_cell.mem # Recover the original order of h_o # Note: REP # perm_mat_inv = perm_mat.transpose(1, 2) # N * O * O # h_o = perm_mat_inv.bmm(h_o) # N * O * dim_h_o # k = perm_mat_inv.bmm(k) # N * O * dim_c_2 # r = perm_mat_inv.bmm(r) # N * O * dim_c_2 '''Visualization''' # att = perm_mat_inv.data[self.ntm_cell.n].mm(att.view(self.n_objects, -1)).view(self.n_objects, -1, self.ntm_cell.wa) # O * ha * wa # mem = perm_mat_inv.data[self.ntm_cell.n].mm(mem.view(self.n_objects, -1)).view(self.n_objects, -1, self.ntm_cell.wa) # O * ha * wa # if o.v > 0: # self.att[self.t].copy_(att) # self.mem[self.t].copy_(mem) # Generate outputs # h_o = smd.CheckBP('h_o', 0)(h_o) # TODO: Rethink this FCN. Sample before? FCN only for the sigmoids etc, and conv for shape and app? Also check the koopman, how does it influence. a = self.fcn(h_o.view(-1, self.dim_h_o)) # NO * dim_y a_e = a[:, :conf_dim] # NO * dim_y_e a_l = a[:, conf_dim:conf_dim + layer_dim] # NO * dim_y_l a_p = a[:, conf_dim + layer_dim:conf_dim + layer_dim + pose_dim] # NO * dim_y_p a_s = a[:, conf_dim + layer_dim + pose_dim:conf_dim + layer_dim + pose_dim + shape_dim] # NO * dim_Y_s a_a = a[:, conf_dim + layer_dim + pose_dim + shape_dim:] # NO * dim_Y_aa # y_e confidence? # a_e = smd.CheckBP('a_e', 0)(a_e) y_e = a_e.tanh().abs() y_e = y_e.view(-1, self.n_objects, conf_dim) # N * O * dim_y_e # y_l layer # a_l = smd.CheckBP('a_l', 0)(a_l) y_l = self.softmax(a_l) ut.norm_grad(y_l, 2) y_l = self.st_gumbel_softmax(y_l) y_l = y_l.view(-1, self.n_objects, layer_dim) # N * O * dim_y_l # y_p pose # a_p = smd.CheckBP('a_p', 0)(a_p) y_p = a_p.tanh() # y_p = a_p y_p = y_p.view(-1, self.n_objects, pose_dim) # N * O * dim_y_p # Y_s shape # a_s = smd.CheckBP('a_s', 0)(a_s) Y_s = a_s.sigmoid() Y_s = self.st_gumbel_sigmoid(Y_s) Y_s = Y_s.view(-1, self.n_objects, shape_dim) # N * O * 1 * h * w TODO: o. # Y_a appearance # a_a = smd.CheckBP('a_a', 0)(a_a) Y_a = a_a.sigmoid() Y_a = Y_a.view(-1, self.n_objects, app_dim) # N * O * D * h * w TODO: o. # adaptive computation time '''Adaptive computation''' # y_e_perm = perm_mat.bmm(y_e).round() # N * O * dim_y_e # y_e_mask = y_e_prev.round() + y_e_perm # N * O * dim_y_e # y_e_mask = y_e_mask.lt(0.5).type_as(y_e_mask) # y_e_mask = y_e_mask.cumsum(1) # y_e_mask = y_e_mask.lt(0.5).type_as(y_e_mask) # ones = torch.ones(y_e_mask.size(0), 1, conf_dim).cuda() # N * 1 * dim_y_e # y_e_mask = torch.cat((ones, y_e_mask[:, 0:self.n_objects-1]), dim=1) # y_e_mask = perm_mat_inv.bmm(y_e_mask) # N * O * dim_y_e # h_o = y_e_mask * (h_o - h_o_prev) + h_o_prev # N * O * dim_h_o # # h_o = y_e_mask * h_o # N * O * dim_h_o # y_e = y_e_mask * y_e # N * O * dim_y_e # y_p = y_e_mask * y_p # N * O * dim_y_p # Y_a = y_e_mask.view(-1, self.n_objects, conf_dim, 1, 1) * Y_a # N * O * D * h * w (o.dim_y_e) # if self.t == T - 1: # print(y_e.data.view(-1, self.n_objects)[0:1, 0:min(self.n_objects, 10)]) return h_o, y_e, y_l, y_p, Y_s, Y_a
def forward(self, h_o_prev, C): """ h_o_prev: N * dim_h_o C: N * C2_1 * C2_2 """ '''Visualize''' # if o.v > 0: # if self.i == 0: # self.att.fill_(0.5) # self.mem.fill_(0.5) # self.mem[self.i].copy_(C.data[n].mean(1).view(self.ha, self.wa)) '''Attention''' # Addressing key k = self.linear_k(h_o_prev) # N * C2_2 k_expand = k.unsqueeze(1).expand_as(C) # N * C2_1 * C2_2 # Key strength, which equals to beta_pre.exp().log1p() + 1 but avoids 'inf' caused by exp() beta_pre = self.linear_b(h_o_prev) beta_pos = beta_pre.clamp(min=0) beta_neg = beta_pre.clamp(max=0) beta = beta_neg.exp().log1p() + beta_pos + ( -beta_pos).exp().log1p() + (1 - np.log(2)) # N * 1 # Weighting C_cos = ut.Identity()(C) ut.norm_grad(C_cos, 1) s = self.cosine_similarity(C_cos, k_expand).view(-1, C.shape[1]) # N * C2_1 w = self.softmax(s * beta) # N * C2_1 '''Read Memory''' # Read vector w1 = w.unsqueeze(1) # N * 1 * C2_1 ut.norm_grad(w1, 1) r = w1.bmm(C).squeeze(1) # N * C2_2 # RNN h_o = self.rnn_cell(r, h_o_prev) '''Write Memory''' # Erase vector e = self.linear_e(h_o).sigmoid().unsqueeze(1) # N * 1 * C2_2 # Write vector v = self.linear_v(h_o).unsqueeze(1) # N * 1 * C2_2 # Update memory w2 = w.unsqueeze(2) # N * C2_1 * 1 C = C * (1 - w2.bmm(e)) + w2.bmm(v) # N * C2_1 * C2_2 '''Visualize''' # if o.v > 0: # self.att[self.i].copy_(w.data[n].view(self.ha, self.wa)) # # C in 3D shape # o.dim_C3_1 = o.cnn['out_sizes'][-1][0] # o.dim_C3_2 = o.cnn['out_sizes'][-1][1] # o.dim_C3_3 = o.cnn['conv_features'][-1] # o.dim_h_o = o.dim_C3_3 * 4 # # C in 2D shape # if "no_att" in o.exp_config: # o.dim_C2_1 = 1 # o.dim_C2_2 = o.dim_C3_1 * o.dim_C3_2 * o.dim_C3_3 # else: # o.dim_C2_1 = o.dim_C3_1 * o.dim_C3_2 # o.dim_C2_2 = o.dim_C3_3 # if "no_occ" in o.exp_config: # o.dim_y_l = 1 # o.dim_y_e = 1 # o.dim_y_p = 4 # o.dim_Y_s = 1 * o.h * o.w # o.dim_Y_a = o.D * o.h * o.w return h_o, C, k, r