コード例 #1
0
    def test_assessor(self):
        pass
        _reverse_io()
        send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}')
        send(CommandType.ReportMetricData, '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}')
        send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}')
        send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED"}')
        send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED"}')
        send(CommandType.NewTrialJob, 'null')
        _restore_io()

        assessor = NaiveAssessor()
        dispatcher = MsgDispatcher(None, assessor)
        msg_dispatcher_base._worker_fast_exit_on_terminate = False

        dispatcher.run()
        e = dispatcher.worker_exceptions[0]
        self.assertIs(type(e), AssertionError)
        self.assertEqual(e.args[0], 'Unsupported command: CommandType.NewTrialJob')

        self.assertEqual(_trials, ['A', 'B', 'A'])
        self.assertEqual(_end_trials, [('A', False), ('B', True)])

        _reverse_io()
        command, data = receive()
        self.assertIs(command, CommandType.KillTrialJob)
        self.assertEqual(data, '"A"')
        self.assertEqual(len(_out_buf.read()), 0)
コード例 #2
0
 def _assert_params(self, parameter_id, param, trial_results, search_space):
     command, data = receive()
     self.assertIs(command, CommandType.NewTrialJob)
     data = json.loads(data)
     self.assertEqual(data['parameter_id'], parameter_id)
     self.assertEqual(data['parameter_source'], 'algorithm')
     self.assertEqual(data['parameters']['param'], param)
     self.assertEqual(data['parameters']['trial_results'], trial_results)
     self.assertEqual(data['parameters']['search_space'], search_space)
コード例 #3
0
    def test_submit_models(self):
        _reset()
        nni.retiarii.debug_configs.framework = 'pytorch'
        os.makedirs('generated', exist_ok=True)
        from nni.runtime import protocol
        import nni.runtime.platform.test as tt
        protocol._set_out_file(
            open('generated/debug_protocol_out_file.py', 'wb'))
        protocol._set_in_file(
            open('generated/debug_protocol_out_file.py', 'rb'))

        models = _load_mnist(2)

        advisor = RetiariiAdvisor()
        cgo_engine = CGOExecutionEngine(devices=[
            GPUDevice("test", 0),
            GPUDevice("test", 1),
            GPUDevice("test", 2),
            GPUDevice("test", 3)
        ],
                                        batch_waiting_time=0)
        set_execution_engine(cgo_engine)
        submit_models(*models)
        time.sleep(3)

        if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
            cmd, data = protocol.receive()
            params = nni.load(data)

            tt.init_params(params)

            trial_thread = threading.Thread(
                target=CGOExecutionEngine.trial_execute_graph)
            trial_thread.start()
            last_metric = None
            while True:
                time.sleep(1)
                if tt._last_metric:
                    metric = tt.get_last_metric()
                    if metric == last_metric:
                        continue
                    if 'value' in metric:
                        metric['value'] = json.dumps(metric['value'])
                    advisor.handle_report_metric_data(metric)
                    last_metric = metric
                if not trial_thread.is_alive():
                    trial_thread.join()
                    break

            trial_thread.join()

        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()
        cgo_engine.join()
コード例 #4
0
ファイル: test_cgo_engine.py プロジェクト: zyz0577/nni
    def test_submit_models(self):
        os.environ['CGO'] = 'true'
        os.makedirs('generated', exist_ok=True)
        from nni.runtime import protocol, platform
        import nni.runtime.platform.test as tt
        protocol._out_file = open('generated/debug_protocol_out_file.py', 'wb')
        protocol._in_file = open('generated/debug_protocol_out_file.py', 'rb')

        models = _load_mnist(2)
        advisor = RetiariiAdvisor()
        submit_models(*models)

        if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
            cmd, data = protocol.receive()
            params = json.loads(data)
            params['parameters']['training_kwargs']['max_steps'] = 100

            tt.init_params(params)

            trial_thread = threading.Thread(
                target=CGOExecutionEngine.trial_execute_graph())
            trial_thread.start()
            last_metric = None
            while True:
                time.sleep(1)
                if tt._last_metric:
                    metric = tt.get_last_metric()
                    if metric == last_metric:
                        continue
                    advisor.handle_report_metric_data(metric)
                    last_metric = metric
                if not trial_thread.is_alive():
                    break

            trial_thread.join()
        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()
コード例 #5
0
 def test_receive_zh(self):
     _prepare_receive('IN00000000000006世界'.encode('utf8'))
     command, data = receive()
     self.assertIs(command, CommandType.Initialize)
     self.assertEqual(data, '世界')
コード例 #6
0
 def test_receive_en(self):
     _prepare_receive(b'IN00000000000005hello')
     command, data = receive()
     self.assertIs(command, CommandType.Initialize)
     self.assertEqual(data, 'hello')