コード例 #1
0
    def forward(self, img_seq):
        # Image to keypoints:
        if self.debug: print('Image shape: ', img_seq.shape)

        img_list = unstack_time(img_seq)

        keypoints_list = []
        heatmaps_list = []
        for img in img_list:
            img = ops.add_coord_channels(img)
            encoded = self.image_encoder(img)
            heatmaps = self.feats_heatmap(encoded)

            if self.debug: print("Heatmaps shape:", heatmaps.size())

            keypoints = ops.maps_to_keypoints(heatmaps)

            if self.debug: print("Keypoints shape:", keypoints.shape)
            if self.debug: print()

            heatmaps_list.append(heatmaps)
            keypoints_list.append(keypoints)

        keypoints_seq = stack_time(keypoints_list)
        heatmaps_seq = stack_time(heatmaps_list)

        return keypoints_seq, heatmaps_seq
コード例 #2
0
    def forward(self, keypoints_seq):
        """ keypoints: [batch_size, num_keypoints, 3] """

        if self.debug: print("Keypoints shape: ", keypoints_seq.shape)

        keypoints_list = unstack_time(keypoints_seq)

        reconstructed_img_list = []
        for keypoints in keypoints_list:
            gaussian_maps = self.keypoints_to_maps(keypoints)
            if self.debug: print("Gaussian Heatmap: ", gaussian_maps.shape)

            gaussian_maps = ops.add_coord_channels(gaussian_maps)
            if self.debug:
                print("Gaussian Heatmap with CoordConv: ", gaussian_maps.shape)

            gaussian_maps = self.adjust_channels_of_decoder_input(
                gaussian_maps)
            if self.debug:
                print("Gaussian Heatmap before decoder: ", gaussian_maps.shape)

            decoded_rep = self.image_decoder(gaussian_maps)
            if self.debug: print("Decoded Representation: ", decoded_rep.shape)

            reconstructed_img = self.adjust_channels_of_output_image(
                decoded_rep)
            if self.debug:
                print("Reconstructed Img: ", reconstructed_img.shape)

            if self.debug: print()

            reconstructed_img_list.append(reconstructed_img)

        reconstructed_img_seq = stack_time(reconstructed_img_list)
        return reconstructed_img_seq
コード例 #3
0
    def forward(self, keypoints_seq):
        keypoints_seq_list = unstack_time(keypoints_seq)

        output_keypoints_list = [None] * self.num_timesteps
        kl_div_list = [None] * self.cfg.observed_steps

        rnn_state = torch.zeros(
            [keypoints_seq.shape[0],
             self.cfg.num_rnn_units]).to(keypoints_seq.device)

        for t in range(self.cfg.observed_steps):
            output_keypoints_list[t], rnn_state, kl_div_list[
                t] = self.vrnn_iteration(self.cfg, keypoints_seq_list[t],
                                         rnn_state, self.rnn_cell,
                                         self.prior_net, self.decoder, None,
                                         self.posterior_net)

        for t in range(self.cfg.observed_steps, self.num_timesteps):
            output_keypoints_list[t], rnn_state, _ = self.vrnn_iteration(
                self.cfg, keypoints_seq_list[t], rnn_state, self.rnn_cell,
                self.prior_net, self.decoder, None, None)

        output_keypoints_seq = stack_time(output_keypoints_list)
        kl_div_seq = stack_time(kl_div_list)

        return output_keypoints_seq, kl_div_seq
コード例 #4
0
    def forward(self, keypoints_seq, first_frame, first_frame_keypoints):
        """ keypoints: [batch_size, T, num_keypoints, 3]
            first_frame: [batch_size, 3, IM_H, IM_W]
            first_frame_keypoints: [batch_size, num_keypoints, 3]
        """

        if self.debug: print("Keypoints shape: ", keypoints_seq.shape)

        first_frame_features = self.appearance_feature_extractor(
            first_frame)  # batch_size x 128 x 16 x 16
        first_frame_gaussian_maps = self.keypoints_to_maps(
            first_frame_keypoints)  # batch_size x 64 x 16 x 16

        keypoints_list = unstack_time(keypoints_seq)

        reconstructed_img_list = []
        for keypoints in keypoints_list:
            gaussian_maps = self.keypoints_to_maps(keypoints)
            if self.debug: print("Gaussian Heatmap: ", gaussian_maps.shape)

            combined_maps = torch.cat([
                gaussian_maps, first_frame_features, first_frame_gaussian_maps
            ],
                                      dim=1)

            combined_maps = ops.add_coord_channels(combined_maps)
            if self.debug:
                print("Gaussian Heatmap with CoordConv: ", combined_maps.shape)

            combined_maps = self.adjust_channels_of_decoder_input(
                combined_maps)
            if self.debug:
                print("Gaussian Heatmap before decoder: ", combined_maps.shape)

            decoded_rep = self.image_decoder(combined_maps)
            if self.debug: print("Decoded Representation: ", decoded_rep.shape)

            reconstructed_img = self.adjust_channels_of_output_image(
                decoded_rep)
            if self.debug:
                print("Reconstructed Img: ", reconstructed_img.shape)

            if self.debug: print()

            reconstructed_img_list.append(reconstructed_img)

        reconstructed_img_seq = stack_time(reconstructed_img_list)
        reconstructed_img_seq = reconstructed_img_seq + first_frame[:,
                                                                    None, :, :, :]
        return reconstructed_img_seq
コード例 #5
0
    def unroll(self, keypoints_seq, T_future):
        keypoints_seq_list = unstack_time(keypoints_seq)
        T_obs = len(keypoints_seq_list) // 2
        output_keypoints_list = [None] * (T_obs + T_future)

        rnn_state = torch.zeros(
            [keypoints_seq.shape[0],
             self.cfg.num_rnn_units]).to(keypoints_seq.device)
        for t in range(T_obs):
            output_keypoints_list[t], rnn_state, _ = self.vrnn_iteration(
                self.cfg, keypoints_seq_list[t], rnn_state, self.rnn_cell,
                self.prior_net, self.decoder, None, self.posterior_net)

        for t in range(T_obs, T_obs + T_future):
            output_keypoints_list[t], rnn_state, _ = self.vrnn_iteration(
                self.cfg, keypoints_seq_list[-1], rnn_state, self.rnn_cell,
                self.prior_net, self.decoder, None, None)

        output_keypoints_seq = stack_time(output_keypoints_list)

        return output_keypoints_seq