Esempio n. 1
0
 def _init_dataloader(self, mode, loader=None, transforms=None):
     """Init dataloader."""
     if loader is not None:
         return loader
     if mode == "train" and self.hps is not None and self.hps.get(
             "dataset") is not None:
         if self.hps.get("dataset") and self.hps.get("dataset").get('type'):
             dataset_cls = ClassFactory.get_cls(
                 ClassType.DATASET,
                 self.hps.get("dataset").get('type'))
         else:
             dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset"))
     elif self.hps:
         if self.hps.get("dataset") and self.hps.get("dataset").get('type'):
             dataset_cls = ClassFactory.get_cls(
                 ClassType.DATASET,
                 self.hps.get("dataset").get('type'))
             dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset"))
         else:
             dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
             dataset = dataset_cls(mode=mode)
     else:
         dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         dataset = dataset_cls(mode=mode)
     if transforms is not None:
         dataset.transforms = transforms
     if self.distributed and mode == "train":
         dataset.set_distributed(self._world_size, self._rank_id)
     # adapt the dataset to specific backend
     dataloader = Adapter(dataset).loader
     return dataloader
Esempio n. 2
0
def from_module(module):
    """From Model."""
    name = module.__class__.__name__
    if ClassFactory.is_exists(ClassType.NETWORK, name):
        module_cls = ClassFactory.get_cls(ClassType.NETWORK, name)
        if hasattr(module_cls, "from_module"):
            return module_cls.from_module(module)
    return module
Esempio n. 3
0
def register_datasets(backend):
    """Import and register datasets automatically."""
    if backend == "pytorch":
        from . import pytorch
    elif backend == "tensorflow":
        from . import tensorflow
    elif backend == "mindspore":
        import mindspore.dataset
        from . import mindspore
    ClassFactory.lazy_register("vega.datasets.common", {"imagenet": ["Imagenet"]})
    from . import common
    from . import transforms
Esempio n. 4
0
 def _train_single_model(self,
                         model_desc=None,
                         model_id=None,
                         weights_file=None):
     cls_trainer = ClassFactory.get_cls(ClassType.TRAINER,
                                        PipeStepConfig.trainer.type)
     step_name = self.task.step_name
     if model_desc is not None:
         sample = dict(worker_id=model_id,
                       desc=model_desc,
                       step_name=step_name)
         record = ReportRecord().load_dict(sample)
         logging.debug("update record=%s", str(record))
         trainer = cls_trainer(model_desc=model_desc,
                               id=model_id,
                               pretrained_model_file=weights_file)
     else:
         trainer = cls_trainer(None, 0)
         record = ReportRecord(trainer.step_name,
                               trainer.worker_id,
                               desc=trainer.model_desc)
     ReportClient().update(**record.to_dict())
     # resume training
     if vega.is_torch_backend() and General._resume:
         trainer.load_checkpoint = True
         trainer._resume_training = True
     if self._distributed_training:
         self._do_distributed_fully_train(trainer)
     else:
         self._do_single_fully_train(trainer)
Esempio n. 5
0
    def _do_hccl_fully_train(self, trainer):
        origin_worker_id = trainer.worker_id
        model_desc = trainer.model_desc
        del trainer

        os.environ['RANK_SIZE'] = os.environ['ORIGIN_RANK_SIZE']
        os.environ['RANK_TABLE_FILE'] = os.environ['ORIGIN_RANK_TABLE_FILE']
        origin_parallel_fully_train = General.parallel_fully_train
        origin_parallel = General._parallel
        General.parallel_fully_train = True
        General.dft = True
        General._parallel = True

        cls_trainer = ClassFactory.get_cls(ClassType.TRAINER,
                                           PipeStepConfig.trainer.type)
        self.master = create_master()
        workers_num = int(os.environ['RANK_SIZE'])
        for i in range(workers_num):
            worker_id = "{}-{}".format(origin_worker_id, i)
            trainer = cls_trainer(model_desc, id=worker_id)
            evaluator = self._get_evaluator(
                worker_id) if os.environ['DEVICE_ID'] == "0" else None
            self.master.run(trainer, evaluator)

        self.master.join()
        self.master.close()
        General.parallel_fully_train = origin_parallel_fully_train
        General.dft = False
        General._parallel = origin_parallel
Esempio n. 6
0
 def _get_evaluator(self, worker_id):
     if not PipeStepConfig.evaluator_enable:
         return None
     cls_evaluator = ClassFactory.get_cls('evaluator', "Evaluator")
     evaluator = cls_evaluator({
         "step_name": self.task.step_name,
         "worker_id": worker_id
     })
     return evaluator
Esempio n. 7
0
 def _train_single_model(self, model_desc, model_id, hps, multi_task):
     cls_trainer = ClassFactory.get_cls(ClassType.TRAINER, PipeStepConfig.trainer.type)
     step_name = self.task.step_name
     sample = dict(worker_id=model_id, desc=model_desc, step_name=step_name)
     record = ReportRecord().load_dict(sample)
     logging.debug("update record=%s", str(record))
     trainer = cls_trainer(model_desc=model_desc, id=model_id, hps=hps, multi_task=multi_task)
     ReportClient().update(**record.to_dict())
     if self._distributed_training:
         self._do_distributed_fully_train(trainer)
     else:
         self._do_single_fully_train(trainer)
Esempio n. 8
0
 def to_model(self):
     """Transform a NetworkDesc to a special model."""
     logging.debug("Start to Create a Network.")
     module = ClassFactory.get_cls(ClassType.NETWORK, "Module")
     model = module.from_desc(self._desc)
     if not model:
         raise Exception("Failed to create model, model desc={}".format(
             self._desc))
     model.desc = self._desc
     if hasattr(model, '_apply_names'):
         model._apply_names()
     return model
Esempio n. 9
0
 def __new__(cls,
             model=None,
             id=None,
             hps=None,
             load_ckpt_flag=False,
             model_desc=None,
             **kwargs):
     """Create Trainer clss."""
     if vega.is_torch_backend():
         trainer_cls = ClassFactory.get_cls(ClassType.TRAINER,
                                            "TrainerTorch")
     elif vega.is_tf_backend():
         trainer_cls = ClassFactory.get_cls(ClassType.TRAINER, "TrainerTf")
     else:
         trainer_cls = ClassFactory.get_cls(ClassType.TRAINER, "TrainerMs")
     return trainer_cls(model=model,
                        id=id,
                        hps=hps,
                        load_ckpt_flag=load_ckpt_flag,
                        model_desc=model_desc,
                        **kwargs)
Esempio n. 10
0
 def __init__(self,
              search_space,
              num_samples,
              max_epochs,
              repeat_times,
              min_epochs=1,
              eta=3,
              multi_obj=False,
              random_samples=None,
              prob_crossover=0.6,
              prob_mutatation=0.2,
              tuner="GP"):
     """Init BOHB."""
     super().__init__(search_space, num_samples, max_epochs, min_epochs,
                      eta)
     # init all the configs
     self.repeat_times = repeat_times
     self.max_epochs = max_epochs
     self.iter_list, self.min_epoch_list = self._get_total_iters(
         num_samples, max_epochs, self.repeat_times, min_epochs, eta)
     self.additional_samples = self._get_additional_samples(eta)
     if random_samples is not None:
         self.random_samples = random_samples
     else:
         self.random_samples = max(self.iter_list[0][0], 2)
     self.tuner_name = "GA" if multi_obj else tuner
     logger.info(
         f"bohb info, iter list: {self.iter_list}, min epoch list: {self.min_epoch_list}, "
         f"addition samples: {self.additional_samples}, tuner: {self.tuner_name}, "
         f"random samples: {self.random_samples}")
     # create sha list
     self.sha_list = self._create_sha_list(search_space, self.iter_list,
                                           self.min_epoch_list,
                                           self.repeat_times)
     # create tuner
     if multi_obj:
         self.tuner = GeneticAlgorithm(search_space,
                                       random_samples=self.random_samples,
                                       prob_crossover=prob_crossover,
                                       prob_mutatation=prob_mutatation)
     elif self.tuner_name == "hebo":
         self.tuner = ClassFactory.get_cls(ClassType.SEARCH_ALGORITHM,
                                           "HeboAdaptor")(search_space)
     else:
         self.tuner = TunerBuilder(search_space, tuner=tuner)
Esempio n. 11
0
 def __init__(self, desc=None):
     """Init SearchSpace."""
     super(SearchSpace, self).__init__()
     if desc is None:
         desc = SearchSpaceConfig().to_dict()
         if desc.type is not None:
             desc = ClassFactory.get_cls(ClassType.SEARCHSPACE,
                                         desc.type).get_space(desc)
     for name, item in desc.items():
         self.__setattr__(name, item)
         self.__setitem__(name, item)
     self._params = OrderedDict()
     self._condition_dict = OrderedDict()
     self._forbidden_list = []
     self._hp_count = 0
     self._dag = DAG()
     if desc is not None:
         self.form_desc(desc)
Esempio n. 12
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.

"""Lazy import necks."""

from vega.common.class_factory import ClassFactory


ClassFactory.lazy_register("vega.networks.pytorch.necks", {
    "ffm": ["network:FeatureFusionModule"],
    "fpn": ["FPN"]
})
Esempio n. 13
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import mindspore network."""

from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.networks.mindspore", {
        "dnet": ["DNet"],
        "super_network":
        ["DartsNetwork", "CARSDartsNetwork", "GDASDartsNetwork"],
        "backbones.load_official_model": ["OffcialModelLoader"],
        "backbones.resnet_ms": ["ResNetMs"],
        "losses.mix_auxiliary_loss": ["MixAuxiliaryLoss"],
    })
Esempio n. 14
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import data augmentation algorithms."""

from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.algorithms.data_augmentation", {
        "pba_hpo": ["PBAHpo"],
        "pba_trainer_callback": ["PbaTrainerCallback"],
        "cyclesr": ["CyclesrTrainerCallback"],
    })
Esempio n. 15
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Import and register evaluator automatically."""

from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.evaluator", {
        "device_evaluator": ["DeviceEvaluator"],
        "host_evaluator": ["HostEvaluator"],
        "evaluator": ["Evaluator"],
    })
Esempio n. 16
0
def transform_architecture(model, pretrained_model_file=None):
    """Transform architecture."""
    if not hasattr(model, "_arch_params") or not model._arch_params or \
            PipeStepConfig.pipe_step.get("type") == "TrainPipeStep":
        return model
    model._apply_names()
    logging.info(
        "Start to transform architecture, model arch params type: {}".format(
            model._arch_params_type))
    ConnectionsArchParamsCombiner().combine(model)
    if vega.is_ms_backend():
        from mindspore.train.serialization import load_checkpoint
        changed_name_list = []
        mask_weight_list = []
        for name, module in model.named_modules():
            if not ClassFactory.is_exists(model._arch_params_type,
                                          module.model_name):
                continue
            changed_name_list, mask_weight_list = decode_fn_ms(
                module, changed_name_list, mask_weight_list)
        assert len(changed_name_list) == len(mask_weight_list)
        # change model and rebuild
        model_desc = model.desc
        root_name = [
            name for name in list(model_desc.keys())
            if name not in ('type', '_arch_params')
        ]
        for changed_name, mask in zip(changed_name_list, mask_weight_list):
            name = changed_name.split('.')
            name[0] = root_name[int(name[0])]
            assert len(name) <= 6
            if len(name) == 6:
                model_desc[name[0]][name[1]][name[2]][name[3]][name[4]][
                    name[5]] = sum(mask)
            if len(name) == 5:
                model_desc[name[0]][name[1]][name[2]][name[3]][name[4]] = sum(
                    mask)
            if len(name) == 4:
                model_desc[name[0]][name[1]][name[2]][name[3]] = sum(mask)
            if len(name) == 3:
                model_desc[name[0]][name[1]][name[2]] = sum(mask)
            if len(name) == 2:
                model_desc[name[0]][name[1]] = sum(mask)
        network = NetworkDesc(model_desc)
        model = network.to_model()
        model_desc.pop(
            '_arch_params') if '_arch_params' in model_desc else model_desc
        model.desc = model_desc
        # change weight
        if hasattr(model, "pretrained"):
            pretrained_weight = model.pretrained(pretrained_model_file)
            load_checkpoint(pretrained_weight, net=model)
            os.remove(pretrained_weight)

    else:
        for name, module in model.named_modules():
            if not ClassFactory.is_exists(model._arch_params_type,
                                          module.model_name):
                continue
            arch_cls = ClassFactory.get_cls(model._arch_params_type,
                                            module.model_name)

            decode_fn(module, arch_cls)
            module.register_forward_pre_hook(arch_cls.fit_weights)
            module.register_forward_hook(module.clear_module_arch_params)
    return model
Esempio n. 17
0
from .search_algorithm import SearchAlgorithm
from .ea_conf import EAConfig
from .pareto_front_conf import ParetoFrontConfig
from .pareto_front import ParetoFront
from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register("vega.core.search_algs", {
    "ps_differential": ["DifferentialAlgorithm"],
})
Esempio n. 18
0
 def __init__(self, student, teacher, header=None):
     super(TinyBertDistil, self).__init__(student, teacher)
     self.loss_mse = ops.MSELoss()
     self.head = ClassFactory.get_instance(ClassType.NETWORK, header)
Esempio n. 19
0
from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register("vega.datasets.pytorch", {
    "coco_transforms": ["CocoCategoriesTransform", "PolysToMaskTransform"],
})
Esempio n. 20
0
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import dataset."""

from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.datasets.common",
    {
        "avazu": ["AvazuDataset"],
        "cifar10": ["Cifar10"],
        "cifar100": ["Cifar100"],
        "div2k": ["DIV2K"],
        "cls_ds": ["ClassificationDataset"],
        "cityscapes": ["Cityscapes"],
        "div2k_unpair": ["Div2kUnpair"],
        "fmnist": ["FashionMnist"],
        # "imagenet": ["Imagenet"],
        "mnist": ["Mnist"],
        "sr_datasets": ["Set5", "Set14", "BSDS100"],
        "auto_lane_datasets": ["AutoLaneDataset"],
        "coco": ["CocoDataset", "DetectionDataset"],
        "glue": ["GlueDataset"],
        "spatiotemporal": ["SpatiotemporalDataset"],
        "reds": ["REDS"],
        "nasbench": ["Nasbench"],
    })
Esempio n. 21
0
 def __init__(self, student, teacher):
     super().__init__()
     self.student = ClassFactory.get_instance(ClassType.NETWORK, student)
     self.teacher = ClassFactory.get_instance(ClassType.NETWORK, teacher)
     self.teacher.freeze('teacher')
Esempio n. 22
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Import and register trainer automatically."""

from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.trainer", {
        "trainer_torch": ["TrainerTorch"],
        "trainer_tf": ["TrainerTf"],
        "trainer_ms": ["TrainerMs"],
        "trainer": ["Trainer"],
        "script_runner": ["ScriptRunner"],
    })
Esempio n. 23
0
"""Lazy import tensorflow networks."""

from .network import Sequential
from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.networks.tensorflow",
    {
        "resnet_tf": ["ResNetTF", 'ResNetSlim'],
        # backbones
        "backbones.resnet_det": ["ResNetDet"],
        # customs
        "customs.edvr.edvr": ["EDVR"],
        "customs.gcn_regressor": ["GCNRegressor"],
        # detectors
        "detectors.faster_rcnn_trainer_callback":
        ["FasterRCNNTrainerCallback"],
        "detectors.faster_rcnn": ["FasterRCNN"],
        "detectors.tf_optimizer": ["TFOptimizer"],
        # losses
        "losses.cross_entropy_loss": ["CrossEntropyLoss"],
        "losses.mix_auxiliary_loss": ["MixAuxiliaryLoss"],
        "losses.charbonnier": ["CharbonnierLoss"],
        # necks
        "necks.mask_rcnn_box": ["MaskRCNNBox"],
    })

ClassFactory.lazy_register(
    "vega.networks.tensorflow.utils", {
        "anchor_utils.anchor_generator": ["AnchorGenerator"],
        "hyperparams.initializer": ["Initializer"],
Esempio n. 24
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.

"""Import and register modules automatically."""

from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register("vega.modules", {
    "module": ["network:Module"],
})


def register_modules():
    """Import and register modules automatically."""
    from . import blocks
    from . import cells
    from . import connections
    from . import operators
    from . import preprocess
    from . import loss
    from . import getters
    from . import necks
    from . import backbones
    from . import distillation
Esempio n. 25
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import custom network."""

from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.networks.pytorch.customs", {
        "nago": ["network:NAGO"],
        "deepfm": ["network:DeepFactorizationMachineModel"],
        "autogate": ["network:AutoGateModel"],
        "autogroup": ["network:AutoGroupModel"],
        "bisenet": ["network:BiSeNet"],
        "modnas": ["network:ModNasArchSpace"],
        "mobilenetv2": ["network:MobileNetV2"],
        "gcn_regressor": ["network:GCNRegressor"],
    })
Esempio n. 26
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import compression algorithms."""

from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.algorithms.compression", {
        "prune_ea":
        ["PruneCodec", "PruneEA", "PruneSearchSpace", "PruneTrainerCallback"],
        "prune_ea_mobilenet":
        ["PruneMobilenetCodec", "PruneMobilenetTrainerCallback"],
        "quant_ea": ["QuantCodec", "QuantEA", "QuantTrainerCallback"],
    })
Esempio n. 27
0
from .callback import Callback
from .callback_list import CallbackList
from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.trainer.callbacks", {
        "metrics_evaluator": ["trainer.callback:MetricsEvaluator"],
        "progress_logger": ["trainer.callback:ProgressLogger"],
        "performance_saver": ["trainer.callback:PerformanceSaver"],
        "lr_scheduler": ["trainer.callback:LearningRateScheduler"],
        "model_builder": ["trainer.callback:ModelBuilder"],
        "model_statistics": ["trainer.callback:ModelStatistics"],
        "model_checkpoint": ["trainer.callback:ModelCheckpoint"],
        "report_callback": ["trainer.callback:ReportCallback"],
        "runtime_callback": ["trainer.callback:RuntimeCallback"],
        "detection_progress_logger":
        ["trainer.callback:DetectionProgressLogger"],
        "detection_metrics_evaluator":
        ["trainer.callback:DetectionMetricsEvaluator"],
        "visual_callback": ["trainer.callback:VisualCallBack"],
        "model_tuner": ["trainer.callback:ModelTuner"],
        "timm_trainer_callback": ["trainer.callback:TimmTrainerCallback"],
    })
Esempio n. 28
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import pytorch blocks."""

from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register("vega.networks.pytorch.blocks", {
    "block": ["Block"],
    "conv_ws": ["ConvWS2d"],
    "stem": ['PreTwoStem'],
})
Esempio n. 29
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import head network."""

from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.networks.pytorch.heads", {
        "auto_lane_head": ["network:AutoLaneHead"],
        "auxiliary_head": ["network:AuxiliaryHead"],
    })
Esempio n. 30
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import detector."""

from vega.common.class_factory import ClassFactory

ClassFactory.lazy_register("vega.networks.pytorch.detectors", {
    "auto_lane_detector": ["AutoLaneDetector"],
})