示例#1
0
 def from_desc(cls, desc):
     """Create Model from desc."""
     desc = deepcopy(desc)
     module_groups = desc.get('modules', [])
     module_type = desc.get('type', 'Sequential')
     loss = desc.get('loss')
     modules = OrderedDict()
     for group_name in module_groups:
         module_desc = deepcopy(desc.get(group_name))
         if 'modules' in module_desc:
             module = cls.from_desc(module_desc)
         else:
             cls_name = module_desc.get('type')
             if not ClassFactory.is_exists(ClassType.NETWORK, cls_name):
                 raise ValueError("Network {} not exists.".format(cls_name))
             module = ClassFactory.get_instance(ClassType.NETWORK,
                                                module_desc)
         modules[group_name] = module
     if not modules and module_type:
         model = ClassFactory.get_instance(ClassType.NETWORK, desc)
     else:
         if ClassFactory.is_exists(SearchSpaceType.CONNECTIONS,
                                   module_type):
             connections = ClassFactory.get_cls(SearchSpaceType.CONNECTIONS,
                                                module_type)
         else:
             connections = ClassFactory.get_cls(SearchSpaceType.CONNECTIONS,
                                                'Sequential')
         model = list(modules.values())[0] if len(
             modules) == 1 else connections(modules)
     if loss:
         model.add_loss(ClassFactory.get_cls(ClassType.LOSS, loss))
     return model
示例#2
0
 def from_desc(cls, desc):
     """Create Model from desc."""
     module_groups = desc.get('modules')
     module_type = desc.get('type', 'Sequential')
     loss = desc.get('loss')
     modules = OrderedDict()
     for group_name in module_groups:
         module_desc = deepcopy(desc.get(group_name))
         if 'modules' in module_desc:
             module = cls.from_desc(module_desc)
         else:
             cls_name = module_desc.get('type')
             if not ClassFactory.is_exists(ClassType.SEARCH_SPACE,
                                           cls_name):
                 return None
             module = ClassFactory.get_instance(ClassType.SEARCH_SPACE,
                                                module_desc)
         modules[group_name] = module
     if ClassFactory.is_exists(SearchSpaceType.CONNECTIONS, module_type):
         connections = ClassFactory.get_cls(SearchSpaceType.CONNECTIONS,
                                            module_type)
     else:
         connections = ClassFactory.get_cls(SearchSpaceType.CONNECTIONS,
                                            'Sequential')
     model = list(
         modules.values())[0] if len(modules) == 1 else connections(modules)
     if loss:
         model.add_loss(ClassFactory.get_cls(ClassType.LOSS, loss))
     return model
示例#3
0
    def _init_transforms(self):
        """Initialize transforms method.

        :return: a list of object
        :rtype: list
        """
        if "transforms" in self.args.keys():
            transforms = list()
            if not isinstance(self.args.transforms, list):
                self.args.transforms = [self.args.transforms]
            for i in range(len(self.args.transforms)):
                transform_name = self.args.transforms[i].pop("type")
                kwargs = self.args.transforms[i]
                if ClassFactory.is_exists(ClassType.TRANSFORM, transform_name):
                    transforms.append(
                        ClassFactory.get_cls(ClassType.TRANSFORM,
                                             transform_name)(**kwargs))
                else:
                    transforms.append(
                        getattr(
                            importlib.import_module('torchvision.transforms'),
                            transform_name)(**kwargs))
            return transforms
        else:
            return list()
 def _register_models_from_current_module_scope(module):
     for _name in dir(module):
         if _name.startswith("_"):
             continue
         _cls = getattr(module, _name)
         if isinstance(_cls, ModuleType):
             continue
         if ClassFactory.is_exists(ClassType.SEARCH_SPACE, 'torchvision_' + _cls.__name__):
             continue
         ClassFactory.register_cls(_cls, ClassType.SEARCH_SPACE, alias='torchvision_' + _cls.__name__)
示例#5
0
 def __init__(self, aux_weight, loss_base):
     """Init MixAuxiliaryLoss."""
     self.aux_weight = aux_weight
     loss_base_cp = loss_base.copy()
     loss_base_name = loss_base_cp.pop('type')
     if ClassFactory.is_exists('trainer.loss', loss_base_name):
         loss_class = ClassFactory.get_cls('trainer.loss', loss_base_name)
     else:
         loss_class = getattr(importlib.import_module('tensorflow.losses'),
                              loss_base_name)
     self.loss_fn = loss_class(**loss_base_cp['params'])
示例#6
0
 def _init_loss(self):
     """Init loss."""
     if vega.is_torch_backend():
         loss_config = self.criterion.copy()
         loss_name = loss_config.pop('type')
         loss_class = getattr(importlib.import_module('torch.nn'), loss_name)
         return loss_class(**loss_config)
     elif vega.is_tf_backend():
         from inspect import isclass
         loss_config = self.config.tf_criterion.copy()
         loss_name = loss_config.pop('type')
         if ClassFactory.is_exists('trainer.loss', loss_name):
             loss_class = ClassFactory.get_cls('trainer.loss', loss_name)
             if isclass(loss_class):
                 return loss_class(**loss_config)
             else:
                 return partial(loss_class, **loss_config)
         else:
             loss_class = getattr(importlib.import_module('tensorflow.losses'), loss_name)
             return partial(loss_class, **loss_config)
示例#7
0
    def _init_after_scheduler(self):
        """Init after_scheduler with after_scheduler_config."""
        if isinstance(self.after_scheduler_config, dict):
            scheduler_config = copy.deepcopy(self.after_scheduler_config)
            print("after_scheduler_config: {}".format(scheduler_config))
            scheduler_name = scheduler_config.pop('type')
            if ClassFactory.is_exists(ClassType.LR_SCHEDULER, scheduler_name):
                scheduler_class = ClassFactory.get_cls(ClassType.LR_SCHEDULER,
                                                       scheduler_name)
            else:
                scheduler_class = getattr(
                    importlib.import_module('torch.optim.lr_scheduler'),
                    scheduler_name)

            if scheduler_class.__name__ == "CosineAnnealingLR":
                if scheduler_config.get("T_max", -1) == -1:
                    if scheduler_config.get("by_epoch", True):
                        scheduler_config["T_max"] = self.epochs
                    else:
                        scheduler_config["T_max"] = self.epochs * self.steps

            self.after_scheduler = scheduler_class(self.optimizer,
                                                   **scheduler_config)