Пример #1
0
    def train(self, input_fn, steps=None):
        """Trains a model given training data `input_fn`.

        :param input_fn: A function that constructs the input data for evaluation. The
            function should construct and return one of the following:
            * A `TFDataset` object, each elements of which is a tuple `(features, labels)`.
            * A `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
            `(features, labels)` with same constraints as below.
            * A tuple `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
            of string feature name to `Tensor` and `labels` is a `Tensor` or a
            dictionary of string label name to `Tensor`. Both `features` and
            `labels` are consumed by `model_fn`. They should satisfy the expectation
            of `model_fn` from inputs.
        :param steps: Number of steps for which to train the model.

        Returns:
          `self`, for chaining.
        """

        with tf.Graph().as_default() as g:
            global_step_tensor = self.estimator._create_and_assert_global_step(
                g)
            add_step_input = tf.placeholder(dtype=tf.int64, shape=())
            assign_step = tf.assign_add(global_step_tensor, add_step_input)
            result = self.estimator._call_input_fn(input_fn,
                                                   tf.estimator.ModeKeys.TRAIN)
            if isinstance(result, TFDataset):
                if not result.has_batch:
                    raise ValueError("The batch_size of TFDataset must be " +
                                     "specified when used for training.")
                spec = self._call_model_fn(result.feature_tensors,
                                           result.label_tensors,
                                           tf.estimator.ModeKeys.TRAIN,
                                           self.config)
                latest_checkpoint = self.estimator.latest_checkpoint()

                with tf.Session() as sess:
                    saver = tf.train.Saver()
                    if latest_checkpoint:
                        saver.restore(sess, latest_checkpoint)
                    else:
                        sess.run(tf.global_variables_initializer())

                    zoo_ckpt_path = os.path.join(self._model_dir,
                                                 "analytics-zoo")

                    opt = TFOptimizer.from_train_op(spec.train_op,
                                                    spec.loss,
                                                    sess=sess,
                                                    dataset=result,
                                                    model_dir=zoo_ckpt_path)

                    opt.optimize(MaxIteration(steps))
                    sess.run(assign_step, feed_dict={add_step_input: steps})
                    final_step = sess.run(global_step_tensor)
                    model_path = os.path.join(self._model_dir, "model")
                    saver.save(sess, model_path, global_step=final_step)
                    return self

        return self.estimator.train(input_fn, steps=steps)
Пример #2
0
    def _fit_distributed(self, dataset, epochs, **kwargs):
        self.tf_optimizer = TFOptimizer.from_keras(self.model,
                                                   dataset,
                                                   model_dir=self.model_dir,
                                                   **kwargs)

        self.tf_optimizer.optimize(MaxEpoch(epochs))
Пример #3
0
    def _fit_distributed(self, dataset, validation_split, epochs, **kwargs):
        self.tf_optimizer = TFOptimizer.from_keras(self.model, dataset,
                                                   val_split=validation_split, **kwargs)

        if self.train_summary is not None:
            self.tf_optimizer.set_train_summary(self.train_summary)

        if self.val_summary is not None:
            self.tf_optimizer.set_val_summary(self.val_summary)

        self.tf_optimizer.optimize(MaxEpoch(epochs))