Example #1
0
 def build(self):
     """Build the trainer by assembling the necessary components."""
     super().build()
     if self.optimizer is None:
         self.optimizer = Optimizer()(model=self.model,
                                      distributed=self.distributed)
     if hasattr(self.model, 'add_loss'):
         loss_cls = Loss()()
         self.model.add_loss(loss_cls)
         self.loss = self.model.overall_loss()
     else:
         self.loss = Loss()()
     if self.config.adaptive_muti_loss and hasattr(self.loss,
                                                   "adaptive_muti_loss"):
         self.loss.adaptive_muti_loss(save_path=self.get_local_worker_path(
             self.step_name, self.worker_id),
                                      weight=self.config.loss_weight)
     if self.lr_scheduler is None:
         self.lr_scheduler = LrScheduler()(self.optimizer)
     if self.actions_list is not None:
         self.total_optimizer = self.optimizer
         self.total_loss = self.loss
         self.total_lr_scheduler = self.lr_scheduler
     # Some trainer has different train batch size from valid batch
     self.train_metrics = self._init_metrics()
     self.valid_metrics = self._init_metrics()
     self._init_horovod_setting()
     if self.use_amp:
         from apex import amp
         self.model, self.optimizer = amp.initialize(self.model,
                                                     self.optimizer,
                                                     opt_level='O1')
Example #2
0
 def before_train(self, logs=None):
     """Be called before the train process."""
     self.config = self.trainer.config
     self.device = vega.is_gpu_device() if vega.is_gpu_device(
     ) is not True else 0
     self.base_net_desc = self.trainer.model.desc
     sess_config = None
     if vega.is_torch_backend():
         if vega.is_npu_device():
             count_input = torch.FloatTensor(1, 3, 32, 32).npu()
         elif vega.is_gpu_device():
             count_input = torch.FloatTensor(1, 3, 32, 32).to(self.device)
     elif vega.is_tf_backend():
         count_input = tf.random.uniform([1, 3, 32, 32], dtype=tf.float32)
         sess_config = self.trainer._init_session_config()
     elif vega.is_ms_backend():
         count_input = mindspore.Tensor(
             np.random.randn(1, 3, 32, 32).astype(np.float32))
     self.flops_count, self.params_count = calc_model_flops_params(
         self.trainer.model, count_input)
     self.latency_count = calc_forward_latency(self.trainer.model,
                                               count_input, sess_config)
     logging.info("after prune model glops=%sM, params=%sK, latency=%sms",
                  self.flops_count * 1e-6, self.params_count * 1e-3,
                  self.latency_count * 1000)
     self.trainer.model = self._generate_init_model()
     if vega.is_torch_backend():
         self.trainer.optimizer = Optimizer()(
             model=self.trainer.model, distributed=self.trainer.distributed)
         self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
Example #3
0
    def build(self):
        """Build the trainer by assembling the necessary components."""
        super().build()
        if self.config.lr_scheduler.params:
            self.lr_scheduler = LrScheduler()
            dynamic_lr = self.lr_scheduler()(
                base_lr=self.config.optimizer.params["lr"],
                global_step=self.config.epochs * len(self.train_loader),
                total_epoch=self.config.epochs)
            self.optimizer = Optimizer()(model=self.model,
                                         dynamic_lr=dynamic_lr)
        else:
            self.optimizer = Optimizer()(model=self.model)
        if hasattr(self.model, 'add_loss'):
            loss_cls = Loss()()
            self.model.add_loss(loss_cls)
            self.loss = self.model.overall_loss()
        else:
            self.loss = Loss()()
        self.metric_name = self.config.metric.type

        # Some trainer has different train batch size from valid batch
        self.train_metrics = None
        self.valid_metrics = self._init_metrics()
        self.ms_metrics = self.valid_metrics() if isinstance(
            self.valid_metrics(), dict) else {
                self.metric_name: self.valid_metrics()
            }

        self.ms_model = MsModel(network=self.model,
                                loss_fn=self.loss,
                                optimizer=self.optimizer,
                                metrics=self.ms_metrics)
Example #4
0
    def model_fn(self, features, labels, mode):
        """Define cars model_fn used by TensorFlow Estimator."""
        logging.info('Cars model function action')
        self.trainer.loss = Loss()()

        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.compat.v1.train.get_global_step()
            epoch = tf.cast(global_step, tf.float32) / tf.cast(
                len(self.trainer.train_loader), tf.float32)
            self.trainer.optimizer = Optimizer()(
                distributed=self.trainer.distributed)
            self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
            self.trainer.lr_scheduler.step(epoch)
            self.trainer.model.training = True
            alphas = tf.convert_to_tensor(self.alphas)
            for j in range(self.alg_policy.num_individual_per_iter):
                i = np.random.randint(0, self.alg_policy.num_individual, 1)[0]
                if self.epoch < self.alg_policy.warmup:
                    alpha = tf.convert_to_tensor(
                        self.search_alg.random_sample_path())
                else:
                    alpha = alphas[i]
                logits = self.trainer.model(features, alpha=alpha)
                logits = tf.cast(logits, tf.float32)
                loss = self.trainer.loss(logits=logits, labels=labels)
                loss = self.trainer.optimizer.regularize_loss(loss)
                grads, vars = zip(
                    *self.trainer.optimizer.compute_gradients(loss))
                if j == 0:
                    accum_grads = [
                        tf.Variable(tf.zeros_like(grad), trainable=False)
                        for grad in grads
                    ]
                accum_grads = [
                    accum_grads[k] + grads[k] for k in range(len(grads))
                ]
                if self.epoch < self.alg_policy.warmup:
                    break
            clipped_grads, _ = tf.clip_by_global_norm(
                accum_grads, self.trainer.config.grad_clip)
            minimize_op = self.trainer.optimizer.apply_gradients(
                list(zip(clipped_grads, vars)), global_step)
            update_ops = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.UPDATE_OPS)
            train_op = tf.group(minimize_op, update_ops)

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            alpha = tf.convert_to_tensor(self.trainer.valid_alpha)
            self.trainer.model.training = False
            logits = self.trainer.model(features, alpha=alpha)
            logits = tf.cast(logits, tf.float32)
            loss = self.trainer.loss(logits=logits, labels=labels)
            eval_metric_ops = self.trainer.valid_metrics(logits, labels)

        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)
    def step(self, epoch=None):
        """Step forward for current scheduler."""
        if self.warmup_finished:
            self.after_scheduler.step(epoch)
            return

        self.current_iters = epoch
        warmup_lr = self.get_lr()
        for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
            param_group['lr'] = lr

        if epoch >= self.warmup_iters:
            self.warmup_finished = True
            self.after_scheduler = LrScheduler(self.after_scheduler_config)(self.optimizer)
            self.by_epoch = self.after_scheduler.by_epoch
    def model_fn(self, features, labels, mode):
        """Darts model_fn used by TensorFlow Estimator."""
        logging.info('Darts model function action')
        global_step = tf.compat.v1.train.get_global_step()
        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            features, valid_features = features['train'], features['valid']
            labels, valid_labels = labels['train'], labels['valid']
            # update arch
            epoch = tf.cast(global_step, tf.float32) / tf.cast(
                len(self.trainer.train_loader), tf.float32)
            self.trainer.optimizer = Optimizer()(
                distributed=self.trainer.distributed)
            self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
            self.trainer.lr_scheduler.step(epoch)
            update_ops = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.UPDATE_OPS)
            arch_minimize_op = self.search_alg.step(
                valid_x=valid_features,
                valid_y=valid_labels,
                lr=self.trainer.lr_scheduler.get_lr()[0])
            train_op = tf.group(arch_minimize_op, update_ops)
        self.model.training = mode == tf.estimator.ModeKeys.TRAIN
        logits = self.model(features)
        logits = tf.cast(logits, tf.float32)
        self.trainer.loss = Loss()()
        loss = self.trainer.loss(logits=logits, labels=labels)

        if mode == tf.estimator.ModeKeys.TRAIN:
            with tf.control_dependencies([train_op]):
                weight_ops = self.model.get_weight_ops()
                loss_scale = self.trainer.config.loss_scale if self.trainer.use_amp else 1
                train_op = self.trainer.optimizer.step(loss, loss_scale,
                                                       global_step, weight_ops)

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = self.trainer.valid_metrics(logits, labels)

        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)
Example #7
0
 def __call__(self, model=None, distributed=False):
     """Call Optimizer class."""
     for config in self.config:
         name = config.get('model')
         sub_model = getattr(model, config.get('model'))
         sub_opt = Optimizer(config)(sub_model, distributed)
         sub_lr = None
         sub_loss = None
         if config.get('lr_scheduler'):
             sub_lr = LrScheduler(
                 config=config.get('lr_scheduler'))(sub_opt)
         if config.get('loss'):
             sub_loss = ClassFactory.get_instance(ClassType.LOSS,
                                                  config.get('loss'))
         self._opts[name] = dict(opt=sub_opt,
                                 lr=sub_lr,
                                 loss=sub_loss,
                                 model=sub_model)
     return self
 def before_train(self, logs=None):
     """Be called before the train process."""
     self.config = self.trainer.config
     model_code = copy.deepcopy(self.trainer.model.desc)
     model = self.trainer.model
     logging.info('current code: %s, %s', model_code.nbit_w_list,
                  model_code.nbit_a_list)
     quantizer = Quantizer(model, model_code.nbit_w_list,
                           model_code.nbit_a_list)
     model = quantizer()
     self.trainer.model = model
     count_input = [1, 3, 32, 32]
     if General.data_format == 'channels_last':
         count_input = [1, 32, 32, 3]
     sess_config = None
     if vega.is_torch_backend():
         if vega.is_gpu_device():
             model = model.cuda()
             count_input = torch.FloatTensor(*count_input).cuda()
         elif vega.is_npu_device():
             model = model.npu()
             count_input = torch.FloatTensor(*count_input).npu()
         self.trainer.optimizer = Optimizer()(
             model=self.trainer.model, distributed=self.trainer.distributed)
         self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
     elif vega.is_tf_backend():
         tf.compat.v1.reset_default_graph()
         count_input = tf.random.uniform(count_input, dtype=tf.float32)
         sess_config = self.trainer._init_session_config()
     self.flops_count, self.params_count = calc_model_flops_params(
         model, count_input, custom_hooks=quantizer.custom_hooks())
     self.latency_count = calc_forward_latency(model, count_input,
                                               sess_config)
     logging.info("after quant model glops=%sM, params=%sK, latency=%sms",
                  self.flops_count * 1e-6, self.params_count * 1e-3,
                  self.latency_count * 1000)
     self.validate()
Example #9
0
class TrainerTf(TrainerBase):
    """Trainer tensorflow class."""

    def build(self):
        """Build the trainer by assembling the necessary components."""
        super().build()

        # Some trainer has different train batch size from valid batch
        self.train_metrics = None
        self.valid_metrics = self._init_metrics()
        self._init_horovod_setting()

    def _set_default_funcs(self):
        self.model_fn = self._default_model_fn
        self.train_input_fn = self._default_train_input_fn
        self.valid_input_fn = self._default_valid_input_fn

    def _set_condition(self):
        self._init_tf_session()
        self._init_distributed_setting()
        self._init_tf_estimator()

    def _train_epoch(self):
        if self.config.train_in_once:
            max_steps = self.config.max_train_steps or len(self.train_loader) * self.epochs
            self.estimator.train(input_fn=self.train_input_fn,
                                 max_steps=max_steps,
                                 hooks=self._init_logging_hook())
        else:
            self.estimator.train(input_fn=self.train_input_fn,
                                 steps=self.config.max_train_steps or len(self.train_loader),
                                 hooks=self._init_logging_hook())

    def _valid_epoch(self):
        self.callbacks.before_valid()
        valid_logs = None

        eval_metrics = self.estimator.evaluate(input_fn=self.valid_input_fn,
                                               steps=len(self.valid_loader))
        self.valid_metrics.update(eval_metrics)
        valid_logs = dict()
        valid_logs['cur_valid_perfs'] = self.valid_metrics.results

        self.callbacks.after_valid(valid_logs)

    def _init_distributed_setting(self):
        if not self.distributed:
            return

        if vega.is_npu_device():
            sess_config = self._init_session_config()
            self.sess = tf.compat.v1.Session(config=sess_config)
            from npu_bridge.estimator import npu_ops
            self.npu_init = npu_ops.initialize_system()
            self.npu_shutdown = npu_ops.shutdown_system()
            self.sess.run(self.npu_init)

        if vega.is_gpu_device():
            import horovod.tensorflow as hvd
            self._world_size = hvd.size()
            self._rank_id = hvd.rank()
            self._local_rank_id = hvd.local_rank()
        elif vega.is_npu_device():
            from hccl.manage.api import get_local_rank_id
            from hccl.manage.api import get_rank_size
            from hccl.manage.api import get_rank_id
            self._world_size = get_rank_size()
            self._rank_id = get_rank_id()
            self._local_rank_id = get_local_rank_id()

    def _default_train_input_fn(self):
        return self.train_loader.input_fn()

    def _default_valid_input_fn(self):
        return self.valid_loader.input_fn()

    def _default_model_fn(self, features, labels, mode):
        """Define model_fn used by TensorFlow Estimator.

        :params features: input features
        :type features: tensorflow tensors
        :params labels: label data
        :type labels: tensorflow tensors
        :params mode: mode of estimator
        :type mode: tf.estimator.ModeKeys
        :return: tensorflow EstimatorSpec
        :rtype: tf.estimator.EstimatorSpec
        """
        logging.info('model function action')

        self.model.training = mode == tf.estimator.ModeKeys.TRAIN
        if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN:
            mixup_ratio = tf.compat.v1.distributions.Beta(0.1, 0.1).sample()
            mixed_x, y_a, y_b = self._mixup_batch(features, labels, mixup_ratio)
            logits = self.model(mixed_x)
        else:
            logits = self.model(features)
        logits = tf.cast(logits, tf.float32)
        if hasattr(self.model, 'add_loss'):
            loss_cls = Loss()()
            self.model.add_loss(loss_cls)
            self.loss = self.model.overall_loss()
        else:
            self.loss = Loss()()
        # loss
        if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN:
            loss = self._mixup_loss(self.loss, logits, y_a, y_b, mixup_ratio)
        else:
            loss = self.loss(logits, labels)
        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.compat.v1.train.get_or_create_global_step()
            epoch = tf.cast(global_step, tf.float32) / tf.cast(len(self.train_loader), tf.float32)
            self.optimizer = Optimizer()(distributed=self.distributed)
            self.lr_scheduler = LrScheduler()(optimizer=self.optimizer)
            self.lr_scheduler.step(epoch)
            if self.distributed:
                self.optimizer = Optimizer.set_distributed(self.optimizer)

            update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
            loss_scale = self.config.loss_scale if self.use_amp else 1
            minimize_op = self.optimizer.step(loss, loss_scale, global_step)
            train_op = tf.group(minimize_op, update_ops)
            logging_hook = list()
            logging_hook.append(tf.train.LoggingTensorHook(
                tensors={"learning rate": self.lr_scheduler.get_lr()[0]},
                every_n_iter=self.config.train_report_steps))

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = self.valid_metrics(logits, labels)
        if mode == tf.estimator.ModeKeys.TRAIN:
            return tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, train_op=train_op,
                eval_metric_ops=eval_metric_ops,
                training_hooks=logging_hook)
        else:
            return tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, train_op=train_op,
                eval_metric_ops=eval_metric_ops)

    def _mixup_batch(self, x, y, ratio):
        batch_size = tf.shape(x)[0]
        indices = tf.random.shuffle(tf.range(batch_size, dtype=tf.int32))
        mixed_x = ratio * x + (1 - ratio) * tf.gather(x, indices)
        y_a, y_b = y, tf.gather(y, indices)
        return mixed_x, y_a, y_b

    def _mixup_loss(self, loss, pred, y_a, y_b, ratio):
        return ratio * loss(pred, y_a) + (1 - ratio) * loss(pred, y_b)

    def _init_tf_estimator(self):
        """Init tensorflow estimator."""
        sess_config = self._init_session_config()
        if vega.is_gpu_device():
            self._init_gpu_estimator(sess_config)
        elif vega.is_npu_device():
            self._init_npu_estimator(sess_config)

    def _init_tf_session(self):
        sess_config = self._init_session_config()
        self.graph = tf.Graph()

        with self.graph.as_default():
            self.sess = tf.compat.v1.Session(config=sess_config)

    def _init_session_config(self):
        sess_config = self._init_gpu_session_config() if vega.is_gpu_device() else \
            self._init_npu_session_config()
        return sess_config

    def _init_logging_hook(self):
        logging_hook = []
        if vega.is_gpu_device() and self.distributed:
            import horovod.tensorflow as hvd
            logging_hook += [hvd.BroadcastGlobalVariablesHook(0)]
        return logging_hook

    def _init_gpu_estimator(self, sess_config):
        """Init tensorflow estimator."""
        distribution = None
        if not self.distributed and General._parallel and General.devices_per_trainer > 1:
            distribution = tf.contrib.distribute.MirroredStrategy()
        config = tf.estimator.RunConfig(model_dir=self.get_local_worker_path(),
                                        save_checkpoints_steps=self.config.save_steps,
                                        log_step_count_steps=self.config.train_report_steps,
                                        session_config=None if distribution else sess_config,
                                        train_distribute=distribution)
        self.estimator = tf.estimator.Estimator(model_fn=self.model_fn, config=config)

    def _init_npu_estimator(self, sess_config):
        from npu_bridge.estimator.npu.npu_config import NPURunConfig
        from npu_bridge.estimator.npu.npu_estimator import NPUEstimator
        model_dir = self.get_local_worker_path()
        config = NPURunConfig(model_dir=model_dir,
                              save_checkpoints_steps=self.config.save_steps,
                              log_step_count_steps=self.config.train_report_steps,
                              session_config=sess_config,
                              enable_data_pre_proc=True,
                              iterations_per_loop=1)
        self.estimator = NPUEstimator(model_fn=self.model_fn,
                                      config=config)

    def _init_gpu_session_config(self):
        sess_config = tf.compat.v1.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        if self.distributed:
            import horovod.tensorflow as hvd
            sess_config.gpu_options.visible_device_list = str(hvd.local_rank())
        return sess_config

    def _init_npu_session_config(self):
        from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig
        sess_config = tf.ConfigProto()
        sess_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF
        custom_op = sess_config.graph_options.rewrite_options.custom_optimizers.add()
        custom_op.name = "NpuOptimizer"
        if self.use_amp:
            custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision")
        custom_op.parameter_map["use_off_line"].b = True

        return sess_config
Example #10
0
    def _default_model_fn(self, features, labels, mode):
        """Define model_fn used by TensorFlow Estimator.

        :params features: input features
        :type features: tensorflow tensors
        :params labels: label data
        :type labels: tensorflow tensors
        :params mode: mode of estimator
        :type mode: tf.estimator.ModeKeys
        :return: tensorflow EstimatorSpec
        :rtype: tf.estimator.EstimatorSpec
        """
        logging.info('model function action')

        self.model.training = mode == tf.estimator.ModeKeys.TRAIN
        if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN:
            mixup_ratio = tf.compat.v1.distributions.Beta(0.1, 0.1).sample()
            mixed_x, y_a, y_b = self._mixup_batch(features, labels, mixup_ratio)
            logits = self.model(mixed_x)
        else:
            logits = self.model(features)
        logits = tf.cast(logits, tf.float32)
        if hasattr(self.model, 'add_loss'):
            loss_cls = Loss()()
            self.model.add_loss(loss_cls)
            self.loss = self.model.overall_loss()
        else:
            self.loss = Loss()()
        # loss
        if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN:
            loss = self._mixup_loss(self.loss, logits, y_a, y_b, mixup_ratio)
        else:
            loss = self.loss(logits, labels)
        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.compat.v1.train.get_or_create_global_step()
            epoch = tf.cast(global_step, tf.float32) / tf.cast(len(self.train_loader), tf.float32)
            self.optimizer = Optimizer()(distributed=self.distributed)
            self.lr_scheduler = LrScheduler()(optimizer=self.optimizer)
            self.lr_scheduler.step(epoch)
            if self.distributed:
                self.optimizer = Optimizer.set_distributed(self.optimizer)

            update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
            loss_scale = self.config.loss_scale if self.use_amp else 1
            minimize_op = self.optimizer.step(loss, loss_scale, global_step)
            train_op = tf.group(minimize_op, update_ops)
            logging_hook = list()
            logging_hook.append(tf.train.LoggingTensorHook(
                tensors={"learning rate": self.lr_scheduler.get_lr()[0]},
                every_n_iter=self.config.train_report_steps))

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = self.valid_metrics(logits, labels)
        if mode == tf.estimator.ModeKeys.TRAIN:
            return tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, train_op=train_op,
                eval_metric_ops=eval_metric_ops,
                training_hooks=logging_hook)
        else:
            return tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, train_op=train_op,
                eval_metric_ops=eval_metric_ops)
Example #11
0
class TrainerTorch(TrainerBase):
    """Trainer torch class."""
    def build(self):
        """Build the trainer by assembling the necessary components."""
        super().build()
        if self.optimizer is None:
            self.optimizer = Optimizer()(model=self.model,
                                         distributed=self.distributed)
        if hasattr(self.model, 'add_loss'):
            loss_cls = Loss()()
            self.model.add_loss(loss_cls)
            self.loss = self.model.overall_loss()
        else:
            self.loss = Loss()()
        if self.config.adaptive_muti_loss and hasattr(self.loss,
                                                      "adaptive_muti_loss"):
            self.loss.adaptive_muti_loss(save_path=self.get_local_worker_path(
                self.step_name, self.worker_id),
                                         weight=self.config.loss_weight)
        if self.lr_scheduler is None:
            self.lr_scheduler = LrScheduler()(self.optimizer)
        if self.actions_list is not None:
            self.total_optimizer = self.optimizer
            self.total_loss = self.loss
            self.total_lr_scheduler = self.lr_scheduler
        # Some trainer has different train batch size from valid batch
        self.train_metrics = self._init_metrics()
        self.valid_metrics = self._init_metrics()
        self._init_horovod_setting()
        if self.use_amp:
            from apex import amp
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level='O1')

    def _set_default_funcs(self):
        self.make_batch = self._default_make_batch
        if isinstance(self.config.optimizer, list):
            self.train_step = self._multi_train_step
        else:
            self.train_step = self._default_train_step
        self.valid_step = self._default_valid_step

    def _set_condition(self):
        self._init_distributed_setting()
        torch.manual_seed(self.config.seed)
        self._init_setting()

    def _init_setting(self):
        """Init CUDA setting."""
        if vega.is_gpu_device():
            import torch.cuda
            self.config.device = vega.is_gpu_device() if vega.is_gpu_device(
            ) is not True else 0
            if self.distributed:
                torch.cuda.set_device(self._local_rank_id)
            torch.cuda.manual_seed(self.config.seed)
        elif vega.is_npu_device():
            import torch.npu
            device = "npu:{}".format(os.environ.get('DEVICE_ID', 0))
            torch.npu.set_device(device)
            torch.npu.manual_seed(self.config.seed)
        elif vega.is_cpu_device():
            self.config.device = -1
            return
        else:
            raise ValueError('Set a correct device: cuda or npu.')

    def _init_distributed_setting(self):
        if self.distributed:
            import horovod.torch as hvd
            self._world_size = hvd.size()
            self._rank_id = hvd.rank()
            self._local_rank_id = hvd.local_rank()

    def _init_horovod_setting(self):
        """Init horovod setting."""
        self.is_chief = True
        if self.distributed:
            import horovod.torch as hvd
            hvd.broadcast_parameters(self.model.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(self.optimizer, root_rank=0)
            if hvd.rank() != 0:
                self.is_chief = False
            else:
                self.is_chief = True

    def _train_epoch(self):
        self.model.train()
        for batch_index, batch in enumerate(self.train_loader):
            if self.config.max_train_steps and batch_index > self.config.max_train_steps:
                return
            batch = self.make_batch(batch)
            batch_logs = {'train_batch': batch}
            self.callbacks.before_train_step(batch_index, batch_logs)
            train_batch_output = self.train_step(batch)
            batch_logs.update(train_batch_output)
            if self.config.is_detection_trainer:
                batch_logs.update({'is_detection_trainer': True})
            self.callbacks.after_train_step(batch_index, batch_logs)

    def _valid_epoch(self):
        self.callbacks.before_valid()
        valid_logs = None

        self.model.eval()
        with torch.no_grad():
            for batch_index, batch in enumerate(self.valid_loader):
                batch = self.make_batch(batch)
                batch_logs = {'valid_batch': batch}
                self.callbacks.before_valid_step(batch_index, batch_logs)
                valid_batch_output = self.valid_step(batch)
                self.callbacks.after_valid_step(batch_index,
                                                valid_batch_output)

        self.callbacks.after_valid(valid_logs)

    def _default_make_batch(self, batch):
        """Unpack batch to get input and target."""
        if not vega.is_cpu_device():
            batch = self._set_device(batch)
        return batch

    def _set_device(self, data):
        if torch.is_tensor(data):
            if vega.is_gpu_device():
                return data.cuda()
            else:
                return data.npu()
        if isinstance(data, dict):
            return {k: self._set_device(v) for k, v in data.items()}
        elif isinstance(data, list):
            return [self._set_device(v) for v in data]
        elif isinstance(data, tuple):
            return tuple([self._set_device(v) for v in data])
        return data

    def _default_train_step(self, batch):
        self.optimizer.zero_grad()
        input, target = None, None
        if isinstance(batch, dict):
            output = self.model(**batch)
        elif isinstance(batch, list) and isinstance(batch[0], dict):
            output = self.model(batch)
        else:
            # classification
            input, target = batch
            if self.config.mixup:
                mixup_ratio = np.random.beta(0.1, 0.1)
                mixed_x, y_a, y_b = self._mixup_batch(input, target,
                                                      mixup_ratio)
                output = self.model(mixed_x)
            else:
                output = self.model(input) if not isinstance(
                    input, dict) else self.model(**input)
        # loss
        if self.config.mixup:
            loss = self._mixup_loss(self.loss, output, y_a, y_b, mixup_ratio)
        else:
            loss = self.loss(output, target)
        if self.use_amp:
            from apex import amp
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
                self.optimizer.synchronize()
            with self.optimizer.skip_synchronize():
                self.optimizer.step()
        else:
            loss.backward()
            if self.config.grad_clip:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.config.grad_clip)
            self.optimizer.step()
        return {
            'loss': loss.item(),
            'train_batch_output': output,
            'lr': self.lr_scheduler.get_lr()
        }

    def _multi_train_step(self, batch):
        train_batch_output = None
        for opt_name, sub_opt in self.optimizer.get_opts():
            self.optimizer = sub_opt.get('opt')
            self.loss = sub_opt.get('loss')
            self.lr_scheduler = sub_opt.get('lr')
            train_batch_output = self._default_train_step(batch)
        return train_batch_output

    def _default_valid_step(self, batch):
        if isinstance(batch, dict):
            output = self.model(**batch)
        elif isinstance(batch, list) and isinstance(batch[0], dict):
            output = self.model(batch)
        else:
            input, target = batch
            output = self.model(input) if not isinstance(
                input, dict) else self.model(**input)
        return {'valid_batch_output': output}

    def _mixup_batch(self, x, y, ratio):
        indices = torch.randperm(x.shape[0])
        mixed_x = ratio * x + (1 - ratio) * x[indices]
        y_a, y_b = y, y[indices]
        return mixed_x, y_a, y_b

    def _mixup_loss(self, loss, pred, y_a, y_b, ratio):
        return ratio * loss(pred, y_a) + (1 - ratio) * loss(pred, y_b)