def test_3d_shape(self): input_shape = (2, 32, 32, 32, 1) x = tf.ones(input_shape) toynet_instance = ToyNet(num_classes=160) out = toynet_instance(x, is_training=True) with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) out = sess.run(out) self.assertAllClose((2, 32, 32, 32, 160), out.shape)
def get_test_network(): net = ToyNet(num_classes=4) return net