示例#1
0
def convert_resource_gather(name, tf_node, inputs, uff_graph, **kwargs):
    if len(inputs) > 2:
        tf_axis_node = kwargs["tf_nodes"][inputs[-1]]
        axis = int(tf2uff.convert_tf2numpy_const_node(tf_axis_node))
        inputs = inputs[:-1]
    else:
        axis = 0
    indices_dtype = tf2uff.convert_tf2numpy_dtype(tf_node.attr['Tindices'].type)
    params_dtype = tf2uff.convert_tf2numpy_dtype(tf_node.attr['dtype'].type)
    uff_graph.gather_v2(inputs, name, axis, indices_dtype, params_dtype)
    return [tf2uff.split_node_name_and_output(inp)[0] for inp in inputs]
示例#2
0
def convert_gather(name, tf_node, inputs, uff_graph, **kwargs):
    indices_dtype = tf2uff.convert_tf2numpy_dtype(tf_node.attr['Tindices'].type)
    params_dtype = tf2uff.convert_tf2numpy_dtype(tf_node.attr['Tparams'].type)
    validate_indices = tf_node.attr['validate_indices'].b
    uff_graph.gather(inputs, name, indices_dtype, params_dtype, validate_indices)
    return [tf2uff.split_node_name_and_output(inp)[0] for inp in inputs]
示例#3
0
def convert_placeholder(name, tf_node, inputs, uff_graph, **kwargs):
    dtype = tf2uff.convert_tf2numpy_dtype(tf_node.attr['dtype'].type)
    shape = tf2uff.get_tf_shape_as_int_list(tf_node.attr['shape'])
    uff_graph.input(shape, dtype, name)
    return [tf2uff.split_node_name_and_output(inp)[0] for inp in inputs]