Example #1
0
 def _init_dataloader(self, mode, loader=None, transforms=None):
     """Init dataloader."""
     if loader is not None:
         return loader
     if mode == "train" and self.hps is not None and self.hps.get(
             "dataset") is not None:
         if self.hps.get("dataset") and self.hps.get("dataset").get('type'):
             dataset_cls = ClassFactory.get_cls(
                 ClassType.DATASET,
                 self.hps.get("dataset").get('type'))
         else:
             dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset"))
     elif self.hps:
         if self.hps.get("dataset") and self.hps.get("dataset").get('type'):
             dataset_cls = ClassFactory.get_cls(
                 ClassType.DATASET,
                 self.hps.get("dataset").get('type'))
             dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset"))
         else:
             dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
             dataset = dataset_cls(mode=mode)
     else:
         dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         dataset = dataset_cls(mode=mode)
     if transforms is not None:
         dataset.transforms = transforms
     if self.distributed and mode == "train":
         dataset.set_distributed(self._world_size, self._rank_id)
     # adapt the dataset to specific backend
     dataloader = Adapter(dataset).loader
     return dataloader
Example #2
0
    def _do_hccl_fully_train(self, trainer):
        origin_worker_id = trainer.worker_id
        model_desc = trainer.model_desc
        del trainer

        origin_parallel_fully_train = General.parallel_fully_train
        origin_parallel = General._parallel
        General.parallel_fully_train = True
        General.dft = True
        General._parallel = True

        cls_trainer = ClassFactory.get_cls('trainer')
        self.master = create_master()
        workers_num = int(os.environ['RANK_SIZE'])
        for i in range(workers_num):
            worker_id = "{}-{}".format(origin_worker_id, i)
            trainer = cls_trainer(model_desc, id=worker_id)
            evaluator = self._get_evaluator(worker_id)
            self.master.run(trainer, evaluator)

        self.master.join()
        self.master.shutdown()
        General.parallel_fully_train = origin_parallel_fully_train
        General.dft = False
        General._parallel = origin_parallel
Example #3
0
 def _train_single_model(self,
                         model_desc=None,
                         model_id=None,
                         weights_file=None):
     cls_trainer = ClassFactory.get_cls('trainer')
     step_name = self.task.step_name
     if model_desc is not None:
         sample = dict(worker_id=model_id,
                       desc=model_desc,
                       step_name=step_name)
         record = ReportRecord().load_dict(sample)
         logging.debug("Broadcast Record=%s", str(record))
         trainer = cls_trainer(model_desc=model_desc,
                               id=model_id,
                               pretrained_model_file=weights_file)
     else:
         trainer = cls_trainer(None, 0)
         record = ReportRecord(trainer.step_name,
                               trainer.worker_id,
                               desc=trainer.model_desc)
     ReportClient.broadcast(record)
     ReportServer.add_watched_var(trainer.step_name, trainer.worker_id)
     # resume training
     if vega.is_torch_backend() and General._resume:
         trainer.load_checkpoint = True
         trainer._resume_training = True
     if self._distributed_training:
         self._do_distributed_fully_train(trainer)
     else:
         self._do_single_fully_train(trainer)
Example #4
0
    def create_model(self, model_info):
        """Create Deep-Q network."""
        model_config = model_info['model_config']
        zeus_model_name = model_config['zeus_model_name']

        model_cls = ClassFactory.get_cls(ClassType.NETWORK, zeus_model_name)
        zeus_model = model_cls(state_dim=self.state_dim,
                               action_dim=self.action_dim)
        # zeus_model = DqnMlpNet(state_dim=self.state_dim, action_dim=self.action_dim)

        LossConfig.type = model_config['loss']
        OptimConfig.type = model_config['optim']
        OptimConfig.params.update({'lr': model_config['lr']})

        loss_input = dict()
        loss_input['inputs'] = [{
            "name": "input_state",
            "type": "float32",
            "shape": self.state_dim
        }]
        loss_input['labels'] = [{
            "name": "target_value",
            "type": "float32",
            "shape": self.action_dim
        }]
        model = Trainer(model=zeus_model,
                        backend='tensorflow',
                        device='GPU',
                        loss_input=loss_input,
                        lazy_build=False)
        return model
Example #5
0
 def _get_evaluator(self, worker_id):
     if not PipeStepConfig.evaluator_enable:
         return None
     cls_evaluator = ClassFactory.get_cls('evaluator', "Evaluator")
     evaluator = cls_evaluator({
         "step_name": self.task.step_name,
         "worker_id": worker_id
     })
     return evaluator
Example #6
0
 def to_model(self):
     """Transform a NetworkDesc to a special model."""
     logging.debug("Start to Create a Network.")
     module = ClassFactory.get_cls(ClassType.NETWORK, "Module")
     model = module.from_desc(self._desc)
     if not model:
         raise Exception("Failed to create model, model desc={}".format(
             self._desc))
     if not self.is_deformation:
         model.desc = self._desc
     return model
Example #7
0
 def __init__(self, desc=None):
     """Init SearchSpace."""
     super(SearchSpace, self).__init__()
     if desc is None:
         desc = SearchSpaceConfig().to_dict()
         if desc.type is not None:
             desc = ClassFactory.get_cls(ClassType.SEARCHSPACE,
                                         desc.type).get_space(desc)
     for name, item in desc.items():
         self.__setattr__(name, item)
         self.__setitem__(name, item)
     self._params = OrderedDict()
     self._condition_dict = OrderedDict()
     self._forbidden_list = []
     self._hp_count = 0
     self._dag = DAG()
     if desc is not None:
         self.form_desc(desc)
Example #8
0
def register_datasets(backend):
    """Import and register datasets automatically."""
    if backend == "pytorch":
        from . import pytorch
        ClassFactory.lazy_register("zeus.datasets.common",
                                   {"imagenet": ["Imagenet"]})
    elif backend == "tensorflow":
        from . import tensorflow
        ClassFactory.lazy_register("zeus.datasets.tensorflow",
                                   {"imagenet": ["Imagenet"]})
    elif backend == "mindspore":
        import mindspore.dataset
        from . import mindspore
        ClassFactory.lazy_register("zeus.datasets.common",
                                   {"imagenet": ["Imagenet"]})
    from . import common
    from . import transforms
Example #9
0
from .metrics import Metrics
from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "zeus.metrics.pytorch", {
        "lane_metric": ["trainer.metric:LaneMetric"],
        "regression": ["trainer.metric:MSE", "trainer.metric:mse"],
        "detection_metric":
        ["trainer.metric:CocoMetric", "trainer.metric:coco"],
        "gan_metric": ["trainer.metric:GANMetric"],
        "classifier_metric": [
            "trainer.metric:accuracy", "trainer.metric:Accuracy",
            "trainer.metric:SklearnMetrics"
        ],
        "auc_metrics": ["trainer.metric:AUC", "trainer.metric:auc"],
        "segmentation_metric": ["trainer.metric:IoUMetric"],
        "sr_metric": ["trainer.metric:PSNR", "trainer.metric:SSIM"],
    })
Example #10
0
# -*- coding:utf-8 -*-
from .pipe_step import PipeStep
from .pipeline import Pipeline
from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.core.pipeline", {
        "search_pipe_step": ["SearchPipeStep"],
        "train_pipe_step": ["TrainPipeStep"],
        "benchmark_pipe_step": ["BenchmarkPipeStep"],
    })
Example #11
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import compression algorithms."""

from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.algorithms.compression", {
        "prune_ea":
        ["PruneCodec", "PruneEA", "PruneSearchSpace", "PruneTrainerCallback"],
        "prune_ea_mobilenet":
        ["PruneMobilenetCodec", "PruneMobilenetTrainerCallback"],
        "quant_ea": ["QuantCodec", "QuantEA", "QuantTrainerCallback"],
    })
Example #12
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.

"""Lazy import data augmentation algorithms."""

from zeus.common.class_factory import ClassFactory


ClassFactory.lazy_register("vega.algorithms.data_augmentation", {
    "pba_hpo": ["PBAHpo"],
    "pba_trainer_callback": ["PbaTrainerCallback"],
    "cyclesr": ["CyclesrTrainerCallback"],
})
Example #13
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import hpo algorithms."""

from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "vega.algorithms.hpo", {
        "asha_hpo": ["AshaHpo"],
        "bo_hpo": ["BoHpo"],
        "bohb_hpo": ["BohbHpo"],
        "boss_hpo": ["BossHpo"],
        "random_hpo": ["RandomSearch"],
        "random_pareto_hpo": ["RandomParetoHpo"],
        "evolution_search": ["EvolutionAlgorithm"],
    })
Example #14
0
from .metrics import Metrics
from zeus.common.class_factory import ClassFactory


ClassFactory.lazy_register("zeus.metrics.tensorflow", {
    "segmentation_metric": ["trainer.metric:IoUMetric"],
    "classifier_metric": ["trainer.metric:accuracy"],
    "sr_metric": ["trainer.metric:PSNR", "trainer.metric:SSIM"],
    "forecast": ["trainer.metric:MSE", "trainer.metric:RMSE"],
})
Example #15
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import ops."""

from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register("zeus.networks.pytorch.ops", {
    "fmdunit": ["network:FMDUnit", "network:LinearScheduler"],
})
Example #16
0
from .metrics import Metrics
from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "zeus.metrics.mindspore", {
        "segmentation_metric": ["trainer.metric:IoUMetric"],
        "classifier_metric": ["trainer.metric:accuracy"],
        "sr_metric": ["trainer.metric:PSNR", "trainer.metric:SSIM"],
    })
Example #17
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import modules."""

from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "zeus.modules.deformations", {
        "deformation": ["Deformation"],
        "backbone_deformation": ["BackboneDeformation"],
        "prune_deformation": ["PruneDeformation"]
    })
Example #18
0
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import tensorflow networks."""

from .network import Sequential
from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "zeus.networks.tensorflow",
    {
        "resnet_tf": ["ResNetTF"],
        # backbones
        "backbones.resnet_det": ["ResNetDet"],
        # detectors
        "detectors.faster_rcnn_trainer_callback":
        ["FasterRCNNTrainerCallback"],
        "detectors.faster_rcnn": ["FasterRCNN"],
        "detectors.tf_optimizer": ["TFOptimizer"],
        # losses
        "losses.cross_entropy_loss": ["CrossEntropyLoss"],
        "losses.mix_auxiliary_loss": ["MixAuxiliaryLoss"],
        # necks
        "necks.mask_rcnn_box": ["MaskRCNNBox"],
    })

ClassFactory.lazy_register(
    "zeus.networks.tensorflow.utils", {
        "anchor_utils.anchor_generator": ["AnchorGenerator"],
        "hyperparams.initializer": ["Initializer"],
        "hyperparams.regularizer": ["Regularizer"],
    })
Example #19
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import loss functions."""

from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "zeus.networks.pytorch.losses", {
        "sum_loss": ["trainer.loss:SumLoss"],
        "smooth_l1_loss": ["trainer.loss:SmoothL1Loss"],
        "custom_cross_entropy_loss": ["trainer.loss:CustomCrossEntropyLoss"],
        "cross_entropy_label_smooth": ["trainer.loss:CrossEntropyLabelSmooth"],
        "mix_auxiliary_loss": ["trainer.loss:MixAuxiliaryLoss"],
    })
Example #20
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.

"""Lazy import pytorch backbones."""

from zeus.common.class_factory import ClassFactory


ClassFactory.lazy_register("zeus.networks.pytorch.backbones", {
    "getter": ["BackboneGetter"],
    "load_official_model": ["OffcialModelLoader"],
    "resnet_variant_det": ["ResNetVariantDet"],
    "resnext_variant_det": ["ResNeXtVariantDet"],
})
Example #21
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.

"""Lazy import necks."""

from zeus.common.class_factory import ClassFactory


ClassFactory.lazy_register("zeus.networks.pytorch.necks", {
    "ffm": ["network:FeatureFusionModule"],
    "fpn": ["TorchFPN"]
})
Example #22
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import detector."""

from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register("zeus.networks.pytorch.detectors", {
    "auto_lane_detector": ["AutoLaneDetector"],
})
Example #23
0
from .loss import Loss
from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "zeus.modules.loss", {
        "multiloss": ["trainer.loss:MultiLoss"],
        "focal_loss": ["trainer.loss:FocalLoss"],
        "f1_loss": ["trainer.loss:F1Loss"],
        "forecast_loss": ["trainer.loss:ForecastLoss"],
        "mean_loss": ["trainer.loss:MeanLoss"],
        "ProbOhemCrossEntropy2d": ["trainer.loss:ProbOhemCrossEntropy2d"],
        "gan_loss": ["trainer.loss:GANLoss"],
    })
Example #24
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import mindspore network."""

from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "zeus.networks.mindspore", {
        "dnet": ["DNet"],
        "super_network":
        ["DartsNetwork", "CARSDartsNetwork", "GDASDartsNetwork"],
        "prune_deformation": ["PruneDeformation", "Deformation"],
        "backbones.load_official_model": ["OffcialModelLoader"],
        "backbones.resnet_ms": ["ResNetMs"],
        "losses.mix_auxiliary_loss": ["MixAuxiliaryLoss"],
    })
Example #25
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Lazy import custom network."""

from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register(
    "zeus.networks.pytorch.customs", {
        "nago": ["network:NAGO"],
        "deepfm": ["network:DeepFactorizationMachineModel"],
        "autogate": ["network:AutoGateModel"],
        "autogroup": ["network:AutoGroupModel"],
        "bisenet": ["network:BiSeNet"],
        "modnas": ["network:ModNasArchSpace"],
        "mobilenetv2": ["network:MobileNetV2"],
    })
Example #26
0
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.

"""Lazy import dataset."""

from zeus.common.class_factory import ClassFactory


ClassFactory.lazy_register("zeus.datasets.common", {
    "avazu": ["AvazuDataset"],
    "cifar10": ["Cifar10"],
    "cifar100": ["Cifar100"],
    "div2k": ["DIV2K"],
    "cls_ds": ["ClassificationDataset"],
    "cityscapes": ["Cityscapes"],
    "div2k_unpair": ["Div2kUnpair"],
    "fmnist": ["FashionMnist"],
    # "imagenet": ["Imagenet"],
    "mnist": ["Mnist"],
    "sr_datasets": ["Set5", "Set14", "BSDS100"],
    "auto_lane_datasets": ["AutoLaneDataset"],
    "coco": ["CocoDataset"],
    "mrpc": ["MrpcDataset"],
    "spatiotemporal": ["SpatiotemporalDataset"]
})
Example #27
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.

"""Lazy import cyclesr bodys."""

from zeus.common.class_factory import ClassFactory


ClassFactory.lazy_register("zeus.networks.pytorch.cyclesrbodys", {
    "cyclesr_net": ["CycleSRModel"],
})
Example #28
0
from .network_desc import NetworkDesc

ClassFactory.lazy_register(
    "zeus.networks",
    {
        "adelaide": ["AdelaideFastNAS"],
        "bert": ["BertClassifier"],
        "dnet": ["DNet", "DNetBackbone"],
        "erdb_esr": ["ESRN"],
        "faster_backbone": ["FasterBackbone"],
        "faster_rcnn": ["FasterRCNN"],
        "mobilenet": ["MobileNetV3Tiny", "MobileNetV2Tiny"],
        "mobilenetv3": ["MobileNetV3Small", "MobileNetV3Large"],
        "necks": ["FPN"],
        "quant": ["Quantizer"],
        "resnet_det": ["ResNetDet"],
        "resnet_general": ["ResNetGeneral"],
        "resnet": ["ResNet"],
        "resnext_det": ["ResNeXtDet"],
        "sgas_network": ["SGASNetwork"],
        "simple_cnn": ["SimpleCnn"],
        "spnet_backbone": ["SpResNetDet"],
        "super_network":
        ["DartsNetwork", "CARSDartsNetwork", "GDASDartsNetwork"],
        "text_cnn": ["TextCells", "TextCNN"],
        "gcn": ["GCN"],
        # xingtian
        "mtm_sr": ["MtMSR"],
        "xt_model": ["DqnMlpNet", "DqnCnnNet", "mse_loss"],
    })

Example #29
0
ClassFactory.lazy_register("zeus.datasets.transforms", {
    # common
    "AutoContrast": ["AutoContrast"],
    "BboxTransform": ["BboxTransform"],
    "Brightness": ["Brightness"],
    "Color": ["Color"],
    "Compose": ["Compose", "ComposeAll"],
    "Compose_pair": ["Compose_pair"],
    "Contrast": ["Contrast"],
    "Cutout": ["Cutout"],
    "Equalize": ["Equalize"],
    "RandomCrop_pair": ["RandomCrop_pair"],
    "RandomHorizontalFlip_pair": ["RandomHorizontalFlip_pair"],
    "RandomMirrow_pair": ["RandomMirrow_pair"],
    "RandomRotate90_pair": ["RandomRotate90_pair"],
    "RandomVerticallFlip_pair": ["RandomVerticallFlip_pair"],
    "RandomColor_pair": ["RandomColor_pair"],
    "RandomRotate_pair": ["RandomRotate_pair"],
    "Rescale_pair": ["Rescale_pair"],
    "Normalize_pair": ["Normalize_pair"],
    # GPU only
    "ImageTransform": ["ImageTransform"],
    "Invert": ["Invert"],
    "MaskTransform": ["MaskTransform"],
    "Posterize": ["Posterize"],
    "Rotate": ["Rotate"],
    "SegMapTransform": ["SegMapTransform"],
    "Sharpness": ["Sharpness"],
    "Shear_X": ["Shear_X"],
    "Shear_Y": ["Shear_Y"],
    "Solarize": ["Solarize"],
    "Translate_X": ["Translate_X"],
    "Translate_Y": ["Translate_Y"],
    "RandomGaussianBlur_pair": ["RandomGaussianBlur_pair"],
    "RandomHorizontalFlipWithBoxes": ["RandomHorizontalFlipWithBoxes"],
})
Example #30
0
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.

"""Lazy import nas algorithms."""

from zeus.common.class_factory import ClassFactory

ClassFactory.lazy_register("vega.algorithms.nas", {
    "adelaide_ea": ["AdelaideCodec", "AdelaideMutate", "AdelaideRandom", "AdelaideEATrainerCallback"],
    "auto_lane": ["AutoLaneNas", "AutoLaneNasCodec", "AutoLaneTrainerCallback"],
    "backbone_nas": ["BackboneNasCodec", "BackboneNasSearchSpace", "BackboneNas"],
    "cars": ["CARSAlgorithm", "CARSTrainerCallback", "CARSPolicyConfig"],
    "darts_cnn": ["DartsCodec", "DartsFullTrainerCallback", "DartsNetworkTemplateConfig", "DartsTrainerCallback"],
    "dnet_nas": ["DblockNasCodec", "DblockNas", "DnetNasCodec", "DnetNas"],
    "esr_ea": ["ESRCodec", "ESRTrainerCallback", "ESRSearch"],
    "fis": ["AutoGateGrdaS1TrainerCallback", "AutoGateGrdaS2TrainerCallback", "AutoGateS1TrainerCallback",
            "AutoGateS2TrainerCallback", "AutoGroupTrainerCallback", "CtrTrainerCallback"],
    "mfkd": ["MFKD1", "SimpleCnnMFKD"],
    "modnas": ["ModNasAlgorithm", "ModNasTrainerCallback"],
    "segmentation_ea": ["SegmentationCodec", "SegmentationEATrainerCallback", "SegmentationNas"],
    "sgas": ["SGASTrainerCallback"],
    "sm_nas": ["SmNasCodec", "SMNasM"],
    "sp_nas": ["SpNasS", "SpNasP"],
    "sr_ea": ["SRCodec", "SRMutate", "SRRandom"],
    "mfasc": ["search_algorithm:MFASC"]
})