예제 #1
0
def create_merged_stripe(files, used_idxs, s3_file_name, n_x, y, trace_span):
    with tracing.start_span('parallel stripe loading', parent=trace_span):
        with concurrent.futures.ThreadPoolExecutor(10) as executor:
            contents = list(
                executor.map(partial(load_stripe, used_idxs, y, n_x), files))

    with tracing.start_span('merged stripe save', parent=trace_span):
        d = numpy.concatenate(contents, axis=1).tobytes()
        s3_put(f"{y}/{s3_file_name}", d)
예제 #2
0
def ingest_from_queue():
    q = get_queue()
    for ingest_req in q:
        # Queue is empty for now
        if ingest_req is None:
            logger.info("Empty queue")
            break

        ingest_req = ingest_req.data

        # Expire out anything whose run time is very old (probably a bad request/URL)
        if datetime.utcfromtimestamp(
                ingest_req['run_time']) < datetime.utcnow() - timedelta(
                    hours=12):
            logger.info("Expiring old request %s", ingest_req)
            continue

        # If this URL doesn't exist, try again in a few minutes
        if not (url_exists(ingest_req['url'])
                and url_exists(ingest_req['idx_url'])):
            logger.info("Rescheduling request %s", ingest_req)
            q.put(ingest_req, '5m')
            continue

        with tracing.start_span('ingest item') as span:
            for k, v in ingest_req.items():
                span.set_attribute(k, v)

            try:
                source = Source.query.filter_by(
                    short_name=ingest_req['source']).first()

                with tempfile.NamedTemporaryFile() as reduced:
                    with tracing.start_span('download'):
                        logging.info(
                            f"Downloading and reducing {ingest_req['url']} from {ingest_req['run_time']} {source.short_name}"
                        )
                        reduce_grib(ingest_req['url'], ingest_req['idx_url'],
                                    source.fields, reduced)
                    with tracing.start_span('ingest'):
                        logging.info("Ingesting all")
                        ingest_grib_file(reduced.name, source)

                source.last_updated = datetime.utcnow()

                db.session.commit()
            except KeyboardInterrupt:
                raise
            except Exception:
                logger.exception("Exception while ingesting %s. Will retry",
                                 ingest_req)
                q.put(ingest_req, '5m')
예제 #3
0
def ingest_grib_file(file_path, source):
    """
    Ingests a given GRIB file into the backend.
    :param file_path: Path to the GRIB file
    :param source: Source object which denotes which source this data is from
    :return: None
    """
    logger.info("Processing GRIB file '%s'", file_path)

    grib = pygrib.open(file_path)

    # Keeps all data points that we'll be inserting at the end.
    # Map of proj_id to map of {(field_id, valid_time, run_time) -> [msg, ...]}
    data_by_projection = collections.defaultdict(lambda: collections.defaultdict(list))

    for field in SourceField.query.filter(SourceField.source_id == source.id, SourceField.metric.has(Metric.intermediate == False)).all():
        try:
            msgs = grib.select(**field.selectors)
        except ValueError:
            logger.warning("Could not find message(s) in grib matching selectors %s", field.selectors)
            continue

        for msg in msgs:
            with tracing.start_span('parse message') as span:
                span.set_attribute('message', str(msg))

                if field.projection is None or field.projection.params != msg.projparams:
                    projection = get_or_create_projection(msg)
                    field.projection_id = projection.id
                    db.session.commit()

                valid_date = get_end_valid_time(msg)
                data_by_projection[field.projection.id][(field.id, valid_date, msg.analDate)].append(msg.values)

    with tracing.start_span('generate derived'):
        logger.info("Generating derived fields")
        for proj_id, fields in get_source_module(source.short_name).generate_derived(grib).items():
            for k, v in fields.items():
                data_by_projection[proj_id][k].extend(v)

    with tracing.start_span('save denormalized'):
        logger.info("Saving denormalized location/time data for all messages")
        for proj_id, fields in data_by_projection.items():
            create_files(proj_id, fields)

    logger.info("Done saving denormalized data")
예제 #4
0
def merge():
    """
    Merge all (small) files into larger files to reduce the number of S3 requests each query needs to do.
    """
    all_files = FileMeta.query.filter(
        FileMeta.file_name.in_(
            FileBandMeta.query.filter(
                FileBandMeta.valid_time > datetime.utcnow()).with_entities(
                    FileBandMeta.file_name)), ).order_by(
                        FileMeta.loc_size.asc(), ).all()

    proj_files = collections.defaultdict(list)
    for f in all_files:
        proj_files[f.projection].append(f)

    # Pull from the projection with the most backlog first
    for proj, proj_files in sorted(proj_files.items(),
                                   key=lambda pair: len(pair[1]),
                                   reverse=True):
        # Don't waste time if we don't really have that many files
        if len(proj_files) < 8:
            continue

        # Merge in smaller batches (10 <= batch_size <= 50) to more quickly reduce S3 load per query.
        batch_size = min(ceil(len(proj_files) / 4), 50)
        if len(proj_files) < 40:
            batch_size = len(proj_files)

        for files in chunk(proj_files, batch_size):
            # This next part is all about figuring out what items are still used in
            # each file so that the merge process can effectively garbage collect
            # unused data.

            # Dict of FileMeta -> list of float32 item indexes still used by some band
            used_idxs = collections.defaultdict(list)

            offset = 0
            # Dict of FileBandMeta -> offset
            new_offsets = {}

            for f in files:
                for band in f.bands:
                    # Don't bother merging old data. Prevents racing with the cleaner,
                    # and probably won't be queried anyways.
                    if band.valid_time < datetime.utcnow():
                        continue

                    new_offsets[band] = offset
                    offset += 4 * band.vals_per_loc

                    start_idx = band.offset // 4
                    used_idxs[f].extend(
                        range(start_idx, start_idx + band.vals_per_loc))

            s3_file_name = hashlib.md5(
                ('-'.join(f.file_name
                          for f in files)).encode('utf-8')).hexdigest()

            merged_meta = FileMeta(
                file_name=s3_file_name,
                projection_id=proj.id,
                loc_size=offset,
            )
            db.session.add(merged_meta)

            logger.info("Merging %s into %s",
                        ','.join(f.file_name for f in files), s3_file_name)

            n_y, n_x = proj.shape()

            # If we fail to create any merged stripe, don't commit the changes to
            # band offset/file name, but _do_ commit the FileMeta to the DB.
            # This way the normal cleaning process will remove any orphaned bands.
            commit_merged = True

            # max workers = 10 to limit mem utilization
            # Approximate worst case, we'll have
            # (5 sources * 70 runs * 2000 units wide * 20 metrics/unit * 4 bytes per metric) per row
            # or ~50MB/row in memory.
            # 10 rows keeps us well under 1GB which is what this should be provisioned for.
            with tracing.start_span('parallel stripe creation') as span:
                span.set_attribute("s3_file_name", s3_file_name)
                span.set_attribute("num_files", len(files))

                with concurrent.futures.ThreadPoolExecutor(10) as executor:
                    futures = concurrent.futures.wait([
                        executor.submit(create_merged_stripe, files, used_idxs,
                                        s3_file_name, n_x, y, span)
                        for y in range(n_y)
                    ])
                    for fut in futures.done:
                        if fut.exception() is not None:
                            logger.error("Exception merging: %s",
                                         fut.exception())
                            commit_merged = False

                span.set_attribute("commit", commit_merged)

            if commit_merged:
                for band, offset in new_offsets.items():
                    band.offset = offset
                    band.file_name = merged_meta.file_name

                logger.info("Updated file band meta")

            db.session.commit()

        # We know we won't need this projection again, so clear it
        clear_proj_cache()
예제 #5
0
                                        s3_file_name, n_x, y, span)
                        for y in range(n_y)
                    ])
                    for fut in futures.done:
                        if fut.exception() is not None:
                            logger.error("Exception merging: %s",
                                         fut.exception())
                            commit_merged = False

                span.set_attribute("commit", commit_merged)

            if commit_merged:
                for band, offset in new_offsets.items():
                    band.offset = offset
                    band.file_name = merged_meta.file_name

                logger.info("Updated file band meta")

            db.session.commit()

        # We know we won't need this projection again, so clear it
        clear_proj_cache()


if __name__ == "__main__":
    init_sentry()
    logging.basicConfig(level=logging.INFO)
    init_tracing('merge')
    with tracing.start_span('merge'):
        merge()
예제 #6
0
                    short_name=ingest_req['source']).first()

                with tempfile.NamedTemporaryFile() as reduced:
                    with tracing.start_span('download'):
                        logging.info(
                            f"Downloading and reducing {ingest_req['url']} from {ingest_req['run_time']} {source.short_name}"
                        )
                        reduce_grib(ingest_req['url'], ingest_req['idx_url'],
                                    source.fields, reduced)
                    with tracing.start_span('ingest'):
                        logging.info("Ingesting all")
                        ingest_grib_file(reduced.name, source)

                source.last_updated = datetime.utcnow()

                db.session.commit()
            except KeyboardInterrupt:
                raise
            except Exception:
                logger.exception("Exception while ingesting %s. Will retry",
                                 ingest_req)
                q.put(ingest_req, '5m')


if __name__ == "__main__":
    init_sentry()
    logging.basicConfig(level=logging.INFO)
    init_tracing('queue_worker')
    with tracing.start_span('queue worker'):
        ingest_from_queue()
예제 #7
0
def summarize():
    """
    Summarizes the weather in a natural way.
    Returns a list of objects describing a summary of the weather (one per day).
    """
    lat = float(request.args['lat'])
    lon = float(request.args['lon'])
    start = request.args.get('start', type=int)
    days = int(request.args['days'])

    if lat > 90 or lat < -90 or lon > 180 or lon < -180:
        abort(400)

    if days > 10:
        abort(400)

    # TODO: This should be done relative to the location's local TZ
    now = datetime.utcnow()
    if start is None:
        start = now
    else:
        start = datetime.utcfromtimestamp(start)

        if not app.debug:
            if start < now - timedelta(days=1):
                start = now - timedelta(days=1)

    source_fields = SourceField.query.filter(
        or_(
            SourceField.metric == metrics.temp,
            SourceField.metric == metrics.raining,
            SourceField.metric == metrics.snowing,
            SourceField.metric_id.in_([metrics.wind_speed.id, metrics.wind_direction.id, metrics.gust_speed.id]),
            SourceField.metric == metrics.cloud_cover,
            SourceField.metric == metrics.composite_reflectivity,
        ),
        SourceField.projection_id != None,
    ).all()

    with tracing.start_span("load_data_points") as span:
        end = start + timedelta(days=days)
        span.set_attribute("start", str(start))
        span.set_attribute("end", str(end))
        span.set_attribute("source_fields", str(source_fields))
        data_points = load_data_points((lat, lon), start, end, source_fields)

    with tracing.start_span("combine_models") as span:
        combined_data_points = combine_models(data_points)

    time_ranges = [(start, start.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1))]
    for d in range(1, days):
        last_end = time_ranges[-1][1]
        time_ranges.append((last_end, last_end + timedelta(days=1)))

    summarizations = []

    with tracing.start_span("summarizations") as span:
        for dstart, dend in time_ranges:
            summary = SummarizedData(dstart, dend, combined_data_points)
            summarizations.append(summary.dict())

    return jsonify(summarizations)
예제 #8
0
def wx_for_location():
    """
    Gets the weather for a specific location, optionally limiting by metric and time.
    at that time.
    """
    lat = float(request.args['lat'])
    lon = float(request.args['lon'])

    if lat > 90 or lat < -90 or lon > 180 or lon < -180:
        abort(400)

    requested_metrics = request.args.getlist('metrics', int)

    if requested_metrics:
        metric_ids = set(requested_metrics)
    else:
        metric_ids = Metric.query.with_entities(Metric.id)

    now = datetime.utcnow()
    start = request.args.get('start', type=int)
    end = request.args.get('end', type=int)

    if start is None:
        start = now - timedelta(hours=1)
    else:
        start = datetime.utcfromtimestamp(start)

        if not app.debug:
            if start < now - timedelta(days=1):
                start = now - timedelta(days=1)

    if end is None:
        end = now + timedelta(hours=12)
    else:
        end = datetime.utcfromtimestamp(end)

        if not app.debug:
            if end > now + timedelta(days=7):
                end = now + timedelta(days=7)

    requested_source_fields = SourceField.query.filter(
        SourceField.metric_id.in_(metric_ids),
        SourceField.projection_id != None,  # noqa: E711
    ).all()

    with tracing.start_span("load_data_points") as span:
        span.set_attribute("start", str(start))
        span.set_attribute("end", str(end))
        span.set_attribute("source_fields", str(requested_source_fields))
        data_points = load_data_points((lat, lon), start, end, requested_source_fields)

    # valid time -> data points
    datas = collections.defaultdict(list)

    for dp in data_points:
        datas[datetime2unix(dp.valid_time)].append({
            'run_time': datetime2unix(dp.run_time),
            'src_field_id': dp.source_field_id,
            'value': dp.median(),
            'raw_values': dp.values,
        })

    wx = {
        'data': datas,
        'ordered_times': sorted(datas.keys()),
    }

    return jsonify(wx)
예제 #9
0
def load_data_points(
    coords: Tuple[float, float],
    start: datetime.datetime,
    end: datetime.datetime,
    source_fields: Optional[Iterable[SourceField]] = None
) -> List[DataPointSet]:

    if source_fields is None:
        source_fields = SourceField.query.all()

    # Determine all valid source fields (fields in source_fields which cover the given coords),
    # and the x,y for projection used in any valid source field.
    valid_source_fields = []
    locs: Dict[int, Tuple[float, float]] = {}
    for sf in source_fields:
        if sf.projection_id in locs and locs[sf.projection_id] is None:
            continue

        if sf.projection_id not in locs:
            with tracing.start_span("get_xy_for_coord") as span:
                span.set_attribute("projection_id", sf.projection_id)
                loc = get_xy_for_coord(sf.projection, coords)

            # Skip if given projection does not cover coords
            if loc is None:
                continue

            locs[sf.projection_id] = loc

        valid_source_fields.append(sf)

    with tracing.start_span("load file band metas") as span:
        fbms: List[FileBandMeta] = FileBandMeta.query.filter(
            FileBandMeta.source_field_id.in_(
                [sf.id for sf in valid_source_fields]),
            FileBandMeta.valid_time >= start,
            FileBandMeta.valid_time < end,
        ).all()

    # Gather all files we need data from
    file_metas = set(fbm.file_meta for fbm in fbms)

    file_contents = {}

    # Read them in (in parallel)
    # TODO: use asyncio here instead once everything else is ported?
    with tracing.start_span("load file chunks") as span:
        span.set_attribute("num_files", len(file_metas))
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = {
                executor.submit(load_file_chunk, fm, locs[fm.projection_id]):
                fm
                for fm in file_metas
            }
            for future in concurrent.futures.as_completed(futures):
                fm = futures[future]
                file_contents[fm.file_name] = future.result()

    # filebandmeta -> values
    data_points = []
    for fbm in fbms:
        raw = file_contents[fbm.file_name][fbm.offset:fbm.offset +
                                           (4 * fbm.vals_per_loc)]
        data_values: List[float] = array.array("f", raw).tolist()
        data_point = DataPointSet(
            values=data_values,
            metric_id=fbm.source_field.metric.id,
            valid_time=fbm.valid_time,
            source_field_id=fbm.source_field_id,
            run_time=fbm.run_time,
        )

        data_points.append(data_point)

    return data_points