コード例 #1
0
ファイル: test_tfnet.py プロジェクト: zzti-bsj/analytics-zoo
 def test_tf_net_predict_dataset(self):
     tfnet_path = os.path.join(TestTF.resource_path, "tfnet")
     net = TFNet.from_export_folder(tfnet_path)
     dataset = TFDataset.from_ndarrays((np.random.rand(16, 4), ))
     output = net.predict(dataset)
     output = np.stack(output.collect())
     assert output.shape == (16, 2)
コード例 #2
0
ファイル: test_tfnet.py プロジェクト: zzti-bsj/analytics-zoo
 def test_tf_net_predict(self):
     tfnet_path = os.path.join(TestTF.resource_path, "tfnet")
     import tensorflow as tf
     tf_session_config = tf.ConfigProto(inter_op_parallelism_threads=1,
                                        intra_op_parallelism_threads=1)
     net = TFNet.from_export_folder(tfnet_path,
                                    tf_session_config=tf_session_config)
     output = net.predict(np.random.rand(16, 4),
                          batch_per_thread=5,
                          distributed=False)
     assert output.shape == (16, 2)
コード例 #3
0
ファイル: test_tfnet.py プロジェクト: zzti-bsj/analytics-zoo
 def test_init_tf_net(self):
     tfnet_path = os.path.join(TestTF.resource_path, "tfnet")
     net = TFNet.from_export_folder(tfnet_path)
     output = net.forward(np.random.rand(2, 4))
     assert output.shape == (2, 2)