Пример #1
0
def run_evaluation(model, opt, options, dataset_name, log_freq=50):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # Create SMPL model
    smpl = SMPL().to(device)
    if dataset_name == '3dpw' or dataset_name == 'surreal':
        smpl_male = SMPL(cfg.MALE_SMPL_FILE).to(device)
        smpl_female = SMPL(cfg.FEMALE_SMPL_FILE).to(device)

    batch_size = opt.batch_size

    # Create dataloader for the dataset
    if dataset_name == 'surreal':
        dataset = SurrealDataset(options, use_augmentation=False, is_train=False, use_IUV=False)
    else:
        dataset = BaseDataset(options, dataset_name, use_augmentation=False, is_train=False, use_IUV=False)

    data_loader = DataLoader(dataset,  batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.num_workers),
                             pin_memory=True)

    print('data loader finish')

    # Transfer model to the GPU
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    model.eval()

    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros(len(dataset))
    mpjpe_pa = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))

    # Mask and part metrics
    # Accuracy
    accuracy = 0.
    parts_accuracy = 0.
    # True positive, false positive and false negative
    tp = np.zeros((2, 1))
    fp = np.zeros((2, 1))
    fn = np.zeros((2, 1))
    parts_tp = np.zeros((7, 1))
    parts_fp = np.zeros((7, 1))
    parts_fn = np.zeros((7, 1))
    # Pixel count accumulators
    pixel_count = 0
    parts_pixel_count = 0

    eval_pose = False
    eval_shape = False
    eval_masks = False
    eval_parts = False
    joint_mapper = cfg.J24_TO_J17 if dataset_name == 'mpi-inf-3dhp' else cfg.J24_TO_J14
    # Choose appropriate evaluation for each dataset
    if 'h36m' in dataset_name or dataset_name == '3dpw' or dataset_name == 'mpi-inf-3dhp':
        eval_pose = True
    elif dataset_name in ['up-3d', 'surreal']:
        eval_shape = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = cfg.DATASET_FOLDERS['upi-s1h']

    if eval_parts or eval_masks:
        from utils.part_utils import PartRenderer
        renderer = PartRenderer()

    # Iterate over the entire dataset
    for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl(gt_pose, gt_betas)
        images = batch['img'].to(device)

        curr_batch_size = images.shape[0]

        # Run inference
        with torch.no_grad():
            out_dict = model(images)

        pred_vertices = out_dict['pred_vertices']
        camera = out_dict['camera']
        # 3D pose evaluation
        if eval_pose:
            # Get 14 ground truth joints
            if 'h36m' in dataset_name or 'mpi-inf' in dataset_name:
                gt_keypoints_3d = batch['pose_3d'].cuda()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper, :-1]
                gt_pelvis = (gt_keypoints_3d[:, [2]] + gt_keypoints_3d[:, [3]]) / 2
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis
            else:
                gender = batch['gender'].to(device)
                gt_vertices = smpl_male(gt_pose, gt_betas)
                gt_vertices_female = smpl_female(gt_pose, gt_betas)
                gt_vertices[gender == 1, :, :] = gt_vertices_female[gender == 1, :, :]

                gt_keypoints_3d = smpl.get_train_joints(gt_vertices)[:, joint_mapper]
                # gt_keypoints_3d = smpl.get_lsp_joints(gt_vertices)    # joints_regressor used in cmr
                gt_pelvis = (gt_keypoints_3d[:, [2]] + gt_keypoints_3d[:, [3]]) / 2
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis

            # Get 14 predicted joints from the non-parametic mesh
            pred_keypoints_3d = smpl.get_train_joints(pred_vertices)[:, joint_mapper]
            # pred_keypoints_3d = smpl.get_lsp_joints(pred_vertices)    # joints_regressor used in cmr
            pred_pelvis = (pred_keypoints_3d[:, [2]] + pred_keypoints_3d[:, [3]]) / 2
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis

            # Absolute error (MPJPE)
            error = torch.sqrt(((pred_keypoints_3d - gt_keypoints_3d) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            mpjpe[step * batch_size:step * batch_size + curr_batch_size] = error

            # Reconstuction_error
            r_error = reconstruction_error(pred_keypoints_3d.cpu().numpy(), gt_keypoints_3d.cpu().numpy(),
                                           reduction=None)
            mpjpe_pa[step * batch_size:step * batch_size + curr_batch_size] = r_error

        # Shape evaluation (Mean per-vertex error)
        if eval_shape:
            if dataset_name == 'surreal':
                gender = batch['gender'].to(device)
                gt_vertices = smpl_male(gt_pose, gt_betas)
                gt_vertices_female = smpl_female(gt_pose, gt_betas)
                gt_vertices[gender == 1, :, :] = gt_vertices_female[gender == 1, :, :]

            gt_pelvis_mesh = smpl.get_eval_joints(gt_vertices)
            pred_pelvis_mesh = smpl.get_eval_joints(pred_vertices)
            gt_pelvis_mesh = (gt_pelvis_mesh[:, [2]] + gt_pelvis_mesh[:, [3]]) / 2
            pred_pelvis_mesh = (pred_pelvis_mesh[:, [2]] + pred_pelvis_mesh[:, [3]]) / 2

            # se = torch.sqrt(((pred_vertices - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            se = torch.sqrt(((pred_vertices - pred_pelvis_mesh - gt_vertices + gt_pelvis_mesh) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            shape_err[step * batch_size:step * batch_size + curr_batch_size] = se

        # If mask or part evaluation, render the mask and part images
        if eval_masks or eval_parts:
            mask, parts = renderer(pred_vertices, camera)
        # Mask evaluation (for LSP)
        if eval_masks:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            # Dimensions of original image
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                # After rendering, convert imate back to original resolution
                pred_mask = uncrop(mask[i].cpu().numpy(), center[i], scale[i], orig_shape[i]) > 0
                # Load gt mask
                gt_mask = cv2.imread(os.path.join(annot_path, batch['maskname'][i]), 0) > 0
                # Evaluation consistent with the original UP-3D code
                accuracy += (gt_mask == pred_mask).sum()
                pixel_count += np.prod(np.array(gt_mask.shape))
                for c in range(2):
                    cgt = gt_mask == c
                    cpred = pred_mask == c
                    tp[c] += (cgt & cpred).sum()
                    fp[c] += (~cgt & cpred).sum()
                    fn[c] += (cgt & ~cpred).sum()
                f1 = 2 * tp / (2 * tp + fp + fn)

        # Part evaluation (for LSP)
        if eval_parts:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                pred_parts = uncrop(parts[i].cpu().numpy().astype(np.uint8), center[i], scale[i], orig_shape[i])
                # Load gt part segmentation
                gt_parts = cv2.imread(os.path.join(annot_path, batch['partname'][i]), 0)
                # Evaluation consistent with the original UP-3D code
                # 6 parts + background
                for c in range(7):
                    cgt = gt_parts == c
                    cpred = pred_parts == c
                    cpred[gt_parts == 255] = 0
                    parts_tp[c] += (cgt & cpred).sum()
                    parts_fp[c] += (~cgt & cpred).sum()
                    parts_fn[c] += (cgt & ~cpred).sum()
                gt_parts[gt_parts == 255] = 0
                pred_parts[pred_parts == 255] = 0
                parts_f1 = 2 * parts_tp / (2 * parts_tp + parts_fp + parts_fn)
                parts_accuracy += (gt_parts == pred_parts).sum()
                parts_pixel_count += np.prod(np.array(gt_parts.shape))

        # Print intermediate results during evaluation
        if step % log_freq == log_freq - 1:
            if eval_pose:
                print('MPJPE: ' + str(1000 * mpjpe[:step * batch_size].mean()))
                print('MPJPE-PA: ' + str(1000 * mpjpe_pa[:step * batch_size].mean()))
                print()
            if eval_shape:
                print('Shape Error: ' + str(1000 * shape_err[:step * batch_size].mean()))
                print()
            if eval_masks:
                print('Accuracy: ', accuracy / pixel_count)
                print('F1: ', f1.mean())
                print()
            if eval_parts:
                print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
                print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5, 6]].mean())
                print()

    # Print final results during evaluation
    print('*** Final Results ***')
    print()
    if eval_pose:
        print('MPJPE: ' + str(1000 * mpjpe.mean()))
        print('MPJPE-PA: ' + str(1000 * mpjpe_pa.mean()))
        print()
    if eval_shape:
        print('Shape Error: ' + str(1000 * shape_err.mean()))
        print()
    if eval_masks:
        print('Accuracy: ', accuracy / pixel_count)
        print('F1: ', f1.mean())
        print()
    if eval_parts:
        print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
        print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5, 6]].mean())
        print()

    # Save final results to .txt file
    txt_name = join(opt.save_root, dataset_name + '.txt')
    f = open(txt_name, 'w')
    f.write('*** Final Results ***')
    f.write('\n')
    if eval_pose:
        f.write('MPJPE: ' + str(1000 * mpjpe.mean()))
        f.write('\n')
        f.write('MPJPE-PA: ' + str(1000 * mpjpe_pa.mean()))
        f.write('\n')
    if eval_shape:
        f.write('Shape Error: ' + str(1000 * shape_err.mean()))
        f.write('\n')
    if eval_masks:
        f.write('Accuracy: ' + str(accuracy / pixel_count))
        f.write('\n')
        f.write('F1: ' + str(f1.mean()))
        f.write('\n')
    if eval_parts:
        f.write('Parts Accuracy: ' + str(parts_accuracy / parts_pixel_count))
        f.write('\n')
        f.write('Parts F1 (BG): ' + str(parts_f1[[0, 1, 2, 3, 4, 5, 6]].mean()))
        f.write('\n')
Пример #2
0
class Trainer(BaseTrainer):
    def init_fn(self):
        # create training dataset
        self.train_ds = create_dataset(self.options.dataset,
                                       self.options,
                                       use_IUV=True)
        self.dp_res = int(self.options.img_res // (2**self.options.warp_level))

        self.CNet = DPNet(warp_lv=self.options.warp_level,
                          norm_type=self.options.norm_type).to(self.device)

        self.LNet = get_LNet(self.options).to(self.device)
        self.smpl = SMPL().to(self.device)
        self.female_smpl = SMPL(cfg.FEMALE_SMPL_FILE).to(self.device)
        self.male_smpl = SMPL(cfg.MALE_SMPL_FILE).to(self.device)

        uv_res = self.options.uv_res
        self.uv_type = self.options.uv_type
        self.sampler = Index_UV_Generator(UV_height=uv_res,
                                          UV_width=-1,
                                          uv_type=self.uv_type).to(self.device)

        weight_file = 'data/weight_p24_h{:04d}_w{:04d}_{}.npy'.format(
            uv_res, uv_res, self.uv_type)
        if not os.path.exists(weight_file):
            cal_uv_weight(self.sampler, weight_file)

        uv_weight = torch.from_numpy(np.load(weight_file)).to(
            self.device).float()
        uv_weight = uv_weight * self.sampler.mask.to(uv_weight.device).float()
        uv_weight = uv_weight / uv_weight.mean()
        self.uv_weight = uv_weight[None, :, :, None]
        self.tv_factor = (uv_res - 1) * (uv_res - 1)

        # Setup an optimizer
        if self.options.stage == 'dp':
            self.optimizer = torch.optim.Adam(
                params=list(self.CNet.parameters()),
                lr=self.options.lr,
                betas=(self.options.adam_beta1, 0.999),
                weight_decay=self.options.wd)
            self.models_dict = {'CNet': self.CNet}
            self.optimizers_dict = {'optimizer': self.optimizer}

        else:
            self.optimizer = torch.optim.Adam(
                params=list(self.LNet.parameters()) +
                list(self.CNet.parameters()),
                lr=self.options.lr,
                betas=(self.options.adam_beta1, 0.999),
                weight_decay=self.options.wd)
            self.models_dict = {'CNet': self.CNet, 'LNet': self.LNet}
            self.optimizers_dict = {'optimizer': self.optimizer}

        # Create loss functions
        self.criterion_shape = nn.L1Loss().to(self.device)
        self.criterion_uv = nn.L1Loss().to(self.device)
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        self.criterion_keypoints_3d = nn.L1Loss(reduction='none').to(
            self.device)
        self.criterion_regr = nn.MSELoss().to(self.device)

        # LSP indices from full list of keypoints
        self.to_lsp = list(range(14))
        self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy())

        # Optionally start training from a pretrained checkpoint
        # Note that this is different from resuming training
        # For the latter use --resume
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)

    def train_step(self, input_batch):
        """Training step."""
        dtype = torch.float32

        if self.options.stage == 'dp':
            self.CNet.train()

            # Grab data from the batch
            has_dp = input_batch['has_dp']
            images = input_batch['img']
            gt_dp_iuv = input_batch['gt_iuv']
            gt_dp_iuv[:, 1:] = gt_dp_iuv[:, 1:] / 255.0
            batch_size = images.shape[0]

            if images.is_cuda and self.options.ngpu > 1:
                pred_dp, dp_feature, codes = data_parallel(
                    self.CNet, images, range(self.options.ngpu))
            else:
                pred_dp, dp_feature, codes = self.CNet(images)

            if self.options.adaptive_weight:
                fit_joint_error = input_batch['fit_joint_error']
                ada_weight = self.error_adaptive_weight(fit_joint_error).type(
                    dtype)
            else:
                # ada_weight = pred_scale.new_ones(batch_size).type(dtype)
                ada_weight = None

            losses = {}
            '''loss on dense pose result'''
            loss_dp_mask, loss_dp_uv = self.dp_loss(pred_dp, gt_dp_iuv, has_dp,
                                                    ada_weight)
            loss_dp_mask = loss_dp_mask * self.options.lam_dp_mask
            loss_dp_uv = loss_dp_uv * self.options.lam_dp_uv
            losses['dp_mask'] = loss_dp_mask
            losses['dp_uv'] = loss_dp_uv
            loss_total = sum(loss for loss in losses.values())
            # Do backprop
            self.optimizer.zero_grad()
            loss_total.backward()
            self.optimizer.step()

            # for visualize
            if (self.step_count + 1) % self.options.summary_steps == 0:
                data = {}
                vis_num = min(4, batch_size)
                data['image'] = input_batch['img_orig'][0:vis_num].detach()
                data['pred_dp'] = pred_dp[0:vis_num].detach()
                data['gt_dp'] = gt_dp_iuv[0:vis_num].detach()
                self.vis_data = data

            # Pack output arguments to be used for visualization in a list
            out_args = {
                key: losses[key].detach().item()
                for key in losses.keys()
            }
            out_args['total'] = loss_total.detach().item()
            self.loss_item = out_args

        elif self.options.stage == 'end':
            self.CNet.train()
            self.LNet.train()

            # Grab data from the batch
            # gt_keypoints_2d = input_batch['keypoints']
            # gt_keypoints_3d = input_batch['pose_3d']
            # gt_keypoints_2d = torch.cat([input_batch['keypoints'], input_batch['keypoints_smpl']], dim=1)
            # gt_keypoints_3d = torch.cat([input_batch['pose_3d'], input_batch['pose_3d_smpl']], dim=1)
            gt_keypoints_2d = input_batch['keypoints']
            gt_keypoints_3d = input_batch['pose_3d']
            has_pose_3d = input_batch['has_pose_3d']

            gt_keypoints_2d_smpl = input_batch['keypoints_smpl']
            gt_keypoints_3d_smpl = input_batch['pose_3d_smpl']
            has_pose_3d_smpl = input_batch['has_pose_3d_smpl']

            gt_pose = input_batch['pose']
            gt_betas = input_batch['betas']
            has_smpl = input_batch['has_smpl']
            has_dp = input_batch['has_dp']
            images = input_batch['img']
            gender = input_batch['gender']

            # images.requires_grad_()
            gt_dp_iuv = input_batch['gt_iuv']
            gt_dp_iuv[:, 1:] = gt_dp_iuv[:, 1:] / 255.0
            batch_size = images.shape[0]

            gt_vertices = images.new_zeros([batch_size, 6890, 3])
            if images.is_cuda and self.options.ngpu > 1:
                with torch.no_grad():
                    gt_vertices[gender < 0] = data_parallel(
                        self.smpl, (gt_pose[gender < 0], gt_betas[gender < 0]),
                        range(self.options.ngpu))
                    gt_vertices[gender == 0] = data_parallel(
                        self.male_smpl,
                        (gt_pose[gender == 0], gt_betas[gender == 0]),
                        range(self.options.ngpu))
                    gt_vertices[gender == 1] = data_parallel(
                        self.female_smpl,
                        (gt_pose[gender == 1], gt_betas[gender == 1]),
                        range(self.options.ngpu))
                    gt_uv_map = data_parallel(self.sampler, gt_vertices,
                                              range(self.options.ngpu))
                pred_dp, dp_feature, codes = data_parallel(
                    self.CNet, images, range(self.options.ngpu))
                pred_uv_map, pred_camera = data_parallel(
                    self.LNet, (pred_dp, dp_feature, codes),
                    range(self.options.ngpu))
            else:
                # gt_vertices = self.smpl(gt_pose, gt_betas)
                with torch.no_grad():
                    gt_vertices[gender < 0] = self.smpl(
                        gt_pose[gender < 0], gt_betas[gender < 0])
                    gt_vertices[gender == 0] = self.male_smpl(
                        gt_pose[gender == 0], gt_betas[gender == 0])
                    gt_vertices[gender == 1] = self.female_smpl(
                        gt_pose[gender == 1], gt_betas[gender == 1])
                    gt_uv_map = self.sampler.get_UV_map(gt_vertices.float())
                pred_dp, dp_feature, codes = self.CNet(images)
                pred_uv_map, pred_camera = self.LNet(pred_dp, dp_feature,
                                                     codes)

            if self.options.adaptive_weight:
                # Get the confidence of the GT mesh, which is used as the weight of loss item.
                # The confidence is related to the fitting error and for the data with GT SMPL parameters,
                # the confidence is 1.0
                fit_joint_error = input_batch['fit_joint_error']
                ada_weight = self.error_adaptive_weight(fit_joint_error).type(
                    dtype)
            else:
                ada_weight = None

            losses = {}
            '''loss on dense pose result'''
            loss_dp_mask, loss_dp_uv = self.dp_loss(pred_dp, gt_dp_iuv, has_dp,
                                                    ada_weight)
            loss_dp_mask = loss_dp_mask * self.options.lam_dp_mask
            loss_dp_uv = loss_dp_uv * self.options.lam_dp_uv
            losses['dp_mask'] = loss_dp_mask
            losses['dp_uv'] = loss_dp_uv
            '''loss on location map'''
            sampled_vertices = self.sampler.resample(
                pred_uv_map.float()).type(dtype)
            loss_uv = self.uv_loss(
                gt_uv_map.float(), pred_uv_map.float(), has_smpl,
                ada_weight).type(dtype) * self.options.lam_uv
            losses['uv'] = loss_uv

            if self.options.lam_tv > 0:
                loss_tv = self.tv_loss(pred_uv_map) * self.options.lam_tv
                losses['tv'] = loss_tv
            '''loss on mesh'''
            if self.options.lam_mesh > 0:
                loss_mesh = self.shape_loss(sampled_vertices, gt_vertices,
                                            has_smpl,
                                            ada_weight) * self.options.lam_mesh
                losses['mesh'] = loss_mesh
            '''loss on joints'''
            weight_key = sampled_vertices.new_ones(batch_size)
            if self.options.gtkey3d_from_mesh:
                # For the data without GT 3D keypoints but with SMPL parameters,
                # we can get the GT 3D keypoints from the mesh.
                # The confidence of the keypoints is related to the confidence of the mesh.
                gt_keypoints_3d_mesh = self.smpl.get_train_joints(gt_vertices)
                gt_keypoints_3d_mesh = torch.cat([
                    gt_keypoints_3d_mesh,
                    gt_keypoints_3d_mesh.new_ones([batch_size, 24, 1])
                ],
                                                 dim=-1)
                valid = has_smpl > has_pose_3d
                gt_keypoints_3d[valid] = gt_keypoints_3d_mesh[valid]
                has_pose_3d[valid] = 1
                if ada_weight is not None:
                    weight_key[valid] = ada_weight[valid]

            sampled_joints_3d = self.smpl.get_train_joints(sampled_vertices)
            loss_keypoints_3d = self.keypoint_3d_loss(sampled_joints_3d,
                                                      gt_keypoints_3d,
                                                      has_pose_3d, weight_key)
            loss_keypoints_3d = loss_keypoints_3d * self.options.lam_key3d
            losses['key3D'] = loss_keypoints_3d

            sampled_joints_2d = orthographic_projection(
                sampled_joints_3d, pred_camera)[:, :, :2]
            loss_keypoints_2d = self.keypoint_loss(
                sampled_joints_2d, gt_keypoints_2d) * self.options.lam_key2d
            losses['key2D'] = loss_keypoints_2d

            # We add the 24 joints of SMPL model for the training on SURREAL dataset.
            weight_key_smpl = sampled_vertices.new_ones(batch_size)
            if self.options.gtkey3d_from_mesh:
                gt_keypoints_3d_mesh = self.smpl.get_smpl_joints(gt_vertices)
                gt_keypoints_3d_mesh = torch.cat([
                    gt_keypoints_3d_mesh,
                    gt_keypoints_3d_mesh.new_ones([batch_size, 24, 1])
                ],
                                                 dim=-1)
                valid = has_smpl > has_pose_3d_smpl
                gt_keypoints_3d_smpl[valid] = gt_keypoints_3d_mesh[valid]
                has_pose_3d_smpl[valid] = 1
                if ada_weight is not None:
                    weight_key_smpl[valid] = ada_weight[valid]

            if self.options.use_smpl_joints:
                sampled_joints_3d_smpl = self.smpl.get_smpl_joints(
                    sampled_vertices)
                loss_keypoints_3d_smpl = self.smpl_keypoint_3d_loss(
                    sampled_joints_3d_smpl, gt_keypoints_3d_smpl,
                    has_pose_3d_smpl, weight_key_smpl)
                loss_keypoints_3d_smpl = loss_keypoints_3d_smpl * self.options.lam_key3d_smpl
                losses['key3D_smpl'] = loss_keypoints_3d_smpl

                sampled_joints_2d_smpl = orthographic_projection(
                    sampled_joints_3d_smpl, pred_camera)[:, :, :2]
                loss_keypoints_2d_smpl = self.keypoint_loss(
                    sampled_joints_2d_smpl,
                    gt_keypoints_2d_smpl) * self.options.lam_key2d_smpl
                losses['key2D_smpl'] = loss_keypoints_2d_smpl
            '''consistent loss'''
            if not self.options.lam_con == 0:
                loss_con = self.consistent_loss(
                    gt_dp_iuv, pred_uv_map, pred_camera,
                    ada_weight) * self.options.lam_con
                losses['con'] = loss_con

            loss_total = sum(loss for loss in losses.values())
            # Do backprop
            self.optimizer.zero_grad()
            loss_total.backward()
            self.optimizer.step()

            # for visualize
            if (self.step_count + 1) % self.options.summary_steps == 0:
                data = {}
                vis_num = min(4, batch_size)
                data['image'] = input_batch['img_orig'][0:vis_num].detach()
                data['gt_vert'] = gt_vertices[0:vis_num].detach()
                data['pred_vert'] = sampled_vertices[0:vis_num].detach()
                data['pred_cam'] = pred_camera[0:vis_num].detach()
                data['pred_joint'] = sampled_joints_2d[0:vis_num].detach()
                data['gt_joint'] = gt_keypoints_2d[0:vis_num].detach()
                data['pred_uv'] = pred_uv_map[0:vis_num].detach()
                data['gt_uv'] = gt_uv_map[0:vis_num].detach()
                data['pred_dp'] = pred_dp[0:vis_num].detach()
                data['gt_dp'] = gt_dp_iuv[0:vis_num].detach()
                self.vis_data = data

            # Pack output arguments to be used for visualization in a list
            out_args = {
                key: losses[key].detach().item()
                for key in losses.keys()
            }
            out_args['total'] = loss_total.detach().item()
            self.loss_item = out_args

        return out_args

    def train_summaries(self, batch, epoch):
        """Tensorboard logging."""
        if self.options.stage == 'dp':
            dtype = self.vis_data['pred_dp'].dtype
            rend_imgs = []
            vis_size = self.vis_data['pred_dp'].shape[0]
            # Do visualization for the first 4 images of the batch
            for i in range(vis_size):
                img = self.vis_data['image'][i].cpu().numpy().transpose(
                    1, 2, 0)
                H, W, C = img.shape
                rend_img = img.transpose(2, 0, 1)

                gt_dp = self.vis_data['gt_dp'][i]
                gt_dp = torch.nn.functional.interpolate(gt_dp[None, :],
                                                        size=[H, W])[0]
                # gt_dp = torch.cat((gt_dp, gt_dp.new_ones(1, H, W)), dim=0).cpu().numpy()
                gt_dp = gt_dp.cpu().numpy()
                rend_img = np.concatenate((rend_img, gt_dp), axis=2)

                pred_dp = self.vis_data['pred_dp'][i]
                pred_dp[0] = (pred_dp[0] > 0.5).type(dtype)
                pred_dp[1:] = pred_dp[1:] * pred_dp[0]
                pred_dp = torch.nn.functional.interpolate(pred_dp[None, :],
                                                          size=[H, W])[0]
                pred_dp = pred_dp.cpu().numpy()
                rend_img = np.concatenate((rend_img, pred_dp), axis=2)

                # import matplotlib.pyplot as plt
                # plt.imshow(rend_img.transpose([1, 2, 0]))
                rend_imgs.append(torch.from_numpy(rend_img))

            rend_imgs = make_grid(rend_imgs, nrow=1)
            self.summary_writer.add_image('imgs', rend_imgs, self.step_count)

        else:
            gt_keypoints_2d = self.vis_data['gt_joint'].cpu().numpy()
            pred_vertices = self.vis_data['pred_vert']
            pred_keypoints_2d = self.vis_data['pred_joint']
            pred_camera = self.vis_data['pred_cam']
            dtype = pred_camera.dtype
            rend_imgs = []
            vis_size = pred_vertices.shape[0]
            # Do visualization for the first 4 images of the batch
            for i in range(vis_size):
                img = self.vis_data['image'][i].cpu().numpy().transpose(
                    1, 2, 0)
                H, W, C = img.shape

                # Get LSP keypoints from the full list of keypoints
                gt_keypoints_2d_ = gt_keypoints_2d[i, self.to_lsp]
                pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[
                    i, self.to_lsp]
                vertices = pred_vertices[i].cpu().numpy()
                cam = pred_camera[i].cpu().numpy()
                # Visualize reconstruction and detected pose
                rend_img = visualize_reconstruction(img, self.options.img_res,
                                                    gt_keypoints_2d_, vertices,
                                                    pred_keypoints_2d_, cam,
                                                    self.renderer)
                rend_img = rend_img.transpose(2, 0, 1)

                if 'gt_vert' in self.vis_data.keys():
                    rend_img2 = vis_mesh(
                        img,
                        self.vis_data['gt_vert'][i].cpu().numpy(),
                        cam,
                        self.renderer,
                        color='blue')
                    rend_img2 = rend_img2.transpose(2, 0, 1)
                    rend_img = np.concatenate((rend_img, rend_img2), axis=2)

                gt_dp = self.vis_data['gt_dp'][i]
                gt_dp = torch.nn.functional.interpolate(gt_dp[None, :],
                                                        size=[H, W])[0]
                gt_dp = gt_dp.cpu().numpy()
                # gt_dp = torch.cat((gt_dp, gt_dp.new_ones(1, H, W)), dim=0).cpu().numpy()
                rend_img = np.concatenate((rend_img, gt_dp), axis=2)

                pred_dp = self.vis_data['pred_dp'][i]
                pred_dp[0] = (pred_dp[0] > 0.5).type(dtype)
                pred_dp[1:] = pred_dp[1:] * pred_dp[0]
                pred_dp = torch.nn.functional.interpolate(pred_dp[None, :],
                                                          size=[H, W])[0]
                pred_dp = pred_dp.cpu().numpy()
                rend_img = np.concatenate((rend_img, pred_dp), axis=2)

                # import matplotlib.pyplot as plt
                # plt.imshow(rend_img.transpose([1, 2, 0]))
                rend_imgs.append(torch.from_numpy(rend_img))

            rend_imgs = make_grid(rend_imgs, nrow=1)

            uv_maps = []
            for i in range(vis_size):
                uv_temp = torch.cat(
                    (self.vis_data['pred_uv'][i], self.vis_data['gt_uv'][i]),
                    dim=1)
                uv_maps.append(uv_temp.permute(2, 0, 1))

            uv_maps = make_grid(uv_maps, nrow=1)
            uv_maps = uv_maps.abs()
            uv_maps = uv_maps / uv_maps.max()

            # Save results in Tensorboard
            self.summary_writer.add_image('imgs', rend_imgs, self.step_count)
            self.summary_writer.add_image('uv_maps', uv_maps, self.step_count)

        for key in self.loss_item.keys():
            self.summary_writer.add_scalar('loss_' + key, self.loss_item[key],
                                           self.step_count)

    def train(self):
        """Training process."""
        # Run training for num_epochs epochs
        for epoch in range(self.epoch_count, self.options.num_epochs):
            # Create new DataLoader every epoch and (possibly) resume from an arbitrary step inside an epoch
            train_data_loader = CheckpointDataLoader(
                self.train_ds,
                checkpoint=self.checkpoint,
                batch_size=self.options.batch_size,
                num_workers=self.options.num_workers,
                pin_memory=self.options.pin_memory,
                shuffle=self.options.shuffle_train)

            # Iterate over all batches in an epoch
            batch_len = len(self.train_ds) // self.options.batch_size
            data_stream = tqdm(train_data_loader,
                               desc='Epoch ' + str(epoch),
                               total=len(self.train_ds) //
                               self.options.batch_size,
                               initial=train_data_loader.checkpoint_batch_idx)
            for step, batch in enumerate(
                    data_stream, train_data_loader.checkpoint_batch_idx):
                if time.time() < self.endtime:

                    batch = {
                        k:
                        v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()
                    }

                    loss_dict = self.train_step(batch)
                    self.step_count += 1

                    tqdm_info = 'Epoch:%d| %d/%d ' % (epoch, step, batch_len)
                    for k, v in loss_dict.items():
                        tqdm_info += ' %s:%.4f' % (k, v)
                    data_stream.set_description(tqdm_info)

                    if self.step_count % self.options.summary_steps == 0:
                        self.train_summaries(step, epoch)

                    # Save checkpoint every checkpoint_steps steps
                    if self.step_count % self.options.checkpoint_steps == 0 and self.step_count > 0:
                        self.saver.save_checkpoint(
                            self.models_dict, self.optimizers_dict, epoch,
                            step + 1, self.options.batch_size,
                            train_data_loader.sampler.dataset_perm,
                            self.step_count)
                        tqdm.write('Checkpoint saved')

                    # Run validation every test_steps steps
                    if self.step_count % self.options.test_steps == 0:
                        self.test()

                else:
                    tqdm.write('Timeout reached')
                    self.saver.save_checkpoint(
                        self.models_dict, self.optimizers_dict, epoch, step,
                        self.options.batch_size,
                        train_data_loader.sampler.dataset_perm,
                        self.step_count)
                    tqdm.write('Checkpoint saved')
                    sys.exit(0)

            # load a checkpoint only on startup, for the next epochs just iterate over the dataset as usual
            self.checkpoint = None
            # save checkpoint after each 10 epoch
            if (epoch + 1) % 10 == 0:
                self.saver.save_checkpoint(self.models_dict,
                                           self.optimizers_dict, epoch + 1, 0,
                                           self.options.batch_size, None,
                                           self.step_count)

        self.saver.save_checkpoint(self.models_dict,
                                   self.optimizers_dict,
                                   epoch + 1,
                                   0,
                                   self.options.batch_size,
                                   None,
                                   self.step_count,
                                   checkpoint_filename='final')
        return