コード例 #1
0
def _validate(**kwargs):
    """
    Check each feed type and keep valid results
    """
    results = []
    version = kwargs["version"]

    for record_type in [mds.STATUS_CHANGES, mds.TRIPS]:
        datasource = common.get_data(record_type, **kwargs)

        if len(datasource) > 0:
            versions = set([d["version"] for d in datasource])

            if len(versions) > 1:
                expected, unexpected = mds.Version(
                    versions.pop()), mds.Version(versions.pop())
                error = mds.versions.UnexpectedVersionError(
                    expected, unexpected)
                results.append(
                    (record_type, expected, datasource, [], [error], []))
                continue

            version = mds.Version(version or versions.pop())

            try:
                valid, errors, removed = validate(record_type, datasource,
                                                  version)
                results.append(
                    (record_type, version, datasource, valid, errors, removed))
            except mds.versions.UnexpectedVersionError as unexpected_version:
                results.append((record_type, version, datasource, [],
                                [unexpected_version], []))

    return results
コード例 #2
0
def setup_cli():
    """
    Create the cli argument interface, and parses incoming args.

    Returns a tuple:
        - the argument parser
        - the parsed args
    """
    parser = argparse.ArgumentParser()

    parser.add_argument("--availability",
                        action="store_true",
                        help="Run the availability calculation.")
    parser.add_argument(
        "--cutoff",
        type=int,
        default=-1,
        help=
        "Maximum allowed length of a time-windowed event (e.g. availability window, trip), in days."
    )
    parser.add_argument("--debug",
                        action="store_true",
                        help="Print debug messages.")
    parser.add_argument(
        "--duration",
        type=int,
        help=
        "Number of seconds; with --start_time or --end_time, defines a time query range."
    )
    parser.add_argument(
        "--end",
        type=str,
        help="The end of the time query range for this request.\
        Should be either int Unix seconds or ISO-8601 datetime format.\
        At least one of end or start is required.")
    parser.add_argument("--local",
                        action="store_true",
                        help="Input and query times are local.")
    parser.add_argument(
        "--query",
        action="append",
        type=lambda kv: kv.split("=", 1),
        dest="queries",
        metavar="QUERY",
        help=
        "A {provider_name}={vehicle_type} pair; multiple pairs will be analyzed separately."
    )
    parser.add_argument(
        "--start",
        type=str,
        help="The beginning of the time query range for this request.\
        Should be either int Unix seconds or ISO-8601 datetime format\
        At least one of end or start is required.")
    parser.add_argument(
        "--version",
        type=lambda v: mds.Version(v),
        default=mds.Version("0.2.1"),
        help="The release version at which to reference MDS, e.g. 0.3.1")

    return parser, parser.parse_args()
コード例 #3
0
def load(datasource, record_type, **kwargs):
    """
    Load data into a database.
    """
    print(f"Loading {record_type}")

    columns = kwargs.pop("columns", [])
    if len(columns) == 0:
        columns = COLUMNS[record_type]

    actions = kwargs.pop("update_actions", [])

    if len(actions) == 1 and actions[0] is True:
        # flag-only option, use defaults
        actions = UPDATE_ACTIONS[record_type]
    elif len(actions) > 1:
        # convert action tuples to dict, filtering any flag-only options
        actions = dict(filter(lambda x: x is not True, actions))

    conflict_update = len(actions) > 0

    version = mds.Version(kwargs.pop("version", common.default_version))
    stage_first = int(kwargs.pop("stage_first", True))

    db_config = dict(stage_first=stage_first, version=version, **env())
    db = kwargs.get("db", mds.Database(**db_config))

    load_config = dict(table=record_type, drop_duplicates=columns)
    if record_type == mds.STATUS_CHANGES:
        load_config["on_conflict_update"] = status_changes_conflict_update(columns, actions, version) if conflict_update else None
        db.load_status_changes(datasource, **load_config)
    elif record_type == mds.TRIPS:
        load_config["on_conflict_update"] = trips_conflict_update(columns, actions, version) if conflict_update else None
        db.load_trips(datasource, **load_config)
コード例 #4
0
def _validate_provider(provider, **kwargs):
    """
    Validate the feeds for a provider.
    """
    # compute a time query range; one or both sides may not be relevant for all feeds.
    if "start_time" not in kwargs and "end_time" not in kwargs:
        # default to the hour beginning 25 hours before the current time
        end = datetime.datetime.utcnow() - datetime.timedelta(days=1)
        start = end - datetime.timedelta(seconds=3600)
    elif "start_time" not in kwargs or "end_time" not in kwargs:
        # one side of range provided, compute the other side for a total range of an hour
        start, end = common.parse_time_range(duration=3600, **kwargs)
    else:
        # both sides of range provided
        start, end = common.parse_time_range(**kwargs)

    kwargs["start_time"] = start
    kwargs["end_time"] = end

    config = common.get_config(provider, kwargs.get("config"))

    # assert the version parameter
    version = mds.Version(config.pop("version", kwargs.get("version")))
    version.raise_if_unsupported()

    kwargs["version"] = version
    kwargs["no_paging"] = False
    kwargs["rate_limit"] = 0
    kwargs["client"] = mds.Client(provider, version=version, **config)

    return _validate(**kwargs)
コード例 #5
0
def status_changes_conflict_update(columns, actions, version=None):
    """
    Create a tuple for generating the status_changes ON CONFLICT UPDATE statement.
    """
    condition, version = prepare_conflict_update(columns, version)

    if version < mds.Version("0.3.0"):
        if "associated_trips" not in actions:
            actions["associated_trips"] = "cast(EXCLUDED.associated_trips as uuid[])"
    else:
        if "associated_trip" not in actions:
            actions["associated_trip"] = "cast(EXCLUDED.associated_trip as uuid)"

    return condition, actions
コード例 #6
0
def setup_cli(**kwargs):
    """
    Set up the common command line arguments.

    Keyword arguments are passed to the ArgumentParser instance.

    Returns the ArgumentParser.
    """
    parser = argparse.ArgumentParser(**kwargs)

    parser.add_argument(
        "--auth_type",
        type=str,
        default="Bearer",
        help="The type used for the Authorization header for requests to the provider\
        (e.g. Basic, Bearer)."
    )

    parser.add_argument(
        "--config",
        type=str,
        help="Path to a provider configuration file to use."
    )

    parser.add_argument(
        "-H",
        "--header",
        dest="headers",
        action="append",
        type=lambda kv: (kv.split(":", 1)[0].strip(), kv.split(":", 1)[1].strip()),
        default=[],
        help="One or more 'Header: value' combinations, sent with each request."
    )

    parser.add_argument(
        "--output",
        type=str,
        help="Write results to json files in this directory."
    )

    parser.add_argument(
        "--version",
        type=lambda v: mds.Version(v),
        default=DEFAULT_VERSION,
        help=f"The release version at which to reference MDS, e.g. {DEFAULT_VERSION}"
    )

    return parser
コード例 #7
0
def ingest(record_type, **kwargs):
    """
    Run the ingestion flow:

    1. acquire data from files or API
    2. optionally validate data, filtering invalid records
    3. optionally write data to output files
    4. optionally load valid records into the database
    """
    version = mds.Version(kwargs.pop("version", common.DEFAULT_VERSION))
    version.raise_if_unsupported()

    datasource = common.get_data(record_type, **kwargs, version=version)
    data_key = mds.Schema(record_type).data_key

    # validation and filtering
    if not kwargs.pop("no_validate", False):
        print(f"Validating {record_type} @ {version}")

        valid, errors, removed = validation.validate(record_type, datasource, version=version)

        seen = sum([len(d["data"][data_key]) for d in datasource])
        passed = sum([len(v["data"][data_key]) for v in valid])
        failed = sum([len(r["data"][data_key]) for r in removed])

        print(f"{seen} records, {passed} passed, {failed} failed")
    else:
        print("Skipping data validation")
        valid = datasource
        removed = None

    # output to files if needed
    output = kwargs.pop("output", None)
    if output:
        f = mds.DataFile(record_type, output)
        f.dump_payloads(valid)
        if removed:
            f.dump_payloads(removed)

    # load to database
    loading = not kwargs.pop("no_load", False)
    if loading and len(valid) > 0:
        database.load(valid, record_type, **kwargs, version=version)
    else:
        print("Skipping data load")

    print(f"{record_type} complete")
コード例 #8
0
def load(datasource, record_type, **kwargs):
    """
    Load data into a database.
    """
    print(f"Loading {record_type}")

    version = mds.Version(kwargs.pop("version", common.DEFAULT_VERSION))
    version.raise_if_unsupported()

    if version < mds.Version._040_() and record_type not in [mds.STATUS_CHANGES, mds.TRIPS]:
        raise ValueError(f"MDS Version {version} only supports {STATUS_CHANGES} and {TRIPS}.")
    elif version < mds.Version._041_() and record_type == mds.VEHICLES:
        raise ValueError(f"MDS Version {version} does not support the {VEHICLES} endpoint.")

    columns = kwargs.pop("columns", [])
    if len(columns) == 0:
        columns = COLUMNS[record_type]

    actions = kwargs.pop("update_actions", [])

    if len(actions) == 1 and actions[0] is True:
        # flag-only option, use defaults
        actions = default_conflict_update_actions(record_type, version)
    elif len(actions) > 1:
        # convert action tuples to dict, filtering any flag-only options
        actions = dict(filter(lambda x: x is not True, actions))

    stage_first = int(kwargs.pop("stage_first", True))

    db_config = dict(stage_first=stage_first, version=version, **env())
    db = kwargs.get("db", mds.Database(**db_config))

    load_config = dict(table=record_type, drop_duplicates=columns)
    if len(actions) > 0:
        load_config["on_conflict_update"] = conflict_update_condition(columns), actions

    if record_type == mds.EVENTS:
        db.load_events(datasource, **load_config)
    elif record_type == mds.STATUS_CHANGES:
        db.load_status_changes(datasource, **load_config)
    elif record_type == mds.TRIPS:
        db.load_trips(datasource, **load_config)
    elif record_type == mds.VEHICLES:
        db.load_vehicles(datasource, **load_config)
コード例 #9
0
def get_data(record_type, **kwargs):
    """
    Get provider data as in-memory objects.
    """
    if kwargs.get("source"):
        source = kwargs.get("source")
        print(f"Reading {record_type} from {source}")
        payloads = mds.DataFile(record_type, source).load_payloads()
        return payloads

    # required for API calls
    client = kwargs.pop("client")
    start_time = kwargs.pop("start_time")
    end_time = kwargs.pop("end_time")

    paging = not kwargs.get("no_paging")
    rate_limit = kwargs.get("rate_limit")
    version = kwargs.get("version")

    # package up for API requests
    api_kwargs = dict(paging=paging, rate_limit=rate_limit)

    print(f"Requesting {record_type} from {client.provider.provider_name}")
    print(f"Time range: {start_time.isoformat()} to {end_time.isoformat()}")

    if record_type == mds.STATUS_CHANGES:
        api_kwargs["start_time"] = start_time
        api_kwargs["end_time"] = end_time
    elif record_type == mds.TRIPS:
        api_kwargs["device_id"] = kwargs.get("device_id")
        api_kwargs["vehicle_id"] = kwargs.get("vehicle_id")

        if version < mds.Version("0.3.0"):
            api_kwargs["start_time"] = start_time
            api_kwargs["end_time"] = end_time
        else:
            api_kwargs["min_end_time"] = start_time
            api_kwargs["max_end_time"] = end_time

    return client.get(record_type, **api_kwargs)
コード例 #10
0
def _validate_provider(provider, **kwargs):
    """
    Validate the feeds for a provider.
    """
    # assert the time parameters -> if giving one, both must be given
    if any([
            kwargs.get("start_time") and not kwargs.get("end_time"),
            kwargs.get("end_time") and not kwargs.get("start_time")
    ]):
        print(
            "Both --start_time and --end_time are required for custom query ranges."
        )
        exit(1)

    if not (kwargs.get("start_time") or kwargs.get("end_time")):
        # default to the hour beginning 25 hours before the current time
        end = datetime.datetime.utcnow() - datetime.timedelta(days=1)
        start = end - datetime.timedelta(seconds=3600)
    else:
        # parse from user input
        start, end = common.parse_time_range(**kwargs)

    kwargs["start_time"] = start
    kwargs["end_time"] = end

    config = common.get_config(provider, kwargs.get("config"))

    # assert the version parameter
    version = mds.Version(config.pop("version", kwargs.get("version")))
    if version.unsupported:
        raise mds.UnsupportedVersionError(version)
    else:
        kwargs["version"] = version

    kwargs["no_paging"] = False
    kwargs["rate_limit"] = 0
    kwargs["client"] = mds.Client(provider, version=version, **config)

    return _validate(**kwargs)
コード例 #11
0
"""
Helper functions for shared functionality.
"""

import argparse
import datetime
import pathlib

import mds

default_version = mds.Version("0.3.2")


def count_seconds(ts):
    """
    Return the number of seconds since a given UNIX datetime.
    """
    return round((datetime.datetime.utcnow() - ts).total_seconds())


def get_config(provider, config_path=None):
    """
    Obtain provider's configuration data from the given file path, or the default file path if None.
    """
    if config_path:
        return mds.ConfigFile(config_path, provider).dump()
    elif pathlib.Path("./config.json").exists():
        return mds.ConfigFile("./config.json", provider).dump()
    else:
        return {}
コード例 #12
0
def setup_cli():
    """
    Create the cli argument interface, and parses incoming args.

    Returns a tuple:
        - the argument parser
        - the parsed args
    """

    # used to display CLI options for vehicle_types and propulsion_types
    schema = mds.Schema(mds.TRIPS, 'master')

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--boundary",
        type=str,
        help=
        "Path to a data file with geographic bounds for the generated data. Overrides the MDS_BOUNDARY environment variable."
    )
    parser.add_argument(
        "--close",
        type=int,
        default=19,
        help=
        "The hour of the day (24-hr format) that provider stops operations. Overrides --start and --end."
    )
    parser.add_argument(
        "--date_format",
        type=str,
        default="unix",
        help=
        "Format for datetime input (to this CLI) and output (to stdout and files). Options:\
            - 'unix' for Unix timestamps (default)\
            - 'iso8601' for ISO 8601 format\
            - '<python format string>' for custom formats,\
            see https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior"
    )
    parser.add_argument(
        "--devices",
        type=int,
        help="The number of devices to model in the generated data")
    parser.add_argument(
        "--end",
        type=str,
        help="The latest event in the generated data, in --date_format format")
    parser.add_argument(
        "--inactivity",
        type=float,
        help="Describes the portion of the fleet that remains inactive.")
    parser.add_argument(
        "--open",
        type=int,
        default=7,
        help=
        "The hour of the day (24-hr format) that provider begins operations. Overrides --start and --end."
    )
    parser.add_argument(
        "--output",
        type=str,
        help="Path to a directory to write the resulting data file(s)")
    parser.add_argument(
        "--propulsion_types",
        type=str,
        nargs="+",
        default=[],  # to be filled in below if not specified by user
        metavar="PROPULSION_TYPE",
        help=
        "A list of propulsion_types to use for the generated data, e.g. '{}'".
        format(" ".join(schema.propulsion_types)))
    parser.add_argument(
        "--provider_name",
        type=str,
        help="The name of the fake mobility as a service provider")
    parser.add_argument(
        "--provider_id",
        type=uuid.UUID,
        help="The ID of the fake mobility as a service provider")
    parser.add_argument(
        "--start",
        type=str,
        help="The earliest event in the generated data, in --date_format format"
    )
    parser.add_argument(
        "--speed_mph",
        type=float,
        help=
        "The average speed of devices in miles per hour. Cannot be used with --speed_ms"
    )
    parser.add_argument(
        "--speed_ms",
        type=float,
        help=
        "The average speed of devices in meters per second. Always takes precedence"
    )
    parser.add_argument(
        "--vehicle_types",
        type=str,
        nargs="+",
        default=[],  # to be filled in below if not specified by user
        metavar="VEHICLE_TYPE",
        help="A list of vehicle_types to use for the generated data, e.g. '{}'"
        .format(" ".join(schema.vehicle_types)))
    parser.add_argument(
        "--version",
        type=lambda v: mds.Version(v),
        default=mds.Version("0.2.1"),
        help="The release version at which to reference MDS, e.g. 0.3.1")

    # use the specified MDS schema for filling in vehicle and propulsion types if not specified by user
    args = parser.parse_args()
    trips_schema = mds.Schema(mds.TRIPS, args.version)

    if not args.vehicle_types:
        args.vehicle_types = trips_schema.vehicle_types

    if not args.propulsion_types:
        args.propulsion_types = trips_schema.propulsion_types

    return parser, args
コード例 #13
0
    arg_parser, args = setup_cli()

    # assert the data type parameters
    if not any([args.events, args.status_changes, args.trips, args.vehicles]):
        print("At least one of --events, --status_changes, --trips, or --vehicles is required.")
        print("Run main.py --help for more information.")
        print("Exiting.")
        print()
        exit(1)

    print(f"Starting ingestion run: {now.isoformat()}")

    config = common.get_config(args.provider, args.config)

    # assert the version parameter
    args.version = mds.Version(config.pop("version", args.version))
    args.version.raise_if_unsupported()

    print(f"Referencing MDS @ {args.version}")

    # shortcut for loading from files
    if args.source:
        if args.events:
            ingest(mds.EVENTS, **vars(args))
        if args.status_changes:
            ingest(mds.STATUS_CHANGES, **vars(args))
        if args.trips:
            ingest(mds.TRIPS, **vars(args))
        if args.vehicles:
            ingest(mds.VEHICLES, **vars(args))
        # finished
コード例 #14
0
def validate(record_type, sources, version, **kwargs):
    """
    Partition sources into a tuple of (valid, errors, failures)

        - valid: the sources with remaining valid data records
        - errors: a list of mds.schemas.DataValidationError
        - removed: the sources with invalid data records
    """
    if not all([isinstance(d, dict) and "data" in d for d in sources]):
        raise TypeError(
            "Sources appears to be the wrong data type. Expected a list of payload dicts."
        )

    source_versions = [mds.Version(d["version"]) for d in sources]

    if any([version != v for v in source_versions]):
        raise mds.versions.UnexpectedVersionError(source_versions[0], version)

    valid = []
    errors = []
    removed = []
    validator = kwargs.get("validator", _validator(record_type, version))

    for source in sources:
        records = list(source.get("data", {}).get(record_type, []))
        invalid_records = []
        invalid_source = False
        invalid_idx = set()

        # schema validation
        for error in validator.validate(source):
            errors.append(error)
            failure, idx = _failure(record_type, error)
            invalid_source = invalid_source or failure

            # this was a problem with a single item, mark it for removal
            if not failure and isinstance(idx, int):
                invalid_idx.add(idx)

        # filter invalid items if the overall payload was OK
        if not invalid_source:
            if len(invalid_idx) > 0:
                valid_records = [
                    r for r in records if records.index(r) not in invalid_idx
                ]
                invalid_records = [
                    r for r in records if records.index(r) in invalid_idx
                ]
            else:
                valid_records = records

            if len(valid_records) > 0:
                # create a copy to preserve the original payload
                payload = {**source, "data": {record_type: valid_records}}
                valid.append(payload)

            if len(invalid_records) > 0:
                # create a copy to preserve the original payload
                payload = {**source, "data": {record_type: invalid_records}}
                removed.append(payload)

    return valid, errors, removed
コード例 #15
0
"""
Helper functions for shared functionality.
"""

import argparse
import datetime
import pathlib

import mds


DEFAULT_VERSION = mds.Version("0.3.2")
VERSION_040 = mds.Version("0.4.0")


def count_seconds(ts):
    """
    Return the number of seconds since a given UNIX datetime.
    """
    return round((datetime.datetime.utcnow() - ts).total_seconds())


def get_config(provider, config_path=None):
    """
    Obtain provider's configuration data from the given file path, or the default file path if None.
    """
    if config_path:
        return mds.ConfigFile(config_path, provider).dump()
    elif pathlib.Path("./config.json").exists():
        return mds.ConfigFile("./config.json", provider).dump()
    else: