Exemplo n.º 1
0
    def testModel(self):
        # HParams
        hparams = trainer_lib.create_hparams(
            "transformer_tiny",
            data_dir=algorithmic.TinyAlgo.data_dir,
            problem_name="tiny_algo")

        # Dataset
        problem = hparams.problem
        dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN,
                                  algorithmic.TinyAlgo.data_dir)
        dataset = dataset.repeat(None).padded_batch(10, dataset.output_shapes)
        features = dataset.make_one_shot_iterator().get_next()
        features = problem_lib.standardize_shapes(features)

        # Model
        model = registry.model("transformer")(hparams,
                                              tf.estimator.ModeKeys.TRAIN)
        logits, losses = model(features)

        self.assertTrue("training" in losses)
        loss = losses["training"]

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            logits_val, loss_val = sess.run([logits, loss])
            logits_shape = list(logits_val.shape)
            logits_shape[1] = None
            self.assertAllEqual(logits_shape, [10, None, 1, 1, 4])
            self.assertEqual(loss_val.shape, tuple())
Exemplo n.º 2
0
  def testModel(self):
    # HParams
    hparams = trainer_lib.create_hparams(
        "transformer_tiny", data_dir=self.data_dir, problem_name="tiny_algo")

    # Dataset
    problem = hparams.problem
    dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN, self.data_dir)
    dataset = dataset.repeat(None).padded_batch(10, dataset.output_shapes)
    features = dataset.make_one_shot_iterator().get_next()
    features = problem_lib.standardize_shapes(features)

    # Model
    model = registry.model("transformer")(hparams, tf.estimator.ModeKeys.TRAIN)
    logits, losses = model(features)

    self.assertTrue("training" in losses)
    loss = losses["training"]

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      logits_val, loss_val = sess.run([logits, loss])
      logits_shape = list(logits_val.shape)
      logits_shape[1] = None
      self.assertAllEqual(logits_shape, [10, None, 1, 1, 4])
      self.assertEqual(loss_val.shape, tuple())
Exemplo n.º 3
0
    def testMultipleTargetModalities(self):
        # Use existing hparams and override target modality.
        hparams = trainer_lib.create_hparams(
            "transformer_tiny",
            data_dir=algorithmic.TinyAlgo.data_dir,
            problem_name="tiny_algo")
        # Manually turn off sharing. It is not currently supported for multitargets.
        hparams.shared_embedding_and_softmax_weights = 0  # pylint: disable=line-too-long
        hparams.problem_hparams.target_modality = {
            "targets": hparams.problem_hparams.target_modality,
            "A": hparams.problem_hparams.target_modality,
            "B": hparams.problem_hparams.target_modality,
        }
        hparams.problem._hparams = hparams.problem_hparams

        # Dataset
        problem = hparams.problem
        dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN,
                                  algorithmic.TinyAlgo.data_dir)
        dataset = dataset.repeat(None).padded_batch(10, dataset.output_shapes)
        features = dataset.make_one_shot_iterator().get_next()
        features = problem_lib.standardize_shapes(features)
        features["A"] = features["B"] = features["targets"]

        # Model
        model = registry.model("transformer")(hparams,
                                              tf.estimator.ModeKeys.TRAIN)

        def body(args, mb=model.body):
            out = mb(args)
            return {"targets": out, "A": out, "B": out}

        model.body = body

        logits, losses = model(features)

        self.assertTrue("training" in losses)
        loss = losses["training"]

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run([logits, loss])
Exemplo n.º 4
0
  def testMultipleTargetModalities(self):
    # Use existing hparams and override target modality.
    hparams = trainer_lib.create_hparams(
        "transformer_tiny", data_dir=algorithmic.TinyAlgo.data_dir,
        problem_name="tiny_algo")
    # Manually turn off sharing. It is not currently supported for multitargets.
    hparams.shared_embedding_and_softmax_weights = 0  # pylint: disable=line-too-long
    hparams.problem_hparams.modality = {
        "targets": hparams.problem_hparams.modality["targets"],
        "targets_A": hparams.problem_hparams.modality["targets"],
        "targets_B": hparams.problem_hparams.modality["targets"],
    }
    hparams.problem._hparams = hparams.problem_hparams

    # Dataset
    problem = hparams.problem
    dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN,
                              algorithmic.TinyAlgo.data_dir)
    dataset = dataset.repeat(None).padded_batch(10, dataset.output_shapes)
    features = dataset.make_one_shot_iterator().get_next()
    features = problem_lib.standardize_shapes(features)
    features["targets_A"] = features["targets_B"] = features["targets"]

    # Model
    model = registry.model("transformer")(hparams, tf.estimator.ModeKeys.TRAIN)

    def body(args, mb=model.body):
      out = mb(args)
      return {"targets": out, "targets_A": out, "targets_B": out}

    model.body = body

    logits, losses = model(features)

    self.assertTrue("training" in losses)
    loss = losses["training"]

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run([logits, loss])
Exemplo n.º 5
0
    def testMultipleTargetModalities(self):
        # HParams
        hparams = trainer_lib.create_hparams(
            "transformer_tiny",
            data_dir=algorithmic.TinyAlgo.data_dir,
            problem_name="tiny_algo")
        tm = hparams.problem.get_hparams().target_modality
        hparams.problem.get_hparams().target_modality = {
            "targets": tm,
            "A": tm,
            "B": tm
        }

        # Dataset
        problem = hparams.problem
        dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN,
                                  algorithmic.TinyAlgo.data_dir)
        dataset = dataset.repeat(None).padded_batch(10, dataset.output_shapes)
        features = dataset.make_one_shot_iterator().get_next()
        features = problem_lib.standardize_shapes(features)
        features["A"] = features["B"] = features["targets"]

        # Model
        model = registry.model("transformer")(hparams,
                                              tf.estimator.ModeKeys.TRAIN)

        def body(args, mb=model.body):
            out = mb(args)
            return {"targets": out, "A": out, "B": out}

        model.body = body

        logits, losses = model(features)

        self.assertTrue("training" in losses)
        loss = losses["training"]

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run([logits, loss])
Exemplo n.º 6
0
  def testMultipleTargetModalities(self):
    # HParams
    hparams = trainer_lib.create_hparams(
        "transformer_tiny", data_dir=self.data_dir, problem_name="tiny_algo")
    tm = hparams.problem.get_hparams().target_modality
    hparams.problem.get_hparams().target_modality = {
        "targets": tm,
        "A": tm,
        "B": tm
    }

    # Dataset
    problem = hparams.problem
    dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN, self.data_dir)
    dataset = dataset.repeat(None).padded_batch(10, dataset.output_shapes)
    features = dataset.make_one_shot_iterator().get_next()
    features = problem_lib.standardize_shapes(features)
    features["A"] = features["B"] = features["targets"]

    # Model
    model = registry.model("transformer")(hparams, tf.estimator.ModeKeys.TRAIN)

    def body(args, mb=model.body):
      out = mb(args)
      return {"targets": out, "A": out, "B": out}

    model.body = body

    logits, losses = model(features)

    self.assertTrue("training" in losses)
    loss = losses["training"]

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run([logits, loss])