Exemplo n.º 1
0
 def _correspondence_check(self):
     if self.load_rgb:
         for sparse_img_path, gt_img_path, rgb_img_path in zip(
                 self.sparse_depth_paths, self.gt_paths, self.rgb_paths):
             # check img_id, camera_id, footage_id
             date_long = re.search(
                 '/{0}/(.*)/proj_depth/'.format(self.dataset_type),
                 sparse_img_path)[1]
             date = date_long[:10]
             img_id = re.search('/velodyne_raw/(.*)', sparse_img_path)[1]
             assert os.path.join(self.rgb_path, date, date_long,
                                 img_id.replace('/',
                                                '/data/')) == rgb_img_path
             assert sparse_img_path.replace(
                 'data_depth_velodyne', 'data_depth_annotated').replace(
                     'velodyne_raw', 'groundtruth') == gt_img_path
             # ### once for all ###
             # sparse_img_shape = np.array(Image.open(sparse_img_path)).shape
             # gt_img_shape = np.array(Image.open(gt_img_path)).shape
             # rgb_img_shape = np.array(Image.open(rgb_img_path)).shape
             # if sparse_img_shape != (375, 1242) or gt_img_shape != (375, 1242) or rgb_img_shape != (375, 1242, 3):
             #     exec(utils.TEST_EMBEDDING)
             #     print(sparse_img_shape, gt_img_shape, rgb_img_shape)
             #     print(sparse_img_path, gt_img_path, rgb_img_path)
             # exec(utils.TEST_EMBEDDING)
     else:
         for sparse_img_path, gt_img_path in zip(self.sparse_depth_paths,
                                                 self.gt_paths):
             # check img_id, camera_id, footage_id
             assert sparse_img_path.replace(
                 'data_depth_velodyne', 'data_depth_annotated').replace(
                     'velodyne_raw', 'groundtruth') == gt_img_path
     content_list = []
     content_list += ['Dataset is complete']
     utils.print_notification(content_list)
Exemplo n.º 2
0
def print_opt(opt):
    content_list = []
    args = list(vars(opt))
    args.sort()
    for arg in args:
        content_list += [arg.rjust(25, ' ') + '  ' + str(getattr(opt, arg))]
    utils.print_notification(content_list, 'OPTIONS')
Exemplo n.º 3
0
 def check_options(self):
     if self.opt.guess_model != 'init_guess':
         content_list = []
         content_list += [
             'HomographyInference currently only support init_guess as upstream'
         ]
         utils.print_notification(content_list, 'ERROR')
         exit(1)
Exemplo n.º 4
0
 def check_options(self):
     if self.opt.guess_model != self.name:
         content_list = []
         content_list += [
             'You are not using the correct class for training or eval'
         ]
         utils.print_notification(content_list, 'ERROR')
         exit(1)
Exemplo n.º 5
0
 def _verify_checkpoint(self, checkpoint):
     if checkpoint['prevent_neg'] != self.opt.prevent_neg:
         content_list = []
         content_list += [
             'Prevent negative method are different between the checkpoint and user options'
         ]
         utils.print_notification(content_list, 'ERROR')
         exit(1)
 def check_options(self):
     valid_models = ['loss_surface']
     if self.opt.error_model not in valid_models:
         content_list = []
         content_list += [
             'End2EndOptim current only support {0} as optimization objective'
             .format(valid_models)
         ]
         utils.print_notification(content_list, 'ERROR')
         exit(1)
     assert self.opt.optim_iters > 0, 'optimization iterations should be larger than 0'
Exemplo n.º 7
0
 def make_value_positive(self, x):
     if self.prevent_neg == 'sigmoid':
         x = torch.sigmoid(x)
     else:
         content_list = []
         content_list += [
             'Unknown prevent_neg method: {0}'.format(self.prevent_neg)
         ]
         utils.print_notification(content_list, 'ERROR')
         exit(1)
     return x
Exemplo n.º 8
0
    def create_resnet_config(self):
        if hasattr(self.opt,
                   'imagenet_pretrain') and self.opt.imagenet_pretrain:
            content_list = []
            content_list += [
                'LossSurfaceRegressor do not support imagenet pretrained weights loading'
            ]
            utils.print_notification(content_list, 'ERROR')
            exit(1)

        resnet_config = super().create_resnet_config()
        return resnet_config
Exemplo n.º 9
0
 def check_options(self):
     if self.opt.error_model != self.name:
         content_list = []
         content_list += [
             'You are not using the correct class for training or eval'
         ]
         content_list += [
             'error_model in options: {0}, current error_model class: {1}'.
             format(self.opt.error_model, self.name)
         ]
         utils.print_notification(content_list, 'ERROR')
         exit(1)
Exemplo n.º 10
0
 def _get_rgb_paths(self):
     self.rgb_paths = []
     for fname in self.sparse_depth_paths:
         date_long = re.search(
             '/{0}/(.*)/proj_depth/'.format(self.dataset_type), fname)[1]
         date = date_long[:10]
         img_id = re.search('/velodyne_raw/(.*)', fname)[1]
         rgb_img_path = os.path.join(self.rgb_path, date, date_long,
                                     img_id.replace('/', '/data/'))
         if not os.path.isfile(rgb_img_path):
             content_list = []
             content_list += ['Cannot find corresponding RGB images']
             utils.print_notification(content_list, 'ERROR')
             exit(1)
         self.rgb_paths.append(rgb_img_path)
Exemplo n.º 11
0
def check_prevent_neg(opt):
    if hasattr(opt, 'prevent_neg') and hasattr(opt,
                                               'load_weights_error_model'):
        json_path = os.path.join(opt.out_dir, opt.load_weights_error_model,
                                 'params.json')
        with open(json_path, 'r') as f:
            model_config = json.load(f)
        weights_prevent_neg = model_config['prevent_neg']
        if weights_prevent_neg != opt.prevent_neg:
            content_list = []
            content_list += [
                'Prevent negative method are different between the checkpoint and user options'
            ]
            utils.print_notification(content_list, 'ERROR')
            exit(1)
Exemplo n.º 12
0
def check_pretrained_weights(opt):
    if opt.load_weights:
        if hasattr(opt, 'resume') and opt.resume:
            content_list = []
            content_list += ['Resume or load weights, make your choice']
            utils.print_notification(content_list, 'ERROR')
            exit(1)
        existing = False
        weights_path = os.path.join(opt.out_dir, opt.load_weights,
                                    'checkpoint.pth.tar')
        if os.path.exists(weights_path):
            existing = True
        else:
            content_list = []
            content_list += ['Cannot find pretrained weights']
            utils.print_notification(content_list, 'ERROR')
            exit(1)
Exemplo n.º 13
0
 def print_resnet_config(self, resnet_config):
     content_list = []
     content_list += ['Resnet backbone config for {0}'.format(self.name)]
     content_list += [
         'Spectral norm for resnet: {0}'.format(
             resnet_config.need_spectral_norm)
     ]
     if resnet_config.group_norm == 0:
         content_list += ['Using BN for resnet']
     else:
         content_list += [
             'Using GN for resnet, number of groups: {0}'.format(
                 resnet_config.group_norm)
         ]
     content_list += [
         'Imagenet pretrain weights for resnet: {0}'.format(
             resnet_config.pretrained)
     ]
     utils.print_notification(content_list)
Exemplo n.º 14
0
    def load_pretrained_weights(self):
        '''load pretrained weights
        this function can load weights from another model.
        '''
        # 1. load check point
        checkpoint_path = self._get_checkpoint_path()
        checkpoint = self._load_checkpoint(checkpoint_path)

        # 2. verify check point
        self._verify_checkpoint(checkpoint)

        # 3. try loading weights
        key_name = 'model_state_dict'
        saved_weights = checkpoint[key_name]
        try:
            self.load_state_dict(saved_weights)
        except RuntimeError:
            # handling the DataParallel weights problem
            try:
                weights = saved_weights
                weights = {
                    k.replace('module.', ''): v
                    for k, v in weights.items()
                }
                self.load_state_dict(weights)
            except RuntimeError:
                try:
                    weights = saved_weights
                    weights = {'module.' + k: v for k, v in weights.items()}
                    self.load_state_dict(weights)
                except RuntimeError:
                    content_list = []
                    content_list += [
                        'Cannot load weights for {0}'.format(self.name)
                    ]
                    utils.print_notification(content_list, 'ERROR')
                    exit(1)

        # 4. loaded
        content_list = []
        content_list += ['Weights loaded for {0}'.format(self.name)]
        content_list += ['From: {0}'.format(checkpoint_path)]
        utils.print_notification(content_list)
Exemplo n.º 15
0
def check_pretrained_weights(opt):
    pretrained_weights_option_list = [
        'load_weights_upstream', 'load_weights_error_model'
    ]
    for pretrained_weights_option in pretrained_weights_option_list:
        if hasattr(opt, pretrained_weights_option) and getattr(
                opt, pretrained_weights_option):
            weights_path = os.path.join(
                opt.out_dir, getattr(opt, pretrained_weights_option),
                'checkpoint.pth.tar')
            if not os.path.exists(weights_path):
                content_list = []
                content_list += [
                    'Cannot find pretrained weights for {0}, at {1}'.format(
                        pretrained_weights_option, weights_path)
                ]
                utils.print_notification(content_list, 'ERROR')
                exit(1)
            if hasattr(opt, 'error_model'):
                check_prevent_neg(opt)
Exemplo n.º 16
0
def resnet18(opt, pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(opt,
                   BasicBlock, [2, 2, 2, 2],
                   group_norm=opt.group_norm,
                   **kwargs)
    if pretrained:
        try:
            model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
            content_list = []
            content_list += ['Imagenet pretrained weights fully loaded']
            utils.print_notification(content_list)
        except:
            pretrained_dict = model_zoo.load_url(model_urls['resnet18'])
            # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
            model_dict = model.state_dict()
            # 1. filter out unnecessary keys
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            # 2. pop-out fc
            pretrained_dict.pop('fc.weight', None)
            pretrained_dict.pop('fc.bias', None)
            # 3. overwrite entries in the existing state dict
            model_dict.update(pretrained_dict)
            # 4. load the new state dict
            model.load_state_dict(model_dict)
            content_list = []
            content_list += ['Imagenet pretrained weights partially loaded']
            content_list += [str(pretrained_dict.keys())]
            utils.print_notification(content_list)
    return model
Exemplo n.º 17
0
def check_existing(opt):
    existing = False
    if os.path.exists(opt.out) or os.path.exists(opt.tfb_out):
        content_list = []
        content_list += [opt.out, str(os.path.exists(opt.out))]
        content_list += [opt.tfb_out, str(os.path.exists(opt.tfb_out))]
        content_list += ['Found existing checkpoint and log']
        utils.print_notification(content_list, 'WARNING')
        existing = True
    else:
        content_list = []
        content_list += ['New model, no history']
        utils.print_notification(content_list)
        existing = False
    if existing is False and opt.resume is True:
        content_list = []
        content_list += ['No history, cannot resume']
        utils.print_notification(content_list, 'ERROR')
        exit(1)