def test_inference(self): gqcnn = NeuralNetWork(self.config) config_ = self.config['gqcnn_config'] with gqcnn.graph.as_default(): image = tf.placeholder(tf.float32, [ None, config_['im_width'], config_['im_height'], config_['im_channels'] ], name='image_input') pose = tf.placeholder(tf.float32, [None, config_['pose_dim']], name='pose_input') out = gqcnn.inference(image, pose) self.assertIsNotNone(out)
def test_dataset(self): cfg = self.config['training'] network = NeuralNetWork(self.config, training=True) train = GQCNNTraing(self.config, network, DATA_PATH, OUT_PATH) with tf.Session(graph=train._network.graph) as sess: sess.run(train._train_iterator.initializer) im, pose, label = sess.run(train._train_iterator.get_next()) im_shape = (cfg['train_batch_size'], cfg['im_height'], cfg['im_width'], cfg['im_channels']) pose_shape = (cfg['train_batch_size'], cfg['pose_len']) label_shape = (cfg['train_batch_size'], ) self.assertEqual(im.shape, im_shape) self.assertEqual(pose.shape, pose_shape) self.assertEqual(label.shape, label_shape)
def test_load_npz(self): gqcnn = NeuralNetWork(self.config) with tf.Session(graph=gqcnn.graph, config=self.gpu_config) as sess: gqcnn.load_weights(sess, SAVE_PATH)
def test_save(self): gqcnn = NeuralNetWork(self.config) with tf.Session(graph=gqcnn.graph, config=self.gpu_config) as sess: gqcnn.load_weights(sess, MODEL_PATH, remap=True) gqcnn.save_to_npz(sess, SAVE_PATH)
def test_load_ckpt(self): gqcnn = NeuralNetWork(self.config) with tf.Session(graph=gqcnn.graph, config=self.gpu_config) as sess: gqcnn.load_weights(sess, MODEL_PATH, remap=True)
def test_create(self): gqcnn = NeuralNetWork(self.config) self.assertIsNotNone(gqcnn)
def main(): config_logging(TEST_LOG_FILE) config = load_config(TEST_CFG_FILE) network = NeuralNetWork(config, training=True) train = GQCNNTraing(config, network, DATA_PATH, OUT_PATH) train.optimize(50)
def test_create(self): network = NeuralNetWork(self.config, training=True) train = GQCNNTraing(self.config, network, DATA_PATH, OUT_PATH) # train.optimize(10) self.assertIsNotNone(train)