Exemple #1
0
 def test_next_node_on_gpu(self, app_client):
     """Test get next node on GPU."""
     gpu_debugger_client = MockDebuggerClient(backend='GPU')
     with gpu_debugger_client.get_thread_instance():
         check_state(app_client)
         # send run command to get watchpoint hit
         url = 'control'
         body_data = {
             'mode': 'continue',
             'level': 'node',
             'name': 'Default/TransData-op99'
         }
         res = get_request_result(app_client, url, body_data)
         assert res == {
             'metadata': {
                 'state': 'sending',
                 'enable_recheck': False
             }
         }
         # get metadata
         check_state(app_client)
         url = 'retrieve'
         body_data = {'mode': 'all'}
         expect_file = 'retrieve_next_node_on_gpu.json'
         send_and_compare_result(app_client, url, body_data, expect_file)
         send_terminate_cmd(app_client)
Exemple #2
0
 def test_continue_on_gpu(self, app_client, params, expect_file):
     """Test get next node on GPU."""
     gpu_debugger_client = MockDebuggerClient(backend='GPU', graph_num=2)
     original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS
     settings.ENABLE_RECOMMENDED_WATCHPOINTS = True
     try:
         with gpu_debugger_client.get_thread_instance():
             check_state(app_client)
             # send run command to get watchpoint hit
             url = 'control'
             body_data = {'mode': 'continue'}
             body_data.update(params)
             res = get_request_result(app_client, url, body_data)
             assert res == {
                 'metadata': {
                     'state': 'sending',
                     'enable_recheck': False
                 }
             }
             # get metadata
             check_state(app_client)
             url = 'retrieve'
             body_data = {'mode': 'all'}
             send_and_compare_result(app_client, url, body_data,
                                     expect_file)
             send_terminate_cmd(app_client)
     finally:
         settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
Exemple #3
0
 def test_multi_retrieve_when_train_begin(self, app_client, body_data,
                                          expect_file):
     """Test retrieve when train_begin."""
     url = 'retrieve'
     debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2)
     with debugger_client.get_thread_instance():
         check_state(app_client)
         send_and_compare_result(app_client, url, body_data, expect_file)
         send_terminate_cmd(app_client)