Exemplo n.º 1
0
 def _test_xception(self, img_size):
     vocab_size = 9
     batch_size = 3
     x = np.random.random_integers(0,
                                   high=255,
                                   size=(batch_size, img_size, img_size, 3))
     y = np.random.random_integers(1,
                                   high=vocab_size - 1,
                                   size=(batch_size, 1, 1, 1))
     hparams = xception.xception_tiny()
     p_hparams = problem_hparams.test_problem_hparams(
         vocab_size, vocab_size, hparams)
     p_hparams.input_modality["inputs"] = modalities.ImageModality(hparams)
     p_hparams.target_modality = modalities.ClassLabelModality(
         hparams, vocab_size)
     with self.test_session() as session:
         features = {
             "inputs": tf.constant(x, dtype=tf.int32),
             "targets": tf.constant(y, dtype=tf.int32),
         }
         model = xception.Xception(hparams, tf.estimator.ModeKeys.TRAIN,
                                   p_hparams)
         logits, _ = model(features)
         session.run(tf.global_variables_initializer())
         res = session.run(logits)
     self.assertEqual(res.shape, (batch_size, 1, 1, 1, vocab_size))
Exemplo n.º 2
0
 def _testXception(self, img_size, output_size):
     vocab_size = 9
     batch_size = 3
     x = np.random.random_integers(0,
                                   high=255,
                                   size=(batch_size, img_size, img_size, 3))
     y = np.random.random_integers(1,
                                   high=vocab_size - 1,
                                   size=(batch_size, 1, 1, 1))
     hparams = xception.xception_tiny()
     p_hparams = problem_hparams.test_problem_hparams(
         vocab_size, vocab_size)
     p_hparams.input_modality["inputs"] = (registry.Modalities.IMAGE, None)
     with self.test_session() as session:
         features = {
             "inputs": tf.constant(x, dtype=tf.int32),
             "targets": tf.constant(y, dtype=tf.int32),
         }
         model = xception.Xception(hparams, tf.estimator.ModeKeys.TRAIN,
                                   p_hparams)
         sharded_logits, _ = model.model_fn(features)
         logits = tf.concat(sharded_logits, 0)
         session.run(tf.global_variables_initializer())
         res = session.run(logits)
     self.assertEqual(res.shape, output_size + (1, vocab_size))
Exemplo n.º 3
0
 def testXception(self):
   vocab_size = 9
   x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1))
   y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1))
   hparams = xception.xception_tiny()
   p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size)
   with self.test_session() as session:
     features = {
         "inputs": tf.constant(x, dtype=tf.int32),
         "targets": tf.constant(y, dtype=tf.int32),
     }
     model = xception.Xception(
         hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
     sharded_logits, _ = model.model_fn(features)
     logits = tf.concat(sharded_logits, 0)
     session.run(tf.global_variables_initializer())
     res = session.run(logits)
   self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size))
Exemplo n.º 4
0
 def _test_xception(self, img_size):
     vocab_size = 9
     batch_size = 3
     x = np.random.randint(256, size=(batch_size, img_size, img_size, 3))
     y = np.random.randint(1, high=vocab_size, size=(batch_size, 1, 1, 1))
     hparams = xception.xception_tiny()
     p_hparams = problem_hparams.test_problem_hparams(
         vocab_size, vocab_size, hparams)
     p_hparams.modality["inputs"] = modalities.ModalityType.IMAGE
     p_hparams.modality["targets"] = modalities.ModalityType.CLASS_LABEL
     with self.test_session() as session:
         features = {
             "inputs": tf.constant(x, dtype=tf.int32),
             "targets": tf.constant(y, dtype=tf.int32),
         }
         model = xception.Xception(hparams, tf_estimator.ModeKeys.TRAIN,
                                   p_hparams)
         logits, _ = model(features)
         session.run(tf.global_variables_initializer())
         res = session.run(logits)
     self.assertEqual(res.shape, (batch_size, 1, 1, 1, vocab_size))