Example #1
0
    def _default_model_fn(self, features, labels, mode):
        """Define model_fn used by TensorFlow Estimator.

        :params features: input features
        :type features: tensorflow tensors
        :params labels: label data
        :type labels: tensorflow tensors
        :params mode: mode of estimator
        :type mode: tf.estimator.ModeKeys
        :return: tensorflow EstimatorSpec
        :rtype: tf.estimator.EstimatorSpec
        """
        logging.info('model function action')
        self.model.training = mode == tf.estimator.ModeKeys.TRAIN
        logits = self.model(features)
        assign_ops = self.model.pretrained()
        with tf.control_dependencies(assign_ops):
            logits = tf.cast(logits, tf.float32)
            if hasattr(self.model, 'add_loss'):
                loss_cls = Loss()()
                self.model.add_loss(loss_cls)
                self.loss = self.model.overall_loss()
            else:
                self.loss = Loss()()
            loss = self.loss(logits, labels)
            train_op = None
            if mode == tf.estimator.ModeKeys.TRAIN:
                global_step = tf.compat.v1.train.get_or_create_global_step()
                epoch = tf.cast(global_step, tf.float32) / tf.cast(
                    len(self.train_loader), tf.float32)
                self.optimizer = Optimizer()(distributed=self.distributed)
                self.lr_scheduler = LrScheduler()(optimizer=self.optimizer)
                self.lr_scheduler.step(epoch)
                if self.distributed:
                    self.optimizer = Optimizer.set_distributed(self.optimizer)

                update_ops = tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.UPDATE_OPS)
                loss_scale = self.config.loss_scale if self.use_amp else 1
                minimize_op = self.optimizer.step(loss, loss_scale,
                                                  global_step)
                train_op = tf.group(minimize_op, update_ops)

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = self.valid_metrics(logits, labels)
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)
Example #2
0
    def _default_model_fn(self, features, labels, mode):
        """Define model_fn used by TensorFlow Estimator.

        :params features: input features
        :type features: tensorflow tensors
        :params labels: label data
        :type labels: tensorflow tensors
        :params mode: mode of estimator
        :type mode: tf.estimator.ModeKeys
        :return: tensorflow EstimatorSpec
        :rtype: tf.estimator.EstimatorSpec
        """
        logging.info('model function action')

        self.model.training = mode == tf.estimator.ModeKeys.TRAIN
        if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN:
            mixup_ratio = tf.compat.v1.distributions.Beta(0.1, 0.1).sample()
            mixed_x, y_a, y_b = self._mixup_batch(features, labels,
                                                  mixup_ratio)
            logits = self.model(mixed_x)
        else:
            logits = self.model(features)
        logits = tf.cast(logits, tf.float32)
        if hasattr(self.model, 'add_loss'):
            loss_cls = Loss()()
            self.model.add_loss(loss_cls)
            self.loss = self.model.overall_loss()
        else:
            self.loss = Loss()()
        # loss
        if self.config.mixup and mode == tf.estimator.ModeKeys.TRAIN:
            loss = self._mixup_loss(self.loss, logits, y_a, y_b, mixup_ratio)
        else:
            loss = self.loss(logits, labels)
        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.compat.v1.train.get_or_create_global_step()
            epoch = tf.cast(global_step, tf.float32) / tf.cast(
                len(self.train_loader), tf.float32)
            self.optimizer = Optimizer()(distributed=self.distributed)
            self.lr_scheduler = LrScheduler()(optimizer=self.optimizer)
            self.lr_scheduler.step(epoch)
            if self.distributed:
                self.optimizer = Optimizer.set_distributed(self.optimizer)

            update_ops = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.UPDATE_OPS)
            loss_scale = self.config.loss_scale if self.use_amp else 1
            minimize_op = self.optimizer.step(loss, loss_scale, global_step)
            train_op = tf.group(minimize_op, update_ops)
            logging_hook = list()
            logging_hook.append(
                tf.train.LoggingTensorHook(
                    tensors={"learning rate": self.lr_scheduler.get_lr()[0]},
                    every_n_iter=10))

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = self.valid_metrics(logits, labels)
        if mode == tf.estimator.ModeKeys.TRAIN:
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=loss,
                                              train_op=train_op,
                                              eval_metric_ops=eval_metric_ops,
                                              training_hooks=logging_hook)
        else:
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=loss,
                                              train_op=train_op,
                                              eval_metric_ops=eval_metric_ops)