def test_load_symbol_nodes(self, mock_symbol_load, mock_isfile, mock_json_loads, mock_json_load): mock_isfile.return_value = True mock_json_load.return_value = {'nodes': ''} mock_json_loads.return_value = {'nodes': {'node1': 1}} mock_symbol_load_obj = MockSymbolLoadObj() mock_symbol_load.return_value = mock_symbol_load_obj with patch('mo.front.mxnet.loader.open') as mock_open: self.assertEqual({'node1': 1}, load_symbol_nodes("model_name", True))
def test_load_symbol_nodes_with_json(self, mock_symbol_load, mock_isfile, mock_json_loads, mock_json_load): mock_isfile.return_value = True #json.load mock_json_load.return_value = {'nodes': {'node1': 1}} mock_json_loads.return_value = {'nodes': ''} mock_symbol_load_obj = MockSymbolLoadObj() mock_symbol_load.return_value = mock_symbol_load_obj with patch('mo.front.mxnet.loader.open') as mock_open: self.assertEqual({'node1': 1}, load_symbol_nodes("model_name", input_symbol="some-symbol.json", legacy_mxnet_model=False))
def test_load_symbol_with_custom_nodes(self, mock_symbol_load, mock_isfile, mock_json_loads, mock_json_load): mock_isfile.return_value = True mock_json_load.return_value = { 'nodes': [{ 'op': 'custom_op' }, { 'op': 'custom_op' }] } mock_json_loads.return_value = {'nodes': {'node1': 1}} mock_symbol_load_obj = MockSymbolLoadObj() mock_symbol_load.return_value = mock_symbol_load_obj with patch('mo.front.mxnet.loader.open') as mock_open: list_nodes = load_symbol_nodes("model_name", False) self.assertEqual(2, len(list_nodes)) for node in list_nodes: self.assertEqual({'op': 'custom_op'}, node)