예제 #1
0
class LrScheduler(object):
    """Register and call LrScheduler class."""

    config = LrSchedulerConfig()

    def __init__(self, config=None):
        """Initialize."""
        # register pytorch optim as default
        if config:
            self.config = Config(config)
            raw_config = deepcopy(self.config)
        else:
            self.config = LrScheduler.config
            raw_config = self.config.to_dict()
        raw_config.type = self.config.type
        map_dict = LrSchedulerMappingDict()
        self.map_config = ConfigBackendMapping(
            map_dict.type_mapping_dict, map_dict.params_mapping_dict).backend_mapping(raw_config)
        self._cls = ClassFactory.get_cls(ClassType.LR_SCHEDULER, self.map_config.type)

    def __call__(self, optimizer=None, epochs=None, steps=None):
        """Call lr scheduler class."""
        params = self.map_config.get("params", {})
        logging.debug("Call LrScheduler. name={}, params={}".format(self._cls.__name__, params))

        setattr(self._cls, "by_epoch", True)
        if hasattr(self.config, "by_epoch"):
            setattr(self._cls, "by_epoch", self.config.by_epoch)

        try:
            if params:
                return self._cls(optimizer, **params)
            else:
                return self._cls(optimizer)
        except Exception as ex:
            logging.error("Failed to call LrScheduler name={}, params={}".format(self._cls.__name__, params))
            raise ex
예제 #2
0
파일: optim.py 프로젝트: vineetrao25/vega
class Optimizer(object):
    """Register and call Optimizer class."""

    config = OptimConfig()

    def __new__(cls, *args, **kwargs):
        """Create optimizer or multi-optimizer class."""
        if isinstance(cls.config.to_dict, list):
            t_cls = ClassFactory.get_cls(ClassType.OPTIMIZER,
                                         'MultiOptimizers')
            return super().__new__(t_cls)
        return super().__new__(cls)

    def __init__(self, config=None):
        """Initialize."""
        self.is_multi_opt = False
        if config is not None:
            self.config = Config(config)
        raw_config = self.config.to_dict()
        raw_config.type = self.config.type
        map_dict = OptimMappingDict
        self.map_config = ConfigBackendMapping(
            map_dict.type_mapping_dict,
            map_dict.params_mapping_dict).backend_mapping(raw_config)
        self.optim_cls = ClassFactory.get_cls(ClassType.OPTIMIZER,
                                              self.map_config.type)

    def __call__(self, model=None, distributed=False, **kwargs):
        """Call Optimizer class.

        :param model: model, used in torch case
        :param distributed: use distributed
        :return: optimizer
        """
        params = self.map_config.get("params", {})
        logging.debug("Call Optimizer. name={}, params={}".format(
            self.optim_cls.__name__, params))
        optimizer = None
        try:
            if zeus.is_torch_backend():
                learnable_params = [
                    param for param in model.parameters()
                    if param.requires_grad
                ]
                optimizer = self.optim_cls(learnable_params, **params)
                if distributed:
                    optimizer = self.set_distributed(optimizer, model)
            elif zeus.is_tf_backend():
                optimizer = dynamic_optimizer(self.optim_cls, **params)
            elif zeus.is_ms_backend():
                if "dynamic_lr" in kwargs:
                    params.update({"learning_rate": kwargs["dynamic_lr"]})
                learnable_params = [
                    param for param in model.trainable_params()
                    if param.requires_grad
                ]
                optimizer = self.optim_cls(learnable_params, **params)
            return optimizer
        except Exception as ex:
            logging.error("Failed to call Optimizer name={}, params={}".format(
                self.optim_cls.__name__, params))
            raise ex

    @classmethod
    def set_distributed(cls, optimizer, model=None):
        """Set distributed optimizer."""
        if zeus.is_torch_backend():
            optimizer = hvd.DistributedOptimizer(
                optimizer,
                named_parameters=model.named_parameters(),
                compression=hvd.Compression.none)
        elif zeus.is_tf_backend():
            optim_class = hvd.DistributedOptimizer if zeus.is_gpu_device(
            ) else NPUDistributedOptimizer
            optimizer = dynamic_distributed_optimizer(optim_class, optimizer)
        return optimizer