def _default_model_fn(self, features, labels, mode): """Define model_fn used by TensorFlow Estimator. :params features: input features :type features: tensorflow tensors :params labels: label data :type labels: tensorflow tensors :params mode: mode of estimator :type mode: tf.estimator.ModeKeys :return: tensorflow EstimatorSpec :rtype: tf.estimator.EstimatorSpec """ logging.info('model function action') self.model.training = mode == tf.estimator.ModeKeys.TRAIN logits = self.model(features) assign_ops = self.model.pretrained() with tf.control_dependencies(assign_ops): logits = tf.cast(logits, tf.float32) if hasattr(self.model, 'add_loss'): loss_cls = Loss()() self.model.add_loss(loss_cls) self.loss = self.model.overall_loss() else: self.loss = Loss()() loss = self.loss(logits, labels) train_op = None if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.compat.v1.train.get_or_create_global_step() epoch = tf.cast(global_step, tf.float32) / tf.cast( len(self.train_loader), tf.float32) self.optimizer = Optimizer()(distributed=self.distributed) self.lr_scheduler = LrScheduler()(optimizer=self.optimizer) self.lr_scheduler.step(epoch) if self.distributed: self.optimizer = Optimizer.set_distributed(self.optimizer) update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS) loss_scale = self.config.loss_scale if self.use_amp else 1 minimize_op = self.optimizer.step(loss, loss_scale, global_step) train_op = tf.group(minimize_op, update_ops) eval_metric_ops = None if mode == tf.estimator.ModeKeys.EVAL: eval_metric_ops = self.valid_metrics(logits, labels) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)
def _default_model_fn(self, features, labels, mode): """Define model_fn used by TensorFlow Estimator. :params features: input features :type features: tensorflow tensors :params labels: label data :type labels: tensorflow tensors :params mode: mode of estimator :type mode: tf.estimator.ModeKeys :return: tensorflow EstimatorSpec :rtype: tf.estimator.EstimatorSpec """ logging.info('model function action') self.model.training = mode == tf.estimator.ModeKeys.TRAIN if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN: mixup_ratio = tf.compat.v1.distributions.Beta(0.1, 0.1).sample() mixed_x, y_a, y_b = self._mixup_batch(features, labels, mixup_ratio) logits = self.model(mixed_x) else: logits = self.model(features) logits = tf.cast(logits, tf.float32) if hasattr(self.model, 'add_loss'): loss_cls = Loss()() self.model.add_loss(loss_cls) self.loss = self.model.overall_loss() else: self.loss = Loss()() # loss if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN: loss = self._mixup_loss(self.loss, logits, y_a, y_b, mixup_ratio) else: loss = self.loss(logits, labels) train_op = None if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.compat.v1.train.get_or_create_global_step() epoch = tf.cast(global_step, tf.float32) / tf.cast( len(self.train_loader), tf.float32) self.optimizer = Optimizer()(distributed=self.distributed) self.lr_scheduler = LrScheduler()(optimizer=self.optimizer) self.lr_scheduler.step(epoch) if self.distributed: self.optimizer = Optimizer.set_distributed(self.optimizer) update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS) loss_scale = self.config.loss_scale if self.use_amp else 1 minimize_op = self.optimizer.step(loss, loss_scale, global_step) train_op = tf.group(minimize_op, update_ops) logging_hook = list() logging_hook.append( tf.train.LoggingTensorHook( tensors={"learning rate": self.lr_scheduler.get_lr()[0]}, every_n_iter=10)) eval_metric_ops = None if mode == tf.estimator.ModeKeys.EVAL: eval_metric_ops = self.valid_metrics(logits, labels) if mode == tf.estimator.ModeKeys.TRAIN: return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, training_hooks=logging_hook) else: return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)