def train(self, input_fn, steps=None):
        """Trains a model given training data `input_fn`.

        :param input_fn: A function that constructs the input data for evaluation. The
            function should construct and return one of the following:
            * A `TFDataset` object, each elements of which is a tuple `(features, labels)`.
            * A `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
            `(features, labels)` with same constraints as below.
            * A tuple `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
            of string feature name to `Tensor` and `labels` is a `Tensor` or a
            dictionary of string label name to `Tensor`. Both `features` and
            `labels` are consumed by `model_fn`. They should satisfy the expectation
            of `model_fn` from inputs.
        :param steps: Number of steps for which to train the model.

        Returns:
          `self`, for chaining.
        """

        with tf.Graph().as_default() as g:
            global_step_tensor = self.estimator._create_and_assert_global_step(g)
            add_step_input = tf.placeholder(dtype=tf.int64, shape=())
            assign_step = tf.assign_add(global_step_tensor, add_step_input)
            result = self.estimator._call_input_fn(input_fn, tf.estimator.ModeKeys.TRAIN)
            if isinstance(result, TFDataset):
                if not result.has_batch:
                    raise ValueError("The batch_size of TFDataset must be " +
                                     "specified when used for training.")
                spec = self._call_model_fn(result.feature_tensors,
                                           result.label_tensors,
                                           tf.estimator.ModeKeys.TRAIN,
                                           self.config)
                optim_method = to_bigdl_optim_method(koptim_method=self.optimizer)
                latest_checkpoint = self.estimator.latest_checkpoint()

                with tf.Session() as sess:
                    saver = tf.train.Saver()
                    if latest_checkpoint:
                        saver.restore(sess, latest_checkpoint)
                    else:
                        sess.run(tf.global_variables_initializer())

                    zoo_ckpt_path = os.path.join(self._model_dir, "analytics-zoo")
                    opt = TFOptimizer.from_loss(spec.loss,
                                                optim_method,
                                                session=sess,
                                                clip_norm=self.gradient_clipping_norm,
                                                clip_value=self.gradient_clipping_constant,
                                                model_dir=zoo_ckpt_path)

                    opt.optimize(MaxIteration(steps))
                    sess.run(assign_step, feed_dict={add_step_input: steps})
                    final_step = sess.run(global_step_tensor)
                    model_path = os.path.join(self._model_dir, "model")
                    saver.save(sess, model_path, global_step=final_step)
                    return self

        return self.estimator.train(input_fn, steps=steps)
    def from_keras(cls,
                   keras_model,
                   dataset,
                   optim_method=None,
                   val_spilt=0.0,
                   **kwargs):
        """
        Create a TFOptimizer from a tensorflow.keras model. The model must be compiled.
        :param keras_model: the tensorflow.keras model, which must be compiled.
        :param dataset: a TFDataset
        :param optim_method: the optimization method to be used, such as bigdl.optim.optimizer.Adam
        :param val_spilt: Float between 0 and 1. Fraction of the training data to be used as
        validation data.
        :return:
        """
        import tensorflow.keras.backend as K
        loss = keras_model.total_loss

        model_inputs = keras_model.inputs
        if hasattr(keras_model, "targets"):
            model_targets = keras_model.targets
        else:
            model_targets = keras_model._targets

        inputs = model_inputs + model_targets

        variables = keras_model._collected_trainable_weights
        variables.sort(key=lambda variable: variable.name)
        keras_optimizer = keras_model.optimizer

        grads = K.gradients(loss, variables)
        if None in grads:
            raise ValueError('An operation has `None` for gradient. '
                             'Please make sure that all of your ops have a '
                             'gradient defined (i.e. are differentiable). '
                             'Common ops without gradient: '
                             'K.argmax, K.round, K.eval.')
        clip_norm = None
        clip_value = None
        if hasattr(keras_optimizer, 'clipnorm'):
            clip_norm = keras_optimizer.clipnorm
        if hasattr(keras_optimizer, 'clipvalue'):
            clip_value = (-keras_optimizer.clipvalue,
                          keras_optimizer.clipvalue)

        sess = K.get_session()
        if optim_method is None:
            optim_method = keras_optimizer
        optim_method = to_bigdl_optim_method(optim_method)

        if keras_model.metrics and (dataset.get_validation_data() is not None
                                    or val_spilt != 0.0):
            if isinstance(keras_model.metrics, dict):
                raise ValueError(
                    "different metrics for different outputs are not supported right now"
                )

            if dataset.get_validation_data() is None and val_spilt == 0.0:
                raise ValueError(
                    "Validation data is not specified. Please set " +
                    "val_rdd in TFDataset, or set val_split larger than zero")

            if len(keras_model.outputs) > 1:
                if not all([
                        name.endswith("loss")
                        for name in keras_model.metrics_names
                ]):
                    raise ValueError(
                        "metrics (except loss) for multi-head model is not supported"
                    )
                else:
                    bigdl_val_methods = [Loss()]
                    val_outputs = keras_model.outputs
                    val_labels = model_targets
            else:
                bigdl_val_methods = \
                    [to_bigdl_metric(m, keras_model.loss) for m in keras_model.metrics_names]
                val_outputs = keras_model.outputs
                val_labels = model_targets
        else:
            val_outputs = None
            val_labels = None
            bigdl_val_methods = None

        tensor_with_value = {K.learning_phase(): [True, False]}

        updates = keras_model.updates

        return cls(loss,
                   optim_method,
                   sess,
                   dataset,
                   inputs,
                   grads,
                   variables,
                   loss.graph,
                   val_outputs,
                   val_labels,
                   bigdl_val_methods,
                   val_spilt,
                   tensors_with_value=tensor_with_value,
                   clip_norm=clip_norm,
                   clip_value=clip_value,
                   updates=updates,
                   **kwargs)
    def from_keras(cls, keras_model, dataset, optim_method=None,
                   session_config=None, model_dir=None):
        """
        Create a TFOptimizer from a tensorflow.keras model. The model must be compiled.
        :param keras_model: the tensorflow.keras model, which must be compiled.
        :param dataset: a TFDataset
        :param optim_method: the optimization method to be used, such as bigdl.optim.optimizer.Adam
        validation data.
        :return:
        """
        import tensorflow.keras.backend as K

        model_inputs = keras_model.inputs
        if hasattr(keras_model, "targets"):
            model_targets = keras_model.targets
        else:
            model_targets = keras_model._targets

        flatten_inputs = nest.flatten(dataset.feature_tensors)
        assert len(model_inputs) == len(flatten_inputs), \
            ("the keras model and TFDataset should have the same number of tensors" +
             " keras model has {} inputs " +
             "while TFDataset has {} inputs").format(len(model_inputs),
                                                     len(flatten_inputs))
        for i in range(len(flatten_inputs)):
            if not TFOptimizer._shape_match(model_inputs[i].shape, flatten_inputs[i].shape):
                raise ValueError(("The {}th input in keras model {}"
                                  " does not match the TFDataset"
                                  "input {}").format(i,
                                                     model_inputs[i],
                                                     flatten_inputs[i]))

        flatten_targets = nest.flatten(dataset.label_tensors)
        assert len(model_targets) == len(flatten_targets), \
            ("the keras model and TFDataset should have the same number of tensors" +
             " keras model has {} targets " +
             "while TFDataset has {} labels").format(len(model_targets),
                                                     len(flatten_inputs))
        # todo check targets shape, currently checking target shape will
        # cause too much false alarm.

        loss = keras_model.total_loss
        variables = keras_model._collected_trainable_weights
        variables.sort(key=lambda variable: variable.name)
        keras_optimizer = keras_model.optimizer

        grads = K.gradients(loss, variables)
        if None in grads:
            raise ValueError('An operation has `None` for gradient. '
                             'Please make sure that all of your ops have a '
                             'gradient defined (i.e. are differentiable). '
                             'Common ops without gradient: '
                             'K.argmax, K.round, K.eval.')
        clip_norm = None
        clip_value = None
        if hasattr(keras_optimizer, 'clipnorm'):
            clip_norm = keras_optimizer.clipnorm
        if hasattr(keras_optimizer, 'clipvalue'):
            clip_value = (-keras_optimizer.clipvalue, keras_optimizer.clipvalue)

        sess = K.get_session()
        if optim_method is None:
            optim_method = keras_optimizer
        optim_method = to_bigdl_optim_method(optim_method)

        if keras_model.metrics and (dataset.get_validation_data() is not None):
            if isinstance(keras_model.metrics, dict):
                raise ValueError(
                    "different metrics for different outputs are not supported right now")

            if dataset.get_validation_data() is None:
                raise ValueError("Validation data is not specified. Please set " +
                                 "val_rdd in TFDataset")

            if len(keras_model.outputs) > 1:
                if not all([name.endswith("loss") for name in keras_model.metrics_names]):
                    raise ValueError("metrics (except loss) for multi-head model is not supported")
                else:
                    bigdl_val_methods = [Loss()]
                    val_outputs = keras_model.outputs
                    val_labels = model_targets
            else:
                bigdl_val_methods = \
                    [to_bigdl_metric(m, keras_model.loss) for m in keras_model.metrics_names]
                val_outputs = keras_model.outputs
                val_labels = model_targets
        else:
            val_outputs = None
            val_labels = None
            bigdl_val_methods = None

        tensor_with_value = {
            K.learning_phase(): [True, False]
        }

        updates = keras_model.updates

        metrics = None

        if bigdl_val_methods is not None:
            val_methods = to_list(bigdl_val_methods)
            metrics = {}
            for i, method in enumerate(val_methods):
                metrics['bigdl_metirc_' + str(i)] = BigDLMetric(method, val_outputs, val_labels)

        tf_model = TFModel.create(loss, sess, model_inputs, model_targets, keras_model.outputs,
                                  grads, variables, loss.graph,
                                  tensor_with_value, session_config, metrics,
                                  updates, model_dir=None)

        return cls(tf_model, optim_method, sess=sess, dataset=dataset,
                   clip_norm=clip_norm, clip_value=clip_value, model_dir=model_dir)