Example #1
0
def test_run_per_partition(dataset):
    sql_conf = SQLConfig({'sql_database': DB_PATH})
    db = Database(**sql_conf.to_dict())

    run_conf = RunConfig(
        {
            'dataset_id': dataset.id,
            'methods': ['logreg'],
            'run_per_partition': True
        }
    )

    atm = ATM(sql_conf, None, None)

    run_ids = atm.enter_data(None, run_conf)

    with db_session(db):
        runs = []
        for run_id in run_ids:
            run = db.get_datarun(run_id.id)
            if run is not None:
                runs.append(run)

        assert len(runs) == METHOD_HYPERPARTS['logreg']
        assert all([len(r.hyperpartitions) == 1 for r in runs])
Example #2
0
File: cli.py Project: singh8477/ATM
def _get_atm(args):
    sql_conf = SQLConfig(args)
    aws_conf = AWSConfig(args)
    log_conf = LogConfig(args)

    return ATM(**sql_conf.to_dict(), **aws_conf.to_dict(),
               **log_conf.to_dict())
Example #3
0
File: cli.py Project: zwcdp/ATM
def _get_atm(args):
    sql_conf = SQLConfig(args)
    aws_conf = AWSConfig(args)
    log_conf = LogConfig(args)

    # Build params dictionary to pass to ATM.
    # Needed because Python 2.7 does not support multiple star operators in a single statement.
    atm_args = sql_conf.to_dict()
    atm_args.update(aws_conf.to_dict())
    atm_args.update(log_conf.to_dict())

    return ATM(**atm_args)
Example #4
0
def test_enter_data_all(dataset):
    sql_conf = SQLConfig({'sql_database': DB_PATH})
    db = Database(**sql_conf.to_dict())
    run_conf = RunConfig({'dataset_id': dataset.id, 'methods': METHOD_HYPERPARTS.keys()})

    atm = ATM(sql_conf, None, None)

    run_id = atm.enter_data(None, run_conf)

    with db_session(db):
        run = db.get_datarun(run_id.id)
        assert run.dataset.id == dataset.id
        assert len(run.hyperpartitions) == sum(METHOD_HYPERPARTS.values())
Example #5
0
def test_enter_data_by_methods(dataset):
    sql_conf = SQLConfig({'sql_database': DB_PATH})
    db = Database(**sql_conf.to_dict())
    run_conf = RunConfig({'dataset_id': dataset.id})

    atm = ATM(sql_conf, None, None)

    for method, n_parts in METHOD_HYPERPARTS.items():
        run_conf.methods = [method]
        run_id = atm.enter_data(None, run_conf)

        with db_session(db):
            run = db.get_datarun(run_id.id)
            assert run.dataset.id == dataset.id
            assert len(run.hyperpartitions) == n_parts
Example #6
0
def get_new_worker(**kwargs):
    kwargs['dataset_id'] = kwargs.get('dataset_id', None)
    kwargs['methods'] = kwargs.get('methods', ['logreg', 'dt'])
    sql_conf = SQLConfig({'sql_database': DB_PATH})
    run_conf = RunConfig(kwargs)

    dataset_conf = DatasetConfig(kwargs)

    db = Database(**sql_conf.to_dict())
    atm = ATM(sql_conf, None, None)

    run_id = atm.enter_data(dataset_conf, run_conf)
    datarun = db.get_datarun(run_id.id)

    return Worker(db, datarun)
Example #7
0
def get_new_worker(**kwargs):
    kwargs['methods'] = kwargs.get('methods', ['logreg', 'dt'])
    sql_conf = SQLConfig(database=DB_PATH)
    run_conf = RunConfig(**kwargs)
    run_id = enter_data(sql_conf, run_conf)
    db = Database(**vars(sql_conf))
    datarun = db.get_datarun(run_id)
    return Worker(db, datarun)
def test_enter_data_all(dataset):
    sql_conf = SQLConfig(database=DB_PATH)
    db = Database(**vars(sql_conf))
    run_conf = RunConfig(dataset_id=dataset.id,
                         methods=METHOD_HYPERPARTS.keys())

    run_id = enter_data(sql_conf, run_conf)

    with db_session(db):
        run = db.get_datarun(run_id)
        assert run.dataset.id == dataset.id
        assert len(run.hyperpartitions) == sum(METHOD_HYPERPARTS.values())
def test_enter_data_by_methods(dataset):
    sql_conf = SQLConfig(database=DB_PATH)
    db = Database(**vars(sql_conf))
    run_conf = RunConfig(dataset_id=dataset.id)

    for method, n_parts in METHOD_HYPERPARTS.items():
        run_conf.methods = [method]
        run_id = enter_data(sql_conf, run_conf)

        assert db.get_datarun(run_id)
        with db_session(db):
            run = db.get_datarun(run_id)
            assert run.dataset.id == dataset.id
            assert len(run.hyperpartitions) == n_parts
def test_run_per_partition(dataset):
    sql_conf = SQLConfig(database=DB_PATH)
    db = Database(**vars(sql_conf))
    run_conf = RunConfig(dataset_id=dataset.id, methods=['logreg'])

    run_ids = enter_data(sql_conf, run_conf, run_per_partition=True)

    with db_session(db):
        runs = []
        for run_id in run_ids:
            run = db.get_datarun(run_id)
            if run is not None:
                runs.append(run)

        assert len(runs) == METHOD_HYPERPARTS['logreg']
        assert all([len(run.hyperpartitions) == 1 for run in runs])
Example #11
0
def test_create_dataset(db):
    train_url = DATA_URL + 'pollution_1_train.csv'
    test_url = DATA_URL + 'pollution_1_test.csv'

    sql_conf = SQLConfig({'sql_database': DB_PATH})

    train_path_local = get_local_path('pollution_test.csv', train_url, None)
    if os.path.exists(train_path_local):
        os.remove(train_path_local)

    test_path_local = get_local_path('pollution_test_test.csv', test_url, None)
    if os.path.exists(test_path_local):
        os.remove(test_path_local)

    dataset_conf = DatasetConfig({
        'name': 'pollution_test',
        'train_path': train_url,
        'test_path': test_url,
        'data_description': 'test',
        'class_column': 'class'
    })

    atm = ATM(sql_conf, None, None)

    dataset = atm.create_dataset(dataset_conf)
    dataset = db.get_dataset(dataset.id)

    train, test = dataset.load()  # This will create the test_path_local

    assert os.path.exists(train_path_local)
    assert os.path.exists(test_path_local)

    assert dataset.train_path == train_url
    assert dataset.test_path == test_url
    assert dataset.description == 'test'
    assert dataset.class_column == 'class'
    assert dataset.n_examples == 40
    assert dataset.d_features == 16
    assert dataset.k_classes == 2
    assert dataset.majority >= 0.5

    # remove test dataset
    if os.path.exists(train_path_local):
        os.remove(train_path_local)

    if os.path.exists(test_path_local):
        os.remove(test_path_local)
Example #12
0
def _get_parser():
    logging_args = argparse.ArgumentParser(add_help=False)
    logging_args.add_argument('-v', '--verbose', action='count', default=0)
    logging_args.add_argument('-l', '--logfile')

    parser = argparse.ArgumentParser(description='ATM Command Line Interface',
                                     parents=[logging_args])

    subparsers = parser.add_subparsers(title='action', help='Action to perform')
    parser.set_defaults(action=None)

    # Common Arguments
    sql_args = SQLConfig.get_parser()
    aws_args = AWSConfig.get_parser()
    log_args = LogConfig.get_parser()
    run_args = RunConfig.get_parser()
    dataset_args = DatasetConfig.get_parser()

    # Enter Data Parser
    enter_data_parents = [
        logging_args,
        sql_args,
        aws_args,
        dataset_args,
        log_args,
        run_args
    ]
    enter_data = subparsers.add_parser('enter_data', parents=enter_data_parents,
                                       help='Add a Dataset and trigger a Datarun on it.')
    enter_data.set_defaults(action=_enter_data)

    # Wroker Args
    worker_args = argparse.ArgumentParser(add_help=False)
    worker_args.add_argument('--cloud-mode', action='store_true', default=False,
                             help='Whether to run this worker in cloud mode')
    worker_args.add_argument('--no-save', dest='save_files', action='store_false',
                             help="don't save models and metrics at all")

    # Worker
    worker_parents = [
        logging_args,
        worker_args,
        sql_args,
        aws_args,
        log_args
    ]
    worker = subparsers.add_parser('worker', parents=worker_parents,
                                   help='Start a single worker in foreground.')
    worker.set_defaults(action=_work)
    worker.add_argument('--dataruns', help='Only train on dataruns with these ids', nargs='+')
    worker.add_argument('--total-time', help='Number of seconds to run worker', type=int)

    # Server Args
    server_args = argparse.ArgumentParser(add_help=False)
    server_args.add_argument('--host', help='IP to listen at')
    server_args.add_argument('--port', help='Port to listen at', type=int)

    # Server
    server = subparsers.add_parser('server', parents=[logging_args, server_args, sql_args],
                                   help='Start the REST API Server in foreground.')
    server.set_defaults(action=_serve)
    server.add_argument('--debug', help='Start in debug mode', action='store_true')
    # add_arguments_sql(server)

    # Background Args
    background_args = argparse.ArgumentParser(add_help=False)
    background_args.add_argument('--pid', help='PID file to use.', default='atm.pid')

    # Start Args
    start_args = argparse.ArgumentParser(add_help=False)
    start_args.add_argument('--foreground', action='store_true', help='Run on foreground')
    start_args.add_argument('-w', '--workers', default=1, type=int, help='Number of workers')
    start_args.add_argument('--no-server', dest='server', action='store_false',
                            help='Do not start the REST server')

    # Start
    start_parents = [
        logging_args,
        worker_args,
        server_args,
        background_args,
        start_args,
        sql_args,
        aws_args,
        log_args
    ]
    start = subparsers.add_parser('start', parents=start_parents,
                                  help='Start an ATM Local Cluster.')
    start.set_defaults(action=_start)

    # Status
    status = subparsers.add_parser('status', parents=[logging_args, background_args])
    status.set_defaults(action=_status)

    # Stop Args
    stop_args = argparse.ArgumentParser(add_help=False)
    stop_args.add_argument('-t', '--timeout', default=5, type=int,
                           help='Seconds to wait before killing the process.')
    stop_args.add_argument('-f', '--force', action='store_true',
                           help='Kill the process if it does not terminate gracefully.')

    # Stop
    stop = subparsers.add_parser('stop', parents=[logging_args, stop_args, background_args],
                                 help='Stop an ATM Local Cluster.')
    stop.set_defaults(action=_stop)

    # restart
    restart = subparsers.add_parser('restart', parents=start_parents + [stop_args],
                                    help='Restart an ATM Local Cluster.')
    restart.set_defaults(action=_restart)

    # Make Config
    make_config = subparsers.add_parser('make_config', parents=[logging_args],
                                        help='Generate a config templates folder in the cwd.')
    make_config.set_defaults(action=_make_config)

    # Get Demos
    get_demos = subparsers.add_parser('get_demos', parents=[logging_args],
                                      help='Generate a demos folder with demo CSVs in the cwd.')
    get_demos.set_defaults(action=_get_demos)

    return parser
Example #13
0
def _get_atm(args):
    sql_conf = SQLConfig(args)
    aws_conf = AWSConfig(args)
    log_conf = LogConfig(args)
    return ATM(sql_conf, aws_conf, log_conf)