예제 #1
0
 def test_mergePreviews(self):
     """
     Tests the merging of Previews.
     """
     # Merging non-preview traces in one Stream object should raise.
     st = Stream(traces=[Trace(data=np.empty(2)),
                         Trace(data=np.empty(2))])
     self.assertRaises(Exception, mergePreviews, st)
     # Merging empty traces should return an new empty Stream object.
     st = Stream()
     stream_id = id(st)
     st2 = mergePreviews(st)
     self.assertNotEqual(stream_id, id(st2))
     self.assertEqual(len(st.traces), 0)
     # Different sampling rates in one Stream object causes problems.
     tr1 = Trace(data=np.empty(10))
     tr1.stats.preview = True
     tr1.stats.sampling_rate = 100
     tr2 = Trace(data=np.empty(10))
     tr2.stats.preview = True
     st = Stream(traces=[tr1, tr2])
     self.assertRaises(Exception, mergePreviews, st)
     # Different data types should raise.
     tr1 = Trace(data=np.empty(10, dtype=np.int32))
     tr1.stats.preview = True
     tr2 = Trace(data=np.empty(10, dtype=np.float64))
     tr2.stats.preview = True
     st = Stream(traces=[tr1, tr2])
     self.assertRaises(Exception, mergePreviews, st)
     # Now some real tests.
     # 1
     tr1 = Trace(data=np.array([1, 2] * 100))
     tr1.stats.preview = True
     tr1.stats.starttime = UTCDateTime(500)
     tr2 = Trace(data=np.array([3, 1] * 100))
     tr2.stats.preview = True
     tr2.stats.starttime = UTCDateTime(500)
     st = Stream(traces=[tr1, tr2])
     st2 = mergePreviews(st)
     self.assertEqual(len(st2.traces), 1)
     self.assertEqual(st2[0].stats.starttime, UTCDateTime(500))
     np.testing.assert_array_equal(st2[0].data, np.array([3, 2] * 100))
     # 2
     tr1 = Trace(data=np.array([1] * 10))
     tr1.stats.preview = True
     tr2 = Trace(data=np.array([2] * 9))
     tr2.stats.starttime = tr2.stats.starttime + 20
     tr2.stats.preview = True
     st = Stream(traces=[tr1, tr2])
     st2 = mergePreviews(st)
     self.assertEqual(len(st2.traces), 1)
     self.assertEqual(st2[0].stats.starttime, tr1.stats.starttime)
     np.testing.assert_array_equal(st2[0].data,
                                   np.array([1] * 10 + [-1] * 10 + [2] * 9))
예제 #2
0
 def test_mergePreviews(self):
     """
     Tests the merging of Previews.
     """
     # Merging non-preview traces in one Stream object should raise.
     st = Stream(traces=[Trace(data=np.empty(2)),
                         Trace(data=np.empty(2))])
     self.assertRaises(Exception, mergePreviews, st)
     # Merging empty traces should return an new empty Stream object.
     st = Stream()
     stream_id = id(st)
     st2 = mergePreviews(st)
     self.assertNotEqual(stream_id, id(st2))
     self.assertEqual(len(st.traces), 0)
     # Different sampling rates in one Stream object causes problems.
     tr1 = Trace(data=np.empty(10))
     tr1.stats.preview = True
     tr1.stats.sampling_rate = 100
     tr2 = Trace(data=np.empty(10))
     tr2.stats.preview = True
     st = Stream(traces=[tr1, tr2])
     self.assertRaises(Exception, mergePreviews, st)
     # Different data types should raise.
     tr1 = Trace(data=np.empty(10, dtype='int32'))
     tr1.stats.preview = True
     tr2 = Trace(data=np.empty(10, dtype='float64'))
     tr2.stats.preview = True
     st = Stream(traces=[tr1, tr2])
     self.assertRaises(Exception, mergePreviews, st)
     # Now some real tests.
     # 1
     tr1 = Trace(data=np.array([1, 2] * 100))
     tr1.stats.preview = True
     tr1.stats.starttime = UTCDateTime(500)
     tr2 = Trace(data=np.array([3, 1] * 100))
     tr2.stats.preview = True
     tr2.stats.starttime = UTCDateTime(500)
     st = Stream(traces=[tr1, tr2])
     st2 = mergePreviews(st)
     self.assertEqual(len(st2.traces), 1)
     self.assertEqual(st2[0].stats.starttime, UTCDateTime(500))
     np.testing.assert_array_equal(st2[0].data, np.array([3, 2] * 100))
     # 2
     tr1 = Trace(data=np.array([1] * 10))
     tr1.stats.preview = True
     tr2 = Trace(data=np.array([2] * 9))
     tr2.stats.starttime = tr2.stats.starttime + 20
     tr2.stats.preview = True
     st = Stream(traces=[tr1, tr2])
     st2 = mergePreviews(st)
     self.assertEqual(len(st2.traces), 1)
     self.assertEqual(st2[0].stats.starttime, tr1.stats.starttime)
     np.testing.assert_array_equal(st2[0].data,
                                   np.array([1] * 10 + [-1] * 10 + [2] * 9))
예제 #3
0
 def test_mergePreviews2(self):
     """
     Test case for issue #84.
     """
     # Note: explicitly creating np.ones instead of np.empty in order to
     # prevent NumPy warnings related to max function
     tr1 = Trace(data=np.ones(2880))
     tr1.stats.starttime = UTCDateTime("2010-01-01T00:00:00.670000Z")
     tr1.stats.delta = 30.0
     tr1.stats.preview = True
     tr1.verify()
     tr2 = Trace(data=np.ones(2881))
     tr2.stats.starttime = UTCDateTime("2010-01-01T23:59:30.670000Z")
     tr2.stats.delta = 30.0
     tr2.stats.preview = True
     tr2.verify()
     st1 = Stream([tr1, tr2])
     st1.verify()
     # merge
     st2 = mergePreviews(st1)
     st2.verify()
     # check
     self.assertTrue(st2[0].stats.preview)
     self.assertEqual(st2[0].stats.starttime, tr1.stats.starttime)
     self.assertEqual(st2[0].stats.endtime, tr2.stats.endtime)
     self.assertEqual(st2[0].stats.npts, 5760)
     self.assertEqual(len(st2[0]), 5760)
예제 #4
0
 def test_mergePreviews2(self):
     """
     Test case for issue #84.
     """
     # Note: explicitly creating np.ones instead of np.empty in order to
     # prevent NumPy warnings related to max function
     tr1 = Trace(data=np.ones(2880))
     tr1.stats.starttime = UTCDateTime("2010-01-01T00:00:00.670000Z")
     tr1.stats.delta = 30.0
     tr1.stats.preview = True
     tr1.verify()
     tr2 = Trace(data=np.ones(2881))
     tr2.stats.starttime = UTCDateTime("2010-01-01T23:59:30.670000Z")
     tr2.stats.delta = 30.0
     tr2.stats.preview = True
     tr2.verify()
     st1 = Stream([tr1, tr2])
     st1.verify()
     # merge
     st2 = mergePreviews(st1)
     st2.verify()
     # check
     self.assertTrue(st2[0].stats.preview)
     self.assertEqual(st2[0].stats.starttime, tr1.stats.starttime)
     self.assertEqual(st2[0].stats.endtime, tr2.stats.endtime)
     self.assertEqual(st2[0].stats.npts, 5760)
     self.assertEqual(len(st2[0]), 5760)
예제 #5
0
def _getPreview(session, **kwargs):
    # build up query
    query = session.query(WaveformChannel)
    # start and end time
    try:
        start = kwargs.get('start_datetime')
        start = UTCDateTime(start)
    except:
        start = UTCDateTime() - 60 * 20
    finally:
        query = query.filter(WaveformChannel.endtime > start.datetime)
    try:
        end = kwargs.get('end_datetime')
        end = UTCDateTime(end)
    except:
        # 10 minutes
        end = UTCDateTime()
    finally:
        query = query.filter(WaveformChannel.starttime < end.datetime)
    # process arguments
    if 'trace_ids' in kwargs:
        # filter over trace id list
        trace_ids = kwargs.get('trace_ids', '')
        trace_filter = or_()
        for trace_id in trace_ids.split(','):
            temp = trace_id.split('.')
            if len(temp) != 4:
                continue
            trace_filter.append(and_(
                WaveformChannel.network == temp[0],
                WaveformChannel.station == temp[1],
                WaveformChannel.location == temp[2],
                WaveformChannel.channel == temp[3]))
        if trace_filter.clauses:
            query = query.filter(trace_filter)
    else:
        # filter over network/station/location/channel id
        for key in ['network_id', 'station_id', 'location_id',
                    'channel_id']:
            text = kwargs.get(key, None)
            if text == None:
                continue
            col = getattr(WaveformChannel, key[:-3])
            if text == "":
                query = query.filter(col == None)
            elif '*' in text or '?' in text:
                text = text.replace('?', '_')
                text = text.replace('*', '%')
                query = query.filter(col.like(text))
            else:
                query = query.filter(col == text)
    # execute query
    results = query.all()
    session.close()
    # create Stream
    st = Stream()
    for result in results:
        preview = result.getPreview()
        st.append(preview)
    # merge and trim
    st = mergePreviews(st)
    st.trim(start, end)
    return st, start, end
예제 #6
0
 def getPreview(self, trace_ids=[], starttime=None, endtime=None,
                network=None, station=None, location=None, channel=None,
                pad=False):
     """
     Returns the preview trace.
     """
     # build up query
     session = self.session()
     query = session.query(WaveformChannel)
     # start and end time
     try:
         starttime = UTCDateTime(starttime)
     except:
         starttime = UTCDateTime() - 60 * 20
     finally:
         query = query.filter(WaveformChannel.endtime > starttime.datetime)
     try:
         endtime = UTCDateTime(endtime)
     except:
         # 10 minutes
         endtime = UTCDateTime()
     finally:
         query = query.filter(WaveformChannel.starttime < endtime.datetime)
     # process arguments
     if trace_ids:
         # filter over trace id list
         trace_filter = or_()
         for trace_id in trace_ids:
             temp = trace_id.split('.')
             if len(temp) != 4:
                 continue
             trace_filter.append(and_(
                 WaveformChannel.network == temp[0],
                 WaveformChannel.station == temp[1],
                 WaveformChannel.location == temp[2],
                 WaveformChannel.channel == temp[3]))
         if trace_filter.clauses:
             query = query.filter(trace_filter)
     else:
         # filter over network/station/location/channel id
         kwargs = {'network': network, 'station': station,
                   'location': location, 'channel': channel}
         for key, value in kwargs.items():
             if value is None:
                 continue
             col = getattr(WaveformChannel, key)
             if '*' in value or '?' in value:
                 value = value.replace('?', '_')
                 value = value.replace('*', '%')
                 query = query.filter(col.like(value))
             else:
                 query = query.filter(col == value)
     # execute query
     results = query.all()
     session.close()
     # create Stream
     st = Stream()
     for result in results:
         preview = result.getPreview()
         st.append(preview)
     # merge and trim
     st = mergePreviews(st)
     st.trim(starttime, endtime, pad=pad)
     return st
예제 #7
0
 def getPreview(self,
                trace_ids=[],
                starttime=None,
                endtime=None,
                network=None,
                station=None,
                location=None,
                channel=None,
                pad=False):
     """
     Returns the preview trace.
     """
     # build up query
     session = self.session()
     query = session.query(WaveformChannel)
     # start and end time
     try:
         starttime = UTCDateTime(starttime)
     except:
         starttime = UTCDateTime() - 60 * 20
     finally:
         query = query.filter(WaveformChannel.endtime > starttime.datetime)
     try:
         endtime = UTCDateTime(endtime)
     except:
         # 10 minutes
         endtime = UTCDateTime()
     finally:
         query = query.filter(WaveformChannel.starttime < endtime.datetime)
     # process arguments
     if trace_ids:
         # filter over trace id list
         trace_filter = or_()
         for trace_id in trace_ids:
             temp = trace_id.split('.')
             if len(temp) != 4:
                 continue
             trace_filter.append(
                 and_(WaveformChannel.network == temp[0],
                      WaveformChannel.station == temp[1],
                      WaveformChannel.location == temp[2],
                      WaveformChannel.channel == temp[3]))
         if trace_filter.clauses:
             query = query.filter(trace_filter)
     else:
         # filter over network/station/location/channel id
         kwargs = {
             'network': network,
             'station': station,
             'location': location,
             'channel': channel
         }
         for key, value in kwargs.items():
             if value is None:
                 continue
             col = getattr(WaveformChannel, key)
             if '*' in value or '?' in value:
                 value = value.replace('?', '_')
                 value = value.replace('*', '%')
                 query = query.filter(col.like(value))
             else:
                 query = query.filter(col == value)
     # execute query
     results = query.all()
     session.close()
     # create Stream
     st = Stream()
     for result in results:
         preview = result.getPreview()
         st.append(preview)
     # merge and trim
     st = mergePreviews(st)
     st.trim(starttime, endtime, pad=pad)
     return st
예제 #8
0
    def __plotStraight(self, trace, ax, *args, **kwargs):  # @UnusedVariable
        """
        Just plots the data samples in the self.stream. Useful for smaller
        datasets up to around 1000000 samples (depending on the machine its
        being run on).

        Slow and high memory consumption for large datasets.
        """
        if len(trace) > 1:
            stream = Stream(traces=trace)
            # Merge with 'interpolation'. In case of overlaps this method will
            # always use the longest available trace.
            if hasattr(trace[0].stats, 'preview') and trace[0].stats.preview:
                stream = Stream(traces=stream)
                stream = mergePreviews(stream)
            else:
                stream.merge(method=1)
            trace = stream[0]
        else:
            trace = trace[0]
        # Check if it is a preview file and adjust accordingly.
        # XXX: Will look weird if the preview file is too small.
        if hasattr(trace.stats, 'preview') and trace.stats.preview:
            # Mask the gaps.
            trace.data = np.ma.masked_array(trace.data)
            trace.data[trace.data == -1] = np.ma.masked
            # Recreate the min_max scene.
            dtype = trace.data.dtype
            old_time_range = trace.stats.endtime - trace.stats.starttime
            data = np.empty(2 * trace.stats.npts, dtype=dtype)
            data[0::2] = trace.data / 2.0
            data[1::2] = -trace.data / 2.0
            trace.data = data
            # The times are not supposed to change.
            trace.stats.delta = old_time_range / float(trace.stats.npts - 1)
        # Write to self.stats.
        calib = trace.stats.calib
        max = trace.data.max()
        min = trace.data.min()
        # set label
        if hasattr(trace.stats, 'preview') and trace.stats.preview:
            tr_id = trace.id + ' [preview]'
        elif hasattr(trace, 'label'):
            tr_id = trace.label
        else:
            tr_id = trace.id
        self.stats.append([tr_id, calib * trace.data.mean(),
                           calib * min, calib * max])
        # Pad the beginning and the end with masked values if necessary. Might
        # seem like overkill but it works really fast and is a clean solution
        # to gaps at the beginning/end.
        concat = [trace]
        if self.starttime != trace.stats.starttime:
            samples = (trace.stats.starttime - self.starttime) * \
                trace.stats.sampling_rate
            temp = [np.ma.masked_all(int(samples))]
            concat = temp.extend(concat)
            concat = temp
        if self.endtime != trace.stats.endtime:
            samples = (self.endtime - trace.stats.endtime) * \
                trace.stats.sampling_rate
            concat.append(np.ma.masked_all(int(samples)))
        if len(concat) > 1:
            # Use the masked array concatenate, otherwise it will result in a
            # not masked array.
            trace.data = np.ma.concatenate(concat)
            # set starttime and calculate endtime
            trace.stats.starttime = self.starttime
        trace.data *= calib
        ax.plot(trace.data, color=self.color, linewidth=self.linewidth,
            linestyle=self.linestyle)
        ax.xaxis.grid(color=self.grid_color, linestyle=self.grid_linestyle,
            linewidth=self.grid_linewidth)
        ax.yaxis.grid(color=self.grid_color, linestyle=self.grid_linestyle,
            linewidth=self.grid_linewidth)
        # Set the x limit for the graph to also show the masked values at the
        # beginning/end.
        ax.set_xlim(0, len(trace.data) - 1)
예제 #9
0
    def __plotStraight(self, trace, ax, *args, **kwargs):  # @UnusedVariable
        """
        Just plots the data samples in the self.stream. Useful for smaller
        datasets up to around 1000000 samples (depending on the machine its
        being run on).

        Slow and high memory consumption for large datasets.
        """
        if len(trace) > 1:
            stream = Stream(traces=trace)
            # Merge with 'interpolation'. In case of overlaps this method will
            # always use the longest available trace.
            if hasattr(trace[0].stats, 'preview') and trace[0].stats.preview:
                stream = Stream(traces=stream)
                stream = mergePreviews(stream)
            else:
                stream.merge(method=1)
            trace = stream[0]
        else:
            trace = trace[0]
        # Check if it is a preview file and adjust accordingly.
        # XXX: Will look weird if the preview file is too small.
        if hasattr(trace.stats, 'preview') and trace.stats.preview:
            # Mask the gaps.
            trace.data = np.ma.masked_array(trace.data)
            trace.data[trace.data == -1] = np.ma.masked
            # Recreate the min_max scene.
            dtype = trace.data.dtype
            old_time_range = trace.stats.endtime - trace.stats.starttime
            data = np.empty(2 * trace.stats.npts, dtype=dtype)
            data[0::2] = trace.data / 2.0
            data[1::2] = -trace.data / 2.0
            trace.data = data
            # The times are not supposed to change.
            trace.stats.delta = old_time_range / float(trace.stats.npts - 1)
        # Write to self.stats.
        calib = trace.stats.calib
        max = trace.data.max()
        min = trace.data.min()
        # set label
        if hasattr(trace.stats, 'preview') and trace.stats.preview:
            tr_id = trace.id + ' [preview]'
        elif hasattr(trace, 'label'):
            tr_id = trace.label
        else:
            tr_id = trace.id
        self.stats.append(
            [tr_id, calib * trace.data.mean(), calib * min, calib * max])
        # Pad the beginning and the end with masked values if necessary. Might
        # seem like overkill but it works really fast and is a clean solution
        # to gaps at the beginning/end.
        concat = [trace]
        if self.starttime != trace.stats.starttime:
            samples = (trace.stats.starttime - self.starttime) * \
                trace.stats.sampling_rate
            temp = [np.ma.masked_all(int(samples))]
            concat = temp.extend(concat)
            concat = temp
        if self.endtime != trace.stats.endtime:
            samples = (self.endtime - trace.stats.endtime) * \
                trace.stats.sampling_rate
            concat.append(np.ma.masked_all(int(samples)))
        if len(concat) > 1:
            # Use the masked array concatenate, otherwise it will result in a
            # not masked array.
            trace.data = np.ma.concatenate(concat)
            # set starttime and calculate endtime
            trace.stats.starttime = self.starttime
        trace.data *= calib
        ax.plot(trace.data,
                color=self.color,
                linewidth=self.linewidth,
                linestyle=self.linestyle)
        ax.xaxis.grid(color=self.grid_color,
                      linestyle=self.grid_linestyle,
                      linewidth=self.grid_linewidth)
        ax.yaxis.grid(color=self.grid_color,
                      linestyle=self.grid_linestyle,
                      linewidth=self.grid_linewidth)
        # Set the x limit for the graph to also show the masked values at the
        # beginning/end.
        ax.set_xlim(0, len(trace.data) - 1)
예제 #10
0
 def getWaveform(self, network, station, location, channel, id):
     """
     Gets the waveform. Loads it from the cache or requests it from SeisHub.
     """
     if self.env.debug and not self.env.seishub.online:
         msg = 'No connection to SeisHub server. Only locally cached ' + \
               'information is available.'
         print msg
     # Go through directory structure and create all necessary
     # folders if necessary.
     network_path = os.path.join(self.env.cache_dir, network)
     if not os.path.exists(network_path):
         os.mkdir(network_path)
     station_path = os.path.join(network_path, station)
     if not os.path.exists(station_path):
         os.mkdir(station_path)
     files = os.listdir(station_path)
     # Remove all unwanted files.
     files = [file for file in files if file[-7:] == '--cache' and
              file.split('--')[0] == '%s[%s]' % (channel, location)]
     # If no file exists get it from SeisHub. It will also get cached for
     # future access.
     if len(files) == 0 and self.env.seishub.online:
         if self.env.debug:
             print ' * No cached file found for %s.%s.%s.%s' \
                 % (network, station, location, channel)
         stream = self.getPreview(network, station, location, channel,
                                   station_path)
         return stream
     else:
         # Otherwise figure out if the requested time span is already cached.
         times = [(float(file.split('--')[1]), float(file.split('--')[2]),
                   os.path.join(station_path, file)) for file in files]
         starttime = self.env.starttime.timestamp
         endtime = self.env.endtime.timestamp
         # Times should be sorted anyway so explicit sorting is not necessary.
         # Additionally by design there should be no overlaps.
         missing_time_frames = []
         times = [time for time in times if time[0] <= endtime and time[1] >=
                  starttime]
         if len(times):
             if starttime < times[0][0]:
                 missing_time_frames.append((starttime, times[0][0] +
                                             self.env.buffer))
             for _i in xrange(len(times) - 1):
                 missing_time_frames.append((times[_i][1] - self.env.buffer,
                                 times[_i + 1][0] + self.env.buffer))
             if endtime > times[-1][1]:
                 missing_time_frames.append((times[-1][1] - self.env.buffer,
                                             endtime))
             # Load all cached files.
             stream = self.loadFiles(times)
         else:
             missing_time_frames.append((self.env.starttime -
                     self.env.buffer, self.env.endtime + self.env.buffer))
             stream = Stream()
         # Get the gaps.
         if missing_time_frames and self.env.seishub.online:
             if self.env.debug:
                 print ' * Only partially cached file found for %s.%s.%s.%s.' \
                       % (network, station, location, channel) + \
                       ' Requesting the rest from SeisHub...'
             stream += self.loadGaps(missing_time_frames, network, station,
                                     location, channel)
             if not stream:
                 msg = 'No data available for %s.%s.%s.%s for the selected timeframes'\
                     % (network, station, location, channel)
                 print msg
                 return
         else:
             if self.env.debug:
                 print ' * Cached file found for %s.%s.%s.%s' \
                     % (network, station, location, channel)
         # XXX: Pretty ugly to ensure all data has the same dtype.
         for trace in stream:
             trace.data = np.require(trace.data, dtype='float32')
         # Merge everything and pickle once again.
         stream = mergePreviews(stream)
         # Pickle the stream object for future reference. Do not pickle it if it
         # is smaller than 200 samples. Just not worth the hassle.
         if stream[0].stats.npts > 200:
             # Delete all the old files.
             for _, _, file in times:
                 os.remove(file)
             filename = os.path.join(station_path, '%s[%s]--%s--%s--cache' % \
                         (channel, location, str(stream[0].stats.starttime.timestamp),
                          str(stream[0].stats.endtime.timestamp)))
             file = open(filename, 'wb')
             pickle.dump(stream, file, 2)
             file.close()
         return stream