Пример #1
0
    def __process_affine(self, image, target, theta, nopoints, aux_info=None):
        image, target, theta = image.clone(), target.copy(), theta.clone()
        (C, H, W), (height, width) = image.size(), self.shape
        if nopoints:  # do not have label
            norm_trans_points = torch.zeros((3, self.NUM_PTS))
            heatmaps = torch.zeros((self.NUM_PTS + 1, height // self.downsample, width // self.downsample))
            mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8)
            transpose_theta = identity2affine(False)
        else:
            norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W))
            norm_trans_points = apply_boundary(norm_trans_points)
            real_trans_points = norm_trans_points.clone()
            real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2, :])
            heatmaps, mask = generate_label_map(real_trans_points.numpy(), height // self.downsample,
                                                width // self.downsample, self.sigma, self.downsample, nopoints,
                                                self.heatmap_type)  # H*W*C
            heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor)
            mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
            if self.mean_face is None:
                # warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
                transpose_theta = identity2affine(False)
            else:
                if torch.sum(norm_trans_points[2, :] == 1) < 3:
                    warnings.warn(
                        'In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format(
                            aux_info))
                    transpose_theta = identity2affine(False)
                else:
                    transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone())

        affineImage = affine2image(image, theta, self.shape)
        if self.cutout is not None: affineImage = self.cutout(affineImage)

        return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta
def calculate_temporal_loss(criterion, heatmaps, locs, past2now, future2now,
                            FBcheck, mask, config):
    # return the calculate target from the first frame to the whole sequence.
    batch, frames, num_pts, _ = locs.size()
    assert batch == past2now.size(0) == future2now.size(0) == FBcheck.size(
        0), '{:} vs {:} vs {:} vs {:}'.format(locs.size(), past2now.size(),
                                              future2now.size(),
                                              FBcheck.size())
    assert num_pts == past2now.size(2) == future2now.size(2) == FBcheck.size(
        1), '{:} vs {:} vs {:} vs {:}'.format(locs.size(), past2now.size(),
                                              future2now.size(),
                                              FBcheck.size())
    assert frames - 1 == past2now.size(1) == future2now.size(
        1), '{:} vs {:} vs {:} vs {:}'.format(locs.size(), past2now.size(),
                                              future2now.size(),
                                              FBcheck.size())
    assert mask.dim() == 4 and mask.size(0) == batch and mask.size(
        1) == num_pts, 'mask : {:}'.format(mask.size())

    locs, past2now, future2now = locs.contiguous(), past2now.contiguous(
    ), future2now.contiguous()
    FBcheck, mask = FBcheck.contiguous(), mask.view(batch,
                                                    num_pts).contiguous()
    with torch.no_grad():
        past2now_l1_dis = criterion.loss_l1_func(locs[:, 1:],
                                                 past2now,
                                                 reduction='none')
        futu2now_l1_dis = criterion.loss_l1_func(locs[:, :-1],
                                                 future2now,
                                                 reduction='none')

        inmap_ok = get_in_map(locs).sum(1) == frames
        check_ok = torch.sqrt(FBcheck[:, :, 0]**2 +
                              FBcheck[:, :, 1]**2) < config.fb_thresh
        distc_ok = (past2now_l1_dis.sum(-1) +
                    futu2now_l1_dis.sum(-1)) / 4 < config.dis_thresh
        distc_ok = distc_ok.sum(1) == frames - 1
        data_ok = (inmap_ok.view(batch, 1, num_pts, 1).int() +
                   check_ok.view(batch, 1, num_pts, 1).int() +
                   distc_ok.view(batch, 1, num_pts, 1).int() +
                   mask.view(batch, 1, num_pts, 1).int()) == 4

    if config.sbr_loss_type == 'L1':
        past_loss = criterion.loss_l1_func(locs[:, config.video_L],
                                           past2now[:, config.video_L - 1],
                                           reduction='none')
        future_loss = criterion.loss_l1_func(locs[:, config.video_L],
                                             future2now[:, config.video_L],
                                             reduction='none')
        temporal_loss = past_loss + future_loss
        final_loss = torch.masked_select(temporal_loss, data_ok)
        final_loss = torch.mean(final_loss)
        loss_string = ''
    elif config.sbr_loss_type == 'MSE':
        past_loss = criterion.loss_mse_func(locs[:, config.video_L],
                                            past2now[:, config.video_L - 1],
                                            reduction='none')
        future_loss = criterion.loss_mse_func(locs[:, config.video_L],
                                              future2now[:, config.video_L],
                                              reduction='none')
        temporal_loss = past_loss + future_loss
        final_loss = torch.masked_select(temporal_loss, data_ok)
        final_loss = torch.mean(final_loss)
        loss_string = ''
    elif config.sbr_loss_type == 'HEAT':
        H, W = heatmaps[0].size(-2), heatmaps[0].size(-1)
        identity_grid = F.affine_grid(identity2affine().cuda().view(1, 2, 3),
                                      torch.Size([1, 1, H, W]),
                                      align_corners=True)
        identity_grid = identity_grid.view(1, H, W, 2)
        # PAST
        past2now_grid = identity_grid - (past2now[:, config.video_L - 1] -
                                         locs[:, config.video_L - 1]).view(
                                             batch * num_pts, 1, 1, 2)
        past2now_heatmaps = [
            x[:, config.video_L - 1].contiguous().view(batch * num_pts, 1, H,
                                                       W) for x in heatmaps
        ]
        past2now_predicts = [
            F.grid_sample(x,
                          past2now_grid,
                          mode='bilinear',
                          align_corners=True).view(batch, num_pts, H, W)
            for x in past2now_heatmaps
        ]
        #past2now_grid_rev  = identity_grid + (past2now[:,config.video_L-1]-locs[:,config.video_L-1]).view(batch*num_pts, 1, 1, 2)
        #past2now_predicts_rev = [F.grid_sample(x, past2now_grid_rev).view(batch, num_pts, H, W) for x in past2now_heatmaps]
        # FUTURE
        futu2now_grid = identity_grid - (future2now[:, config.video_L] -
                                         locs[:, config.video_L + 1]).view(
                                             batch * num_pts, 1, 1, 2)
        futu2now_heatmaps = [
            x[:, config.video_L + 1].contiguous().view(batch * num_pts, 1, H,
                                                       W) for x in heatmaps
        ]
        futu2now_predicts = [
            F.grid_sample(x,
                          futu2now_grid,
                          mode='bilinear',
                          align_corners=True).view(batch, num_pts, H, W)
            for x in futu2now_heatmaps
        ]
        #futu2now_grid_rev  = identity_grid + (future2now[:,config.video_L]-locs[:,config.video_L+1]).view(batch*num_pts, 1, 1, 2)
        #futu2now_predicts_rev = [F.grid_sample(x, futu2now_grid_rev).view(batch, num_pts, H, W) for x in futu2now_heatmaps]

        heatmaps_targets = [
            x[:, config.video_L].contiguous() for x in heatmaps
        ]

        data_ok = data_ok.view(batch, num_pts, 1, 1)
        loss_list, loss_string = [], ''
        for index in range(len(heatmaps_targets)):
            past_loss = criterion(past2now_predicts[index],
                                  heatmaps_targets[index], data_ok)
            futu_loss = criterion(futu2now_predicts[index],
                                  heatmaps_targets[index], data_ok)
            #past_lossR = criterion(past2now_predicts_rev[index], heatmaps_targets[index], data_ok)
            #futu_lossR = criterion(futu2now_predicts_rev[index], heatmaps_targets[index], data_ok)
            if index != 0: loss_string += ' '
            loss_string += 'S{:}[P={:.6f}, F={:.6f}]'.format(
                index, past_loss.item(), futu_loss.item())
            loss_list.append(past_loss)
            loss_list.append(futu_loss)
            #final_loss += past_loss + futu_loss
        final_loss = sum(loss_list)
    else:
        raise ValueError('invalid SBR loss type : {:}'.format(
            config.sbr_loss_type))

    nums = torch.sum(data_ok).item()
    if nums == 0: return 0, nums, loss_string
    else: return final_loss, nums, loss_string
Пример #3
0
def calculate_multiview_loss(criterion, mv_heatmaps, mv_locs, proj_locs, masks,
                             config):
    assert mv_locs.dim() == 4 and mv_locs.size(
        -1) == 2, 'invalid mv-locs size : {:}'.format(mv_locs.shape)
    assert mv_locs.size() == proj_locs.size(), '{:} vs {:}'.format(
        mv_locs.shape, proj_locs.shape)
    batch, cameras, num_pts, _ = mv_locs.size()

    with torch.no_grad():
        inmap_ok1 = get_in_map(mv_locs).sum(1) == cameras
        inmap_ok2 = get_in_map(proj_locs).sum(1) == cameras
        inmap_ok = (inmap_ok1.int() + inmap_ok2.int()) == 2
        stm_dis = criterion.loss_l1_func(mv_locs, proj_locs, reduction='none')
        distc_ok = stm_dis.sum(-1) / 2 < config.stm_dis_thresh
        data_ok = (inmap_ok.view(batch, 1, num_pts, 1).int() +
                   distc_ok.view(batch, -1, num_pts, 1).int() +
                   masks.view(batch, 1, num_pts, 1).int()) == 3

    if config.sbt_loss_type == 'L1':
        stm_losses = criterion.loss_l1_func(mv_locs,
                                            proj_locs,
                                            reduction='none')
        final_loss = torch.masked_select(stm_losses, data_ok)
        final_loss = torch.mean(final_loss)
    elif config.sbt_loss_type == 'MSE':
        stm_losses = criterion.loss_mse_func(mv_locs,
                                             proj_locs,
                                             reduction='none')
        final_loss = torch.masked_select(stm_losses, data_ok)
        final_loss = torch.mean(final_loss)
    elif config.sbt_loss_type == 'HEAT':
        H, W = mv_heatmaps[0].size(-2), mv_heatmaps[0].size(-1)
        identity_grid = F.affine_grid(identity2affine().cuda().view(1, 2, 3),
                                      torch.Size([1, 1, H, W]),
                                      align_corners=True)
        multiview_grid = identity_grid + (proj_locs - mv_locs).view(
            batch * cameras * num_pts, 1, 1, 2)
        multiview_heatmaps = [
            x.contiguous().view(batch * cameras * num_pts, 1, H, W)
            for x in mv_heatmaps
        ]
        multiview_predicts = [
            F.grid_sample(x,
                          multiview_grid,
                          mode='bilinear',
                          padding_mode='border',
                          align_corners=True).view(batch, cameras, num_pts, H,
                                                   W)
            for x in multiview_heatmaps
        ]

        data_ok, loss_list = data_ok.view(batch, cameras, num_pts, 1, 1), []
        for index in range(len(multiview_predicts)):
            mv_loss = criterion(multiview_predicts[index], mv_heatmaps[index],
                                data_ok)
            loss_list.append(mv_loss)
        final_loss = sum(loss_list)
    else:
        raise ValueError('invalid SBT loss type : {:}'.format(
            config.sbt_loss_type))

    nums = torch.sum(data_ok).item() * 1.0 / cameras
    if nums == 0: return 0, nums
    else: return final_loss, nums
Пример #4
0
    def __process_affine(self,
                         frames,
                         target,
                         theta,
                         nopoints,
                         skip_opt,
                         aux_info=None):
        frames, target, theta = [frame.clone() for frame in frames
                                 ], target.copy(), theta.clone()
        (C, H, W), (height, width) = frames[0].size(), self.shape
        if nopoints:  # do not have label
            norm_trans_points = torch.zeros((3, self.NUM_PTS))
            heatmaps = torch.zeros(
                (self.NUM_PTS + 1, height // self.downsample,
                 width // self.downsample))
            mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8)
            transpose_theta = identity2affine(False)
        else:
            norm_trans_points = apply_affine2point(target.get_points(), theta,
                                                   (H, W))
            norm_trans_points = apply_boundary(norm_trans_points)
            real_trans_points = norm_trans_points.clone()
            real_trans_points[:2, :] = denormalize_points(
                self.shape, real_trans_points[:2, :])
            heatmaps, mask = generate_label_map(real_trans_points.numpy(),
                                                height // self.downsample,
                                                width // self.downsample,
                                                self.sigma, self.downsample,
                                                nopoints,
                                                self.heatmap_type)  # H*W*C
            heatmaps = torch.from_numpy(heatmaps.transpose(
                (2, 0, 1))).type(torch.FloatTensor)
            mask = torch.from_numpy(mask.transpose(
                (2, 0, 1))).type(torch.ByteTensor)
            if torch.sum(norm_trans_points[2, :] ==
                         1) < 3 or self.mean_face is None:
                warnings.warn(
                    'In GeneralDatasetV2 after transformation, no visiable point, using identity instead. Aux: {:}'
                    .format(aux_info))
                transpose_theta = identity2affine(False)
            else:
                transpose_theta = solve2theta(norm_trans_points,
                                              self.mean_face.clone())

        affineFrames = [
            affine2image(frame, theta, self.shape) for frame in frames
        ]

        if not skip_opt:
            Gframes = [self.tensor2img(frame) for frame in affineFrames]
            forward_flow, backward_flow = [], []
            for idx in range(len(Gframes)):
                if idx > 0:
                    forward_flow.append(
                        self.optflow.calc(Gframes[idx - 1], Gframes[idx],
                                          None))
                if idx + 1 < len(Gframes):
                    #backward_flow.append( self.optflow.calc(Gframes[idx], Gframes[idx+1], None) )
                    backward_flow.append(
                        self.optflow.calc(Gframes[idx + 1], Gframes[idx],
                                          None))
            forward_flow = torch.stack(
                [torch.from_numpy(x) for x in forward_flow])
            backward_flow = torch.stack(
                [torch.from_numpy(x) for x in backward_flow])
        else:
            forward_flow, backward_flow = torch.zeros(
                (len(affineFrames) - 1, height, width, 2)), torch.zeros(
                    (len(affineFrames) - 1, height, width, 2))
        # affineFrames  #frames x #channel x #height x #width
        # forward_flow  (#frames-1) x #height x #width x 2
        # backward_flow (#frames-1) x #height x #width x 2
        return torch.stack(
            affineFrames
        ), forward_flow, backward_flow, heatmaps, mask, norm_trans_points, theta, transpose_theta