コード例 #1
0
ファイル: reader.py プロジェクト: KhronosGroup/NNEF-Tools
    def __call__(self, path, input_shapes=None):
        compressed = os.path.splitext(path) in ['tgz', 'gz'
                                                ] and not os.path.isdir(path)

        folder = None
        try:
            if compressed:
                folder = tempfile.mkdtemp(prefix="nnef_")
                tgz_extract(path, folder)
                path = folder

            if not os.path.isdir(path):
                raise IOError(
                    "NNEF model must be a (compressed) folder, but an uncompressed file was provided"
                )

            nnef_graph = nnef.load_graph(path,
                                         stdlib=self._stdlib,
                                         lowered=self._decomposed)
            if self._infer_shapes:
                nnef.infer_shapes(nnef_graph,
                                  external_shapes=input_shapes or {},
                                  custom_shapes=self._custom_shapes or {})

            return _build_graph(nnef_graph)
        finally:
            if folder is not None:
                shutil.rmtree(folder)
コード例 #2
0
 def test_reshape(self):
     graph = nnef.parse_string("""
         version 1.0;
         graph G( input ) -> ( output )
         {
             input = external(shape = [1,2,3,4]);
             output = reshape(input, axis_start = 1, axis_count = 2, shape = [6]);
         }
         """)
     nnef.infer_shapes(graph)
コード例 #3
0
def nnef2ir(inputFolder, outputFolder):
    nnef_graph = nnef.load_graph(inputFolder)
    nnef.infer_shapes(nnef_graph)
    graph = nnef_graph_to_ir_graph(nnef_graph)
    graph.toFile(outputFolder)
コード例 #4
0
ファイル: sample_ext.py プロジェクト: jnorwood/NNEF-Tools
import nnef


def shuffle_shape(input, groups):
    assert input[
        1] % groups == 0, "input channels ({}) is not divisible by groups ({})".format(
            input[1], groups)
    return input


graph = nnef.parse_string("""
    version 1.0;
    extension KHR_enable_fragment_definitions;

    fragment shuffle<?>( input: tensor<?>, groups: integer ) -> ( output: tensor<?> );

    graph Net( input ) -> ( output )
    {
        input = external(shape = [1,3,224,224]);
        filter = variable(shape = [32,3,5,5], label = 'conv/filter');
        conv = conv(input, filter);
        output = shuffle(conv, groups = 4);
    }
    """)

nnef.infer_shapes(graph, custom_shapes={'shuffle': shuffle_shape})

print(
    nnef.format_graph(graph.name, graph.inputs, graph.outputs,
                      graph.operations))
コード例 #5
0
    args = ap.parse_args()

    stdlib = ''
    if args.stdlib:
        try:
            with open(args.stdlib) as file:
                stdlib = file.read()
        except FileNotFoundError as e:
            print('Could not open file: ' + args.stdlib)
            exit(-1)

    try:
        graph = nnef.load_graph(args.path,
                                stdlib=stdlib,
                                lowered=args.lower.split(','))
    except nnef.Error as err:
        print('Parse error: ' + str(err))
        exit(-1)

    if args.shapes:
        try:
            nnef.infer_shapes(graph)
        except nnef.Error as err:
            print('Shape error: ' + str(err))
            exit(-1)

    print(
        nnef.format_graph(graph.name, graph.inputs, graph.outputs,
                          graph.operations))
    print('Validation succeeded')
コード例 #6
0
ファイル: parser_config.py プロジェクト: jnorwood/NNEF-Tools
 def infer_shapes(self, graph):
     nnef.infer_shapes(graph=graph, custom_shapes=self._shapes)
     return graph