コード例 #1
0
ファイル: loader_test.py プロジェクト: pc2/CustoNN2
 def test_load_symbol_nodes(self, mock_symbol_load, mock_isfile,
                            mock_json_loads, mock_json_load):
     mock_isfile.return_value = True
     mock_json_load.return_value = {'nodes': ''}
     mock_json_loads.return_value = {'nodes': {'node1': 1}}
     mock_symbol_load_obj = MockSymbolLoadObj()
     mock_symbol_load.return_value = mock_symbol_load_obj
     with patch('mo.front.mxnet.loader.open') as mock_open:
         self.assertEqual({'node1': 1},
                          load_symbol_nodes("model_name", True))
コード例 #2
0
 def test_load_symbol_nodes_with_json(self, mock_symbol_load, mock_isfile,
                                      mock_json_loads, mock_json_load):
     mock_isfile.return_value = True
     #json.load
     mock_json_load.return_value = {'nodes': {'node1': 1}}
     mock_json_loads.return_value = {'nodes': ''}
     mock_symbol_load_obj = MockSymbolLoadObj()
     mock_symbol_load.return_value = mock_symbol_load_obj
     with patch('mo.front.mxnet.loader.open') as mock_open:
         self.assertEqual({'node1': 1},
                          load_symbol_nodes("model_name",
                                            input_symbol="some-symbol.json",
                                            legacy_mxnet_model=False))
コード例 #3
0
ファイル: loader_test.py プロジェクト: pc2/CustoNN2
 def test_load_symbol_with_custom_nodes(self, mock_symbol_load, mock_isfile,
                                        mock_json_loads, mock_json_load):
     mock_isfile.return_value = True
     mock_json_load.return_value = {
         'nodes': [{
             'op': 'custom_op'
         }, {
             'op': 'custom_op'
         }]
     }
     mock_json_loads.return_value = {'nodes': {'node1': 1}}
     mock_symbol_load_obj = MockSymbolLoadObj()
     mock_symbol_load.return_value = mock_symbol_load_obj
     with patch('mo.front.mxnet.loader.open') as mock_open:
         list_nodes = load_symbol_nodes("model_name", False)
         self.assertEqual(2, len(list_nodes))
         for node in list_nodes:
             self.assertEqual({'op': 'custom_op'}, node)