def save_checkpoint(model, filename, optimizer=None, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError('meta must be a dict or None, but got {}'.format( type(meta))) file_utils.mkdir_or_exist(filename) if hasattr(model, 'module'): model = model.module checkpoint = { 'meta': meta, # 'state_dict': model.state_dict(), 'model': model, } if optimizer is not None: checkpoint['optimizer'] = optimizer.state_dict() torch.save(checkpoint, filename)
def __init__(self, model, batch_processor, optimizer=None, work_dir=None, logger=None, meta=None, ema=None): assert callable(batch_processor) self.model = model if meta is not None: assert isinstance(meta, dict), '"meta" must be a dict or None' self.meta = meta self.mode = None self._hooks = [] self._epoch = 0 self._iter = 0 self._inner_iter = 0 self._max_epochs = 0 self._max_iters = 0 self._warmup_max_iters = 0 self._momentum = 0 self.batch_processor = batch_processor # create work_dir if isinstance(work_dir, str): self.work_dir = osp.abspath(work_dir) file_utils.mkdir_or_exist(self.work_dir) elif work_dir is None: self.work_dir = None else: raise TypeError('"work_dir" must be a str or None') # get model name from the model class if hasattr(self.model, 'module'): self._model_name = self.model.module.__class__.__name__ else: self._model_name = self.model.__class__.__name__ if logger is None: self.logger = Logging.getLogger() else: self.logger = logger self.log_buffer = LogBuffer() if optimizer is not None: self.optimizer = self.init_optimizer(optimizer) else: self.optimizer = None if ema is not None: self.ema = self.init_ema(ema) else: self.ema = None
def show_result(img, result, class_names, show=True, out_file=None): # assert isinstance(class_names, (None,tuple, list)) assert isinstance(img, (str, np.ndarray)) if isinstance(img, str): img = cv2.imdecode(np.fromfile(img, dtype=np.uint8), -1) img = img.copy() for rslt in result: label = rslt['label'] score = rslt['score'] bbox_int = rslt['bbox'] left_top = (bbox_int[0], bbox_int[1]) right_bottom = (bbox_int[2], bbox_int[3]) cv2.rectangle(img, left_top, right_bottom, color=(0, 0, 255), thickness=2) label_text = class_names[ label] if class_names is not None else 'cls {}'.format(label) label_text += '|{:.02f}'.format(score) cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 2), cv2.FONT_HERSHEY_COMPLEX, fontScale=0.5, color=(0, 255, 0)) if out_file is not None: dir_name = osp.abspath(osp.dirname(out_file)) mkdir_or_exist(dir_name) cv2.imwrite(out_file, img) if show: win_name = 'inference result images' wait_time = 0 cv2.imshow(win_name, img) if wait_time == 0: # prevent from hangning if windows was closed while True: ret = cv2.waitKey(1) closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1 # if user closed window or if some key pressed if closed or ret != -1: break else: ret = cv2.waitKey(wait_time) return img
def main(): args = parse_args() cfg = Config.fromfile(args.config) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # update configs according to CLI args if args.work_dir is not None: cfg.work_dir = args.work_dir if args.resume_from is not None: cfg.resume_from = args.resume_from if args.device is not None: cfg.device = args.device else: cfg.device = None device = select_device(cfg.device) if args.autoscale_lr: # apply the linear scaling rule (https://arxiv.org/abs/1706.02677) cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8 # create work_dir file_utils.mkdir_or_exist(osp.abspath(cfg.work_dir)) # init the logger before other steps # timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) # log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp)) logger = Logging.getLogger() # init the meta dict to record some important information such as # environment info and seed, which will be logged meta = dict() # log env info env_info_dict = collect_env() env_info = '\n'.join([('{}: {}'.format(k, v)) for k, v in env_info_dict.items()]) dash_line = '-' * 60 + '\n' logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line) meta['env_info'] = env_info meta['batch_size'] = cfg.data.batch_size meta['subdivisions'] = cfg.data.subdivisions meta['multi_scale'] = args.multi_scale # log some basic info logger.info('Config:\n{}'.format(cfg.text)) # set random seeds if args.seed is not None: logger.info('Set random seed to {}, deterministic: {}'.format( args.seed, args.deterministic)) set_random_seed(args.seed, deterministic=args.deterministic) cfg.seed = args.seed meta['seed'] = args.seed model = build_from_dict(cfg.model, DETECTORS) model = model.cuda(device) # model.device = device if device.type != 'cpu' and torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.device = device datasets = [build_from_dict(cfg.data.train, DATASET)] if len(cfg.workflow) == 2: val_dataset = copy.deepcopy(cfg.data.val) val_dataset.pipeline = cfg.data.train.pipeline datasets.append(build_from_dict(val_dataset, DATASET)) if cfg.checkpoint_config is not None: # save mmdet version, config file content and class names in # checkpoints as meta data cfg.checkpoint_config.meta = dict(config=cfg.text, CLASSES=datasets[0].CLASSES) # add an attribute for visualization convenience model.CLASSES = datasets[0].CLASSES timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) train_detector(model, datasets, cfg, validate=args.validate, timestamp=timestamp, meta=meta)