示例#1
0
def run_movie(flags_obj):
    """Construct all necessary functions and call run_loop.

  Args:
    flags_obj: Object containing user specified flags.
  """

    if flags_obj.download_if_missing:
        movielens.download(dataset=flags_obj.dataset,
                           data_dir=flags_obj.data_dir)

    train_input_fn, eval_input_fn, model_column_fn = \
      movielens_dataset.construct_input_fns(
          dataset=flags_obj.dataset, data_dir=flags_obj.data_dir,
          batch_size=flags_obj.batch_size, repeat=flags_obj.epochs_between_evals)

    tensors_to_log = {'loss': '{loss_prefix}head/weighted_loss/value'}

    wide_deep_run_loop.run_loop(name="MovieLens",
                                train_input_fn=train_input_fn,
                                eval_input_fn=eval_input_fn,
                                model_column_fn=model_column_fn,
                                build_estimator_fn=build_estimator,
                                flags_obj=flags_obj,
                                tensors_to_log=tensors_to_log,
                                early_stop=False)
示例#2
0
  def test_input_fn(self):
    train_input_fn, _, _ = movielens_dataset.construct_input_fns(
        dataset=movielens.ML_1M, data_dir=self.temp_dir, batch_size=8, repeat=1)

    dataset = train_input_fn()
    features, labels = dataset.make_one_shot_iterator().get_next()

    with self.session() as sess:
      features, labels = sess.run((features, labels))

      # Compare the two features dictionaries.
      for key in TEST_INPUT_VALUES:
        self.assertTrue(key in features)
        self.assertAllClose(TEST_INPUT_VALUES[key], features[key][0])

      self.assertAllClose(labels[0], [1.0])