示例#1
0
 def testMultiModel(self):
   x = np.random.random_integers(0, high=255, size=(3, 5, 4, 3))
   y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1))
   hparams = slicenet.slicenet_params1_tiny()
   p_hparams = problem_hparams.image_cifar10(hparams)
   hparams.problems = [p_hparams]
   with self.test_session() as session:
     features = {
         "inputs": tf.constant(x, dtype=tf.int32),
         "targets": tf.constant(y, dtype=tf.int32),
         "target_space_id": tf.constant(1, dtype=tf.int32),
     }
     model = multimodel.MultiModel(hparams, p_hparams)
     sharded_logits, _, _ = model.model_fn(features, True)
     logits = tf.concat(sharded_logits, 0)
     session.run(tf.global_variables_initializer())
     res = session.run(logits)
   self.assertEqual(res.shape, (3, 1, 1, 1, 10))
 def testMultiModel(self):
     x = np.random.random_integers(0, high=255, size=(3, 5, 5, 3))
     y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1))
     hparams = multimodel.multimodel_tiny()
     hparams.add_hparam("data_dir", "")
     problem = registry.problem("image_cifar10")
     p_hparams = problem.get_hparams(hparams)
     hparams.problems = [p_hparams]
     with self.test_session() as session:
         features = {
             "inputs": tf.constant(x, dtype=tf.int32),
             "targets": tf.constant(y, dtype=tf.int32),
             "target_space_id": tf.constant(1, dtype=tf.int32),
         }
         model = multimodel.MultiModel(hparams, tf.estimator.ModeKeys.TRAIN,
                                       p_hparams)
         logits, _ = model(features)
         session.run(tf.global_variables_initializer())
         res = session.run(logits)
     self.assertEqual(res.shape, (3, 1, 1, 1, 10))