def forward(self, img_seq): # Image to keypoints: if self.debug: print('Image shape: ', img_seq.shape) img_list = unstack_time(img_seq) keypoints_list = [] heatmaps_list = [] for img in img_list: img = ops.add_coord_channels(img) encoded = self.image_encoder(img) heatmaps = self.feats_heatmap(encoded) if self.debug: print("Heatmaps shape:", heatmaps.size()) keypoints = ops.maps_to_keypoints(heatmaps) if self.debug: print("Keypoints shape:", keypoints.shape) if self.debug: print() heatmaps_list.append(heatmaps) keypoints_list.append(keypoints) keypoints_seq = stack_time(keypoints_list) heatmaps_seq = stack_time(heatmaps_list) return keypoints_seq, heatmaps_seq
def forward(self, keypoints_seq): """ keypoints: [batch_size, num_keypoints, 3] """ if self.debug: print("Keypoints shape: ", keypoints_seq.shape) keypoints_list = unstack_time(keypoints_seq) reconstructed_img_list = [] for keypoints in keypoints_list: gaussian_maps = self.keypoints_to_maps(keypoints) if self.debug: print("Gaussian Heatmap: ", gaussian_maps.shape) gaussian_maps = ops.add_coord_channels(gaussian_maps) if self.debug: print("Gaussian Heatmap with CoordConv: ", gaussian_maps.shape) gaussian_maps = self.adjust_channels_of_decoder_input( gaussian_maps) if self.debug: print("Gaussian Heatmap before decoder: ", gaussian_maps.shape) decoded_rep = self.image_decoder(gaussian_maps) if self.debug: print("Decoded Representation: ", decoded_rep.shape) reconstructed_img = self.adjust_channels_of_output_image( decoded_rep) if self.debug: print("Reconstructed Img: ", reconstructed_img.shape) if self.debug: print() reconstructed_img_list.append(reconstructed_img) reconstructed_img_seq = stack_time(reconstructed_img_list) return reconstructed_img_seq
def forward(self, keypoints_seq): keypoints_seq_list = unstack_time(keypoints_seq) output_keypoints_list = [None] * self.num_timesteps kl_div_list = [None] * self.cfg.observed_steps rnn_state = torch.zeros( [keypoints_seq.shape[0], self.cfg.num_rnn_units]).to(keypoints_seq.device) for t in range(self.cfg.observed_steps): output_keypoints_list[t], rnn_state, kl_div_list[ t] = self.vrnn_iteration(self.cfg, keypoints_seq_list[t], rnn_state, self.rnn_cell, self.prior_net, self.decoder, None, self.posterior_net) for t in range(self.cfg.observed_steps, self.num_timesteps): output_keypoints_list[t], rnn_state, _ = self.vrnn_iteration( self.cfg, keypoints_seq_list[t], rnn_state, self.rnn_cell, self.prior_net, self.decoder, None, None) output_keypoints_seq = stack_time(output_keypoints_list) kl_div_seq = stack_time(kl_div_list) return output_keypoints_seq, kl_div_seq
def forward(self, keypoints_seq, first_frame, first_frame_keypoints): """ keypoints: [batch_size, T, num_keypoints, 3] first_frame: [batch_size, 3, IM_H, IM_W] first_frame_keypoints: [batch_size, num_keypoints, 3] """ if self.debug: print("Keypoints shape: ", keypoints_seq.shape) first_frame_features = self.appearance_feature_extractor( first_frame) # batch_size x 128 x 16 x 16 first_frame_gaussian_maps = self.keypoints_to_maps( first_frame_keypoints) # batch_size x 64 x 16 x 16 keypoints_list = unstack_time(keypoints_seq) reconstructed_img_list = [] for keypoints in keypoints_list: gaussian_maps = self.keypoints_to_maps(keypoints) if self.debug: print("Gaussian Heatmap: ", gaussian_maps.shape) combined_maps = torch.cat([ gaussian_maps, first_frame_features, first_frame_gaussian_maps ], dim=1) combined_maps = ops.add_coord_channels(combined_maps) if self.debug: print("Gaussian Heatmap with CoordConv: ", combined_maps.shape) combined_maps = self.adjust_channels_of_decoder_input( combined_maps) if self.debug: print("Gaussian Heatmap before decoder: ", combined_maps.shape) decoded_rep = self.image_decoder(combined_maps) if self.debug: print("Decoded Representation: ", decoded_rep.shape) reconstructed_img = self.adjust_channels_of_output_image( decoded_rep) if self.debug: print("Reconstructed Img: ", reconstructed_img.shape) if self.debug: print() reconstructed_img_list.append(reconstructed_img) reconstructed_img_seq = stack_time(reconstructed_img_list) reconstructed_img_seq = reconstructed_img_seq + first_frame[:, None, :, :, :] return reconstructed_img_seq
def unroll(self, keypoints_seq, T_future): keypoints_seq_list = unstack_time(keypoints_seq) T_obs = len(keypoints_seq_list) // 2 output_keypoints_list = [None] * (T_obs + T_future) rnn_state = torch.zeros( [keypoints_seq.shape[0], self.cfg.num_rnn_units]).to(keypoints_seq.device) for t in range(T_obs): output_keypoints_list[t], rnn_state, _ = self.vrnn_iteration( self.cfg, keypoints_seq_list[t], rnn_state, self.rnn_cell, self.prior_net, self.decoder, None, self.posterior_net) for t in range(T_obs, T_obs + T_future): output_keypoints_list[t], rnn_state, _ = self.vrnn_iteration( self.cfg, keypoints_seq_list[-1], rnn_state, self.rnn_cell, self.prior_net, self.decoder, None, None) output_keypoints_seq = stack_time(output_keypoints_list) return output_keypoints_seq