def test_next_node_on_gpu(self, app_client): """Test get next node on GPU.""" gpu_debugger_client = MockDebuggerClient(backend='GPU') with gpu_debugger_client.get_thread_instance(): check_state(app_client) # send run command to get watchpoint hit url = 'control' body_data = { 'mode': 'continue', 'level': 'node', 'name': 'Default/TransData-op99' } res = get_request_result(app_client, url, body_data) assert res == { 'metadata': { 'state': 'sending', 'enable_recheck': False } } # get metadata check_state(app_client) url = 'retrieve' body_data = {'mode': 'all'} expect_file = 'retrieve_next_node_on_gpu.json' send_and_compare_result(app_client, url, body_data, expect_file) send_terminate_cmd(app_client)
def test_continue_on_gpu(self, app_client, params, expect_file): """Test get next node on GPU.""" gpu_debugger_client = MockDebuggerClient(backend='GPU', graph_num=2) original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS settings.ENABLE_RECOMMENDED_WATCHPOINTS = True try: with gpu_debugger_client.get_thread_instance(): check_state(app_client) # send run command to get watchpoint hit url = 'control' body_data = {'mode': 'continue'} body_data.update(params) res = get_request_result(app_client, url, body_data) assert res == { 'metadata': { 'state': 'sending', 'enable_recheck': False } } # get metadata check_state(app_client) url = 'retrieve' body_data = {'mode': 'all'} send_and_compare_result(app_client, url, body_data, expect_file) send_terminate_cmd(app_client) finally: settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
def test_multi_retrieve_when_train_begin(self, app_client, body_data, expect_file): """Test retrieve when train_begin.""" url = 'retrieve' debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2) with debugger_client.get_thread_instance(): check_state(app_client) send_and_compare_result(app_client, url, body_data, expect_file) send_terminate_cmd(app_client)