コード例 #1
0
ファイル: test_worker.py プロジェクト: wuqixiaobai/ATM
def test_is_datarun_finished(db, dataset, datarun):
    r1 = db.get_datarun(1)
    worker = Worker(db, r1)
    assert worker.is_datarun_finished()

    r2 = db.get_datarun(2)
    worker = Worker(db, r2)
    assert not worker.is_datarun_finished()

    deadline = (datetime.datetime.now() - datetime.timedelta(seconds=1)).strftime(TIME_FMT)
    worker = get_new_worker(deadline=deadline)
    assert worker.is_datarun_finished()
コード例 #2
0
def get_new_worker(**kwargs):
    kwargs['methods'] = kwargs.get('methods', ['logreg', 'dt'])
    sql_conf = SQLConfig(database=DB_PATH)
    run_conf = RunConfig(**kwargs)
    run_id = enter_data(sql_conf, run_conf)
    db = Database(**vars(sql_conf))
    datarun = db.get_datarun(run_id)
    return Worker(db, datarun)
コード例 #3
0
def test_save_classifier(db, datarun, model, metrics):
    worker = Worker(db, datarun, models_dir=MODEL_DIR, metrics_dir=METRIC_DIR)
    hp = db.get_hyperpartitions(datarun_id=worker.datarun.id)[0]
    classifier = worker.db.start_classifier(hyperpartition_id=hp.id,
                                            datarun_id=worker.datarun.id,
                                            host='localhost',
                                            hyperparameter_values=DT_PARAMS)

    worker.db.complete_classifier = Mock()
    worker.save_classifier(classifier.id, model, metrics)
    worker.db.complete_classifier.assert_called()

    with DBSession(worker.db):
        clf = db.get_classifier(classifier.id)

        loaded = load_model(clf, MODEL_DIR)
        assert isinstance(loaded, Model)
        assert loaded.method == model.method
        assert loaded.random_state == model.random_state

        assert load_metrics(clf, METRIC_DIR) == metrics
コード例 #4
0
ファイル: test_worker.py プロジェクト: wuqixiaobai/ATM
def test_save_classifier(db, datarun, model, metrics):
    log_conf = LogConfig(model_dir=MODEL_DIR, metric_dir=METRIC_DIR)
    worker = Worker(db, datarun, log_config=log_conf)
    hp = db.get_hyperpartitions(datarun_id=worker.datarun.id)[0]
    classifier = worker.db.start_classifier(hyperpartition_id=hp.id,
                                            datarun_id=worker.datarun.id,
                                            host='localhost',
                                            hyperparameter_values=DT_PARAMS)

    worker.db.complete_classifier = Mock()
    worker.save_classifier(classifier.id, model, metrics)
    worker.db.complete_classifier.assert_called()

    with db_session(worker.db):
        clf = db.get_classifier(classifier.id)

        loaded = load_model(clf, MODEL_DIR)
        assert type(loaded) == Model
        assert loaded.method == model.method
        assert loaded.random_state == model.random_state

        assert load_metrics(clf, METRIC_DIR) == metrics
コード例 #5
0
ファイル: test_worker.py プロジェクト: zenghanfu/ATM
def test_is_datarun_finished(db, dataset, datarun):
    r1 = db.get_datarun(1)
    worker = Worker(db, r1)
    assert worker.is_datarun_finished()

    r2 = db.get_datarun(2)
    worker = Worker(db, r2)
    assert not worker.is_datarun_finished()

    deadline = (datetime.datetime.now() - datetime.timedelta(seconds=1)).strftime(TIME_FMT)
    worker = get_new_worker(deadline=deadline)
    assert worker.is_datarun_finished()
コード例 #6
0
def get_new_worker(**kwargs):
    kwargs['dataset_id'] = kwargs.get('dataset_id', None)
    kwargs['methods'] = kwargs.get('methods', ['logreg', 'dt'])
    sql_conf = SQLConfig({'sql_database': DB_PATH})
    run_conf = RunConfig(kwargs)

    dataset_conf = DatasetConfig(kwargs)

    db = Database(**sql_conf.to_dict())
    atm = ATM(sql_conf, None, None)

    run_id = atm.enter_data(dataset_conf, run_conf)
    datarun = db.get_datarun(run_id.id)

    return Worker(db, datarun)
コード例 #7
0
def get_new_worker(**kwargs):
    kwargs['dataset_id'] = kwargs.get('dataset_id', None)
    kwargs['methods'] = kwargs.get('methods', ['logreg', 'dt'])
    run_conf = RunConfig(kwargs)

    kwargs['train_path'] = POLLUTION_PATH
    dataset_conf = DatasetConfig(kwargs)

    db = Database(dialect='sqlite', database=DB_PATH)
    atm = ATM(dialect='sqlite', database=DB_PATH)

    dataset = atm.add_dataset(**dataset_conf.to_dict())
    run_conf.dataset_id = dataset.id
    datarun = atm.add_datarun(**run_conf.to_dict())

    return Worker(db, datarun)
コード例 #8
0
ファイル: core.py プロジェクト: singh8477/ATM
    def work(self, datarun_ids=None, save_files=True, choose_randomly=True,
             cloud_mode=False, total_time=None, wait=True, verbose=False):
        """
        Check the ModelHub database for unfinished dataruns, and spawn workers to
        work on them as they are added. This process will continue to run until it
        exceeds total_time or is broken with ctrl-C.

        datarun_ids (optional): list of IDs of dataruns to compute on. If None,
            this will work on all unfinished dataruns in the database.
        choose_randomly: if True, work on all highest-priority dataruns in random
            order. If False, work on them in sequential order (by ID)
        cloud_mode: if True, save processed datasets to AWS.
        total_time (optional): if set to an integer, this worker will only work for
            total_time seconds. Otherwise, it will continue working until all
            dataruns are complete (or indefinitely).
        wait: if True, once all dataruns in the database are complete, keep spinning
            and wait for new runs to be added. If False, exit once all dataruns are
            complete.
        """
        start_time = datetime.now()

        # main loop
        while True:
            # get all pending and running dataruns, or all pending/running dataruns
            # from the list we were given
            dataruns = self.db.get_dataruns(include_ids=datarun_ids, ignore_complete=True)
            if not dataruns:
                if wait:
                    LOGGER.debug('No dataruns found. Sleeping %d seconds and trying again.',
                                 ATM.LOOP_WAIT)
                    time.sleep(ATM.LOOP_WAIT)
                    continue

                else:
                    LOGGER.info('No dataruns found. Exiting.')
                    break

            # either choose a run randomly between priority, or take the run with the lowest ID
            if choose_randomly:
                run = random.choice(dataruns)
            else:
                run = sorted(dataruns, key=attrgetter('id'))[0]

            # say we've started working on this datarun, if we haven't already
            self.db.mark_datarun_running(run.id)

            LOGGER.info('Computing on datarun %d' % run.id)
            # actual work happens here
            worker = Worker(self.db, run, save_files=save_files,
                            cloud_mode=cloud_mode, aws_access_key=self.aws_access_key,
                            aws_secret_key=self.aws_secret_key, s3_bucket=self.s3_bucket,
                            s3_folder=self.s3_folder, models_dir=self.models_dir,
                            metrics_dir=self.metrics_dir, verbose_metrics=self.verbose_metrics)
            try:
                if run.budget_type == 'classifier':
                    pbar = tqdm(
                        total=run.budget,
                        ascii=True,
                        initial=run.completed_classifiers,
                        disable=not verbose
                    )

                    while run.status != RunStatus.COMPLETE:
                        worker.run_classifier()
                        run = self.db.get_datarun(run.id)
                        if verbose and run.completed_classifiers > pbar.last_print_n:
                            pbar.update(run.completed_classifiers - pbar.last_print_n)

                    pbar.close()

                elif run.budget_type == 'walltime':
                    pbar = tqdm(
                        disable=not verbose,
                        ascii=True,
                        initial=run.completed_classifiers,
                        unit=' Classifiers'
                    )

                    while run.status != RunStatus.COMPLETE:
                        worker.run_classifier()
                        run = self.db.get_datarun(run.id)  # Refresh the datarun object.
                        if verbose and run.completed_classifiers > pbar.last_print_n:
                            pbar.update(run.completed_classifiers - pbar.last_print_n)

                    pbar.close()

            except ClassifierError:
                # the exception has already been handled; just wait a sec so we
                # don't go out of control reporting errors
                LOGGER.error('Something went wrong. Sleeping %d seconds.', ATM.LOOP_WAIT)
                time.sleep(ATM.LOOP_WAIT)

            elapsed_time = (datetime.now() - start_time).total_seconds()
            if total_time is not None and elapsed_time >= total_time:
                LOGGER.info('Total run time for worker exceeded; exiting.')
                break
コード例 #9
0
    def work(self,
             datarun_ids=None,
             save_files=False,
             choose_randomly=True,
             cloud_mode=False,
             total_time=None,
             wait=True):
        """
        Check the ModelHub database for unfinished dataruns, and spawn workers to
        work on them as they are added. This process will continue to run until it
        exceeds total_time or is broken with ctrl-C.

        datarun_ids (optional): list of IDs of dataruns to compute on. If None,
            this will work on all unfinished dataruns in the database.
        choose_randomly: if True, work on all highest-priority dataruns in random
            order. If False, work on them in sequential order (by ID)
        cloud_mode: if True, save processed datasets to AWS. If this option is set,
            aws_config must be supplied.
        total_time (optional): if set to an integer, this worker will only work for
            total_time seconds. Otherwise, it will continue working until all
            dataruns are complete (or indefinitely).
        wait: if True, once all dataruns in the database are complete, keep spinning
            and wait for new runs to be added. If False, exit once all dataruns are
            complete.
        """
        start_time = datetime.now()
        public_ip = get_public_ip()

        # main loop
        while True:
            # get all pending and running dataruns, or all pending/running dataruns
            # from the list we were given
            dataruns = self.db.get_dataruns(include_ids=datarun_ids,
                                            ignore_complete=True)
            if not dataruns:
                if wait:
                    logger.warning(
                        'No dataruns found. Sleeping %d seconds and trying again.',
                        ATM.LOOP_WAIT)
                    time.sleep(ATM.LOOP_WAIT)
                    continue

                else:
                    logger.warning('No dataruns found. Exiting.')
                    break

            max_priority = max([datarun.priority for datarun in dataruns])
            priority_runs = [r for r in dataruns if r.priority == max_priority]

            # either choose a run randomly, or take the run with the lowest ID
            if choose_randomly:
                run = random.choice(priority_runs)
            else:
                run = sorted(dataruns, key=attrgetter('id'))[0]

            # say we've started working on this datarun, if we haven't already
            self.db.mark_datarun_running(run.id)

            logger.info('Computing on datarun %d' % run.id)
            # actual work happens here
            worker = Worker(self.db,
                            run,
                            save_files=save_files,
                            cloud_mode=cloud_mode,
                            aws_config=self.aws_conf,
                            log_config=self.log_conf,
                            public_ip=public_ip)
            try:
                worker.run_classifier()

            except ClassifierError:
                # the exception has already been handled; just wait a sec so we
                # don't go out of control reporting errors
                logger.warning('Something went wrong. Sleeping %d seconds.',
                               ATM.LOOP_WAIT)
                time.sleep(ATM.LOOP_WAIT)

            elapsed_time = (datetime.now() - start_time).total_seconds()
            if total_time is not None and elapsed_time >= total_time:
                logger.warning('Total run time for worker exceeded; exiting.')
                break
コード例 #10
0
def worker(db, datarun):
    return Worker(db, datarun)
コード例 #11
0
    def work(self, datarun_ids=None, save_files=True, choose_randomly=True,
             cloud_mode=False, total_time=None, wait=True, verbose=False):
        """Get unfinished Dataruns from the database and work on them.

        Check the ModelHub Database for unfinished Dataruns, and work on them
        as they are added. This process will continue to run until it exceeds
        total_time or there are no more Dataruns to process or it is killed.

        Args:
            datarun_ids (list):
                list of IDs of Dataruns to work on. If ``None``, this will work on any
                unfinished Dataruns found in the database. Optional. Defaults to ``None``.
            save_files (bool):
                Whether to save the fitted classifiers and their metrics or not.
                Optional. Defaults to True.
            choose_randomly (bool):
                If ``True``, work on all the highest-priority dataruns in random order.
                Otherwise, work on them in sequential order (by ID).
                Optional. Defaults to ``True``.
            cloud_mode (bool):
                Save the models and metrics in AWS S3 instead of locally. This option
                works only if S3 configuration has been provided on initialization.
                Optional. Defaults to ``False``.
            total_time (int):
                Total time to run the work process, in seconds. If ``None``, continue to
                run until interrupted or there are no more Dataruns to process.
                Optional. Defaults to ``None``.
            wait (bool):
                If ``True``, wait for more Dataruns to be inserted into the Database
                once all have been processed. Otherwise, exit the worker loop
                when they run out.
                Optional. Defaults to ``False``.
            verbose (bool):
                Whether to be verbose about the process. Optional. Defaults to ``True``.
        """
        start_time = datetime.now()

        # main loop
        while True:
            # get all pending and running dataruns, or all pending/running dataruns
            # from the list we were given
            dataruns = self.db.get_dataruns(include_ids=datarun_ids, ignore_complete=True)
            if not dataruns:
                if wait:
                    LOGGER.debug('No dataruns found. Sleeping %d seconds and trying again.',
                                 self._LOOP_WAIT)
                    time.sleep(self._LOOP_WAIT)
                    continue

                else:
                    LOGGER.info('No dataruns found. Exiting.')
                    break

            # either choose a run randomly between priority, or take the run with the lowest ID
            if choose_randomly:
                run = random.choice(dataruns)
            else:
                run = sorted(dataruns, key=attrgetter('id'))[0]

            # say we've started working on this datarun, if we haven't already
            self.db.mark_datarun_running(run.id)

            LOGGER.info('Computing on datarun %d' % run.id)
            # actual work happens here
            worker = Worker(self.db, run, save_files=save_files,
                            cloud_mode=cloud_mode, aws_access_key=self.aws_access_key,
                            aws_secret_key=self.aws_secret_key, s3_bucket=self.s3_bucket,
                            s3_folder=self.s3_folder, models_dir=self.models_dir,
                            metrics_dir=self.metrics_dir, verbose_metrics=self.verbose_metrics)

            try:
                if run.budget_type == 'classifier':
                    pbar = tqdm(
                        total=run.budget,
                        ascii=True,
                        initial=run.completed_classifiers,
                        disable=not verbose
                    )

                    while run.status != RunStatus.COMPLETE:
                        worker.run_classifier()
                        run = self.db.get_datarun(run.id)
                        if verbose and run.completed_classifiers > pbar.last_print_n:
                            pbar.update(run.completed_classifiers - pbar.last_print_n)

                    pbar.close()

                elif run.budget_type == 'walltime':
                    pbar = tqdm(
                        disable=not verbose,
                        ascii=True,
                        initial=run.completed_classifiers,
                        unit=' Classifiers'
                    )

                    while run.status != RunStatus.COMPLETE:
                        worker.run_classifier()
                        run = self.db.get_datarun(run.id)  # Refresh the datarun object.
                        if verbose and run.completed_classifiers > pbar.last_print_n:
                            pbar.update(run.completed_classifiers - pbar.last_print_n)

                    pbar.close()

            except ClassifierError:
                # the exception has already been handled; just wait a sec so we
                # don't go out of control reporting errors
                LOGGER.error('Something went wrong. Sleeping %d seconds.', self._LOOP_WAIT)
                time.sleep(self._LOOP_WAIT)

            elapsed_time = (datetime.now() - start_time).total_seconds()
            if total_time is not None and elapsed_time >= total_time:
                LOGGER.info('Total run time for worker exceeded; exiting.')
                break
コード例 #12
0
def get_datarun_steps_info(datarun_id,
                           classifier_start=None,
                           classifier_end=None,
                           nice=False):
    """
    Get the scores of the hyperpartitions/method in each step.
    :param datarun_id: the id of the datarun
    :param classifier_start: only return the scores of and after the `classifier_start` th classifier
    :param classifier_end: only return the scores before the `classifier_end` th classifier,
        Note that :classifier_start and :classifier_end are not ids, they starts from 1.
        (This is because the caller may not know the classifier ids of the datarun)
    :param nice: A flag for return nice format result
    :return:
        if nice is False,
        [
            {"1": 0.2, "2": 0.3, ...},
            ...
        ]
        if nice is True,
        [
            {
                "knn": [0.2, 0.3],
                "logreg": [0.1],
                ...
            },
            ...
        ]
    """
    if classifier_start is None:
        classifier_start = -np.inf
    if classifier_end is None:
        classifier_end = np.inf
    db = get_db()

    datarun = db.get_datarun(datarun_id=datarun_id)
    hyperpartitions = db.get_hyperpartitions(datarun_id=datarun_id)

    # load classifiers and build scores lists
    # make sure all hyperpartitions are present in the dict, even ones that
    # don't have any classifiers. That way the selector can choose hyperpartitions
    # that haven't been scored yet.
    hyperpartition_scores = {fs.id: [] for fs in hyperpartitions}
    classifiers = db.get_classifiers(datarun_id=datarun_id,
                                     status=ClassifierStatus.COMPLETE)
    selected_classifiers = [
        c for c in classifiers if c.hyperpartition_id in hyperpartition_scores
    ]
    # Create a temporary worker
    worker = Worker(db, datarun, public_ip=get_public_ip())
    bandit_scores_of_steps = []
    for i, c in enumerate(selected_classifiers):
        if i >= classifier_end:
            break
        # the cast to float is necessary because the score is a Decimal;
        # doing Decimal-float arithmetic throws errors later on.
        score = float(getattr(c, datarun.score_target) or 0)
        hyperpartition_scores[c.hyperpartition_id].append(score)
        bandit_scores = selector_bandit_scores(worker.selector,
                                               hyperpartition_scores)
        bandit_scores = {
            key: float("%.5f" % val)
            for key, val in bandit_scores.items()
        }
        if i < classifier_start:
            continue
        bandit_scores_of_steps.append(bandit_scores)
    # For a nicer formatted output
    if nice:
        results = []
        hp_id2method = {fs.id: fs.method for fs in hyperpartitions}
        for bandit_scores in bandit_scores_of_steps:
            res = defaultdict(list)
            for hp_id, score in bandit_scores.items():
                res[hp_id2method[hp_id]].append(score)
            results.append(res)
        return results

    return bandit_scores_of_steps