def forward(self, X_seq, **kwargs): o = self.o if 'X_bg_seq' in kwargs.keys(): Y_b_seq = kwargs['X_bg_seq'] # Extract features X_seq_cat = torch.cat((X_seq, Variable(self.coor.clone())), 2) # N * T * D+2 * H * W C_o_seq = self.feature_extractor(X_seq_cat) # N * T * M * R C_o_seq = smd.CheckBP('C_o_seq')(C_o_seq) # Update trackers h_o_prev, y_e_prev = self.load_states('h_o_prev', 'y_e_prev') h_o_seq, y_e_seq, y_l_seq, y_p_seq, Y_s_seq, Y_a_seq = self.tracker_array( h_o_prev, y_e_prev, C_o_seq) # N * T * O * ... if o.r == 1: self.save_states(h_o_prev=h_o_seq, y_e_prev=y_e_seq) # Render the image using tracker outputs ka = {} if o.bg == 1: ka['Y_b'] = Y_b_seq X_r_seq, area = self.renderer(y_e_seq, y_l_seq, y_p_seq, Y_s_seq, Y_a_seq, **ka) # N * T * D * H * W # Calculate the loss ka = {'y_e': y_e_seq} if o.bg == 0: ka['Y_a'] = Y_a_seq else: ka['Y_b'] = Y_b_seq if o.metric == 0: ka['y_p'] = y_p_seq loss = self.loss_calculator(X_r_seq, X_seq, area, **ka) loss = loss.sum() / (o.N * o.T) # Visualize if o.v > 0: ka = { 'X': X_seq, 'X_r': X_r_seq, 'y_e': y_e_seq, 'y_l': y_l_seq, 'y_p': y_p_seq, 'Y_s': Y_s_seq, 'Y_a': Y_a_seq } if o.bg == 1: ka['Y_b'] = Y_b_seq if o.metric == 1: ka['X_org'] = kwargs['X_org_seq'] self.visualize(**ka) return loss
def forward(self, X_seq, **kwargs): o = self.o if 'X_bg_seq' in kwargs.keys(): Y_b_seq = kwargs['X_bg_seq'] # Extract features X_seq_cat = torch.cat((X_seq, self.coor.clone()), 2) # N * T * D+2 * H * W C_o_seq = self.feature_extractor(X_seq_cat) # N * T * M * R C_o_seq = smd.CheckBP('C_o_seq')(C_o_seq) # Update trackers h_o_prev, y_e_prev = self.load_states('h_o_prev', 'y_e_prev') h_o_seq, y_e_seq, y_l_seq, y_p_seq, Y_s_seq, Y_a_seq = self.tracker_array( h_o_prev, y_e_prev, C_o_seq) # N * T * O * ... if o.r == 1: self.save_states(h_o_prev=h_o_seq, y_e_prev=y_e_seq) # Render the image using tracker outputs ka = {} if o.bg == 1: ka['Y_b'] = Y_b_seq X_r_seq, area, _ = self.renderer(y_e_seq, y_l_seq, y_p_seq, Y_s_seq, Y_a_seq, **ka) # N * T * D * H * W # Calculate the loss ka = {'y_e': y_e_seq} if o.bg == 0: ka['Y_a'] = Y_a_seq else: ka['Y_b'] = Y_b_seq if o.metric == 0: ka['y_p'] = y_p_seq loss = self.loss_calculator(X_r_seq, X_seq, area, **ka) loss = loss.sum() / (o.N * o.T) # Visualize if o.v > 0: downsampled_pred = nn.functional.interpolate(X_r_seq.view( -1, o.D, o.H, o.W), scale_factor=0.5) downsampled_target = nn.functional.interpolate(X_seq.view( -1, o.D, o.H, o.W), scale_factor=0.5) loss = nn.functional.mse_loss(downsampled_pred, downsampled_target, reduction='sum') video = {} video['class'] = 'video' video['filename'] = 'video_id_{}'.format(o.batch_id) ka = { 'X': X_seq, 'X_r': X_r_seq, 'y_e': y_e_seq, 'y_l': y_l_seq, 'y_p': y_p_seq, 'Y_s': Y_s_seq, 'Y_a': Y_a_seq } if o.bg == 1: ka['Y_b'] = Y_b_seq if o.metric == 1: ka['X_org'] = kwargs['X_org_seq'] frames = self.visualize(**ka) video['frames'] = frames return loss, video return loss
def forward(self, X_seq, **kwargs): o = self.o if 'X_bg_seq' in kwargs.keys(): Y_b_seq = kwargs['X_bg_seq'] X_base_img = None data, path, actions, phase = None, None, None, None coords_info = local_coords(o, o.H, o.W) if 'X_base_img' in kwargs.keys(): X_base_img = kwargs['X_base_img'] data, path, actions, phase = [X_base_img[i] for i in range(4)] _, H, W = data.shape #coords_info = global_coords(o, H, W, path, o.H, o.W, border=20) if (sum(actions[:, 0]) < 0): print('%%%%%%%%%%' * 5) self.reset_states() self.memory = None self.obs_loss = 0 o.new_track = True # Extract features X_seq_cat = torch.cat((X_seq, Variable(coords_info.clone())), 2) # N * T * D+2 * H * W C_o_seq = self.feature_extractor(X_seq_cat) # N * T * M * R C_o_seq = smd.CheckBP('C_o_seq')(C_o_seq) # Update trackers h_o_prev, y_e_prev = self.load_states('h_o_prev', 'y_e_prev') if (phase == 'pred'): print('@@@@@$$$$$$$$' * 5) if (self.memory is None): self.memory = h_o_prev h_o_seq, y_e_seq, y_l_seq, y_p_seq, Y_s_seq, Y_a_seq = self.tracker_array( self.memory, y_e_prev, C_o_seq, path, phase) #results = self.tracker_array.ntm.generate_outputs(self.memory) #print([results[i].shape for i in range(len(results))]) else: self.memory = None h_o_seq, y_e_seq, y_l_seq, y_p_seq, Y_s_seq, Y_a_seq = self.tracker_array( h_o_prev, y_e_prev, C_o_seq, path, phase) # N * T * O * ... if o.r == 1: self.save_states(h_o_prev=h_o_seq, y_e_prev=y_e_seq) ''' if(phase == 'obs'): n = 0 if o.train == 0: save_dir = os.path.join(o.pic_dir, str(n)) for t in range(0, o.T): img = X_seq.data[n, t].permute(1, 2, 0).clamp(0, 1) tao = o.batch_id * o.T + t utils.mkdir(os.path.join(save_dir, 'input')) utils.imwrite(img, os.path.join(save_dir, 'input', "%05d" % (tao))) utils.mkdir(os.path.join(save_dir, 'base')) data, path, actions, phase = X_base_img torch.save((data[n], path[n], actions[n]), os.path.join(save_dir, 'base', str(o.batch_id) + '.pt')) return None ''' # Render the image ka = {} if o.bg == 1: ka['Y_b'] = Y_b_seq X_r_seq, area = self.renderer(y_e_seq, y_l_seq, y_p_seq, Y_s_seq, Y_a_seq, **ka) # N * T * D * H * W if (o.train == 0): area = area.unsqueeze(0) # Calculate the loss ka = {'y_e': y_e_seq} if o.bg == 0: ka['Y_a'] = Y_a_seq else: ka['Y_b'] = Y_b_seq if o.metric == 0: ka['y_p'] = y_p_seq # print(X_r_seq.shape, X_seq.shape, area.shape, ka['y_e'].shape, ka['Y_a'].shape) loss = self.loss_calculator(X_r_seq, X_seq, area, **ka) loss = loss.sum() / (o.N * o.T) if (phase == 'obs'): n = 0 if o.train == 0: save_dir = os.path.join(o.pic_dir, str(n)) for t in range(0, o.T): img = X_seq.data[n, t].permute(1, 2, 0).clamp(0, 1) tao = o.batch_id * o.T + t utils.mkdir(os.path.join(save_dir, 'input')) utils.imwrite( img, os.path.join(save_dir, 'input', "%05d" % (tao))) utils.mkdir(os.path.join(save_dir, 'base')) data, path, actions, phase = X_base_img torch.save((data[n], path[n], actions[n]), os.path.join(save_dir, 'base', str(o.batch_id) + '.pt')) self.obs_loss += loss print('Loss: {}'.format(loss)) return None # Visualize if o.v > 0: ka = { 'X': X_seq, 'X_r': X_r_seq, 'y_e': y_e_seq, 'y_l': y_l_seq, 'y_p': y_p_seq, 'Y_s': Y_s_seq, 'Y_a': Y_a_seq } if X_base_img: ka['X_base_img'] = X_base_img if o.bg == 1: ka['Y_b'] = Y_b_seq if o.metric == 1: ka['X_org'] = kwargs['X_org_seq'] self.visualize(**ka) print('Obs_loss: {}'.format(self.obs_loss)) return loss