Exemple #1
0
def check_graph(tmp_path, graph, model_name, model_framework, check_weights=False):
    """
    Checking that two graphs are equal by comparing topologies and
    all weights if check_weights is specified as True.
    """
    model_name = '_'.join([model_name, model_framework])
    ir_name_xml = model_name + '.xml'
    path_to_ir_xml = tmp_path.joinpath(ir_name_xml)
    save_graph(graph, tmp_path.as_posix(), model_name)

    path_to_ref_ir_xml = REFERENCE_MODELS_PATH.joinpath(ir_name_xml)

    if check_weights:
        ir_name_bin = model_name + '.bin'
        path_to_ir_bin = tmp_path.joinpath(ir_name_bin).as_posix()
        path_to_ref_ir_bin = REFERENCE_MODELS_PATH.joinpath(ir_name_bin).as_posix()
    else:
        path_to_ir_bin = None
        path_to_ref_ir_bin = None

    if not path_to_ref_ir_xml.exists():
        shutil.copyfile(path_to_ir_xml.as_posix(), path_to_ref_ir_xml.as_posix())
        if check_weights:
            shutil.copyfile(path_to_ir_bin, path_to_ref_ir_bin)

    ref_graph = IREngine(path_to_ref_ir_xml.as_posix(), path_to_ref_ir_bin)
    test_graph = IREngine(path_to_ir_xml.as_posix(), path_to_ir_bin)

    result, stderr = ref_graph.compare(test_graph)
    if stderr:
        print(stderr)
    assert result
Exemple #2
0
    def test_load_bin_hashes(self):
        path_for_file = self.IR.generate_bin_hashes_file()
        IR = IREngine(path_to_xml=str(self.xml),
                      path_to_bin=str(path_for_file))
        is_ok = True
        # Check for constant nodes
        const_nodes = IR.graph.get_op_nodes(type='Const')
        for node in const_nodes:
            if not node.has_valid('hashes'):
                log.error('Constant node {} do not include hashes'.format(
                    node.name))
                is_ok = False

        # Check for TensorIterator Body
        ti_nodes = IR.graph.get_op_nodes(type='TensorIterator')
        for ti in ti_nodes:
            if not ti.has_valid('body'):
                log.error(
                    "TensorIterator doesn't have body attribute for node: {}".
                    format(ti.name))
            else:
                const_ti_nodes = ti.body.graph.get_op_nodes(type='Const')
                for node in const_ti_nodes:
                    if not node.has_valid('hashes'):
                        log.error(
                            'Constant node {} do not include hashes'.format(
                                node.name))
                        is_ok = False

        self.assertTrue(is_ok, 'Test for function load_bin_hashes failed')
        os.remove(path_for_file)
Exemple #3
0
 def test_is_float(self, test_data, result):
     test_data = test_data
     self.assertEqual(
         IREngine._IREngine__isfloat(test_data), result,
         "Function __isfloat is not working with value: {}".format(
             test_data))
     log.info(
         'Test for function __is_float passed with value: {}, expected result: {}'
         .format(test_data, result))
Exemple #4
0
    def setUp(self):
        path, _ = os.path.split(os.path.dirname(__file__))
        self.xml = os.path.join(
            path, os.pardir, os.pardir, "utils", "test_data",
            "mxnet_synthetic_gru_bidirectional_FP16_1_v6.xml")
        self.xml_negative = os.path.join(
            path, os.pardir, os.pardir, "utils", "test_data",
            "mxnet_synthetic_gru_bidirectional_FP16_1_v6_negative.xml")
        self.bin = os.path.splitext(self.xml)[0] + '.bin'
        self.assertTrue(os.path.exists(self.xml),
                        'XML file not found: {}'.format(self.xml))
        self.assertTrue(os.path.exists(self.bin),
                        'BIN file not found: {}'.format(self.bin))

        self.IR = IREngine(path_to_xml=str(self.xml),
                           path_to_bin=str(self.bin))
        self.IR_ref = IREngine(path_to_xml=str(self.xml),
                               path_to_bin=str(self.bin))
        self.IR_negative = IREngine(path_to_xml=str(self.xml_negative),
                                    path_to_bin=str(self.bin))
Exemple #5
0
def restore_graph_from_ir(path_to_xml: str, path_to_bin: str = None) -> (Graph, dict):
    """
    Function to make valid graph and metadata for MO back stage from IR.
    :param path_to_xml:
    :param path_to_bin:
    :return: (restored graph, meta data)
    """
    ir = IREngine(path_to_xml, path_to_bin)
    assert ir.graph.graph.get('ir_version') >= 10, 'IR version {} is not supported, ' \
        'please generate actual IR for your model and use it.'.format(ir.graph.graph.get('ir_version'))

    path = get_mo_root_dir()
    collect_ops(path)
    collect_extenders(path)

    # Create a new copy of graph with correct attributes (shape & type infer, backend attrs etc.)
    new_graph = copy_graph_with_ops(ir.graph)

    return new_graph, copy(ir.meta_data)
Exemple #6
0
 def test_find_input(self):
     # Create references for this test:
     ref_nodes = [Node(self.IR.graph, '0')]
     # Check function:
     a = IREngine._IREngine__find_input(self.IR.graph)
     self.assertTrue(a == ref_nodes, 'Error')
Exemple #7
0
class TestFunction(unittest.TestCase):
    def setUp(self):
        path, _ = os.path.split(os.path.dirname(__file__))
        self.xml = os.path.join(
            path, os.pardir, os.pardir, "utils", "test_data",
            "mxnet_synthetic_gru_bidirectional_FP16_1_v6.xml")
        self.xml_negative = os.path.join(
            path, os.pardir, os.pardir, "utils", "test_data",
            "mxnet_synthetic_gru_bidirectional_FP16_1_v6_negative.xml")
        self.bin = os.path.splitext(self.xml)[0] + '.bin'
        self.assertTrue(os.path.exists(self.xml),
                        'XML file not found: {}'.format(self.xml))
        self.assertTrue(os.path.exists(self.bin),
                        'BIN file not found: {}'.format(self.bin))

        self.IR = IREngine(path_to_xml=str(self.xml),
                           path_to_bin=str(self.bin))
        self.IR_ref = IREngine(path_to_xml=str(self.xml),
                               path_to_bin=str(self.bin))
        self.IR_negative = IREngine(path_to_xml=str(self.xml_negative),
                                    path_to_bin=str(self.bin))

    @generate(*[(4.4, True), ('aaaa', False)])
    def test_is_float(self, test_data, result):
        test_data = test_data
        self.assertEqual(
            IREngine._IREngine__isfloat(test_data), result,
            "Function __isfloat is not working with value: {}".format(
                test_data))
        log.info(
            'Test for function __is_float passed with value: {}, expected result: {}'
            .format(test_data, result))

    # TODO add comparison not for type IREngine
    def test_compare(self):
        flag, msg = self.IR.compare(self.IR_ref)
        self.assertTrue(flag, 'Comparing false, test compare function failed')
        log.info('Test for function compare passed')

    def test_compare_negative(self):
        # Reference data for test:
        reference_msg = 'Current node "2" with type "Const" and reference node "2" with type "Input" have different ' \
                        'attr "type" : Const and Input'
        # Check function:
        flag, msg = self.IR.compare(self.IR_negative)
        self.assertFalse(
            flag, 'Comparing flag failed, test compare function failed')
        self.assertEqual(
            '\n'.join(msg), reference_msg,
            'Comparing message failed, test compare negative failed')

        log.info('Test for function compare passed')

    def test_find_input(self):
        # Create references for this test:
        ref_nodes = [Node(self.IR.graph, '0')]
        # Check function:
        a = IREngine._IREngine__find_input(self.IR.graph)
        self.assertTrue(a == ref_nodes, 'Error')

    def test_get_inputs(self):
        # Reference data for test:
        ref_input_dict = {'data': shape_array([1, 10, 16])}
        # Check function:
        inputs_dict = self.IR.get_inputs()
        self.assertTrue(
            strict_compare_tensors(ref_input_dict['data'],
                                   inputs_dict['data']),
            'Test on function get_inputs failed')
        log.info('Test for function get_inputs passed')

    def test_eq_function(self):
        self.assertTrue(self.IR == self.IR_ref,
                        'Comparing false, test eq function failed')
        log.info('Test for function eq passed')

    @unittest.mock.patch('numpy.savez_compressed')
    def test_generate_bin_hashes_file(self, numpy_savez):
        # Generate bin_hashes file in default directory
        self.IR.generate_bin_hashes_file()
        numpy_savez.assert_called_once()
        log.info(
            'Test for function generate_bin_hashes_file with default folder passed'
        )

    @unittest.mock.patch('numpy.savez_compressed')
    def test_generate_bin_hashes_file_custom_directory(self, numpy_savez):
        # Generate bin_hashes file in custom directory
        directory_for_file = os.path.join(
            os.path.split(os.path.dirname(__file__))[0], "utils", "test_data",
            "bin_hash")
        self.IR.generate_bin_hashes_file(path_for_file=directory_for_file)
        numpy_savez.assert_called_once()
        log.info(
            'Test for function generate_bin_hashes_file with custom folder passed'
        )

    @generate(*[({
        'order': '1,0,2'
    }, {
        'order': [1, 0, 2]
    }), ({
        'order': '1'
    }, {
        'order': 1
    })])
    def test_normalize_attr(self, test_data, reference):
        result_dict = IREngine._IREngine__normalize_attrs(attrs=test_data)
        self.assertTrue(reference == result_dict,
                        'Test on function normalize_attr failed')
        log.info('Test for function normalize_attr passed')

    def test_load_bin_hashes(self):
        path_for_file = self.IR.generate_bin_hashes_file()
        IR = IREngine(path_to_xml=str(self.xml),
                      path_to_bin=str(path_for_file))
        is_ok = True
        # Check for constant nodes
        const_nodes = IR.graph.get_op_nodes(type='Const')
        for node in const_nodes:
            if not node.has_valid('hashes'):
                log.error('Constant node {} do not include hashes'.format(
                    node.name))
                is_ok = False

        # Check for TensorIterator Body
        ti_nodes = IR.graph.get_op_nodes(type='TensorIterator')
        for ti in ti_nodes:
            if not ti.has_valid('body'):
                log.error(
                    "TensorIterator doesn't have body attribute for node: {}".
                    format(ti.name))
            else:
                const_ti_nodes = ti.body.graph.get_op_nodes(type='Const')
                for node in const_ti_nodes:
                    if not node.has_valid('hashes'):
                        log.error(
                            'Constant node {} do not include hashes'.format(
                                node.name))
                        is_ok = False

        self.assertTrue(is_ok, 'Test for function load_bin_hashes failed')
        os.remove(path_for_file)

    @generate(*[
        ("0", True),
        ("1", True),
        ("-1", True),
        ("-", False),
        ("+1", True),
        ("+", False),
        ("1.0", False),
        ("-1.0", False),
        ("1.5", False),
        ("+1.5", False),
        ("abracadabra", False),
    ])
    def test_isint(self, value, result):
        self.assertEqual(IREngine._IREngine__isint(value), result)
Exemple #8
0
 def test_isint(self, value, result):
     self.assertEqual(IREngine._IREngine__isint(value), result)
Exemple #9
0
 def test_normalize_attr(self, test_data, reference):
     result_dict = IREngine._IREngine__normalize_attrs(attrs=test_data)
     self.assertTrue(reference == result_dict,
                     'Test on function normalize_attr failed')
     log.info('Test for function normalize_attr passed')