Пример #1
0
def _get_atm(args):
    sql_conf = SQLConfig(args)
    aws_conf = AWSConfig(args)
    log_conf = LogConfig(args)

    return ATM(**sql_conf.to_dict(), **aws_conf.to_dict(),
               **log_conf.to_dict())
Пример #2
0
Файл: cli.py Проект: zwcdp/ATM
def _get_atm(args):
    sql_conf = SQLConfig(args)
    aws_conf = AWSConfig(args)
    log_conf = LogConfig(args)

    # Build params dictionary to pass to ATM.
    # Needed because Python 2.7 does not support multiple star operators in a single statement.
    atm_args = sql_conf.to_dict()
    atm_args.update(aws_conf.to_dict())
    atm_args.update(log_conf.to_dict())

    return ATM(**atm_args)
Пример #3
0
    def __init__(self,
                 database,
                 datarun,
                 save_files=True,
                 cloud_mode=False,
                 aws_config=None,
                 log_config=None,
                 public_ip='localhost'):
        """
        database: Database object with connection information
        datarun: Datarun ORM object to work on.
        save_files: if True, save model and metrics files to disk or cloud.
        cloud_mode: if True, save classifiers to the cloud
        aws_config: S3Config object with amazon s3 connection info
        """
        self.db = database
        self.datarun = datarun
        self.save_files = save_files
        self.cloud_mode = cloud_mode
        self.aws_config = aws_config
        self.public_ip = public_ip

        log_config = log_config or LogConfig()
        self.model_dir = log_config.model_dir
        self.metric_dir = log_config.metric_dir
        self.verbose_metrics = log_config.verbose_metrics
        ensure_directory(self.model_dir)
        ensure_directory(self.metric_dir)

        # load the Dataset from the database
        self.dataset = self.db.get_dataset(self.datarun.dataset_id)

        # load the Selector and Tuner classes specified by our datarun
        self.load_selector()
        self.load_tuner()
Пример #4
0
def test_save_classifier(db, datarun, model, metrics):
    log_conf = LogConfig(model_dir=MODEL_DIR, metric_dir=METRIC_DIR)
    worker = Worker(db, datarun, log_config=log_conf)
    hp = db.get_hyperpartitions(datarun_id=worker.datarun.id)[0]
    classifier = worker.db.start_classifier(hyperpartition_id=hp.id,
                                            datarun_id=worker.datarun.id,
                                            host='localhost',
                                            hyperparameter_values=DT_PARAMS)

    worker.db.complete_classifier = Mock()
    worker.save_classifier(classifier.id, model, metrics)
    worker.db.complete_classifier.assert_called()

    with db_session(worker.db):
        clf = db.get_classifier(classifier.id)

        loaded = load_model(clf, MODEL_DIR)
        assert type(loaded) == Model
        assert loaded.method == model.method
        assert loaded.random_state == model.random_state

        assert load_metrics(clf, METRIC_DIR) == metrics
Пример #5
0
def _get_parser():
    logging_args = argparse.ArgumentParser(add_help=False)
    logging_args.add_argument('-v', '--verbose', action='count', default=0)
    logging_args.add_argument('-l', '--logfile')

    parser = argparse.ArgumentParser(description='ATM Command Line Interface',
                                     parents=[logging_args])

    subparsers = parser.add_subparsers(title='action', help='Action to perform')
    parser.set_defaults(action=None)

    # Common Arguments
    sql_args = SQLConfig.get_parser()
    aws_args = AWSConfig.get_parser()
    log_args = LogConfig.get_parser()
    run_args = RunConfig.get_parser()
    dataset_args = DatasetConfig.get_parser()

    # Enter Data Parser
    enter_data_parents = [
        logging_args,
        sql_args,
        aws_args,
        dataset_args,
        log_args,
        run_args
    ]
    enter_data = subparsers.add_parser('enter_data', parents=enter_data_parents,
                                       help='Add a Dataset and trigger a Datarun on it.')
    enter_data.set_defaults(action=_enter_data)

    # Wroker Args
    worker_args = argparse.ArgumentParser(add_help=False)
    worker_args.add_argument('--cloud-mode', action='store_true', default=False,
                             help='Whether to run this worker in cloud mode')
    worker_args.add_argument('--no-save', dest='save_files', action='store_false',
                             help="don't save models and metrics at all")

    # Worker
    worker_parents = [
        logging_args,
        worker_args,
        sql_args,
        aws_args,
        log_args
    ]
    worker = subparsers.add_parser('worker', parents=worker_parents,
                                   help='Start a single worker in foreground.')
    worker.set_defaults(action=_work)
    worker.add_argument('--dataruns', help='Only train on dataruns with these ids', nargs='+')
    worker.add_argument('--total-time', help='Number of seconds to run worker', type=int)

    # Server Args
    server_args = argparse.ArgumentParser(add_help=False)
    server_args.add_argument('--host', help='IP to listen at')
    server_args.add_argument('--port', help='Port to listen at', type=int)

    # Server
    server = subparsers.add_parser('server', parents=[logging_args, server_args, sql_args],
                                   help='Start the REST API Server in foreground.')
    server.set_defaults(action=_serve)
    server.add_argument('--debug', help='Start in debug mode', action='store_true')
    # add_arguments_sql(server)

    # Background Args
    background_args = argparse.ArgumentParser(add_help=False)
    background_args.add_argument('--pid', help='PID file to use.', default='atm.pid')

    # Start Args
    start_args = argparse.ArgumentParser(add_help=False)
    start_args.add_argument('--foreground', action='store_true', help='Run on foreground')
    start_args.add_argument('-w', '--workers', default=1, type=int, help='Number of workers')
    start_args.add_argument('--no-server', dest='server', action='store_false',
                            help='Do not start the REST server')

    # Start
    start_parents = [
        logging_args,
        worker_args,
        server_args,
        background_args,
        start_args,
        sql_args,
        aws_args,
        log_args
    ]
    start = subparsers.add_parser('start', parents=start_parents,
                                  help='Start an ATM Local Cluster.')
    start.set_defaults(action=_start)

    # Status
    status = subparsers.add_parser('status', parents=[logging_args, background_args])
    status.set_defaults(action=_status)

    # Stop Args
    stop_args = argparse.ArgumentParser(add_help=False)
    stop_args.add_argument('-t', '--timeout', default=5, type=int,
                           help='Seconds to wait before killing the process.')
    stop_args.add_argument('-f', '--force', action='store_true',
                           help='Kill the process if it does not terminate gracefully.')

    # Stop
    stop = subparsers.add_parser('stop', parents=[logging_args, stop_args, background_args],
                                 help='Stop an ATM Local Cluster.')
    stop.set_defaults(action=_stop)

    # restart
    restart = subparsers.add_parser('restart', parents=start_parents + [stop_args],
                                    help='Restart an ATM Local Cluster.')
    restart.set_defaults(action=_restart)

    # Make Config
    make_config = subparsers.add_parser('make_config', parents=[logging_args],
                                        help='Generate a config templates folder in the cwd.')
    make_config.set_defaults(action=_make_config)

    # Get Demos
    get_demos = subparsers.add_parser('get_demos', parents=[logging_args],
                                      help='Generate a demos folder with demo CSVs in the cwd.')
    get_demos.set_defaults(action=_get_demos)

    return parser
Пример #6
0
def _get_atm(args):
    sql_conf = SQLConfig(args)
    aws_conf = AWSConfig(args)
    log_conf = LogConfig(args)
    return ATM(sql_conf, aws_conf, log_conf)