Esempio n. 1
0
def build(task: str, cfg: CfgNode):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        builder task name (track|vos)
    cfg: CfgNode
        buidler configuration

    Returns
    -------
    torch.nn.Module
        module built by builder
    """
    if task in TASK_BACKBONES:
        modules = TASK_BACKBONES[task]
    else:
        logger.error("no backbone for task {}".format(task))
        exit(-1)

    name = cfg.name
    assert name in modules, "backbone {} not registered for {}!".format(
        name, task)
    module = modules[name]()
    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()
    return module
Esempio n. 2
0
def build(task: str, cfg: CfgNode, seed: int = 0) -> TransformerBase:
    r"""
    Arguments
    ---------
    task: str
        task
    cfg: CfgNode
        node name: transformer
    """
    assert task in TASK_TRANSFORMERS, "invalid task name"
    MODULES = TASK_TRANSFORMERS[task]

    names = cfg.names
    modules = []

    for name in names:
        module = MODULES[name](seed=seed)
        hps = module.get_hps()
        hps = merge_cfg_into_hps(cfg[name], hps)
        module.set_hps(hps)
        module.update_params()

        modules.append(module)

    return modules
Esempio n. 3
0
def build(task: str, cfg: CfgNode):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        builder task name (track|vos)
    cfg: CfgNode
        buidler configuration

    Returns
    -------
    torch.nn.Module
        module built by builder
    """
    if task in TASK_HEADS:
        head_modules = TASK_HEADS[task]
    else:
        logger.error("no task model for task {}".format(task))
        exit(-1)

    name = cfg.name
    head_module = head_modules[name]()
    hps = head_module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    head_module.set_hps(hps)
    head_module.update_params()

    return head_module
Esempio n. 4
0
def build(task: str, cfg: CfgNode, seed: int = 0) -> DatasetBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: sampler
    seed: int
        seed for rng initialization
    """
    assert task in TASK_SAMPLERS, "invalid task name"
    MODULES = TASK_SAMPLERS[task]

    submodules_cfg = cfg.submodules

    dataset_cfg = submodules_cfg.dataset
    datasets = dataset_builder.build(task, dataset_cfg)

    filter_cfg = getattr(submodules_cfg, "filter", None)
    filt = filter_builder.build(task,
                                filter_cfg) if filter_cfg is not None else None

    name = cfg.name
    module = MODULES[name](datasets, seed=seed, filt=filt)

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
Esempio n. 5
0
def build(task: str, cfg: CfgNode, seed: int = 0) -> DatapipelineBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: data
    seed: int
        seed for rng initialization
    """
    assert task in TASK_DATAPIPELINES, "invalid task name"
    MODULES = TASK_DATAPIPELINES[task]

    sampler = build_sampler(task, cfg.sampler, seed=seed)
    transformers = build_transformer(task, cfg.transformer, seed=seed)
    target = build_target(task, cfg.target)

    pipeline = []
    pipeline.extend(transformers)
    pipeline.append(target)

    cfg = cfg.datapipeline
    name = cfg.name
    module = MODULES[name](sampler, pipeline)

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
Esempio n. 6
0
def build(task: str, cfg: CfgNode, model: ModuleBase):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        task name
    cfg: CfgNode
        buidler configuration
    model: ModuleBase
        model instance

    Returns
    -------
    torch.nn.Module
        module built by builder
    """
    assert task in PIPELINES, "no pipeline for task {}".format(task)
    pipelines = PIPELINES[task]
    pipeline_name = cfg.name
    pipeline = pipelines[pipeline_name](model)
    hps = pipeline.get_hps()
    hps = merge_cfg_into_hps(cfg[pipeline_name], hps)
    pipeline.set_hps(hps)
    pipeline.update_params()

    return pipeline
Esempio n. 7
0
def build(task: str, cfg: CfgNode, pipeline: PipelineBase):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        builder task name (track|vos)
    cfg: CfgNode
        buidler configuration

    Returns
    -------
    TesterBase
        tester built by builder
    """
    assert task in TASK_TESTERS, "no tester for task {}".format(task)
    MODULES = TASK_TESTERS[task]

    names = cfg.tester.names
    testers = []
    # tester for multiple experiments
    for name in names:
        tester = MODULES[name](pipeline)
        hps = tester.get_hps()
        hps = merge_cfg_into_hps(cfg.tester[name], hps)
        tester.set_hps(hps)
        tester.update_params()
        testers.append(tester)
    return testers
Esempio n. 8
0
def build(task: str, cfg: CfgNode) -> DatasetBase:
    assert task in TASK_FILTERS, "invalid task name"
    MODULES = TASK_FILTERS[task]

    name = cfg.name
    module = MODULES[name]()
    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
Esempio n. 9
0
def build(task: str,
          cfg: CfgNode,
          backbone: ModuleBase,
          head: ModuleBase,
          loss: ModuleBase = None):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        builder task name
    cfg: CfgNode
        buidler configuration
    backbone: torch.nn.Module
        backbone used by task module.
    head: torch.nn.Module
        head network used by task module.
    loss: torch.nn.Module
        criterion module used by task module (for training). None in case other than training.

    Returns
    -------
    torch.nn.Module
        task module built by builder
    """
    if task == "track":
        task_modules = TRACK_TASKMODELS
    elif task == "vos":
        task_modules = VOS_TASKMODELS
    else:
        logger.error("no task model for task {}".format(task))
        exit(-1)

    if task == "track":
        name = cfg.name
        task_module = task_modules[name](backbone, head, loss)
        hps = task_module.get_hps()
        hps = merge_cfg_into_hps(cfg[name], hps)
        task_module.set_hps(hps)
        task_module.update_params()
        return task_module
    else:
        logger.error("task model {} is not completed".format(task))
        exit(-1)
Esempio n. 10
0
def build(
        task: str,
        cfg: CfgNode,
        model: ModuleBase = None,
        segmenter: ModuleBase = None,
        tracker: ModuleBase = None,
):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        task name
    cfg: CfgNode
        buidler configuration
    model: ModuleBase
        model instance for siamfcpp
    segmenter: ModuleBase
        segmenter instance for tracker
    tracker: ModuleBase
        model instance for tracker

    Returns
    -------
    torch.nn.Module
        module built by builder
    """
    assert task in PIPELINES, "no pipeline for task {}".format(task)
    pipelines = PIPELINES[task]
    pipeline_name = cfg.name

    if task == 'track':
        pipeline = pipelines[pipeline_name](model)
    elif task == 'vos':
        pipeline = pipelines[pipeline_name](segmenter, tracker)
    else:
        logger.error("unknown task {} for pipline".format(task))
        exit(-1)
    hps = pipeline.get_hps()
    hps = merge_cfg_into_hps(cfg[pipeline_name], hps)
    pipeline.set_hps(hps)
    pipeline.update_params()

    return pipeline
Esempio n. 11
0
def build(task: str, cfg: CfgNode, basemodel=None):
    r"""
    backbone build 函数

    根据传入的 task(track|vos) 及配置信息构建并返回 backbone 模块

    Builder function.

    Arguments
    ---------
    task: str
        builder task name (track|vos)
    cfg: CfgNode
        buidler configuration

    basemodel:
        warp backbone into encoder if not None

    Returns
    -------
    torch.nn.Module
        module built by builder
    """
    if task in TASK_BACKBONES:
        modules = TASK_BACKBONES[
            task]  # 检索对应的 backbones(TRACK_BACKBONES|VOS_BACKBONES)
    else:
        logger.error("no backbone for task {}".format(task))
        exit(-1)  # 检索失败

    name = cfg.name  # 配置信息中的目标 backbone 名称
    assert name in modules, "backbone {} not registered for {}!".format(
        name, task)

    if basemodel:
        module = modules[name](basemodel)
    else:
        module = modules[name](
        )  # 根据 backbone 名称检索对应的 backbone module (如 AlexNet)并实例化

    hps = module.get_hps()  # 获取 backbone 实例的超参数字典
    hps = merge_cfg_into_hps(cfg[name], hps)  # 将配置信息中的该 backbone 类的超参数数据合并入字典中
    module.set_hps(hps)  # 重新设置 backbone 实例的超参数
    module.update_params()
    return module
Esempio n. 12
0
def build(task: str, cfg: CfgNode, model: nn.Module) -> OptimizerBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: optim
    """
    name = cfg.name
    module = OPTIMIZERS[name](cfg, model)

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
Esempio n. 13
0
def build_sat_model(task: str,
                    cfg: CfgNode,
                    gml_extractor=None,
                    joint_encoder=None,
                    decoder=None,
                    loss: ModuleBase = None):
    r"""
    Builder function for SAT.

    Arguments
    ---------
    task: str
        builder task name
    cfg: CfgNode
        buidler configuration
    gml_extractor: torch.nn.Module
        feature extractor for global modeling loop
    joint_encoder: torch.nn.Module
        joint encoder
    decoder: torch.nn.Module
        decoder for SAT
    loss: torch.nn.Module
        criterion module used by task module (for training). None in case other than training.

    Returns
    -------
    torch.nn.Module
        task module built by builder
    """

    if task == "vos":
        task_modules = TASK_TASKMODELS[task]
    else:
        logger.error("sat model builder could not build task {}".format(task))
        exit(-1)
    name = cfg.name
    #SatVOS
    task_module = task_modules[name](gml_extractor, joint_encoder, decoder,
                                     loss)
    hps = task_module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    task_module.set_hps(hps)
    task_module.update_params()
    return task_module
Esempio n. 14
0
def build(task: str, cfg: CfgNode) -> TargetBase:
    r"""
    Arguments
    ---------
    task: str
        task
    cfg: CfgNode
        node name: target
    """
    assert task in TASK_TARGETS, "invalid task name"
    MODULES = TASK_TARGETS[task]

    name = cfg.name
    module = MODULES[name]()
    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
Esempio n. 15
0
def build(task: str, cfg: CfgNode, model: nn.Module) -> OptimizerBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: optim
    """
    assert task in TASK_OPTIMIZERS, "invalid task name"
    MODULES = TASK_OPTIMIZERS[task]
    name = cfg.name
    module = MODULES[name](cfg, model)

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
Esempio n. 16
0
def build(task: str, cfg: CfgNode) -> ContribModule:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: contrib_module
    """
    assert task in TASK_CONTRIB_MODULES, "invalid task name"
    MODULES = TASK_CONTRIB_MODULES[task]

    name = cfg.name
    module = MODULES[name]()

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
Esempio n. 17
0
def build(task: str, cfg: CfgNode) -> GradModifierBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: scheduler
    seed: int
        seed for rng initialization
    """

    name = cfg.name
    module = GRAD_MODIFIERS[name]()

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
Esempio n. 18
0
def build(task: str, cfg: CfgNode) -> TemplateModuleBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: template_module
    """
    assert task in TASK_TEMPLATE_MODULES, "invalid task name"
    MODULES = TASK_TEMPLATE_MODULES[task]  # # 检索 task 对应的 module

    name = cfg.name
    module = MODULES[name]()

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
Esempio n. 19
0
def build(task: str, cfg: CfgNode, pyramid_model=None):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        builder task name (track|vos)
    cfg: CfgNode
        builder configuration

    pyramid_model:
        warp pyramid into encoder if not None

    Returns
    -------
    torch.nn.Module
        module built by builder
    """
    if task in TASK_PYRAMIDS:
        modules = TASK_PYRAMIDS[task]
    else:
        logger.error("no pyramid for task {}".format(task))
        exit(-1)

    name = cfg.name
    assert name in modules, "pyramid {} not registered for {}!".format(
        name, task)

    if pyramid_model:
        module = modules[name](pyramid_model)
    else:
        module = modules[name]()

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()
    return module
Esempio n. 20
0
def build_track_dual_backbone(task: str,
                              cfg: CfgNode,
                              basemodel_target=None,
                              basemodel_search=None,
                              head=None,
                              loss: ModuleBase = None):
    r"""
        Builder function for SiamFCpp
        In case of the siamese branches do not share weights

        Arguments
        ---------
        task: str
            builder task name
        cfg: CfgNode
            buidler configuration
        basemodel_target: torch.nn.Module
            backbone used by target image backbone.
        basemodel_search: torch.nn.Module
            backbone used by search image backbone.
        head: torch.nn.Module
            head network used by task module.
        loss: torch.nn.Module
            criterion module used by task module (for training). None in case other than training.

        Returns
        -------
        torch.nn.Module
            task module built by builder
        """
    task_modules = TASK_TASKMODELS[task]
    name = cfg.name
    task_module = task_modules[name](basemodel_target, basemodel_search, head,
                                     loss)
    hps = task_module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    task_module.set_hps(hps)
    task_module.update_params()
    return task_module
Esempio n. 21
0
def build(task: str, cfg: CfgNode) -> DatasetBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: dataset
    """
    assert task in TASK_DATASETS, "invalid task name"
    dataset_modules = TASK_DATASETS[task]

    names = cfg.names
    modules = []
    for name in names:
        module = dataset_modules[name]()
        hps = module.get_hps()
        hps = merge_cfg_into_hps(cfg[name], hps)
        module.set_hps(hps)
        module.update_params()
        modules.append(module)

    return modules
Esempio n. 22
0
def build(task: str, cfg: CfgNode, seed: int = 0) -> DatapipelineBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: data
    seed: int
        seed for rng initialization
    """
    assert task in TASK_DATAPIPELINES, "invalid task name"
    MODULES = TASK_DATAPIPELINES[task]

    sampler = build_sampler(
        task, cfg.sampler,
        seed=seed)  # 从同一个序列中产生一个positive pair或者是从不同序列中产生一个negative pair
    transformers = build_transformer(
        task, cfg.transformer, seed=seed
    )  # 数据增广,主要有scale和shift,使search patch和templete patch并不是正好目标中心对着的,会有一点偏移
    target = build_target(task, cfg.target)  # 构造label

    pipeline = []
    pipeline.extend(transformers)
    pipeline.append(target)

    cfg = cfg.datapipeline
    name = cfg.name
    module = MODULES[name](sampler, pipeline)

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module