def inference(it, num_workers, args): from tensorflowonspark import util # consume worker number from RDD partition iterator for i in it: worker_num = i print("worker_num: {}".format(i)) # setup env for single-node TF util.single_node_env() # load saved_model using default tag and signature sess = tf.Session() tf.saved_model.loader.load(sess, ['serve'], args.export) # parse function for TFRecords def parse_tfr(example_proto): feature_def = {"label": tf.FixedLenFeature(10, tf.int64), "image": tf.FixedLenFeature(IMAGE_PIXELS * IMAGE_PIXELS, tf.int64)} features = tf.parse_single_example(example_proto, feature_def) norm = tf.constant(255, dtype=tf.float32, shape=(784,)) image = tf.div(tf.to_float(features['image']), norm) label = tf.to_float(features['label']) return (image, label) # define a new tf.data.Dataset (for inferencing) ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels)) ds = ds.shard(num_workers, worker_num) ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=1) ds = ds.map(parse_tfr).batch(10) iterator = ds.make_one_shot_iterator() image_label = iterator.get_next(name='inf_image') # create an output file per spark worker for the predictions tf.gfile.MakeDirs(args.output) output_file = tf.gfile.GFile("{}/part-{:05d}".format(args.output, worker_num), mode='w') while True: try: # get images and labels from tf.data.Dataset img, lbl = sess.run(['inf_image:0', 'inf_image:1']) # inference by feeding these images and labels into the input tensors # you can view the exported model signatures via: # saved_model_cli show --dir <export_dir> --all # note that we feed directly into the graph tensors (bypassing the exported signatures) # these tensors will be shown in the "name" field of the signature definitions outputs = sess.run(['dense_2/Softmax:0'], feed_dict={'Placeholder:0': img}) for p in outputs[0]: output_file.write("{}\n".format(np.argmax(p))) except tf.errors.OutOfRangeError: break output_file.close()
def inference(it, num_workers, args): from tensorflowonspark import util # consume worker number from RDD partition iterator for i in it: worker_num = i print("worker_num: {}".format(i)) # setup env for single-node TF util.single_node_env() # load saved_model saved_model = tf.saved_model.load(args.export_dir, tags='serve') predict = saved_model.signatures['serving_default'] # parse function for TFRecords def parse_tfr(example_proto): feature_def = { "label": tf.io.FixedLenFeature(1, tf.int64), "image": tf.io.FixedLenFeature(784, tf.int64) } features = tf.io.parse_single_example(serialized=example_proto, features=feature_def) image = tf.cast(features['image'], dtype=tf.float32) / 255.0 image = tf.reshape(image, [28, 28, 1]) label = tf.cast(features['label'], dtype=tf.float32) return (image, label) # define a new tf.data.Dataset (for inferencing) ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels), shuffle=False) ds = ds.shard(num_workers, worker_num) ds = ds.interleave(tf.data.TFRecordDataset) ds = ds.map(parse_tfr) ds = ds.batch(10) # create an output file per spark worker for the predictions tf.io.gfile.makedirs(args.output) output_file = tf.io.gfile.GFile("{}/part-{:05d}".format( args.output, worker_num), mode='w') for batch in ds: predictions = predict(conv2d_input=batch[0]) labels = np.reshape(batch[1], -1).astype(np.int) preds = np.argmax(predictions['logits'], axis=1) for x in zip(labels, preds): output_file.write("{} {}\n".format(x[0], x[1])) output_file.close()
def inference(it, num_workers, args): from tensorflowonspark import util os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = args.keyfile # consume worker number from RDD partition iterator for i in it: worker_num = i print("worker_num: {}".format(i)) # setup env for single-node TF util.single_node_env() # load saved_model using default tag and signature model_dir = args.model_dir tag_set = 'serve' signature_def_key = 'serving_default' sess = tf.Session() tf.saved_model.loader.load(sess, [tag_set], model_dir) # define a new tf.data.Dataset (for inferencing) ds = tf.data.Dataset.list_files("{}*".format(args.input_dir)) ds = ds.interleave(tf.data.TextLineDataset, cycle_length=1) ds = ds.shard(num_workers, worker_num) ds = ds.batch(100) iterator = ds.make_one_shot_iterator() input = iterator.get_next() # create an output file per spark worker for the predictions tf.gfile.MakeDirs(args.output_dir) file_path = "{}/prediction.results-{:05d}-of-{:05d}".format( args.output_dir, worker_num, num_workers) print("worker({0}) write to {1}".format(worker_num, file_path)) output_file = tf.gfile.GFile(file_path, mode='w') inputs, outputs = get_input_and_output_names(model_dir, tag_set, signature_def_key) sorted_outputs = sorted(outputs.items()) output_keys = [key for key, value in sorted_outputs] output_names = [value for key, value in sorted_outputs] while True: try: dataset = sess.run(input) result = sess.run(output_names, feed_dict={inputs.values()[0]: dataset}) cols = len(result) rows = len(result[0]) for row_num in range(rows): row = bytearray("{") for col_num in range(cols): if col_num > 0: row.extend(",") key = output_keys[col_num] value = result[col_num][row_num] row.extend(bytearray("'")) row.extend( bytearray(key.decode(default_codec), default_codec)) row.extend(bytearray("':")) if isinstance(value, str): row.extend(bytearray("'")) row.extend( bytearray(value.decode(default_codec), default_codec)) row.extend(bytearray("'")) else: row.extend(bytearray(str(value))) row.extend(bytearray("}\n")) output_file.write(bytes(row)) except tf.errors.OutOfRangeError: break output_file.close()