Exemple #1
0
  def setUpClass(cls):
    tmp_dir = tf.test.get_temp_dir()
    shutil.rmtree(tmp_dir)
    os.mkdir(tmp_dir)

    # Generate a small test dataset
    FLAGS.problems = "tiny_algo"
    TrainerUtilsTest.data_dir = tmp_dir
    registry.problem(FLAGS.problems).generate_data(TrainerUtilsTest.data_dir,
                                                   None)
Exemple #2
0
  def testSingleEvalStepRawSession(self):
    """Illustrate how to run a T2T model in a raw session."""

    # Set model name, hparams, problems as would be set on command line.
    model_name = "transformer"
    FLAGS.hparams_set = "transformer_test"
    FLAGS.problems = "tiny_algo"
    data_dir = "/tmp"  # Used only when a vocab file or such like is needed.

    # Create the problem object, hparams, placeholders, features dict.
    encoders = registry.problem(FLAGS.problems).feature_encoders(data_dir)
    hparams = trainer_utils.create_hparams(FLAGS.hparams_set, data_dir)
    trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
    inputs_ph = tf.placeholder(dtype=tf.int32)  # Just length dimension.
    batch_inputs = tf.reshape(inputs_ph, [1, -1, 1, 1])  # Make it 4D.
    # In INFER mode targets can be None.
    targets_ph = tf.placeholder(dtype=tf.int32)  # Just length dimension.
    batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])  # Make it 4D.
    features = {
        "inputs": batch_inputs,
        "targets": batch_targets,
        "problem_choice": 0,  # We run on the first problem here.
        "input_space_id": hparams.problems[0].input_space_id,
        "target_space_id": hparams.problems[0].target_space_id
    }

    # Now set a mode and create the graph by invoking model_fn.
    mode = tf.estimator.ModeKeys.EVAL
    estimator_spec = model_builder.model_fn(
        model_name, features, mode, hparams, problem_names=[FLAGS.problems])
    predictions_dict = estimator_spec.predictions
    predictions = tf.squeeze(  # These are not images, axis=2,3 are not needed.
        predictions_dict["predictions"],
        axis=[2, 3])

    # Having the graph, let's run it on some data.
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      inputs = "0 1 0"
      targets = "0 1 0"
      # Encode from raw string to numpy input array using problem encoders.
      inputs_numpy = encoders["inputs"].encode(inputs)
      targets_numpy = encoders["targets"].encode(targets)
      # Feed the encoded inputs and targets and run session.
      feed = {inputs_ph: inputs_numpy, targets_ph: targets_numpy}
      np_predictions = sess.run(predictions, feed)
      # Check that the result has the correct shape: batch x length x vocab_size
      #   where, for us, batch = 1, length = 3, vocab_size = 4.
      self.assertEqual(np_predictions.shape, (1, 3, 4))
Exemple #3
0
 def testSliceNet(self):
     x = np.random.random_integers(0, high=255, size=(3, 5, 5, 3))
     y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1))
     hparams = slicenet.slicenet_params1_tiny()
     hparams.add_hparam("data_dir", "")
     problem = registry.problem("image_cifar10")
     p_hparams = problem.get_hparams(hparams)
     hparams.problems = [p_hparams]
     with self.test_session() as session:
         features = {
             "inputs": tf.constant(x, dtype=tf.int32),
             "targets": tf.constant(y, dtype=tf.int32),
             "target_space_id": tf.constant(1, dtype=tf.int32),
         }
         model = slicenet.SliceNet(hparams, tf.estimator.ModeKeys.TRAIN,
                                   p_hparams)
         sharded_logits, _ = model.model_fn(features)
         logits = tf.concat(sharded_logits, 0)
         session.run(tf.global_variables_initializer())
         res = session.run(logits)
     self.assertEqual(res.shape, (3, 1, 1, 1, 10))
Exemple #4
0
def problem(name):
    return registry.problem(name)
Exemple #5
0
 def setUpClass(cls):
     tf.set_random_seed(1)
     cls.problem = registry.problem("test_problem")
     cls.data_dir = tempfile.gettempdir()
     cls.filepatterns = generate_test_data(cls.problem, cls.data_dir)