示例#1
0
class TFOptimizer:
    def __init__(self,
                 tf_model,
                 optim_method,
                 sess=None,
                 dataset=None,
                 clip_norm=None,
                 clip_value=None,
                 model_dir=None):
        """
        TFOptimizer is used for distributed training of TensorFlow
        on Spark/BigDL.

        Note that if grads and variables are not None, then they need to be sorted by name
        if you want to use multiple optimization methods for a TensorFlow model according to
        variable names.

        :param loss: The loss tensor of the TensorFlow model, should be a scalar
        :param optim_method: the optimization method to be used, such as bigdl.optim.optimizer.Adam
        :param sess: the current TensorFlow Session, if you want to used a pre-trained model, you
        should use the Session to load the pre-trained variables and pass it to TFOptimizer.
        """

        self.optim_method = optim_method
        self.sess = sess
        self.dataset = dataset

        self.clip_norm = clip_norm
        if clip_value is not None and not isinstance(clip_value, tuple):
            raise ValueError(
                "The clip_value argument should be a tuple (min_value, max_value)"
            )
        self.clip_constant = clip_value

        if self.dataset.batch_size <= 0:
            raise ValueError(
                "You should set batch_size instead of batch_per_thread for training"
            )

        self.model_dir = model_dir

        self.tf_model = tf_model

        batch_size = self.dataset.batch_size

        self.train_data = self.dataset.get_training_data()
        self.val_data = self.dataset.get_validation_data()

        self.batch_size = batch_size

        self.estimator = Estimator(self.tf_model.training_helper_layer,
                                   self.optim_method, self.model_dir)

        if self.clip_norm:
            self.estimator.set_l2_norm_gradient_clipping(self.clip_norm)
        if self.clip_constant:
            min_value, max_value = self.clip_constant
            self.estimator.set_constant_gradient_clipping(min_value, max_value)

    def load_checkpoint(self, path, version):
        # todo make version optional
        model_path = os.path.join(path, "model.{}".format(version))
        optim_method_path = os.path.join(
            path, "optimMethod-TFParkTraining.{}".format(version))
        self.tf_model.training_helper_layer.load_checkpoint(model_path)
        self.optim_method = OptimMethod.load(optim_method_path)
        self.estimator = Estimator(self.tf_model.training_helper_layer,
                                   self.optim_method, self.model_dir)
        if self.clip_norm:
            self.estimator.set_l2_norm_gradient_clipping(self.clip_norm)
        if self.clip_constant:
            min_value, max_value = self.clip_constant
            self.estimator.set_constant_gradient_clipping(min_value, max_value)

    @staticmethod
    def _get_or_create_session(session):
        import tensorflow as tf
        if session is None:
            sess = tf.Session()
            sess.run(tf.global_variables_initializer())
        else:
            sess = session
        return sess

    @staticmethod
    def _get_dataset_from_loss(loss):
        import tensorflow as tf
        all_required_inputs = find_placeholders([loss])
        dataset = tf.get_collection(all_required_inputs[0].name)[0]
        return dataset

    @staticmethod
    def _get_vars_grads(loss):
        import tensorflow as tf
        grads_vars = tf.train.GradientDescentOptimizer(0).compute_gradients(
            loss)
        grads_vars.sort(key=lambda grad_var: grad_var[1].name)
        variables = []
        grads = []
        for (grad, var) in grads_vars:
            if grad is not None:
                variables.append(var)
                grads.append(grad)
        return grads, variables

    @staticmethod
    def _get_vars_grads_from_train_op(train_op):
        def predicate(t):
            return t.name.split("/")[-1].startswith("zoo_identity_op_for_grad")

        grads = find_tensors([train_op], predicate)
        grad_ops = [grad.op for grad in grads]
        variables = []
        for grad in grad_ops:
            var = list(grad.control_inputs)[0]
            if var.name == "VarHandleOp":
                variables.append(var)
            else:
                variables.append(list(var.outputs)[0])
        # variables = [grad.op.control_inputs[0].outputs[0] for grad in grads]
        return grads, variables

    @classmethod
    def from_train_op(cls,
                      train_op,
                      loss,
                      *,
                      inputs=None,
                      labels=None,
                      metrics=None,
                      updates=None,
                      sess=None,
                      dataset=None,
                      tensor_with_value=None,
                      session_config=None,
                      model_dir=None):

        sess = TFOptimizer._get_or_create_session(sess)
        grads, variables = TFOptimizer._get_vars_grads_from_train_op(train_op)
        if dataset is None:
            dataset = TFOptimizer._get_dataset_from_loss(loss)
        _ = dataset.tensors  # trigger create tensors if not available
        dataset_inputs = dataset._original_tensors
        if isinstance(dataset_inputs, tuple) and len(dataset_inputs) == 2:
            if inputs is None:
                inputs = dataset_inputs[0]

            if labels is None:
                labels = dataset_inputs[1]
        else:
            if inputs is None:
                inputs = dataset_inputs

            if labels is None:
                labels = []

        inputs = nest.flatten(inputs)
        labels = nest.flatten(labels)
        from zoo.tfpark.zoo_optimizer import FakeOptimMethod
        return TFOptimizer._from_grads(loss=loss,
                                       sess=sess,
                                       inputs=inputs,
                                       labels=labels,
                                       grads=grads,
                                       variables=variables,
                                       dataset=dataset,
                                       metrics=metrics,
                                       tensor_with_value=tensor_with_value,
                                       optim_method=FakeOptimMethod(),
                                       session_config=session_config,
                                       updates=updates,
                                       model_dir=model_dir,
                                       train_op=train_op)

    @classmethod
    def _from_grads(cls,
                    loss,
                    sess,
                    inputs,
                    labels,
                    grads,
                    variables,
                    dataset,
                    optim_method=None,
                    clip_norm=None,
                    clip_value=None,
                    metrics=None,
                    tensor_with_value=None,
                    session_config=None,
                    model_dir=None,
                    updates=None,
                    train_op=None):
        graph = loss.graph
        if metrics is None:
            metrics = {}

        tf_model = TFModel.create(loss,
                                  sess,
                                  inputs,
                                  labels, [],
                                  grads,
                                  variables,
                                  graph,
                                  tensor_with_value,
                                  session_config,
                                  metrics,
                                  updates,
                                  model_dir=None,
                                  train_op=train_op)
        return cls(tf_model,
                   optim_method,
                   sess=sess,
                   dataset=dataset,
                   clip_norm=clip_norm,
                   clip_value=clip_value,
                   model_dir=model_dir)

    @classmethod
    def from_loss(cls,
                  loss,
                  optim_method,
                  session=None,
                  inputs=None,
                  dataset=None,
                  val_outputs=None,
                  val_labels=None,
                  val_method=None,
                  clip_norm=None,
                  clip_value=None,
                  metrics=None,
                  tensor_with_value=None,
                  session_config=None,
                  model_dir=None,
                  updates=None):
        """
        Create a TFOptimizer from a TensorFlow loss tensor.
        The loss tensor must come from a TensorFlow graph that only takes TFDataset.tensors and
        the tensors in `tensor_with_value` as inputs.
        :param loss: The loss tensor of the TensorFlow model, should be a scalar
        :param optim_method: the optimization method to be used, such as bigdl.optim.optimizer.Adam
        :param session: the current TensorFlow Session, if you want to used a pre-trained model,
        you should use the Session to load the pre-trained variables and pass it to TFOptimizer.
        :param val_outputs: the validation output TensorFlow tensor to be used by val_methods
        :param val_labels: the validation label TensorFlow tensor to be used by val_methods
        :param val_method: the BigDL val_method(s) to be used.
        :param clip_norm: float >= 0. Gradients will be clipped when their L2 norm exceeds
        this value.
        :param clip_value: float >= 0. Gradients will be clipped when their absolute value
        exceeds this value.
        :param metrics: a dictionary. The key should be a string representing the metric's name
        and the value should be the corresponding TensorFlow tensor, which should be a scalar.
        :param tensor_with_value: a dictionary. The key is TensorFlow tensor, usually a
        placeholder, the value of the dictionary is a tuple of two elements. The first one of
        the tuple is the value to feed to the tensor in training phase and the second one
        is the value to feed to the tensor in validation phase.
        :return: a TFOptimizer
        """
        sess = TFOptimizer._get_or_create_session(session)
        grads, variables = TFOptimizer._get_vars_grads(loss)

        if dataset is None and inputs is None:
            dataset = TFOptimizer._get_dataset_from_loss(loss)
            inputs = dataset._original_tensors
        else:
            if inputs is None:
                raise ValueError("please specify inputs")
            _ = dataset.tensors  # trigger creating placeholders

        if isinstance(inputs, tuple) and len(inputs) == 2:
            inputs, labels = inputs
        else:
            labels = []

        inputs = nest.flatten(inputs)
        labels = nest.flatten(labels)

        if clip_value is not None:
            if isinstance(clip_value, float) or isinstance(clip_value, int):
                if clip_value <= 0:
                    ValueError(
                        "The clip_value argument should be positive number")
                clip_value = (-float(clip_value), float(clip_value))

            if not isinstance(clip_value, tuple):
                raise ValueError(
                    "The clip_value argument should be" +
                    " a positive float/int which clips to" +
                    " (-clip_value, clip_value); " +
                    "or a tuple which clips to (min_value, max_value)")

        if val_method is not None:
            val_methods = to_list(val_method)
            if metrics is None:
                metrics = {}

            for i, method in enumerate(val_methods):
                metrics['bigdl_metric_' + str(i)] = BigDLMetric(
                    method, val_outputs, val_labels)

        return TFOptimizer._from_grads(loss, sess, inputs, labels, grads,
                                       variables, dataset, optim_method,
                                       clip_norm, clip_value, metrics,
                                       tensor_with_value, session_config,
                                       model_dir, updates)

    @staticmethod
    def export_training_model(export_dir,
                              loss,
                              sess,
                              inputs,
                              labels=None,
                              predictions=None,
                              metrics=None,
                              tensor_with_value=None,
                              updates=None):

        grads, variables = TFOptimizer._get_vars_grads(loss)

        TFModel.export(export_dir, loss, sess, inputs, labels, predictions,
                       grads, variables, loss.graph, tensor_with_value,
                       metrics, updates)
        logging.info(
            "Exported TensorFlow model in {} for training".format(export_dir))

    @staticmethod
    def _shape_match(model_shape, dataset_shape):

        for i in range(len(dataset_shape)):
            if dataset_shape[i].value is None:
                return model_shape[i].value is None
            else:
                return dataset_shape[i].value == model_shape[i].value or \
                    model_shape[i].value is None

    @classmethod
    def from_keras(cls,
                   keras_model,
                   dataset,
                   session_config=None,
                   model_dir=None,
                   metrics=None,
                   optimizer=None):
        """
        Create a TFOptimizer from a tensorflow.keras model. The model must be compiled.
        :param keras_model: the tensorflow.keras model, which must be compiled.
        :param dataset: a TFDataset
        :return:
        """
        import tensorflow.keras.backend as K

        model_inputs = keras_model.inputs

        if hasattr(keras_model, "targets"):
            model_targets = keras_model.targets
        else:
            model_targets = keras_model._targets

        # target can be None if loss is None
        model_targets = list(filter(lambda x: x is not None, model_targets))

        # standarize feature, labels to support keras model
        if isinstance(dataset, TFNdarrayDataset):
            dataset = _standarize_feature_label_dataset(dataset, keras_model)

        flatten_inputs = nest.flatten(dataset.feature_tensors)
        assert len(model_inputs) == len(flatten_inputs), \
            ("the keras model and TFDataset should have the same number of tensors" +
             " keras model has {} inputs " +
             "while TFDataset has {} inputs").format(len(model_inputs),
                                                     len(flatten_inputs))
        for i in range(len(flatten_inputs)):
            if not TFOptimizer._shape_match(model_inputs[i].shape,
                                            flatten_inputs[i].shape):
                raise ValueError(("The {}th input in keras model {}"
                                  " does not match the TFDataset"
                                  "input {}").format(i, model_inputs[i],
                                                     flatten_inputs[i]))

        flatten_targets = nest.flatten(dataset.label_tensors)
        assert len(model_targets) == len(flatten_targets), \
            ("the keras model and TFDataset should have the same number of tensors" +
             " keras model has {} targets " +
             "while TFDataset has {} labels").format(len(model_targets),
                                                     len(flatten_inputs))
        # todo check targets shape, currently checking target shape will
        # cause too much false alarm.

        loss = keras_model.total_loss
        variables = keras_model._collected_trainable_weights
        variables.sort(key=lambda variable: variable.name)
        keras_optimizer = keras_model.optimizer

        from zoo.tfpark.zoo_optimizer import get_gradients_for_keras
        grads = get_gradients_for_keras(keras_optimizer, loss, variables)
        grads_and_vars = list(zip(grads, variables))
        import tensorflow.python.keras.optimizers as koptimizers
        if isinstance(keras_optimizer, koptimizers.TFOptimizer):
            # work around keras TFOptimzier bug
            train_op = keras_optimizer.optimizer.apply_gradients(
                grads_and_vars)
        else:
            train_op = keras_optimizer.apply_gradients(grads_and_vars)

        sess = K.get_session()

        if keras_model.metrics and (dataset.get_validation_data() is not None):
            if isinstance(keras_model.metrics, dict):
                raise ValueError(
                    "different metrics for different outputs are not supported right now"
                )

            if len(keras_model.outputs) > 1:
                if not all([
                        name.endswith("loss")
                        for name in keras_model.metrics_names
                ]):
                    raise ValueError(
                        "metrics (except loss) for multi-head model is not supported"
                    )
                else:
                    bigdl_val_methods = [Loss()]
                    val_outputs = keras_model.outputs
                    val_labels = model_targets
            else:
                bigdl_val_methods = \
                    [to_bigdl_metric(m, keras_model.loss) for m in keras_model.metrics_names]
                val_outputs = keras_model.outputs
                val_labels = model_targets
        else:
            val_outputs = None
            val_labels = None
            bigdl_val_methods = None

        tensor_with_value = {K.learning_phase(): [True, False]}

        updates = []

        updates += keras_model.get_updates_for(None)
        # Conditional updates relevant to this model
        updates += keras_model.get_updates_for(keras_model.inputs)

        if bigdl_val_methods is not None:
            val_methods = to_list(bigdl_val_methods)
            bigdl_metrics = {}
            for i, method in enumerate(val_methods):
                bigdl_metrics['bigdl_metric_' + str(i)] = BigDLMetric(
                    method, val_outputs, val_labels)
            if metrics is None:
                metrics = bigdl_metrics
            else:
                metrics.update(bigdl_metrics)

        if optimizer is not None:
            clip_norm = None
            clip_value = None
            if hasattr(keras_optimizer, 'clipnorm'):
                clip_norm = keras_optimizer.clipnorm
            if hasattr(keras_optimizer, 'clipvalue'):
                clip_value = (-keras_optimizer.clipvalue,
                              keras_optimizer.clipvalue)
            tf_model = TFModel.create(loss,
                                      sess,
                                      model_inputs,
                                      model_targets,
                                      keras_model.outputs,
                                      grads,
                                      variables,
                                      loss.graph,
                                      tensor_with_value,
                                      session_config,
                                      metrics,
                                      updates,
                                      model_dir=None)

            return cls(tf_model,
                       optimizer,
                       sess=sess,
                       dataset=dataset,
                       clip_norm=clip_norm,
                       clip_value=clip_value,
                       model_dir=model_dir)

        return cls.from_train_op(train_op,
                                 loss,
                                 inputs=model_inputs,
                                 labels=model_targets,
                                 metrics=metrics,
                                 updates=updates,
                                 sess=sess,
                                 dataset=dataset,
                                 tensor_with_value=tensor_with_value,
                                 session_config=session_config,
                                 model_dir=model_dir)

    def set_constant_gradient_clipping(self, min_value, max_value):
        """
        Configure constant clipping settings.

        :param min_value: the minimum value to clip by
        :param max_value: the maxmimum value to clip by
        """
        self.estimator.set_constant_gradient_clipping(min_value, max_value)

    def set_gradient_clipping_by_l2_norm(self, clip_norm):
        """
        Configure L2 norm clipping settings.
        :param clip_norm: gradient L2-Norm threshold
        """
        self.estimator.set_l2_norm_gradient_clipping(clip_norm)

    def optimize(self, end_trigger=None, checkpoint_trigger=None):
        """
        Run the training loop of the this optimizer
        :param end_trigger: BigDL's Trigger to indicate when to stop the training.
        :param checkpoint_trigger: When to save a checkpoint and evaluate model.
        """
        if end_trigger is None:
            end_trigger = MaxEpoch(1)

        if checkpoint_trigger is None:
            checkpoint_trigger = EveryEpoch()

        if self.tf_model.val_methods and self.val_data is not None:
            self.estimator.train_minibatch(
                train_set=self.train_data,
                criterion=self.tf_model.criterion,
                end_trigger=end_trigger,
                checkpoint_trigger=checkpoint_trigger,
                validation_set=self.val_data,
                validation_method=self.tf_model.val_methods)
        else:
            self.estimator.train_minibatch(
                train_set=self.train_data,
                criterion=self.tf_model.criterion,
                end_trigger=end_trigger,
                checkpoint_trigger=checkpoint_trigger)

        self.tf_model.training_helper_layer.get_weights_to_python()
示例#2
0
class TFOptimizer:
    def __init__(self, tf_model, optim_method,
                 sess=None, dataset=None,
                 val_split=0.0,
                 clip_norm=None, clip_value=None,
                 model_dir=None):
        """
        TFOptimizer is used for distributed training of TensorFlow
        on Spark/BigDL.

        Note that if grads and variables are not None, then they need to be sorted by name
        if you want to use multiple optimization methods for a TensorFlow model according to
        variable names.

        :param loss: The loss tensor of the TensorFlow model, should be a scalar
        :param optim_method: the optimization method to be used, such as bigdl.optim.optimizer.Adam
        :param sess: the current TensorFlow Session, if you want to used a pre-trained model, you
        should use the Session to load the pre-trained variables and pass it to TFOptimizer.
        """

        self.optim_method = optim_method
        self.sess = sess
        self.dataset = dataset

        self.clip_norm = clip_norm
        if clip_value is not None and not isinstance(clip_value, tuple):
            raise ValueError("The clip_value argument should be a tuple (min_value, max_value)")
        self.clip_constant = clip_value

        if self.dataset.batch_size <= 0:
            raise ValueError("You should set batch_size instead of batch_per_thread for training")

        self.model_dir = model_dir

        self.tf_model = tf_model

        batch_size = self.dataset.batch_size

        sample_rdd = self.dataset.get_training_data()

        if val_split != 0.0:
            training_rdd, val_rdd = sample_rdd.randomSplit([1 - val_split, val_split])
        else:
            training_rdd = sample_rdd
            val_rdd = self.dataset.get_validation_data()

        self.training_rdd = training_rdd
        self.val_rdd = val_rdd
        self.batch_size = batch_size

        self.estimator = Estimator(self.tf_model.training_helper_layer, self.optim_method,
                                   model_dir)

        if self.clip_norm:
            self.estimator.set_l2_norm_gradient_clipping(self.clip_norm)
        if self.clip_constant:
            min_value, max_value = self.clip_constant
            self.estimator.set_constant_gradient_clipping(min_value, max_value)

    @staticmethod
    def _get_or_create_session(session):
        if session is None:
            sess = tf.Session()
            sess.run(tf.global_variables_initializer())
        else:
            sess = session
        return sess

    @staticmethod
    def _get_dataset_from_loss(loss):
        all_required_inputs = find_placeholders([loss])
        dataset = tf.get_collection(all_required_inputs[0].name)[0]
        return dataset

    @staticmethod
    def _get_vars_grads(loss):

        grads_vars = tf.train.GradientDescentOptimizer(0).compute_gradients(loss)
        grads_vars.sort(key=lambda grad_var: grad_var[1].name)
        variables = []
        grads = []
        for (grad, var) in grads_vars:
            if grad is not None:
                variables.append(var)
                grads.append(grad)
        return grads, variables

    @staticmethod
    def _get_vars_grads_from_train_op(train_op):
        def predicate(t):
            return t.name.split("/")[-1].startswith("zoo_identity_op_for_grad")

        grads = find_tensors([train_op], predicate)
        grad_ops = [grad.op for grad in grads]
        variables = []
        for grad in grad_ops:
            var = list(grad.control_inputs)[0]
            if var.name == "VarHandleOp":
                variables.append(var)
            else:
                variables.append(list(var.outputs)[0])
        # variables = [grad.op.control_inputs[0].outputs[0] for grad in grads]
        return grads, variables

    @classmethod
    def from_train_op(cls, train_op, loss, metrics=None, updates=None, sess=None, dataset=None,
                      tensor_with_value=None, session_config=None, model_dir=None):
        sess = TFOptimizer._get_or_create_session(sess)
        grads, variables = TFOptimizer._get_vars_grads_from_train_op(train_op)
        if dataset is None:
            dataset = TFOptimizer._get_dataset_from_loss(loss)
        inputs = nest.flatten(dataset._original_tensors)
        return TFOptimizer._from_grads(loss=loss, sess=sess, inputs=inputs, grads=grads,
                                       variables=variables, dataset=dataset, metrics=metrics,
                                       tensor_with_value=tensor_with_value,
                                       optim_method=FakeOptimMethod(),
                                       session_config=session_config, updates=updates,
                                       model_dir=model_dir, train_op=train_op)

    @classmethod
    def _from_grads(cls, loss, sess, inputs, grads, variables, dataset, optim_method=None,
                    val_split=0.0, clip_norm=None, clip_value=None,
                    metrics=None, tensor_with_value=None, session_config=None,
                    model_dir=None, updates=None, train_op=None):
        graph = loss.graph
        if metrics is None:
            metrics = {}

        tf_model = TFModel.create(loss, sess, inputs, grads, variables, graph,
                                  tensor_with_value, session_config, metrics,
                                  updates, model_dir, train_op=train_op)
        return cls(tf_model, optim_method, sess=sess, dataset=dataset, val_split=val_split,
                   clip_norm=clip_norm, clip_value=clip_value)

    @classmethod
    def from_loss(cls, loss, optim_method, session=None, val_outputs=None,
                  val_labels=None, val_method=None, val_split=0.0,
                  clip_norm=None, clip_value=None, metrics=None,
                  tensor_with_value=None, session_config=None, model_dir=None, updates=None):
        """
        Create a TFOptimizer from a TensorFlow loss tensor.
        The loss tensor must come from a TensorFlow graph that only takes TFDataset.tensors and
        the tensors in `tensor_with_value` as inputs.
        :param loss: The loss tensor of the TensorFlow model, should be a scalar
        :param optim_method: the optimization method to be used, such as bigdl.optim.optimizer.Adam
        :param session: the current TensorFlow Session, if you want to used a pre-trained model,
        you should use the Session to load the pre-trained variables and pass it to TFOptimizer.
        :param val_outputs: the validation output TensorFlow tensor to be used by val_methods
        :param val_labels: the validation label TensorFlow tensor to be used by val_methods
        :param val_method: the BigDL val_method(s) to be used.
        :param val_split: Float between 0 and 1. Fraction of the training data to be used as
        validation data.
        :param clip_norm: float >= 0. Gradients will be clipped when their L2 norm exceeds
        this value.
        :param clip_value: float >= 0. Gradients will be clipped when their absolute value
        exceeds this value.
        :param metrics: a dictionary. The key should be a string representing the metric's name
        and the value should be the corresponding TensorFlow tensor, which should be a scalar.
        :param tensor_with_value: a dictionary. The key is TensorFlow tensor, usually a
        placeholder, the value of the dictionary is a tuple of two elements. The first one of
        the tuple is the value to feed to the tensor in training phase and the second one
        is the value to feed to the tensor in validation phase.
        :return: a TFOptimizer
        """
        sess = TFOptimizer._get_or_create_session(session)
        grads, variables = TFOptimizer._get_vars_grads(loss)
        dataset = TFOptimizer._get_dataset_from_loss(loss)
        inputs = nest.flatten(dataset._original_tensors)

        if clip_value is not None:
            if isinstance(clip_value, float) or isinstance(clip_value, int):
                if clip_value <= 0:
                    ValueError("The clip_value argument should be positive number")
                clip_value = (-float(clip_value), float(clip_value))

            if not isinstance(clip_value, tuple):
                raise ValueError("The clip_value argument should be" +
                                 " a positive float/int which clips to" +
                                 " (-clip_value, clip_value); " +
                                 "or a tuple which clips to (min_value, max_value)")

        if val_method is not None:
            val_methods = to_list(val_method)
            if metrics is None:
                metrics = {}

            for i, method in enumerate(val_methods):
                metrics['bigdl_metirc_' + str(i)] = BigDLMetric(method, val_outputs, val_labels)

        return TFOptimizer._from_grads(loss, sess, inputs, grads, variables, dataset, optim_method,
                                       val_split, clip_norm, clip_value,
                                       metrics, tensor_with_value, session_config,
                                       model_dir, updates)

    @staticmethod
    def export_training_model(export_dir, loss, sess, inputs,
                              metrics=None, tensor_with_value=None, updates=None):

        grads, variables = TFOptimizer._get_vars_grads(loss)

        TFModel.export(export_dir, loss, sess, inputs, grads, variables, loss.graph,
                       tensor_with_value, metrics, updates)
        logging.info("Exported TensorFlow model in {} for training".format(export_dir))

    @classmethod
    def from_keras(cls, keras_model, dataset, optim_method=None, val_split=0.0,
                   session_config=None, model_dir=None):
        """
        Create a TFOptimizer from a tensorflow.keras model. The model must be compiled.
        :param keras_model: the tensorflow.keras model, which must be compiled.
        :param dataset: a TFDataset
        :param optim_method: the optimization method to be used, such as bigdl.optim.optimizer.Adam
        :param val_split: Float between 0 and 1. Fraction of the training data to be used as
        validation data.
        :return:
        """
        import tensorflow.keras.backend as K

        model_inputs = keras_model.inputs
        if hasattr(keras_model, "targets"):
            model_targets = keras_model.targets
        else:
            model_targets = keras_model._targets

        inputs = model_inputs + model_targets

        loss = keras_model.total_loss
        variables = keras_model._collected_trainable_weights
        variables.sort(key=lambda variable: variable.name)
        keras_optimizer = keras_model.optimizer

        grads = K.gradients(loss, variables)
        if None in grads:
            raise ValueError('An operation has `None` for gradient. '
                             'Please make sure that all of your ops have a '
                             'gradient defined (i.e. are differentiable). '
                             'Common ops without gradient: '
                             'K.argmax, K.round, K.eval.')
        clip_norm = None
        clip_value = None
        if hasattr(keras_optimizer, 'clipnorm'):
            clip_norm = keras_optimizer.clipnorm
        if hasattr(keras_optimizer, 'clipvalue'):
            clip_value = (-keras_optimizer.clipvalue, keras_optimizer.clipvalue)

        sess = K.get_session()
        if optim_method is None:
            optim_method = keras_optimizer
        optim_method = to_bigdl_optim_method(optim_method)

        if keras_model.metrics and (dataset.get_validation_data() is not None or val_split != 0.0):
            if isinstance(keras_model.metrics, dict):
                raise ValueError(
                    "different metrics for different outputs are not supported right now")

            if dataset.get_validation_data() is None and val_split == 0.0:
                raise ValueError("Validation data is not specified. Please set " +
                                 "val_rdd in TFDataset, or set val_split larger than zero")

            if len(keras_model.outputs) > 1:
                if not all([name.endswith("loss") for name in keras_model.metrics_names]):
                    raise ValueError("metrics (except loss) for multi-head model is not supported")
                else:
                    bigdl_val_methods = [Loss()]
                    val_outputs = keras_model.outputs
                    val_labels = model_targets
            else:
                bigdl_val_methods = \
                    [to_bigdl_metric(m, keras_model.loss) for m in keras_model.metrics_names]
                val_outputs = keras_model.outputs
                val_labels = model_targets
        else:
            val_outputs = None
            val_labels = None
            bigdl_val_methods = None

        tensor_with_value = {
            K.learning_phase(): [True, False]
        }

        updates = keras_model.updates

        metrics = None

        if bigdl_val_methods is not None:
            val_methods = to_list(bigdl_val_methods)
            metrics = {}
            for i, method in enumerate(val_methods):
                metrics['bigdl_metirc_' + str(i)] = BigDLMetric(method, val_outputs, val_labels)

        tf_model = TFModel.create(loss, sess, inputs, grads, variables, loss.graph,
                                  tensor_with_value, session_config, metrics,
                                  updates, model_dir)

        return cls(tf_model, optim_method, sess=sess, dataset=dataset, val_split=val_split,
                   clip_norm=clip_norm, clip_value=clip_value)

    def set_constant_gradient_clipping(self, min_value, max_value):
        """
        Configure constant clipping settings.

        :param min_value: the minimum value to clip by
        :param max_value: the maxmimum value to clip by
        """
        self.estimator.set_constant_gradient_clipping(min_value, max_value)

    def set_gradient_clipping_by_l2_norm(self, clip_norm):
        """
        Configure L2 norm clipping settings.
        :param clip_norm: gradient L2-Norm threshold
        """
        self.estimator.set_l2_norm_gradient_clipping(clip_norm)

    def optimize(self, end_trigger=None, checkpoint_trigger=None):
        """
        Run the training loop of the this optimizer
        :param end_trigger: BigDL's Trigger to indicate when to stop the training.
        :param checkpoint_trigger: When to save a checkpoint and evaluate model.
        """
        if end_trigger is None:
            end_trigger = MaxEpoch(1)

        if checkpoint_trigger is None:
            checkpoint_trigger = EveryEpoch()

        if self.tf_model.val_methods and self.val_rdd is not None:
            self.estimator.train_minibatch(train_set=self.training_rdd,
                                           criterion=self.tf_model.criterion,
                                           end_trigger=end_trigger,
                                           checkpoint_trigger=checkpoint_trigger,
                                           validation_set=self.val_rdd,
                                           validation_method=self.tf_model.val_methods)
        else:
            self.estimator.train_minibatch(train_set=self.training_rdd,
                                           criterion=self.tf_model.criterion,
                                           end_trigger=end_trigger)

        self.tf_model.training_helper_layer.get_weights_to_python()