Exemple #1
0
 def _init_callbacks(self, callbacks):
     # Initialize custom callbacks by configuration or parameters
     if callbacks is not None:
         return callbacks
     disables = []
     if not self.config.model_statistics:
         disables.append('ModelStatistics')
     self.callbacks = CallbackList(self.config.callbacks, disables)
     self.callbacks.set_trainer(self)
Exemple #2
0
 def _init_callbacks(self, callbacks):
     # Initialize callbacks by configuration or parameters
     if callbacks is None:
         _callbacks = []
         callbacks_config = self.cfg.callbacks.copy()
         for callback_config in callbacks_config.values():
             callback_name = callback_config.pop('type')
             if ClassFactory.is_exists(ClassType.CALLBACK, callback_name):
                 callback_class = ClassFactory.get_cls(
                     ClassType.CALLBACK, callback_name)
                 callback = callback_class(**callback_config)
                 _callbacks.append(callback)
             else:
                 raise ValueError(
                     "Undefined callback {}".format(callback_name))
     else:
         _callbacks = callbacks
     # Sort the callbacks
     metrics_evaluator = None
     model_checkpoint = None
     model_statistics = None
     predefined_callbacks = []
     customized_callbacks = []
     for callback in _callbacks:
         if isinstance(callback, self._predefined_callbacks()):
             if isinstance(callback, MetricsEvaluator):
                 metrics_evaluator = callback
             if isinstance(callback, ModelStatistics):
                 model_statistics = callback
             if isinstance(callback, ModelCheckpoint):
                 model_checkpoint = callback
             else:
                 predefined_callbacks.append(callback)
         else:
             customized_callbacks.append(callback)
     if metrics_evaluator is None:
         metrics_evaluator = MetricsEvaluator()
     if model_checkpoint is None:
         model_checkpoint = ModelCheckpoint()
     _callbacks = [metrics_evaluator, model_checkpoint] + \
         customized_callbacks + predefined_callbacks
     if 'model_statistic' in self.cfg and self.cfg.model_statistic:
         if model_statistics is None:
             model_statistics = ModelStatistics()
         _callbacks = [model_statistics] + _callbacks
     # Creat Callbacklist and set its trainer and pramameters
     self.callbacks = CallbackList(_callbacks)
     _callbacks_params = {
         'epochs': self.epochs,
         'is_chief': self.is_chief,
         'use_cuda': self.use_cuda,
         'do_validation': self.do_validation,
         'is_detection_trainer': self.cfg.is_detection_trainer
     }
     self.callbacks.set_params(_callbacks_params)
     self.callbacks.set_trainer(self)
Exemple #3
0
class Trainer(DistributedWorker):
    """Trainer class.

    :param id: id of the model, defaults to None
    :type id: int, optional
    """

    __worker_id__ = 0

    def __init__(self,
                 model=None,
                 id=None,
                 hps=None,
                 load_ckpt_flag=False,
                 **kwargs):
        """Init Trainer."""
        super(Trainer, self).__init__(self.cfg)

        self.worker_type = WorkerTypes.TRAINER
        Trainer.__worker_id__ += 1
        if id is not None:
            self._worker_id = id
        else:
            self._worker_id = Trainer.__worker_id__

        # Data Memeber list of Trainer
        self.is_chief = True
        self.use_cuda = True
        self.epochs = self.cfg.epochs
        self.do_validation = True
        self.auto_save_ckpt = True
        self.auto_save_perf = True
        self.skip_train = False
        self.valid_freq = self.cfg.get('valid_freq', 1)
        self.hps = hps
        self.model = model
        self.optimizer = None
        self.lr_scheduler = None
        self.loss = None
        self.use_syncbn = self.cfg.get('syncbn', False)
        self.use_amp = self.cfg.get('amp', False)
        self.train_metrics = None
        self.valid_metrics = None
        self.train_loader = None
        self.valid_loader = None
        self.train_step = None
        self.valid_step = None
        self.make_batch = None
        self.callbacks = None
        self.model_desc = {}
        self.visual_data = {}
        self.load_ckpt_flag = load_ckpt_flag
        self.checkpoint_file_name = 'weights.pth'
        self.model_pickle_file_name = 'model.pkl'
        self.performance_file_name = 'performance.txt'
        self.horovod = self.cfg.get('horovod', False)
        # Used by TimmTrainerCallbacks since it builds its trainer in
        # the before_train callback
        self.lazy_built = self.cfg.get('lazy_built', False)
        # Indicate whether the necessary components of a trainer
        # has been built for running
        self.has_built = False
        self._callbacks_mapping()

    def _callbacks_mapping(self):
        """Convert config to callback setting."""
        mapping = {}
        callback_config = self.cfg.get('callbacks')
        if callback_config:
            mapping[callback_config] = {'type': callback_config}
        if self.cfg.get('model_statistics'):
            mapping['model_statistics'] = {'type': 'ModelStatistics'}
        # default callbacks
        if 'call_point' in self.cfg.lr_scheduler:
            lr_scheduler_point = self.cfg.lr_scheduler.pop('call_point')
            mapping['lr_scheduler_point'] = {
                'type': 'LearningRateScheduler',
                'call_point': lr_scheduler_point
            }
        else:
            mapping['lr_scheduler_point'] = {'type': 'LearningRateScheduler'}
        mapping['progress_logger'] = {
            'type': 'ProgressLogger',
            'train_verbose': self.cfg.get('report_verbose', 2),
            'train_report_steps': self.cfg.report_freq
        }
        self.cfg.callbacks = Config(mapping)

    def train_process(self):
        """Whole train process of the TrainWorker specified in config.

        After training, the model and validation results are saved to local_worker_path and s3_path.
        """
        init_log(log_file="worker_{}.txt".format(self.worker_id))
        logging.debug("Use the unified Trainer")
        if not self.lazy_built:
            self.build(model=self.model,
                       hps=self.hps,
                       load_ckpt_flag=self.load_ckpt_flag)
        self.train()

    def build(self,
              model=None,
              optimizer=None,
              loss=None,
              lr_scheduler=None,
              metrics=None,
              hps=None,
              callbacks=None,
              train_loader=None,
              valid_loader=None,
              make_batch=None,
              train_step=None,
              valid_step=None,
              load_ckpt_flag=False,
              checkpoint_file_name="weights.pth",
              model_pickle_file_name="model.pkl",
              performance_file_name="performance.txt"):
        """Build the trainer by assembling the necessary components."""
        # Intitialize hyperparameters by parameters or configurations
        self.checkpoint_file_name = checkpoint_file_name
        self.model_pickle_file_name = model_pickle_file_name
        self.performance_file_name = performance_file_name

        self._init_cuda_setting()
        self._init_hps(hps)

        self.do_validation = self.cfg.with_valid
        self.model = self._init_model(model)
        self.load_ckpt_flag = load_ckpt_flag
        if self.load_ckpt_flag:
            self.load_checkpoint()
        else:
            self._load_pretrained_model()
        if self.model is not None and self.use_cuda:
            self.model = self.model.cuda()

        self.use_syncbn = self.cfg.get('syncbn', False)
        if self.use_syncbn:
            self.model = apex.parallel.convert_syncbn_model(self.model)
        self.optimizer = self._init_optimizer(optimizer)
        self.loss = self._init_loss(loss)
        self.lr_scheduler = self._init_lr_scheduler(lr_scheduler)
        # Some trainer has different train batch size from valid batch
        self.train_metrics = self._init_metrics(metrics)
        self.valid_metrics = self._init_metrics(metrics)
        self.train_loader = self._init_dataloader(mode='train',
                                                  loader=train_loader)
        self.valid_loader = self._init_dataloader(mode='test',
                                                  loader=valid_loader)
        self._init_horovod_setting()
        self.use_amp = self.cfg.get('amp', False)
        if self.use_amp:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level='O1')
        if self.callbacks is None:
            self.callbacks = callbacks
        self._init_step_functions(make_batch, train_step, valid_step)
        # self.output_model_desc()
        cur_working_dir = FileOps.join_path(self.local_output_path,
                                            self.step_name)
        FileOps.make_dir(cur_working_dir)
        # Make sure Trainer has been built for training
        self.has_built = True

    def train(self):
        """Do the training with data, callbacks and step functions etc."""
        self._init_callbacks(self.callbacks)
        self._train_loop()

    def _train_loop(self):
        """Do the training with data, callbacks and step functions etc."""
        # Allow user to build trainer in before_train() callback, but they
        # should set lazy_built in configuration file to True
        self.callbacks.before_train()
        if self.skip_train:
            return
        for epoch in range(self.epochs):
            epoch_logs = {'train_num_batches': len(self.train_loader)}
            if self.do_validation:
                epoch_logs.update(
                    {'valid_num_batches': len(self.valid_loader)})
            self.callbacks.before_epoch(epoch, epoch_logs)
            for batch_index, batch in enumerate(self.train_loader):
                batch = self.make_batch(batch)
                batch_logs = {'train_batch': batch}
                self.callbacks.before_train_step(batch_index, batch_logs)
                train_batch_output = self.train_step(batch)
                batch_logs.update(train_batch_output)
                if self.cfg.is_detection_trainer:
                    batch_logs.update({'is_detection_trainer': True})
                self.callbacks.after_train_step(batch_index, batch_logs)
            if self.do_validation and self._should_run_validation(epoch):
                self._valid_loop()
            self.callbacks.after_epoch(epoch)
        self.callbacks.after_train()

    def _valid_loop(self):
        self.callbacks.before_valid()
        self.model.eval()
        with torch.no_grad():
            for batch_index, batch in enumerate(self.valid_loader):
                batch = self.make_batch(batch)
                batch_logs = {'valid_batch': batch}
                self.callbacks.before_valid_step(batch_index, batch_logs)
                valid_batch_output = self.valid_step(batch)
                self.callbacks.after_valid_step(batch_index,
                                                valid_batch_output)
            # TODO: will be removed to callback
            pfm = self.valid_metrics.results
            if self.horovod:
                pfm = pfm[0][0] if isinstance(pfm[0], list) else pfm[0]
                pfm = self._metric_average(
                    pfm,
                    list(self.valid_metrics.results_dict.keys())[0])
            if self.is_chief and self.auto_save_perf:
                self._save_performance(pfm)
        self.callbacks.after_valid()

    def _default_train_step(self, batch):
        self.model.train()
        input, target = batch
        self.optimizer.zero_grad()
        output = self.model(input)
        loss = self.loss(output, target)
        if self.use_amp:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
                self.optimizer.synchronize()
            with self.optimizer.skip_synchronize():
                self.optimizer.step()
        else:
            loss.backward()
            if 'grad_clip' in self.cfg:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.cfg.grad_clip)
            self.optimizer.step()
        return {'loss': loss.item(), 'train_batch_output': output}

    def _default_valid_step(self, batch):
        input, target = batch
        if self.cfg.is_detection_trainer:
            output = self.model(input, forward_train=False)
        else:
            output = self.model(input)
        return {'valid_batch_output': output}

    def _should_run_validation(self, epoch):
        # Zero valid_freq means doesn't run _valid_loop of the trainer
        # and user may provide _valid_loop in other callbacks
        if self.valid_freq == 0:
            return False
        else:
            return (epoch + 1) % self.valid_freq == 0

    def _default_make_batch(self, batch):
        """Unpack batch to get input and target."""
        input, target = batch
        if self.use_cuda and not self.cfg.is_detection_trainer:
            input, target = input.cuda(), target.cuda()
        return (input, target)

    def _predefined_callbacks(self):
        predefined_callbacks = (MetricsEvaluator, LearningRateScheduler,
                                ProgressLogger, PerformanceSaver,
                                ModelStatistics, ModelCheckpoint)
        return predefined_callbacks

    def _init_callbacks(self, callbacks):
        # Initialize callbacks by configuration or parameters
        if callbacks is None:
            _callbacks = []
            callbacks_config = self.cfg.callbacks.copy()
            for callback_config in callbacks_config.values():
                callback_name = callback_config.pop('type')
                if ClassFactory.is_exists(ClassType.CALLBACK, callback_name):
                    callback_class = ClassFactory.get_cls(
                        ClassType.CALLBACK, callback_name)
                    callback = callback_class(**callback_config)
                    _callbacks.append(callback)
                else:
                    raise ValueError(
                        "Undefined callback {}".format(callback_name))
        else:
            _callbacks = callbacks
        # Sort the callbacks
        metrics_evaluator = None
        model_checkpoint = None
        model_statistics = None
        predefined_callbacks = []
        customized_callbacks = []
        for callback in _callbacks:
            if isinstance(callback, self._predefined_callbacks()):
                if isinstance(callback, MetricsEvaluator):
                    metrics_evaluator = callback
                if isinstance(callback, ModelStatistics):
                    model_statistics = callback
                if isinstance(callback, ModelCheckpoint):
                    model_checkpoint = callback
                else:
                    predefined_callbacks.append(callback)
            else:
                customized_callbacks.append(callback)
        if metrics_evaluator is None:
            metrics_evaluator = MetricsEvaluator()
        if model_checkpoint is None:
            model_checkpoint = ModelCheckpoint()
        _callbacks = [metrics_evaluator, model_checkpoint] + \
            customized_callbacks + predefined_callbacks
        if 'model_statistic' in self.cfg and self.cfg.model_statistic:
            if model_statistics is None:
                model_statistics = ModelStatistics()
            _callbacks = [model_statistics] + _callbacks
        # Creat Callbacklist and set its trainer and pramameters
        self.callbacks = CallbackList(_callbacks)
        _callbacks_params = {
            'epochs': self.epochs,
            'is_chief': self.is_chief,
            'use_cuda': self.use_cuda,
            'do_validation': self.do_validation,
            'is_detection_trainer': self.cfg.is_detection_trainer
        }
        self.callbacks.set_params(_callbacks_params)
        self.callbacks.set_trainer(self)

    def _init_step_functions(self,
                             make_batch=None,
                             train_step=None,
                             valid_step=None):
        # Init make_batch function by user or using the default one
        if self.make_batch is None:
            if make_batch is not None:
                self.make_batch = make_batch
            else:
                self.make_batch = self._default_make_batch

        # Init train_step function by user or using the default one
        if self.train_step is None:
            if train_step is not None:
                self.train_step = train_step
            else:
                self.train_step = self._default_train_step

        # Init valid_step function by user or using the default one
        if self.valid_step is None:
            if valid_step is not None:
                self.valid_step = valid_step
            else:
                self.valid_step = self._default_valid_step

    def _init_all_settings(self):
        """Init all settings from config."""
        if self.cfg.cuda:
            self._init_cuda_setting()
        self._init_hps(self.hps)
        if self.model is None:
            self.model = self._init_model()
        if self.model is not None and self.cfg.cuda:
            self.model = self.model.cuda()
        if self._flag_load_checkpoint:
            self.load_checkpoint()
        else:
            self._load_pretrained_model()
        self.use_syncbn = self.cfg.get('syncbn', False)
        if self.use_syncbn:
            self.model = apex.parallel.convert_syncbn_model(self.model)
        self.epochs = self.cfg.epochs
        self.optimizer = self._init_optimizer()
        self.lr_scheduler = self._init_lr_scheduler()
        self.loss = self._init_loss()
        if self.horovod:
            self._init_horovod_setting()
        self.use_amp = self.cfg.get('amp', False)
        if self.use_amp:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level='O1')
        self.train_loader = self._init_dataloader(mode='train')
        self.valid_loader = self._init_dataloader(mode='test')

    def _init_cuda_setting(self):
        """Init CUDA setting."""
        if not self.cfg.cuda:
            self.cfg.device = -1
            return
        self.cfg.device = self.cfg.cuda if self.cfg.cuda is not True else 0
        self.use_cuda = True
        if self.horovod:
            torch.cuda.set_device(hvd.local_rank())
        torch.cuda.manual_seed(self.cfg.seed)

    def _init_horovod_setting(self):
        """Init horovod setting."""
        self.is_chief = True
        if self.horovod:
            hvd.broadcast_parameters(self.model.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(self.optimizer, root_rank=0)
            if hvd.rank() != 0:
                self.is_chief = False
            else:
                self.is_chief = True

    def _init_model(self, model=None):
        """Load model desc from save path and parse to model."""
        if model is not None:
            return model
        model_cfg = ClassFactory.__configs__.get('model')
        if 'model_desc_file' in model_cfg and model_cfg.model_desc_file is not None:
            desc_file = model_cfg.model_desc_file.replace(
                "{model_zoo}", self.model_zoo_path)
            desc_file = desc_file.replace("{local_base_path}",
                                          self.local_base_path)
            if ":" not in desc_file:
                desc_file = os.path.abspath(desc_file)
            if ":" in desc_file:
                local_desc_file = FileOps.join_path(
                    self.local_output_path, os.path.basename(desc_file))
                FileOps.copy_file(desc_file, local_desc_file)
                desc_file = local_desc_file
            if self.horovod:
                hvd.join()
            model_desc = Config(desc_file)
            logging.info("net_desc:{}".format(model_desc))
        elif 'model_desc' in model_cfg and model_cfg.model_desc is not None:
            model_desc = model_cfg.model_desc
        else:
            return None
        if model_desc is not None:
            self.model_desc = model_desc
            net_desc = NetworkDesc(model_desc)
            model = net_desc.to_model()
            return model
        else:
            return None

    def _init_hps(self, hps):
        """Convert trainer values in hps to cfg.

        :param hps: hyperparameters
        :type hps: dict
        """
        if "hps_file" in self.cfg and self.cfg.hps_file is not None:
            hps_file = self.cfg.hps_file.replace("{local_base_path}",
                                                 self.local_base_path)
            hps = Config(hps_file)
        if hps is not None:
            self.cfg = Config(update_dict(hps.get('trainer'), self.cfg))
            self.hps = hps
        # if 'hps_file' in self.cfg and self.cfg.hps_file is not None:
        #     hps_dict = {}
        #     local_hp_file = FileOps.join_path(self.local_output_path,
        #                                       os.path.basename(self.cfg.hps_file))
        #     FileOps.copy_file(self.cfg.hps_file, local_hp_file)
        #     with open(local_hp_file) as json_file:
        #         hps_dict = json.load(json_file)
        #         hps_dict = Config(hps_dict)
        #     self.cfg = Config(update_dict(hps_dict.get('trainer'), self.cfg))
        #     if self.hps is None:
        #         self.hps = hps_dict
        #     else:
        #         update_dict(hps_dict, self.hps)
        #     logging.info("load hps file:{}".format(self.hps))

    def _init_optimizer(self, optimizer=None):
        """Init optimizer from torch.optim according to optim type in config."""
        if optimizer is not None:
            return optimizer
        optim_config = self.cfg.optim.copy()
        optim_name = optim_config.pop('type')
        if ClassFactory.is_exists(ClassType.OPTIM, optim_name):
            optim_class = ClassFactory.get_cls(ClassType.OPTIM, optim_name)
        else:
            optim_class = getattr(importlib.import_module('torch.optim'),
                                  optim_name)
        learnable_params = [
            param for param in self.model.parameters() if param.requires_grad
        ]
        optimizer = optim_class(learnable_params, **optim_config)
        if self.horovod:
            optimizer = hvd.DistributedOptimizer(
                optimizer,
                named_parameters=self.model.named_parameters(),
                compression=hvd.Compression.none)
        return optimizer

    def _init_loss(self, loss_fn=None):
        """Init loss function from torch according to type in config."""
        if loss_fn is not None:
            return loss_fn
        loss_config = self.cfg.loss.copy()
        loss_name = loss_config.pop('type')
        if NetworkFactory.is_exists(NetTypes.LOSS, loss_name):
            loss_class = NetworkFactory.get_network(NetTypes.LOSS, loss_name)
        elif ClassFactory.is_exists('trainer.loss', loss_name):
            loss_class = ClassFactory.get_cls('trainer.loss', loss_name)
        else:
            loss_class = getattr(importlib.import_module('torch.nn'),
                                 loss_name)
        loss_fn = loss_class(**loss_config)
        if self.cfg.cuda:
            loss_fn = loss_fn.cuda()
        return loss_fn

    def _init_lr_scheduler(self, scheduler=None):
        """Init lr scheduler from torch.optim.lr_scheduler according to type in config."""
        if scheduler is not None:
            return scheduler
        scheduler_config = self.cfg.lr_scheduler.copy()
        scheduler_name = scheduler_config.pop('type')
        if ClassFactory.is_exists(ClassType.LR_SCHEDULER, scheduler_name):
            scheduler_class = ClassFactory.get_cls(ClassType.LR_SCHEDULER,
                                                   scheduler_name)
        else:
            scheduler_class = getattr(
                importlib.import_module('torch.optim.lr_scheduler'),
                scheduler_name)
        return scheduler_class(self.optimizer, **scheduler_config)

    def _init_metrics(self, metrics=None):
        """Init metrics."""
        if metrics is not None:
            return metrics
        else:
            return Metrics(self.cfg.metric)

    def _init_dataloader(self, mode, loader=None):
        """Init dataloader."""
        if loader is not None:
            return loader
        if self.horovod:
            if hvd.local_rank() == 0:
                Dataset()
            hvd.join()
        if mode == "train" and self.hps is not None and self.hps.get(
                "dataset") is not None:
            dataset = Dataset(mode=mode, hp=self.hps)
        else:
            dataset = Dataset(mode=mode)
        if self.horovod:
            sampler = torch.utils.data.distributed.DistributedSampler(
                dataset, num_replicas=hvd.size(), rank=hvd.rank())
            dataset.sampler = sampler
        return dataset.dataloader

    def _load_pretrained_model(self):
        if self.model is None:
            return
        if "pretrained_model_file" in self.cfg and self.cfg.pretrained_model_file is not None:
            model_file = self.cfg.pretrained_model_file.replace(
                "{model_zoo}", self.model_zoo_path)
            model_file = os.path.abspath(model_file)
            ckpt = torch.load(model_file)
            self.model.load_state_dict(ckpt)
            return

    def load_checkpoint(self,
                        worker_id=None,
                        step_name=None,
                        saved_folder=None):
        """Load checkpoint."""
        if saved_folder is None:
            if worker_id is None:
                worker_id = self.worker_id
            if step_name is None:
                step_name = self.step_name
            saved_folder = self.get_local_worker_path(step_name, worker_id)
        checkpoint_file = FileOps.join_path(saved_folder,
                                            self.checkpoint_file_name)
        model_pickle_file = FileOps.join_path(saved_folder,
                                              self.model_pickle_file_name)
        try:
            with open(model_pickle_file, 'rb') as f:
                model = pickle.load(f)
                ckpt = torch.load(checkpoint_file,
                                  map_location=torch.device('cpu'))
                model.load_state_dict(ckpt['weight'])
                if self.cfg.cuda:
                    model = model.cuda()
                self.model = model
        except Exception:
            logging.info(
                'Checkpoint file is not existed, use default model now.')
            return

    def _save_checkpoint(self, epoch):
        """Save checkpoint."""
        checkpoint_file = FileOps.join_path(self.get_local_worker_path(),
                                            self.checkpoint_file_name)
        model_pickle_file = FileOps.join_path(self.get_local_worker_path(),
                                              self.model_pickle_file_name)
        # pickle model
        with open(model_pickle_file, 'wb') as handle:
            pickle.dump(self.model, handle, protocol=pickle.HIGHEST_PROTOCOL)
        # save checkpoint
        ckpt = {
            'epoch': epoch,
            'weight': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'lr_scheduler': self.lr_scheduler.state_dict(),
        }
        torch.save(ckpt, checkpoint_file)

    def _save_performance(self, performance):
        """Save performance into performance.txt.

        :param performance: performance value
        """
        logging.debug("performance=%s", str(performance))
        self.performance_file = FileOps.join_path(self.get_local_worker_path(),
                                                  self.performance_file_name)
        with open(self.performance_file, 'w') as f:
            if isinstance(performance, list):
                for p in performance:
                    f.write("{}\n".format(p))
            elif isinstance(performance, dict):
                for p in performance.values():
                    f.write("{}\n".format(p))
            else:
                f.write("{}".format(performance))

    def get_performance(self,
                        worker_id=None,
                        step_name=None,
                        saved_folder=None):
        """Read Performance values from perform.txt.

        :param step_name: step name in the pipeline.
        :type step_name: str.
        :param worker_id: the worker's worker id.
        :type worker_id: str.
        :return: performance value
        :rtype: int/float/list

        """
        if saved_folder is None:
            if worker_id is None:
                worker_id = self.worker_id
            if step_name is None:
                step_name = self.step_name
            saved_folder = self.get_local_worker_path(step_name, worker_id)
        performance_file = FileOps.join_path(saved_folder,
                                             self.performance_file_name)
        if not os.path.isfile(performance_file):
            logging.info("Performance file is not exited, file={}".format(
                performance_file))
            return []
        with open(performance_file, 'r') as f:
            performance = []
            for line in f.readlines():
                line = line.strip()
                if line == "":
                    continue
                data = json.loads(line)
                if isinstance(data, list):
                    data = data[0]
                performance.append(data)
            logging.info("performance={}".format(performance))
        return performance

    def _metric_average(self, val, name):
        """Do metric average.

        :param val: input value
        :param name: metric name
        :return:
        """
        tensor = torch.tensor(val)
        avg_tensor = hvd.allreduce(tensor, name=name)
        return avg_tensor.item()

    @property
    def _first_rank(self):
        """Check if the first rank."""
        if self.horovod and hvd.rank() != 0:
            return False
        else:
            return True

    def output_model_desc(self, id=None, model_desc=None, performance=None):
        """Save model desc and performance.

        :param id: model desc id, usally worker id instead.
        :type id: int or str.
        :param model_desc: model description.
        :type model_desc: json.
        :param performance: performance value, eg. {"accuracy": 98.23}.
        :type performance: json.

        """
        if id is None:
            id = self.worker_id
        if model_desc is None:
            if not hasattr(self, "model_desc"):
                logger.error(
                    "Failed to save model desc, param 'model_desc' is not assigned."
                )
                return
            model_desc = self.model_desc
        _file = FileOps.join_path(self.local_output_path, self.step_name,
                                  "model_desc_{}.json".format(str(id)))
        FileOps.make_base_dir(_file)
        try:
            with open(_file, "w") as f:
                json.dump(model_desc, f)
        except Exception as ex:
            logger.error(
                "Failed to save model desc, file={}, desc={}, msg={}".format(
                    _file, model_desc, str(ex)))
            return
        if performance is not None:
            self.output_evaluate_result(id, performance)

    def _backup(self):
        """Backup result worker folder."""
        if self.need_backup is True and self.backup_base_path is not None:
            backup_worker_path = FileOps.join_path(self.backup_base_path,
                                                   self.get_worker_subpath())
            FileOps.copy_folder(self.get_local_worker_path(),
                                backup_worker_path)

    def _save_visual_data(self, is_train=True, pfms=None, loss=None, lr=None):
        # TODO Will move to metric base class later.
        for _name, value in pfms.items():
            if is_train:
                _name = "{}_{}".format("t", _name)
            else:
                _name = "{}_{}".format("v", _name)
            if isinstance(value, list):
                for i, _item in enumerate(value):
                    _name = "{}_{}".format(_name, i)
                    self.visual_data[_name] = _item.data.item()
            elif isinstance(value, dict):
                for k, v in value.keys():
                    _name = "{}_{}".format(k, i)
                    self.visual_data[_name] = v
            elif value is not None:
                self.visual_data[_name] = value.data.item()
        if loss is not None:
            self.visual_data["loss"] = loss
        if lr is not None:
            self.visual_data["lr"] = lr

    def output_evaluate_result(self,
                               id=None,
                               performance=None,
                               evaluate_type="gpu"):
        """Save model performance.

        :param id: model desc id, usally worker id instead.
        :type id: int or str.
        :param performance: performance value, eg. {"accuracy": 98.23}.
        :type performance: json.
        :param evaluate_type: evaluate type, eg. "gpu", "davinci", "arm".
        :type evaluate_type: str.
        """
        if performance is None:
            return
        if id is None:
            id = self.worker_id
        _file = FileOps.join_path(
            self.local_output_path, self.step_name,
            "performance_{}_{}.txt".format(evaluate_type, str(id)))
        FileOps.make_base_dir(_file)
        try:
            performance = str(performance)
            with open(_file, "w") as f:
                f.write(performance)
        except Exception as ex:
            logger.error(
                "Failed to save performance, file={}, pfm={}, msg={}".format(
                    _file, performance, str(ex)))
            return

    def output_hps(self, id=None, hps=None):
        """Save model desc and performance.

        :param id: model desc id, usually worker id.
        :type id: int or str.
        :param hps: hyper parameters.
        :type hps: json.

        """
        if id is None:
            id = self.worker_id
        if hps is None:
            if not hasattr(self, "hps"):
                logger.error(
                    "Failed to save hyperparameters, param 'hps' is not assigned."
                )
                return
            hps = self.hps
        _file = FileOps.join_path(self.local_output_path, self.step_name,
                                  "hyperparameters.json")
        FileOps.make_base_dir(_file)
        try:
            with open(_file, "w") as f:
                json.dump({str(id): hps}, f)
        except Exception as ex:
            logger.error(
                "Failed to save hyperparameters, file={}, hps={}, msg={}".
                format(_file, hps, str(ex)))
            return

    def output_model(self,
                     id=None,
                     model=None,
                     model_desc=None,
                     performance=None):
        """Save model, model description, performance.

        :param id: model desc id, usually worker id.
        :type id: int or str.
        :param model: hyper parameters.
        :type hps: json.

        """
        if id is None:
            id = self.worker_id
        if model is None:
            if not hasattr(self, "model"):
                logger.error(
                    "Failed to save model, param 'model' is not assigned.")
                return
            model = self.model
        if model_desc is None:
            if not hasattr(self, "model_desc"):
                logger.error(
                    "Failed to save model, param 'model_desc' is not assigned."
                )
                return
            model_desc = self.model_desc
        _pth_file = FileOps.join_path(self.local_output_path, self.step_name,
                                      "model_{}.pth".format(id))
        FileOps.make_base_dir(_pth_file)
        try:
            torch.save(model.state_dict(), _pth_file)
        except Exception as ex:
            logger.error("Failed to save model pth, file={}, msg={}".format(
                _pth_file, str(ex)))
        self.output_model_desc(id, model_desc, performance)
Exemple #4
0
class Trainer(DistributedWorker):
    """Trainer class.

    :param model: input model, defaults to None
    :type model: tf model, optional
    :param id: id of the model, defaults to None
    :type id: int, optional
    :param hps: hyperparameters, defaults to None
    :type hps: dict, optional
    """

    # __worker_id__ = 0
    config = TrainerConfig()

    def __init__(self,
                 model=None,
                 id=None,
                 hps=None,
                 load_ckpt_flag=False,
                 **kwargs):
        super(Trainer, self).__init__()
        self.worker_type = WorkerTypes.TRAINER
        Trainer.__worker_id__ += 1
        if id is not None:
            self._worker_id = id
        else:
            self._worker_id = Trainer.__worker_id__

        # Data Memeber list of Trainer
        self.is_chief = True
        self.use_cuda = self.config.cuda
        self.epochs = self.config.epochs
        self.do_validation = True
        self.auto_save_ckpt = True
        self.auto_save_perf = True
        self.skip_train = False
        self.valid_interval = self.config.valid_interval
        self.hps = hps
        self.model = model
        self.optimizer = None
        self.lr_scheduler = None
        self.loss = None
        self.use_syncbn = self.config.syncbn
        self.use_amp = self.config.amp
        self.train_metrics = None
        self.valid_metrics = None
        self.call_metrics_on_train = self.config.call_metrics_on_train
        self.train_loader = None
        self.valid_loader = None
        self.train_step = None
        self.valid_step = None
        self.make_batch = None
        self.model_fn = None
        self.train_input_fn = None
        self.valid_input_fn = None
        self.callbacks = None
        self.performance = None
        self.model_desc = {}
        self.visual_data = {}
        self.load_ckpt_flag = load_ckpt_flag
        self.checkpoint_file_name = 'checkpoint.pth'
        self.model_pickle_file_name = 'model.pkl'
        self.model_path = FileOps.join_path(self.get_local_worker_path(),
                                            self.model_pickle_file_name)
        self.checkpoint_file = FileOps.join_path(self.get_local_worker_path(),
                                                 self.checkpoint_file_name)
        self.weights_file = FileOps.join_path(
            self.get_local_worker_path(),
            "model_{}.pth".format(self.worker_id))
        self.distributed = self.config.distributed
        # Used by TimmTrainerCallbacks since it builds its trainer in
        # the before_train callback
        self.lazy_built = self.config.lazy_built
        # Indicate whether the necessary components of a trainer
        # has been built for running
        self.has_built = False
        self._world_size = 1
        self._rank_id = 0
        self._local_rank_id = 0
        self.config.kwargs = kwargs

    def train_process(self):
        """Whole train process of the TrainWorker specified in config.

        After training, the model and validation results are saved to local_worker_path and s3_path.
        """
        init_log(log_file="worker_{}.txt".format(self.worker_id))
        logging.debug("Use the unified Trainer")

        if not self.lazy_built:
            self.build(model=self.model,
                       hps=self.hps,
                       load_ckpt_flag=self.load_ckpt_flag)
        self._init_callbacks(self.callbacks)
        self._train_loop()

    def build(self,
              model=None,
              optimizer=None,
              loss=None,
              lr_scheduler=None,
              metrics=None,
              hps=None,
              callbacks=None,
              train_loader=None,
              valid_loader=None,
              make_batch=None,
              train_step=None,
              valid_step=None,
              model_fn=None,
              train_input_fn=None,
              valid_input_fn=None,
              load_ckpt_flag=False,
              checkpoint_file_name="checkpoint.pth",
              model_pickle_file_name="model.pkl"):
        """Build the trainer by assembling the necessary components."""
        # Intitialize hyperparameters by parameters or configurations
        self._init_hps(hps)
        logging.debug("Trainer Config: {}".format(obj2config(self.config)))
        self.checkpoint_file_name = checkpoint_file_name
        self.model_pickle_file_name = model_pickle_file_name
        if vega.is_torch_backend():
            self._init_step_functions(make_batch, train_step, valid_step)
        elif vega.is_tf_backend():
            self._init_estimator_fn(model_fn, train_input_fn, valid_input_fn)
        self._init_tf_session()
        self._init_distributed_setting()
        self._init_cuda_setting()
        self._init_tf_estimator()
        self.do_validation = self.config.with_valid
        self.model = self._init_model(model)
        self.load_ckpt_flag = load_ckpt_flag
        if self.load_ckpt_flag:
            self.load_checkpoint()
        else:
            self._load_pretrained_model()
        self.use_syncbn = self.config.syncbn
        if self.use_syncbn and vega.is_torch_backend():
            self.model = apex.parallel.convert_syncbn_model(self.model)
        self.train_loader = self._init_dataloader(mode='train',
                                                  loader=train_loader)
        self.valid_loader = self._init_dataloader(mode='val',
                                                  loader=valid_loader)
        if vega.is_torch_backend():
            self.optimizer = Optimizer()(model=self.model, distributed=self.distributed) \
                if optimizer is None else optimizer
            self.loss = Loss()() if loss is None else loss
            self.lr_scheduler = LrScheduler()(
                self.optimizer) if lr_scheduler is None else lr_scheduler
        # Some trainer has different train batch size from valid batch
        self.train_metrics = self._init_metrics(
            metrics) if vega.is_torch_backend() else None
        self.valid_metrics = self._init_metrics(metrics)

        self._init_horovod_setting()
        if self.use_amp and vega.is_torch_backend():
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level='O1')
        if self.callbacks is None:
            self.callbacks = callbacks
        # self.output_model_desc()
        cur_working_dir = FileOps.join_path(self.local_output_path,
                                            self.step_name)
        FileOps.make_dir(cur_working_dir)
        # Make sure Trainer has been built for training
        self.has_built = True

    def _init_cuda_setting(self):
        """Init CUDA setting."""
        if not vega.is_torch_backend():
            return
        if not self.config.cuda:
            self.config.device = -1
            return
        self.config.device = self.config.cuda if self.config.cuda is not True else 0
        self.use_cuda = True
        if self.distributed:
            torch.cuda.set_device(self._local_rank_id)
        torch.cuda.manual_seed(self.config.seed)

    def _init_distributed_setting(self):
        if not self.distributed:
            return
        if vega.is_npu_device():
            self.npu_init = npu_ops.initialize_system()
            self.npu_shutdown = npu_ops.shutdown_system()
            self.sess.run(self.npu_init)
        self._world_size = hvd.size() if vega.is_gpu_device(
        ) else get_rank_size()
        self._rank_id = hvd.rank() if vega.is_gpu_device() else get_rank_id()
        self._local_rank_id = hvd.local_rank() if vega.is_gpu_device(
        ) else get_local_rank_id()

    def _init_horovod_setting(self):
        """Init horovod setting."""
        self.is_chief = True
        if self.distributed and vega.is_torch_backend():
            hvd.broadcast_parameters(self.model.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(self.optimizer, root_rank=0)
            if hvd.rank() != 0:
                self.is_chief = False
            else:
                self.is_chief = True

    def _init_hps(self, hps=None):
        """Load hps from file."""
        if hps is not None:
            self.hps = hps
        elif self.config.hps_file is not None:
            desc_file = self.config.hps_file.replace("{local_base_path}",
                                                     self.local_base_path)
            self.hps = Config(desc_file)
        elif self.config.hps_folder is not None:
            folder = self.config.hps_folder.replace("{local_base_path}",
                                                    self.local_base_path)
            pattern = FileOps.join_path(folder, "desc_*.json")
            desc_file = glob.glob(pattern)[0]
            self.hps = Config(desc_file)
        if self.hps and self.hps.get('trainer'):
            load_conf_from_desc(self.config, self.hps.get('trainer'))

    def _init_model(self, model=None):
        """Load model desc from save path and parse to model."""
        if model is not None:
            if vega.is_torch_backend() and self.use_cuda:
                model = model.cuda()
            return model
        model_cfg = Config(ClassFactory.__configs__.get('model'))
        if "model_desc_file" in model_cfg and model_cfg.model_desc_file is not None:
            desc_file = model_cfg.model_desc_file
            desc_file = desc_file.replace("{local_base_path}",
                                          self.local_base_path)
            if ":" not in desc_file:
                desc_file = os.path.abspath(desc_file)
            if ":" in desc_file:
                local_desc_file = FileOps.join_path(
                    self.local_output_path, os.path.basename(desc_file))
                FileOps.copy_file(desc_file, local_desc_file)
                desc_file = local_desc_file
            model_desc = Config(desc_file)
            logging.info("net_desc:{}".format(model_desc))
        elif "model_desc" in model_cfg and model_cfg.model_desc is not None:
            model_desc = model_cfg.model_desc
        elif "models_folder" in model_cfg and model_cfg.models_folder is not None:
            folder = model_cfg.models_folder.replace("{local_base_path}",
                                                     self.local_base_path)
            pattern = FileOps.join_path(folder, "desc_*.json")
            desc_file = glob.glob(pattern)[0]
            model_desc = Config(desc_file)
        else:
            return None
        if model_desc is not None:
            self.model_desc = model_desc
            net_desc = NetworkDesc(model_desc)
            model = net_desc.to_model()
            if vega.is_torch_backend() and self.use_cuda:
                model = model.cuda()
            return model
        else:
            return None

    def _load_pretrained_model(self):
        if self.model is None:
            return
        if self.config.pretrained_model_file is not None:
            model_file = self.config.pretrained_model_file
            model_file = os.path.abspath(model_file)
            if vega.is_torch_backend():
                ckpt = torch.load(model_file)
                self.model.load_state_dict(ckpt)
            elif vega.is_tf_backend():
                model_folder = os.path.dirname(model_file)
                FileOps.copy_folder(model_folder, self.get_local_worker_path())
            return

    def load_checkpoint(self,
                        worker_id=None,
                        step_name=None,
                        saved_folder=None):
        """Load checkpoint."""
        if saved_folder is None:
            if worker_id is None:
                worker_id = self.worker_id
            if step_name is None:
                step_name = self.step_name
            saved_folder = self.get_local_worker_path(step_name, worker_id)
        checkpoint_file = FileOps.join_path(saved_folder,
                                            self.checkpoint_file_name)
        model_pickle_file = FileOps.join_path(saved_folder,
                                              self.model_pickle_file_name)
        try:
            with open(model_pickle_file, 'rb') as f:
                model = pickle.load(f)
                if vega.is_torch_backend():
                    ckpt = torch.load(checkpoint_file,
                                      map_location=torch.device('cpu'))
                    model.load_state_dict(ckpt['weight'])
                    if self.config.cuda:
                        model = model.cuda()
                elif vega.is_tf_backend():
                    FileOps.copy_folder(saved_folder,
                                        self.get_local_worker_path())
                self.model = model
        except Exception:
            logging.info(
                'Checkpoint file is not existed, use default model now.')
            return

    def _init_metrics(self, metrics=None):
        """Init metrics."""
        if metrics is not None:
            return metrics
        else:
            return Metrics()

    def _init_dataloader(self, mode, loader=None):
        """Init dataloader."""
        if loader is not None:
            return loader
        if mode == "train" and self.hps is not None and self.hps.get(
                "dataset") is not None:
            dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
            dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset"))
        else:
            dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
            dataset = dataset_cls(mode=mode)
        if vega.is_torch_backend():
            if self.distributed:
                sampler = torch.utils.data.distributed.DistributedSampler(
                    dataset, num_replicas=hvd.size(), rank=hvd.rank())
                dataset.sampler = sampler
            return dataset.dataloader
        elif vega.is_tf_backend():
            if self.distributed:
                dataset.set_distributed(self._world_size, self._rank_id)
            return dataset

    def _train_loop(self):
        """Do the training with data, callbacks and step functions etc."""
        # Allow user to build trainer in before_train() callback, but they
        # should set lazy_built in configuration file to True
        self.callbacks.before_train()
        if self.skip_train:
            return
        for epoch in range(self.epochs):
            epoch_logs = {'train_num_batches': len(self.train_loader)}
            if self.do_validation:
                epoch_logs.update(
                    {'valid_num_batches': len(self.valid_loader)})
            self.callbacks.before_epoch(epoch, epoch_logs)
            self._train_epoch()
            if self.do_validation and self._should_run_validation(epoch):
                self._valid_epoch()
            self.callbacks.after_epoch(epoch)
        self.callbacks.after_train()
        if self.distributed:
            self._shutdown_distributed()

    def _train_epoch(self):
        if vega.is_torch_backend():
            self.model.train()
            for batch_index, batch in enumerate(self.train_loader):
                batch = self.make_batch(batch)
                batch_logs = {'train_batch': batch}
                self.callbacks.before_train_step(batch_index, batch_logs)
                train_batch_output = self.train_step(batch)
                batch_logs.update(train_batch_output)
                if self.config.is_detection_trainer:
                    batch_logs.update({'is_detection_trainer': True})
                self.callbacks.after_train_step(batch_index, batch_logs)
        elif vega.is_tf_backend():
            self.estimator.train(input_fn=self.train_input_fn,
                                 steps=len(self.train_loader),
                                 hooks=self._init_logging_hook())

    def _valid_epoch(self):
        self.callbacks.before_valid()
        valid_logs = None
        if vega.is_torch_backend():
            self.model.eval()
            with torch.no_grad():
                for batch_index, batch in enumerate(self.valid_loader):
                    batch = self.make_batch(batch)
                    batch_logs = {'valid_batch': batch}
                    self.callbacks.before_valid_step(batch_index, batch_logs)
                    valid_batch_output = self.valid_step(batch)
                    self.callbacks.after_valid_step(batch_index,
                                                    valid_batch_output)
        elif vega.is_tf_backend():
            eval_metrics = self.estimator.evaluate(
                input_fn=self.valid_input_fn, steps=len(self.valid_loader))
            self.valid_metrics.update(eval_metrics)
            valid_logs = dict()
            valid_logs['cur_valid_perfs'] = self.valid_metrics.results
        self.callbacks.after_valid(valid_logs)

    def _init_step_functions(self,
                             make_batch=None,
                             train_step=None,
                             valid_step=None):
        # Init make_batch function by user or using the default one
        if self.make_batch is None:
            if make_batch is not None:
                self.make_batch = make_batch
            else:
                self.make_batch = self._default_make_batch

        # Init train_step function by user or using the default one
        if self.train_step is None:
            if train_step is not None:
                self.train_step = train_step
            else:
                self.train_step = self._default_train_step

        # Init valid_step function by user or using the default one
        if self.valid_step is None:
            if valid_step is not None:
                self.valid_step = valid_step
            else:
                self.valid_step = self._default_valid_step

    def _default_make_batch(self, batch):
        """Unpack batch to get input and target."""
        input, target = batch
        if self.use_cuda and not self.config.is_detection_trainer:
            input, target = input.cuda(), target.cuda()
        return (input, target)

    def _default_train_step(self, batch):
        input, target = batch
        self.optimizer.zero_grad()
        output = self.model(input)
        loss = self.loss(output, target)
        if self.use_amp:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
                self.optimizer.synchronize()
            with self.optimizer.skip_synchronize():
                self.optimizer.step()
        else:
            loss.backward()
            if self.config.grad_clip:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.config.grad_clip)
            self.optimizer.step()
        return {'loss': loss.item(), 'train_batch_output': output}

    def _default_valid_step(self, batch):
        input, target = batch
        if self.config.is_detection_trainer:
            output = self.model(input, forward_train=False)
        else:
            output = self.model(input)
        return {'valid_batch_output': output}

    def _init_estimator_fn(self, model_fn, train_input_fn, valid_input_fn):
        if self.model_fn is None:
            if model_fn is not None:
                self.model_fn = model_fn
            else:
                self.model_fn = self._default_model_fn

        if self.train_input_fn is None:
            if train_input_fn is not None:
                self.train_input_fn = train_input_fn
            else:
                self.train_input_fn = self._default_train_input_fn

        if self.valid_input_fn is None:
            if valid_input_fn is not None:
                self.valid_input_fn = valid_input_fn
            else:
                self.valid_input_fn = self._default_valid_input_fn

    def _init_minimize_op(self, loss, global_step, var_list=None):
        """Init loss minimize operation, include loss scale method."""
        loss_scale = self.config.loss_scale if self.use_amp else 1.
        if loss_scale != 1:
            scaled_grad_vars = self.optimizer.compute_gradients(
                loss * loss_scale, var_list=var_list)
            unscaled_grad_vars = [(grad / loss_scale, var)
                                  for grad, var in scaled_grad_vars]
            minimize_op = self.optimizer.apply_gradients(
                unscaled_grad_vars, global_step)
        else:
            grad_vars = self.optimizer.compute_gradients(loss,
                                                         var_list=var_list)
            minimize_op = self.optimizer.apply_gradients(
                grad_vars, global_step)
        return minimize_op

    def _default_train_input_fn(self):
        return self.train_loader.input_fn()

    def _default_valid_input_fn(self):
        return self.valid_loader.input_fn()

    def _default_model_fn(self, features, labels, mode):
        """Define model_fn used by TensorFlow Estimator.

        :params features: input features
        :type features: tensorflow tensors
        :params labels: label data
        :type labels: tensorflow tensors
        :params mode: mode of estimator
        :type mode: tf.estimator.ModeKeys
        :return: tensorflow EstimatorSpec
        :rtype: tf.estimator.EstimatorSpec
        """
        logging.info('model function action')
        logits = self.model(features, mode == tf.estimator.ModeKeys.TRAIN)
        logits = tf.cast(logits, tf.float32)
        self.loss = Loss()()
        loss = self.loss(logits=logits, labels=labels)
        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_or_create_global_step()
            epoch = tf.cast(global_step, tf.float32) / tf.cast(
                len(self.train_loader), tf.float32)
            self.lr_scheduler = LrScheduler()()
            self.optimizer = Optimizer()(lr_scheduler=self.lr_scheduler,
                                         epoch=epoch,
                                         distributed=self.distributed)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            minimize_op = self._init_minimize_op(loss, global_step)
            train_op = tf.group(minimize_op, update_ops)

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = self.valid_metrics(logits, labels)
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)

    def _should_run_validation(self, epoch):
        # Zero valid_interval means doesn't run _valid_loop of the trainer
        # and user may provide _valid_loop in other callbacks
        if self.valid_interval == 0:
            return False
        else:
            return epoch % self.valid_interval == 0 or (epoch +
                                                        1) == self.epochs

    def _init_callbacks(self, callbacks):
        # Initialize custom callbacks by configuration or parameters
        if callbacks is not None:
            return callbacks
        disables = []
        if not self.config.model_statistics:
            disables.append('ModelStatistics')
        self.callbacks = CallbackList(self.config.callbacks, disables)
        self.callbacks.set_trainer(self)

    def _metric_average(self, val, name):
        """Do metric average.

        :param val: input value
        :param name: metric name
        :return:
        """
        tensor = torch.tensor(val)
        avg_tensor = hvd.allreduce(tensor, name=name)
        return avg_tensor.item()

    @property
    def _first_rank(self):
        """Check if the first rank."""
        if self.distributed and hvd.rank() != 0:
            return False
        else:
            return True

    def _backup(self):
        """Backup result worker folder."""
        if self.need_backup is True and self.backup_base_path is not None:
            backup_worker_path = FileOps.join_path(self.backup_base_path,
                                                   self.get_worker_subpath())
            FileOps.copy_folder(
                self.get_local_worker_path(self.step_name, self.worker_id),
                backup_worker_path)

    def _save_visual_data(self, is_train=True, pfms=None, loss=None, lr=None):
        # TODO Will move to metric base class later.
        for _name, value in pfms.items():
            if is_train:
                _name = "{}_{}".format("t", _name)
            else:
                _name = "{}_{}".format("v", _name)
            if isinstance(value, list):
                for i, _item in enumerate(value):
                    _name = "{}_{}".format(_name, i)
                    self.visual_data[_name] = _item.data.item()
            elif isinstance(value, dict):
                for k, v in value.keys():
                    _name = "{}_{}".format(k, i)
                    self.visual_data[_name] = v
            elif value is not None:
                self.visual_data[_name] = value.data.item()
        if loss is not None:
            self.visual_data["loss"] = loss
        if lr is not None:
            self.visual_data["lr"] = lr

    def _init_tf_estimator(self):
        """Init tensorflow estimator."""
        if not vega.is_tf_backend():
            return
        sess_config = self._init_session_config()
        if vega.is_gpu_device():
            self._init_gpu_estimator(sess_config)
        elif vega.is_npu_device():
            self._init_npu_estimator(sess_config)

    def _init_tf_session(self):
        if not vega.is_tf_backend():
            return
        sess_config = self._init_session_config()
        self.sess = tf.Session(config=sess_config)

    def _init_session_config(self):
        sess_config = self._init_gpu_session_config() if vega.is_gpu_device() else \
            self._init_npu_session_config()
        return sess_config

    def _init_logging_hook(self):
        logging_hook = []
        if vega.is_gpu_device() and self.distributed:
            logging_hook += [hvd.BroadcastGlobalVariablesHook(0)]
        return logging_hook

    def _init_gpu_estimator(self, sess_config):
        """Init tensorflow estimator."""
        config = tf.estimator.RunConfig(
            model_dir=self.get_local_worker_path(),
            save_checkpoints_steps=self.config.save_steps,
            log_step_count_steps=self.config.report_freq,
            session_config=sess_config)
        self.estimator = tf.estimator.Estimator(model_fn=self.model_fn,
                                                config=config)

    def _init_npu_estimator(self, sess_config):
        model_dir = self.get_local_worker_path()
        if self.distributed:
            model_dir = FileOps.join_path(model_dir, str(self._local_rank_id))
        config = NPURunConfig(model_dir=model_dir,
                              save_checkpoints_steps=self.config.save_steps,
                              log_step_count_steps=self.config.report_freq,
                              session_config=sess_config,
                              enable_data_pre_proc=True,
                              iterations_per_loop=1)
        self.estimator = NPUEstimator(model_fn=self.model_fn, config=config)

    def _init_gpu_session_config(self):
        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        if self.distributed:
            sess_config.gpu_options.visible_device_list = str(hvd.local_rank())
        return sess_config

    def _init_npu_session_config(self):
        sess_config = tf.ConfigProto()
        sess_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF
        custom_op = sess_config.graph_options.rewrite_options.custom_optimizers.add(
        )
        custom_op.name = "NpuOptimizer"
        if self.use_amp:
            custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes(
                "allow_mix_precision")
        custom_op.parameter_map["use_off_line"].b = True
        # custom_op.parameter_map['hcom_parallel'].b = True
        # custom_op.parameter_map["enable_data_pre_proc"].b = True
        # custom_op.parameter_map["mix_compile_mode"].b = True  # mixed calculation
        # custom_op.parameter_map["min_group_size"].b = 1
        return sess_config

    def _shutdown_distributed(self):
        if vega.is_npu_device() and self.distributed:
            self.sess.run(self.npu_shutdown)
            self.sess.close()