Beispiel #1
0
    def init_fn(self):

        job_env = submitit.JobEnvironment()

        self.train_ds = MixedDataset(self.options,
                                     ignore_3d=self.options.ignore_3d,
                                     is_train=True)

        self.model = hmr(config.SMPL_MEAN_PARAMS,
                         pretrained=True).to(self.device)
        self.model.cuda(job_env.local_rank)
        self.model = torch.nn.parallel.DistributedDataParallel(
            self.model,
            device_ids=[job_env.local_rank],
            output_device=job_env.local_rank)

        if self.options.bExemplarMode:
            lr = 5e-5 * 0.2
        else:
            lr = self.options.lr
        self.optimizer = torch.optim.Adam(
            params=self.model.parameters(),
            #   lr=self.options.lr,
            lr=lr,
            weight_decay=0)

        if self.options.bUseSMPLX:  #SMPL-X model           #No change is required for HMR training. SMPL-X ignores hand and other parts.
            #SMPL uses 23 joints, while SMPL-X uses 21 joints, automatically ignoring the last two joints of SMPL
            self.smpl = SMPLX(config.SMPL_MODEL_DIR,
                              batch_size=self.options.batch_size,
                              create_transl=False).to(self.device)
        else:  #Original SMPL
            self.smpl = SMPL(config.SMPL_MODEL_DIR,
                             batch_size=self.options.batch_size,
                             create_transl=False).to(self.device)

        # Per-vertex loss on the shape
        self.criterion_shape = nn.L1Loss().to(self.device)
        # Keypoint (2D and 3D) loss
        # No reduction because confidence weighting needs to be applied
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        # Loss for SMPL parameter regression
        self.criterion_regr = nn.MSELoss().to(self.device)
        self.models_dict = {'model': self.model}
        self.optimizers_dict = {'optimizer': self.optimizer}
        self.focal_length = constants.FOCAL_LENGTH

        # Initialize SMPLify fitting module
        self.smplify = SMPLify(step_size=1e-2,
                               batch_size=self.options.batch_size,
                               num_iters=self.options.num_smplify_iters,
                               focal_length=self.focal_length)
        if self.options.pretrained_checkpoint is not None:
            print(">>> Load Pretrained mode: {}".format(
                self.options.pretrained_checkpoint))
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)
            self.backupModel()

        #This should be called here after loading model
        # if torch.cuda.device_count() > 1:
        assert torch.cuda.device_count() > 1
        print("Let's use", torch.cuda.device_count(), "GPUs!")

        # self.model = torch.nn.DataParallel(self.model)      #Failed...
        # self.model.cuda(job_env.local_rank)

        # Load dictionary of fits
        self.fits_dict = FitsDict(self.options, self.train_ds)

        # Create renderer
        self.renderer = None  # Renderer(focal_length=self.focal_length, img_res=self.options.img_res, faces=self.smpl.faces)

        #debug
        from torchvision.transforms import Normalize
        self.de_normalize_img = Normalize(mean=[
            -constants.IMG_NORM_MEAN[0] / constants.IMG_NORM_STD[0],
            -constants.IMG_NORM_MEAN[1] / constants.IMG_NORM_STD[1],
            -constants.IMG_NORM_MEAN[2] / constants.IMG_NORM_STD[2]
        ],
                                          std=[
                                              1 / constants.IMG_NORM_STD[0],
                                              1 / constants.IMG_NORM_STD[1],
                                              1 / constants.IMG_NORM_STD[2]
                                          ])
Beispiel #2
0
            'spinmodel_shared/11-13-78679-bab_spin_mlc3d_fter60_ag-9589/checkpoints/2019_11_14-02_14_28-best-55.79321086406708.pt'
        ]  #Ours 3D + augmentation
        params = [
            '--checkpoint',
            'logs/11-13-78679-bab_spin_mlc3d_fter60-7183/checkpoints/2019_11_14-08_12_35-best-56.12510070204735.pt'
        ]  #Ours 3D  (no Aug!)
        params = ['--checkpoint', 'data/model_checkpoint.pt']  #Original

        params += ['--dataset', '3dpw-vibe']
        # params +=['--num_workers',0]

        args = parser.parse_args(params)
        args.batch_size = 64
        args.num_workers = 4

    model = hmr(config.SMPL_MEAN_PARAMS)
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model'], strict=False)
    model.cuda()
    model.eval()

    # # Setup evaluation dataset
    # # dataset = BaseDataset(None, '3dpw', is_train=False, bMiniTest=False)
    # dataset = BaseDataset(None, '3dpw', is_train=False, bMiniTest=False, bEnforceUpperOnly=False)
    # # dataset = BaseDataset(None, '3dpw-crop', is_train=False, bMiniTest=False, bEnforceUpperOnly=True)
    # # # Run evaluation
    # run_evaluation(model, '3dpw',dataset , args.result_file,
    #                batch_size=args.batch_size,
    #                shuffle=args.shuffle,
    #                log_freq=args.log_freq, num_workers=args.num_workers)