def test_element_type(self, mock_argparse):
        main(argparse.ArgumentParser(), fem, 'mock_mo_ngraph_frontend')
        stat = get_model_statistic()

        # verify that 'set_element_type' was called
        assert stat.set_element_type == 1
        assert stat.lastArgElementType == get_element_type(np.int8)
예제 #2
0
 def test_decode_name_with_port_noname(self):
     with self.assertRaisesRegex(Error,
                                 'No\\ node\\ with\\ name.*mocknoname*'):
         decode_name_with_port(self.model, 'mocknoname')
     model_stat = get_model_statistic()
     assert model_stat.get_place_by_tensor_name == 1
     assert model_stat.get_place_by_operation_name == 1
    def test_input_shape(self, mock_argparse):
        main(argparse.ArgumentParser(), fem, 'mock_mo_ngraph_frontend')
        stat = get_model_statistic()

        # verify that 'set_partial_shape' was called
        assert stat.set_partial_shape == 1
        assert stat.lastArgPartialShape == PartialShape([1, 2, 3, 4])
예제 #4
0
    def test_decode_name_with_port_op(self):
        node = decode_name_with_port(self.model, "operation")
        model_stat = get_model_statistic()

        assert model_stat.get_place_by_tensor_name == 1
        assert model_stat.get_place_by_operation_name == 1
        assert node
    def test_extract_subgraph(self, mock_argparse):
        main(argparse.ArgumentParser(), fem, 'mock_mo_ngraph_frontend')
        stat = get_model_statistic()

        # verify that 'extract_subgraph' was called
        assert stat.override_all_inputs == 0
        assert stat.override_all_outputs == 0
        assert stat.extract_subgraph == 1
예제 #6
0
 def test_decode_name_with_port_collision(self):
     with self.assertRaisesRegex(Error, 'Name\\ collision.*tensorAndOp*'):
         decode_name_with_port(self.model, 'tensorAndOp:0')
     model_stat = get_model_statistic()
     place_stat = get_place_statistic()
     assert model_stat.get_place_by_tensor_name == 1
     assert model_stat.get_place_by_operation_name == 1
     assert place_stat.is_equal_data > 0
예제 #7
0
    def test_decode_name_with_port_delim_tensor_no_collision_in(self):
        node = decode_name_with_port(self.model, '0:tensor')
        model_stat = get_model_statistic()
        place_stat = get_place_statistic()

        assert model_stat.get_place_by_tensor_name == 1
        assert model_stat.get_place_by_operation_name == 2
        assert place_stat.get_input_port == 0
        assert node
    def test_set_batch_size(self, mock_argparse):
        mock_return_partial_shape(PartialShape([-1, 2, 3, 4]))
        main(argparse.ArgumentParser(), fem, 'mock_mo_ngraph_frontend')
        stat = get_model_statistic()

        # verify that 'set_element_type' was called
        # 2 is because mock model has 2 inputs
        assert stat.get_partial_shape == 2
        assert stat.set_partial_shape == 2
        assert stat.lastArgPartialShape == PartialShape([123, 2, 3, 4])
    def test_override_same_outputs(self, mock_argparse):

        main(argparse.ArgumentParser(), fem, 'mock_mo_ngraph_frontend')
        stat = get_model_statistic()

        # verify that 'override_all_inputs' was called
        # because outputs were not changed
        assert stat.override_all_inputs == 1
        assert stat.override_all_outputs == 0
        assert stat.extract_subgraph == 0
예제 #10
0
    def test_decode_name_with_port_delim_op_out(self):
        node = decode_name_with_port(self.model, 'operation:7')
        model_stat = get_model_statistic()
        place_stat = get_place_statistic()

        assert model_stat.get_place_by_tensor_name == 1
        assert model_stat.get_place_by_operation_name == 2
        assert place_stat.get_output_port == 1
        assert place_stat.lastArgInt == 7
        assert node
예제 #11
0
    def test_decode_name_with_port_delim_no_port_in(self):
        with self.assertRaisesRegex(Error, 'No\\ node\\ with\\ name.*1234\\:operation*'):
            decode_name_with_port(self.model, '1234:operation')
        model_stat = get_model_statistic()
        place_stat = get_place_statistic()

        assert model_stat.get_place_by_tensor_name == 1
        assert model_stat.get_place_by_operation_name == 2
        assert place_stat.get_input_port == 1
        assert place_stat.lastArgInt == 1234
예제 #12
0
    def test_decode_name_with_port_delim_equal_data_in(self):
        set_equal_data('conv2d', 'conv2d')
        node = decode_name_with_port(self.model, '0:conv2d')
        model_stat = get_model_statistic()
        place_stat = get_place_statistic()

        assert model_stat.get_place_by_tensor_name == 1
        assert model_stat.get_place_by_operation_name == 2
        assert place_stat.get_input_port == 1
        assert place_stat.is_equal_data > 0
        assert node
예제 #13
0
    def test_decode_name_with_port_delim_op_collision_in(self):
        with self.assertRaisesRegex(Error, 'Name\\ collision(?!.*Tensor.*).*0\\:operation*'):
            decode_name_with_port(self.model, '0:operation')
        model_stat = get_model_statistic()
        place_stat = get_place_statistic()

        assert model_stat.get_place_by_tensor_name == 1
        assert model_stat.get_place_by_operation_name == 2
        assert place_stat.is_equal_data > 0
        assert place_stat.get_input_port == 1
        assert place_stat.lastArgInt == 0
예제 #14
0
    def test_decode_name_with_port_delim_no_port_out(self):
        with self.assertRaisesRegex(
                Error, 'No\\ node\\ with\\ name.*operation\\:1234*'):
            decode_name_with_port(self.model, 'operation:1234')
        model_stat = get_model_statistic()
        place_stat = get_place_statistic()

        assert model_stat.get_place_by_tensor_name == 1
        assert model_stat.get_place_by_operation_name >= 1
        assert place_stat.get_output_port == 1
        assert place_stat.lastArgInt == 1234
예제 #15
0
    def test_decode_name_with_port_delim_all_same_data(self):
        set_equal_data('8', '9')
        node = decode_name_with_port(self.model, '8:9')
        model_stat = get_model_statistic()
        place_stat = get_place_statistic()

        assert model_stat.get_place_by_tensor_name == 1
        assert model_stat.get_place_by_operation_name == 3
        assert place_stat.get_input_port == 1
        assert place_stat.get_output_port == 1
        # At least 3 comparisons of places are expected
        assert place_stat.is_equal_data > 2
        assert node
    def test_error_batch(self, mock_argparse):
        # First dimension doesn't look like a batch,
        # so MO shall not convert anything and produce specified error
        mock_return_partial_shape(PartialShape([122, 2, 3, 4]))
        with self.assertLogs() as logger:
            main(argparse.ArgumentParser(), fem, 'mock_mo_ngraph_frontend')

        stat = get_model_statistic()

        assert [s for s in logger.output if 'question=39' in s]

        # verify that 'get_element_type' was called
        assert stat.get_partial_shape == 1
        # verify that 'set_element_type' was not called
        assert stat.set_partial_shape == 0