def _test_complete_flow(self, train_input_fn, eval_input_fn,
                            predict_input_fn, input_dimension, label_dimension,
                            batch_size):
        feature_columns = [
            tf.feature_column.numeric_column('x', shape=(input_dimension, ))
        ]
        est = linear.LinearEstimator(
            head=head_lib._regression_head(label_dimension=label_dimension),
            feature_columns=feature_columns,
            model_dir=self._model_dir)

        # Train
        num_steps = 10
        est.train(train_input_fn, steps=num_steps)

        # Evaluate
        scores = est.evaluate(eval_input_fn)
        self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])
        self.assertIn('loss', six.iterkeys(scores))

        # Predict
        predictions = np.array([
            x[prediction_keys.PredictionKeys.PREDICTIONS]
            for x in est.predict(predict_input_fn)
        ])
        self.assertAllEqual((batch_size, label_dimension), predictions.shape)

        # Export
        feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(
            feature_columns)
        serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
            feature_spec)
        export_dir = est.export_saved_model(tempfile.mkdtemp(),
                                            serving_input_receiver_fn)
        self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))
def _linear_estimator_fn(weight_column=None, label_dimension=1, **kwargs):
    """Returns a LinearEstimator that uses regression_head."""
    return linear.LinearEstimator(
        head=head_lib._regression_head(
            weight_column=weight_column,
            label_dimension=label_dimension,
            # Tests in core (from which this test inherits) test the sum loss.
            loss_reduction=tf.compat.v1.losses.Reduction.SUM),
        **kwargs)
    def _test_ckpt_converter(self, train_input_fn, eval_input_fn,
                             predict_input_fn, input_dimension,
                             label_dimension, batch_size, optimizer):

        # Create checkpoint in CannedEstimator v1.
        feature_columns_v1 = [
            feature_column._numeric_column('x', shape=(input_dimension, ))
        ]

        est_v1 = linear.LinearEstimator(
            head=head_lib._regression_head(label_dimension=label_dimension),
            feature_columns=feature_columns_v1,
            model_dir=self._old_ckpt_dir,
            optimizer=optimizer)
        # Train
        num_steps = 10
        est_v1.train(train_input_fn, steps=num_steps)
        self.assertIsNotNone(est_v1.latest_checkpoint())
        self.assertTrue(est_v1.latest_checkpoint().startswith(
            self._old_ckpt_dir))

        # Convert checkpoint from v1 to v2.
        source_checkpoint = os.path.join(self._old_ckpt_dir, 'model.ckpt-10')
        source_graph = os.path.join(self._old_ckpt_dir, 'graph.pbtxt')
        target_checkpoint = os.path.join(self._new_ckpt_dir, 'model.ckpt-10')
        checkpoint_converter.convert_checkpoint('linear', source_checkpoint,
                                                source_graph,
                                                target_checkpoint)

        # Create CannedEstimator V2 and restore from the converted checkpoint.
        feature_columns_v2 = [
            tf.feature_column.numeric_column('x', shape=(input_dimension, ))
        ]
        est_v2 = linear.LinearEstimatorV2(head=regression_head.RegressionHead(
            label_dimension=label_dimension),
                                          feature_columns=feature_columns_v2,
                                          model_dir=self._new_ckpt_dir,
                                          optimizer=optimizer)
        # Train
        extra_steps = 10
        est_v2.train(train_input_fn, steps=extra_steps)
        self.assertIsNotNone(est_v2.latest_checkpoint())
        self.assertTrue(est_v2.latest_checkpoint().startswith(
            self._new_ckpt_dir))
        # Make sure estimator v2 restores from the converted checkpoint, and
        # continues training extra steps.
        self.assertEqual(
            num_steps + extra_steps,
            est_v2.get_variable_value(tf.compat.v1.GraphKeys.GLOBAL_STEP))