コード例 #1
0
    def _init_transforms(self):
        """Initialize transforms method.

        :return: a list of object
        :rtype: list
        """
        if "transforms" in self.args.keys():
            transforms = list()
            if not isinstance(self.args.transforms, list):
                self.args.transforms = [self.args.transforms]
            for i in range(len(self.args.transforms)):
                transform_name = self.args.transforms[i].pop("type")
                kwargs = self.args.transforms[i]
                if ClassFactory.is_exists(ClassType.TRANSFORM, transform_name):
                    transforms.append(
                        ClassFactory.get_cls(ClassType.TRANSFORM,
                                             transform_name)(**kwargs))
                else:
                    transforms.append(
                        getattr(
                            importlib.import_module('torchvision.transforms'),
                            transform_name)(**kwargs))
            return transforms
        else:
            return list()
コード例 #2
0
ファイル: trainer.py プロジェクト: zhwzhong/vega
 def _init_callbacks(self, callbacks):
     # Initialize callbacks by configuration or parameters
     if callbacks is None:
         _callbacks = []
         callbacks_config = self.cfg.callbacks.copy()
         for callback_config in callbacks_config.values():
             callback_name = callback_config.pop('type')
             if ClassFactory.is_exists(ClassType.CALLBACK, callback_name):
                 callback_class = ClassFactory.get_cls(
                     ClassType.CALLBACK, callback_name)
                 callback = callback_class(**callback_config)
                 _callbacks.append(callback)
             else:
                 raise ValueError(
                     "Undefined callback {}".format(callback_name))
     else:
         _callbacks = callbacks
     # Sort the callbacks
     metrics_evaluator = None
     model_checkpoint = None
     model_statistics = None
     predefined_callbacks = []
     customized_callbacks = []
     for callback in _callbacks:
         if isinstance(callback, self._predefined_callbacks()):
             if isinstance(callback, MetricsEvaluator):
                 metrics_evaluator = callback
             if isinstance(callback, ModelStatistics):
                 model_statistics = callback
             if isinstance(callback, ModelCheckpoint):
                 model_checkpoint = callback
             else:
                 predefined_callbacks.append(callback)
         else:
             customized_callbacks.append(callback)
     if metrics_evaluator is None:
         metrics_evaluator = MetricsEvaluator()
     if model_checkpoint is None:
         model_checkpoint = ModelCheckpoint()
     _callbacks = [metrics_evaluator, model_checkpoint] + \
         customized_callbacks + predefined_callbacks
     if 'model_statistic' in self.cfg and self.cfg.model_statistic:
         if model_statistics is None:
             model_statistics = ModelStatistics()
         _callbacks = [model_statistics] + _callbacks
     # Creat Callbacklist and set its trainer and pramameters
     self.callbacks = CallbackList(_callbacks)
     _callbacks_params = {
         'epochs': self.epochs,
         'is_chief': self.is_chief,
         'use_cuda': self.use_cuda,
         'do_validation': self.do_validation,
         'is_detection_trainer': self.cfg.is_detection_trainer
     }
     self.callbacks.set_params(_callbacks_params)
     self.callbacks.set_trainer(self)
コード例 #3
0
 def __init__(self, aux_weight, loss_base):
     """Init MixAuxiliaryLoss."""
     self.aux_weight = aux_weight
     loss_base_cp = loss_base.copy()
     loss_base_name = loss_base_cp.pop('type')
     if ClassFactory.is_exists('trainer.loss', loss_base_name):
         loss_class = ClassFactory.get_cls('trainer.loss', loss_base_name)
     else:
         loss_class = getattr(importlib.import_module('tensorflow.losses'),
                              loss_base_name)
     self.loss_fn = loss_class(**loss_base_cp)
コード例 #4
0
ファイル: trainer.py プロジェクト: zhwzhong/vega
 def _init_lr_scheduler(self, scheduler=None):
     """Init lr scheduler from torch.optim.lr_scheduler according to type in config."""
     if scheduler is not None:
         return scheduler
     scheduler_config = self.cfg.lr_scheduler.copy()
     scheduler_name = scheduler_config.pop('type')
     if ClassFactory.is_exists(ClassType.LR_SCHEDULER, scheduler_name):
         scheduler_class = ClassFactory.get_cls(ClassType.LR_SCHEDULER,
                                                scheduler_name)
     else:
         scheduler_class = getattr(
             importlib.import_module('torch.optim.lr_scheduler'),
             scheduler_name)
     return scheduler_class(self.optimizer, **scheduler_config)
コード例 #5
0
ファイル: warmup_scheduler.py プロジェクト: zhwzhong/vega
 def _init_after_scheduler(self):
     """Init after_scheduler with after_scheduler_config."""
     if isinstance(self.after_scheduler_config, dict):
         scheduler_config = copy.deepcopy(self.after_scheduler_config)
         print("after_scheduler_config: {}".format(scheduler_config))
         scheduler_name = scheduler_config.pop('type')
         if ClassFactory.is_exists(ClassType.LR_SCHEDULER, scheduler_name):
             scheduler_class = ClassFactory.get_cls(ClassType.LR_SCHEDULER,
                                                    scheduler_name)
         else:
             scheduler_class = getattr(importlib.import_module('torch.optim.lr_scheduler'),
                                       scheduler_name)
         self.after_scheduler = scheduler_class(self.optimizer,
                                                **scheduler_config)
コード例 #6
0
ファイル: trainer.py プロジェクト: zhwzhong/vega
 def _init_loss(self, loss_fn=None):
     """Init loss function from torch according to type in config."""
     if loss_fn is not None:
         return loss_fn
     loss_config = self.cfg.loss.copy()
     loss_name = loss_config.pop('type')
     if NetworkFactory.is_exists(NetTypes.LOSS, loss_name):
         loss_class = NetworkFactory.get_network(NetTypes.LOSS, loss_name)
     elif ClassFactory.is_exists('trainer.loss', loss_name):
         loss_class = ClassFactory.get_cls('trainer.loss', loss_name)
     else:
         loss_class = getattr(importlib.import_module('torch.nn'),
                              loss_name)
     loss_fn = loss_class(**loss_config)
     if self.cfg.cuda:
         loss_fn = loss_fn.cuda()
     return loss_fn
コード例 #7
0
ファイル: metrics.py プロジェクト: zhwzhong/vega
 def __init__(self, metric_cfg):
     """Init Metrics."""
     metric_config = deepcopy(metric_cfg)
     self.mdict = {}
     if not isinstance(metric_config, list):
         metric_config = [metric_config]
     for metric_item in metric_config:
         metric_name = metric_item.pop('type')
         if ClassFactory.is_exists(ClassType.METRIC, metric_name):
             metric_class = ClassFactory.get_cls(ClassType.METRIC,
                                                 metric_name)
         else:
             metric_class = getattr(
                 importlib.import_module('vega.core.metrics'), metric_name)
         if isfunction(metric_class):
             metric_class = partial(metric_class, **metric_item)
         else:
             metric_class = metric_class(**metric_item)
         self.mdict[metric_name] = metric_class
     self.mdict = Config(self.mdict)
コード例 #8
0
ファイル: trainer.py プロジェクト: zhwzhong/vega
 def _init_optimizer(self, optimizer=None):
     """Init optimizer from torch.optim according to optim type in config."""
     if optimizer is not None:
         return optimizer
     optim_config = self.cfg.optim.copy()
     optim_name = optim_config.pop('type')
     if ClassFactory.is_exists(ClassType.OPTIM, optim_name):
         optim_class = ClassFactory.get_cls(ClassType.OPTIM, optim_name)
     else:
         optim_class = getattr(importlib.import_module('torch.optim'),
                               optim_name)
     learnable_params = [
         param for param in self.model.parameters() if param.requires_grad
     ]
     optimizer = optim_class(learnable_params, **optim_config)
     if self.horovod:
         optimizer = hvd.DistributedOptimizer(
             optimizer,
             named_parameters=self.model.named_parameters(),
             compression=hvd.Compression.none)
     return optimizer
コード例 #9
0
ファイル: ps_differential.py プロジェクト: zeyefkey/vega
 def _init_loss(self):
     """Init loss."""
     if vega.is_torch_backend():
         loss_config = self.criterion.copy()
         loss_name = loss_config.pop('type')
         loss_class = getattr(importlib.import_module('torch.nn'),
                              loss_name)
         return loss_class(**loss_config)
     elif vega.is_tf_backend():
         from inspect import isclass
         loss_config = self.config.tf_criterion.copy()
         loss_name = loss_config.pop('type')
         if ClassFactory.is_exists('trainer.loss', loss_name):
             loss_class = ClassFactory.get_cls('trainer.loss', loss_name)
             if isclass(loss_class):
                 return loss_class(**loss_config)
             else:
                 return partial(loss_class, **loss_config)
         else:
             loss_class = getattr(
                 importlib.import_module('tensorflow.losses'), loss_name)
             return partial(loss_class, **loss_config)
コード例 #10
0
    def _init_after_scheduler(self):
        """Init after_scheduler with after_scheduler_config."""
        if isinstance(self.after_scheduler_config, dict):
            scheduler_config = copy.deepcopy(self.after_scheduler_config)
            print("after_scheduler_config: {}".format(scheduler_config))
            scheduler_name = scheduler_config.pop('type')
            if ClassFactory.is_exists(ClassType.LR_SCHEDULER, scheduler_name):
                scheduler_class = ClassFactory.get_cls(ClassType.LR_SCHEDULER,
                                                       scheduler_name)
            else:
                scheduler_class = getattr(
                    importlib.import_module('torch.optim.lr_scheduler'),
                    scheduler_name)

            if scheduler_class.__name__ == "CosineAnnealingLR":
                if scheduler_config.get("T_max", -1) == -1:
                    if scheduler_config.get("by_epoch", True):
                        scheduler_config["T_max"] = self.epochs
                    else:
                        scheduler_config["T_max"] = self.epochs * self.steps

            self.after_scheduler = scheduler_class(self.optimizer,
                                                   **scheduler_config)