def train(self, input_fn, steps=None): """Trains a model given training data `input_fn`. :param input_fn: A function that constructs the input data for evaluation. The function should construct and return one of the following: * A `TFDataset` object, each elements of which is a tuple `(features, labels)`. * A `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple `(features, labels)` with same constraints as below. * A tuple `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary of string feature name to `Tensor` and `labels` is a `Tensor` or a dictionary of string label name to `Tensor`. Both `features` and `labels` are consumed by `model_fn`. They should satisfy the expectation of `model_fn` from inputs. :param steps: Number of steps for which to train the model. Returns: `self`, for chaining. """ with tf.Graph().as_default() as g: global_step_tensor = self.estimator._create_and_assert_global_step( g) add_step_input = tf.placeholder(dtype=tf.int64, shape=()) assign_step = tf.assign_add(global_step_tensor, add_step_input) result = self.estimator._call_input_fn(input_fn, tf.estimator.ModeKeys.TRAIN) if isinstance(result, TFDataset): if not result.has_batch: raise ValueError("The batch_size of TFDataset must be " + "specified when used for training.") spec = self._call_model_fn(result.feature_tensors, result.label_tensors, tf.estimator.ModeKeys.TRAIN, self.config) optim_method = TFOptimizer.to_bigdl_optim_method( koptim_method=self.optimizer) latest_checkpoint = self.estimator.latest_checkpoint() with tf.Session() as sess: saver = tf.train.Saver() if latest_checkpoint: saver.restore(sess, latest_checkpoint) else: sess.run(tf.global_variables_initializer()) opt = TFOptimizer.from_loss( spec.loss, optim_method, session=sess, clip_norm=self.gradient_clipping_norm, clip_value=self.gradient_clipping_constant) opt.optimize(MaxIteration(steps)) sess.run(assign_step, feed_dict={add_step_input: steps}) final_step = sess.run(global_step_tensor) saver.save(sess, self.estimator.model_dir + "/model", global_step=final_step) return self return self.estimator.train(input_fn, steps=steps)
def train(self, input_fn, steps=None): with tf.Graph().as_default() as g: global_step_tensor = self.estimator._create_and_assert_global_step( g) add_step_input = tf.placeholder(dtype=tf.int64, shape=()) assign_step = tf.assign_add(global_step_tensor, add_step_input) result = self.estimator._call_input_fn(input_fn, tf.estimator.ModeKeys.TRAIN) if isinstance(result, TFDataset): if not result.has_batch: raise ValueError("The batch_size of TFDataset must be " + "specified when used for training.") spec = self._call_model_fn(result.feature_tensors, result.label_tensors, tf.estimator.ModeKeys.TRAIN, self.config) optim_method = TFOptimizer.to_bigdl_optim_method( koptim_method=self.optimizer) latest_checkpoint = self.estimator.latest_checkpoint() with tf.Session() as sess: saver = tf.train.Saver() if latest_checkpoint: saver.restore(sess, latest_checkpoint) else: sess.run(tf.global_variables_initializer()) opt = TFOptimizer.from_loss(spec.loss, optim_method, session=sess) opt.optimize(MaxIteration(steps)) sess.run(assign_step, feed_dict={add_step_input: steps}) final_step = sess.run(global_step_tensor) saver.save(sess, self.estimator.model_dir + "/model", global_step=final_step) return self return self.estimator.train(input_fn, steps=steps)