Beispiel #1
0
def graph_to_function_v2(graph: Union[GraphDef, tf.Graph]) -> Callable:
    """Wrap a GraphDef or TF1 frozen graph in a TF2 function for easy inference

    Use this function to convert a GraphDef returned by `load_graph_model` or
    a TF v1 frozen graph into a callable TF2 function.

    Args:
        graph: GraphDef protocol buffer message or TF1 frozen graph

    Returns:
        The function returns a TF2 wrapped function that is callable with
        input tensors or `numpy` arrays as arguments and returns a list of
        model outputs as tensors.
    """
    graph_def = graph.as_graph_def() if isinstance(graph, tf.Graph) else graph

    def _imports_graph_def():
        tf.graph_util.import_graph_def(graph_def, name='')

    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph
    inputs = util.get_input_tensors(graph_def)
    outputs = util.get_output_tensors(graph_def)
    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs))
Beispiel #2
0
def capture_display_video(vc, sess, graph, nb_frames, fake_camera):
    i = 0
    input_tensor_names = tfjs_util.get_input_tensors(graph)
    output_tensor_names = tfjs_util.get_output_tensors(graph)

    ret, original_frame = vc.read()
    mask = np.zeros(original_frame.shape)

    background = (
        cv2.resize(cv2.imread(chosen_background), (original_frame.shape[1], original_frame.shape[0]))
        if chosen_background and chosen_background is not "blur"
        else original_frame
    )

    while ret:
        if i % nb_frames == 0 and chosen_background:
            start = time.time()
            if chosen_background == "blur":
                background = cv2.GaussianBlur(original_frame, (blurring_kernel_size, blurring_kernel_size), 0)

            mask = process_frame(
                original_frame,
                sess,
                graph,
                input_tensor_names,
                output_tensor_names,
                threshold,
                output_stride,
                model,
            )

            processing_time = time.time() - start
            start = time.time()

            # Smoothing the mask so there's no "stairs"
            mask = cv2.cvtColor((mask).astype(np.float32), cv2.COLOR_GRAY2BGR)
            mask = cv2.resize(mask, (original_frame.shape[1],
                                    original_frame.shape[0]),
                            interpolation=cv2.INTER_CUBIC)
            mask = cv2.GaussianBlur(mask, (31,31), 31)
            second_text.text(f"Model took: {round(processing_time, 3)}. Alpha: {round(time.time() - start, 3)}")

        displayed_frame = cv2.convertScaleAbs(original_frame*mask + background*(1-mask)) if chosen_background else original_frame

        if show_preview:
            placeholder.image(displayed_frame, channels="BGR")
        if write_to_device:
            fake_camera.schedule_frame(cv2.cvtColor(displayed_frame, cv2.COLOR_BGR2RGB))
        text_placeholder.text(f'Preview: {show_preview} -- v4l2: {write_to_device}')

        ret, original_frame = vc.read()
        i += 1

        if i == nb_frames+1:
            i = 0

    vc.release()
    print("Model: resnet50 (stride={stride})".format(stride=output_stride))
    model_path = 'bodypix_resnet50_float_model-stride{stride}'.format(
        stride=output_stride)
else:
    print('Unknown model type. Use "mobilenet" or "resnet50".')
    sys.exit(1)

# Load the tensorflow model
print("Loading model...")
graph = tfjs_api.load_graph_model(model_path)
print("done.")

# Setup the tensorflow session
sess = tf.compat.v1.Session(graph=graph)

input_tensor_names = tfjs_util.get_input_tensors(graph)
output_tensor_names = tfjs_util.get_output_tensors(graph)
input_tensor = graph.get_tensor_by_name(input_tensor_names[0])

# Initialize layers
layers = reload_layers(config)

static_image = None
for extension in ["jpg", "jpeg", "png"]:
    if config['real_video_device'].lower().endswith(extension):
        success, static_image = cap.read()


def mainloop():
    global config, masks, layers, config_mtime
Beispiel #4
0
 def test_get_input_tensors(self):
     """Should return tensor names for inputs"""
     graph_def = testutils.get_sample_graph_def()
     actual = util.get_input_tensors(graph_def)
     expected = [(n.name + ':0') for n in testutils.get_inputs(graph_def)]
     self.assertEqual(actual, expected)
Beispiel #5
0
def get_tensors_graph(graph):
    input_tensor_names = tfjsutil.get_input_tensors(graph)
    output_tensor_names = tfjsutil.get_output_tensors(graph)
    input_tensor = graph.get_tensor_by_name(input_tensor_names[0])

    return input_tensor, output_tensor_names