def testSliceNet(self): x = np.random.random_integers(0, high=255, size=(3, 5, 4, 3)) y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1)) hparams = slicenet.slicenet_params1_tiny() p_hparams = problem_hparams.image_cifar10(hparams) hparams.problems = [p_hparams] with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), "targets": tf.constant(y, dtype=tf.int32), "target_space_id": tf.constant(1, dtype=tf.int32), } model = slicenet.SliceNet(hparams, p_hparams) sharded_logits, _, _ = model.model_fn(features, True) logits = tf.concat(sharded_logits, 0) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (3, 1, 1, 1, 10))
def testSliceNet(self): x = np.random.randint(256, size=(3, 5, 5, 3)) y = np.random.randint(10, size=(3, 5, 1, 1)) hparams = slicenet.slicenet_params1_tiny() hparams.add_hparam("data_dir", "") problem = registry.problem("image_cifar10") p_hparams = problem.get_hparams(hparams) hparams.problem_hparams = p_hparams with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), "targets": tf.constant(y, dtype=tf.int32), "target_space_id": tf.constant(1, dtype=tf.int32), } model = slicenet.SliceNet(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (3, 1, 1, 1, 10))
def testSliceNet(self): x = np.random.random_integers(0, high=255, size=(3, 5, 5, 3)) y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1)) hparams = slicenet.slicenet_params1_tiny() hparams.add_hparam("data_dir", "") problem = registry.problem("image_cifar10") p_hparams = problem.internal_hparams(hparams) hparams.problems = [p_hparams] with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), "targets": tf.constant(y, dtype=tf.int32), "target_space_id": tf.constant(1, dtype=tf.int32), } model = slicenet.SliceNet(hparams, tf.contrib.learn.ModeKeys.TRAIN, p_hparams) sharded_logits, _ = model.model_fn(features) logits = tf.concat(sharded_logits, 0) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (3, 1, 1, 1, 10))