class MenusTestCase(unittest.TestCase): """ Master suite for testing the menu classes. """ def setUp(self): """ Setup benchmark data. """ self.segy = readSEGY(get_example_file("ew0210_o30.segy"), unpack_headers=True) self.pickdb = PickDatabaseConnection(":memory:") for pick in uniq_picks: self.pickdb.update_pick(**pick)
def test_add_remove_picks(self): """ Should add a pick to the picks table. """ pickdb = PickDatabaseConnection(':memory:') # should add all picks to the database for pick in uniq_picks: pickdb.add_pick(**pick) pickdb.commit() ndb = len(pickdb.execute('SELECT * FROM picks').fetchall()) self.assertEqual(ndb, len(uniq_picks)) # attempting to add pick without primary fields should raise an error with self.assertRaises(sqlite3.IntegrityError): pickdb.add_pick(event='Foobar', time=9.834) # attempting to add duplicate picks should raise error for pick in uniq_picks: with self.assertRaises(sqlite3.IntegrityError): pickdb.add_pick(**pick) # directly removing pick and then re-adding should work for pick in uniq_picks: pickdb.remove_pick(**pick) pickdb.add_pick(**pick) pickdb.commit() ndb = len(pickdb.execute('SELECT * FROM picks').fetchall()) self.assertEqual(ndb, len(uniq_picks)) # invalid collumn names should raise OperationalError with self.assertRaises(sqlite3.OperationalError): pickdb.remove_pick(not_a_field=999) with self.assertRaises(sqlite3.OperationalError): pickdb.add_pick(not_a_field=999) # remove the last pick that we added pickdb.remove_pick(**pick) # attempting to remove a non-existant pick should do nothing pickdb.remove_pick(**pick) # updating picks should add picks if they don't exist and update picks # if they do exist for pick in uniq_picks: pickdb.update_pick(**pick) pickdb.commit() ndb = len(pickdb.execute('SELECT * FROM picks').fetchall()) self.assertEqual(ndb, len(uniq_picks)) # should not be able to add pick without required fields with self.assertRaises(sqlite3.IntegrityError): pickdb.add_pick(event='Pg', ensemble=1, trace=1)
def test_plot_picks(self): """ Should add picks to an axes and the layer managers. """ fig = plt.figure() ax = fig.add_subplot(111) pickdb = PickDatabaseConnection(':memory:') for pick in uniq_picks: pickdb.update_pick(**pick) splt = SEGYPickPlotter(ax, self.segy, pickdb) # should add a single artist to the line dict. for each event splt.plot_picks() for event in self.pickdb.events: self.assertTrue(len(splt.ACTIVE_LINES[event]), 1) # should be able to add new picks and have them be accessible by the # SEGYPickPlotter new_event = '--tracer--' new_pick = copy.copy(uniq_picks[0]) new_pick['event'] = new_event pickdb.update_pick(**new_pick) splt.plot_picks() self.assertTrue(new_event in splt.ACTIVE_LINES)
def test_parallel_raytrace(self): """ Should run raytracing in parallel """ # Create pick database pickdbfile = 'temp.sqlite' if os.path.isfile(pickdbfile): os.remove(pickdbfile) pickdb = PickDatabaseConnection(pickdbfile) for i, event in enumerate( ['P1', 'P2', 'P3']): branch = i + 1 for ens in range(3): pickdb.update_pick(event=event, ensemble=ens, trace=1, branch=branch, subbranch=0, time=5.0, source_x=10, source_y=0.0, source_z=0.006, receiver_x=40, receiver_y=0.0, receiver_z=4.9) pickdb.commit() # set velocity model vmfile = get_example_file('inactive_layers.vm') # raytrace for nproc in [1, 2, 8]: input_dir = 'test.input' output_dir = 'test.output' t_start = time.clock() parallel_raytrace(vmfile, pickdb, branches=[1, 2, 3], input_dir=input_dir, output_dir=output_dir, nproc=nproc, ensemble_field='ensemble') t_elapsed = time.clock() - t_start shutil.rmtree(input_dir) shutil.rmtree(output_dir) os.remove(pickdbfile)
def test_raytrace_branch_id(self): """ Raytracing should honor branch ids """ #vmfile = get_example_file('jump1d.vm') vmfile = get_example_file('inactive_layers.vm') # Create pick database pickdbfile = 'temp.sqlite' if os.path.isfile(pickdbfile): os.remove(pickdbfile) pickdb = PickDatabaseConnection(pickdbfile) pickdb.update_pick(event='P1', ensemble=100, trace=1, branch=1, subbranch=0, time=5.0, source_x=10, source_y=0.0, source_z=0.006, receiver_x=40, receiver_y=0.0, receiver_z=4.9) pickdb.update_pick(event='P2', ensemble=100, trace=1, branch=2, subbranch=0, time=5.0) pickdb.update_pick(event='P3', ensemble=100, trace=1, branch=3, subbranch=0, time=5.0) pickdb.commit() # Raytrace vm = readVM(vmfile) rayfile = 'temp.ray' for branch in range(1, 4): if os.path.isfile(rayfile): os.remove(rayfile) pick_keys = {'branch' : branch} raytrace(vmfile, pickdb, rayfile, **pick_keys) # Should have created a rayfile self.assertTrue(os.path.isfile(rayfile)) # Load rayfans rays = readRayfanGroup(rayfile) # Should have traced just one ray self.assertEqual(len(rays.rayfans), 1) rfn = rays.rayfans[0] self.assertEqual(len(rfn.paths), 1) # Rays should turn in specified layer zmax = max([p[2] for p in rfn.paths[0]]) self.assertGreaterEqual(zmax, vm.rf[branch - 1][0][0]) # cleanup for filename in [rayfile, pickdbfile]: if os.path.isfile(filename): os.remove(filename)
for isrc, _sx in enumerate(sx): d = {'event': 'Pg', 'vm_branch': 1, 'vm_subid': 0, 'ensemble': irec + 10, 'trace': isrc + 5000, 'time': 0.0, 'time_reduced': 0.0, 'source_x': _sx, 'source_y': 0.0, 'source_z': 1.0, 'receiver_x': _rx, 'receiver_y': 0.0, 'receiver_z': 2.0, 'offset' : np.abs(_rx - _sx)} pickdb.update_pick(**d) pickdb.commit() # Raytrace with these picks raytrace(vmfile, pickdb, rayfile) pickdb.close() # Transfer traced to a picks to a new pick database and add noise pickdb = rayfan2db(rayfile, pickdb_file, synthetic=True, noise=0.02) # Raytrace with the new pick database raytrace(vmfile, pickdb, rayfile) # Plot the traced rays and traveltimes fig = plt.figure() ax = fig.add_subplot(211)
class SEGYPlotterTestCase(unittest.TestCase): """ Test cases for SEGYPlotter and SEGYPickPlotter """ def setUp(self): """ Setup benchmark data. """ self.segy = readSEGY(get_example_file('ew0210_o30.segy'), unpack_headers=True) self.pickdb = PickDatabaseConnection(':memory:') for pick in uniq_picks: self.pickdb.update_pick(**pick) def test_init_SEGYPlotter(self): """ Should create a new instance of the SEGYPlotter. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotter(ax, self.segy) # should inherit from SEGYPlotManager for member in inspect.getmembers(SEGYPlotManager): self.assertTrue(hasattr(splt, member[0])) # should *not* build header lookup table self.assertFalse(hasattr(splt, 'sdb')) # should attach axes self.assertTrue(isinstance(splt.ax, matplotlib.axes.Axes)) def test_plot_negative_wiggle_fills(self): """ Should add negative wiggle fills to the axes. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotter(ax, self.segy) # should add a single artist to the active patch dict. splt.plot_wiggles(negative_fills=True) self.assertEqual(len(splt.ACTIVE_PATCHES['negative_fills']),1) # same artist should be in the axes list self.assertTrue(splt.ACTIVE_PATCHES['negative_fills'][0] in \ splt.ax.patches) # should be able to directly remove it splt.ax.patches.remove(splt.ACTIVE_PATCHES['negative_fills'][0]) self.assertEqual(len(splt.ax.patches), 0) # and re-add it splt.ax.patches.append(splt.ACTIVE_PATCHES['negative_fills'][0]) self.assertEqual(len(splt.ax.patches), 1) # should move artist to the inactive patch dict. splt.plot_wiggles(negative_fills=False) self.assertEqual(len(splt.INACTIVE_PATCHES['negative_fills']),1) self.assertEqual(len(splt.ax.patches), 0) # calling again should change nothing splt.plot_wiggles(negative_fills=False) self.assertEqual(len(splt.INACTIVE_PATCHES['negative_fills']),1) self.assertEqual(len(splt.ax.patches), 0) # should move artist back to active patch dict. splt.plot_wiggles(negative_fills=True) self.assertEqual(len(splt.ACTIVE_PATCHES['negative_fills']),1) self.assertEqual(len(splt.ax.patches), 1) # new plot items should be accessible through the orig. axes self.assertEqual(len(splt.ax.patches), len(ax.patches)) def test_plot_positive_wiggle_fills(self): """ Should add positive wiggle fills to the axes. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotter(ax, self.segy) # should add a single artist to the active patch dict. splt.plot_wiggles(positive_fills=True) self.assertEqual(len(splt.ACTIVE_PATCHES['positive_fills']),1) # same artist should be in the axes list self.assertTrue(splt.ACTIVE_PATCHES['positive_fills'][0] in \ splt.ax.patches) # should be able to directly remove it splt.ax.patches.remove(splt.ACTIVE_PATCHES['positive_fills'][0]) self.assertEqual(len(splt.ax.patches), 0) # and re-add it splt.ax.patches.append(splt.ACTIVE_PATCHES['positive_fills'][0]) self.assertEqual(len(splt.ax.patches), 1) # should move artist to the inactive patch dict. splt.plot_wiggles(positive_fills=False) self.assertEqual(len(splt.INACTIVE_PATCHES['positive_fills']),1) self.assertEqual(len(splt.ax.patches), 0) # calling again should change nothing splt.plot_wiggles(positive_fills=False) self.assertEqual(len(splt.INACTIVE_PATCHES['positive_fills']),1) self.assertEqual(len(splt.ax.patches), 0) # should move artist back to active patch dict. splt.plot_wiggles(positive_fills=True) self.assertEqual(len(splt.ACTIVE_PATCHES['positive_fills']),1) self.assertEqual(len(splt.ax.patches), 1) # new plot items should be accessible through the orig. axes self.assertEqual(len(splt.ax.patches), len(ax.patches)) def test_plot_wiggle_traces(self): """ Should add wiggle traces to the axes. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotter(ax, self.segy) # should add a single artist to the active line dict. splt.plot_wiggles(wiggle_traces=True) self.assertEqual(len(splt.ACTIVE_LINES['wiggle_traces']),1) # same artist should be in the axes list self.assertTrue(splt.ACTIVE_LINES['wiggle_traces'][0] in \ splt.ax.lines) # should be able to directly remove it splt.ax.lines.remove(splt.ACTIVE_LINES['wiggle_traces'][0]) self.assertEqual(len(splt.ax.lines), 0) # and re-add it splt.ax.lines.append(splt.ACTIVE_LINES['wiggle_traces'][0]) self.assertEqual(len(splt.ax.lines), 1) # should move artist to the inactive line dict. splt.plot_wiggles(wiggle_traces=False) self.assertEqual(len(splt.INACTIVE_LINES['wiggle_traces']),1) self.assertEqual(len(splt.ax.lines), 0) # calling again should change nothing splt.plot_wiggles(wiggle_traces=False) self.assertEqual(len(splt.INACTIVE_LINES['wiggle_traces']),1) self.assertEqual(len(splt.ax.lines), 0) # should move artist back to active line dict and re-add to axes splt.plot_wiggles(wiggle_traces=True) self.assertEqual(len(splt.ACTIVE_LINES['wiggle_traces']),1) self.assertEqual(len(splt.ax.lines), 1) # new plot items should be accessible through the orig. axes self.assertEqual(len(splt.ax.lines), len(ax.lines)) def test_link_axes(self): """ Plotting should update items in the linked axes. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotter(ax, self.segy) # should add one artist to our axes splt.plot_wiggles(wiggle_traces=True) self.assertEqual(len(splt.ACTIVE_LINES['wiggle_traces']), 1) self.assertTrue('wiggle_traces' not in splt.INACTIVE_LINES) self.assertEqual(len(ax.lines), 1) # should remove one artist to our axes splt.plot_wiggles(wiggle_traces=False) self.assertTrue('wiggle_traces' not in splt.ACTIVE_LINES) self.assertEqual(len(splt.INACTIVE_LINES['wiggle_traces']), 1) self.assertEqual(len(ax.lines), 0) def test_init_SEGYPickPlotter(self): """ Should create a new instance of the SEGYPickPlotter. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPickPlotter(ax, self.segy, pickdb=self.pickdb) # should inherit from SEGYPlotter for member in inspect.getmembers(SEGYPlotter): self.assertTrue(hasattr(splt, member[0])) # should build header lookup table self.assertTrue(isinstance(splt.sdb, SEGYHeaderDatabase)) # should attach axes self.assertTrue(isinstance(splt.ax, matplotlib.axes.Axes)) def test_plot_picks(self): """ Should add picks to an axes and the layer managers. """ fig = plt.figure() ax = fig.add_subplot(111) pickdb = PickDatabaseConnection(':memory:') for pick in uniq_picks: pickdb.update_pick(**pick) splt = SEGYPickPlotter(ax, self.segy, pickdb) # should add a single artist to the line dict. for each event splt.plot_picks() for event in self.pickdb.events: self.assertTrue(len(splt.ACTIVE_LINES[event]), 1) # should be able to add new picks and have them be accessible by the # SEGYPickPlotter new_event = '--tracer--' new_pick = copy.copy(uniq_picks[0]) new_pick['event'] = new_event pickdb.update_pick(**new_pick) splt.plot_picks() self.assertTrue(new_event in splt.ACTIVE_LINES)
def rayfan2db(rayfan_file, raydb_file=':memory:', synthetic=False, noise=None, pickdb=None, raypaths=False): """ Read a rayfan file and store its data in a database. Data are stored in a modified version of a :class:`rockfish.picking.database.PickDatabaseConnection`. Parameters ---------- rayfan_file: {str, file} An open file-like object or a string which is assumed to be a filename of a rayfan binary file. raydb_file: str, optional The filename of the new database. Default is to create a new database in memory. synthetic: bool, optional Determines whether or not to record traced traveltimes as picked traveltimes. noise: {float, None} Maximum amplitude of random noise to add the travel times. If ``None``, no noise is added. pickdb: :class:`rockfish.database.PickDatabaseConnection`, optional An active connection to the pick database that was used trace rays in ``rayfan_file``. Values for extra fields (e.g., 'trace_in_file') are copied from this database to the new database along with rayfan data. Default is ignore these extra fields. raypaths: bool, optional If ``True``, raypath coordinates are stored as text in a new table 'raypaths'. """ raydb = PickDatabaseConnection(raydb_file) rays = readRayfanGroup(rayfan_file) print "Adding {:} traveltimes to {:} ..."\ .format(rays.nrays, raydb_file) # add fields for raypaths if raypaths: raydb._create_table_if_not_exists(RAYPATH_TABLE, RAYPATH_FIELDS) ndb0 = raydb.execute('SELECT COUNT(rowid) FROM picks').fetchone()[0] for rfn in rays.rayfans: for i, _t in enumerate(rfn.travel_times): if noise is not None: _noise = noise * 2 * (np.random.random() - 0.5) else: _noise = 0.0 sx, sy, sz = rfn.paths[i][0] rx, ry, rz = rfn.paths[i][-1] if pickdb is not None: event = pickdb.vmbranch2event[rfn.event_ids[i]] else: event = rfn.event_ids[i] if synthetic: time = rfn.travel_times[i] + _noise time_reduced = time predicted = None residual = 0. else: time = rfn.pick_times[i] time_reduced = time predicted = rfn.travel_times[i] residual = rfn.residuals[i] d = {'event': event, 'ensemble': rfn.start_point_id, 'trace': rfn.end_point_ids[i], 'vm_branch': rfn.event_ids[i], 'vm_subid': rfn.event_subids[i], 'time' : time, 'time_reduced' : time_reduced, 'predicted' : predicted, 'residual' : residual, 'error': rfn.pick_errors[i], 'source_x': sx, 'source_y': sy, 'source_z': sz, 'receiver_x': rx, 'receiver_y': ry, 'receiver_z': rz, 'offset': rfn.offsets[i], 'faz': rfn.azimuths[i], 'method': 'rayfan2db({:})'.format(rayfan_file), 'data_file': rays.file.name} # Copy data from pickdb if pickdb is not None: pick = pickdb.get_picks(event=[event], ensemble=[d['ensemble']], trace=[d['trace']]) if len(pick) > 0: for f in ['trace_in_file', 'line', 'site', 'data_file']: try: d[f] = pick[0][f] except KeyError: pass # add data to standard tables raydb.update_pick(**d) # add raypath data to new table if raypaths: d = {'event': event, 'ensemble': rfn.start_point_id, 'trace': rfn.end_point_ids[i], 'ray_btm_x': rays.bottom_points[i][0], 'ray_btm_y': rays.bottom_points[i][1], 'ray_btm_z': rays.bottom_points[i][2], 'ray_x': str([p[0] for p in rfn.paths[i]]), 'ray_y': str([p[1] for p in rfn.paths[i]]), 'ray_z': str([p[2] for p in rfn.paths[i]])} raydb._insert(RAYPATH_TABLE, **d) raydb.commit() ndb = raydb.execute('SELECT COUNT(rowid) FROM picks').fetchone()[0] if (ndb - ndb0) != rays.nrays: msg = 'Only added {:} of {:} travel times to the database.'\ .format(ndb - ndb0, rays.nrays) warnings.warn(msg) return raydb
class SEGYPlotManagerTestCase(unittest.TestCase): """ Test cases for SEGYPlotManager """ def setUp(self): """ Setup benchmark data. """ self.segy = readSEGY(get_example_file('ew0210_o30.segy'), unpack_headers=True) self.pickdb = PickDatabaseConnection(':memory:') for pick in uniq_picks: self.pickdb.update_pick(**pick) self.default_params = [ 'ABSCISSA_KEY', 'GAIN', 'CLIP', 'NORMALIZATION_METHOD', 'OFFSET_GAIN_POWER', 'WIGGLE_PEN_COLOR', 'WIGGLE_PEN_WIDTH', 'NEG_FILL_COLOR', 'POS_FILL_COLOR', 'DISTANCE_UNIT', 'TIME_UNIT', 'SEGY_TIME_UNITS', 'SEGY_DISTANCE_UNITS', 'SEGY_HEADER_ALIASES'] def test_init_SEGYPlotManager(self): """ Should create a new instance of the SEGYPlotManager. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) # should have the default parameters for attr in self.default_params: self.assertTrue(hasattr(splt, attr)) # by default, should *not* have header databse self.assertFalse(hasattr(splt, 'sdb')) # should attach axes self.assertTrue(isinstance(splt.ax, matplotlib.axes.Axes)) # class should be able to update the axes splt.ax.plot([0,1],[0,1]) self.assertEqual(len(splt.ax.lines), 1) # updates outside the class should be seen inside the class ax.plot([0,1],[0,1]) self.assertEqual(len(splt.ax.lines), 2) self.assertEqual(len(ax.lines), 2) # if a pick database is given, should attach pickdb and build lookup db splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb) self.assertTrue(isinstance(splt.pickdb, PickDatabaseConnection)) self.assertTrue(isinstance(splt.sdb, SEGYHeaderDatabase)) def test_build_header_database(self): """ Should connect to a database for looking up header attributes. """ fig = plt.figure() ax = fig.add_subplot(111) # connecting to a database in memory should not create a file splt = SEGYPlotManager(ax, self.segy, trace_header_database=':memory:') self.assertFalse(os.path.isfile(':memory:')) # connecting to a filename should create a file if it does not exist filename = 'temp.test_build_header_database.sqlite' if os.path.isfile(filename): os.remove(filename) # should not create database if no pickdb is given splt = SEGYPlotManager(ax, self.segy, trace_header_database=filename) self.assertFalse(os.path.isfile(filename)) # should create database if pickdb is given splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb, trace_header_database=filename) self.assertTrue(os.path.isfile(filename)) # clean up os.remove(filename) # create plot manager in memory with pickdb splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb) # should be able to get primary fields from the picks trace table pick_keys = splt.pickdb._get_primary_fields(splt.pickdb.TRACE_TABLE) self.assertTrue(len(pick_keys) > 0) # pick primary keys should be in the alias dictionary # or be a header attribute header = self.segy.traces[0].header for k in pick_keys: self.assertTrue(k in splt.SEGY_HEADER_ALIASES \ or hasattr(header, k)) # should add header fields for primary keys in the pick database sql = 'SELECT * FROM %s' %splt.sdb.TRACE_TABLE data = splt.sdb.execute(sql) for key in pick_keys: _key = splt._get_header_alias(key) for i,row in enumerate(data): segy_value = splt.get_header_value( self.segy.traces[i].header, key, convert_units=False) db_value = row[_key] self.assertEqual(segy_value, db_value) def test_get_units(self): """ Should return (segy units, plot units) or None. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) splt.DISTANCE_UNIT = 'distance_unit_marker' splt.TIME_UNIT = 'time_unit_marker' for key in TRACE_HEADER_KEYS: if key in splt.SEGY_TIME_UNITS: # should return TIME_UNIT for a time attribute self.assertEqual(splt._get_units(key)[1], 'time_unit_marker') elif key in splt.SEGY_DISTANCE_UNITS: # should return DISTANCE_UNIT for a distance attribute self.assertEqual(splt._get_units(key)[1], 'distance_unit_marker') else: # should return None values are unitless self.assertEqual(splt._get_units(key), None) def test_convert_units(self): """ Should convert header units to plot units. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) # should correctly perform unit conversions for distance splt.DISTANCE_UNIT = 'km' self.assertEqual(splt._convert_units('offset', [1000]), [1]) # should correctly perform unit conversions for time splt.TIME_UNIT = 's' self.assertEqual(splt._convert_units('delay', [1000]), [1]) def test_get_time_array(self): """ Should return an array of time values for a single trace. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) tr = self.segy.traces[0] # should have npts values npts = splt.get_header_value(tr.header, 'npts') t = splt.get_time_array(tr.header) self.assertEqual(npts, len(t)) def test_get_header_value(self): """ Should be able to use aliases to get unit-converted values from headers. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) # should return exact header values when convert_units=False tr = self.segy.traces[0] for alias in splt.SEGY_HEADER_ALIASES: key = splt.SEGY_HEADER_ALIASES[alias] value = splt.get_header_value(tr.header, alias, convert_units=False) _value = tr.header.__getattribute__(key) self.assertEqual(value, _value) # default should return header values in the plot units splt.DISTANCE_UNIT = 'km' alias = 'offset' key = splt.SEGY_HEADER_ALIASES[alias] scaled_value = splt.get_header_value(tr.header, alias) header_value = tr.header.__getattribute__(key) self.assertEqual(np.round(scaled_value, decimals=3), header_value/1000.) def test_get_abscissa(self): """ Should return a list of abcissa values. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy, pickdb=self.pickdb) # just get abcissa for a subset of the traces idx = [0, 100, 103, 104, 500, 550] values = [] for i in idx: ensemble = self.segy.traces[i].header.ensemble_number trace = \ self.segy.traces[i].header.trace_number_within_the_ensemble values.append((ensemble, trace)) # should return a list of values for plot x-axis x = splt._get_abscissa(['ensemble', 'trace'], values) _key = splt._get_header_alias(splt.ABSCISSA_KEY) for j,i in enumerate(idx): tr = self.segy.traces[i] self.assertEqual(tr.header.__getattribute__(_key), x[j]) def test_manage_layers(self): """ Should handle moving plot items around. """ fig = plt.figure() ax = fig.add_subplot(111) splt = SEGYPlotManager(ax, self.segy) # should echo parameters if item is not in dicts self.assertTrue(splt._manage_layers(foobar=True)['foobar']) self.assertFalse(splt._manage_layers(foobar=False)['foobar']) # for active item and True, should do nothing splt.ACTIVE_LINES['foobar'] = ax.plot([0,1], [0,1]) self.assertTrue('foobar' in splt.ACTIVE_LINES) self.assertFalse('foobar' in splt.INACTIVE_LINES) self.assertFalse(splt._manage_layers(foobar=True)['foobar']) self.assertTrue('foobar' in splt.ACTIVE_LINES) self.assertFalse('foobar' in splt.INACTIVE_LINES) # for active item and False, should move to inactive and return False self.assertFalse(splt._manage_layers(foobar=False)['foobar']) self.assertFalse('foobar' in splt.ACTIVE_LINES) self.assertTrue('foobar' in splt.INACTIVE_LINES) # for force_new=True, should remove from active and inactive and return # True # item is currently in inactive list need2plot = splt._manage_layers(force_new=True, foobar=True) self.assertTrue(need2plot['foobar']) self.assertFalse('foobar' in splt.ACTIVE_LINES) self.assertFalse('foobar' in splt.INACTIVE_LINES) # item is now in active list splt.ACTIVE_LINES['foobar'] = ax.plot([0,1], [0,1]) need2plot = splt._manage_layers(force_new=True, foobar=True) self.assertTrue(need2plot['foobar']) self.assertFalse('foobar' in splt.ACTIVE_LINES) self.assertFalse('foobar' in splt.INACTIVE_LINES)