def test_copy(self): """ Should create a copy of the database. """ db0 = PickDatabaseConnection(':memory:') for pick in uniq_picks: db0.add_pick(**pick) db1 = db0.copy() self.assertEqual(db0.get_picks()[0], db1.get_picks()[0])
class MenusTestCase(unittest.TestCase): """ Master suite for testing the menu classes. """ def setUp(self): """ Setup benchmark data. """ self.segy = readSEGY(get_example_file("ew0210_o30.segy"), unpack_headers=True) self.pickdb = PickDatabaseConnection(":memory:") for pick in uniq_picks: self.pickdb.update_pick(**pick)
def setUp(self): """ Setup benchmark data. """ self.segy = readSEGY(get_example_file("ew0210_o30.segy"), unpack_headers=True) self.pickdb = PickDatabaseConnection(":memory:") for pick in uniq_picks: self.pickdb.update_pick(**pick)
def test_raytrace_branch_id(self): """ Raytracing should honor branch ids """ #vmfile = get_example_file('jump1d.vm') vmfile = get_example_file('inactive_layers.vm') # Create pick database pickdbfile = 'temp.sqlite' if os.path.isfile(pickdbfile): os.remove(pickdbfile) pickdb = PickDatabaseConnection(pickdbfile) pickdb.update_pick(event='P1', ensemble=100, trace=1, branch=1, subbranch=0, time=5.0, source_x=10, source_y=0.0, source_z=0.006, receiver_x=40, receiver_y=0.0, receiver_z=4.9) pickdb.update_pick(event='P2', ensemble=100, trace=1, branch=2, subbranch=0, time=5.0) pickdb.update_pick(event='P3', ensemble=100, trace=1, branch=3, subbranch=0, time=5.0) pickdb.commit() # Raytrace vm = readVM(vmfile) rayfile = 'temp.ray' for branch in range(1, 4): if os.path.isfile(rayfile): os.remove(rayfile) pick_keys = {'branch' : branch} raytrace(vmfile, pickdb, rayfile, **pick_keys) # Should have created a rayfile self.assertTrue(os.path.isfile(rayfile)) # Load rayfans rays = readRayfanGroup(rayfile) # Should have traced just one ray self.assertEqual(len(rays.rayfans), 1) rfn = rays.rayfans[0] self.assertEqual(len(rfn.paths), 1) # Rays should turn in specified layer zmax = max([p[2] for p in rfn.paths[0]]) self.assertGreaterEqual(zmax, vm.rf[branch - 1][0][0]) # cleanup for filename in [rayfile, pickdbfile]: if os.path.isfile(filename): os.remove(filename)
def test_plot_picks(self): """ Should add picks to an axes and the layer managers. """ fig = plt.figure() ax = fig.add_subplot(111) pickdb = PickDatabaseConnection(':memory:') for pick in uniq_picks: pickdb.update_pick(**pick) splt = SEGYPickPlotter(ax, self.segy, pickdb) # should add a single artist to the line dict. for each event splt.plot_picks() for event in self.pickdb.events: self.assertTrue(len(splt.ACTIVE_LINES[event]), 1) # should be able to add new picks and have them be accessible by the # SEGYPickPlotter new_event = '--tracer--' new_pick = copy.copy(uniq_picks[0]) new_pick['event'] = new_event pickdb.update_pick(**new_pick) splt.plot_picks() self.assertTrue(new_event in splt.ACTIVE_LINES)
def test_parallel_raytrace(self): """ Should run raytracing in parallel """ # Create pick database pickdbfile = 'temp.sqlite' if os.path.isfile(pickdbfile): os.remove(pickdbfile) pickdb = PickDatabaseConnection(pickdbfile) for i, event in enumerate( ['P1', 'P2', 'P3']): branch = i + 1 for ens in range(3): pickdb.update_pick(event=event, ensemble=ens, trace=1, branch=branch, subbranch=0, time=5.0, source_x=10, source_y=0.0, source_z=0.006, receiver_x=40, receiver_y=0.0, receiver_z=4.9) pickdb.commit() # set velocity model vmfile = get_example_file('inactive_layers.vm') # raytrace for nproc in [1, 2, 8]: input_dir = 'test.input' output_dir = 'test.output' t_start = time.clock() parallel_raytrace(vmfile, pickdb, branches=[1, 2, 3], input_dir=input_dir, output_dir=output_dir, nproc=nproc, ensemble_field='ensemble') t_elapsed = time.clock() - t_start shutil.rmtree(input_dir) shutil.rmtree(output_dir) os.remove(pickdbfile)
def test_locate_on_surface(self): """ Should locate a receiver on a surface. """ inst_id = 100 dx = 1 iref = 0 for _vmfile in TEST_MODELS: vmfile = get_example_file(_vmfile) vm = readVM(vmfile) # calculate synthetic times pickdb = PickDatabaseConnection(':memory:') x0 = np.mean(vm.x) y0 = np.mean(vm.y) picks = [] xsearch = vm.xrange2i(max(vm.r1[0], x0 - dx), min(vm.r2[0], x0 + dx)) for i, ix in enumerate(xsearch): x = vm.x[ix] iy = vm.x2i([y0])[0] z0 = vm.rf[iref][ix][iy] pickdb.add_pick(event='Pw', ensemble=inst_id, trace=i, time=1e30, source_x=x, source_y=y0, source_z=0.006, receiver_x=x0, receiver_y=y0, receiver_z=z0, vm_branch=1, vm_subid=0) rayfile = 'temp.ray' raytrace(vmfile, pickdb, rayfile) raydb = rayfan2db(rayfile, 'temp.syn.sqlite', synthetic=True) os.remove(rayfile) # run locate x, y, z, rms = locate_on_surface(vmfile, raydb, 0, x0=x0, y0=y0, dx=dx, dy=dx) # compare result self.assertAlmostEqual(x, x0, 0) self.assertAlmostEqual(y, y0, 0)
def setUp(self): """ Setup benchmark data. """ self.segy = readSEGY(get_example_file('ew0210_o30.segy'), unpack_headers=True) self.pickdb = PickDatabaseConnection(':memory:') for pick in uniq_picks: self.pickdb.update_pick(**pick) self.default_params = [ 'ABSCISSA_KEY', 'GAIN', 'CLIP', 'NORMALIZATION_METHOD', 'OFFSET_GAIN_POWER', 'WIGGLE_PEN_COLOR', 'WIGGLE_PEN_WIDTH', 'NEG_FILL_COLOR', 'POS_FILL_COLOR', 'DISTANCE_UNIT', 'TIME_UNIT', 'SEGY_TIME_UNITS', 'SEGY_DISTANCE_UNITS', 'SEGY_HEADER_ALIASES']
def test_get_events(self): """ Should return a list of event names in the database. """ pickdb = PickDatabaseConnection(':memory:') # add picks and get our own list of event names events = [] for i, pick in enumerate(uniq_picks): pickdb.add_pick(**pick) if pick['event'] not in events: events.append(pick['event']) # private function should return a list of events self.assertEqual(sorted(pickdb._get_events()), sorted(events)) # public 'events' attribute should be a property with # setfunction=_get_events self.assertEqual(pickdb.events, pickdb._get_events()) # should return empty list for empty database pickdb = PickDatabaseConnection(':memory:') self.assertEqual(pickdb.events, [])
pickdb_file = 'synthetic.sqlite' # Cleanup from any earlier run for filename in [vmfile, temp_pickdb_file, rayfile, pickdb_file]: if os.path.isfile(filename): os.remove(filename) # Create a simple 1D model vm = VM(r1=(0, 0, 0), r2=(100, 0, 15), dx=0.5, dy=1, dz=0.1) vm.insert_interface(np.asarray([[z] for z in 2.0 * np.ones((vm.nx))])) vm.define_constant_layer_velocity(0, 1.5) vm.define_stretched_layer_velocities(1, [1.5, 4.0, 6.5, 8.0, 8.5]) vm.write(vmfile) # Create a pickdb to define source/receiver geometry pickdb = PickDatabaseConnection(temp_pickdb_file) sx = np.arange(1., 80., 1.) rx = [1.] for irec, _rx in enumerate(rx): for isrc, _sx in enumerate(sx): d = {'event': 'Pg', 'vm_branch': 1, 'vm_subid': 0, 'ensemble': irec + 10, 'trace': isrc + 5000, 'time': 0.0, 'time_reduced': 0.0, 'source_x': _sx, 'source_y': 0.0, 'source_z': 1.0, 'receiver_x': _rx,
from rockfish.picking.database import PickDatabaseConnection from rockfish.vmtomo.vm import VMFile from rockfish.vmtomo.rayfan import RayfanFile from rockfish.vmtomo.raytracing import trace cleanup = True # Set filenames for model, pick database, etc. pickdbfile = 'temp.pickdb.sqlite' vmfile = '../tests/data/goc_l26.15.00.vm' #vmfile = '../tests/data/cranis_northeast.00.00.vm' input_dir = 'temp.input' rayfile = 'temp.ray' # Build an example pick database pickdb = PickDatabaseConnection(pickdbfile) pickdb.add_pick(event='Pg', ensemble=100, trace=1, vm_branch=3, vm_subid=0, time=5.0, time_reduced=5.0, source_x=50, source_y=0.0, source_z=0.006, receiver_x=100, receiver_y=0.0, receiver_z=2) pickdb.commit() # Raytrace the model trace(vmfile, pickdb, rayfile, input_dir=input_dir, cleanup=False) # Plot model and rays fig = plt.figure() ax = fig.add_subplot(311) vm = VMFile(vmfile) vm.plot(ax=ax)
class SEGYPlotterTestCase(unittest.TestCase): """ Test cases for SEGYPlotter and SEGYPickPlotter """ def setUp(self): """ Setup benchmark data. """ self.segy = readSEGY(get_example_file('ew0210_o30.segy'), unpack_headers=True) self.pickdb = PickDatabaseConnection(':memory:') for pick in uniq_picks: self.pickdb.update_pick(**pick) def test_init_SEGYPlotter(self): """ Should create a new instance of the SEGYPlotter. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotter(ax, self.segy) # should inherit from SEGYPlotManager for member in inspect.getmembers(SEGYPlotManager): self.assertTrue(hasattr(splt, member[0])) # should *not* build header lookup table self.assertFalse(hasattr(splt, 'sdb')) # should attach axes self.assertTrue(isinstance(splt.ax, matplotlib.axes.Axes)) def test_plot_negative_wiggle_fills(self): """ Should add negative wiggle fills to the axes. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotter(ax, self.segy) # should add a single artist to the active patch dict. splt.plot_wiggles(negative_fills=True) self.assertEqual(len(splt.ACTIVE_PATCHES['negative_fills']),1) # same artist should be in the axes list self.assertTrue(splt.ACTIVE_PATCHES['negative_fills'][0] in \ splt.ax.patches) # should be able to directly remove it splt.ax.patches.remove(splt.ACTIVE_PATCHES['negative_fills'][0]) self.assertEqual(len(splt.ax.patches), 0) # and re-add it splt.ax.patches.append(splt.ACTIVE_PATCHES['negative_fills'][0]) self.assertEqual(len(splt.ax.patches), 1) # should move artist to the inactive patch dict. splt.plot_wiggles(negative_fills=False) self.assertEqual(len(splt.INACTIVE_PATCHES['negative_fills']),1) self.assertEqual(len(splt.ax.patches), 0) # calling again should change nothing splt.plot_wiggles(negative_fills=False) self.assertEqual(len(splt.INACTIVE_PATCHES['negative_fills']),1) self.assertEqual(len(splt.ax.patches), 0) # should move artist back to active patch dict. splt.plot_wiggles(negative_fills=True) self.assertEqual(len(splt.ACTIVE_PATCHES['negative_fills']),1) self.assertEqual(len(splt.ax.patches), 1) # new plot items should be accessible through the orig. axes self.assertEqual(len(splt.ax.patches), len(ax.patches)) def test_plot_positive_wiggle_fills(self): """ Should add positive wiggle fills to the axes. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotter(ax, self.segy) # should add a single artist to the active patch dict. splt.plot_wiggles(positive_fills=True) self.assertEqual(len(splt.ACTIVE_PATCHES['positive_fills']),1) # same artist should be in the axes list self.assertTrue(splt.ACTIVE_PATCHES['positive_fills'][0] in \ splt.ax.patches) # should be able to directly remove it splt.ax.patches.remove(splt.ACTIVE_PATCHES['positive_fills'][0]) self.assertEqual(len(splt.ax.patches), 0) # and re-add it splt.ax.patches.append(splt.ACTIVE_PATCHES['positive_fills'][0]) self.assertEqual(len(splt.ax.patches), 1) # should move artist to the inactive patch dict. splt.plot_wiggles(positive_fills=False) self.assertEqual(len(splt.INACTIVE_PATCHES['positive_fills']),1) self.assertEqual(len(splt.ax.patches), 0) # calling again should change nothing splt.plot_wiggles(positive_fills=False) self.assertEqual(len(splt.INACTIVE_PATCHES['positive_fills']),1) self.assertEqual(len(splt.ax.patches), 0) # should move artist back to active patch dict. splt.plot_wiggles(positive_fills=True) self.assertEqual(len(splt.ACTIVE_PATCHES['positive_fills']),1) self.assertEqual(len(splt.ax.patches), 1) # new plot items should be accessible through the orig. axes self.assertEqual(len(splt.ax.patches), len(ax.patches)) def test_plot_wiggle_traces(self): """ Should add wiggle traces to the axes. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotter(ax, self.segy) # should add a single artist to the active line dict. splt.plot_wiggles(wiggle_traces=True) self.assertEqual(len(splt.ACTIVE_LINES['wiggle_traces']),1) # same artist should be in the axes list self.assertTrue(splt.ACTIVE_LINES['wiggle_traces'][0] in \ splt.ax.lines) # should be able to directly remove it splt.ax.lines.remove(splt.ACTIVE_LINES['wiggle_traces'][0]) self.assertEqual(len(splt.ax.lines), 0) # and re-add it splt.ax.lines.append(splt.ACTIVE_LINES['wiggle_traces'][0]) self.assertEqual(len(splt.ax.lines), 1) # should move artist to the inactive line dict. splt.plot_wiggles(wiggle_traces=False) self.assertEqual(len(splt.INACTIVE_LINES['wiggle_traces']),1) self.assertEqual(len(splt.ax.lines), 0) # calling again should change nothing splt.plot_wiggles(wiggle_traces=False) self.assertEqual(len(splt.INACTIVE_LINES['wiggle_traces']),1) self.assertEqual(len(splt.ax.lines), 0) # should move artist back to active line dict and re-add to axes splt.plot_wiggles(wiggle_traces=True) self.assertEqual(len(splt.ACTIVE_LINES['wiggle_traces']),1) self.assertEqual(len(splt.ax.lines), 1) # new plot items should be accessible through the orig. axes self.assertEqual(len(splt.ax.lines), len(ax.lines)) def test_link_axes(self): """ Plotting should update items in the linked axes. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotter(ax, self.segy) # should add one artist to our axes splt.plot_wiggles(wiggle_traces=True) self.assertEqual(len(splt.ACTIVE_LINES['wiggle_traces']), 1) self.assertTrue('wiggle_traces' not in splt.INACTIVE_LINES) self.assertEqual(len(ax.lines), 1) # should remove one artist to our axes splt.plot_wiggles(wiggle_traces=False) self.assertTrue('wiggle_traces' not in splt.ACTIVE_LINES) self.assertEqual(len(splt.INACTIVE_LINES['wiggle_traces']), 1) self.assertEqual(len(ax.lines), 0) def test_init_SEGYPickPlotter(self): """ Should create a new instance of the SEGYPickPlotter. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPickPlotter(ax, self.segy, pickdb=self.pickdb) # should inherit from SEGYPlotter for member in inspect.getmembers(SEGYPlotter): self.assertTrue(hasattr(splt, member[0])) # should build header lookup table self.assertTrue(isinstance(splt.sdb, SEGYHeaderDatabase)) # should attach axes self.assertTrue(isinstance(splt.ax, matplotlib.axes.Axes)) def test_plot_picks(self): """ Should add picks to an axes and the layer managers. """ fig = plt.figure() ax = fig.add_subplot(111) pickdb = PickDatabaseConnection(':memory:') for pick in uniq_picks: pickdb.update_pick(**pick) splt = SEGYPickPlotter(ax, self.segy, pickdb) # should add a single artist to the line dict. for each event splt.plot_picks() for event in self.pickdb.events: self.assertTrue(len(splt.ACTIVE_LINES[event]), 1) # should be able to add new picks and have them be accessible by the # SEGYPickPlotter new_event = '--tracer--' new_pick = copy.copy(uniq_picks[0]) new_pick['event'] = new_event pickdb.update_pick(**new_pick) splt.plot_picks() self.assertTrue(new_event in splt.ACTIVE_LINES)
def rayfan2db(rayfan_file, raydb_file=':memory:', synthetic=False, noise=None, pickdb=None, raypaths=False): """ Read a rayfan file and store its data in a database. Data are stored in a modified version of a :class:`rockfish.picking.database.PickDatabaseConnection`. Parameters ---------- rayfan_file: {str, file} An open file-like object or a string which is assumed to be a filename of a rayfan binary file. raydb_file: str, optional The filename of the new database. Default is to create a new database in memory. synthetic: bool, optional Determines whether or not to record traced traveltimes as picked traveltimes. noise: {float, None} Maximum amplitude of random noise to add the travel times. If ``None``, no noise is added. pickdb: :class:`rockfish.database.PickDatabaseConnection`, optional An active connection to the pick database that was used trace rays in ``rayfan_file``. Values for extra fields (e.g., 'trace_in_file') are copied from this database to the new database along with rayfan data. Default is ignore these extra fields. raypaths: bool, optional If ``True``, raypath coordinates are stored as text in a new table 'raypaths'. """ raydb = PickDatabaseConnection(raydb_file) rays = readRayfanGroup(rayfan_file) print "Adding {:} traveltimes to {:} ..."\ .format(rays.nrays, raydb_file) # add fields for raypaths if raypaths: raydb._create_table_if_not_exists(RAYPATH_TABLE, RAYPATH_FIELDS) ndb0 = raydb.execute('SELECT COUNT(rowid) FROM picks').fetchone()[0] for rfn in rays.rayfans: for i, _t in enumerate(rfn.travel_times): if noise is not None: _noise = noise * 2 * (np.random.random() - 0.5) else: _noise = 0.0 sx, sy, sz = rfn.paths[i][0] rx, ry, rz = rfn.paths[i][-1] if pickdb is not None: event = pickdb.vmbranch2event[rfn.event_ids[i]] else: event = rfn.event_ids[i] if synthetic: time = rfn.travel_times[i] + _noise time_reduced = time predicted = None residual = 0. else: time = rfn.pick_times[i] time_reduced = time predicted = rfn.travel_times[i] residual = rfn.residuals[i] d = {'event': event, 'ensemble': rfn.start_point_id, 'trace': rfn.end_point_ids[i], 'vm_branch': rfn.event_ids[i], 'vm_subid': rfn.event_subids[i], 'time' : time, 'time_reduced' : time_reduced, 'predicted' : predicted, 'residual' : residual, 'error': rfn.pick_errors[i], 'source_x': sx, 'source_y': sy, 'source_z': sz, 'receiver_x': rx, 'receiver_y': ry, 'receiver_z': rz, 'offset': rfn.offsets[i], 'faz': rfn.azimuths[i], 'method': 'rayfan2db({:})'.format(rayfan_file), 'data_file': rays.file.name} # Copy data from pickdb if pickdb is not None: pick = pickdb.get_picks(event=[event], ensemble=[d['ensemble']], trace=[d['trace']]) if len(pick) > 0: for f in ['trace_in_file', 'line', 'site', 'data_file']: try: d[f] = pick[0][f] except KeyError: pass # add data to standard tables raydb.update_pick(**d) # add raypath data to new table if raypaths: d = {'event': event, 'ensemble': rfn.start_point_id, 'trace': rfn.end_point_ids[i], 'ray_btm_x': rays.bottom_points[i][0], 'ray_btm_y': rays.bottom_points[i][1], 'ray_btm_z': rays.bottom_points[i][2], 'ray_x': str([p[0] for p in rfn.paths[i]]), 'ray_y': str([p[1] for p in rfn.paths[i]]), 'ray_z': str([p[2] for p in rfn.paths[i]])} raydb._insert(RAYPATH_TABLE, **d) raydb.commit() ndb = raydb.execute('SELECT COUNT(rowid) FROM picks').fetchone()[0] if (ndb - ndb0) != rays.nrays: msg = 'Only added {:} of {:} travel times to the database.'\ .format(ndb - ndb0, rays.nrays) warnings.warn(msg) return raydb
def test_add_remove_picks(self): """ Should add a pick to the picks table. """ pickdb = PickDatabaseConnection(':memory:') # should add all picks to the database for pick in uniq_picks: pickdb.add_pick(**pick) pickdb.commit() ndb = len(pickdb.execute('SELECT * FROM picks').fetchall()) self.assertEqual(ndb, len(uniq_picks)) # attempting to add pick without primary fields should raise an error with self.assertRaises(sqlite3.IntegrityError): pickdb.add_pick(event='Foobar', time=9.834) # attempting to add duplicate picks should raise error for pick in uniq_picks: with self.assertRaises(sqlite3.IntegrityError): pickdb.add_pick(**pick) # directly removing pick and then re-adding should work for pick in uniq_picks: pickdb.remove_pick(**pick) pickdb.add_pick(**pick) pickdb.commit() ndb = len(pickdb.execute('SELECT * FROM picks').fetchall()) self.assertEqual(ndb, len(uniq_picks)) # invalid collumn names should raise OperationalError with self.assertRaises(sqlite3.OperationalError): pickdb.remove_pick(not_a_field=999) with self.assertRaises(sqlite3.OperationalError): pickdb.add_pick(not_a_field=999) # remove the last pick that we added pickdb.remove_pick(**pick) # attempting to remove a non-existant pick should do nothing pickdb.remove_pick(**pick) # updating picks should add picks if they don't exist and update picks # if they do exist for pick in uniq_picks: pickdb.update_pick(**pick) pickdb.commit() ndb = len(pickdb.execute('SELECT * FROM picks').fetchall()) self.assertEqual(ndb, len(uniq_picks)) # should not be able to add pick without required fields with self.assertRaises(sqlite3.IntegrityError): pickdb.add_pick(event='Pg', ensemble=1, trace=1)
def test_get_picks(self): """ Should return data rows. """ # create a new database in memory pickdb = PickDatabaseConnection(':memory:') # add picks, and get our own times times0 = [] for pick in uniq_picks: pickdb.add_pick(**pick) times0.append(pick['time']) pickdb.commit() # should return a list of sqlite3.Row objects for row in pickdb.get_picks(event=pick['event']): self.assertTrue(isinstance(row, sqlite3.Row)) # should return the same data times1 = [] for pick in uniq_picks: row = pickdb.get_picks(**pick)[0] times1.append(row['time']) self.assertEqual(sorted(times0), sorted(times1)) # should return all data self.assertEqual(len(uniq_picks), len(pickdb.get_picks())) # should return empty list if no matches self.assertEqual(len(pickdb.get_picks(event='golly_gee')), 0) # should also return empty list if no data pickdb = PickDatabaseConnection(':memory:') self.assertEqual(len(pickdb.get_picks()), 0)
def test_counts(self): """ Count functions should return number of distinct rows in tables. """ pickdb = PickDatabaseConnection(':memory:') # add a single pick event = 'TestCount' pick = uniq_picks[0] pick['event'] = event pickdb.add_pick(**pick) pickdb.commit() # should have a single pick for this event in picks table n = pickdb.count_picks_by_event(event) self.assertEqual(n, 1) # should also have a single entry in the event table n = pickdb._count_distinct(pickdb.EVENT_TABLE, event=event) self.assertEqual(n, 1) # now, another pick for the same event pick['ensemble'] += 1 # ensure unique pickdb.add_pick(**pick) pickdb.commit() # should have two picks for this event in the picks table n = pickdb.count_picks_by_event(event) self.assertEqual(n, 2) # should still just have one entry in the event table for this event n = pickdb._count_distinct(pickdb.EVENT_TABLE, event=event) self.assertEqual(n, 1)
def locate_on_surface(vmfile, pickdb, iref, pick_keys={}, x0=None, y0=None, dx=None, dy=None, plot=False): """ Locate a receiver on a surface. :param vmfile: Filename of the VM Tomography slowness model to raytrace. :param pickdb: :class:`rockfish.picking.database.PickDatabaseConnection` to get picks from for raytracing. :param iref: index of the surface to move the receiver on :param pick_keys: Optional. ``dict`` of keys and values to use for selecting picks from ``pickdb``. Default is to use all picks. :param x0, y0: Optional. Initial guess at x and y locations. Default is center of the model. :param dx, dy: Optional. Distance in x and y to search from ``x0, y0``. Default is to search the entire model. :param plot: Optional. Plot results of the location. Default is false. """ # Load model print "Loading VM model..." vm = readVM(vmfile) # Setup search domain print "Setting up search domain..." if x0 is None: x0 = np.mean(vm.x) if y0 is None: y0 = np.mean(vm.y) if dx is None: xsearch = vm.x else: ix = vm.xrange2i(max(vm.r1[0], x0 - dx), min(vm.r2[0], x0 + dx)) xsearch = vm.x[ix] if dy is None: ysearch = vm.y else: iy = vm.yrange2i(max(vm.r1[1], y0 - dy), min(vm.r2[1], y0 + dy)) ysearch = vm.y[iy] # trace each point in the search region print "Building traveltime database for {:}x{:} search grid..."\ .format(len(xsearch), len(ysearch)) zz = vm.rf[iref] ipop = -1 population = [] db = PickDatabaseConnection(':memory:') for x in xsearch: for y in ysearch: z = zz[vm.x2i([x])[0], vm.y2i([y])[0]] ipop += 1 for _p in pickdb.get_picks(**pick_keys): p = dict(_p) p['ensemble'] = ipop p['receiver_x'] = x p['receiver_y'] = y p['receiver_z'] = z db.add_pick(**p) population.append((x, y, z)) print "Raytracing..." rayfile = '.locate.ray' raytrace(vmfile, db, rayfile, stderr=subprocess.STDOUT, stdout=subprocess.PIPE) rays = readRayfanGroup(rayfile) imin = np.argmin([rfn.rms for rfn in rays.rayfans]) ipop = [rfn.start_point_id for rfn in rays.rayfans][imin] rms = [rfn.rms for rfn in rays.rayfans][imin] # # warn if fit point is on edge of search domain # if (vm.ny > 1) and ((yfit == ysearch[0]) or (yfit == ysearch[-1])): # on_edge = True # elif (xfit == xsearch[0]) or (xfit == xsearch[-1]): # on_edge = True # else: # on_edge = False # if on_edge: # msg = 'Best-fit location ({:},{:}) is on edge of model domain:'\ # .format(xfit, yfit) # msg += ' x=[{:},{:}], y=[{:},{:}].'.format(xsearch[0], xsearch[-1], # ysearch[0], ysearch[-1]) # msg += ' Try using a larger search region.' # warnings.warn(msg) # plot if plot: fig = plt.figure() ax = fig.add_subplot(111) if vm.ny == 1: ax.plot([p[0] for p in population], [rfn.rms for rfn in rays.rayfans], '.k') ax.plot(population[ipop][0], rms, '*r') plt.xlabel('x (km)') plt.ylabel('RMS error (s)') else: ax.plot([p[0] for p in population], [p[1] for p in population], '.k') ax.plot(population[ipop][0], population[ipop][1], '*r') ax.plot(x0, y0, 'og') plt.xlabel('x (km)') plt.xlabel('y (km)') plt.title('Results of locate_on_surface()') plt.show() return population[ipop][0], population[ipop][1], population[ipop][2], rms
class SEGYPlotManagerTestCase(unittest.TestCase): """ Test cases for SEGYPlotManager """ def setUp(self): """ Setup benchmark data. """ self.segy = readSEGY(get_example_file('ew0210_o30.segy'), unpack_headers=True) self.pickdb = PickDatabaseConnection(':memory:') for pick in uniq_picks: self.pickdb.update_pick(**pick) self.default_params = [ 'ABSCISSA_KEY', 'GAIN', 'CLIP', 'NORMALIZATION_METHOD', 'OFFSET_GAIN_POWER', 'WIGGLE_PEN_COLOR', 'WIGGLE_PEN_WIDTH', 'NEG_FILL_COLOR', 'POS_FILL_COLOR', 'DISTANCE_UNIT', 'TIME_UNIT', 'SEGY_TIME_UNITS', 'SEGY_DISTANCE_UNITS', 'SEGY_HEADER_ALIASES'] def test_init_SEGYPlotManager(self): """ Should create a new instance of the SEGYPlotManager. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) # should have the default parameters for attr in self.default_params: self.assertTrue(hasattr(splt, attr)) # by default, should *not* have header databse self.assertFalse(hasattr(splt, 'sdb')) # should attach axes self.assertTrue(isinstance(splt.ax, matplotlib.axes.Axes)) # class should be able to update the axes splt.ax.plot([0,1],[0,1]) self.assertEqual(len(splt.ax.lines), 1) # updates outside the class should be seen inside the class ax.plot([0,1],[0,1]) self.assertEqual(len(splt.ax.lines), 2) self.assertEqual(len(ax.lines), 2) # if a pick database is given, should attach pickdb and build lookup db splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb) self.assertTrue(isinstance(splt.pickdb, PickDatabaseConnection)) self.assertTrue(isinstance(splt.sdb, SEGYHeaderDatabase)) def test_build_header_database(self): """ Should connect to a database for looking up header attributes. """ fig = plt.figure() ax = fig.add_subplot(111) # connecting to a database in memory should not create a file splt = SEGYPlotManager(ax, self.segy, trace_header_database=':memory:') self.assertFalse(os.path.isfile(':memory:')) # connecting to a filename should create a file if it does not exist filename = 'temp.test_build_header_database.sqlite' if os.path.isfile(filename): os.remove(filename) # should not create database if no pickdb is given splt = SEGYPlotManager(ax, self.segy, trace_header_database=filename) self.assertFalse(os.path.isfile(filename)) # should create database if pickdb is given splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb, trace_header_database=filename) self.assertTrue(os.path.isfile(filename)) # clean up os.remove(filename) # create plot manager in memory with pickdb splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb) # should be able to get primary fields from the picks trace table pick_keys = splt.pickdb._get_primary_fields(splt.pickdb.TRACE_TABLE) self.assertTrue(len(pick_keys) > 0) # pick primary keys should be in the alias dictionary # or be a header attribute header = self.segy.traces[0].header for k in pick_keys: self.assertTrue(k in splt.SEGY_HEADER_ALIASES \ or hasattr(header, k)) # should add header fields for primary keys in the pick database sql = 'SELECT * FROM %s' %splt.sdb.TRACE_TABLE data = splt.sdb.execute(sql) for key in pick_keys: _key = splt._get_header_alias(key) for i,row in enumerate(data): segy_value = splt.get_header_value( self.segy.traces[i].header, key, convert_units=False) db_value = row[_key] self.assertEqual(segy_value, db_value) def test_get_units(self): """ Should return (segy units, plot units) or None. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) splt.DISTANCE_UNIT = 'distance_unit_marker' splt.TIME_UNIT = 'time_unit_marker' for key in TRACE_HEADER_KEYS: if key in splt.SEGY_TIME_UNITS: # should return TIME_UNIT for a time attribute self.assertEqual(splt._get_units(key)[1], 'time_unit_marker') elif key in splt.SEGY_DISTANCE_UNITS: # should return DISTANCE_UNIT for a distance attribute self.assertEqual(splt._get_units(key)[1], 'distance_unit_marker') else: # should return None values are unitless self.assertEqual(splt._get_units(key), None) def test_convert_units(self): """ Should convert header units to plot units. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) # should correctly perform unit conversions for distance splt.DISTANCE_UNIT = 'km' self.assertEqual(splt._convert_units('offset', [1000]), [1]) # should correctly perform unit conversions for time splt.TIME_UNIT = 's' self.assertEqual(splt._convert_units('delay', [1000]), [1]) def test_get_time_array(self): """ Should return an array of time values for a single trace. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) tr = self.segy.traces[0] # should have npts values npts = splt.get_header_value(tr.header, 'npts') t = splt.get_time_array(tr.header) self.assertEqual(npts, len(t)) def test_get_header_value(self): """ Should be able to use aliases to get unit-converted values from headers. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) # should return exact header values when convert_units=False tr = self.segy.traces[0] for alias in splt.SEGY_HEADER_ALIASES: key = splt.SEGY_HEADER_ALIASES[alias] value = splt.get_header_value(tr.header, alias, convert_units=False) _value = tr.header.__getattribute__(key) self.assertEqual(value, _value) # default should return header values in the plot units splt.DISTANCE_UNIT = 'km' alias = 'offset' key = splt.SEGY_HEADER_ALIASES[alias] scaled_value = splt.get_header_value(tr.header, alias) header_value = tr.header.__getattribute__(key) self.assertEqual(np.round(scaled_value, decimals=3), header_value/1000.) def test_get_abscissa(self): """ Should return a list of abcissa values. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb) # just get abcissa for a subset of the traces idx = [0, 100, 103, 104, 500, 550] values = [] for i in idx: ensemble = self.segy.traces[i].header.ensemble_number trace = \ self.segy.traces[i].header.trace_number_within_the_ensemble values.append((ensemble, trace)) # should return a list of values for plot x-axis x = splt._get_abscissa(['ensemble', 'trace'], values) _key = splt._get_header_alias(splt.ABSCISSA_KEY) for j,i in enumerate(idx): tr = self.segy.traces[i] self.assertEqual(tr.header.__getattribute__(_key), x[j]) def test_manage_layers(self): """ Should handle moving plot items around. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) # should echo parameters if item is not in dicts self.assertTrue(splt._manage_layers(foobar=True)['foobar']) self.assertFalse(splt._manage_layers(foobar=False)['foobar']) # for active item and True, should do nothing splt.ACTIVE_LINES['foobar'] = ax.plot([0,1], [0,1]) self.assertTrue('foobar' in splt.ACTIVE_LINES) self.assertFalse('foobar' in splt.INACTIVE_LINES) self.assertFalse(splt._manage_layers(foobar=True)['foobar']) self.assertTrue('foobar' in splt.ACTIVE_LINES) self.assertFalse('foobar' in splt.INACTIVE_LINES) # for active item and False, should move to inactive and return False self.assertFalse(splt._manage_layers(foobar=False)['foobar']) self.assertFalse('foobar' in splt.ACTIVE_LINES) self.assertTrue('foobar' in splt.INACTIVE_LINES) # for force_new=True, should remove from active and inactive and return # True # item is currently in inactive list need2plot = splt._manage_layers(force_new=True, foobar=True) self.assertTrue(need2plot['foobar']) self.assertFalse('foobar' in splt.ACTIVE_LINES) self.assertFalse('foobar' in splt.INACTIVE_LINES) # item is now in active list splt.ACTIVE_LINES['foobar'] = ax.plot([0,1], [0,1]) need2plot = splt._manage_layers(force_new=True, foobar=True) self.assertTrue(need2plot['foobar']) self.assertFalse('foobar' in splt.ACTIVE_LINES) self.assertFalse('foobar' in splt.INACTIVE_LINES)