Exemplo n.º 1
0
 def test_tf_net(self):
     resource_path = os.path.join(
         os.path.split(__file__)[0], "../../../resources")
     tfnet_path = os.path.join(resource_path, "tfnet")
     net = TFNet.from_export_folder(tfnet_path)
     output = net.forward(np.random.rand(32, 28, 28, 1))
     assert output.shape == (32, 10)
Exemplo n.º 2
0
 def test_tf_net_predict_dataset(self):
     resource_path = os.path.join(os.path.split(__file__)[0], "../../resources")
     tfnet_path = os.path.join(resource_path, "tfnet")
     net = TFNet.from_export_folder(tfnet_path)
     dataset = TFDataset.from_ndarrays((np.random.rand(16, 4),))
     output = net.predict(dataset)
     output = np.stack(output.collect())
     assert output.shape == (16, 2)
Exemplo n.º 3
0
 def test_tf_net_predict(self):
     resource_path = os.path.join(os.path.split(__file__)[0], "../../resources")
     tfnet_path = os.path.join(resource_path, "tfnet")
     import tensorflow as tf
     tf_session_config = tf.ConfigProto(inter_op_parallelism_threads=1,
                                        intra_op_parallelism_threads=1)
     net = TFNet.from_export_folder(tfnet_path, tf_session_config=tf_session_config)
     output = net.predict(np.random.rand(16, 4), batch_per_thread=5, distributed=False)
     assert output.shape == (16, 2)
Exemplo n.º 4
0
    logits, end_points = inception_v1(images, num_classes=1001)

sess = tf.Session()
saver = tf.train.Saver()
saver.restore(sess, "file:///home/hduser/slim/checkpoint/inception_v1.ckpt")
#saver.restore(sess, "hdfs:///slim/checkpoint/inception_v1.ckpt")
# You need to edit this path to the checkpoint you downloaded

from zoo.util.tf import export_tf
avg_pool = end_points['Mixed_3c']
export_tf(sess,
          "file:///home/hduser/slim/tfnet/",
          inputs=[images],
          outputs=[avg_pool])
from zoo.pipeline.api.net import TFNet
amodel = TFNet.from_export_folder("file:///home/hduser/slim/tfnet/")
from bigdl.nn.layer import Sequential, Transpose, Contiguous, Linear, ReLU, SoftMax, Reshape, View, MulConstant, SpatialAveragePooling
full_model = Sequential()
full_model.add(Transpose([(2, 4), (2, 3)]))
scalar = 1. / 255
full_model.add(MulConstant(scalar))
full_model.add(Contiguous())
full_model.add(amodel)
full_model.add(View([1024]))
full_model.add(Linear(1024, 5))
import re
from bigdl.nn.criterion import CrossEntropyCriterion
from pyspark import SparkConf
from pyspark.ml import Pipeline
from pyspark.sql import SQLContext
from pyspark.sql.functions import col, udf