def _init_transforms(self): """Initialize transforms method. :return: a list of object :rtype: list """ if "transforms" in self.args.keys(): transforms = list() if not isinstance(self.args.transforms, list): self.args.transforms = [self.args.transforms] for i in range(len(self.args.transforms)): transform_name = self.args.transforms[i].pop("type") kwargs = self.args.transforms[i] if ClassFactory.is_exists(ClassType.TRANSFORM, transform_name): transforms.append( ClassFactory.get_cls(ClassType.TRANSFORM, transform_name)(**kwargs)) else: transforms.append( getattr( importlib.import_module('torchvision.transforms'), transform_name)(**kwargs)) return transforms else: return list()
def _init_callbacks(self, callbacks): # Initialize callbacks by configuration or parameters if callbacks is None: _callbacks = [] callbacks_config = self.cfg.callbacks.copy() for callback_config in callbacks_config.values(): callback_name = callback_config.pop('type') if ClassFactory.is_exists(ClassType.CALLBACK, callback_name): callback_class = ClassFactory.get_cls( ClassType.CALLBACK, callback_name) callback = callback_class(**callback_config) _callbacks.append(callback) else: raise ValueError( "Undefined callback {}".format(callback_name)) else: _callbacks = callbacks # Sort the callbacks metrics_evaluator = None model_checkpoint = None model_statistics = None predefined_callbacks = [] customized_callbacks = [] for callback in _callbacks: if isinstance(callback, self._predefined_callbacks()): if isinstance(callback, MetricsEvaluator): metrics_evaluator = callback if isinstance(callback, ModelStatistics): model_statistics = callback if isinstance(callback, ModelCheckpoint): model_checkpoint = callback else: predefined_callbacks.append(callback) else: customized_callbacks.append(callback) if metrics_evaluator is None: metrics_evaluator = MetricsEvaluator() if model_checkpoint is None: model_checkpoint = ModelCheckpoint() _callbacks = [metrics_evaluator, model_checkpoint] + \ customized_callbacks + predefined_callbacks if 'model_statistic' in self.cfg and self.cfg.model_statistic: if model_statistics is None: model_statistics = ModelStatistics() _callbacks = [model_statistics] + _callbacks # Creat Callbacklist and set its trainer and pramameters self.callbacks = CallbackList(_callbacks) _callbacks_params = { 'epochs': self.epochs, 'is_chief': self.is_chief, 'use_cuda': self.use_cuda, 'do_validation': self.do_validation, 'is_detection_trainer': self.cfg.is_detection_trainer } self.callbacks.set_params(_callbacks_params) self.callbacks.set_trainer(self)
def __init__(self, aux_weight, loss_base): """Init MixAuxiliaryLoss.""" self.aux_weight = aux_weight loss_base_cp = loss_base.copy() loss_base_name = loss_base_cp.pop('type') if ClassFactory.is_exists('trainer.loss', loss_base_name): loss_class = ClassFactory.get_cls('trainer.loss', loss_base_name) else: loss_class = getattr(importlib.import_module('tensorflow.losses'), loss_base_name) self.loss_fn = loss_class(**loss_base_cp)
def _init_lr_scheduler(self, scheduler=None): """Init lr scheduler from torch.optim.lr_scheduler according to type in config.""" if scheduler is not None: return scheduler scheduler_config = self.cfg.lr_scheduler.copy() scheduler_name = scheduler_config.pop('type') if ClassFactory.is_exists(ClassType.LR_SCHEDULER, scheduler_name): scheduler_class = ClassFactory.get_cls(ClassType.LR_SCHEDULER, scheduler_name) else: scheduler_class = getattr( importlib.import_module('torch.optim.lr_scheduler'), scheduler_name) return scheduler_class(self.optimizer, **scheduler_config)
def _init_after_scheduler(self): """Init after_scheduler with after_scheduler_config.""" if isinstance(self.after_scheduler_config, dict): scheduler_config = copy.deepcopy(self.after_scheduler_config) print("after_scheduler_config: {}".format(scheduler_config)) scheduler_name = scheduler_config.pop('type') if ClassFactory.is_exists(ClassType.LR_SCHEDULER, scheduler_name): scheduler_class = ClassFactory.get_cls(ClassType.LR_SCHEDULER, scheduler_name) else: scheduler_class = getattr(importlib.import_module('torch.optim.lr_scheduler'), scheduler_name) self.after_scheduler = scheduler_class(self.optimizer, **scheduler_config)
def _init_loss(self, loss_fn=None): """Init loss function from torch according to type in config.""" if loss_fn is not None: return loss_fn loss_config = self.cfg.loss.copy() loss_name = loss_config.pop('type') if NetworkFactory.is_exists(NetTypes.LOSS, loss_name): loss_class = NetworkFactory.get_network(NetTypes.LOSS, loss_name) elif ClassFactory.is_exists('trainer.loss', loss_name): loss_class = ClassFactory.get_cls('trainer.loss', loss_name) else: loss_class = getattr(importlib.import_module('torch.nn'), loss_name) loss_fn = loss_class(**loss_config) if self.cfg.cuda: loss_fn = loss_fn.cuda() return loss_fn
def __init__(self, metric_cfg): """Init Metrics.""" metric_config = deepcopy(metric_cfg) self.mdict = {} if not isinstance(metric_config, list): metric_config = [metric_config] for metric_item in metric_config: metric_name = metric_item.pop('type') if ClassFactory.is_exists(ClassType.METRIC, metric_name): metric_class = ClassFactory.get_cls(ClassType.METRIC, metric_name) else: metric_class = getattr( importlib.import_module('vega.core.metrics'), metric_name) if isfunction(metric_class): metric_class = partial(metric_class, **metric_item) else: metric_class = metric_class(**metric_item) self.mdict[metric_name] = metric_class self.mdict = Config(self.mdict)
def _init_optimizer(self, optimizer=None): """Init optimizer from torch.optim according to optim type in config.""" if optimizer is not None: return optimizer optim_config = self.cfg.optim.copy() optim_name = optim_config.pop('type') if ClassFactory.is_exists(ClassType.OPTIM, optim_name): optim_class = ClassFactory.get_cls(ClassType.OPTIM, optim_name) else: optim_class = getattr(importlib.import_module('torch.optim'), optim_name) learnable_params = [ param for param in self.model.parameters() if param.requires_grad ] optimizer = optim_class(learnable_params, **optim_config) if self.horovod: optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=self.model.named_parameters(), compression=hvd.Compression.none) return optimizer
def _init_loss(self): """Init loss.""" if vega.is_torch_backend(): loss_config = self.criterion.copy() loss_name = loss_config.pop('type') loss_class = getattr(importlib.import_module('torch.nn'), loss_name) return loss_class(**loss_config) elif vega.is_tf_backend(): from inspect import isclass loss_config = self.config.tf_criterion.copy() loss_name = loss_config.pop('type') if ClassFactory.is_exists('trainer.loss', loss_name): loss_class = ClassFactory.get_cls('trainer.loss', loss_name) if isclass(loss_class): return loss_class(**loss_config) else: return partial(loss_class, **loss_config) else: loss_class = getattr( importlib.import_module('tensorflow.losses'), loss_name) return partial(loss_class, **loss_config)
def _init_after_scheduler(self): """Init after_scheduler with after_scheduler_config.""" if isinstance(self.after_scheduler_config, dict): scheduler_config = copy.deepcopy(self.after_scheduler_config) print("after_scheduler_config: {}".format(scheduler_config)) scheduler_name = scheduler_config.pop('type') if ClassFactory.is_exists(ClassType.LR_SCHEDULER, scheduler_name): scheduler_class = ClassFactory.get_cls(ClassType.LR_SCHEDULER, scheduler_name) else: scheduler_class = getattr( importlib.import_module('torch.optim.lr_scheduler'), scheduler_name) if scheduler_class.__name__ == "CosineAnnealingLR": if scheduler_config.get("T_max", -1) == -1: if scheduler_config.get("by_epoch", True): scheduler_config["T_max"] = self.epochs else: scheduler_config["T_max"] = self.epochs * self.steps self.after_scheduler = scheduler_class(self.optimizer, **scheduler_config)