def _validate(**kwargs): """ Check each feed type and keep valid results """ results = [] version = kwargs["version"] for record_type in [mds.STATUS_CHANGES, mds.TRIPS]: datasource = common.get_data(record_type, **kwargs) if len(datasource) > 0: versions = set([d["version"] for d in datasource]) if len(versions) > 1: expected, unexpected = mds.Version( versions.pop()), mds.Version(versions.pop()) error = mds.versions.UnexpectedVersionError( expected, unexpected) results.append( (record_type, expected, datasource, [], [error], [])) continue version = mds.Version(version or versions.pop()) try: valid, errors, removed = validate(record_type, datasource, version) results.append( (record_type, version, datasource, valid, errors, removed)) except mds.versions.UnexpectedVersionError as unexpected_version: results.append((record_type, version, datasource, [], [unexpected_version], [])) return results
def setup_cli(): """ Create the cli argument interface, and parses incoming args. Returns a tuple: - the argument parser - the parsed args """ parser = argparse.ArgumentParser() parser.add_argument("--availability", action="store_true", help="Run the availability calculation.") parser.add_argument( "--cutoff", type=int, default=-1, help= "Maximum allowed length of a time-windowed event (e.g. availability window, trip), in days." ) parser.add_argument("--debug", action="store_true", help="Print debug messages.") parser.add_argument( "--duration", type=int, help= "Number of seconds; with --start_time or --end_time, defines a time query range." ) parser.add_argument( "--end", type=str, help="The end of the time query range for this request.\ Should be either int Unix seconds or ISO-8601 datetime format.\ At least one of end or start is required.") parser.add_argument("--local", action="store_true", help="Input and query times are local.") parser.add_argument( "--query", action="append", type=lambda kv: kv.split("=", 1), dest="queries", metavar="QUERY", help= "A {provider_name}={vehicle_type} pair; multiple pairs will be analyzed separately." ) parser.add_argument( "--start", type=str, help="The beginning of the time query range for this request.\ Should be either int Unix seconds or ISO-8601 datetime format\ At least one of end or start is required.") parser.add_argument( "--version", type=lambda v: mds.Version(v), default=mds.Version("0.2.1"), help="The release version at which to reference MDS, e.g. 0.3.1") return parser, parser.parse_args()
def load(datasource, record_type, **kwargs): """ Load data into a database. """ print(f"Loading {record_type}") columns = kwargs.pop("columns", []) if len(columns) == 0: columns = COLUMNS[record_type] actions = kwargs.pop("update_actions", []) if len(actions) == 1 and actions[0] is True: # flag-only option, use defaults actions = UPDATE_ACTIONS[record_type] elif len(actions) > 1: # convert action tuples to dict, filtering any flag-only options actions = dict(filter(lambda x: x is not True, actions)) conflict_update = len(actions) > 0 version = mds.Version(kwargs.pop("version", common.default_version)) stage_first = int(kwargs.pop("stage_first", True)) db_config = dict(stage_first=stage_first, version=version, **env()) db = kwargs.get("db", mds.Database(**db_config)) load_config = dict(table=record_type, drop_duplicates=columns) if record_type == mds.STATUS_CHANGES: load_config["on_conflict_update"] = status_changes_conflict_update(columns, actions, version) if conflict_update else None db.load_status_changes(datasource, **load_config) elif record_type == mds.TRIPS: load_config["on_conflict_update"] = trips_conflict_update(columns, actions, version) if conflict_update else None db.load_trips(datasource, **load_config)
def _validate_provider(provider, **kwargs): """ Validate the feeds for a provider. """ # compute a time query range; one or both sides may not be relevant for all feeds. if "start_time" not in kwargs and "end_time" not in kwargs: # default to the hour beginning 25 hours before the current time end = datetime.datetime.utcnow() - datetime.timedelta(days=1) start = end - datetime.timedelta(seconds=3600) elif "start_time" not in kwargs or "end_time" not in kwargs: # one side of range provided, compute the other side for a total range of an hour start, end = common.parse_time_range(duration=3600, **kwargs) else: # both sides of range provided start, end = common.parse_time_range(**kwargs) kwargs["start_time"] = start kwargs["end_time"] = end config = common.get_config(provider, kwargs.get("config")) # assert the version parameter version = mds.Version(config.pop("version", kwargs.get("version"))) version.raise_if_unsupported() kwargs["version"] = version kwargs["no_paging"] = False kwargs["rate_limit"] = 0 kwargs["client"] = mds.Client(provider, version=version, **config) return _validate(**kwargs)
def status_changes_conflict_update(columns, actions, version=None): """ Create a tuple for generating the status_changes ON CONFLICT UPDATE statement. """ condition, version = prepare_conflict_update(columns, version) if version < mds.Version("0.3.0"): if "associated_trips" not in actions: actions["associated_trips"] = "cast(EXCLUDED.associated_trips as uuid[])" else: if "associated_trip" not in actions: actions["associated_trip"] = "cast(EXCLUDED.associated_trip as uuid)" return condition, actions
def setup_cli(**kwargs): """ Set up the common command line arguments. Keyword arguments are passed to the ArgumentParser instance. Returns the ArgumentParser. """ parser = argparse.ArgumentParser(**kwargs) parser.add_argument( "--auth_type", type=str, default="Bearer", help="The type used for the Authorization header for requests to the provider\ (e.g. Basic, Bearer)." ) parser.add_argument( "--config", type=str, help="Path to a provider configuration file to use." ) parser.add_argument( "-H", "--header", dest="headers", action="append", type=lambda kv: (kv.split(":", 1)[0].strip(), kv.split(":", 1)[1].strip()), default=[], help="One or more 'Header: value' combinations, sent with each request." ) parser.add_argument( "--output", type=str, help="Write results to json files in this directory." ) parser.add_argument( "--version", type=lambda v: mds.Version(v), default=DEFAULT_VERSION, help=f"The release version at which to reference MDS, e.g. {DEFAULT_VERSION}" ) return parser
def ingest(record_type, **kwargs): """ Run the ingestion flow: 1. acquire data from files or API 2. optionally validate data, filtering invalid records 3. optionally write data to output files 4. optionally load valid records into the database """ version = mds.Version(kwargs.pop("version", common.DEFAULT_VERSION)) version.raise_if_unsupported() datasource = common.get_data(record_type, **kwargs, version=version) data_key = mds.Schema(record_type).data_key # validation and filtering if not kwargs.pop("no_validate", False): print(f"Validating {record_type} @ {version}") valid, errors, removed = validation.validate(record_type, datasource, version=version) seen = sum([len(d["data"][data_key]) for d in datasource]) passed = sum([len(v["data"][data_key]) for v in valid]) failed = sum([len(r["data"][data_key]) for r in removed]) print(f"{seen} records, {passed} passed, {failed} failed") else: print("Skipping data validation") valid = datasource removed = None # output to files if needed output = kwargs.pop("output", None) if output: f = mds.DataFile(record_type, output) f.dump_payloads(valid) if removed: f.dump_payloads(removed) # load to database loading = not kwargs.pop("no_load", False) if loading and len(valid) > 0: database.load(valid, record_type, **kwargs, version=version) else: print("Skipping data load") print(f"{record_type} complete")
def load(datasource, record_type, **kwargs): """ Load data into a database. """ print(f"Loading {record_type}") version = mds.Version(kwargs.pop("version", common.DEFAULT_VERSION)) version.raise_if_unsupported() if version < mds.Version._040_() and record_type not in [mds.STATUS_CHANGES, mds.TRIPS]: raise ValueError(f"MDS Version {version} only supports {STATUS_CHANGES} and {TRIPS}.") elif version < mds.Version._041_() and record_type == mds.VEHICLES: raise ValueError(f"MDS Version {version} does not support the {VEHICLES} endpoint.") columns = kwargs.pop("columns", []) if len(columns) == 0: columns = COLUMNS[record_type] actions = kwargs.pop("update_actions", []) if len(actions) == 1 and actions[0] is True: # flag-only option, use defaults actions = default_conflict_update_actions(record_type, version) elif len(actions) > 1: # convert action tuples to dict, filtering any flag-only options actions = dict(filter(lambda x: x is not True, actions)) stage_first = int(kwargs.pop("stage_first", True)) db_config = dict(stage_first=stage_first, version=version, **env()) db = kwargs.get("db", mds.Database(**db_config)) load_config = dict(table=record_type, drop_duplicates=columns) if len(actions) > 0: load_config["on_conflict_update"] = conflict_update_condition(columns), actions if record_type == mds.EVENTS: db.load_events(datasource, **load_config) elif record_type == mds.STATUS_CHANGES: db.load_status_changes(datasource, **load_config) elif record_type == mds.TRIPS: db.load_trips(datasource, **load_config) elif record_type == mds.VEHICLES: db.load_vehicles(datasource, **load_config)
def get_data(record_type, **kwargs): """ Get provider data as in-memory objects. """ if kwargs.get("source"): source = kwargs.get("source") print(f"Reading {record_type} from {source}") payloads = mds.DataFile(record_type, source).load_payloads() return payloads # required for API calls client = kwargs.pop("client") start_time = kwargs.pop("start_time") end_time = kwargs.pop("end_time") paging = not kwargs.get("no_paging") rate_limit = kwargs.get("rate_limit") version = kwargs.get("version") # package up for API requests api_kwargs = dict(paging=paging, rate_limit=rate_limit) print(f"Requesting {record_type} from {client.provider.provider_name}") print(f"Time range: {start_time.isoformat()} to {end_time.isoformat()}") if record_type == mds.STATUS_CHANGES: api_kwargs["start_time"] = start_time api_kwargs["end_time"] = end_time elif record_type == mds.TRIPS: api_kwargs["device_id"] = kwargs.get("device_id") api_kwargs["vehicle_id"] = kwargs.get("vehicle_id") if version < mds.Version("0.3.0"): api_kwargs["start_time"] = start_time api_kwargs["end_time"] = end_time else: api_kwargs["min_end_time"] = start_time api_kwargs["max_end_time"] = end_time return client.get(record_type, **api_kwargs)
def _validate_provider(provider, **kwargs): """ Validate the feeds for a provider. """ # assert the time parameters -> if giving one, both must be given if any([ kwargs.get("start_time") and not kwargs.get("end_time"), kwargs.get("end_time") and not kwargs.get("start_time") ]): print( "Both --start_time and --end_time are required for custom query ranges." ) exit(1) if not (kwargs.get("start_time") or kwargs.get("end_time")): # default to the hour beginning 25 hours before the current time end = datetime.datetime.utcnow() - datetime.timedelta(days=1) start = end - datetime.timedelta(seconds=3600) else: # parse from user input start, end = common.parse_time_range(**kwargs) kwargs["start_time"] = start kwargs["end_time"] = end config = common.get_config(provider, kwargs.get("config")) # assert the version parameter version = mds.Version(config.pop("version", kwargs.get("version"))) if version.unsupported: raise mds.UnsupportedVersionError(version) else: kwargs["version"] = version kwargs["no_paging"] = False kwargs["rate_limit"] = 0 kwargs["client"] = mds.Client(provider, version=version, **config) return _validate(**kwargs)
""" Helper functions for shared functionality. """ import argparse import datetime import pathlib import mds default_version = mds.Version("0.3.2") def count_seconds(ts): """ Return the number of seconds since a given UNIX datetime. """ return round((datetime.datetime.utcnow() - ts).total_seconds()) def get_config(provider, config_path=None): """ Obtain provider's configuration data from the given file path, or the default file path if None. """ if config_path: return mds.ConfigFile(config_path, provider).dump() elif pathlib.Path("./config.json").exists(): return mds.ConfigFile("./config.json", provider).dump() else: return {}
def setup_cli(): """ Create the cli argument interface, and parses incoming args. Returns a tuple: - the argument parser - the parsed args """ # used to display CLI options for vehicle_types and propulsion_types schema = mds.Schema(mds.TRIPS, 'master') parser = argparse.ArgumentParser() parser.add_argument( "--boundary", type=str, help= "Path to a data file with geographic bounds for the generated data. Overrides the MDS_BOUNDARY environment variable." ) parser.add_argument( "--close", type=int, default=19, help= "The hour of the day (24-hr format) that provider stops operations. Overrides --start and --end." ) parser.add_argument( "--date_format", type=str, default="unix", help= "Format for datetime input (to this CLI) and output (to stdout and files). Options:\ - 'unix' for Unix timestamps (default)\ - 'iso8601' for ISO 8601 format\ - '<python format string>' for custom formats,\ see https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior" ) parser.add_argument( "--devices", type=int, help="The number of devices to model in the generated data") parser.add_argument( "--end", type=str, help="The latest event in the generated data, in --date_format format") parser.add_argument( "--inactivity", type=float, help="Describes the portion of the fleet that remains inactive.") parser.add_argument( "--open", type=int, default=7, help= "The hour of the day (24-hr format) that provider begins operations. Overrides --start and --end." ) parser.add_argument( "--output", type=str, help="Path to a directory to write the resulting data file(s)") parser.add_argument( "--propulsion_types", type=str, nargs="+", default=[], # to be filled in below if not specified by user metavar="PROPULSION_TYPE", help= "A list of propulsion_types to use for the generated data, e.g. '{}'". format(" ".join(schema.propulsion_types))) parser.add_argument( "--provider_name", type=str, help="The name of the fake mobility as a service provider") parser.add_argument( "--provider_id", type=uuid.UUID, help="The ID of the fake mobility as a service provider") parser.add_argument( "--start", type=str, help="The earliest event in the generated data, in --date_format format" ) parser.add_argument( "--speed_mph", type=float, help= "The average speed of devices in miles per hour. Cannot be used with --speed_ms" ) parser.add_argument( "--speed_ms", type=float, help= "The average speed of devices in meters per second. Always takes precedence" ) parser.add_argument( "--vehicle_types", type=str, nargs="+", default=[], # to be filled in below if not specified by user metavar="VEHICLE_TYPE", help="A list of vehicle_types to use for the generated data, e.g. '{}'" .format(" ".join(schema.vehicle_types))) parser.add_argument( "--version", type=lambda v: mds.Version(v), default=mds.Version("0.2.1"), help="The release version at which to reference MDS, e.g. 0.3.1") # use the specified MDS schema for filling in vehicle and propulsion types if not specified by user args = parser.parse_args() trips_schema = mds.Schema(mds.TRIPS, args.version) if not args.vehicle_types: args.vehicle_types = trips_schema.vehicle_types if not args.propulsion_types: args.propulsion_types = trips_schema.propulsion_types return parser, args
arg_parser, args = setup_cli() # assert the data type parameters if not any([args.events, args.status_changes, args.trips, args.vehicles]): print("At least one of --events, --status_changes, --trips, or --vehicles is required.") print("Run main.py --help for more information.") print("Exiting.") print() exit(1) print(f"Starting ingestion run: {now.isoformat()}") config = common.get_config(args.provider, args.config) # assert the version parameter args.version = mds.Version(config.pop("version", args.version)) args.version.raise_if_unsupported() print(f"Referencing MDS @ {args.version}") # shortcut for loading from files if args.source: if args.events: ingest(mds.EVENTS, **vars(args)) if args.status_changes: ingest(mds.STATUS_CHANGES, **vars(args)) if args.trips: ingest(mds.TRIPS, **vars(args)) if args.vehicles: ingest(mds.VEHICLES, **vars(args)) # finished
def validate(record_type, sources, version, **kwargs): """ Partition sources into a tuple of (valid, errors, failures) - valid: the sources with remaining valid data records - errors: a list of mds.schemas.DataValidationError - removed: the sources with invalid data records """ if not all([isinstance(d, dict) and "data" in d for d in sources]): raise TypeError( "Sources appears to be the wrong data type. Expected a list of payload dicts." ) source_versions = [mds.Version(d["version"]) for d in sources] if any([version != v for v in source_versions]): raise mds.versions.UnexpectedVersionError(source_versions[0], version) valid = [] errors = [] removed = [] validator = kwargs.get("validator", _validator(record_type, version)) for source in sources: records = list(source.get("data", {}).get(record_type, [])) invalid_records = [] invalid_source = False invalid_idx = set() # schema validation for error in validator.validate(source): errors.append(error) failure, idx = _failure(record_type, error) invalid_source = invalid_source or failure # this was a problem with a single item, mark it for removal if not failure and isinstance(idx, int): invalid_idx.add(idx) # filter invalid items if the overall payload was OK if not invalid_source: if len(invalid_idx) > 0: valid_records = [ r for r in records if records.index(r) not in invalid_idx ] invalid_records = [ r for r in records if records.index(r) in invalid_idx ] else: valid_records = records if len(valid_records) > 0: # create a copy to preserve the original payload payload = {**source, "data": {record_type: valid_records}} valid.append(payload) if len(invalid_records) > 0: # create a copy to preserve the original payload payload = {**source, "data": {record_type: invalid_records}} removed.append(payload) return valid, errors, removed
""" Helper functions for shared functionality. """ import argparse import datetime import pathlib import mds DEFAULT_VERSION = mds.Version("0.3.2") VERSION_040 = mds.Version("0.4.0") def count_seconds(ts): """ Return the number of seconds since a given UNIX datetime. """ return round((datetime.datetime.utcnow() - ts).total_seconds()) def get_config(provider, config_path=None): """ Obtain provider's configuration data from the given file path, or the default file path if None. """ if config_path: return mds.ConfigFile(config_path, provider).dump() elif pathlib.Path("./config.json").exists(): return mds.ConfigFile("./config.json", provider).dump() else: