def predict_by_neural_network(keypoint_coord3d_v, known_finger_poses, pb_file, threshold): detection_graph = tf.Graph() score_label = 'Undefined' with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(pb_file, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') with tf.Session(graph=detection_graph) as sess: input_tensor = detection_graph.get_tensor_by_name('input:0') output_tensor = detection_graph.get_tensor_by_name('output:0') flat_keypoint = np.array( [entry for sublist in keypoint_coord3d_v for entry in sublist]) flat_keypoint = np.expand_dims(flat_keypoint, axis=0) outputs = sess.run(output_tensor, feed_dict={input_tensor: flat_keypoint})[0] max_index = np.argmax(outputs) score_index = max_index if outputs[max_index] >= threshold else -1 score_label = 'Undefined' if score_index == -1 else get_position_name_with_pose_id( score_index, known_finger_poses) # print(outputs) return score_label
def predict_by_svm(keypoint_coord3d_v, known_finger_poses, svc_file): with open(svc_file, 'rb') as handle: svc = pickle.load(handle) flat_keypoint = np.array([entry for sublist in keypoint_coord3d_v for entry in sublist]) flat_keypoint = np.expand_dims(flat_keypoint, axis = 0) max_index = svc.predict(flat_keypoint)[0] score_label = get_position_name_with_pose_id(max_index, known_finger_poses) return score_label