Exemplo n.º 1
0
def obj_from_dict(info, parent=None, default_args=None):
    """Initialize an object from dict.

    The dict must contain the key "type", which indicates the object type

    Args:
        info (dict): Object types and arguments
        parent (:class:`modules`):
        default_args (dict, optional):
    """
    assert isinstance(info, dict) and "type" in info
    assert isinstance(default_args, dict) or default_args is None
    args = info.copy()
    obj_type = args.pop("type")
    if torchie.is_str(obj_type):
        if parent is not None:
            obj_type = getattr(parent, obj_type)
        else:
            obj_type = sys.modules[obj_type]
    elif not isinstance(obj_type, type):
        raise TypeError("type must be a str or valid type, but got {}".format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_type(**args)
Exemplo n.º 2
0
def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.
    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.
    Returns:
        obj: The constructed object.
    """
    assert isinstance(cfg, dict) and "type" in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = cfg.copy()
    obj_type = args.pop("type")
    if torchie.is_str(obj_type):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                "{} is not in the {} registry".format(obj_type, registry.name)
            )
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            "type must be a str or valid type, but got {}".format(type(obj_type))
        )
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    return obj_cls(**args)
Exemplo n.º 3
0
def print_map_summary(mean_ap, results, dataset=None):
    """Print mAP and results of each class.

    Args:
        mean_ap(float): calculated from `eval_map`
        results(list): calculated from `eval_map`
        dataset(None or str or list): dataset name or dataset classes.
    """
    num_scales = (len(results[0]["ap"])
                  if isinstance(results[0]["ap"], np.ndarray) else 1)
    num_classes = len(results)

    recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
    precisions = np.zeros((num_scales, num_classes), dtype=np.float32)
    aps = np.zeros((num_scales, num_classes), dtype=np.float32)
    num_gts = np.zeros((num_scales, num_classes), dtype=int)
    for i, cls_result in enumerate(results):
        if cls_result["recall"].size > 0:
            recalls[:, i] = np.array(cls_result["recall"], ndmin=2)[:, -1]
            precisions[:, i] = np.array(cls_result["precision"], ndmin=2)[:,
                                                                          -1]
        aps[:, i] = cls_result["ap"]
        num_gts[:, i] = cls_result["num_gts"]

    if dataset is None:
        label_names = [str(i) for i in range(1, num_classes + 1)]
    elif torchie.is_str(dataset):
        label_names = get_classes(dataset)
    else:
        label_names = dataset

    if not isinstance(mean_ap, list):
        mean_ap = [mean_ap]
    header = ["class", "gts", "dets", "recall", "precision", "ap"]
    for i in range(num_scales):
        table_data = [header]
        for j in range(num_classes):
            row_data = [
                label_names[j],
                num_gts[i, j],
                results[j]["num_dets"],
                "{:.3f}".format(recalls[i, j]),
                "{:.3f}".format(precisions[i, j]),
                "{:.3f}".format(aps[i, j]),
            ]
            table_data.append(row_data)
        table_data.append(["mAP", "", "", "", "", "{:.3f}".format(mean_ap[i])])
        table = AsciiTable(table_data)
        table.inner_footing_row_border = True
        print(table.table)
Exemplo n.º 4
0
    def __init__(
        self,
        model,
        batch_processor,
        optimizer=None,
        lr_scheduler=None,
        work_dir=None,
        log_level=logging.INFO,
        logger=None,
        **kwargs,
    ):
        assert callable(batch_processor)
        self.model = model
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self.batch_processor = batch_processor

        # Create work_dir
        if torchie.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            torchie.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError("'work_dir' must be a str or None")

        # Get model name from the model class
        if hasattr(self.model, "module"):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
        if logger is None:
            self.logger = self.init_logger(work_dir, log_level)
        else:
            self.logger = logger
        self.log_buffer = LogBuffer()

        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0
        self._max_epochs = 0
        self._max_iters = 0
        self._example_stats = None
Exemplo n.º 5
0
def get_classes(dataset):
    """Get class names of a dataset."""
    alias2name = {}
    for name, aliases in dataset_aliases.items():
        for alias in aliases:
            alias2name[alias] = name

    if torchie.is_str(dataset):
        if dataset in alias2name:
            labels = eval(alias2name[dataset] + "_classes()")
        else:
            raise ValueError("Unrecognized dataset: {}".format(dataset))
    else:
        raise TypeError("dataset must a str, but got {}".format(type(dataset)))
    return labels