def test_assessor(self): _reverse_io() send( CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":2}') send( CommandType.ReportMetricData, '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":2}') send( CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":3}') send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED"}') send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED"}') send(CommandType.NewTrialJob, 'null') _restore_io() assessor = NaiveAssessor() dispatcher = MsgDispatcher(None, assessor) try: dispatcher.run() except Exception as e: self.assertIs(type(e), AssertionError) self.assertEqual(e.args[0], 'Unsupported command: CommandType.NewTrialJob') self.assertEqual(_trials, ['A', 'B', 'A']) self.assertEqual(_end_trials, [('A', False), ('B', True)]) _reverse_io() command, data = receive() self.assertIs(command, CommandType.KillTrialJob) self.assertEqual(data, '"A"') self.assertEqual(len(_out_buf.read()), 0)
def test_msg_dispatcher(self): _reverse_io() # now we are sending to Tuner's incoming stream send(CommandType.RequestTrialJobs, '2') send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10}') send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11}') send(CommandType.UpdateSearchSpace, '{"name":"SS0"}') send(CommandType.RequestTrialJobs, '1') send(CommandType.KillTrialJob, 'null') _restore_io() tuner = NaiveTuner() dispatcher = MsgDispatcher(tuner) nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False dispatcher.run() e = dispatcher.worker_exceptions[0] self.assertIs(type(e), AssertionError) self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob') _reverse_io() # now we are receiving from Tuner's outgoing stream self._assert_params(0, 2, [], None) self._assert_params(1, 4, [], None) self._assert_params(2, 6, [[1, 4, 11, False]], {'name': 'SS0'}) self.assertEqual(len(_out_buf.read()), 0) # no more commands
def main(): ''' main function. ''' args = parse_args() tuner = None assessor = None if args.tuner_class_name in ModuleName: tuner = create_builtin_class_instance(args.tuner_class_name, args.tuner_args) else: tuner = create_customized_class_instance(args.tuner_directory, args.tuner_class_filename, args.tuner_class_name, args.tuner_args) if tuner is None: raise AssertionError('Failed to create Tuner instance') if args.assessor_class_name: if args.assessor_class_name in ModuleName: assessor = create_builtin_class_instance(args.assessor_class_name, args.assessor_args) else: assessor = create_customized_class_instance( args.assessor_directory, args.assessor_class_filename, args.assessor_class_name, args.assessor_args) if assessor is None: raise AssertionError('Failed to create Assessor instance') dispatcher = MsgDispatcher(tuner, assessor) try: dispatcher.run() tuner._on_exit() if assessor is not None: assessor._on_exit() except Exception as exception: logger.exception(exception) tuner._on_error() if assessor is not None: assessor._on_error() raise
def test_tuner(self): _reverse_io() # now we are sending to Tuner's incoming stream send(CommandType.RequestTrialJobs, '2') send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10}') send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11}') send(CommandType.UpdateSearchSpace, '{"name":"SS0"}') send(CommandType.AddCustomizedTrialJob, '{"param":-1}') send(CommandType.ReportMetricData, '{"parameter_id":2,"type":"FINAL","value":22}') send(CommandType.RequestTrialJobs, '1') send(CommandType.KillTrialJob, 'null') _restore_io() tuner = NaiveTuner() dispatcher = MsgDispatcher(tuner) nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False dispatcher.run() e = dispatcher.worker_exceptions[0] self.assertIs(type(e), AssertionError) self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob') _reverse_io() # now we are receiving from Tuner's outgoing stream self._assert_params(0, 2, [], None) self._assert_params(1, 4, [], None) command, data = receive() # this one is customized data = json.loads(data) self.assertIs(command, CommandType.NewTrialJob) self.assertEqual( data, { 'parameter_id': 2, 'parameter_source': 'customized', 'parameters': { 'param': -1 } }) self._assert_params(3, 6, [[1, 4, 11, False], [2, -1, 22, True]], {'name': 'SS0'}) self.assertEqual(len(_out_buf.read()), 0) # no more commands
def main(): ''' main function. ''' args = parse_args() if args.multi_thread: enable_multi_thread() if args.multi_phase: enable_multi_phase() if args.advisor_class_name: # advisor is enabled and starts to run if args.advisor_class_name in AdvisorModuleName: dispatcher = create_builtin_class_instance(args.advisor_class_name, args.advisor_args, True) else: dispatcher = create_customized_class_instance( args.advisor_directory, args.advisor_class_filename, args.advisor_class_name, args.advisor_args) if dispatcher is None: raise AssertionError('Failed to create Advisor instance') try: dispatcher.run() except Exception as exception: logger.exception(exception) raise else: # tuner (and assessor) is enabled and starts to run tuner = None assessor = None if args.tuner_class_name in ModuleName: tuner = create_builtin_class_instance(args.tuner_class_name, args.tuner_args) else: tuner = create_customized_class_instance(args.tuner_directory, args.tuner_class_filename, args.tuner_class_name, args.tuner_args) if tuner is None: raise AssertionError('Failed to create Tuner instance') if args.assessor_class_name: if args.assessor_class_name in ModuleName: assessor = create_builtin_class_instance( args.assessor_class_name, args.assessor_args) else: assessor = create_customized_class_instance( args.assessor_directory, args.assessor_class_filename, args.assessor_class_name, args.assessor_args) if assessor is None: raise AssertionError('Failed to create Assessor instance') dispatcher = MsgDispatcher(tuner, assessor) try: dispatcher.run() tuner._on_exit() if assessor is not None: assessor._on_exit() except Exception as exception: logger.exception(exception) tuner._on_error() if assessor is not None: assessor._on_error() raise