def test_copy(self):
     """
     Should create a copy of the database.
     """
     db0 = PickDatabaseConnection(':memory:')
     for pick in uniq_picks:
         db0.add_pick(**pick)
     db1 = db0.copy()
     self.assertEqual(db0.get_picks()[0], db1.get_picks()[0])
Exemple #2
0
class MenusTestCase(unittest.TestCase):
    """
    Master suite for testing the menu classes.
    """

    def setUp(self):
        """
        Setup benchmark data.
        """
        self.segy = readSEGY(get_example_file("ew0210_o30.segy"), unpack_headers=True)
        self.pickdb = PickDatabaseConnection(":memory:")
        for pick in uniq_picks:
            self.pickdb.update_pick(**pick)
Exemple #3
0
 def setUp(self):
     """
     Setup benchmark data.
     """
     self.segy = readSEGY(get_example_file("ew0210_o30.segy"), unpack_headers=True)
     self.pickdb = PickDatabaseConnection(":memory:")
     for pick in uniq_picks:
         self.pickdb.update_pick(**pick)
Exemple #4
0
    def test_raytrace_branch_id(self):
        """
        Raytracing should honor branch ids
        """
        #vmfile = get_example_file('jump1d.vm')
        vmfile = get_example_file('inactive_layers.vm')

        # Create pick database
        pickdbfile = 'temp.sqlite'
        if os.path.isfile(pickdbfile):
            os.remove(pickdbfile)
        pickdb = PickDatabaseConnection(pickdbfile)
        pickdb.update_pick(event='P1', ensemble=100, trace=1,
                           branch=1, subbranch=0,
                           time=5.0, source_x=10, source_y=0.0, source_z=0.006,
                           receiver_x=40, receiver_y=0.0, receiver_z=4.9)
        pickdb.update_pick(event='P2', ensemble=100, trace=1,
                           branch=2, subbranch=0, time=5.0)
        pickdb.update_pick(event='P3', ensemble=100, trace=1,
                           branch=3, subbranch=0, time=5.0)
        pickdb.commit()

        # Raytrace
        vm = readVM(vmfile)
        rayfile = 'temp.ray'
        for branch in range(1, 4):
            if os.path.isfile(rayfile):
                os.remove(rayfile)
            pick_keys = {'branch' : branch}
            raytrace(vmfile, pickdb, rayfile, **pick_keys) 
            # Should have created a rayfile
            self.assertTrue(os.path.isfile(rayfile))
            # Load rayfans
            rays = readRayfanGroup(rayfile)
            # Should have traced just one ray
            self.assertEqual(len(rays.rayfans), 1)
            rfn = rays.rayfans[0]
            self.assertEqual(len(rfn.paths), 1)
            # Rays should turn in specified layer
            zmax = max([p[2] for p in rfn.paths[0]])
            self.assertGreaterEqual(zmax, vm.rf[branch - 1][0][0])

        # cleanup
        for filename in [rayfile, pickdbfile]:
            if os.path.isfile(filename):
                os.remove(filename)
 def test_plot_picks(self):
     """
     Should add picks to an axes and the layer managers.
     """
     fig = plt.figure()
     ax = fig.add_subplot(111)
     pickdb = PickDatabaseConnection(':memory:')
     for pick in uniq_picks:
         pickdb.update_pick(**pick)
     splt = SEGYPickPlotter(ax, self.segy, pickdb)
     # should add a single artist to the line dict. for each event
     splt.plot_picks()
     for event in self.pickdb.events:
         self.assertTrue(len(splt.ACTIVE_LINES[event]), 1)
     # should be able to add new picks and have them be accessible by the
     # SEGYPickPlotter
     new_event = '--tracer--'
     new_pick = copy.copy(uniq_picks[0])
     new_pick['event'] = new_event
     pickdb.update_pick(**new_pick)
     splt.plot_picks()
     self.assertTrue(new_event in splt.ACTIVE_LINES)
Exemple #6
0
    def test_parallel_raytrace(self):
        """
        Should run raytracing in parallel
        """
        # Create pick database
        pickdbfile = 'temp.sqlite'
        if os.path.isfile(pickdbfile):
            os.remove(pickdbfile)
        pickdb = PickDatabaseConnection(pickdbfile)

        for i, event in enumerate( ['P1', 'P2', 'P3']):
            branch = i + 1
            for ens in range(3):
                pickdb.update_pick(event=event, ensemble=ens, trace=1,
                                   branch=branch, subbranch=0,
                                   time=5.0, source_x=10, source_y=0.0,
                                   source_z=0.006, receiver_x=40,
                                   receiver_y=0.0, receiver_z=4.9)
        pickdb.commit()

        # set velocity model
        vmfile = get_example_file('inactive_layers.vm')

        # raytrace
        for nproc in [1, 2, 8]:
            input_dir = 'test.input'
            output_dir = 'test.output'
            t_start = time.clock()
            parallel_raytrace(vmfile, pickdb, branches=[1, 2, 3],
                              input_dir=input_dir, output_dir=output_dir,
                              nproc=nproc, ensemble_field='ensemble')
            t_elapsed = time.clock() - t_start

        shutil.rmtree(input_dir)
        shutil.rmtree(output_dir)
        os.remove(pickdbfile)
Exemple #7
0
 def test_locate_on_surface(self):
     """
     Should locate a receiver on a surface.
     """
     inst_id = 100
     dx = 1
     iref = 0
     for _vmfile in TEST_MODELS:
         vmfile = get_example_file(_vmfile)
         vm = readVM(vmfile)
         # calculate synthetic times
         pickdb = PickDatabaseConnection(':memory:')
         x0 = np.mean(vm.x)
         y0 = np.mean(vm.y)
         picks = []
         xsearch = vm.xrange2i(max(vm.r1[0], x0 - dx),
                               min(vm.r2[0], x0 + dx))
         for i, ix in enumerate(xsearch):
             x = vm.x[ix]
             iy = vm.x2i([y0])[0]
             z0 = vm.rf[iref][ix][iy]
             pickdb.add_pick(event='Pw', ensemble=inst_id,
                             trace=i, time=1e30,
                             source_x=x, source_y=y0, source_z=0.006,
                             receiver_x=x0, receiver_y=y0, receiver_z=z0,
                             vm_branch=1, vm_subid=0)
         rayfile = 'temp.ray'
         raytrace(vmfile, pickdb, rayfile)
         raydb = rayfan2db(rayfile, 'temp.syn.sqlite', synthetic=True)
         os.remove(rayfile)
         # run locate
         x, y, z, rms = locate_on_surface(vmfile, raydb, 0, x0=x0,
                                     y0=y0, dx=dx, dy=dx)
         # compare result
         self.assertAlmostEqual(x, x0, 0)
         self.assertAlmostEqual(y, y0, 0)
 def setUp(self):
     """
     Setup benchmark data.
     """
     self.segy = readSEGY(get_example_file('ew0210_o30.segy'),
                          unpack_headers=True)
     self.pickdb = PickDatabaseConnection(':memory:')
     for pick in uniq_picks:
         self.pickdb.update_pick(**pick)
     self.default_params = [
         'ABSCISSA_KEY', 'GAIN', 'CLIP',
         'NORMALIZATION_METHOD', 'OFFSET_GAIN_POWER',
         'WIGGLE_PEN_COLOR', 'WIGGLE_PEN_WIDTH',
         'NEG_FILL_COLOR', 'POS_FILL_COLOR', 'DISTANCE_UNIT',
         'TIME_UNIT', 'SEGY_TIME_UNITS', 'SEGY_DISTANCE_UNITS',
         'SEGY_HEADER_ALIASES']
 def test_get_events(self):
     """
     Should return a list of event names in the database.
     """
     pickdb = PickDatabaseConnection(':memory:')
     # add picks and get our own list of event names
     events = []
     for i, pick in enumerate(uniq_picks):
         pickdb.add_pick(**pick)
         if pick['event'] not in events:
             events.append(pick['event'])
     # private function should return a list of events
     self.assertEqual(sorted(pickdb._get_events()),
                      sorted(events))
     # public 'events' attribute should be a property with
     # setfunction=_get_events
     self.assertEqual(pickdb.events,
                      pickdb._get_events())
     # should return empty list for empty database
     pickdb = PickDatabaseConnection(':memory:')
     self.assertEqual(pickdb.events, [])
Exemple #10
0
pickdb_file = 'synthetic.sqlite'

# Cleanup from any earlier run
for filename in [vmfile, temp_pickdb_file, rayfile, pickdb_file]:
    if os.path.isfile(filename):
        os.remove(filename)

# Create a simple 1D model
vm = VM(r1=(0, 0, 0), r2=(100, 0, 15), dx=0.5, dy=1, dz=0.1)
vm.insert_interface(np.asarray([[z] for z in 2.0 * np.ones((vm.nx))]))
vm.define_constant_layer_velocity(0, 1.5)
vm.define_stretched_layer_velocities(1, [1.5, 4.0, 6.5, 8.0, 8.5])
vm.write(vmfile)

# Create a pickdb to define source/receiver geometry
pickdb = PickDatabaseConnection(temp_pickdb_file)
sx = np.arange(1., 80., 1.)
rx = [1.]
for irec, _rx in enumerate(rx):
    for isrc, _sx in enumerate(sx):
        d = {'event': 'Pg',
             'vm_branch': 1,
             'vm_subid': 0,
             'ensemble': irec + 10,
             'trace': isrc + 5000,
             'time': 0.0,
             'time_reduced': 0.0,
             'source_x': _sx,
             'source_y': 0.0,
             'source_z': 1.0,
             'receiver_x': _rx,
from rockfish.picking.database import PickDatabaseConnection
from rockfish.vmtomo.vm import VMFile
from rockfish.vmtomo.rayfan import RayfanFile
from rockfish.vmtomo.raytracing import trace

cleanup = True

# Set filenames for model, pick database, etc.
pickdbfile = 'temp.pickdb.sqlite'
vmfile = '../tests/data/goc_l26.15.00.vm'
#vmfile = '../tests/data/cranis_northeast.00.00.vm'
input_dir = 'temp.input'
rayfile = 'temp.ray'

# Build an example pick database
pickdb = PickDatabaseConnection(pickdbfile)
pickdb.add_pick(event='Pg', ensemble=100, trace=1,
        vm_branch=3, vm_subid=0,
        time=5.0, time_reduced=5.0,
        source_x=50, source_y=0.0, source_z=0.006,
        receiver_x=100, receiver_y=0.0, receiver_z=2)
pickdb.commit()

# Raytrace the model
trace(vmfile, pickdb, rayfile, input_dir=input_dir, cleanup=False)

# Plot model and rays
fig = plt.figure()
ax = fig.add_subplot(311)
vm = VMFile(vmfile)
vm.plot(ax=ax)
Exemple #12
0
class SEGYPlotterTestCase(unittest.TestCase):
    """
    Test cases for SEGYPlotter and SEGYPickPlotter
    """
    def setUp(self):
        """
        Setup benchmark data.
        """
        self.segy = readSEGY(get_example_file('ew0210_o30.segy'),
                             unpack_headers=True)
        self.pickdb = PickDatabaseConnection(':memory:')
        for pick in uniq_picks:
            self.pickdb.update_pick(**pick)

    def test_init_SEGYPlotter(self):
        """
        Should create a new instance of the SEGYPlotter.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotter(ax, self.segy)
        # should inherit from SEGYPlotManager
        for member in inspect.getmembers(SEGYPlotManager):
            self.assertTrue(hasattr(splt, member[0]))
        # should *not* build header lookup table
        self.assertFalse(hasattr(splt, 'sdb'))
        # should attach axes
        self.assertTrue(isinstance(splt.ax, matplotlib.axes.Axes))

    def test_plot_negative_wiggle_fills(self):
        """
        Should add negative wiggle fills to the axes.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotter(ax, self.segy)
        # should add a single artist to the active patch dict.
        splt.plot_wiggles(negative_fills=True)
        self.assertEqual(len(splt.ACTIVE_PATCHES['negative_fills']),1)
        # same artist should be in the axes list
        self.assertTrue(splt.ACTIVE_PATCHES['negative_fills'][0] in \
                        splt.ax.patches)
        # should be able to directly remove it
        splt.ax.patches.remove(splt.ACTIVE_PATCHES['negative_fills'][0])
        self.assertEqual(len(splt.ax.patches), 0)
        # and re-add it
        splt.ax.patches.append(splt.ACTIVE_PATCHES['negative_fills'][0])
        self.assertEqual(len(splt.ax.patches), 1)
        # should move artist to the inactive patch dict.
        splt.plot_wiggles(negative_fills=False)
        self.assertEqual(len(splt.INACTIVE_PATCHES['negative_fills']),1)
        self.assertEqual(len(splt.ax.patches), 0)
        # calling again should change nothing
        splt.plot_wiggles(negative_fills=False)
        self.assertEqual(len(splt.INACTIVE_PATCHES['negative_fills']),1)
        self.assertEqual(len(splt.ax.patches), 0)
        # should move artist back to active patch dict.
        splt.plot_wiggles(negative_fills=True)
        self.assertEqual(len(splt.ACTIVE_PATCHES['negative_fills']),1)
        self.assertEqual(len(splt.ax.patches), 1)
        # new plot items should be accessible through the orig. axes
        self.assertEqual(len(splt.ax.patches),
                         len(ax.patches))

    def test_plot_positive_wiggle_fills(self):
        """
        Should add positive wiggle fills to the axes.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotter(ax, self.segy)
        # should add a single artist to the active patch dict.
        splt.plot_wiggles(positive_fills=True)
        self.assertEqual(len(splt.ACTIVE_PATCHES['positive_fills']),1)
        # same artist should be in the axes list
        self.assertTrue(splt.ACTIVE_PATCHES['positive_fills'][0] in \
                        splt.ax.patches)
        # should be able to directly remove it
        splt.ax.patches.remove(splt.ACTIVE_PATCHES['positive_fills'][0])
        self.assertEqual(len(splt.ax.patches), 0)
        # and re-add it
        splt.ax.patches.append(splt.ACTIVE_PATCHES['positive_fills'][0])
        self.assertEqual(len(splt.ax.patches), 1)
        # should move artist to the inactive patch dict.
        splt.plot_wiggles(positive_fills=False)
        self.assertEqual(len(splt.INACTIVE_PATCHES['positive_fills']),1)
        self.assertEqual(len(splt.ax.patches), 0)
        # calling again should change nothing
        splt.plot_wiggles(positive_fills=False)
        self.assertEqual(len(splt.INACTIVE_PATCHES['positive_fills']),1)
        self.assertEqual(len(splt.ax.patches), 0)
        # should move artist back to active patch dict.
        splt.plot_wiggles(positive_fills=True)
        self.assertEqual(len(splt.ACTIVE_PATCHES['positive_fills']),1)
        self.assertEqual(len(splt.ax.patches), 1)
        # new plot items should be accessible through the orig. axes
        self.assertEqual(len(splt.ax.patches),
                         len(ax.patches))

    def test_plot_wiggle_traces(self):
        """
        Should add wiggle traces to the axes.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotter(ax, self.segy)
        # should add a single artist to the active line dict.
        splt.plot_wiggles(wiggle_traces=True)
        self.assertEqual(len(splt.ACTIVE_LINES['wiggle_traces']),1)
        # same artist should be in the axes list
        self.assertTrue(splt.ACTIVE_LINES['wiggle_traces'][0] in \
                        splt.ax.lines)
        # should be able to directly remove it
        splt.ax.lines.remove(splt.ACTIVE_LINES['wiggle_traces'][0])
        self.assertEqual(len(splt.ax.lines), 0)
        # and re-add it
        splt.ax.lines.append(splt.ACTIVE_LINES['wiggle_traces'][0])
        self.assertEqual(len(splt.ax.lines), 1)
        # should move artist to the inactive line dict.
        splt.plot_wiggles(wiggle_traces=False)
        self.assertEqual(len(splt.INACTIVE_LINES['wiggle_traces']),1)
        self.assertEqual(len(splt.ax.lines), 0)
        # calling again should change nothing
        splt.plot_wiggles(wiggle_traces=False)
        self.assertEqual(len(splt.INACTIVE_LINES['wiggle_traces']),1)
        self.assertEqual(len(splt.ax.lines), 0)
        # should move artist back to active line dict and re-add to axes
        splt.plot_wiggles(wiggle_traces=True)
        self.assertEqual(len(splt.ACTIVE_LINES['wiggle_traces']),1)
        self.assertEqual(len(splt.ax.lines), 1)
        # new plot items should be accessible through the orig. axes
        self.assertEqual(len(splt.ax.lines),
                         len(ax.lines))

    def test_link_axes(self):
        """
        Plotting should update items in the linked axes.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotter(ax, self.segy)
        # should add one artist to our axes
        splt.plot_wiggles(wiggle_traces=True)
        self.assertEqual(len(splt.ACTIVE_LINES['wiggle_traces']), 1)
        self.assertTrue('wiggle_traces' not in splt.INACTIVE_LINES)
        self.assertEqual(len(ax.lines), 1)
        # should remove one artist to our axes
        splt.plot_wiggles(wiggle_traces=False)
        self.assertTrue('wiggle_traces' not in splt.ACTIVE_LINES)
        self.assertEqual(len(splt.INACTIVE_LINES['wiggle_traces']), 1)
        self.assertEqual(len(ax.lines), 0)

    def test_init_SEGYPickPlotter(self):
        """
        Should create a new instance of the SEGYPickPlotter.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPickPlotter(ax, self.segy, pickdb=self.pickdb)
        # should inherit from SEGYPlotter
        for member in inspect.getmembers(SEGYPlotter):
            self.assertTrue(hasattr(splt, member[0]))
        # should build header lookup table
        self.assertTrue(isinstance(splt.sdb, SEGYHeaderDatabase))
        # should attach axes
        self.assertTrue(isinstance(splt.ax, matplotlib.axes.Axes))
        
    def test_plot_picks(self):
        """
        Should add picks to an axes and the layer managers.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        pickdb = PickDatabaseConnection(':memory:')
        for pick in uniq_picks:
            pickdb.update_pick(**pick)
        splt = SEGYPickPlotter(ax, self.segy, pickdb)
        # should add a single artist to the line dict. for each event
        splt.plot_picks()
        for event in self.pickdb.events:
            self.assertTrue(len(splt.ACTIVE_LINES[event]), 1)
        # should be able to add new picks and have them be accessible by the
        # SEGYPickPlotter
        new_event = '--tracer--'
        new_pick = copy.copy(uniq_picks[0])
        new_pick['event'] = new_event
        pickdb.update_pick(**new_pick)
        splt.plot_picks()
        self.assertTrue(new_event in splt.ACTIVE_LINES)
Exemple #13
0
def rayfan2db(rayfan_file, raydb_file=':memory:', synthetic=False, noise=None,
              pickdb=None, raypaths=False):
    """
    Read a rayfan file and store its data in a database.

    Data are stored in a modified version of a
    :class:`rockfish.picking.database.PickDatabaseConnection`.

    Parameters
    ----------
    rayfan_file: {str, file}
        An open file-like object or a string which is assumed to be a
        filename of a rayfan binary file.
    raydb_file: str, optional
        The filename of the new database. Default is to create
        a new database in memory.
    synthetic: bool, optional
        Determines whether or not to record traced traveltimes as picked 
        traveltimes.
    noise: {float, None}
        Maximum amplitude of random noise to add the travel times.  If
        ``None``, no noise is added.
    pickdb: :class:`rockfish.database.PickDatabaseConnection`, optional
        An active connection to the pick database that was used
        trace rays in ``rayfan_file``. Values for extra fields (e.g., 
        'trace_in_file') are copied from this database to the new
        database along with rayfan data. Default is ignore these extra fields.
    raypaths: bool, optional
        If ``True``, raypath coordinates are stored as text in a new table
        'raypaths'.
    """
    raydb = PickDatabaseConnection(raydb_file)
    rays = readRayfanGroup(rayfan_file)
    print "Adding {:} traveltimes to {:} ..."\
            .format(rays.nrays, raydb_file)
    # add fields for raypaths
    if raypaths:
        raydb._create_table_if_not_exists(RAYPATH_TABLE, RAYPATH_FIELDS)
    ndb0 = raydb.execute('SELECT COUNT(rowid) FROM picks').fetchone()[0]
    for rfn in rays.rayfans:
        for i, _t in enumerate(rfn.travel_times):
            if noise is not None:
                _noise = noise * 2 * (np.random.random() - 0.5)
            else:
                _noise = 0.0
            sx, sy, sz = rfn.paths[i][0]
            rx, ry, rz = rfn.paths[i][-1]
            if pickdb is not None:
                event = pickdb.vmbranch2event[rfn.event_ids[i]]
            else:
                event = rfn.event_ids[i]
            if synthetic:
                time = rfn.travel_times[i] + _noise
                time_reduced = time
                predicted = None
                residual = 0.
            else:
                time = rfn.pick_times[i]
                time_reduced = time
                predicted = rfn.travel_times[i]
                residual = rfn.residuals[i] 
            d = {'event': event,
                 'ensemble': rfn.start_point_id,
                 'trace': rfn.end_point_ids[i],
                 'vm_branch': rfn.event_ids[i],
                 'vm_subid': rfn.event_subids[i],
                 'time' : time,
                 'time_reduced' : time_reduced,
                 'predicted' : predicted,
                 'residual' : residual,
                 'error': rfn.pick_errors[i],
                 'source_x': sx, 'source_y': sy, 'source_z': sz,
                 'receiver_x': rx, 'receiver_y': ry, 'receiver_z': rz,
                 'offset': rfn.offsets[i],
                 'faz': rfn.azimuths[i],
                 'method': 'rayfan2db({:})'.format(rayfan_file),
                 'data_file': rays.file.name}
            # Copy data from pickdb
            if pickdb is not None:
                pick = pickdb.get_picks(event=[event],
                                        ensemble=[d['ensemble']],
                                        trace=[d['trace']])
                if len(pick) > 0:
                    for f in ['trace_in_file', 'line', 'site', 'data_file']:
                        try:
                            d[f] = pick[0][f]
                        except KeyError:
                            pass
            # add data to standard tables
            raydb.update_pick(**d)
            # add raypath data to new table
            if raypaths:
                d = {'event': event,
                     'ensemble': rfn.start_point_id,
                     'trace': rfn.end_point_ids[i],
                     'ray_btm_x': rays.bottom_points[i][0],
                     'ray_btm_y': rays.bottom_points[i][1],
                     'ray_btm_z': rays.bottom_points[i][2],
                     'ray_x': str([p[0] for p in rfn.paths[i]]),
                     'ray_y': str([p[1] for p in rfn.paths[i]]),
                     'ray_z': str([p[2] for p in rfn.paths[i]])}
                raydb._insert(RAYPATH_TABLE, **d)

        raydb.commit()
    ndb = raydb.execute('SELECT COUNT(rowid) FROM picks').fetchone()[0]
    if (ndb - ndb0) != rays.nrays:
        msg = 'Only added {:} of {:} travel times to the database.'\
                .format(ndb - ndb0, rays.nrays)
        warnings.warn(msg)
    return raydb
Exemple #14
0
 def test_add_remove_picks(self):
     """
     Should add a pick to the picks table.
     """
     pickdb = PickDatabaseConnection(':memory:')
     # should add all picks to the database
     for pick in uniq_picks:
         pickdb.add_pick(**pick)
     pickdb.commit()
     ndb = len(pickdb.execute('SELECT * FROM picks').fetchall())
     self.assertEqual(ndb, len(uniq_picks))
     # attempting to add pick without primary fields should raise an error
     with self.assertRaises(sqlite3.IntegrityError):
         pickdb.add_pick(event='Foobar', time=9.834)
     # attempting to add duplicate picks should raise error
     for pick in uniq_picks:
         with self.assertRaises(sqlite3.IntegrityError):
             pickdb.add_pick(**pick)
     # directly removing pick and then re-adding should work
     for pick in uniq_picks:
         pickdb.remove_pick(**pick)
         pickdb.add_pick(**pick)
     pickdb.commit()
     ndb = len(pickdb.execute('SELECT * FROM picks').fetchall())
     self.assertEqual(ndb, len(uniq_picks))
     # invalid collumn names should raise OperationalError
     with self.assertRaises(sqlite3.OperationalError):
         pickdb.remove_pick(not_a_field=999)
     with self.assertRaises(sqlite3.OperationalError):
         pickdb.add_pick(not_a_field=999)
     # remove the last pick that we added
     pickdb.remove_pick(**pick)
     # attempting to remove a non-existant pick should do nothing
     pickdb.remove_pick(**pick)
     # updating picks should add picks if they don't exist and update picks
     # if they do exist
     for pick in uniq_picks:
         pickdb.update_pick(**pick)
     pickdb.commit()
     ndb = len(pickdb.execute('SELECT * FROM picks').fetchall())
     self.assertEqual(ndb, len(uniq_picks))
     # should not be able to add pick without required fields
     with self.assertRaises(sqlite3.IntegrityError):
         pickdb.add_pick(event='Pg', ensemble=1, trace=1)
Exemple #15
0
 def test_get_picks(self):
     """
     Should return data rows.
     """
     # create a new database in memory
     pickdb = PickDatabaseConnection(':memory:')
     # add picks, and get our own times
     times0 = []
     for pick in uniq_picks:
         pickdb.add_pick(**pick)
         times0.append(pick['time'])
     pickdb.commit()
     # should return a list of sqlite3.Row objects
     for row in pickdb.get_picks(event=pick['event']):
         self.assertTrue(isinstance(row, sqlite3.Row))
     # should return the same data
     times1 = []
     for pick in uniq_picks:
         row = pickdb.get_picks(**pick)[0]
         times1.append(row['time'])
     self.assertEqual(sorted(times0), sorted(times1))
     # should return all data
     self.assertEqual(len(uniq_picks),
                      len(pickdb.get_picks()))
     # should return empty list if no matches
     self.assertEqual(len(pickdb.get_picks(event='golly_gee')), 0)
     # should also return empty list if no data
     pickdb = PickDatabaseConnection(':memory:')
     self.assertEqual(len(pickdb.get_picks()), 0)
Exemple #16
0
 def test_counts(self):
     """
     Count functions should return number of distinct rows in tables.
     """
     pickdb = PickDatabaseConnection(':memory:')
     # add a single pick
     event = 'TestCount'
     pick = uniq_picks[0]
     pick['event'] = event
     pickdb.add_pick(**pick)
     pickdb.commit()
     # should have a single pick for this event in picks table
     n = pickdb.count_picks_by_event(event)
     self.assertEqual(n, 1)
     # should also have a single entry in the event table
     n = pickdb._count_distinct(pickdb.EVENT_TABLE, event=event)
     self.assertEqual(n, 1)
     # now, another pick for the same event
     pick['ensemble'] += 1  # ensure unique
     pickdb.add_pick(**pick)
     pickdb.commit()
     # should have two picks for this event in the picks table
     n = pickdb.count_picks_by_event(event)
     self.assertEqual(n, 2)
     # should still just have one entry in the event table for this event
     n = pickdb._count_distinct(pickdb.EVENT_TABLE, event=event)
     self.assertEqual(n, 1)
Exemple #17
0
def locate_on_surface(vmfile, pickdb, iref, pick_keys={}, x0=None, y0=None,
                      dx=None, dy=None, plot=False):
    """
    Locate a receiver on a surface.

    :param vmfile: Filename of the VM Tomography slowness model to
        raytrace.
    :param pickdb: :class:`rockfish.picking.database.PickDatabaseConnection`
        to get picks from for raytracing.
    :param iref: index of the surface to move the receiver on
    :param pick_keys: Optional. ``dict`` of keys and values to use for
        selecting picks from ``pickdb``.  Default is to use all picks.
    :param x0, y0: Optional. Initial guess at x and y locations. Default is
        center of the model.
    :param dx, dy: Optional. Distance in x and y to search from ``x0, y0``.
        Default is to search the entire model.
    :param plot: Optional. Plot results of the location. Default is false.
    """
    # Load model
    print "Loading VM model..."
    vm = readVM(vmfile)
    # Setup search domain
    print "Setting up search domain..."
    if x0 is None:
        x0 = np.mean(vm.x)
    if y0 is None:
        y0 = np.mean(vm.y)
    if dx is None:
        xsearch = vm.x
    else:
        ix = vm.xrange2i(max(vm.r1[0], x0 - dx), min(vm.r2[0], x0 + dx))
        xsearch = vm.x[ix]
    if dy is None:
        ysearch = vm.y
    else:
        iy = vm.yrange2i(max(vm.r1[1], y0 - dy), min(vm.r2[1], y0 + dy))
        ysearch = vm.y[iy]
    # trace each point in the search region
    print "Building traveltime database for {:}x{:} search grid..."\
            .format(len(xsearch), len(ysearch))
    zz = vm.rf[iref]
    ipop = -1
    population = []
    db = PickDatabaseConnection(':memory:')
    for x in xsearch:
        for y in ysearch:
            z = zz[vm.x2i([x])[0], vm.y2i([y])[0]]
            ipop += 1
            for _p in pickdb.get_picks(**pick_keys):
                p = dict(_p)
                p['ensemble'] = ipop 
                p['receiver_x'] = x
                p['receiver_y'] = y
                p['receiver_z'] = z
                db.add_pick(**p)
            population.append((x, y, z))
    print "Raytracing..."
    rayfile = '.locate.ray'
    raytrace(vmfile, db, rayfile, stderr=subprocess.STDOUT,
             stdout=subprocess.PIPE)
    rays = readRayfanGroup(rayfile)
    imin = np.argmin([rfn.rms for rfn in rays.rayfans]) 
    ipop = [rfn.start_point_id for rfn in rays.rayfans][imin]
    rms = [rfn.rms for rfn in rays.rayfans][imin]
#    # warn if fit point is on edge of search domain
#    if (vm.ny > 1) and ((yfit == ysearch[0]) or (yfit == ysearch[-1])):
#        on_edge = True
#    elif (xfit == xsearch[0]) or (xfit == xsearch[-1]):
#        on_edge = True
#    else:
#        on_edge = False
#    if on_edge:
#        msg = 'Best-fit location ({:},{:}) is on edge of model domain:'\
#                .format(xfit, yfit)
#        msg += ' x=[{:},{:}], y=[{:},{:}].'.format(xsearch[0], xsearch[-1],
#                                                   ysearch[0], ysearch[-1])
#        msg += ' Try using a larger search region.'
#        warnings.warn(msg)
    # plot
    if plot:
        fig = plt.figure()
        ax = fig.add_subplot(111)
        if vm.ny == 1:
            ax.plot([p[0] for p in population], 
                    [rfn.rms for rfn in rays.rayfans], '.k')
            ax.plot(population[ipop][0], rms, '*r')
            plt.xlabel('x (km)')
            plt.ylabel('RMS error (s)')
        else:
            ax.plot([p[0] for p in population], [p[1] for p in population],
                    '.k')
            ax.plot(population[ipop][0], population[ipop][1], '*r')
            ax.plot(x0, y0, 'og')
            plt.xlabel('x (km)')
            plt.xlabel('y (km)')
        plt.title('Results of locate_on_surface()')
        plt.show()
    return population[ipop][0], population[ipop][1], population[ipop][2], rms
Exemple #18
0
class SEGYPlotManagerTestCase(unittest.TestCase):
    """
    Test cases for SEGYPlotManager
    """
    def setUp(self):
        """
        Setup benchmark data.
        """
        self.segy = readSEGY(get_example_file('ew0210_o30.segy'),
                             unpack_headers=True)
        self.pickdb = PickDatabaseConnection(':memory:')
        for pick in uniq_picks:
            self.pickdb.update_pick(**pick)
        self.default_params = [
            'ABSCISSA_KEY', 'GAIN', 'CLIP',
            'NORMALIZATION_METHOD', 'OFFSET_GAIN_POWER',
            'WIGGLE_PEN_COLOR', 'WIGGLE_PEN_WIDTH',
            'NEG_FILL_COLOR', 'POS_FILL_COLOR', 'DISTANCE_UNIT',
            'TIME_UNIT', 'SEGY_TIME_UNITS', 'SEGY_DISTANCE_UNITS',
            'SEGY_HEADER_ALIASES']

    def test_init_SEGYPlotManager(self):
        """
        Should create a new instance of the SEGYPlotManager.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotManager(ax, self.segy)
        # should have the default parameters
        for attr in self.default_params:
            self.assertTrue(hasattr(splt, attr))
        # by default, should *not* have header databse
        self.assertFalse(hasattr(splt, 'sdb'))
        # should attach axes
        self.assertTrue(isinstance(splt.ax, matplotlib.axes.Axes))
        # class should be able to update the axes
        splt.ax.plot([0,1],[0,1])
        self.assertEqual(len(splt.ax.lines), 1)
        # updates outside the class should be seen inside the class
        ax.plot([0,1],[0,1])
        self.assertEqual(len(splt.ax.lines), 2)
        self.assertEqual(len(ax.lines), 2)
        # if a pick database is given, should attach pickdb and build lookup db
        splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb)
        self.assertTrue(isinstance(splt.pickdb, PickDatabaseConnection))
        self.assertTrue(isinstance(splt.sdb, SEGYHeaderDatabase))
    
    def test_build_header_database(self):
        """
        Should connect to a database for looking up header attributes.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        # connecting to a database in memory should not create a file
        splt = SEGYPlotManager(ax, self.segy, trace_header_database=':memory:')
        self.assertFalse(os.path.isfile(':memory:'))
        # connecting to a filename should create a file if it does not exist
        filename = 'temp.test_build_header_database.sqlite'
        if os.path.isfile(filename):
            os.remove(filename)
        # should not create database if no pickdb is given
        splt = SEGYPlotManager(ax, self.segy, trace_header_database=filename)
        self.assertFalse(os.path.isfile(filename))
        # should create database if pickdb is given
        splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb, 
                               trace_header_database=filename)
        self.assertTrue(os.path.isfile(filename))
        # clean up
        os.remove(filename)
        # create plot manager in memory with pickdb
        splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb)
        # should be able to get primary fields from the picks trace table
        pick_keys = splt.pickdb._get_primary_fields(splt.pickdb.TRACE_TABLE)
        self.assertTrue(len(pick_keys) > 0)
        # pick primary keys should be in the alias dictionary
        # or be a header attribute
        header = self.segy.traces[0].header
        for k in pick_keys:
            self.assertTrue(k in splt.SEGY_HEADER_ALIASES \
                            or hasattr(header, k))
        # should add header fields for primary keys in the pick database
        sql = 'SELECT * FROM %s' %splt.sdb.TRACE_TABLE
        data = splt.sdb.execute(sql)
        for key in pick_keys:
            _key = splt._get_header_alias(key)
            for i,row in enumerate(data):
                segy_value = splt.get_header_value(
                    self.segy.traces[i].header, key,                                                convert_units=False)
                db_value = row[_key]
                self.assertEqual(segy_value, db_value)

    def test_get_units(self):
        """
        Should return (segy units, plot units) or None.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotManager(ax, self.segy)
        splt.DISTANCE_UNIT = 'distance_unit_marker'
        splt.TIME_UNIT = 'time_unit_marker'
        for key in TRACE_HEADER_KEYS:
            if key in splt.SEGY_TIME_UNITS:
                # should return TIME_UNIT for a time attribute
                self.assertEqual(splt._get_units(key)[1],
                                 'time_unit_marker')
            elif key in splt.SEGY_DISTANCE_UNITS:
                # should return DISTANCE_UNIT for a distance attribute
                self.assertEqual(splt._get_units(key)[1],
                                 'distance_unit_marker')
            else:
                # should return None values are unitless
                self.assertEqual(splt._get_units(key), None)

    def test_convert_units(self):
        """
        Should convert header units to plot units.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotManager(ax, self.segy)
        # should correctly perform unit conversions for distance
        splt.DISTANCE_UNIT = 'km'
        self.assertEqual(splt._convert_units('offset', [1000]), [1])
        # should correctly perform unit conversions for time
        splt.TIME_UNIT = 's'
        self.assertEqual(splt._convert_units('delay', [1000]), [1]) 

    def test_get_time_array(self):
        """
        Should return an array of time values for a single trace.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotManager(ax, self.segy)
        tr = self.segy.traces[0]
        # should have npts values
        npts = splt.get_header_value(tr.header, 'npts')
        t = splt.get_time_array(tr.header)
        self.assertEqual(npts, len(t))

    def test_get_header_value(self):
        """
        Should be able to use aliases to get unit-converted values 
        from headers.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotManager(ax, self.segy)
        # should return exact header values when convert_units=False
        tr = self.segy.traces[0]
        for alias in splt.SEGY_HEADER_ALIASES:
            key = splt.SEGY_HEADER_ALIASES[alias]
            value = splt.get_header_value(tr.header, alias,
                                           convert_units=False)
            _value = tr.header.__getattribute__(key)
            self.assertEqual(value, _value)
        # default should return header values in the plot units
        splt.DISTANCE_UNIT = 'km'
        alias = 'offset'
        key = splt.SEGY_HEADER_ALIASES[alias]
        scaled_value = splt.get_header_value(tr.header, alias)
        header_value = tr.header.__getattribute__(key)
        self.assertEqual(np.round(scaled_value, decimals=3), 
                         header_value/1000.)

    def test_get_abscissa(self):
        """
        Should return a list of abcissa values.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb)
        # just get abcissa for a subset of the traces
        idx = [0, 100, 103, 104, 500, 550]
        values = []
        for i in idx:
            ensemble = self.segy.traces[i].header.ensemble_number
            trace = \
                self.segy.traces[i].header.trace_number_within_the_ensemble
            values.append((ensemble, trace))
        # should return a list of values for plot x-axis
        x = splt._get_abscissa(['ensemble', 'trace'], values)
        _key = splt._get_header_alias(splt.ABSCISSA_KEY)
        for j,i in enumerate(idx):
            tr = self.segy.traces[i]
            self.assertEqual(tr.header.__getattribute__(_key),
                             x[j])

    def test_manage_layers(self):
        """
        Should handle moving plot items around.
        """
        fig = plt.figure()
        ax = fig.add_subplot(111)
        splt = SEGYPlotManager(ax, self.segy)
        # should echo parameters if item is not in dicts
        self.assertTrue(splt._manage_layers(foobar=True)['foobar'])
        self.assertFalse(splt._manage_layers(foobar=False)['foobar'])
        # for active item and True, should do nothing
        splt.ACTIVE_LINES['foobar'] = ax.plot([0,1], [0,1])
        self.assertTrue('foobar' in splt.ACTIVE_LINES)
        self.assertFalse('foobar' in splt.INACTIVE_LINES)
        self.assertFalse(splt._manage_layers(foobar=True)['foobar'])
        self.assertTrue('foobar' in splt.ACTIVE_LINES)
        self.assertFalse('foobar' in splt.INACTIVE_LINES)
        # for active item and False, should move to inactive and return False
        self.assertFalse(splt._manage_layers(foobar=False)['foobar'])
        self.assertFalse('foobar' in splt.ACTIVE_LINES)
        self.assertTrue('foobar' in splt.INACTIVE_LINES)
        # for force_new=True, should remove from active and inactive and return
        # True
        # item is currently in inactive list
        need2plot = splt._manage_layers(force_new=True, foobar=True)
        self.assertTrue(need2plot['foobar'])
        self.assertFalse('foobar' in splt.ACTIVE_LINES)
        self.assertFalse('foobar' in splt.INACTIVE_LINES)
        # item is now in active list
        splt.ACTIVE_LINES['foobar'] = ax.plot([0,1], [0,1])
        need2plot = splt._manage_layers(force_new=True, foobar=True)
        self.assertTrue(need2plot['foobar'])
        self.assertFalse('foobar' in splt.ACTIVE_LINES)
        self.assertFalse('foobar' in splt.INACTIVE_LINES)