Пример #1
0
def _get_atm(args):
    sql_conf = SQLConfig(args)
    aws_conf = AWSConfig(args)
    log_conf = LogConfig(args)

    return ATM(**sql_conf.to_dict(), **aws_conf.to_dict(),
               **log_conf.to_dict())
Пример #2
0
def test_run_per_partition(dataset):
    sql_conf = SQLConfig({'sql_database': DB_PATH})
    db = Database(**sql_conf.to_dict())

    run_conf = RunConfig(
        {
            'dataset_id': dataset.id,
            'methods': ['logreg'],
            'run_per_partition': True
        }
    )

    atm = ATM(sql_conf, None, None)

    run_ids = atm.enter_data(None, run_conf)

    with db_session(db):
        runs = []
        for run_id in run_ids:
            run = db.get_datarun(run_id.id)
            if run is not None:
                runs.append(run)

        assert len(runs) == METHOD_HYPERPARTS['logreg']
        assert all([len(r.hyperpartitions) == 1 for r in runs])
Пример #3
0
Файл: cli.py Проект: zwcdp/ATM
def _get_atm(args):
    sql_conf = SQLConfig(args)
    aws_conf = AWSConfig(args)
    log_conf = LogConfig(args)

    # Build params dictionary to pass to ATM.
    # Needed because Python 2.7 does not support multiple star operators in a single statement.
    atm_args = sql_conf.to_dict()
    atm_args.update(aws_conf.to_dict())
    atm_args.update(log_conf.to_dict())

    return ATM(**atm_args)
Пример #4
0
def test_enter_data_all(dataset):
    sql_conf = SQLConfig({'sql_database': DB_PATH})
    db = Database(**sql_conf.to_dict())
    run_conf = RunConfig({'dataset_id': dataset.id, 'methods': METHOD_HYPERPARTS.keys()})

    atm = ATM(sql_conf, None, None)

    run_id = atm.enter_data(None, run_conf)

    with db_session(db):
        run = db.get_datarun(run_id.id)
        assert run.dataset.id == dataset.id
        assert len(run.hyperpartitions) == sum(METHOD_HYPERPARTS.values())
Пример #5
0
def test_enter_data_by_methods(dataset):
    sql_conf = SQLConfig({'sql_database': DB_PATH})
    db = Database(**sql_conf.to_dict())
    run_conf = RunConfig({'dataset_id': dataset.id})

    atm = ATM(sql_conf, None, None)

    for method, n_parts in METHOD_HYPERPARTS.items():
        run_conf.methods = [method]
        run_id = atm.enter_data(None, run_conf)

        with db_session(db):
            run = db.get_datarun(run_id.id)
            assert run.dataset.id == dataset.id
            assert len(run.hyperpartitions) == n_parts
Пример #6
0
def get_new_worker(**kwargs):
    kwargs['dataset_id'] = kwargs.get('dataset_id', None)
    kwargs['methods'] = kwargs.get('methods', ['logreg', 'dt'])
    sql_conf = SQLConfig({'sql_database': DB_PATH})
    run_conf = RunConfig(kwargs)

    dataset_conf = DatasetConfig(kwargs)

    db = Database(**sql_conf.to_dict())
    atm = ATM(sql_conf, None, None)

    run_id = atm.enter_data(dataset_conf, run_conf)
    datarun = db.get_datarun(run_id.id)

    return Worker(db, datarun)