def read_dataset(input_file_path, target_file_path, read_dataset_info): """Read dataset from pickle files and preprocess it. Parameters ---------- input_file_path : str, os.PathLike or pathlib.Path The path of pickle file. target_file_path : str, os.PathLike or pathlib.Path The path of pickle file. read_dataset_info : dict Returns ------- resistance : numpy.ndarray The input data of the neural network. resistivity_log10 : numpy.ndarray The target data of the neural network. """ # read data and assign # data = read_pkl(file_path.numpy().decode('utf-8')) # resistance = data['resistance'] # resistivity_log10 = data['resistivity_log10'] data = read_pkl(input_file_path.numpy().decode('utf-8')) resistance = data data = read_pkl(target_file_path.numpy().decode('utf-8')) resistivity_log10 = data # parse read_dataset_info dictionary preprocess = read_dataset_info['preprocess'] Tx_locations = read_dataset_info['Tx_locations'] Rx_locations = read_dataset_info['Rx_locations'] nCx = read_dataset_info['nCx'] nCy = read_dataset_info['nCy'] # preprocess for k, v in preprocess.items(): if k == 'add_noise' and v.get('perform'): add_noise(resistance, **v.get('kwargs')) elif k == 'log_transform' and v.get('perform'): log_transform(resistance, **v.get('kwargs')) elif k == 'to_midpoint' and v.get('perform'): resistance = to_midpoint(resistance, Tx_locations, Rx_locations) elif k == 'to_txrx' and v.get('perform'): resistance = to_txrx(resistance, Tx_locations, Rx_locations) elif k == 'to_section' and v.get('perform'): resistivity_log10 = to_section(resistivity_log10, nCx, nCy) return resistance, resistivity_log10
def get_data(self, temp_file_list): resistance = np.empty((len(temp_file_list), *self.input_shape)) for i, file in enumerate(temp_file_list): data = read_pkl(file) if self.preprocess['to_midpoint']['perform']: resistance[i, ] = to_midpoint(data['resistance'], self.Tx_locations, self.Rx_locations) elif self.preprocess['to_txrx']['perform']: resistance[i, ] = to_txrx(data['resistance'], self.Tx_locations, self.Rx_locations) else: resistance[i, ] = data['resistance'].reshape(self.input_shape) for k, v in self.preprocess.items(): if k == 'add_noise' and v.get('perform'): add_noise(resistance, **v.get('kwargs')) elif k == 'log_transform' and v.get('perform'): log_transform(resistance, **v.get('kwargs')) return resistance
def plot_data(iterator, simulator, num_figs): SRCLOC = simulator.urf.abmn_locations[:, :4] RECLOC = simulator.urf.abmn_locations[:, 4:] active_idx = simulator.active_idx nCx = simulator.mesh.nCx nCy = simulator.mesh.nCy vectorCCx = simulator.mesh.vectorCCx vectorCCy = simulator.mesh.vectorCCy num_figs = 1 if num_figs < 1 else num_figs i = 1 for file in iterator: data = read_pkl(file.path) print(data['resistance'].shape, data['resistivity_log10'].shape) resistance = data['resistance'] resistivity = data['resistivity_log10'] # plot resistance # txrx version fig, ax = plt.subplots(figsize=(16, 9)) im = ax.imshow( to_txrx( resistance, SRCLOC, RECLOC, value=np.nan )[:, :, 0], origin='lower' ) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im, cax=cax) ax.set_xlabel('Rx_pair') ax.set_ylabel('Tx_pair') cbar.set_label(r'$\Delta V/I$') # midpoint version fig, ax = plt.subplots(figsize=(4, 3)) im = ax.imshow( to_midpoint( resistance, SRCLOC, RECLOC, value=np.nan )[:, :, 0] ) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im, cax=cax) ax.set_xlabel('common midpoint') ax.set_ylabel('count') cbar.set_label(r'$\Delta V/I$') ax.set_aspect('auto', adjustable='box') # plot resistivity # imshow version fig, ax = plt.subplots() im = simulator.mesh.plotImage(resistivity[active_idx], ax=ax) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im[0], cax=cax) ax.set_xlabel('m') ax.set_ylabel('m') cbar.set_label(r'$\Omega \bullet m (log_{10})$') # contourf version fig, ax = plt.subplots() simulator.mesh.plotImage(resistivity[active_idx], ax=ax) im = ax.contourf(vectorCCx, vectorCCy, resistivity[active_idx].reshape((nCy, nCx))) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im, cax=cax) ax.set_xlabel('m') ax.set_ylabel('m') cbar.set_label(r'$\Omega \bullet m (log_{10})$') plt.show() if i == num_figs: break else: i += 1