コード例 #1
0
ファイル: config.py プロジェクト: PaddlePaddle/PaddleSeg
    def learning_rate(self) -> paddle.optimizer.lr.LRScheduler:
        logger.warning(
            '''`learning_rate` in configuration file will be deprecated, please use `lr_scheduler` instead. E.g
            lr_scheduler:
                type: PolynomialDecay
                learning_rate: 0.01''')
        _learning_rate = self.dic.get('learning_rate', {}).get('value')
        if not _learning_rate:
            raise RuntimeError(
                'No learning rate specified in the configuration file.')

        args = self.decay_args
        decay_type = args.pop('type')

        if decay_type == 'poly':
            lr = _learning_rate
            return paddle.optimizer.lr.PolynomialDecay(lr, **args)
        elif decay_type == 'piecewise':
            values = _learning_rate
            return paddle.optimizer.lr.PiecewiseDecay(values=values, **args)
        elif decay_type == 'stepdecay':
            lr = _learning_rate
            return paddle.optimizer.lr.StepDecay(lr, **args)
        else:
            raise RuntimeError('Only poly and piecewise decay support.')
コード例 #2
0
ファイル: utils.py プロジェクト: windstamp/PaddleSeg
def load_entire_model(model, pretrained):
    if pretrained is not None:
        if os.path.exists(pretrained):
            load_pretrained_model(model, pretrained)
        else:
            raise FileNotFoundError(
                'Pretrained model is not found: {}'.format(pretrained))
    else:
        logger.warning('Not all pretrained params of {} are loaded, ' \
                       'training from scratch or a pretrained backbone.'.format(model.__class__.__name__))
コード例 #3
0
ファイル: utils.py プロジェクト: hysunflower/PaddleSeg
def load_pretrained_model(model, pretrained_model):
    if pretrained_model is not None:
        logger.info(
            'Loading pretrained model from {}'.format(pretrained_model))
        # download pretrained model from url
        if urlparse(pretrained_model).netloc:
            pretrained_model = unquote(pretrained_model)
            savename = pretrained_model.split('/')[-1]
            if not savename.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
                savename = pretrained_model.split('/')[-2]
            else:
                savename = savename.split('.')[0]
            with generate_tempdir() as _dir:
                with filelock.FileLock(os.path.join(seg_env.TMP_HOME,
                                                    savename)):
                    pretrained_model = download_file_and_uncompress(
                        pretrained_model,
                        savepath=_dir,
                        extrapath=seg_env.PRETRAINED_MODEL_HOME,
                        extraname=savename)

                    pretrained_model = os.path.join(pretrained_model,
                                                    'model.pdparams')

        if os.path.exists(pretrained_model):
            para_state_dict = paddle.load(pretrained_model)

            model_state_dict = model.state_dict()
            keys = model_state_dict.keys()
            num_params_loaded = 0
            for k in keys:
                if k not in para_state_dict:
                    logger.warning("{} is not in pretrained model".format(k))
                elif list(para_state_dict[k].shape) != list(
                        model_state_dict[k].shape):
                    logger.warning(
                        "[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
                        .format(k, para_state_dict[k].shape,
                                model_state_dict[k].shape))
                else:
                    model_state_dict[k] = para_state_dict[k]
                    num_params_loaded += 1
            model.set_dict(model_state_dict)
            logger.info("There are {}/{} variables loaded into {}.".format(
                num_params_loaded, len(model_state_dict),
                model.__class__.__name__))

        else:
            raise ValueError(
                'The pretrained model directory is not Found: {}'.format(
                    pretrained_model))
    else:
        logger.info(
            'No pretrained model to load, {} will be trained from scratch.'.
            format(model.__class__.__name__))
コード例 #4
0
ファイル: seg_env.py プロジェクト: zjhellofss/PaddleSeg
def _get_seg_home():
    if 'SEG_HOME' in os.environ:
        home_path = os.environ['SEG_HOME']
        if os.path.exists(home_path):
            if os.path.isdir(home_path):
                return home_path
            else:
                logger.warning('SEG_HOME {} is a file!'.format(home_path))
        else:
            return home_path
    return os.path.join(_get_user_home(), '.paddleseg')
コード例 #5
0
 def init_weight(self):
     if self.pretrained is not None:
         para_state_dict = paddle.load(self.pretrained)
         model_state_dict = self.backbone.state_dict()
         keys = model_state_dict.keys()
         num_params_loaded = 0
         for k in keys:
             k_parts = k.split('.')
             torchkey = 'backbone.' + k
             if k_parts[1] == 'layer5':
                 logger.warning("{} should not be loaded".format(k))
             elif torchkey not in para_state_dict:
                 logger.warning("{} is not in pretrained model".format(k))
             elif list(para_state_dict[torchkey].shape) != list(
                     model_state_dict[k].shape):
                 logger.warning(
                     "[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
                     .format(k, para_state_dict[torchkey].shape,
                             model_state_dict[k].shape))
             else:
                 model_state_dict[k] = para_state_dict[torchkey]
                 num_params_loaded += 1
         self.backbone.set_dict(model_state_dict)
         logger.info("There are {}/{} variables loaded into {}.".format(
             num_params_loaded, len(model_state_dict),
             self.backbone.__class__.__name__))
コード例 #6
0
ファイル: utils.py プロジェクト: PaddlePaddle/PaddleSeg
def load_pretrained_model(model, pretrained_model):
    if pretrained_model is not None:
        logger.info(
            'Loading pretrained model from {}'.format(pretrained_model))

        if urlparse(pretrained_model).netloc:
            pretrained_model = download_pretrained_model(pretrained_model)

        if os.path.exists(pretrained_model):
            para_state_dict = paddle.load(pretrained_model)

            model_state_dict = model.state_dict()
            keys = model_state_dict.keys()
            num_params_loaded = 0
            for k in keys:
                if k not in para_state_dict:
                    logger.warning("{} is not in pretrained model".format(k))
                elif list(para_state_dict[k].shape) != list(
                        model_state_dict[k].shape):
                    logger.warning(
                        "[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
                        .format(k, para_state_dict[k].shape,
                                model_state_dict[k].shape))
                else:
                    model_state_dict[k] = para_state_dict[k]
                    num_params_loaded += 1
            model.set_dict(model_state_dict)
            logger.info("There are {}/{} variables loaded into {}.".format(
                num_params_loaded, len(model_state_dict),
                model.__class__.__name__))

        else:
            raise ValueError(
                'The pretrained model directory is not Found: {}'.format(
                    pretrained_model))
    else:
        logger.info(
            'No pretrained model to load, {} will be trained from scratch.'.
            format(model.__class__.__name__))
コード例 #7
0
ファイル: utils.py プロジェクト: hysunflower/PaddleSeg
def load_entire_model(model, pretrained):
    if pretrained is not None:
        load_pretrained_model(model, pretrained)
    else:
        logger.warning('Not all pretrained params of {} are loaded, ' \
                       'training from scratch or a pretrained backbone.'.format(model.__class__.__name__))