Esempio n. 1
0
    def __init__(self, cfg: CfgNode):
        """
        Initialize chart-based loss from configuration options

        Args:
            cfg (CfgNode): configuration options
        """
        # fmt: off
        self.heatmap_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE
        self.w_points = cfg.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS
        self.w_part = cfg.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS
        self.use_part_focal_loss = cfg.MODEL.ROI_DENSEPOSE_HEAD.PART_FOCAL_LOSS
        self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS
        self.w_body = cfg.MODEL.ROI_DENSEPOSE_HEAD.BODY_WEIGHTS
        self.n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS
        self.w_smooth = cfg.MODEL.ROI_DENSEPOSE_HEAD.SMOOTH_WEIGHTS
        self.w_tv = cfg.MODEL.ROI_DENSEPOSE_HEAD.TV_WEIGHTS
        # fmt: on
        self.segm_trained_by_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
        # <<<<<<< HEAD
        self.use_mean_uv = cfg.MODEL.ROI_DENSEPOSE_HEAD.MEAN_UV_LOSS

        if self.use_part_focal_loss:
            gamma = cfg.MODEL.ROI_DENSEPOSE_HEAD.PART_FOCAL_GAMMA
            self.focal_loss = FocalLoss(gamma=gamma,
                                        alpha=None,
                                        size_average=False)

        self.use_teacher_student = cfg.MODEL.TEACHER_STUDENT
        self.teacher_cfg = cfg.MODEL.TEACHER_CFG_FILE
        self.teacher_weights = cfg.MODEL.TEACHER_WEIGHTS
        self.teach_ins_wo_gt_dp = cfg.MODEL.TEACH_INS_WO_GT_DP
        self.w_part_teach = cfg.MODEL.TEACH_PART_WEIGHTS
        self.w_points_teach = cfg.MODEL.TEACH_POINT_REGRESSION_WEIGHTS
        if self.use_teacher_student:
            from densepose.engine import Trainer
            from densepose.modeling.densepose_checkpoint import DensePoseCheckpointer
            from densepose.config import get_cfg, add_densepose_config
            self.teacher_cfg = get_cfg()
            add_densepose_config(self.teacher_cfg)
            self.teacher_cfg.merge_from_file(cfg.MODEL.TEACHER_CFG_FILE)
            self.teacher_model = Trainer.build_model(self.teacher_cfg)
            # pdb.set_trace()
            DensePoseCheckpointer(self.teacher_model).load(
                cfg.MODEL.TEACHER_WEIGHTS)
            self.teacher_model.eval()

        self.use_aux_global_s = cfg.MODEL.CONDINST.AUX_SUPERVISION_GLOBAL_S
        self.use_aux_global_skeleton = cfg.MODEL.CONDINST.AUX_SUPERVISION_GLOBAL_SKELETON
        self.use_aux_body_semantics = cfg.MODEL.CONDINST.AUX_SUPERVISION_BODY_SEMANTICS
        self.w_aux_global_s = cfg.MODEL.CONDINST.AUX_SUPERVISION_GLOBAL_S_WEIGHTS
        self.w_aux_global_skeleton = cfg.MODEL.CONDINST.AUX_SUPERVISION_GLOBAL_SKELETON_WEIGHTS
        self.w_aux_body_semantics = cfg.MODEL.CONDINST.AUX_SUPERVISION_BODY_SEMANTICS

        self.pred_ins_body = cfg.MODEL.CONDINST.PREDICT_INSTANCE_BODY

        #     def __call__(
        #         self, proposals_with_gt: List[Instances], densepose_predictor_outputs: Any, images=None
        # =======
        self.segm_loss = MaskOrSegmentationLoss(cfg)
Esempio n. 2
0
def main(args):
    cfg = setup(args)
    # disable strict kwargs checking: allow one to specify path handle
    # hints through kwargs, like timeout in DP evaluation
    PathManager.set_strict_kwargs_checking(False)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res

    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()