コード例 #1
0
 def from_module(cls, module):
     """From Model."""
     name = module.__class__.__name__
     if ClassFactory.is_exists(ClassType.NETWORK, name):
         module_cls = ClassFactory.get_cls(ClassType.NETWORK, name)
         if hasattr(module_cls, "from_module"):
             return module_cls.from_module(module)
     return module
コード例 #2
0
 def __init__(self, aux_weight, loss_base):
     """Init MixAuxiliaryLoss."""
     self.aux_weight = aux_weight
     loss_base_cp = loss_base.copy()
     loss_base_name = loss_base_cp.pop('type')
     if ClassFactory.is_exists('trainer.loss', loss_base_name):
         loss_class = ClassFactory.get_cls('trainer.loss', loss_base_name)
     else:
         loss_class = getattr(importlib.import_module('tensorflow.losses'), loss_base_name)
     self.loss_fn = loss_class(**loss_base_cp['params'])
コード例 #3
0
 def from_desc(cls, desc):
     """Create Model from desc."""
     desc = deepcopy(desc)
     module_groups = desc.get('modules', [])
     module_type = desc.get('type', 'Sequential')
     loss = desc.get('loss')
     if '_arch_params' in desc:
         arch_params = desc.pop('_arch_params')
         arch_type = list(arch_params.keys())[0]
         ArchParams._arch_type = arch_type
         ArchParams.update(arch_params.get(arch_type))
     modules = OrderedDict()
     for group_name in module_groups:
         module_desc = deepcopy(desc.get(group_name))
         if not module_desc:
             continue
         if 'modules' in module_desc:
             module = cls.from_desc(module_desc)
         else:
             cls_name = module_desc.get('type')
             if not ClassFactory.is_exists(ClassType.NETWORK, cls_name):
                 raise ValueError("Network {} not exists.".format(cls_name))
             module = ClassFactory.get_instance(ClassType.NETWORK,
                                                module_desc)
         modules[group_name] = module
         module.name = str(group_name)
     if not module_groups and module_type:
         model = ClassFactory.get_instance(ClassType.NETWORK, desc)
     else:
         if ClassFactory.is_exists(SearchSpaceType.CONNECTIONS,
                                   module_type):
             connections = ClassFactory.get_cls(SearchSpaceType.CONNECTIONS,
                                                module_type)
         else:
             connections = ClassFactory.get_cls(SearchSpaceType.CONNECTIONS,
                                                'Sequential')
         model = list(modules.values())[0] if len(
             modules) == 1 else connections(modules)
     if loss:
         model.add_loss(ClassFactory.get_cls(ClassType.LOSS, loss))
     return model
コード例 #4
0
 def _register_models_from_current_module_scope(module):
     for _name in dir(module):
         if _name.startswith("_"):
             continue
         _cls = getattr(module, _name)
         if isinstance(_cls, ModuleType):
             continue
         if ClassFactory.is_exists(ClassType.NETWORK,
                                   'torchvision_' + _cls.__name__):
             continue
         ClassFactory.register_cls(_cls,
                                   ClassType.NETWORK,
                                   alias='torchvision_' + _cls.__name__)
コード例 #5
0
ファイル: ps_differential.py プロジェクト: huawei-noah/vega
 def _init_loss(self):
     """Init loss."""
     if vega.is_torch_backend():
         loss_config = self.criterion.copy()
         loss_name = loss_config.pop('type')
         loss_class = getattr(importlib.import_module('torch.nn'),
                              loss_name)
         return loss_class(**loss_config)
     elif vega.is_tf_backend():
         from inspect import isclass
         loss_config = self.config.tf_criterion.copy()
         loss_name = loss_config.pop('type')
         if ClassFactory.is_exists('trainer.loss', loss_name):
             loss_class = ClassFactory.get_cls('trainer.loss', loss_name)
             if isclass(loss_class):
                 return loss_class(**loss_config)
             else:
                 return partial(loss_class, **loss_config)
         else:
             loss_class = getattr(
                 importlib.import_module('tensorflow.losses'), loss_name)
             return partial(loss_class, **loss_config)
コード例 #6
0
ファイル: AutoAugment.py プロジェクト: huawei-noah/vega
    def __call__(self, img):
        """Call function of AutoContrast.

        :param img: input image
        :type img: numpy or tensor
        :return: the image after transform
        :rtype: numpy or tensor
        """
        transforms = []
        if self.prob < 1.0 and random.random() > self.prob:
            return img
        for name in self._RAND_TRANSFORMS.keys():
            if ClassFactory.is_exists(ClassType.TRANSFORM,
                                      self._RAND_TRANSFORMS[name]):
                transforms.append(
                    ClassFactory.get_cls(ClassType.TRANSFORM,
                                         self._RAND_TRANSFORMS[name]))
        ops = np.random.choice(transforms, self.num)
        for op in ops:
            if self.magnitude_std and self.magnitude_std > 0:
                magnitude = random.gauss(self.magnitude, self.magnitude_std)
            magnitude = min(10, max(0, magnitude))
            img = op(magnitude)(img)
        return img