コード例 #1
0
    def forward(self, X_seq, **kwargs):
        o = self.o
        if 'X_bg_seq' in kwargs.keys():
            Y_b_seq = kwargs['X_bg_seq']

        # Extract features
        X_seq_cat = torch.cat((X_seq, Variable(self.coor.clone())),
                              2)  # N * T * D+2 * H * W
        C_o_seq = self.feature_extractor(X_seq_cat)  # N * T * M * R
        C_o_seq = smd.CheckBP('C_o_seq')(C_o_seq)

        # Update trackers
        h_o_prev, y_e_prev = self.load_states('h_o_prev', 'y_e_prev')
        h_o_seq, y_e_seq, y_l_seq, y_p_seq, Y_s_seq, Y_a_seq = self.tracker_array(
            h_o_prev, y_e_prev, C_o_seq)  # N * T * O * ...
        if o.r == 1:
            self.save_states(h_o_prev=h_o_seq, y_e_prev=y_e_seq)

        # Render the image using tracker outputs
        ka = {}
        if o.bg == 1:
            ka['Y_b'] = Y_b_seq
        X_r_seq, area = self.renderer(y_e_seq, y_l_seq, y_p_seq, Y_s_seq,
                                      Y_a_seq, **ka)  # N * T * D * H * W

        # Calculate the loss
        ka = {'y_e': y_e_seq}
        if o.bg == 0:
            ka['Y_a'] = Y_a_seq
        else:
            ka['Y_b'] = Y_b_seq
            if o.metric == 0:
                ka['y_p'] = y_p_seq
        loss = self.loss_calculator(X_r_seq, X_seq, area, **ka)
        loss = loss.sum() / (o.N * o.T)

        # Visualize
        if o.v > 0:
            ka = {
                'X': X_seq,
                'X_r': X_r_seq,
                'y_e': y_e_seq,
                'y_l': y_l_seq,
                'y_p': y_p_seq,
                'Y_s': Y_s_seq,
                'Y_a': Y_a_seq
            }
            if o.bg == 1:
                ka['Y_b'] = Y_b_seq
                if o.metric == 1:
                    ka['X_org'] = kwargs['X_org_seq']
            self.visualize(**ka)

        return loss
コード例 #2
0
    def forward(self, X_seq, **kwargs):
        o = self.o
        if 'X_bg_seq' in kwargs.keys():
            Y_b_seq = kwargs['X_bg_seq']

        # Extract features
        X_seq_cat = torch.cat((X_seq, self.coor.clone()),
                              2)  # N * T * D+2 * H * W
        C_o_seq = self.feature_extractor(X_seq_cat)  # N * T * M * R
        C_o_seq = smd.CheckBP('C_o_seq')(C_o_seq)

        # Update trackers
        h_o_prev, y_e_prev = self.load_states('h_o_prev', 'y_e_prev')
        h_o_seq, y_e_seq, y_l_seq, y_p_seq, Y_s_seq, Y_a_seq = self.tracker_array(
            h_o_prev, y_e_prev, C_o_seq)  # N * T * O * ...
        if o.r == 1:
            self.save_states(h_o_prev=h_o_seq, y_e_prev=y_e_seq)

        # Render the image using tracker outputs
        ka = {}
        if o.bg == 1:
            ka['Y_b'] = Y_b_seq
        X_r_seq, area, _ = self.renderer(y_e_seq, y_l_seq, y_p_seq, Y_s_seq,
                                         Y_a_seq, **ka)  # N * T * D * H * W

        # Calculate the loss
        ka = {'y_e': y_e_seq}
        if o.bg == 0:
            ka['Y_a'] = Y_a_seq
        else:
            ka['Y_b'] = Y_b_seq
            if o.metric == 0:
                ka['y_p'] = y_p_seq
        loss = self.loss_calculator(X_r_seq, X_seq, area, **ka)
        loss = loss.sum() / (o.N * o.T)

        # Visualize
        if o.v > 0:
            downsampled_pred = nn.functional.interpolate(X_r_seq.view(
                -1, o.D, o.H, o.W),
                                                         scale_factor=0.5)
            downsampled_target = nn.functional.interpolate(X_seq.view(
                -1, o.D, o.H, o.W),
                                                           scale_factor=0.5)

            loss = nn.functional.mse_loss(downsampled_pred,
                                          downsampled_target,
                                          reduction='sum')

            video = {}
            video['class'] = 'video'
            video['filename'] = 'video_id_{}'.format(o.batch_id)

            ka = {
                'X': X_seq,
                'X_r': X_r_seq,
                'y_e': y_e_seq,
                'y_l': y_l_seq,
                'y_p': y_p_seq,
                'Y_s': Y_s_seq,
                'Y_a': Y_a_seq
            }
            if o.bg == 1:
                ka['Y_b'] = Y_b_seq
                if o.metric == 1:
                    ka['X_org'] = kwargs['X_org_seq']
            frames = self.visualize(**ka)
            video['frames'] = frames

            return loss, video

        return loss
コード例 #3
0
    def forward(self, X_seq, **kwargs):
        o = self.o
        if 'X_bg_seq' in kwargs.keys():
            Y_b_seq = kwargs['X_bg_seq']

        X_base_img = None
        data, path, actions, phase = None, None, None, None
        coords_info = local_coords(o, o.H, o.W)
        if 'X_base_img' in kwargs.keys():
            X_base_img = kwargs['X_base_img']
            data, path, actions, phase = [X_base_img[i] for i in range(4)]
            _, H, W = data.shape
            #coords_info = global_coords(o, H, W, path, o.H, o.W, border=20)
            if (sum(actions[:, 0]) < 0):
                print('%%%%%%%%%%' * 5)
                self.reset_states()
                self.memory = None
                self.obs_loss = 0
                o.new_track = True

        # Extract features
        X_seq_cat = torch.cat((X_seq, Variable(coords_info.clone())),
                              2)  # N * T * D+2 * H * W
        C_o_seq = self.feature_extractor(X_seq_cat)  # N * T * M * R
        C_o_seq = smd.CheckBP('C_o_seq')(C_o_seq)

        # Update trackers
        h_o_prev, y_e_prev = self.load_states('h_o_prev', 'y_e_prev')
        if (phase == 'pred'):
            print('@@@@@$$$$$$$$' * 5)
            if (self.memory is None):
                self.memory = h_o_prev
            h_o_seq, y_e_seq, y_l_seq, y_p_seq, Y_s_seq, Y_a_seq = self.tracker_array(
                self.memory, y_e_prev, C_o_seq, path, phase)
            #results = self.tracker_array.ntm.generate_outputs(self.memory)
            #print([results[i].shape for i in range(len(results))])
        else:
            self.memory = None
            h_o_seq, y_e_seq, y_l_seq, y_p_seq, Y_s_seq, Y_a_seq = self.tracker_array(
                h_o_prev, y_e_prev, C_o_seq, path, phase)  # N * T * O * ...
        if o.r == 1:
            self.save_states(h_o_prev=h_o_seq, y_e_prev=y_e_seq)
        '''
        if(phase == 'obs'):
            n = 0
            if o.train == 0:
                save_dir = os.path.join(o.pic_dir, str(n))
                for t in range(0, o.T):
                    img = X_seq.data[n, t].permute(1, 2, 0).clamp(0, 1)
                    tao = o.batch_id * o.T + t
                    utils.mkdir(os.path.join(save_dir, 'input'))
                    utils.imwrite(img, os.path.join(save_dir, 'input', "%05d" % (tao)))
                utils.mkdir(os.path.join(save_dir, 'base'))
                data, path, actions, phase = X_base_img
                torch.save((data[n], path[n], actions[n]), os.path.join(save_dir, 'base', str(o.batch_id) + '.pt'))

            return None
        '''

        # Render the image
        ka = {}
        if o.bg == 1:
            ka['Y_b'] = Y_b_seq
        X_r_seq, area = self.renderer(y_e_seq, y_l_seq, y_p_seq, Y_s_seq,
                                      Y_a_seq, **ka)  # N * T * D * H * W
        if (o.train == 0):
            area = area.unsqueeze(0)

        # Calculate the loss
        ka = {'y_e': y_e_seq}
        if o.bg == 0:
            ka['Y_a'] = Y_a_seq
        else:
            ka['Y_b'] = Y_b_seq
            if o.metric == 0:
                ka['y_p'] = y_p_seq
        # print(X_r_seq.shape, X_seq.shape, area.shape, ka['y_e'].shape, ka['Y_a'].shape)
        loss = self.loss_calculator(X_r_seq, X_seq, area, **ka)
        loss = loss.sum() / (o.N * o.T)

        if (phase == 'obs'):
            n = 0
            if o.train == 0:
                save_dir = os.path.join(o.pic_dir, str(n))
                for t in range(0, o.T):
                    img = X_seq.data[n, t].permute(1, 2, 0).clamp(0, 1)
                    tao = o.batch_id * o.T + t
                    utils.mkdir(os.path.join(save_dir, 'input'))
                    utils.imwrite(
                        img, os.path.join(save_dir, 'input', "%05d" % (tao)))
                utils.mkdir(os.path.join(save_dir, 'base'))
                data, path, actions, phase = X_base_img
                torch.save((data[n], path[n], actions[n]),
                           os.path.join(save_dir, 'base',
                                        str(o.batch_id) + '.pt'))
            self.obs_loss += loss
            print('Loss: {}'.format(loss))
            return None

        # Visualize
        if o.v > 0:
            ka = {
                'X': X_seq,
                'X_r': X_r_seq,
                'y_e': y_e_seq,
                'y_l': y_l_seq,
                'y_p': y_p_seq,
                'Y_s': Y_s_seq,
                'Y_a': Y_a_seq
            }
            if X_base_img:
                ka['X_base_img'] = X_base_img
            if o.bg == 1:
                ka['Y_b'] = Y_b_seq
                if o.metric == 1:
                    ka['X_org'] = kwargs['X_org_seq']
            self.visualize(**ka)

        print('Obs_loss: {}'.format(self.obs_loss))
        return loss