Пример #1
0
def save_checkpoint(model, filename, optimizer=None, meta=None):
    """Save checkpoint to file.

    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
    ``optimizer``. By default ``meta`` will contain version and time info.

    Args:
        model (Module): Module whose params are to be saved.
        filename (str): Checkpoint filename.
        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
        meta (dict, optional): Metadata to be saved in checkpoint.
    """
    if meta is None:
        meta = {}
    elif not isinstance(meta, dict):
        raise TypeError('meta must be a dict or None, but got {}'.format(
            type(meta)))

    file_utils.mkdir_or_exist(filename)

    if hasattr(model, 'module'):
        model = model.module

    checkpoint = {
        'meta': meta,
        # 'state_dict': model.state_dict(),
        'model': model,
    }
    if optimizer is not None:
        checkpoint['optimizer'] = optimizer.state_dict()

    torch.save(checkpoint, filename)
Пример #2
0
    def __init__(self,
                 model,
                 batch_processor,
                 optimizer=None,
                 work_dir=None,
                 logger=None,
                 meta=None,
                 ema=None):
        assert callable(batch_processor)
        self.model = model
        if meta is not None:
            assert isinstance(meta, dict), '"meta" must be a dict or None'
        self.meta = meta

        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0
        self._max_epochs = 0
        self._max_iters = 0
        self._warmup_max_iters = 0
        self._momentum = 0

        self.batch_processor = batch_processor

        # create work_dir
        if isinstance(work_dir, str):
            self.work_dir = osp.abspath(work_dir)
            file_utils.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__
        if logger is None:
            self.logger = Logging.getLogger()
        else:
            self.logger = logger
        self.log_buffer = LogBuffer()

        if optimizer is not None:
            self.optimizer = self.init_optimizer(optimizer)
        else:
            self.optimizer = None

        if ema is not None:
            self.ema = self.init_ema(ema)
        else:
            self.ema = None
Пример #3
0
def show_result(img, result, class_names, show=True, out_file=None):
    # assert isinstance(class_names, (None,tuple, list))
    assert isinstance(img, (str, np.ndarray))
    if isinstance(img, str):
        img = cv2.imdecode(np.fromfile(img, dtype=np.uint8), -1)
    img = img.copy()

    for rslt in result:
        label = rslt['label']
        score = rslt['score']
        bbox_int = rslt['bbox']
        left_top = (bbox_int[0], bbox_int[1])
        right_bottom = (bbox_int[2], bbox_int[3])
        cv2.rectangle(img,
                      left_top,
                      right_bottom,
                      color=(0, 0, 255),
                      thickness=2)
        label_text = class_names[
            label] if class_names is not None else 'cls {}'.format(label)
        label_text += '|{:.02f}'.format(score)
        cv2.putText(img,
                    label_text, (bbox_int[0], bbox_int[1] - 2),
                    cv2.FONT_HERSHEY_COMPLEX,
                    fontScale=0.5,
                    color=(0, 255, 0))

    if out_file is not None:
        dir_name = osp.abspath(osp.dirname(out_file))
        mkdir_or_exist(dir_name)
        cv2.imwrite(out_file, img)
    if show:
        win_name = 'inference result images'
        wait_time = 0
        cv2.imshow(win_name, img)
        if wait_time == 0:  # prevent from hangning if windows was closed
            while True:
                ret = cv2.waitKey(1)

                closed = cv2.getWindowProperty(win_name,
                                               cv2.WND_PROP_VISIBLE) < 1
                # if user closed window or if some key pressed
                if closed or ret != -1:
                    break
        else:
            ret = cv2.waitKey(wait_time)

    return img
Пример #4
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.device is not None:
        cfg.device = args.device
    else:
        cfg.device = None
    device = select_device(cfg.device)
    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8

    # create work_dir
    file_utils.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # init the logger before other steps
    # timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    # log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
    logger = Logging.getLogger()

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([('{}: {}'.format(k, v))
                          for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info
    meta['batch_size'] = cfg.data.batch_size
    meta['subdivisions'] = cfg.data.subdivisions
    meta['multi_scale'] = args.multi_scale
    # log some basic info
    logger.info('Config:\n{}'.format(cfg.text))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}, deterministic: {}'.format(
            args.seed, args.deterministic))
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed
    model = build_from_dict(cfg.model, DETECTORS)

    model = model.cuda(device)
    # model.device = device
    if device.type != 'cpu' and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    model.device = device

    datasets = [build_from_dict(cfg.data.train, DATASET)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_from_dict(val_dataset, DATASET))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(config=cfg.text,
                                          CLASSES=datasets[0].CLASSES)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    train_detector(model,
                   datasets,
                   cfg,
                   validate=args.validate,
                   timestamp=timestamp,
                   meta=meta)