Exemplo n.º 1
0
def read(path, parser_configs=None):
    # type: (str, typing.Optional[typing.List[NNEFParserConfig]])->NNEFGraph

    if not (path.endswith('.tgz') or path.endswith('.nnef')
            or path.endswith('.txt') or os.path.isdir(path)):
        raise utils.NNEFToolsException(
            "Only .tgz or .nnef or .txt files or directories are supported")

    parser_config = NNEFParserConfig.combine_configs(
        parser_configs if parser_configs else [])

    path_to_load = None
    compressed = False

    try:
        if os.path.isdir(path):
            compressed = False
            with_weights = True
            path_to_load = path
        elif path.endswith('.tgz'):
            compressed = True
            with_weights = True
            path_to_load = tempfile.mkdtemp(prefix="nnef_")
            utils.tgz_extract(path, path_to_load)
        elif path.endswith('.nnef') or path.endswith('.txt'):
            compressed = False
            with_weights = False
            path_to_load = path
        else:
            assert False

        # If there are fragments in the graph and also in parser_config
        # we remove the non-standard fragments from parser_config to avoid duplicate fragment definition
        if parser_config.fragments:
            re_graph = re.compile(r"^graph\s|\sgraph\s")
            re_fragment = re.compile(r"^fragment\s|\sfragment\s")
            graph_nnef_path = os.path.join(
                path_to_load,
                'graph.nnef') if os.path.isdir(path_to_load) else path_to_load
            with open(graph_nnef_path, 'r') as f:
                while True:
                    line = f.readline()
                    if not line:
                        break
                    if re_fragment.search(line):
                        parser_config.fragments = NNEFParserConfig.STANDARD_CONFIG.fragments
                        break
                    if re_graph.search(line):
                        break

        return _read(parser_graph=parser_config.infer_shapes(
            parser_config.load_graph(path_to_load)),
                     with_weights=with_weights)

    finally:
        if compressed and path_to_load:
            shutil.rmtree(path_to_load)
Exemplo n.º 2
0
def get_writer(input_format,
               output_format,
               compress,
               with_weights=True,
               custom_converters=None):
    if output_format == 'nnef':
        from nnef_tools.io.nnef.nnef_io import Writer
        fragments = NNEFParserConfig.combine_configs(
            NNEFParserConfig.load_configs(custom_converters,
                                          load_standard=False)).fragments
        return Writer(write_weights=with_weights,
                      fragments=fragments,
                      only_print_used_fragments=True,
                      compression_level=compress if compress >= 0 else 0)
    elif output_format == 'tensorflow-py':
        from nnef_tools.io.tensorflow.tf_py_io import Writer

        if custom_converters:
            custom_imports, custom_op_protos = get_tf_py_imports_and_op_protos(
                custom_converters)
        else:
            custom_imports, custom_op_protos = None, None

        return Writer(write_weights=with_weights,
                      custom_op_protos=custom_op_protos,
                      custom_imports=custom_imports)
    elif output_format == 'tensorflow-pb':
        from nnef_tools.io.tensorflow.tf_pb_io import Writer
        return Writer(convert_from_tf_py=True)
    elif output_format == 'tensorflow-lite':
        from nnef_tools.io.tensorflow.tflite_io import Writer
        return Writer(convert_from_tf_py=True)
    elif output_format == 'onnx':
        from nnef_tools.io.onnx.onnx_io import Writer
        return Writer()
    elif output_format == 'caffe':
        from nnef_tools.io.caffe.caffe_io import Writer
        return Writer()
    elif output_format == 'caffe2':
        from nnef_tools.io.caffe2.caffe2_io import Writer
        return Writer()
    else:
        assert False
Exemplo n.º 3
0
def read(path, parser_configs=None):
    # type: (str, typing.Optional[typing.List[NNEFParserConfig]])->NNEFGraph

    assert path.endswith('.tgz') or path.endswith('.nnef') or os.path.isdir(path), \
        "Only .tgz or .nnef files or directories are supported"

    parser_config = NNEFParserConfig.combine_configs(
        parser_configs if parser_configs else [])

    path_to_load = None
    compressed = False

    try:
        if os.path.isdir(path):
            compressed = False
            with_weights = True
            path_to_load = path
        elif path.endswith('.tgz'):
            compressed = True
            with_weights = True
            path_to_load = tempfile.mkdtemp(prefix="nnef_")
            _tgz_extract(path, path_to_load)
        elif path.endswith('.nnef'):
            compressed = False
            with_weights = False
            path_to_load = path
        else:
            assert False

        return _read(parser_graph=parser_config.infer_shapes(
            parser_config.load_graph(path_to_load)),
                     with_weights=with_weights)

    finally:
        if compressed and path_to_load:
            shutil.rmtree(path_to_load)
Exemplo n.º 4
0
def read(
    path,  # type: str
    parser_configs=None,  # type: typing.Optional[typing.List[NNEFParserConfig]]
    input_shape=None,  # type: typing.Union[None, typing.List[int], typing.Dict[str, typing.List[int]]]
):
    # type: (...)->NNEFGraph

    if not (path.endswith('.tgz') or path.endswith('.nnef')
            or path.endswith('.txt') or os.path.isdir(path)):
        raise utils.NNEFToolsException(
            "Only .tgz or .nnef or .txt files or directories are supported")

    parser_config = NNEFParserConfig.combine_configs(
        parser_configs if parser_configs else [])

    path_to_load = None
    compressed = False

    try:
        if os.path.isdir(path):
            compressed = False
            with_weights = True
            path_to_load = path
        elif path.endswith('.tgz'):
            compressed = True
            with_weights = True
            path_to_load = tempfile.mkdtemp(prefix="nnef_")
            utils.tgz_extract(path, path_to_load)
        elif path.endswith('.nnef') or path.endswith('.txt'):
            compressed = False
            with_weights = False
            path_to_load = path
        else:
            assert False

        # If there are fragments in the graph and also in parser_config
        # we remove the non-standard fragments from parser_config to avoid duplicate fragment definition
        if parser_config.fragments:
            re_graph = re.compile(r"^graph\s|\sgraph\s")
            re_fragment = re.compile(r"^fragment\s|\sfragment\s")
            graph_nnef_path = os.path.join(
                path_to_load,
                'graph.nnef') if os.path.isdir(path_to_load) else path_to_load
            with open(graph_nnef_path, 'r') as f:
                while True:
                    line = f.readline()
                    if not line:
                        break
                    if re_fragment.search(line):
                        parser_config.fragments = NNEFParserConfig.STANDARD_CONFIG.fragments
                        break
                    if re_graph.search(line):
                        break

        parser_graph = parser_config.load_graph(path_to_load)

        if input_shape is not None:
            if not isinstance(input_shape, (list, dict)):
                raise utils.NNEFToolsException(
                    "input_shape must be list or dict")

            for op in parser_graph.operations:
                if op.name == 'external':
                    if isinstance(input_shape, dict):
                        name = op.outputs['output']
                        if name in input_shape:
                            op.attribs['shape'] = input_shape[name]
                    else:
                        op.attribs['shape'] = input_shape

        parser_config.infer_shapes(parser_graph)
        return _read(parser_graph=parser_graph, with_weights=with_weights)

    finally:
        if compressed and path_to_load:
            shutil.rmtree(path_to_load)