Ejemplo n.º 1
0
    def test_call_no_command(self):
        out_path = os.path.join(self._dir, 'results')
        os.makedirs(out_path)
        commands_path = os.path.join(out_path, 'commands')
        open(commands_path, 'w').close()
        assert os.path.isfile(commands_path)

        target = CommandsExtension()
        trainer = _MockTrainer(out_path)
        target.initialize(trainer)
        assert type(trainer.stop_trigger) is _CommandIntervalTrigger
        assert not os.path.isfile(commands_path)
        assert CommandsState.job_status(out_path) == JobStatus.RUNNING

        target.finalize()
        assert CommandsState.job_status(out_path) == JobStatus.STOPPED
Ejemplo n.º 2
0
    def post(self, result_id, project_id):
        """POST /api/v1/results/<int:id>/commands."""

        result = db.session.query(Result).filter_by(id=result_id).first()

        if result is None:
            return jsonify({
                'result': None,
                'message': 'No interface defined for URL.'
            }), 404

        job_status = CommandsState.job_status(result.path_name)
        if job_status != JobStatus.RUNNING:
            if job_status == JobStatus.NO_EXTENSION_ERROR:
                return jsonify({
                    'message':
                    '\'CommandsExtension\' is not set or disabled.'
                }), 400
            elif job_status == JobStatus.INITIALIZED:
                return jsonify(
                    {'message':
                     'The target training job has not run, yet'}), 400
            elif job_status == JobStatus.STOPPED:
                return jsonify(
                    {'message':
                     'The target training job has already stopped'}), 400
            else:
                return jsonify(
                    {'message':
                     'Cannot get the target training job status'}), 400

        request_json = request.get_json()
        if request_json is None:
            return jsonify({'message': 'Empty request.'}), 400

        command_name = request_json.get('name', None)
        if command_name is None:
            return jsonify({'message': 'Name is required.'}), 400

        schedule = request_json.get('schedule', None)
        if not CommandItem.is_valid_schedule(schedule):
            return jsonify({'message': 'Schedule is invalid.'}), 400

        command = CommandItem(name=command_name, )

        command.set_request(CommandItem.REQUEST_OPEN,
                            request_json.get('body', None),
                            request_json.get('schedule', None))

        commands = CommandItem.load_commands(result.path_name)
        commands.append(command)

        CommandItem.dump_commands(commands, result.path_name)

        new_result = crawl_result(result, force=True)
        new_result_dict = new_result.serialize

        return jsonify({'commands': new_result_dict['commands']})
Ejemplo n.º 3
0
    def test_call(self):
        out_path = os.path.join(self.dir, 'results')
        os.makedirs(out_path)
        commands_path = os.path.join(out_path, 'commands')
        open(commands_path, 'w').close()
        assert os.path.isfile(commands_path)

        # initialize
        target = CommandsExtension(trigger=(1, 'iteration'))
        trainer = _MockTrainer(out_path)
        target.initialize(trainer)
        assert not trainer.stop_trigger._loop_stop
        assert not os.path.isfile(commands_path)
        assert CommandsState.job_status(out_path) == JobStatus.RUNNING

        # setup valid command
        commands = CommandItem.load_commands(out_path)
        command = CommandItem(name='take_snapshot')
        command.set_request(CommandItem.REQUEST_OPEN, None, None)
        commands.append(command)
        command2 = CommandItem(name='stop')
        command2.set_request(CommandItem.REQUEST_OPEN, None, {
            'key': 'epoch',
            'value': 10
        })
        commands.append(command2)
        command3 = CommandItem(name='adjust_hyperparams')
        command3.set_request(
            CommandItem.REQUEST_OPEN, {
                'optimizer': 'MomentumSGD',
                'hyperparam': {
                    'lr': 0.01,
                    'beta': None,
                    'gamma': 1.0
                }
            }, {
                'key': 'iteration',
                'value': 10
            })
        commands.append(command3)
        CommandItem.dump_commands(commands, out_path)

        # call but skip by interval trigger
        target(trainer)
        commands = CommandItem.load_commands(out_path)
        assert len(commands) == 3
        assert commands[0].response is None
        assert commands[1].response is None
        assert commands[2].response is None

        # call 'take_sanpshot'
        trainer.updater.iteration = 1
        target(trainer)
        commands = CommandItem.load_commands(out_path)
        assert len(commands) == 3
        res = commands[0].response
        assert res['epoch'] == 0
        assert res['iteration'] == 1
        assert res['status'] == CommandItem.RESPONSE_SUCCESS
        assert commands[1].response is None
        assert commands[2].response is None

        # call 'adjust_hyperparams'
        trainer.updater.iteration = 10
        target(trainer)
        commands = CommandItem.load_commands(out_path)
        assert len(commands) == 3
        res = commands[2].response
        assert res['epoch'] == 0
        assert res['iteration'] == 10
        assert res['status'] == CommandItem.RESPONSE_SUCCESS
        assert res['body'] is not None
        assert res['body']['optimizer'] == 'MomentumSGD'
        assert res['body']['hyperparam'] == {'lr': 0.01}
        assert commands[1].response is None

        # call 'stop'
        trainer.updater.iteration = 100
        trainer.updater.epoch = 10
        target(trainer)
        commands = CommandItem.load_commands(out_path)
        assert len(commands) == 3
        res = commands[1].response
        assert res['epoch'] == 10
        assert res['iteration'] == 100
        assert res['status'] == CommandItem.RESPONSE_SUCCESS
        assert res['body'] is None
        assert trainer.stop_trigger._loop_stop

        target.finalize()
        assert CommandsState.job_status(out_path) == JobStatus.STOPPED