예제 #1
0
    def test_submit_models(self):
        _reset()
        nni.retiarii.debug_configs.framework = 'pytorch'
        os.makedirs('generated', exist_ok=True)
        import nni.runtime.platform.test as tt
        protocol._set_out_file(
            open('generated/debug_protocol_out_file.py', 'wb'))
        protocol._set_in_file(
            open('generated/debug_protocol_out_file.py', 'rb'))

        models = _load_mnist(2)

        advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
        advisor._channel = protocol.LegacyCommandChannel()
        advisor.default_worker.start()
        advisor.assessor_worker.start()

        remote = RemoteConfig(machine_list=[])
        remote.machine_list.append(
            RemoteMachineConfig(host='test', gpu_indices=[0, 1, 2, 3]))
        cgo_engine = CGOExecutionEngine(training_service=remote,
                                        batch_waiting_time=0)
        set_execution_engine(cgo_engine)
        submit_models(*models)
        time.sleep(3)

        if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
            cmd, data = protocol.receive()
            params = nni.load(data)

            tt.init_params(params)

            trial_thread = threading.Thread(
                target=CGOExecutionEngine.trial_execute_graph)
            trial_thread.start()
            last_metric = None
            while True:
                time.sleep(1)
                if tt._last_metric:
                    metric = tt.get_last_metric()
                    if metric == last_metric:
                        continue
                    if 'value' in metric:
                        metric['value'] = json.dumps(metric['value'])
                    advisor.handle_report_metric_data(metric)
                    last_metric = metric
                if not trial_thread.is_alive():
                    trial_thread.join()
                    break

            trial_thread.join()

        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()
        cgo_engine.join()
예제 #2
0
 def _test_report_final_result(self, in_, out):
     nni.report_final_result(in_)
     self.assertEqual(
         test_platform.get_last_metric(), {
             'parameter_id': 'test_param',
             'trial_job_id': 'test_trial_job_id',
             'type': 'FINAL',
             'sequence': 0,
             'value': out
         })
예제 #3
0
 def test_report_intermediate_result(self):
     nni.report_intermediate_result(123)
     self.assertEqual(
         test_platform.get_last_metric(), {
             'parameter_id': 'test_param',
             'trial_job_id': 'test_trial_job_id',
             'type': 'PERIODICAL',
             'sequence': 0,
             'value': 123
         })
예제 #4
0
    def test_submit_models(self):
        _reset()
        nni.retiarii.debug_configs.framework = 'pytorch'
        os.makedirs('generated', exist_ok=True)
        from nni.runtime import protocol
        import nni.runtime.platform.test as tt
        protocol._set_out_file(
            open('generated/debug_protocol_out_file.py', 'wb'))
        protocol._set_in_file(
            open('generated/debug_protocol_out_file.py', 'rb'))

        models = _load_mnist(2)

        advisor = RetiariiAdvisor()
        cgo_engine = CGOExecutionEngine(devices=[
            GPUDevice("test", 0),
            GPUDevice("test", 1),
            GPUDevice("test", 2),
            GPUDevice("test", 3)
        ],
                                        batch_waiting_time=0)
        set_execution_engine(cgo_engine)
        submit_models(*models)
        time.sleep(3)

        if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
            cmd, data = protocol.receive()
            params = nni.load(data)

            tt.init_params(params)

            trial_thread = threading.Thread(
                target=CGOExecutionEngine.trial_execute_graph)
            trial_thread.start()
            last_metric = None
            while True:
                time.sleep(1)
                if tt._last_metric:
                    metric = tt.get_last_metric()
                    if metric == last_metric:
                        continue
                    if 'value' in metric:
                        metric['value'] = json.dumps(metric['value'])
                    advisor.handle_report_metric_data(metric)
                    last_metric = metric
                if not trial_thread.is_alive():
                    trial_thread.join()
                    break

            trial_thread.join()

        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()
        cgo_engine.join()
예제 #5
0
 def test_report_final_result_nparray(self):
     arr = np.array([[1, 2, 3], [4, 5, 6]])
     nni.report_final_result(arr)
     out = test_platform.get_last_metric()
     self.assertEqual(len(arr), 2)
     self.assertEqual(len(arr[0]), 3)
     self.assertEqual(len(arr[1]), 3)
     self.assertEqual(arr[0][0], 1)
     self.assertEqual(arr[0][1], 2)
     self.assertEqual(arr[0][2], 3)
     self.assertEqual(arr[1][0], 4)
     self.assertEqual(arr[1][1], 5)
     self.assertEqual(arr[1][2], 6)
예제 #6
0
    def test_submit_models(self):
        os.environ['CGO'] = 'true'
        os.makedirs('generated', exist_ok=True)
        from nni.runtime import protocol, platform
        import nni.runtime.platform.test as tt
        protocol._out_file = open('generated/debug_protocol_out_file.py', 'wb')
        protocol._in_file = open('generated/debug_protocol_out_file.py', 'rb')

        models = _load_mnist(2)
        advisor = RetiariiAdvisor()
        submit_models(*models)

        if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
            cmd, data = protocol.receive()
            params = json.loads(data)
            params['parameters']['training_kwargs']['max_steps'] = 100

            tt.init_params(params)

            trial_thread = threading.Thread(
                target=CGOExecutionEngine.trial_execute_graph())
            trial_thread.start()
            last_metric = None
            while True:
                time.sleep(1)
                if tt._last_metric:
                    metric = tt.get_last_metric()
                    if metric == last_metric:
                        continue
                    advisor.handle_report_metric_data(metric)
                    last_metric = metric
                if not trial_thread.is_alive():
                    break

            trial_thread.join()
        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()