def build(self): """Build the trainer by assembling the necessary components.""" super().build() if self.optimizer is None: self.optimizer = Optimizer()(model=self.model, distributed=self.distributed) if hasattr(self.model, 'add_loss'): loss_cls = Loss()() self.model.add_loss(loss_cls) self.loss = self.model.overall_loss() else: self.loss = Loss()() if self.config.adaptive_muti_loss and hasattr(self.loss, "adaptive_muti_loss"): self.loss.adaptive_muti_loss(save_path=self.get_local_worker_path( self.step_name, self.worker_id), weight=self.config.loss_weight) if self.lr_scheduler is None: self.lr_scheduler = LrScheduler()(self.optimizer) if self.actions_list is not None: self.total_optimizer = self.optimizer self.total_loss = self.loss self.total_lr_scheduler = self.lr_scheduler # Some trainer has different train batch size from valid batch self.train_metrics = self._init_metrics() self.valid_metrics = self._init_metrics() self._init_horovod_setting() if self.use_amp: from apex import amp self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1')
def before_train(self, logs=None): """Be called before the train process.""" self.config = self.trainer.config self.device = vega.is_gpu_device() if vega.is_gpu_device( ) is not True else 0 self.base_net_desc = self.trainer.model.desc sess_config = None if vega.is_torch_backend(): if vega.is_npu_device(): count_input = torch.FloatTensor(1, 3, 32, 32).npu() elif vega.is_gpu_device(): count_input = torch.FloatTensor(1, 3, 32, 32).to(self.device) elif vega.is_tf_backend(): count_input = tf.random.uniform([1, 3, 32, 32], dtype=tf.float32) sess_config = self.trainer._init_session_config() elif vega.is_ms_backend(): count_input = mindspore.Tensor( np.random.randn(1, 3, 32, 32).astype(np.float32)) self.flops_count, self.params_count = calc_model_flops_params( self.trainer.model, count_input) self.latency_count = calc_forward_latency(self.trainer.model, count_input, sess_config) logging.info("after prune model glops=%sM, params=%sK, latency=%sms", self.flops_count * 1e-6, self.params_count * 1e-3, self.latency_count * 1000) self.trainer.model = self._generate_init_model() if vega.is_torch_backend(): self.trainer.optimizer = Optimizer()( model=self.trainer.model, distributed=self.trainer.distributed) self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
def build(self): """Build the trainer by assembling the necessary components.""" super().build() if self.config.lr_scheduler.params: self.lr_scheduler = LrScheduler() dynamic_lr = self.lr_scheduler()( base_lr=self.config.optimizer.params["lr"], global_step=self.config.epochs * len(self.train_loader), total_epoch=self.config.epochs) self.optimizer = Optimizer()(model=self.model, dynamic_lr=dynamic_lr) else: self.optimizer = Optimizer()(model=self.model) if hasattr(self.model, 'add_loss'): loss_cls = Loss()() self.model.add_loss(loss_cls) self.loss = self.model.overall_loss() else: self.loss = Loss()() self.metric_name = self.config.metric.type # Some trainer has different train batch size from valid batch self.train_metrics = None self.valid_metrics = self._init_metrics() self.ms_metrics = self.valid_metrics() if isinstance( self.valid_metrics(), dict) else { self.metric_name: self.valid_metrics() } self.ms_model = MsModel(network=self.model, loss_fn=self.loss, optimizer=self.optimizer, metrics=self.ms_metrics)
def model_fn(self, features, labels, mode): """Define cars model_fn used by TensorFlow Estimator.""" logging.info('Cars model function action') self.trainer.loss = Loss()() train_op = None if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.compat.v1.train.get_global_step() epoch = tf.cast(global_step, tf.float32) / tf.cast( len(self.trainer.train_loader), tf.float32) self.trainer.optimizer = Optimizer()( distributed=self.trainer.distributed) self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer) self.trainer.lr_scheduler.step(epoch) self.trainer.model.training = True alphas = tf.convert_to_tensor(self.alphas) for j in range(self.alg_policy.num_individual_per_iter): i = np.random.randint(0, self.alg_policy.num_individual, 1)[0] if self.epoch < self.alg_policy.warmup: alpha = tf.convert_to_tensor( self.search_alg.random_sample_path()) else: alpha = alphas[i] logits = self.trainer.model(features, alpha=alpha) logits = tf.cast(logits, tf.float32) loss = self.trainer.loss(logits=logits, labels=labels) loss = self.trainer.optimizer.regularize_loss(loss) grads, vars = zip( *self.trainer.optimizer.compute_gradients(loss)) if j == 0: accum_grads = [ tf.Variable(tf.zeros_like(grad), trainable=False) for grad in grads ] accum_grads = [ accum_grads[k] + grads[k] for k in range(len(grads)) ] if self.epoch < self.alg_policy.warmup: break clipped_grads, _ = tf.clip_by_global_norm( accum_grads, self.trainer.config.grad_clip) minimize_op = self.trainer.optimizer.apply_gradients( list(zip(clipped_grads, vars)), global_step) update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS) train_op = tf.group(minimize_op, update_ops) eval_metric_ops = None if mode == tf.estimator.ModeKeys.EVAL: alpha = tf.convert_to_tensor(self.trainer.valid_alpha) self.trainer.model.training = False logits = self.trainer.model(features, alpha=alpha) logits = tf.cast(logits, tf.float32) loss = self.trainer.loss(logits=logits, labels=labels) eval_metric_ops = self.trainer.valid_metrics(logits, labels) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)
def step(self, epoch=None): """Step forward for current scheduler.""" if self.warmup_finished: self.after_scheduler.step(epoch) return self.current_iters = epoch warmup_lr = self.get_lr() for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): param_group['lr'] = lr if epoch >= self.warmup_iters: self.warmup_finished = True self.after_scheduler = LrScheduler(self.after_scheduler_config)(self.optimizer) self.by_epoch = self.after_scheduler.by_epoch
def model_fn(self, features, labels, mode): """Darts model_fn used by TensorFlow Estimator.""" logging.info('Darts model function action') global_step = tf.compat.v1.train.get_global_step() train_op = None if mode == tf.estimator.ModeKeys.TRAIN: features, valid_features = features['train'], features['valid'] labels, valid_labels = labels['train'], labels['valid'] # update arch epoch = tf.cast(global_step, tf.float32) / tf.cast( len(self.trainer.train_loader), tf.float32) self.trainer.optimizer = Optimizer()( distributed=self.trainer.distributed) self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer) self.trainer.lr_scheduler.step(epoch) update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS) arch_minimize_op = self.search_alg.step( valid_x=valid_features, valid_y=valid_labels, lr=self.trainer.lr_scheduler.get_lr()[0]) train_op = tf.group(arch_minimize_op, update_ops) self.model.training = mode == tf.estimator.ModeKeys.TRAIN logits = self.model(features) logits = tf.cast(logits, tf.float32) self.trainer.loss = Loss()() loss = self.trainer.loss(logits=logits, labels=labels) if mode == tf.estimator.ModeKeys.TRAIN: with tf.control_dependencies([train_op]): weight_ops = self.model.get_weight_ops() loss_scale = self.trainer.config.loss_scale if self.trainer.use_amp else 1 train_op = self.trainer.optimizer.step(loss, loss_scale, global_step, weight_ops) eval_metric_ops = None if mode == tf.estimator.ModeKeys.EVAL: eval_metric_ops = self.trainer.valid_metrics(logits, labels) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)
def __call__(self, model=None, distributed=False): """Call Optimizer class.""" for config in self.config: name = config.get('model') sub_model = getattr(model, config.get('model')) sub_opt = Optimizer(config)(sub_model, distributed) sub_lr = None sub_loss = None if config.get('lr_scheduler'): sub_lr = LrScheduler( config=config.get('lr_scheduler'))(sub_opt) if config.get('loss'): sub_loss = ClassFactory.get_instance(ClassType.LOSS, config.get('loss')) self._opts[name] = dict(opt=sub_opt, lr=sub_lr, loss=sub_loss, model=sub_model) return self
def before_train(self, logs=None): """Be called before the train process.""" self.config = self.trainer.config model_code = copy.deepcopy(self.trainer.model.desc) model = self.trainer.model logging.info('current code: %s, %s', model_code.nbit_w_list, model_code.nbit_a_list) quantizer = Quantizer(model, model_code.nbit_w_list, model_code.nbit_a_list) model = quantizer() self.trainer.model = model count_input = [1, 3, 32, 32] if General.data_format == 'channels_last': count_input = [1, 32, 32, 3] sess_config = None if vega.is_torch_backend(): if vega.is_gpu_device(): model = model.cuda() count_input = torch.FloatTensor(*count_input).cuda() elif vega.is_npu_device(): model = model.npu() count_input = torch.FloatTensor(*count_input).npu() self.trainer.optimizer = Optimizer()( model=self.trainer.model, distributed=self.trainer.distributed) self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer) elif vega.is_tf_backend(): tf.compat.v1.reset_default_graph() count_input = tf.random.uniform(count_input, dtype=tf.float32) sess_config = self.trainer._init_session_config() self.flops_count, self.params_count = calc_model_flops_params( model, count_input, custom_hooks=quantizer.custom_hooks()) self.latency_count = calc_forward_latency(model, count_input, sess_config) logging.info("after quant model glops=%sM, params=%sK, latency=%sms", self.flops_count * 1e-6, self.params_count * 1e-3, self.latency_count * 1000) self.validate()
class TrainerTf(TrainerBase): """Trainer tensorflow class.""" def build(self): """Build the trainer by assembling the necessary components.""" super().build() # Some trainer has different train batch size from valid batch self.train_metrics = None self.valid_metrics = self._init_metrics() self._init_horovod_setting() def _set_default_funcs(self): self.model_fn = self._default_model_fn self.train_input_fn = self._default_train_input_fn self.valid_input_fn = self._default_valid_input_fn def _set_condition(self): self._init_tf_session() self._init_distributed_setting() self._init_tf_estimator() def _train_epoch(self): if self.config.train_in_once: max_steps = self.config.max_train_steps or len(self.train_loader) * self.epochs self.estimator.train(input_fn=self.train_input_fn, max_steps=max_steps, hooks=self._init_logging_hook()) else: self.estimator.train(input_fn=self.train_input_fn, steps=self.config.max_train_steps or len(self.train_loader), hooks=self._init_logging_hook()) def _valid_epoch(self): self.callbacks.before_valid() valid_logs = None eval_metrics = self.estimator.evaluate(input_fn=self.valid_input_fn, steps=len(self.valid_loader)) self.valid_metrics.update(eval_metrics) valid_logs = dict() valid_logs['cur_valid_perfs'] = self.valid_metrics.results self.callbacks.after_valid(valid_logs) def _init_distributed_setting(self): if not self.distributed: return if vega.is_npu_device(): sess_config = self._init_session_config() self.sess = tf.compat.v1.Session(config=sess_config) from npu_bridge.estimator import npu_ops self.npu_init = npu_ops.initialize_system() self.npu_shutdown = npu_ops.shutdown_system() self.sess.run(self.npu_init) if vega.is_gpu_device(): import horovod.tensorflow as hvd self._world_size = hvd.size() self._rank_id = hvd.rank() self._local_rank_id = hvd.local_rank() elif vega.is_npu_device(): from hccl.manage.api import get_local_rank_id from hccl.manage.api import get_rank_size from hccl.manage.api import get_rank_id self._world_size = get_rank_size() self._rank_id = get_rank_id() self._local_rank_id = get_local_rank_id() def _default_train_input_fn(self): return self.train_loader.input_fn() def _default_valid_input_fn(self): return self.valid_loader.input_fn() def _default_model_fn(self, features, labels, mode): """Define model_fn used by TensorFlow Estimator. :params features: input features :type features: tensorflow tensors :params labels: label data :type labels: tensorflow tensors :params mode: mode of estimator :type mode: tf.estimator.ModeKeys :return: tensorflow EstimatorSpec :rtype: tf.estimator.EstimatorSpec """ logging.info('model function action') self.model.training = mode == tf.estimator.ModeKeys.TRAIN if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN: mixup_ratio = tf.compat.v1.distributions.Beta(0.1, 0.1).sample() mixed_x, y_a, y_b = self._mixup_batch(features, labels, mixup_ratio) logits = self.model(mixed_x) else: logits = self.model(features) logits = tf.cast(logits, tf.float32) if hasattr(self.model, 'add_loss'): loss_cls = Loss()() self.model.add_loss(loss_cls) self.loss = self.model.overall_loss() else: self.loss = Loss()() # loss if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN: loss = self._mixup_loss(self.loss, logits, y_a, y_b, mixup_ratio) else: loss = self.loss(logits, labels) train_op = None if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.compat.v1.train.get_or_create_global_step() epoch = tf.cast(global_step, tf.float32) / tf.cast(len(self.train_loader), tf.float32) self.optimizer = Optimizer()(distributed=self.distributed) self.lr_scheduler = LrScheduler()(optimizer=self.optimizer) self.lr_scheduler.step(epoch) if self.distributed: self.optimizer = Optimizer.set_distributed(self.optimizer) update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) loss_scale = self.config.loss_scale if self.use_amp else 1 minimize_op = self.optimizer.step(loss, loss_scale, global_step) train_op = tf.group(minimize_op, update_ops) logging_hook = list() logging_hook.append(tf.train.LoggingTensorHook( tensors={"learning rate": self.lr_scheduler.get_lr()[0]}, every_n_iter=self.config.train_report_steps)) eval_metric_ops = None if mode == tf.estimator.ModeKeys.EVAL: eval_metric_ops = self.valid_metrics(logits, labels) if mode == tf.estimator.ModeKeys.TRAIN: return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, training_hooks=logging_hook) else: return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops) def _mixup_batch(self, x, y, ratio): batch_size = tf.shape(x)[0] indices = tf.random.shuffle(tf.range(batch_size, dtype=tf.int32)) mixed_x = ratio * x + (1 - ratio) * tf.gather(x, indices) y_a, y_b = y, tf.gather(y, indices) return mixed_x, y_a, y_b def _mixup_loss(self, loss, pred, y_a, y_b, ratio): return ratio * loss(pred, y_a) + (1 - ratio) * loss(pred, y_b) def _init_tf_estimator(self): """Init tensorflow estimator.""" sess_config = self._init_session_config() if vega.is_gpu_device(): self._init_gpu_estimator(sess_config) elif vega.is_npu_device(): self._init_npu_estimator(sess_config) def _init_tf_session(self): sess_config = self._init_session_config() self.graph = tf.Graph() with self.graph.as_default(): self.sess = tf.compat.v1.Session(config=sess_config) def _init_session_config(self): sess_config = self._init_gpu_session_config() if vega.is_gpu_device() else \ self._init_npu_session_config() return sess_config def _init_logging_hook(self): logging_hook = [] if vega.is_gpu_device() and self.distributed: import horovod.tensorflow as hvd logging_hook += [hvd.BroadcastGlobalVariablesHook(0)] return logging_hook def _init_gpu_estimator(self, sess_config): """Init tensorflow estimator.""" distribution = None if not self.distributed and General._parallel and General.devices_per_trainer > 1: distribution = tf.contrib.distribute.MirroredStrategy() config = tf.estimator.RunConfig(model_dir=self.get_local_worker_path(), save_checkpoints_steps=self.config.save_steps, log_step_count_steps=self.config.train_report_steps, session_config=None if distribution else sess_config, train_distribute=distribution) self.estimator = tf.estimator.Estimator(model_fn=self.model_fn, config=config) def _init_npu_estimator(self, sess_config): from npu_bridge.estimator.npu.npu_config import NPURunConfig from npu_bridge.estimator.npu.npu_estimator import NPUEstimator model_dir = self.get_local_worker_path() config = NPURunConfig(model_dir=model_dir, save_checkpoints_steps=self.config.save_steps, log_step_count_steps=self.config.train_report_steps, session_config=sess_config, enable_data_pre_proc=True, iterations_per_loop=1) self.estimator = NPUEstimator(model_fn=self.model_fn, config=config) def _init_gpu_session_config(self): sess_config = tf.compat.v1.ConfigProto() sess_config.gpu_options.allow_growth = True if self.distributed: import horovod.tensorflow as hvd sess_config.gpu_options.visible_device_list = str(hvd.local_rank()) return sess_config def _init_npu_session_config(self): from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig sess_config = tf.ConfigProto() sess_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF custom_op = sess_config.graph_options.rewrite_options.custom_optimizers.add() custom_op.name = "NpuOptimizer" if self.use_amp: custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") custom_op.parameter_map["use_off_line"].b = True return sess_config
def _default_model_fn(self, features, labels, mode): """Define model_fn used by TensorFlow Estimator. :params features: input features :type features: tensorflow tensors :params labels: label data :type labels: tensorflow tensors :params mode: mode of estimator :type mode: tf.estimator.ModeKeys :return: tensorflow EstimatorSpec :rtype: tf.estimator.EstimatorSpec """ logging.info('model function action') self.model.training = mode == tf.estimator.ModeKeys.TRAIN if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN: mixup_ratio = tf.compat.v1.distributions.Beta(0.1, 0.1).sample() mixed_x, y_a, y_b = self._mixup_batch(features, labels, mixup_ratio) logits = self.model(mixed_x) else: logits = self.model(features) logits = tf.cast(logits, tf.float32) if hasattr(self.model, 'add_loss'): loss_cls = Loss()() self.model.add_loss(loss_cls) self.loss = self.model.overall_loss() else: self.loss = Loss()() # loss if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN: loss = self._mixup_loss(self.loss, logits, y_a, y_b, mixup_ratio) else: loss = self.loss(logits, labels) train_op = None if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.compat.v1.train.get_or_create_global_step() epoch = tf.cast(global_step, tf.float32) / tf.cast(len(self.train_loader), tf.float32) self.optimizer = Optimizer()(distributed=self.distributed) self.lr_scheduler = LrScheduler()(optimizer=self.optimizer) self.lr_scheduler.step(epoch) if self.distributed: self.optimizer = Optimizer.set_distributed(self.optimizer) update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) loss_scale = self.config.loss_scale if self.use_amp else 1 minimize_op = self.optimizer.step(loss, loss_scale, global_step) train_op = tf.group(minimize_op, update_ops) logging_hook = list() logging_hook.append(tf.train.LoggingTensorHook( tensors={"learning rate": self.lr_scheduler.get_lr()[0]}, every_n_iter=self.config.train_report_steps)) eval_metric_ops = None if mode == tf.estimator.ModeKeys.EVAL: eval_metric_ops = self.valid_metrics(logits, labels) if mode == tf.estimator.ModeKeys.TRAIN: return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, training_hooks=logging_hook) else: return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)
class TrainerTorch(TrainerBase): """Trainer torch class.""" def build(self): """Build the trainer by assembling the necessary components.""" super().build() if self.optimizer is None: self.optimizer = Optimizer()(model=self.model, distributed=self.distributed) if hasattr(self.model, 'add_loss'): loss_cls = Loss()() self.model.add_loss(loss_cls) self.loss = self.model.overall_loss() else: self.loss = Loss()() if self.config.adaptive_muti_loss and hasattr(self.loss, "adaptive_muti_loss"): self.loss.adaptive_muti_loss(save_path=self.get_local_worker_path( self.step_name, self.worker_id), weight=self.config.loss_weight) if self.lr_scheduler is None: self.lr_scheduler = LrScheduler()(self.optimizer) if self.actions_list is not None: self.total_optimizer = self.optimizer self.total_loss = self.loss self.total_lr_scheduler = self.lr_scheduler # Some trainer has different train batch size from valid batch self.train_metrics = self._init_metrics() self.valid_metrics = self._init_metrics() self._init_horovod_setting() if self.use_amp: from apex import amp self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1') def _set_default_funcs(self): self.make_batch = self._default_make_batch if isinstance(self.config.optimizer, list): self.train_step = self._multi_train_step else: self.train_step = self._default_train_step self.valid_step = self._default_valid_step def _set_condition(self): self._init_distributed_setting() torch.manual_seed(self.config.seed) self._init_setting() def _init_setting(self): """Init CUDA setting.""" if vega.is_gpu_device(): import torch.cuda self.config.device = vega.is_gpu_device() if vega.is_gpu_device( ) is not True else 0 if self.distributed: torch.cuda.set_device(self._local_rank_id) torch.cuda.manual_seed(self.config.seed) elif vega.is_npu_device(): import torch.npu device = "npu:{}".format(os.environ.get('DEVICE_ID', 0)) torch.npu.set_device(device) torch.npu.manual_seed(self.config.seed) elif vega.is_cpu_device(): self.config.device = -1 return else: raise ValueError('Set a correct device: cuda or npu.') def _init_distributed_setting(self): if self.distributed: import horovod.torch as hvd self._world_size = hvd.size() self._rank_id = hvd.rank() self._local_rank_id = hvd.local_rank() def _init_horovod_setting(self): """Init horovod setting.""" self.is_chief = True if self.distributed: import horovod.torch as hvd hvd.broadcast_parameters(self.model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(self.optimizer, root_rank=0) if hvd.rank() != 0: self.is_chief = False else: self.is_chief = True def _train_epoch(self): self.model.train() for batch_index, batch in enumerate(self.train_loader): if self.config.max_train_steps and batch_index > self.config.max_train_steps: return batch = self.make_batch(batch) batch_logs = {'train_batch': batch} self.callbacks.before_train_step(batch_index, batch_logs) train_batch_output = self.train_step(batch) batch_logs.update(train_batch_output) if self.config.is_detection_trainer: batch_logs.update({'is_detection_trainer': True}) self.callbacks.after_train_step(batch_index, batch_logs) def _valid_epoch(self): self.callbacks.before_valid() valid_logs = None self.model.eval() with torch.no_grad(): for batch_index, batch in enumerate(self.valid_loader): batch = self.make_batch(batch) batch_logs = {'valid_batch': batch} self.callbacks.before_valid_step(batch_index, batch_logs) valid_batch_output = self.valid_step(batch) self.callbacks.after_valid_step(batch_index, valid_batch_output) self.callbacks.after_valid(valid_logs) def _default_make_batch(self, batch): """Unpack batch to get input and target.""" if not vega.is_cpu_device(): batch = self._set_device(batch) return batch def _set_device(self, data): if torch.is_tensor(data): if vega.is_gpu_device(): return data.cuda() else: return data.npu() if isinstance(data, dict): return {k: self._set_device(v) for k, v in data.items()} elif isinstance(data, list): return [self._set_device(v) for v in data] elif isinstance(data, tuple): return tuple([self._set_device(v) for v in data]) return data def _default_train_step(self, batch): self.optimizer.zero_grad() input, target = None, None if isinstance(batch, dict): output = self.model(**batch) elif isinstance(batch, list) and isinstance(batch[0], dict): output = self.model(batch) else: # classification input, target = batch if self.config.mixup: mixup_ratio = np.random.beta(0.1, 0.1) mixed_x, y_a, y_b = self._mixup_batch(input, target, mixup_ratio) output = self.model(mixed_x) else: output = self.model(input) if not isinstance( input, dict) else self.model(**input) # loss if self.config.mixup: loss = self._mixup_loss(self.loss, output, y_a, y_b, mixup_ratio) else: loss = self.loss(output, target) if self.use_amp: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() self.optimizer.synchronize() with self.optimizer.skip_synchronize(): self.optimizer.step() else: loss.backward() if self.config.grad_clip: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) self.optimizer.step() return { 'loss': loss.item(), 'train_batch_output': output, 'lr': self.lr_scheduler.get_lr() } def _multi_train_step(self, batch): train_batch_output = None for opt_name, sub_opt in self.optimizer.get_opts(): self.optimizer = sub_opt.get('opt') self.loss = sub_opt.get('loss') self.lr_scheduler = sub_opt.get('lr') train_batch_output = self._default_train_step(batch) return train_batch_output def _default_valid_step(self, batch): if isinstance(batch, dict): output = self.model(**batch) elif isinstance(batch, list) and isinstance(batch[0], dict): output = self.model(batch) else: input, target = batch output = self.model(input) if not isinstance( input, dict) else self.model(**input) return {'valid_batch_output': output} def _mixup_batch(self, x, y, ratio): indices = torch.randperm(x.shape[0]) mixed_x = ratio * x + (1 - ratio) * x[indices] y_a, y_b = y, y[indices] return mixed_x, y_a, y_b def _mixup_loss(self, loss, pred, y_a, y_b, ratio): return ratio * loss(pred, y_a) + (1 - ratio) * loss(pred, y_b)