Ejemplo n.º 1
0
def run_movie(flags_obj):
  """Construct all necessary functions and call run_loop.

  Args:
    flags_obj: Object containing user specified flags.
  """

  if flags_obj.download_if_missing:
    movielens.download(dataset=flags_obj.dataset, data_dir=flags_obj.data_dir)

  train_input_fn, eval_input_fn, model_column_fn = \
    movielens_dataset.construct_input_fns(
        dataset=flags_obj.dataset, data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size, repeat=flags_obj.epochs_between_evals)

  tensors_to_log = {
      'loss': '{loss_prefix}head/weighted_loss/value'
  }

  wide_deep_run_loop.run_loop(
      name="MovieLens", train_input_fn=train_input_fn,
      eval_input_fn=eval_input_fn,
      model_column_fn=model_column_fn,
      build_estimator_fn=build_estimator,
      flags_obj=flags_obj,
      tensors_to_log=tensors_to_log,
      early_stop=False)
Ejemplo n.º 2
0
  def test_input_fn(self):
    train_input_fn, _, _ = movielens_dataset.construct_input_fns(
        dataset=movielens.ML_1M, data_dir=self.temp_dir, batch_size=8, repeat=1)

    dataset = train_input_fn()
    features, labels = dataset.make_one_shot_iterator().get_next()

    with self.test_session() as sess:
      features, labels = sess.run((features, labels))

      # Compare the two features dictionaries.
      for key in TEST_INPUT_VALUES:
        self.assertTrue(key in features)
        self.assertAllClose(TEST_INPUT_VALUES[key], features[key][0])

      self.assertAllClose(labels[0], [1.0])
Ejemplo n.º 3
0
    def test_input_fn(self):
        train_input_fn, _, _ = movielens_dataset.construct_input_fns(
            dataset=movielens.ML_1M,
            data_dir=self.temp_dir,
            batch_size=8,
            repeat=1)

        dataset = train_input_fn()
        features, labels = dataset.make_one_shot_iterator().get_next()

        with self.session() as sess:
            features, labels = sess.run((features, labels))

            # Compare the two features dictionaries.
            for key in TEST_INPUT_VALUES:
                self.assertTrue(key in features)
                self.assertAllClose(TEST_INPUT_VALUES[key], features[key][0])

            self.assertAllClose(labels[0], [1.0])