コード例 #1
0
                    if param.requires_grad
                ]
                optimizer = self.optim_cls(learnable_params, **params)
                if distributed:
                    optimizer = hvd.DistributedOptimizer(
                        optimizer,
                        named_parameters=model.named_parameters(),
                        compression=hvd.Compression.none)
            elif vega.is_tf_backend():
                lr_scheduler.step(epoch)
                params['learning_rate'] = lr_scheduler.get_lr()[0]
                optimizer = self.optim_cls(**params)
                if distributed:
                    optimizer = hvd.DistributedOptimizer(optimizer) if vega.is_gpu_device() else \
                        NPUDistributedOptimizer(optimizer)
            return optimizer
        except Exception as ex:
            logging.error("Failed to call Optimizer name={}, params={}".format(
                self.optim_cls.__name__, params))
            raise ex


if vega.is_torch_backend():
    import torch.optim as torch_opt

    ClassFactory.register_from_package(torch_opt, ClassType.OPTIM)
elif vega.is_tf_backend():
    import tensorflow.train as tf_train

    ClassFactory.register_from_package(tf_train, ClassType.OPTIM)
コード例 #2
0
ファイル: lr_scheduler.py プロジェクト: zeyefkey/vega
        if self._cls.__name__ == "CosineAnnealingLR":
            if params.get("T_max", -1) == -1:
                if params.get("by_epoch", True):
                    params["T_max"] = epochs
                else:
                    params["T_max"] = epochs * steps

        if self._cls.__name__ == "WarmupScheduler":
            params["epochs"] = epochs
            params["steps"] = steps

        try:
            if params and optimizer:
                return self._cls(optimizer, **params)
            elif optimizer:
                return self._cls(optimizer)
            else:
                return self._cls(**params)
        except Exception as ex:
            logging.error(
                "Failed to call LrScheduler name={}, params={}".format(
                    self._cls.__name__, params))
            raise ex


if vega.is_torch_backend():
    import torch.optim.lr_scheduler as torch_lr

    ClassFactory.register_from_package(torch_lr, ClassType.LR_SCHEDULER)
コード例 #3
0
        loss_name = self.config.type
        self._cls = ClassFactory.get_cls(ClassType.LOSS, loss_name)

    def __call__(self):
        """Call loss cls."""
        params = obj2config(self.config).get("params", {})
        logging.debug("Call Loss. name={}, params={}".format(self._cls.__name__, params))
        try:
            if params:
                cls_obj = self._cls(**params) if isclass(self._cls) else partial(self._cls, **params)
            else:
                cls_obj = self._cls() if isclass(self._cls) else partial(self._cls)
            if vega.is_torch_backend() and TrainerConfig().cuda:
                cls_obj = cls_obj.cuda()
            return cls_obj
        except Exception as ex:
            logging.error("Failed to call Loss name={}, params={}".format(self._cls.__name__, params))
            raise ex


if vega.is_torch_backend():
    import torch.nn as torch_nn
    import timm.loss as timm_loss

    ClassFactory.register_from_package(torch_nn, ClassType.LOSS)
    ClassFactory.register_from_package(timm_loss, ClassType.LOSS)
elif vega.is_tf_backend():
    import tensorflow.losses as tf_loss

    ClassFactory.register_from_package(tf_loss, ClassType.LOSS)
コード例 #4
0
        return pfms

    def reset(self):
        """Reset states for new evaluation after each epoch."""
        for val in self.mdict.values():
            val.reset()

    @property
    def results(self):
        """Return metrics results."""
        res = {}
        for name, metric in self.mdict.items():
            res.update(metric.result)
        return res

    @property
    def objectives(self):
        """Return objectives results."""
        return {name: self.mdict.get(name).objective for name in self.mdict}

    def __getattr__(self, key):
        """Get a metric by key name.

        :param key: metric name
        :type key: str
        """
        return self.mdict[key]


ClassFactory.register_from_package(metrics, ClassType.METRIC)