def test_tf_net(self): resource_path = os.path.join( os.path.split(__file__)[0], "../../../resources") tfnet_path = os.path.join(resource_path, "tfnet") net = TFNet.from_export_folder(tfnet_path) output = net.forward(np.random.rand(32, 28, 28, 1)) assert output.shape == (32, 10)
def test_tf_net_predict_dataset(self): resource_path = os.path.join(os.path.split(__file__)[0], "../../resources") tfnet_path = os.path.join(resource_path, "tfnet") net = TFNet.from_export_folder(tfnet_path) dataset = TFDataset.from_ndarrays((np.random.rand(16, 4),)) output = net.predict(dataset) output = np.stack(output.collect()) assert output.shape == (16, 2)
def test_tf_net_predict(self): resource_path = os.path.join(os.path.split(__file__)[0], "../../resources") tfnet_path = os.path.join(resource_path, "tfnet") import tensorflow as tf tf_session_config = tf.ConfigProto(inter_op_parallelism_threads=1, intra_op_parallelism_threads=1) net = TFNet.from_export_folder(tfnet_path, tf_session_config=tf_session_config) output = net.predict(np.random.rand(16, 4), batch_per_thread=5, distributed=False) assert output.shape == (16, 2)
logits, end_points = inception_v1(images, num_classes=1001) sess = tf.Session() saver = tf.train.Saver() saver.restore(sess, "file:///home/hduser/slim/checkpoint/inception_v1.ckpt") #saver.restore(sess, "hdfs:///slim/checkpoint/inception_v1.ckpt") # You need to edit this path to the checkpoint you downloaded from zoo.util.tf import export_tf avg_pool = end_points['Mixed_3c'] export_tf(sess, "file:///home/hduser/slim/tfnet/", inputs=[images], outputs=[avg_pool]) from zoo.pipeline.api.net import TFNet amodel = TFNet.from_export_folder("file:///home/hduser/slim/tfnet/") from bigdl.nn.layer import Sequential, Transpose, Contiguous, Linear, ReLU, SoftMax, Reshape, View, MulConstant, SpatialAveragePooling full_model = Sequential() full_model.add(Transpose([(2, 4), (2, 3)])) scalar = 1. / 255 full_model.add(MulConstant(scalar)) full_model.add(Contiguous()) full_model.add(amodel) full_model.add(View([1024])) full_model.add(Linear(1024, 5)) import re from bigdl.nn.criterion import CrossEntropyCriterion from pyspark import SparkConf from pyspark.ml import Pipeline from pyspark.sql import SQLContext from pyspark.sql.functions import col, udf