def testLinearRegression(self):
    my_seed = 42
    config = run_config.RunConfig(tf_random_seed=my_seed)
    boston = base.load_boston()
    columns = [feature_column.real_valued_column('', dimension=13)]

    # We train with

    with ops.Graph().as_default() as g1:
      random.seed(my_seed)
      g1.seed = my_seed
      variables.create_global_step()
      regressor1 = linear.LinearRegressor(
          optimizer=_NULL_OPTIMIZER, feature_columns=columns, config=config)
      regressor1.fit(x=boston.data, y=boston.target, steps=1)

    with ops.Graph().as_default() as g2:
      random.seed(my_seed)
      g2.seed = my_seed
      variables.create_global_step()
      regressor2 = linear.LinearRegressor(
          optimizer=_NULL_OPTIMIZER, feature_columns=columns, config=config)
      regressor2.fit(x=boston.data, y=boston.target, steps=1)

    self.assertAllClose(regressor1.weights_, regressor2.weights_)
    self.assertAllClose(regressor1.bias_, regressor2.bias_)
    self.assertAllClose(
        list(regressor1.predict(
            boston.data, as_iterable=True)),
        list(regressor2.predict(
            boston.data, as_iterable=True)),
        atol=1e-05)
Exemplo n.º 2
0
def _build_estimator_for_export_tests(tmpdir):

  def _input_fn():
    iris = base.load_iris()
    return {
        'feature': constant_op.constant(
            iris.data, dtype=dtypes.float32)
    }, constant_op.constant(
        iris.target, shape=[150], dtype=dtypes.int32)

  feature_columns = [
      feature_column_lib.real_valued_column(
          'feature', dimension=4)
  ]

  est = linear.LinearRegressor(feature_columns)
  est.fit(input_fn=_input_fn, steps=20)

  feature_spec = feature_column_lib.create_feature_spec_for_parsing(
      feature_columns)
  serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)

  # hack in an op that uses an asset, in order to test asset export.
  # this is not actually valid, of course.
  def serving_input_fn_with_asset():
    features, labels, inputs = serving_input_fn()

    vocab_file_name = os.path.join(tmpdir, 'my_vocab_file')
    vocab_file = gfile.GFile(vocab_file_name, mode='w')
    vocab_file.write(VOCAB_FILE_CONTENT)
    vocab_file.close()
    hashtable = lookup.HashTable(
        lookup.TextFileStringTableInitializer(vocab_file_name), 'x')
    features['bogus_lookup'] = hashtable.lookup(
        math_ops.to_int64(features['feature']))

    return input_fn_utils.InputFnOps(features, labels, inputs)

  return est, serving_input_fn_with_asset