def __process_affine(self, image, target, theta, nopoints, aux_info=None): image, target, theta = image.clone(), target.copy(), theta.clone() (C, H, W), (height, width) = image.size(), self.shape if nopoints: # do not have label norm_trans_points = torch.zeros((3, self.NUM_PTS)) heatmaps = torch.zeros((self.NUM_PTS + 1, height // self.downsample, width // self.downsample)) mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8) transpose_theta = identity2affine(False) else: norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W)) norm_trans_points = apply_boundary(norm_trans_points) real_trans_points = norm_trans_points.clone() real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2, :]) heatmaps, mask = generate_label_map(real_trans_points.numpy(), height // self.downsample, width // self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor) mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor) if self.mean_face is None: # warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.') transpose_theta = identity2affine(False) else: if torch.sum(norm_trans_points[2, :] == 1) < 3: warnings.warn( 'In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format( aux_info)) transpose_theta = identity2affine(False) else: transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone()) affineImage = affine2image(image, theta, self.shape) if self.cutout is not None: affineImage = self.cutout(affineImage) return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta
def calculate_temporal_loss(criterion, heatmaps, locs, past2now, future2now, FBcheck, mask, config): # return the calculate target from the first frame to the whole sequence. batch, frames, num_pts, _ = locs.size() assert batch == past2now.size(0) == future2now.size(0) == FBcheck.size( 0), '{:} vs {:} vs {:} vs {:}'.format(locs.size(), past2now.size(), future2now.size(), FBcheck.size()) assert num_pts == past2now.size(2) == future2now.size(2) == FBcheck.size( 1), '{:} vs {:} vs {:} vs {:}'.format(locs.size(), past2now.size(), future2now.size(), FBcheck.size()) assert frames - 1 == past2now.size(1) == future2now.size( 1), '{:} vs {:} vs {:} vs {:}'.format(locs.size(), past2now.size(), future2now.size(), FBcheck.size()) assert mask.dim() == 4 and mask.size(0) == batch and mask.size( 1) == num_pts, 'mask : {:}'.format(mask.size()) locs, past2now, future2now = locs.contiguous(), past2now.contiguous( ), future2now.contiguous() FBcheck, mask = FBcheck.contiguous(), mask.view(batch, num_pts).contiguous() with torch.no_grad(): past2now_l1_dis = criterion.loss_l1_func(locs[:, 1:], past2now, reduction='none') futu2now_l1_dis = criterion.loss_l1_func(locs[:, :-1], future2now, reduction='none') inmap_ok = get_in_map(locs).sum(1) == frames check_ok = torch.sqrt(FBcheck[:, :, 0]**2 + FBcheck[:, :, 1]**2) < config.fb_thresh distc_ok = (past2now_l1_dis.sum(-1) + futu2now_l1_dis.sum(-1)) / 4 < config.dis_thresh distc_ok = distc_ok.sum(1) == frames - 1 data_ok = (inmap_ok.view(batch, 1, num_pts, 1).int() + check_ok.view(batch, 1, num_pts, 1).int() + distc_ok.view(batch, 1, num_pts, 1).int() + mask.view(batch, 1, num_pts, 1).int()) == 4 if config.sbr_loss_type == 'L1': past_loss = criterion.loss_l1_func(locs[:, config.video_L], past2now[:, config.video_L - 1], reduction='none') future_loss = criterion.loss_l1_func(locs[:, config.video_L], future2now[:, config.video_L], reduction='none') temporal_loss = past_loss + future_loss final_loss = torch.masked_select(temporal_loss, data_ok) final_loss = torch.mean(final_loss) loss_string = '' elif config.sbr_loss_type == 'MSE': past_loss = criterion.loss_mse_func(locs[:, config.video_L], past2now[:, config.video_L - 1], reduction='none') future_loss = criterion.loss_mse_func(locs[:, config.video_L], future2now[:, config.video_L], reduction='none') temporal_loss = past_loss + future_loss final_loss = torch.masked_select(temporal_loss, data_ok) final_loss = torch.mean(final_loss) loss_string = '' elif config.sbr_loss_type == 'HEAT': H, W = heatmaps[0].size(-2), heatmaps[0].size(-1) identity_grid = F.affine_grid(identity2affine().cuda().view(1, 2, 3), torch.Size([1, 1, H, W]), align_corners=True) identity_grid = identity_grid.view(1, H, W, 2) # PAST past2now_grid = identity_grid - (past2now[:, config.video_L - 1] - locs[:, config.video_L - 1]).view( batch * num_pts, 1, 1, 2) past2now_heatmaps = [ x[:, config.video_L - 1].contiguous().view(batch * num_pts, 1, H, W) for x in heatmaps ] past2now_predicts = [ F.grid_sample(x, past2now_grid, mode='bilinear', align_corners=True).view(batch, num_pts, H, W) for x in past2now_heatmaps ] #past2now_grid_rev = identity_grid + (past2now[:,config.video_L-1]-locs[:,config.video_L-1]).view(batch*num_pts, 1, 1, 2) #past2now_predicts_rev = [F.grid_sample(x, past2now_grid_rev).view(batch, num_pts, H, W) for x in past2now_heatmaps] # FUTURE futu2now_grid = identity_grid - (future2now[:, config.video_L] - locs[:, config.video_L + 1]).view( batch * num_pts, 1, 1, 2) futu2now_heatmaps = [ x[:, config.video_L + 1].contiguous().view(batch * num_pts, 1, H, W) for x in heatmaps ] futu2now_predicts = [ F.grid_sample(x, futu2now_grid, mode='bilinear', align_corners=True).view(batch, num_pts, H, W) for x in futu2now_heatmaps ] #futu2now_grid_rev = identity_grid + (future2now[:,config.video_L]-locs[:,config.video_L+1]).view(batch*num_pts, 1, 1, 2) #futu2now_predicts_rev = [F.grid_sample(x, futu2now_grid_rev).view(batch, num_pts, H, W) for x in futu2now_heatmaps] heatmaps_targets = [ x[:, config.video_L].contiguous() for x in heatmaps ] data_ok = data_ok.view(batch, num_pts, 1, 1) loss_list, loss_string = [], '' for index in range(len(heatmaps_targets)): past_loss = criterion(past2now_predicts[index], heatmaps_targets[index], data_ok) futu_loss = criterion(futu2now_predicts[index], heatmaps_targets[index], data_ok) #past_lossR = criterion(past2now_predicts_rev[index], heatmaps_targets[index], data_ok) #futu_lossR = criterion(futu2now_predicts_rev[index], heatmaps_targets[index], data_ok) if index != 0: loss_string += ' ' loss_string += 'S{:}[P={:.6f}, F={:.6f}]'.format( index, past_loss.item(), futu_loss.item()) loss_list.append(past_loss) loss_list.append(futu_loss) #final_loss += past_loss + futu_loss final_loss = sum(loss_list) else: raise ValueError('invalid SBR loss type : {:}'.format( config.sbr_loss_type)) nums = torch.sum(data_ok).item() if nums == 0: return 0, nums, loss_string else: return final_loss, nums, loss_string
def calculate_multiview_loss(criterion, mv_heatmaps, mv_locs, proj_locs, masks, config): assert mv_locs.dim() == 4 and mv_locs.size( -1) == 2, 'invalid mv-locs size : {:}'.format(mv_locs.shape) assert mv_locs.size() == proj_locs.size(), '{:} vs {:}'.format( mv_locs.shape, proj_locs.shape) batch, cameras, num_pts, _ = mv_locs.size() with torch.no_grad(): inmap_ok1 = get_in_map(mv_locs).sum(1) == cameras inmap_ok2 = get_in_map(proj_locs).sum(1) == cameras inmap_ok = (inmap_ok1.int() + inmap_ok2.int()) == 2 stm_dis = criterion.loss_l1_func(mv_locs, proj_locs, reduction='none') distc_ok = stm_dis.sum(-1) / 2 < config.stm_dis_thresh data_ok = (inmap_ok.view(batch, 1, num_pts, 1).int() + distc_ok.view(batch, -1, num_pts, 1).int() + masks.view(batch, 1, num_pts, 1).int()) == 3 if config.sbt_loss_type == 'L1': stm_losses = criterion.loss_l1_func(mv_locs, proj_locs, reduction='none') final_loss = torch.masked_select(stm_losses, data_ok) final_loss = torch.mean(final_loss) elif config.sbt_loss_type == 'MSE': stm_losses = criterion.loss_mse_func(mv_locs, proj_locs, reduction='none') final_loss = torch.masked_select(stm_losses, data_ok) final_loss = torch.mean(final_loss) elif config.sbt_loss_type == 'HEAT': H, W = mv_heatmaps[0].size(-2), mv_heatmaps[0].size(-1) identity_grid = F.affine_grid(identity2affine().cuda().view(1, 2, 3), torch.Size([1, 1, H, W]), align_corners=True) multiview_grid = identity_grid + (proj_locs - mv_locs).view( batch * cameras * num_pts, 1, 1, 2) multiview_heatmaps = [ x.contiguous().view(batch * cameras * num_pts, 1, H, W) for x in mv_heatmaps ] multiview_predicts = [ F.grid_sample(x, multiview_grid, mode='bilinear', padding_mode='border', align_corners=True).view(batch, cameras, num_pts, H, W) for x in multiview_heatmaps ] data_ok, loss_list = data_ok.view(batch, cameras, num_pts, 1, 1), [] for index in range(len(multiview_predicts)): mv_loss = criterion(multiview_predicts[index], mv_heatmaps[index], data_ok) loss_list.append(mv_loss) final_loss = sum(loss_list) else: raise ValueError('invalid SBT loss type : {:}'.format( config.sbt_loss_type)) nums = torch.sum(data_ok).item() * 1.0 / cameras if nums == 0: return 0, nums else: return final_loss, nums
def __process_affine(self, frames, target, theta, nopoints, skip_opt, aux_info=None): frames, target, theta = [frame.clone() for frame in frames ], target.copy(), theta.clone() (C, H, W), (height, width) = frames[0].size(), self.shape if nopoints: # do not have label norm_trans_points = torch.zeros((3, self.NUM_PTS)) heatmaps = torch.zeros( (self.NUM_PTS + 1, height // self.downsample, width // self.downsample)) mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8) transpose_theta = identity2affine(False) else: norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W)) norm_trans_points = apply_boundary(norm_trans_points) real_trans_points = norm_trans_points.clone() real_trans_points[:2, :] = denormalize_points( self.shape, real_trans_points[:2, :]) heatmaps, mask = generate_label_map(real_trans_points.numpy(), height // self.downsample, width // self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C heatmaps = torch.from_numpy(heatmaps.transpose( (2, 0, 1))).type(torch.FloatTensor) mask = torch.from_numpy(mask.transpose( (2, 0, 1))).type(torch.ByteTensor) if torch.sum(norm_trans_points[2, :] == 1) < 3 or self.mean_face is None: warnings.warn( 'In GeneralDatasetV2 after transformation, no visiable point, using identity instead. Aux: {:}' .format(aux_info)) transpose_theta = identity2affine(False) else: transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone()) affineFrames = [ affine2image(frame, theta, self.shape) for frame in frames ] if not skip_opt: Gframes = [self.tensor2img(frame) for frame in affineFrames] forward_flow, backward_flow = [], [] for idx in range(len(Gframes)): if idx > 0: forward_flow.append( self.optflow.calc(Gframes[idx - 1], Gframes[idx], None)) if idx + 1 < len(Gframes): #backward_flow.append( self.optflow.calc(Gframes[idx], Gframes[idx+1], None) ) backward_flow.append( self.optflow.calc(Gframes[idx + 1], Gframes[idx], None)) forward_flow = torch.stack( [torch.from_numpy(x) for x in forward_flow]) backward_flow = torch.stack( [torch.from_numpy(x) for x in backward_flow]) else: forward_flow, backward_flow = torch.zeros( (len(affineFrames) - 1, height, width, 2)), torch.zeros( (len(affineFrames) - 1, height, width, 2)) # affineFrames #frames x #channel x #height x #width # forward_flow (#frames-1) x #height x #width x 2 # backward_flow (#frames-1) x #height x #width x 2 return torch.stack( affineFrames ), forward_flow, backward_flow, heatmaps, mask, norm_trans_points, theta, transpose_theta