Exemple #1
0
    def test_assessor(self):
        _reverse_io()
        send(
            CommandType.ReportMetricData,
            '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":2}')
        send(
            CommandType.ReportMetricData,
            '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":2}')
        send(
            CommandType.ReportMetricData,
            '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":3}')
        send(CommandType.TrialEnd,
             '{"trial_job_id":"A","event":"SYS_CANCELED"}')
        send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED"}')
        send(CommandType.NewTrialJob, 'null')
        _restore_io()

        assessor = NaiveAssessor()
        dispatcher = MsgDispatcher(None, assessor)
        try:
            dispatcher.run()
        except Exception as e:
            self.assertIs(type(e), AssertionError)
            self.assertEqual(e.args[0],
                             'Unsupported command: CommandType.NewTrialJob')

        self.assertEqual(_trials, ['A', 'B', 'A'])
        self.assertEqual(_end_trials, [('A', False), ('B', True)])

        _reverse_io()
        command, data = receive()
        self.assertIs(command, CommandType.KillTrialJob)
        self.assertEqual(data, '"A"')
        self.assertEqual(len(_out_buf.read()), 0)
    def test_msg_dispatcher(self):
        _reverse_io()  # now we are sending to Tuner's incoming stream
        send(CommandType.RequestTrialJobs, '2')
        send(CommandType.ReportMetricData,
             '{"parameter_id":0,"type":"PERIODICAL","value":10}')
        send(CommandType.ReportMetricData,
             '{"parameter_id":1,"type":"FINAL","value":11}')
        send(CommandType.UpdateSearchSpace, '{"name":"SS0"}')
        send(CommandType.RequestTrialJobs, '1')
        send(CommandType.KillTrialJob, 'null')
        _restore_io()

        tuner = NaiveTuner()
        dispatcher = MsgDispatcher(tuner)
        nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False

        dispatcher.run()
        e = dispatcher.worker_exceptions[0]
        self.assertIs(type(e), AssertionError)
        self.assertEqual(e.args[0],
                         'Unsupported command: CommandType.KillTrialJob')

        _reverse_io()  # now we are receiving from Tuner's outgoing stream
        self._assert_params(0, 2, [], None)
        self._assert_params(1, 4, [], None)

        self._assert_params(2, 6, [[1, 4, 11, False]], {'name': 'SS0'})

        self.assertEqual(len(_out_buf.read()), 0)  # no more commands
Exemple #3
0
def main():
    '''
    main function.
    '''

    args = parse_args()

    tuner = None
    assessor = None

    if args.tuner_class_name in ModuleName:
        tuner = create_builtin_class_instance(args.tuner_class_name,
                                              args.tuner_args)
    else:
        tuner = create_customized_class_instance(args.tuner_directory,
                                                 args.tuner_class_filename,
                                                 args.tuner_class_name,
                                                 args.tuner_args)

    if tuner is None:
        raise AssertionError('Failed to create Tuner instance')

    if args.assessor_class_name:
        if args.assessor_class_name in ModuleName:
            assessor = create_builtin_class_instance(args.assessor_class_name,
                                                     args.assessor_args)
        else:
            assessor = create_customized_class_instance(
                args.assessor_directory, args.assessor_class_filename,
                args.assessor_class_name, args.assessor_args)
        if assessor is None:
            raise AssertionError('Failed to create Assessor instance')

    dispatcher = MsgDispatcher(tuner, assessor)

    try:
        dispatcher.run()
        tuner._on_exit()
        if assessor is not None:
            assessor._on_exit()
    except Exception as exception:
        logger.exception(exception)
        tuner._on_error()
        if assessor is not None:
            assessor._on_error()
        raise
Exemple #4
0
    def test_tuner(self):
        _reverse_io()  # now we are sending to Tuner's incoming stream
        send(CommandType.RequestTrialJobs, '2')
        send(CommandType.ReportMetricData,
             '{"parameter_id":0,"type":"PERIODICAL","value":10}')
        send(CommandType.ReportMetricData,
             '{"parameter_id":1,"type":"FINAL","value":11}')
        send(CommandType.UpdateSearchSpace, '{"name":"SS0"}')
        send(CommandType.AddCustomizedTrialJob, '{"param":-1}')
        send(CommandType.ReportMetricData,
             '{"parameter_id":2,"type":"FINAL","value":22}')
        send(CommandType.RequestTrialJobs, '1')
        send(CommandType.KillTrialJob, 'null')
        _restore_io()

        tuner = NaiveTuner()
        dispatcher = MsgDispatcher(tuner)
        nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False

        dispatcher.run()
        e = dispatcher.worker_exceptions[0]
        self.assertIs(type(e), AssertionError)
        self.assertEqual(e.args[0],
                         'Unsupported command: CommandType.KillTrialJob')

        _reverse_io()  # now we are receiving from Tuner's outgoing stream
        self._assert_params(0, 2, [], None)
        self._assert_params(1, 4, [], None)

        command, data = receive()  # this one is customized
        data = json.loads(data)
        self.assertIs(command, CommandType.NewTrialJob)
        self.assertEqual(
            data, {
                'parameter_id': 2,
                'parameter_source': 'customized',
                'parameters': {
                    'param': -1
                }
            })

        self._assert_params(3, 6, [[1, 4, 11, False], [2, -1, 22, True]],
                            {'name': 'SS0'})

        self.assertEqual(len(_out_buf.read()), 0)  # no more commands
Exemple #5
0
def main():
    '''
    main function.
    '''

    args = parse_args()
    if args.multi_thread:
        enable_multi_thread()
    if args.multi_phase:
        enable_multi_phase()

    if args.advisor_class_name:
        # advisor is enabled and starts to run
        if args.advisor_class_name in AdvisorModuleName:
            dispatcher = create_builtin_class_instance(args.advisor_class_name,
                                                       args.advisor_args, True)
        else:
            dispatcher = create_customized_class_instance(
                args.advisor_directory, args.advisor_class_filename,
                args.advisor_class_name, args.advisor_args)
        if dispatcher is None:
            raise AssertionError('Failed to create Advisor instance')
        try:
            dispatcher.run()
        except Exception as exception:
            logger.exception(exception)
            raise
    else:
        # tuner (and assessor) is enabled and starts to run
        tuner = None
        assessor = None
        if args.tuner_class_name in ModuleName:
            tuner = create_builtin_class_instance(args.tuner_class_name,
                                                  args.tuner_args)
        else:
            tuner = create_customized_class_instance(args.tuner_directory,
                                                     args.tuner_class_filename,
                                                     args.tuner_class_name,
                                                     args.tuner_args)

        if tuner is None:
            raise AssertionError('Failed to create Tuner instance')

        if args.assessor_class_name:
            if args.assessor_class_name in ModuleName:
                assessor = create_builtin_class_instance(
                    args.assessor_class_name, args.assessor_args)
            else:
                assessor = create_customized_class_instance(
                    args.assessor_directory, args.assessor_class_filename,
                    args.assessor_class_name, args.assessor_args)
            if assessor is None:
                raise AssertionError('Failed to create Assessor instance')

        dispatcher = MsgDispatcher(tuner, assessor)

        try:
            dispatcher.run()
            tuner._on_exit()
            if assessor is not None:
                assessor._on_exit()
        except Exception as exception:
            logger.exception(exception)
            tuner._on_error()
            if assessor is not None:
                assessor._on_error()
            raise