コード例 #1
0
ファイル: deform_conv.py プロジェクト: zlx-6/yolodet-pytorch
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)
        # version = None
        if version is None or version < 2:
            # the key is different in early versions
            # In version < 2, ModulatedDeformConvPack
            # loads previous benchmark models.
            if (prefix + 'conv_offset.weight' not in state_dict
                    and prefix[:-1] + '_offset.weight' in state_dict):
                state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
                    prefix[:-1] + '_offset.weight')
            if (prefix + 'conv_offset.bias' not in state_dict
                    and prefix[:-1] + '_offset.bias' in state_dict):
                state_dict[prefix +
                           'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
                                                                '_offset.bias')

        if version is not None and version > 1:
            Logging.getLogger().info(
                'ModulatedDeformConvPack {} is upgraded to version 2.'.format(
                    prefix.rstrip('.')))

        super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                      strict, missing_keys, unexpected_keys,
                                      error_msgs)
コード例 #2
0
ファイル: train.py プロジェクト: molly948/yolodet-pytorch
def train_detector(model,
                   dataset,
                   cfg,
                   validate=False,
                   timestamp=None,
                   meta=None):
    logger = Logging.getLogger()
    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    data_loaders = [build_dataloader(ds, data=cfg.data) for ds in dataset]
    if torch.cuda.is_available():
        model = model.cuda(cfg.gpu_ids[0])
        model.device = cfg.gpu_ids[0]
        if torch.cuda.device_count() > 1:
            model = DataParallel(model, device_ids=cfg.gpu_ids)
    else:
        model.device = 'cpu'

    # build runner
    optimizer = cfg.optimizer

    if 'ema' in cfg:
        ema = cfg.ema
    else:
        ema = None
    runner = Runner(model,
                    batch_processor,
                    optimizer,
                    cfg.work_dir,
                    logger=logger,
                    meta=meta,
                    ema=ema)
    # an ugly walkaround to make the .log and .log.json filenames the same
    runner.timestamp = timestamp

    # register eval hooks 需要放在日志前面,不然打印不出日志。
    if validate:
        cfg.data.val.train = False
        val_dataset = build_from_dict(cfg.data.val, DATASET)
        val_dataloader = build_dataloader(val_dataset,
                                          shuffle=False,
                                          data=cfg.data)
        eval_cfg = cfg.get('evaluation', {})
        from yolodet.models.hooks.eval_hook import EvalHook
        runner.register_hook(EvalHook(val_dataloader, **eval_cfg))

    # register hooks
    # runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,cfg.checkpoint_config)
    runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
コード例 #3
0
    def __init__(self,
                 model,
                 batch_processor,
                 optimizer=None,
                 work_dir=None,
                 logger=None,
                 meta=None,
                 ema=None):
        assert callable(batch_processor)
        self.model = model
        if meta is not None:
            assert isinstance(meta, dict), '"meta" must be a dict or None'
        self.meta = meta

        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0
        self._max_epochs = 0
        self._max_iters = 0
        self._warmup_max_iters = 0
        self._momentum = 0

        self.batch_processor = batch_processor

        # create work_dir
        if isinstance(work_dir, str):
            self.work_dir = osp.abspath(work_dir)
            file_utils.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__
        if logger is None:
            self.logger = Logging.getLogger()
        else:
            self.logger = logger
        self.log_buffer = LogBuffer()

        if optimizer is not None:
            self.optimizer = self.init_optimizer(optimizer)
        else:
            self.optimizer = None

        if ema is not None:
            self.ema = self.init_ema(ema)
        else:
            self.ema = None
コード例 #4
0
ファイル: checkpoint.py プロジェクト: zlx-6/yolodet-pytorch
def load_checkpoint(model, filename, strict=False, map_location='cpu'):
    logger = Logging.getLogger()
    logger.info('load checkpoint from %s', filename)
    if not osp.isfile(filename):
        raise IOError('{} is not a checkpoint file'.format(filename))
    checkpoint = torch.load(filename, map_location=map_location)
    # get state_dict from checkpoint
    if isinstance(checkpoint, OrderedDict):
        state_dict = checkpoint
    elif isinstance(checkpoint, dict) and 'model' in checkpoint:
        state_dict = checkpoint['model'].state_dict()
    else:
        raise RuntimeError(
            'No state_dict found in checkpoint file {}'.format(filename))
    # strip prefix of state_dict
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
    # load state_dict
    if hasattr(model, 'module'):
        load_state_dict(model.module, state_dict, strict)
    else:
        load_state_dict(model, state_dict, strict)

    return checkpoint
コード例 #5
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.device is not None:
        cfg.device = args.device
    else:
        cfg.device = None
    device = select_device(cfg.device)
    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8

    # create work_dir
    file_utils.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # init the logger before other steps
    # timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    # log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
    logger = Logging.getLogger()

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([('{}: {}'.format(k, v))
                          for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info
    meta['batch_size'] = cfg.data.batch_size
    meta['subdivisions'] = cfg.data.subdivisions
    meta['multi_scale'] = args.multi_scale
    # log some basic info
    logger.info('Config:\n{}'.format(cfg.text))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}, deterministic: {}'.format(
            args.seed, args.deterministic))
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed
    model = build_from_dict(cfg.model, DETECTORS)

    model = model.cuda(device)
    # model.device = device
    if device.type != 'cpu' and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    model.device = device

    datasets = [build_from_dict(cfg.data.train, DATASET)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_from_dict(val_dataset, DATASET))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(config=cfg.text,
                                          CLASSES=datasets[0].CLASSES)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    train_detector(model,
                   datasets,
                   cfg,
                   validate=args.validate,
                   timestamp=timestamp,
                   meta=meta)
コード例 #6
0
ファイル: checkpoint.py プロジェクト: zlx-6/yolodet-pytorch
def load_state_dict(module, state_dict, strict=False):
    """Load state_dict to a module.

    This method is modified from :meth:`torch.nn.Module.load_state_dict`.
    Default value for ``strict`` is set to ``False`` and the message for
    param mismatch will be shown even if strict is False.

    Args:
        module (Module): Module that receives the state_dict.
        state_dict (OrderedDict): Weights.
        strict (bool): whether to strictly enforce that the keys
            in :attr:`state_dict` match the keys returned by this module's
            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
        logger (:obj:`logging.Logger`, optional): Logger to log the error
            message. If not specified, print function will be used.
    """
    logger = Logging.getLogger()
    unexpected_keys = []
    all_missing_keys = []
    err_msg = []

    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    # use _load_from_state_dict to enable checkpoint version control
    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
                                     all_missing_keys, unexpected_keys,
                                     err_msg)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(module)
    load = None  # break load->load reference cycle

    # ignore "num_batches_tracked" of BN layers
    missing_keys = [
        key for key in all_missing_keys if 'num_batches_tracked' not in key
    ]

    if unexpected_keys:
        err_msg.append('unexpected key in source state_dict: {}\n'.format(
            ', '.join(unexpected_keys)))
    if missing_keys:
        err_msg.append('missing keys in source state_dict: {}\n'.format(
            ', '.join(missing_keys)))

    if len(err_msg) > 0:
        err_msg.insert(
            0, 'The model and loaded state dict do not match exactly\n')
        err_msg = '\n'.join(err_msg)
        if strict:
            raise RuntimeError(err_msg)
        elif logger is not None:
            logger.warning(err_msg)
        else:
            print(err_msg)