def _test_xception(self, img_size): vocab_size = 9 batch_size = 3 x = np.random.random_integers(0, high=255, size=(batch_size, img_size, img_size, 3)) y = np.random.random_integers(1, high=vocab_size - 1, size=(batch_size, 1, 1, 1)) hparams = xception.xception_tiny() p_hparams = problem_hparams.test_problem_hparams( vocab_size, vocab_size, hparams) p_hparams.input_modality["inputs"] = modalities.ImageModality(hparams) p_hparams.target_modality = modalities.ClassLabelModality( hparams, vocab_size) with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), "targets": tf.constant(y, dtype=tf.int32), } model = xception.Xception(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (batch_size, 1, 1, 1, vocab_size))
def _testXception(self, img_size, output_size): vocab_size = 9 batch_size = 3 x = np.random.random_integers(0, high=255, size=(batch_size, img_size, img_size, 3)) y = np.random.random_integers(1, high=vocab_size - 1, size=(batch_size, 1, 1, 1)) hparams = xception.xception_tiny() p_hparams = problem_hparams.test_problem_hparams( vocab_size, vocab_size) p_hparams.input_modality["inputs"] = (registry.Modalities.IMAGE, None) with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), "targets": tf.constant(y, dtype=tf.int32), } model = xception.Xception(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) sharded_logits, _ = model.model_fn(features) logits = tf.concat(sharded_logits, 0) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, output_size + (1, vocab_size))
def testXception(self): vocab_size = 9 x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1)) hparams = xception.xception_tiny() p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), "targets": tf.constant(y, dtype=tf.int32), } model = xception.Xception( hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) sharded_logits, _ = model.model_fn(features) logits = tf.concat(sharded_logits, 0) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size))
def _test_xception(self, img_size): vocab_size = 9 batch_size = 3 x = np.random.randint(256, size=(batch_size, img_size, img_size, 3)) y = np.random.randint(1, high=vocab_size, size=(batch_size, 1, 1, 1)) hparams = xception.xception_tiny() p_hparams = problem_hparams.test_problem_hparams( vocab_size, vocab_size, hparams) p_hparams.modality["inputs"] = modalities.ModalityType.IMAGE p_hparams.modality["targets"] = modalities.ModalityType.CLASS_LABEL with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), "targets": tf.constant(y, dtype=tf.int32), } model = xception.Xception(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (batch_size, 1, 1, 1, vocab_size))