Example #1
0
    def __init__(self, asdf_file_name, netsta_list='*'):

        self._data_path = asdf_file_name
        self._earth_radius = 6371  # km

        self.fds = FederatedASDFDataSet(asdf_file_name)
        # Gather station metadata
        netsta_list_subset = set(
            netsta_list.split(' ')) if netsta_list != '*' else netsta_list
        self.netsta_list = []
        self.metadata = defaultdict(list)

        rtps = []
        for netsta in list(self.fds.unique_coordinates.keys()):
            if (netsta_list_subset != '*'):
                if netsta not in netsta_list_subset:
                    continue

            self.netsta_list.append(netsta)
            self.metadata[netsta] = self.fds.unique_coordinates[netsta]

            rtps.append([
                self._earth_radius,
                np.radians(90 - self.metadata[netsta][1]),
                np.radians(self.metadata[netsta][0])
            ])
        # end for

        rtps = np.array(rtps)
        xyzs = rtp2xyz(rtps[:, 0], rtps[:, 1], rtps[:, 2])

        self._tree = cKDTree(xyzs)
        self._cart_location = defaultdict(list)
        for i, ns in enumerate(self.netsta_list):
            self._cart_location[ns] = xyzs[i, :]
Example #2
0
def merge_results(output_path):
    search_strings = ['p_arrivals*', 's_arrivals*']
    output_fns = ['p_combined.txt', 's_combined.txt']

    for ss, ofn in zip(search_strings, output_fns):
        files = recursive_glob(output_path, ss)
        ofn = open('%s/%s' % (output_path, ofn), 'w+')

        data = set()
        for i, fn in enumerate(files):
            lines = open(fn, 'r').readlines()

            if (i == 0):
                ofn.write(lines[0])
            # end if

            for j in range(1, len(lines)):
                data.add(lines[j])
            # end for

            os.system('rm %s' % (fn))
        # end for

        for l in data:
            ofn.write(l)
        # end for

        ofn.close()
def test_get_stations():
    fds = FederatedASDFDataSet(asdf_file_list)

    rows = np.array(
        fds.get_stations('1900-01-01T00:00:00', '2100-01-01T00:00:00'))

    station_set = set()
    for n, s in rows[:, 0:2]:
        station_set.add((n, s))

    # There are eight stations in the h5 file
    assert len(station_set) == 8
def test_get_coordinates():
    fds = FederatedASDFDataSet(asdf_file_list)

    rows = np.array(
        fds.get_stations('1900-01-01T00:00:00', '2100-01-01T00:00:00'))

    station_set = set()
    for n, s in rows[:, 0:2]:
        station_set.add((n, s))

    # we should have coordinates for each station
    assert len(fds.unique_coordinates) == len(station_set)
Example #5
0
    def get_unique_station_pairs(self, other_dataset, nn=1):
        pairs = set()
        for st1 in self.netsta_list:
            st2list = None
            if (nn != -1):
                if self == other_dataset:
                    st2list = set(
                        self.get_closest_stations(st1,
                                                  other_dataset,
                                                  nn=nn + 1))
                    if st1 in st2list:
                        st2list.remove(st1)
                    st2list = list(st2list)
                else:
                    st2list = self.get_closest_stations(st1,
                                                        other_dataset,
                                                        nn=nn)
            else:
                st2list = other_dataset.netsta_list
            # end if

            for st2 in st2list:
                pairs.add((st1, st2))
            # end for
        # end for

        pairs_subset = set()
        for item in pairs:
            if (item[0] == item[1]): continue

            dup_item = (item[1], item[0])
            if (dup_item not in pairs_subset and item not in pairs_subset):
                pairs_subset.add(item)
            # end if
        # end if

        return list(pairs_subset)
Example #6
0
        def cull_pairs(pairs, keep_list_fn):
            result = set()
            pairs_set = set()

            for pair in pairs:
                pairs_set.add('%s.%s' % (pair[0], pair[1]))
            # end for

            keep_list = open(keep_list_fn, 'r').readlines()
            for keep_pair in keep_list:
                keep_pair = keep_pair.strip()
                if (len(keep_pair)):
                    knet1, ksta1, knet2, ksta2 = keep_pair.split('.')

                    keep_pair_alt = '%s.%s.%s.%s' % (knet2, ksta2, knet1,
                                                     ksta1)

                    if (keep_pair in pairs_set or keep_pair_alt in pairs_set):
                        result.add(('%s.%s' % (knet1, ksta1),
                                    '%s.%s' % (knet2, ksta2)))
                # end if
            # end for

            return list(result)
def test_get_local_net_sta_list():
    fds = FederatedASDFDataSet(asdf_file_list)

    local_netsta_list = list(fds.local_net_sta_list())
    rows = np.array(
        fds.get_stations('1900-01-01T00:00:00', '2100-01-01T00:00:00'))

    # Get a list of unique stations
    stations = set()
    for n, s in rows[:, 0:2]:
        stations.add((n, s))
    # end for

    # On serial runs, all stations should be allocated to rank 0
    assert len(local_netsta_list) == len(stations)
Example #8
0
    def get_stations(self,
                     starttime,
                     endtime,
                     network=None,
                     station=None,
                     location=None,
                     channel=None):
        starttime = UTCDateTime(starttime).timestamp
        endtime = UTCDateTime(endtime).timestamp

        query = 'select * from wdb where '
        if (network): query += " net='%s' " % (network)
        if (station):
            if (network): query += "and sta='%s' " % (station)
            else: query += "sta='%s' " % (station)
        if (location):
            if (network or station): query += "and loc='%s' " % (location)
            else: query += "loc='%s' " % (location)
        if (channel):
            if (network or station or location):
                query += "and cha='%s' " % (channel)
            else:
                query += "cha='%s' " % (channel)
        if (network or station or location or channel): query += ' and '
        query += ' et>=%f and st<=%f' \
                 % (starttime, endtime)
        query += ' group by net, sta, loc, cha'

        rows = self.conn.execute(query).fetchall()
        results = set()
        for row in rows:
            ds_id, net, sta, loc, cha, st, et, tag = row

            rv = (net, sta, loc, cha,
                  self.asdf_station_coordinates[ds_id]['%s.%s' %
                                                       (net, sta)][0],
                  self.asdf_station_coordinates[ds_id]['%s.%s' %
                                                       (net, sta)][1])
            results.add(rv)
        # end for

        return list(results)
def test_get_global_time_range():
    fds = FederatedASDFDataSet(asdf_file_list)

    rows = np.array(
        fds.get_stations('1900-01-01T00:00:00', '2100-01-01T00:00:00'))

    station_set = set()
    for n, s in rows[:, 0:2]:
        station_set.add((n, s))

    minlist = []
    maxlist = []
    for (n, s) in station_set:
        min, max = fds.get_global_time_range(n, s)
        minlist.append(min)
        maxlist.append(max)
    # end for

    min = UTCDateTime(np.array(minlist).min())
    max = UTCDateTime(np.array(maxlist).max())

    # Ensure aggregate min/max to corresponding values in the db
    assert min == UTCDateTime('2000-01-01T00:00:00.000000Z')
    assert max == UTCDateTime('2002-01-01T00:00:00.000000Z')
Example #10
0
    def _load_events_helper(self):
        eventList = []
        poTimestamps = []
        lines = []

        for ifn, fn in enumerate(self.csv_files):
            print(('Reading %s' % (fn)))
            for iline, line in enumerate(open(fn, 'r').readlines()):
                lines.append(line)
            # end for
        # end for

        if (self.rank == 0):
            eid_set = set()
            for iline, line in enumerate(lines):
                evtline = ''
                if (line[0] == '#'):
                    items = line.split(',')
                    vals = list(map(float, items[1:]))

                    year = int(vals[0])
                    month = int(vals[1])
                    day = int(vals[2])
                    hour = int(vals[3])
                    minute = int(vals[4])
                    second = vals[5]

                    lon = vals[6]
                    lat = vals[7]
                    if (lon < -180 or lon > 180): continue
                    if (lat < -90 or lat > 90): continue

                    depth = vals[8] if vals[8] >= 0 else 0

                    mb = vals[10]
                    ms = vals[11]
                    mi = vals[12]
                    mw = vals[13]
                    mag = 0
                    magtype = 'mw'
                    if (mw > 0):
                        mag = mw
                        magtype = 'mw'
                    elif (ms > 0):
                        mag = ms
                        magtype = 'ms'
                    elif (mb > 0):
                        mag = mb
                        magtype = 'mb'
                    elif (mi > 0):
                        mag = mi
                        magtype = 'mi'
                    # end if

                    eventid = vals[-1]
                    if (eventid in eid_set):
                        raise RuntimeError(
                            'Duplicate event-id found. Aborting..')
                    else:
                        eid_set.add(eventid)
                    # end if

                    utctime = None
                    try:
                        utctime = UTCDateTime(year, month, day, hour, minute,
                                              second)
                    except Exception:
                        continue
                    # end try

                    origin = Origin(utctime, lat, lon, depth)
                    event = Event()
                    event.public_id = int(eventid)
                    event.preferred_origin = origin
                    event.preferred_magnitude = Magnitude(mag, magtype)

                    eventList.append(event)
                    poTimestamps.append(origin.utctime.timestamp)
                else:
                    eventList[-1].preferred_origin.arrival_list.append(iline)
                # end if

                #if(iline%1000==0): print iline
            # end for
            eventList = split_list(eventList, self.nproc)
            poTimestamps = split_list(poTimestamps, self.nproc)
        # end if

        # broadcast workload to all procs

        eventList = self.comm.scatter(eventList, root=0)
        poTimestamps = self.comm.scatter(poTimestamps, root=0)

        print(
            ('Processing %d events on rank %d' % (len(eventList), self.rank)))
        for e in eventList:
            lineIndices = copy.copy(e.preferred_origin.arrival_list)
            e.preferred_origin.arrival_list = []
            for lineIndex in lineIndices:
                items = lines[lineIndex].split(',')
                vals = list(map(float, items[8:]))

                year = int(vals[0])
                month = int(vals[1])
                day = int(vals[2])
                hour = int(vals[3])
                minute = int(vals[4])
                second = vals[5]

                utctime = None
                try:
                    utctime = UTCDateTime(year, month, day, hour, minute,
                                          second)
                except Exception:
                    continue
                # end try

                try:
                    lon = float(items[4])
                except:
                    lon = 0
                try:
                    lat = float(items[5])
                except:
                    lat = 0
                try:
                    elev = float(items[6])
                except:
                    elev = 0

                distance = vals[-1]
                a = Arrival(items[3].strip(), items[0].strip(),
                            items[2].strip(), items[1].strip(), lon, lat, elev,
                            items[7].strip(), utctime, distance)
                e.preferred_origin.arrival_list.append(a)
            # end for
        # end for

        self.eventList = eventList
        self.poTimestamps = np.array(poTimestamps)
def plot_results(stations, results, output_basename):
    # collate indices for each channel for each station
    assert len(stations) == len(results)
    groupIndices = defaultdict(list)
    for i in np.arange(len(results)):
        groupIndices['%s.%s' % (stations[i][0], stations[i][1])].append(i)
    # end for

    # gather number of days of usable data for each station
    usableStationDays = defaultdict(int)
    maxUsableDays = -1e32
    minUsableDays = 1e32
    for k, v in groupIndices.items():
        for i, index in enumerate(v):
            x, means = results[index]

            means = np.array(means)
            days = np.sum(~np.isnan(means) & np.bool_(means != 0))
            if usableStationDays[k] < days:
                usableStationDays[k] = days
                maxUsableDays = max(maxUsableDays, days)
                minUsableDays = min(minUsableDays, days)
            # end if
        # end for
    # end for

    # Plot station map
    pdf = PdfPages('%s.pdf' % output_basename)

    fig = plt.figure(figsize=(20, 30))
    ax1 = fig.add_axes([0.05, 0.05, 0.9, 0.7])
    ax2 = fig.add_axes([0.05, 0.7, 0.9, 0.3])
    ax2.set_visible(False)

    minLon = 1e32
    maxLon = -1e32
    minLat = 1e32
    maxLat = -1e32
    for s in stations:
        lon, lat = s[4], s[5]
        if lon < 0:
            lon += 360
        # end if
        minLon = min(minLon, lon)
        maxLon = max(maxLon, lon)
        minLat = min(minLat, lat)
        maxLat = max(maxLat, lat)
    # end for

    minLon -= 1
    maxLon += 1
    minLat -= 1
    maxLat += 1

    m = Basemap(ax=ax1, projection='merc',
                resolution='i', llcrnrlat=minLat, urcrnrlat=maxLat,
                llcrnrlon=minLon, urcrnrlon=maxLon,
                lat_0=(minLat + maxLat) / 2., lon_0=(minLon + maxLon) / 2.)
    # draw coastlines.
    m.drawcoastlines()

    # draw grid
    parallels = np.linspace(np.around(minLat / 5) * 5 - 5, np.around(maxLat / 5) * 5 + 5, 6)
    m.drawparallels(parallels, labels=[True, True, False, False], fontsize=20)
    meridians = np.linspace(np.around(minLon / 5) * 5 - 5, np.around(maxLon / 5) * 5 + 5, 6)
    m.drawmeridians(meridians, labels=[False, False, True, True], fontsize=20)

    # plot stations
    norm = matplotlib.colors.Normalize(vmin=minUsableDays, vmax=maxUsableDays, clip=True)
    mapper = cm.ScalarMappable(norm=norm, cmap=cm.jet_r)
    plotted = set()
    for s in stations:
        if s[1] in plotted:
            continue
        else:
            plotted.add(s[1])
        # end if

        lon, lat = s[4], s[5]

        px, py = m(lon, lat)
        pxl, pyl = m(lon, lat - 0.1)
        days = usableStationDays['%s.%s' % (s[0], s[1])]
        m.scatter(px, py, s=400, marker='v',
                  c=mapper.to_rgba(days),
                  edgecolor='none', label='%s: %d' % (s[1], days))
        ax1.annotate(s[1], xy=(px + 0.05, py + 0.05), fontsize=22)
    # end for

    fig.axes[0].set_title("Network Name: %s" % s[0], fontsize=30, y=1.05)
    fig.axes[0].legend(prop={'size': 16}, loc=(0.2, 1.3),
                       ncol=5, fancybox=True, title='No. of Usable Days',
                       title_fontsize=16)

    pdf.savefig()
    plt.close()

    # Plot results
    for k, v in groupIndices.items():
        axesCount = 0
        for i in v:
            assert (k == '%s.%s' % (stations[i][0], stations[i][1]))
            # only need axes for non-null results
            a, b = results[i]
            if len(a) and len(b):
                axesCount += 1
            # end if
        # end for
        fig, axes = plt.subplots(axesCount, sharex=True)
        fig.set_size_inches(20, 15)

        axes = np.atleast_1d(axes)

        if len(axes):
            axesIdx = 0
            for i, index in enumerate(v):
                try:
                    x, means = results[index]

                    if len(x) and len(means):
                        x = [a.matplotlib_date for a in x]
                        d = np.array(means)

                        if len(d):
                            d[0] = np.nanmedian(d)
                        # end if

                        dnorm = d
                        dnormmin = np.nanmin(dnorm)
                        dnormmax = np.nanmax(dnorm)

                        axes[axesIdx].scatter(x, dnorm, marker='.', s=20)
                        axes[axesIdx].plot(x, dnorm, c='k', label='24 hr mean\n'
                                                                  'Gaps indicate no-data', lw=2, alpha=0.7)
                        axes[axesIdx].grid(axis='x', linestyle=':', alpha=0.3)

                        axes[axesIdx].fill_between(x, dnormmax * np.int_(d == 0), dnormmin * np.int_(d == 0),
                                                   where=dnormmax * np.int_(d == 0) - dnormmin * np.int_(d == 0) > 0,
                                                   color='r', alpha=0.5, label='All 0 Samples')

                        axes[axesIdx].fill_between(x, dnormmax * np.int_(np.isnan(d)), dnormmin * np.int_(np.isnan(d)),
                                                   where=dnormmax * np.int_(np.isnan(d)) - dnormmin * np.int_(
                                                       np.isnan(d)) > 1,
                                                   color='b', alpha=0.5, label='No Data')

                        axes[axesIdx].xaxis.set_major_locator(AutoDateLocator())
                        axes[axesIdx].xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))

                        for tick in axes[axesIdx].get_xticklabels():
                            tick.set_rotation(45)
                        # end for
                        axes[axesIdx].legend(loc='upper right', prop={'size': 12})
                        axes[axesIdx].tick_params(axis='both', labelsize=16)
                        stn = stations[index]
                        axes[axesIdx].set_title('Channel %s.%s' % (stn[2], stn[3]),
                                                fontsize=18, y=0.95, va='top')
                        axes[axesIdx].set_xlim(xmin=min(x), xmax=max(x))
                        axes[axesIdx].set_ylim(ymin=dnormmin, ymax=dnormmax)
                        axes[axesIdx].set_ylabel('Ampl.', fontsize=16)

                        axesIdx += 1
                    # end if
                except:
                    # plotting fails when each axes contain <2 values; just move on in those instances
                    logging.warning('Plotting failed on station %s' % k)
                # end try
            # end for
            axes[-1].set_xlabel('Days', fontsize=16)
        # end if

        plt.suptitle('%s Data Availability (~%d days)' % (k, usableStationDays[k]),
                     y=0.96, fontsize=20)
        pdf.savefig()
        plt.close()
        gc.collect()

    # end for

    pdf.close()
Example #12
0
def process(input_folder, inventory, output_file_name, min_length_sec,
            merge_threshold, ntraces_per_file):
    """
    INPUT_FOLDER: Path to input folder containing miniseed files \n
    INVENTORY: Path to FDSNStationXML inventory containing channel-level metadata for all stations \n
    OUTPUT_FILE_NAME: Name of output ASDF file \n
    """
    comm = MPI.COMM_WORLD
    nproc = comm.Get_size()
    rank = comm.Get_rank()

    # Read inventory
    inv = None
    try:
        inv = read_inventory(inventory)
    except Exception as e:
        print(e)
    # end try

    files = np.array(glob.glob(input_folder + '/*.mseed'))
    random.Random(nproc).shuffle(files)
    # files = files[:100]

    ustations = set()
    ustationInv = defaultdict(list)
    networklist = []
    stationlist = []
    for file in files:
        _, _, net, sta, _ = file.split('.')
        ustations.add('%s.%s' % (net, sta))
        networklist.append(net)
        stationlist.append(sta)
    # end for

    networklist = np.array(networklist)
    stationlist = np.array(stationlist)

    idx = np.lexsort((networklist, stationlist))
    files = files[idx]

    myfiles = split_list(files, nproc)[rank]

    if (rank == 0):
        for i, ustation in enumerate(ustations):
            net, sta = ustation.split('.')
            sinv = inv.select(network=net, station=sta)
            if (not len(sinv.networks)):
                print(('Missing station: %s.%s' % (net, sta)))
                ustationInv[ustation] = None
            else:
                ustationInv[ustation] = sinv
            # end if
        # end for
    # end if
    ustationInv = comm.bcast(ustationInv, root=0)

    # Extract trace-count-lists in parallel
    mytrccountlist = np.zeros(len(myfiles))
    for ifile, file in enumerate(tqdm(myfiles)):
        try:
            st = read(file, headonly=True)
            mytrccountlist[ifile] = len(st)
            # end if
        except Exception as e:
            print(e)
        # end try
    # end for

    trccountlist = comm.gather(mytrccountlist, root=0)

    if (rank == 0):
        trccountlist = np.array(
            [item for sublist in trccountlist for item in sublist])

        # Some mseed files can be problematic in terms of having way too many traces --
        # e.g. 250k+ traces, each a couple of samples long, for a day mseed file. We
        # need to blacklist them and exclude them from the ASDF file.
        print(('Blacklisted %d files out of %d files; ' \
               'average trace-count %f, std: %f' % (np.sum(trccountlist > ntraces_per_file),
                                                    len(trccountlist),
                                                    np.mean(trccountlist),
                                                    np.std(trccountlist))))
        f = open(str(os.path.splitext(output_file_name)[0] + '.trccount.txt'),
                 'w+')
        for i in range(len(files)):
            f.write('%s\t%d\n' % (files[i], trccountlist[i]))
        # end for
        f.close()

        if (os.path.exists(output_file_name)): os.remove(output_file_name)
        ds = pyasdf.ASDFDataSet(output_file_name,
                                compression='gzip-3',
                                mpi=False)

        for ifile, file in enumerate(tqdm(files)):
            st = []
            if (trccountlist[ifile] > ntraces_per_file):
                continue
            else:
                try:
                    st = read(file)
                except Exception as e:
                    print(e)
                    continue
                # end try
            # end if

            if (len(st)):
                netsta = st[0].stats.network + '.' + st[0].stats.station

                if (ustationInv[netsta]):
                    if (merge_threshold):
                        ntraces = len(st)
                        if (ntraces > merge_threshold):
                            try:
                                st = st.merge(method=1,
                                              fill_value='interpolate')
                            except:
                                print('Failed to merge traces. Moving along..')
                                continue
                            # end try
                            print(
                                ('Merging stream with %d traces' % (ntraces)))
                        # end if
                    # end if
                # end if

                for tr in st:

                    if (tr.stats.npts == 0): continue
                    if (min_length_sec):
                        if (tr.stats.npts * tr.stats.delta < min_length_sec):
                            continue
                    # end if

                    asdfTag = make_ASDF_tag(tr,
                                            "raw_recording").encode('ascii')

                    try:
                        ds.add_waveforms(tr, tag='raw_recording')
                    except Exception as e:
                        print(e)
                        print('Failed to append trace:')
                        print(tr)
                    # end try
                # end for

                try:
                    ds.add_stationxml(ustationInv[netsta])
                except Exception as e:
                    print(e)
                    print('Failed to append inventory:')
                    print((ustationInv[netsta]))
                # end try
            # end if
        # end for
        print('Closing asdf file..')
        del ds