예제 #1
0
def build(task: str, cfg: CfgNode, seed: int = 0) -> TransformerBase:
    r"""
    Arguments
    ---------
    task: str
        task
    cfg: CfgNode
        node name: transformer
    """
    assert task in TASK_TRANSFORMERS, "invalid task name"
    MODULES = TASK_TRANSFORMERS[task]

    names = cfg.names
    modules = []

    for name in names:
        module = MODULES[name](seed=seed)
        hps = module.get_hps()
        hps = merge_cfg_into_hps(cfg[name], hps)
        module.set_hps(hps)
        module.update_params()

        modules.append(module)

    return modules
예제 #2
0
def build(task: str, cfg: CfgNode):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        builder task name (track|vos)
    cfg: CfgNode
        buidler configuration

    Returns
    -------
    torch.nn.Module
        module built by builder
    """
    if task in TASK_HEADS:
        head_modules = TASK_HEADS[task]
    else:
        logger.error("no task model for task {}".format(task))
        exit(-1)

    name = cfg.name
    head_module = head_modules[name]()
    hps = head_module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    head_module.set_hps(hps)
    head_module.update_params()

    return head_module
예제 #3
0
def build(task: str, cfg: CfgNode, seed: int = 0) -> DatasetBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: sampler
    seed: int
        seed for rng initialization
    """
    assert task in TASK_SAMPLERS, "invalid task name"
    MODULES = TASK_SAMPLERS[task]

    submodules_cfg = cfg.submodules

    dataset_cfg = submodules_cfg.dataset
    datasets = dataset_builder.build(task, dataset_cfg)

    if submodules_cfg.filter.name != "":
        filter_cfg = submodules_cfg.filter
        data_filter = filter_builder.build(task, filter_cfg)
    else:
        data_filter = None

    name = cfg.name
    module = MODULES[name](datasets, seed=seed, data_filter=data_filter)

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
예제 #4
0
def build(task: str, cfg: CfgNode, pipeline: PipelineBase):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        builder task name (track|vos)
    cfg: CfgNode
        buidler configuration, 
        node nams: tester

    Returns
    -------
    TesterBase
        tester built by builder
    """
    assert task in TASK_TESTERS, "no tester for task {}".format(task)
    MODULES = TASK_TESTERS[task]

    names = cfg.names
    testers = []
    # tester for multiple experiments
    for name in names:
        tester = MODULES[name](pipeline)
        hps = tester.get_hps()
        hps = merge_cfg_into_hps(cfg[name], hps)
        tester.set_hps(hps)
        tester.update_params()
        testers.append(tester)
    return testers
예제 #5
0
def build(task: str, cfg: CfgNode, seed: int = 0) -> DatapipelineBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: data
    seed: int
        seed for rng initialization
    """
    assert task in TASK_DATAPIPELINES, "invalid task name"
    MODULES = TASK_DATAPIPELINES[task]

    sampler = build_sampler(task, cfg.sampler, seed=seed)
    transformers = build_transformer(task, cfg.transformer, seed=seed)
    target = build_target(task, cfg.target)

    pipeline = []
    pipeline.extend(transformers)
    pipeline.append(target)

    cfg = cfg.datapipeline
    name = cfg.name
    module = MODULES[name](sampler, pipeline)

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
예제 #6
0
def build(task: str, cfg: CfgNode) -> DatasetBase:
    assert task in TASK_FILTERS, "invalid task name"
    MODULES = TASK_FILTERS[task]

    name = cfg.name
    module = MODULES[name]()
    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
예제 #7
0
def build(task: str, cfg: CfgNode, model: nn.Module) -> OptimizerBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: optim
    """
    name = cfg.name
    module = OPTIMIZERS[name](cfg, model)

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
예제 #8
0
def build_sat_model(task: str,
                    cfg: CfgNode,
                    gml_extractor=None,
                    joint_encoder=None,
                    decoder=None,
                    loss: ModuleBase = None):
    r"""
    Builder function for SAT.

    Arguments
    ---------
    task: str
        builder task name
    cfg: CfgNode
        buidler configuration
    gml_extractor: torch.nn.Module
        feature extractor for global modeling loop
    joint_encoder: torch.nn.Module
        joint encoder
    decoder: torch.nn.Module
        decoder for SAT
    loss: torch.nn.Module
        criterion module used by task module (for training). None in case other than training.

    Returns
    -------
    torch.nn.Module
        task module built by builder
    """

    if task == "vos":
        task_modules = TASK_TASKMODELS[task]
    else:
        logger.error("sat model builder could not build task {}".format(task))
        exit(-1)
    name = cfg.name
    #SatVOS
    task_module = task_modules[name](gml_extractor, joint_encoder, decoder,
                                     loss)
    hps = task_module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    task_module.set_hps(hps)
    task_module.update_params()
    return task_module
예제 #9
0
def build(
        task: str,
        cfg: CfgNode,
        model: ModuleBase = None,
        segmenter: ModuleBase = None,
        tracker: ModuleBase = None,
):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        task name
    cfg: CfgNode
        buidler configuration
    model: ModuleBase
        model instance for siamfcpp
    segmenter: ModuleBase
        segmenter instance for tracker
    tracker: ModuleBase
        model instance for tracker

    Returns
    -------
    torch.nn.Module
        module built by builder
    """
    assert task in PIPELINES, "no pipeline for task {}".format(task)
    pipelines = PIPELINES[task]
    pipeline_name = cfg.name

    if task == 'track':
        pipeline = pipelines[pipeline_name](model)
    elif task == 'vos':
        pipeline = pipelines[pipeline_name](segmenter, tracker)

    hps = pipeline.get_hps()
    hps = merge_cfg_into_hps(cfg[pipeline_name], hps)
    pipeline.set_hps(hps)
    pipeline.update_params()

    return pipeline
예제 #10
0
def build(task: str,
          cfg: CfgNode,
          backbone: ModuleBase,
          head: ModuleBase,
          loss: ModuleBase = None):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        builder task name
    cfg: CfgNode
        buidler configuration
    backbone: torch.nn.Module
        backbone used by task module.
    head: torch.nn.Module
        head network used by task module.
    loss: torch.nn.Module
        criterion module used by task module (for training). None in case other than training.

    Returns
    -------
    torch.nn.Module
        task module built by builder
    """
    if task in TASK_TASKMODELS:
        task_modules = TASK_TASKMODELS[task]
    else:
        logger.error("no task model for task {}".format(task))
        exit(-1)

    if task == "track":
        name = cfg.name
        task_module = task_modules[name](backbone, head, loss)
        hps = task_module.get_hps()
        hps = merge_cfg_into_hps(cfg[name], hps)
        task_module.set_hps(hps)
        task_module.update_params()
        return task_module
    else:
        logger.error("task model {} is not completed".format(task))
        exit(-1)
예제 #11
0
def build(task: str, cfg: CfgNode) -> TargetBase:
    r"""
    Arguments
    ---------
    task: str
        task
    cfg: CfgNode
        node name: target
    """
    assert task in TASK_TARGETS, "invalid task name"
    MODULES = TASK_TARGETS[task]

    name = cfg.name
    module = MODULES[name]()
    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
예제 #12
0
def build(task: str, cfg: CfgNode) -> GradModifierBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: scheduler
    seed: int
        seed for rng initialization
    """

    name = cfg.name
    module = GRAD_MODIFIERS[name]()

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()

    return module
예제 #13
0
def build(task: str, cfg: CfgNode, basemodel=None):
    r"""
    Builder function.

    Arguments
    ---------
    task: str
        builder task name (track|vos)
    cfg: CfgNode
        buidler configuration

    basemodel:
        warp backbone into encoder if not None

    Returns
    -------
    torch.nn.Module
        module built by builder
    """
    if task in TASK_BACKBONES:
        modules = TASK_BACKBONES[task]
    else:
        logger.error("no backbone for task {}".format(task))
        exit(-1)

    name = cfg.name
    assert name in modules, "backbone {} not registered for {}!".format(
        name, task)

    if basemodel:
        module = modules[name](basemodel)
    else:
        module = modules[name]()

    hps = module.get_hps()
    hps = merge_cfg_into_hps(cfg[name], hps)
    module.set_hps(hps)
    module.update_params()
    return module
예제 #14
0
def build(task: str, cfg: CfgNode) -> DatasetBase:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: dataset
    """
    assert task in TASK_DATASETS, "invalid task name"
    dataset_modules = TASK_DATASETS[task]

    names = cfg.names
    modules = []
    for name in names:
        module = dataset_modules[name]()
        hps = module.get_hps()
        hps = merge_cfg_into_hps(cfg[name], hps)
        module.set_hps(hps)
        module.update_params()
        modules.append(module)

    return modules