コード例 #1
0
ファイル: NFE_Solver.py プロジェクト: tomtomatron/vezda
def solver(medium, s, Uh, V, alpha, domain):

    #==============================================================================
    # Load the receiver coordinates and recording times from the data directory
    datadir = np.load('datadir.npz')
    times = np.load(str(datadir['recordingTimes']))
    receiverPoints = np.load(str(datadir['receivers']))

    # Compute length of time step.
    # This parameter is used for FFT shifting and time windowing
    dt = times[1] - times[0]

    # Load the windowing parameters for the receiver and time axes of
    # the 3D data array
    if Path('window.npz').exists():
        windowDict = np.load('window.npz')

        # Time window parameters (with units of time)
        tstart = windowDict['tstart']
        tstop = windowDict['tstop']

        # Convert time window parameters to corresponding array indices
        tstart = int(round(tstart / dt))
        tstop = int(round(tstop / dt))
        tstep = windowDict['tstep']

        # Receiver window parameters
        rstart = windowDict['rstart']
        rstop = windowDict['rstop']
        rstep = windowDict['rstep']

    else:
        # Set default window parameters if user did
        # not specify window parameters.

        # Time window parameters (integers corresponding to indices in an array)
        tstart = 0
        tstop = len(times)
        tstep = 1

        # Receiver window parameters
        rstart = 0
        rstop = receiverPoints.shape[0]
        rstep = 1

    # Slice the recording times according to the time window parameters
    # to create a time window array
    tinterval = np.arange(tstart, tstop, tstep)
    times = times[tinterval]
    T = times[-1] - times[0]
    times = np.linspace(-T, T, 2 * len(times) - 1)

    # Slice the receiverPoints array according to the receiver window parameters
    rinterval = np.arange(rstart, rstop, rstep)
    receiverPoints = receiverPoints[rinterval, :]

    Nr = receiverPoints.shape[0]
    Nt = len(times)  # number of samples in time window

    # Get information about the pulse function used to
    # generate the interrogating wave (These parameters are
    # used to help Vezda decide if it needs to recompute the
    # test functions in the case the user changes these parameters.)
    velocity = pulseFun.velocity  # only used if medium == constant
    peakFreq = pulseFun.peakFreq  # peak frequency
    peakTime = pulseFun.peakTime  # time at which the pulse amplitude is maximum

    # Used for getting time and frequency units
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))
    else:
        plotParams = default_params()

    # Get machine precision
    eps = np.finfo(float).eps  # about 2e-16 (used in division
    # so we never divide by zero)
    #==============================================================================
    # Load the user-specified space-time sampling grid
    try:
        samplingGrid = np.load('samplingGrid.npz')
    except FileNotFoundError:
        samplingGrid = None

    if samplingGrid is None:
        sys.exit(
            textwrap.dedent('''
                A space-time sampling grid needs to be set up before running the
                \'vzsolve\' command. Enter:
                    
                    vzgrid --help
                
                from the command-line for more information on how to set up a
                sampling grid.
                '''))

    if 'z' not in samplingGrid:
        # Apply linear sampling method to three-dimensional space-time
        x = samplingGrid['x']
        y = samplingGrid['y']
        tau = samplingGrid['tau']
        z = None

        # Get number of sampling points in space and time
        Nx = len(x)
        Ny = len(y)
        X, Y = np.meshgrid(x, y, indexing='ij')

        # Initialize the Histogram for storing images at each sampling point in time.
        # Initialize the Image (time-integrated Histogram with respect to L2 norm)
        Image = np.zeros(X.shape)

        if medium == 'constant':
            # Vezda will compute free-space test functions over the space-time
            # sampling grid via function calls to 'FundamentalSolutions.py'. This is
            # much more efficient than applying a forward and inverse FFT pair to
            # shift the test functions in time corresponding to different sampling
            # points in time. FFT pairs are only used when medium == variable.
            pulse = lambda t: pulseFun.pulse(t)

            # Previously computed test functions and parameters from pulseFun module
            # are stored in 'VZTestFuncs.npz'. If the current space-time sampling grid
            # and pulseFun module parameters are consistent with those stored in
            # 'VZTestFuncs.npz', then Vezda will load the previously computed test
            # functions. Otherwise, Vezda will recompute the test functions. This reduces
            # computational cost by only computing test functions when necessary.
            if Path('VZTestFuncs.npz').exists():
                print(
                    '\nDetected that free-space test functions have already been computed...'
                )
                print(
                    'Checking consistency with current space-time sampling grid and pulse function...'
                )
                TFDict = np.load('VZTestFuncs.npz')

                if samplingIsCurrent(TFDict, receiverPoints, times, velocity,
                                     tau, x, y, z, peakFreq, peakTime):
                    print('Moving forward to imaging algorithm...')
                    TFarray = TFDict['TFarray']

                else:
                    if tau[0] != 0:
                        tu = plotParams['tu']
                        if tu != '':
                            print(
                                'Recomputing test functions for focusing time %0.2f %s...'
                                % (tau[0], tu))
                        else:
                            print(
                                'Recomputing test functions for focusing time %0.2f...'
                                % (tau[0]))
                        TFarray, samplingPoints = sampleSpace(
                            receiverPoints, times - tau[0], velocity, x, y, z,
                            pulse)
                    else:
                        print('Recomputing test functions...')
                        TFarray, samplingPoints = sampleSpace(
                            receiverPoints, times, velocity, x, y, z, pulse)

                    np.savez('VZTestFuncs.npz',
                             TFarray=TFarray,
                             time=times,
                             receivers=receiverPoints,
                             peakFreq=peakFreq,
                             peakTime=peakTime,
                             velocity=velocity,
                             x=x,
                             y=y,
                             tau=tau,
                             samplingPoints=samplingPoints)

            else:
                print(
                    '\nComputing free-space test functions for the current space-time sampling grid...'
                )
                if tau[0] != 0:
                    tu = plotParams['tu']
                    if tu != '':
                        print(
                            'Computing test functions for focusing time %0.2f %s...'
                            % (tau[0], tu))
                    else:
                        print(
                            'Computing test functions for focusing time %0.2f...'
                            % (tau[0]))
                    TFarray, samplingPoints = sampleSpace(
                        receiverPoints, times - tau[0], velocity, x, y, z,
                        pulse)
                else:
                    TFarray, samplingPoints = sampleSpace(
                        receiverPoints, times, velocity, x, y, z, pulse)

                np.savez('VZTestFuncs.npz',
                         TFarray=TFarray,
                         time=times,
                         receivers=receiverPoints,
                         peakFreq=peakFreq,
                         peakTime=peakTime,
                         velocity=velocity,
                         x=x,
                         y=y,
                         tau=tau,
                         samplingPoints=samplingPoints)
            #==============================================================================
            if domain == 'freq':
                # Transform test functions into the frequency domain and bandpass for efficient solution
                # to near-field equation
                print('Transforming test functions to the frequency domain...')

                N = nextPow2(Nt)
                TFarray = np.fft.rfft(TFarray, n=N, axis=1)

                if plotParams['fmax'] is None:
                    freqs = np.fft.rfftfreq(N, tstep * dt)
                    plotParams['fmax'] = np.max(freqs)
                    pickle.dump(plotParams, open('plotParams.pkl', 'wb'),
                                pickle.HIGHEST_PROTOCOL)

                # Apply the frequency window
                fmin = plotParams['fmin']
                fmax = plotParams['fmax']
                fu = plotParams['fu']  # frequency units (e.g., Hz)

                if fu != '':
                    print('Applying bandpass filter: [%0.2f %s, %0.2f %s]' %
                          (fmin, fu, fmax, fu))
                else:
                    print('Applying bandpass filter: [%0.2f, %0.2f]' %
                          (fmin, fmax))

                df = 1.0 / (N * tstep * dt)
                startIndex = int(round(fmin / df))
                stopIndex = int(round(fmax / df))

                finterval = np.arange(startIndex, stopIndex, 1)
                TFarray = TFarray[:, finterval, :]

            N = TFarray.shape[1]

            #==============================================================================
            # Solve the near-field equation for each sampling point
            print('Localizing the source...')
            # Compute the Tikhonov-regularized solution to the near-field equation N * phi = tf.
            # 'tf' is a test function
            # 'alpha' is the regularization parameter
            # 'phi_alpha' is the regularized solution given 'alpha'

            k = 0  # counter for spatial sampling points
            for ix in trange(Nx, desc='Solving system'):
                for iy in range(Ny):
                    tf = np.reshape(TFarray[:, :, k], (N * Nr, 1))
                    phi_alpha = Tikhonov(Uh, s, V, tf, alpha)
                    Image[ix, iy] = 1.0 / (norm(phi_alpha) + eps)
                    k += 1

            Imin = np.min(Image)
            Imax = np.max(Image)
            Image = (Image - Imin) / (Imax - Imin + eps)

        elif medium == 'variable':
            if 'testFuncs' in datadir:
                # Load the user-provided test functions
                TFarray = np.load(str(datadir['testFuncs']))

                # Apply the receiver/time windows, if any
                TFarray = TFarray[rinterval, :, :]
                TFarray = TFarray[:, tinterval, :]

                #==============================================================================
                if domain == 'freq':
                    # Transform test functions into the frequency domain and bandpass for efficient solution
                    # to near-field equation
                    print(
                        'Transforming test functions to the frequency domain...'
                    )

                    N = nextPow2(Nt)
                    TFarray = np.fft.rfft(TFarray, n=N, axis=1)

                    if plotParams['fmax'] is None:
                        freqs = np.fft.rfftfreq(N, tstep * dt)
                        plotParams['fmax'] = np.max(freqs)
                        pickle.dump(plotParams, open('plotParams.pkl', 'wb'),
                                    pickle.HIGHEST_PROTOCOL)

                    # Apply the frequency window
                    fmin = plotParams['fmin']
                    fmax = plotParams['fmax']
                    fu = plotParams['fu']  # frequency units (e.g., Hz)

                    if fu != '':
                        print(
                            'Applying bandpass filter: [%0.2f %s, %0.2f %s]' %
                            (fmin, fu, fmax, fu))
                    else:
                        print('Applying bandpass filter: [%0.2f, %0.2f]' %
                              (fmin, fmax))

                    df = 1.0 / (N * tstep * dt)
                    startIndex = int(round(fmin / df))
                    stopIndex = int(round(fmax / df))

                    finterval = np.arange(startIndex, stopIndex, 1)
                    TFarray = TFarray[:, finterval, :]

                N = TFarray.shape[1]

                # Load the sampling points
                samplingPoints = np.load(str(datadir['samplingPoints']))

            else:
                sys.exit(
                    textwrap.dedent('''
                        FileNotFoundError: Attempted to load file containing test
                        functions, but no such file exists. If a file exists containing
                        the test functions, run:
                            
                            'vzdata --path=<path/to/data/>'
                        
                        and specify the file containing the test functions when prompted.
                        Otherwise, specify 'no' when asked if a file containing the test
                        functions exists.
                        '''))

            userResponded = False
            print(
                textwrap.dedent('''
                 In what order was the sampling grid spanned to compute the test functions?
                 
                 Enter 'xy' if for each x, loop over y. (Default)
                 Enter 'yx' if for each y, loop over x.
                 Enter 'q/quit' to abort the calculation.
                 '''))
            while userResponded == False:
                order = input('Order: ')
                if order == '' or order == 'xy':
                    print('Proceeding with order \'xy\'...')
                    print('Localizing the source...')
                    # Compute the Tikhonov-regularized solution to the near-field equation N * phi = tf.
                    # 'tf' is a test function
                    # 'alpha' is the regularization parameter
                    # 'phi_alpha' is the regularized solution given 'alpha'

                    k = 0  # counter for spatial sampling points
                    for ix in trange(Nx, desc='Solving system'):
                        for iy in range(Ny):
                            tf = np.reshape(TFarray[:, :, k], (N * Nr, 1))
                            phi_alpha = Tikhonov(Uh, s, V, tf, alpha)
                            Image[ix, iy] = 1.0 / (norm(phi_alpha) + eps)
                            k += 1

                    Imin = np.min(Image)
                    Imax = np.max(Image)
                    Image = (Image - Imin) / (Imax - Imin + eps)
                    userResponded = True
                    break

                elif order == 'yx':
                    print('Proceeding with order \'yx\'...')
                    print('Localizing the source...')
                    # Compute the Tikhonov-regularized solution to the near-field equation N * phi = tf.
                    # 'tf' is a test function
                    # 'alpha' is the regularization parameter
                    # 'phi_alpha' is the regularized solution given 'alpha'

                    k = 0  # counter for spatial sampling points
                    for iy in trange(Ny, desc='Solving system'):
                        for ix in range(Nx):
                            tf = np.reshape(TFarray[:, :, k], (N * Nr, 1))
                            phi_alpha = Tikhonov(Uh, s, V, tf, alpha)
                            Image[ix, iy] = 1.0 / (norm(phi_alpha) + eps)
                            k += 1

                    Imin = np.min(Image)
                    Imax = np.max(Image)
                    Image = (Image - Imin) / (Imax - Imin + eps)
                    userResponded = True
                    break

                elif order == 'q' or order == 'quit':
                    sys.exit('Aborting calculation.')

                else:
                    print(
                        textwrap.dedent('''
                         Invalid response. Please enter one of the following:
                         
                         Enter 'xy' if for each x, loop over y. (Default)
                         Enter 'yx' if for each y, loop over x.
                         Enter 'q/quit' to abort the calculation.
                         '''))

        if domain == 'freq':
            np.savez('imageNFE.npz', Image=Image, alpha=alpha, X=X, Y=Y)
        else:
            np.savez('imageNFE.npz',
                     Image=Image,
                     alpha=alpha,
                     X=X,
                     Y=Y,
                     tau=tau)

    #==============================================================================
    else:
        # Apply linear sampling method to four-dimensional space-time
        x = samplingGrid['x']
        y = samplingGrid['y']
        z = samplingGrid['z']
        tau = samplingGrid['tau']

        # Get number of sampling points in space and time
        Nx = len(x)
        Ny = len(y)
        Nz = len(z)
        X, Y, Z = np.meshgrid(x, y, z, indexing='ij')

        # Initialize the Histogram for storing images at each sampling point in time.
        # Initialize the Image (time-integrated Histogram with respect to L2 norm)
        Image = np.zeros(X.shape)

        if medium == 'constant':
            # Vezda will compute free-space test functions over the space-time
            # sampling grid via function calls to 'FundamentalSolutions.py'. This is
            # much more efficient than applying a forward and inverse FFT pair to
            # shift the test functions in time corresponding to different sampling
            # points in time. FFT pairs are only used when medium == variable.
            pulse = lambda t: pulseFun.pulse(t)

            # Previously computed test functions and parameters from pulseFun module
            # are stored in 'VZTestFuncs.npz'. If the current space-time sampling grid
            # and pulseFun module parameters are consistent with those stored in
            # 'VZTestFuncs.npz', then Vezda will load the previously computed test
            # functions. Otherwise, Vezda will recompute the test functions. This reduces
            # computational cost by only computing test functions when necessary.

            if Path('VZTestFuncs.npz').exists():
                print(
                    '\nDetected that free-space test functions have already been computed...'
                )
                print(
                    'Checking consistency with current space-time sampling grid and pulse function...'
                )
                TFDict = np.load('VZTestFuncs.npz')

                if samplingIsCurrent(TFDict, receiverPoints, times, velocity,
                                     tau, x, y, z, peakFreq, peakTime):
                    print('Moving forward to imaging algorithm...')
                    TFarray = TFDict['TFarray']

                else:

                    if tau[0] != 0:
                        tu = plotParams['tu']
                        if tu != '':
                            print(
                                'Recomputing test functions for focusing time %0.2f %s...'
                                % (tau[0], tu))
                        else:
                            print(
                                'Recomputing test functions for focusing time %0.2f...'
                                % (tau[0]))
                        TFarray, samplingPoints = sampleSpace(
                            receiverPoints, times - tau[0], velocity, x, y, z,
                            pulse)
                    else:
                        print('Recomputing test functions...')
                        TFarray, samplingPoints = sampleSpace(
                            receiverPoints, times, velocity, x, y, z, pulse)

                    np.savez('VZTestFuncs.npz',
                             TFarray=TFarray,
                             time=times,
                             receivers=receiverPoints,
                             peakFreq=peakFreq,
                             peakTime=peakTime,
                             velocity=velocity,
                             x=x,
                             y=y,
                             z=z,
                             tau=tau,
                             samplingPoints=samplingPoints)

            else:
                print(
                    '\nComputing free-space test functions for the current space-time sampling grid...'
                )
                if tau[0] != 0:
                    tu = plotParams['tu']
                    if tu != '':
                        print(
                            'Computing test functions for focusing time %0.2f %s...'
                            % (tau[0], tu))
                    else:
                        print(
                            'Computing test functions for focusing time %0.2f...'
                            % (tau[0]))
                    TFarray, samplingPoints = sampleSpace(
                        receiverPoints, times - tau[0], velocity, x, y, z,
                        pulse)
                else:
                    TFarray, samplingPoints = sampleSpace(
                        receiverPoints, times, velocity, x, y, z, pulse)

                np.savez('VZTestFuncs.npz',
                         TFarray=TFarray,
                         time=times,
                         receivers=receiverPoints,
                         peakFreq=peakFreq,
                         peakTime=peakTime,
                         velocity=velocity,
                         x=x,
                         y=y,
                         z=z,
                         tau=tau,
                         samplingPoints=samplingPoints)

            #==============================================================================
            if domain == 'freq':
                # Transform test functions into the frequency domain and bandpass for efficient solution
                # to near-field equation
                print('Transforming test functions to the frequency domain...')

                N = nextPow2(Nt)
                TFarray = np.fft.rfft(TFarray, n=N, axis=1)

                if plotParams['fmax'] is None:
                    freqs = np.fft.rfftfreq(N, tstep * dt)
                    plotParams['fmax'] = np.max(freqs)
                    pickle.dump(plotParams, open('plotParams.pkl', 'wb'),
                                pickle.HIGHEST_PROTOCOL)

                # Apply the frequency window
                fmin = plotParams['fmin']
                fmax = plotParams['fmax']
                fu = plotParams['fu']  # frequency units (e.g., Hz)

                if fu != '':
                    print('Applying bandpass filter: [%0.2f %s, %0.2f %s]' %
                          (fmin, fu, fmax, fu))
                else:
                    print('Applying bandpass filter: [%0.2f, %0.2f]' %
                          (fmin, fmax))

                df = 1.0 / (N * tstep * dt)
                startIndex = int(round(fmin / df))
                stopIndex = int(round(fmax / df))

                finterval = np.arange(startIndex, stopIndex, 1)
                TFarray = TFarray[:, finterval, :]

            N = TFarray.shape[1]
            #==============================================================================
            # Solve the near-field equation for each sampling point
            print('Localizing the source...')
            # Compute the Tikhonov-regularized solution to the near-field equation N * phi = tf.
            # 'tf' is a test function
            # 'alpha' is the regularization parameter
            # 'phi_alpha' is the regularized solution given 'alpha'

            k = 0  # counter for spatial sampling points
            for ix in trange(Nx, desc='Solving system'):
                for iy in range(Ny):
                    for iz in range(Nz):
                        tf = np.reshape(TFarray[:, :, k], (N * Nr, 1))
                        phi_alpha = Tikhonov(Uh, s, V, tf, alpha)
                        Image[ix, iy, iz] = 1.0 / (norm(phi_alpha) + eps)
                        k += 1

            Imin = np.min(Image)
            Imax = np.max(Image)
            Image = (Image - Imin) / (Imax - Imin + eps)

        elif medium == 'variable':
            if 'testFuncs' in datadir:
                # Load the user-provided test functions
                TFarray = np.load(str(datadir['testFuncs']))

                # Apply the receiver/time windows, if any
                TFarray = TFarray[rinterval, :, :]
                TFarray = TFarray[:, tinterval, :]

                #==============================================================================
                if domain == 'freq':
                    # Transform test functions into the frequency domain and bandpass for efficient solution
                    # to near-field equation
                    print(
                        'Transforming test functions to the frequency domain...'
                    )

                    N = nextPow2(Nt)
                    TFarray = np.fft.rfft(TFarray, n=N, axis=1)

                    if plotParams['fmax'] is None:
                        freqs = np.fft.rfftfreq(N, tstep * dt)
                        plotParams['fmax'] = np.max(freqs)
                        pickle.dump(plotParams, open('plotParams.pkl', 'wb'),
                                    pickle.HIGHEST_PROTOCOL)

                    # Apply the frequency window
                    fmin = plotParams['fmin']
                    fmax = plotParams['fmax']
                    fu = plotParams['fu']  # frequency units (e.g., Hz)

                    if fu != '':
                        print(
                            'Applying bandpass filter: [%0.2f %s, %0.2f %s]' %
                            (fmin, fu, fmax, fu))
                    else:
                        print('Applying bandpass filter: [%0.2f, %0.2f]' %
                              (fmin, fmax))

                    df = 1.0 / (N * tstep * dt)
                    startIndex = int(round(fmin / df))
                    stopIndex = int(round(fmax / df))

                    finterval = np.arange(startIndex, stopIndex, 1)
                    TFarray = TFarray[:, finterval, :]

                N = TFarray.shape[1]

                # Load the sampling points
                samplingPoints = np.load(str(datadir['samplingPoints']))

            else:
                sys.exit(
                    textwrap.dedent('''
                        FileNotFoundError: Attempted to load file containing test
                        functions, but no such file exists. If a file exists containing
                        the test functions, run:
                            
                            'vzdata --path=<path/to/data/>'
                        
                        and specify the file containing the test functions when prompted.
                        Otherwise, specify 'no' when asked if a file containing the test
                        functions exists.
                        '''))

            userResponded = False
            print(
                textwrap.dedent('''
                 In what order was the sampling grid spanned to compute the test functions?
                 
                 Enter 'xyz' if for each x, for each y, loop over z. (Default)
                 Enter 'xzy' if for each x, for each z, loop over y.
                 Enter 'yxz' if for each y, for each x, loop over z.
                 Enter 'yzx' if for each y, for each z, loop over x.
                 Enter 'zxy' if for each z, for each x, loop over y.
                 Enter 'zyx' if for each z, for each y, loop over x
                 Enter 'q/quit' to abort the calculation.
                 '''))
            while userResponded == False:
                order = input('Order: ')
                if order == '' or order == 'xyz':
                    print('Proceeding with order \'xyz\'...')
                    print('Localizing the source...')
                    # Compute the Tikhonov-regularized solution to the near-field equation N * phi = tf.
                    # 'tf' is a test function
                    # 'alpha' is the regularization parameter
                    # 'phi_alpha' is the regularized solution given 'alpha'

                    k = 0  # counter for spatial sampling points
                    for ix in trange(Nx, desc='Solving system'):
                        for iy in range(Ny):
                            for iz in range(Nz):
                                tf = np.reshape(TFarray[:, :, k], (N * Nr, 1))
                                phi_alpha = Tikhonov(Uh, s, V, tf, alpha)
                                Image[ix, iy,
                                      iz] = 1.0 / (norm(phi_alpha) + eps)
                                k += 1

                    Imin = np.min(Image)
                    Imax = np.max(Image)
                    Image = (Image - Imin) / (Imax - Imin + eps)
                    userResponded = True
                    break

                elif order == 'xzy':
                    print('Proceeding with order \'xzy\'...')
                    print('Localizing the source...')
                    # Compute the Tikhonov-regularized solution to the near-field equation N * phi = tf.
                    # 'tf' is a test function
                    # 'alpha' is the regularization parameter
                    # 'phi_alpha' is the regularized solution given 'alpha'

                    k = 0  # counter for spatial sampling points
                    for ix in trange(Nx, desc='Solving system'):
                        for iz in range(Nz):
                            for iy in range(Ny):
                                tf = np.reshape(TFarray[:, :, k], (N * Nr, 1))
                                phi_alpha = Tikhonov(Uh, s, V, tf, alpha)
                                Image[ix, iy,
                                      iz] = 1.0 / (norm(phi_alpha) + eps)
                                k += 1

                    Imin = np.min(Image)
                    Imax = np.max(Image)
                    Image = (Image - Imin) / (Imax - Imin + eps)
                    userResponded = True
                    break

                elif order == 'yxz':
                    print('Proceeding with order \'yxz\'...')
                    print('Localizing the source...')
                    # Compute the Tikhonov-regularized solution to the near-field equation N * phi = tf.
                    # 'tf' is a test function
                    # 'alpha' is the regularization parameter
                    # 'phi_alpha' is the regularized solution given 'alpha'

                    k = 0  # counter for spatial sampling points
                    for iy in trange(Ny, desc='Solving system'):
                        for ix in range(Nx):
                            for iz in range(Nz):
                                tf = np.reshape(TFarray[:, :, k], (N * Nr, 1))
                                phi_alpha = Tikhonov(Uh, s, V, tf, alpha)
                                Image[ix, iy,
                                      iz] = 1.0 / (norm(phi_alpha) + eps)
                                k += 1

                    Imin = np.min(Image)
                    Imax = np.max(Image)
                    Image = (Image - Imin) / (Imax - Imin + eps)
                    userResponded = True
                    break

                elif order == 'yzx':
                    print('Proceeding with order \'yzx\'...')
                    print('Localizing the source...')
                    # Compute the Tikhonov-regularized solution to the near-field equation N * phi = tf.
                    # 'tf' is a test function
                    # 'alpha' is the regularization parameter
                    # 'phi_alpha' is the regularized solution given 'alpha'

                    k = 0  # counter for spatial sampling points
                    for iy in trange(Ny, desc='Solving system'):
                        for iz in range(Nz):
                            for ix in range(Nx):
                                tf = np.reshape(TFarray[:, :, k], (N * Nr, 1))
                                phi_alpha = Tikhonov(Uh, s, V, tf, alpha)
                                Image[ix, iy,
                                      iz] = 1.0 / (norm(phi_alpha) + eps)
                                k += 1

                    Imin = np.min(Image)
                    Imax = np.max(Image)
                    Image = (Image - Imin) / (Imax - Imin + eps)
                    userResponded = True
                    break

                elif order == 'zxy':
                    print('Proceeding with order \'zxy\'...')
                    print('Localizing the source...')
                    # Compute the Tikhonov-regularized solution to the near-field equation N * phi = tf.
                    # 'tf' is a test function
                    # 'alpha' is the regularization parameter
                    # 'phi_alpha' is the regularized solution given 'alpha'

                    k = 0  # counter for spatial sampling points
                    for iz in trange(Nz, desc='Solving system'):
                        for ix in range(Nx):
                            for iy in range(Ny):
                                tf = np.reshape(TFarray[:, :, k], (N * Nr, 1))
                                phi_alpha = Tikhonov(Uh, s, V, tf, alpha)
                                Image[ix, iy,
                                      iz] = 1.0 / (norm(phi_alpha) + eps)
                                k += 1

                    Imin = np.min(Image)
                    Imax = np.max(Image)
                    Image = (Image - Imin) / (Imax - Imin + eps)
                    userResponded = True
                    break

                elif order == 'zyx':
                    print('Proceeding with order \'zyx\'...')
                    print('Localizing the source...')
                    # Compute the Tikhonov-regularized solution to the near-field equation N * phi = tf.
                    # 'tf' is a test function
                    # 'alpha' is the regularization parameter
                    # 'phi_alpha' is the regularized solution given 'alpha'

                    k = 0  # counter for spatial sampling points
                    for iz in trange(Nz, desc='Solving system'):
                        for iy in range(Ny):
                            for ix in range(Nx):
                                tf = np.reshape(TFarray[:, :, k], (N * Nr, 1))
                                phi_alpha = Tikhonov(Uh, s, V, tf, alpha)
                                Image[ix, iy,
                                      iz] = 1.0 / (norm(phi_alpha) + eps)
                                k += 1

                    Imin = np.min(Image)
                    Imax = np.max(Image)
                    Image = (Image - Imin) / (Imax - Imin + eps)
                    userResponded = True
                    break

                elif order == 'q' or order == 'quit':
                    sys.exit('Aborting calculation.')
                else:
                    print(
                        textwrap.dedent('''
                         Invalid response. Please enter one of the following:
                         
                         Enter 'xyz' if for each x, for each y, loop over z. (Default)
                         Enter 'xzy' if for each x, for each z, loop over y.
                         Enter 'yxz' if for each y, for each x, loop over z.
                         Enter 'yzx' if for each y, for each z, loop over x.
                         Enter 'zxy' if for each z, for each x, loop over y.
                         Enter 'zyx' if for each z, for each y, loop over x
                         Enter 'q/quit' to abort the calculation.
                         '''))

        if domain == 'freq':
            np.savez('imageNFE.npz', Image=Image, alpha=alpha, X=X, Y=Y, Z=Z)
        else:
            np.savez('imageNFE.npz',
                     Image=Image,
                     alpha=alpha,
                     X=X,
                     Y=Y,
                     Z=Z,
                     tau=tau)
コード例 #2
0
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--fmin',
        type=float,
        help='Specify the minimum frequency component of the noise.')
    parser.add_argument(
        '--fmax',
        type=float,
        help='Specify the maximum frequency component of the noise.')
    parser.add_argument(
        '--snr',
        type=float,
        help='''Specify the desired signal-to-noise ratio. Must be a positive
                                real number.''')
    args = parser.parse_args()

    #==============================================================================
    try:
        Dict = np.load('noisyData.npz')
        fmin = Dict['fmin']
        fmax = Dict['fmax']
        snr = Dict['snr']
    except FileNotFoundError:
        fmin, fmax, snr = None, None, None

    # Used for getting frequency units
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))
    else:
        plotParams = default_params()

    if all(v is None for v in [args.fmin, args.fmax, args.snr]):
        # if no arguments are passed

        if all(v is None for v in [fmin, fmax, snr]):
            # and no parameters have been assigned values
            sys.exit(
                textwrap.dedent('''
                    No noise has been added to the data.
                    '''))
        else:
            # print fmin, fmax, snr and exit
            fu = plotParams['fu']
            if fu != '':
                sys.exit(
                    textwrap.dedent('''
                        Band-limited white noise has already been added to the data:
                            
                        Minimum frequency: {:0.2f} {}
                        Maximum frequency: {:0.2f} {}
                        Signal-to-noise ratio: {:0.2f}
                        '''.format(fmin, fu, fmax, fu, snr)))
            else:
                sys.exit(
                    textwrap.dedent('''
                        Band-limited white noise has already been added to the data:
                            
                        Minimum frequency: {:0.2f}
                        Maximum frequency: {:0.2f}
                        Signal-to-noise ratio: {:0.2f}
                        '''.format(fmin, fmax, snr)))

    elif all(v is not None for v in [args.fmin, args.fmax, args.snr]):
        # if all arguments were passed

        if args.fmax < args.fmin:
            sys.exit(
                textwrap.dedent('''
                    RelationError: The maximum frequency component of the nosie must be greater
                    than or equal to the mininum frequency component.
                    '''))
        elif args.fmin <= 0:
            sys.exit(
                textwrap.dedent('''
                    ValueError: The minimum frequency component of the noise must be strictly positive.
                    '''))
        elif args.snr <= 0:
            sys.exit(
                textwrap.dedent('''
                    ValueError: The signal-to-noise ratio (SNR) must be strictly positive.
                    '''))

        fmin = args.fmin
        fmax = args.fmax
        snr = args.snr
        fu = plotParams['fu']
        if fu != '':
            print(
                textwrap.dedent('''
                  Adding band-limited white noise:
          
                  Minimum frequency: {:0.2f} {}
                  Maximum frequency: {:0.2f} {}
                  Signal-to-noise ratio: {:0.2f}
                  '''.format(fmin, fu, fmax, fu, snr)))
        else:
            print(
                textwrap.dedent('''
                  Adding band-limited white noise:
          
                  Minimum frequency: {:0.2f}
                  Maximum frequency: {:0.2f}
                  Signal-to-noise ratio: {:0.2f}
                  '''.format(fmin, fmax, snr)))

        # Load the 3D data array and recording times from data directory
        datadir = np.load('datadir.npz')
        recordedData = np.load(str(datadir['recordedData']))
        recordingTimes = np.load(str(datadir['recordingTimes']))
        dt = recordingTimes[1] - recordingTimes[0]

        noisyData = add_noise(recordedData, dt, fmin, fmax, snr)
        np.savez('noisyData.npz',
                 noisyData=noisyData,
                 fmin=fmin,
                 fmax=fmax,
                 snr=snr)

    else:
        sys.exit(
            textwrap.dedent('''
                Error: All command-line arguments \'--fmin\', \'--fmax\', and \'--snr\' must
                be used when parameterizing the noise.
                '''))
コード例 #3
0
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', action='store_true',
                        help='Plot the recorded data. (Default)')
    parser.add_argument('--testfunc', action='store_true',
                        help='Plot the simulated test functions.')
    parser.add_argument('--tu', type=str,
                        help='Specify the time units (e.g., \'s\' or \'ms\').')
    parser.add_argument('--au', type=str,
                        help='Specify the amplitude units (e.g., \'m\' or \'mm\').')
    parser.add_argument('--pclip', type=float,
                        help='''Specify the percentage (0-1) of the peak amplitude to display. This
                        parameter is used for pcolormesh plots only. Default is set to 1.''')
    parser.add_argument('--title', type=str,
                        help='''Specify a title for the wiggle plot. Default title is
                        \'Data\' if \'--data\' is passed and 'Test Function' if \'--testfunc\'
                        is passed.''')
    parser.add_argument('--format', '-f', type=str, default='pdf', choices=['png', 'pdf', 'ps', 'eps', 'svg'],
                        help='''Specify the image format of the saved file. Accepted formats are png, pdf,
                        ps, eps, and svg. Default format is set to pdf.''')
    parser.add_argument('--map', action='store_true',
                        help='''Plot a map of the receiver and source/sampling point locations. The current
                        source/sampling point will be highlighted. The boundary of the scatterer will also
                        be shown if available.''')
    parser.add_argument('--mode', type=str, choices=['light', 'dark'], required=False,
                        help='''Specify whether to view plots in light mode for daytime viewing
                        or dark mode for nighttime viewing.
                        Mode must be either \'light\' or \'dark\'.''')
    
    args = parser.parse_args()
    #==============================================================================
    # if a plotParams.pkl file already exists, load relevant parameters
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))
        
        # update parameters for wiggle plots based on passed arguments
        if args.mode is not None:
            plotParams['view_mode'] = args.mode
        
        if args.tu is not None:
            plotParams['tu'] = args.tu
        
        if args.au is not None:
            plotParams['au'] = args.au
            
        if args.pclip is not None:
            if args.pclip >= 0 and args.pclip <= 1:
                plotParams['pclip'] = args.pclip
            else:
                print(textwrap.dedent(
                      '''
                      Warning: Invalid value passed to argument \'--pclip\'. Value must be
                      between 0 and 1.
                      '''))
            
        if args.title is not None:
            if args.data:
                plotParams['data_title'] = args.title
            elif args.testfunc:
                plotParams['tf_title'] = args.title
    
    else: # create a plotParams dictionary file with default values
        plotParams = default_params()
        
        # update parameters for wiggle plots based on passed arguments
        if args.mode is not None:
            plotParams['view_mode'] = args.mode
        
        if args.tu is not None:
            plotParams['tu'] = args.tu
        
        if args.au is not None:
            plotParams['au'] = args.au
        
        if args.title is not None:
            if args.data:
                plotParams['data_title'] = args.title
            elif args.testfunc:
                plotParams['tf_title'] = args.title
    
    pickle.dump(plotParams, open('plotParams.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)

    #==============================================================================
    # Load the relevant data to plot
    datadir = np.load('datadir.npz')
    receiverPoints = np.load(str(datadir['receivers']))
    time = np.load(str(datadir['recordingTimes']))
    
    # Apply any user-specified windows
    rinterval, tinterval, tstep, dt, sinterval = get_user_windows()
    receiverPoints = receiverPoints[rinterval, :]
    time = time[tinterval]
    
    # Load the scatterer boundary, if it exists
    if 'scatterer' in datadir:
        scatterer = np.load(str(datadir['scatterer']))
    else:
        scatterer = None    
    
    if all(v is True for v in [args.data, args.testfunc]):
        # User specified both data and testfuncs for plotting
        # Send error message and exit.
        sys.exit(textwrap.dedent(
                '''
                Error: Cannot plot both recorded data and simulated test functions. Use
                
                    vzwiggles --data
                    
                to plot the recorded data or
                
                    vzwiggles --testfuncs
                    
                to plot the simulated test functions.
                '''))
    
    elif all(v is not True for v in [args.data, args.testfunc]):
        # User did not specify which wiggles to plot.
        # Plot recorded data by default.
        # load the 3D data array into variable 'X'
        # X[receiver, time, source]
        wiggleType = 'data'
        X = load_data(domain='time', verbose=True)
        
        if 'sources' in datadir:
            sourcePoints = np.load(str(datadir['sources']))
            sourcePoints = sourcePoints[sinterval, :]
        else:
            sourcePoints = None
    
    elif args.data:
        # load the 3D data array into variable 'X'
        # X[receiver, time, source]
        wiggleType = 'data'
        X = load_data(domain='time', verbose=True)
        
        if 'sources' in datadir:
            sourcePoints = np.load(str(datadir['sources']))
            sourcePoints = sourcePoints[sinterval, :]
        else:
            sourcePoints = None
        
    elif args.testfunc:
        wiggleType = 'testfunc'
        
        # Update time to convolution times
        T = time[-1] - time[0]
        time = np.linspace(-T, T, 2 * len(time) - 1)
        
        if 'testFuncs' not in datadir and not Path('VZTestFuncs.npz').exists():
            X, sourcePoints = load_test_funcs(domain='time', medium='constant',
                                              verbose=True, return_sampling_points=True)
            
        if 'testFuncs' in datadir and not Path('VZTestFuncs.npz').exists():
            X, sourcePoints = load_test_funcs(domain='time', medium='variable',
                                              verbose=True, return_sampling_points=True)
            
            # Pad time axis with zeros to length of convolution 2Nt-1
            npad = ((0, 0), (X.shape[1] - 1, 0), (0, 0))
            X = np.pad(X, pad_width=npad, mode='constant', constant_values=0)
            
        elif not 'testFuncs' in datadir and Path('VZTestFuncs.npz').exists():
            X, sourcePoints = load_test_funcs(domain='time', medium='constant',
                                              verbose=True, return_sampling_points=True)
                    
        elif 'testFuncs' in datadir and Path('VZTestFuncs.npz').exists():
            userResponded = False
            print(textwrap.dedent(
                 '''
                 Two files are available containing simulated test functions.
                 
                 Enter '1' to view the user-provided test functions. (Default)
                 Enter '2' to view the test functions computed by Vezda.
                 Enter 'q/quit' to exit.
                 '''))
            while userResponded == False:
                answer = input('Action: ')
                
                if answer == '' or answer == '1':
                    X, sourcePoints = load_test_funcs(domain='time', medium='variable',
                                                      verbose=True, return_sampling_points=True)
            
                    # Pad time axis with zeros to length of convolution 2Nt-1
                    npad = ((0, 0), (X.shape[1] - 1, 0), (0, 0))
                    X = np.pad(X, pad_width=npad, mode='constant', constant_values=0)
                    
                    userResponded = True
                    break
                
                elif answer == '2':
                    X, sourcePoints = load_test_funcs(domain='time', medium='constant',
                                                      verbose=True, return_sampling_points=True)
                    
                    userResponded = True
                    break
                
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.')
                
                else:
                    print('Invalid response. Please enter \'1\', \'2\', or \'q/quit\'.')
            
    
    #==============================================================================        
    # increment source/recording interval and receiver interval to be consistent
    # with one-based indexing (i.e., count from one instead of zero)
    sinterval += 1
    rinterval += 1
    
    Ns = X.shape[2]
    
    remove_keymap_conflicts({'left', 'right', 'up', 'down', 'save'})
    if args.map:
        fig, ax1, ax2 = setFigure(num_axes=2, mode=plotParams['view_mode'],
                                  ax2_dim=receiverPoints.shape[1])
            
        ax1.volume = X
        ax1.index = Ns // 2
        title = wave_title(ax1.index, sinterval, sourcePoints, wiggleType, plotParams)
        plotWiggles(ax1, X[:, :, ax1.index], time, rinterval, receiverPoints, title, wiggleType, plotParams)
        
        ax2.index = ax1.index
        plotMap(ax2, ax2.index, receiverPoints, sourcePoints, scatterer, wiggleType, plotParams)
        plt.tight_layout()
        fig.canvas.mpl_connect('key_press_event', lambda event: process_key_waves(event, time, rinterval, sinterval,
                                                                                  receiverPoints, sourcePoints, Ns, scatterer,
                                                                                  args.map, wiggleType, plotParams))
    
    else:
        fig, ax = setFigure(num_axes=1, mode=plotParams['view_mode'])
            
        ax.volume = X
        ax.index = Ns // 2
        title = wave_title(ax.index, sinterval, sourcePoints, wiggleType, plotParams)
        plotWiggles(ax, X[:, :, ax.index], time, rinterval, receiverPoints, title, wiggleType, plotParams)
        plt.tight_layout()
        fig.canvas.mpl_connect('key_press_event', lambda event: process_key_waves(event, time, rinterval, sinterval,
                                                                                  receiverPoints, sourcePoints, Ns, scatterer,
                                                                                  args.map, wiggleType, plotParams))
    
    plt.show()
コード例 #4
0
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--nfe',
        action='store_true',
        help='''Plot the image obtained by solving the near-field equation.''')
    parser.add_argument(
        '--lse',
        action='store_true',
        help=
        '''Plot the image obtained by solving the Lippmann-Schwinger equation.'''
    )
    parser.add_argument(
        '--movie',
        action='store_true',
        help='Save a three-dimensional figure as a rotating frame.')
    parser.add_argument('--isolevel',
                        type=float,
                        default=None,
                        help='''Specify the contour level of the isosurface for
                        three-dimensional visualizations. Level must be between 0 and 1.'''
                        )
    parser.add_argument(
        '--format',
        '-f',
        type=str,
        default=None,
        choices=['png', 'pdf', 'ps', 'eps', 'svg'],
        help=
        '''specify the image format of the saved file. Accepted formats are png, pdf,
                        ps, eps, and svg. Default format is set to pdf.''')
    parser.add_argument('--xlabel',
                        type=str,
                        default=None,
                        help='''specify a label for the x-axis.''')
    parser.add_argument('--ylabel',
                        type=str,
                        default=None,
                        help='''specify a label for the y-axis.''')
    parser.add_argument('--zlabel',
                        type=str,
                        default=None,
                        help='''specify a label for the z-axis.''')
    parser.add_argument(
        '--xu',
        type=str,
        default=None,
        help='Specify the units for the x-axis (e.g., \'m\' or \'km\').')
    parser.add_argument(
        '--yu',
        type=str,
        default=None,
        help='Specify the units for the y-axis (e.g., \'m\' or \'km\').')
    parser.add_argument(
        '--zu',
        type=str,
        default=None,
        help='Specify the units for the z-axis (e.g., \'m\' or \'km\').')
    parser.add_argument(
        '--colormap',
        type=str,
        default=None,
        choices=['viridis', 'plasma', 'inferno', 'magma', 'cividis'],
        help=
        'specify a perceptually uniform sequential colormap. Default is \'magma\''
    )
    parser.add_argument(
        '--colorbar',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'y/yes/true\' to plot colorbar. Default is \'n/no/false\'')
    parser.add_argument(
        '--invert_xaxis',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'y/yes/true\' to invert x-axis. Default is \'n/no/false\'')
    parser.add_argument(
        '--invert_yaxis',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'y/yes/true\' to invert y-axis. Default is \'n/no/false\'')
    parser.add_argument(
        '--invert_zaxis',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'y/yes/true\' to invert z-axis. Default is \'n/no/false\'')
    parser.add_argument(
        '--show_scatterer',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'y/yes/true\' to show scatterer boundary. Default is \'n/no/false\''
    )
    parser.add_argument(
        '--show_sources',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help='specify \'n/no/false\' to hide sources. Default is \'y/yes/true\''
    )
    parser.add_argument(
        '--show_receivers',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'n/no/false\' to hide receivers. Default is \'y/yes/true\'')
    parser.add_argument(
        '--mode',
        type=str,
        choices=['light', 'dark'],
        required=False,
        help='''Specify whether to view plots in light mode for daytime viewing
                        or dark mode for nighttime viewing.
                        Mode must be either \'light\' or \'dark\'.''')
    args = parser.parse_args()

    #==============================================================================

    # if a plotParams.pkl file already exists, load relevant parameters
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))

        # for both wiggle plots and image plots
        if args.format is not None:
            plotParams['pltformat'] = args.format

        if args.mode is not None:
            plotParams['view_mode'] = args.mode

        # for image/map plots
        if args.isolevel is not None:
            plotParams['isolevel'] = args.isolevel

        if args.xlabel is not None:
            plotParams['xlabel'] = args.xlabel

        if args.ylabel is not None:
            plotParams['ylabel'] = args.ylabel

        if args.zlabel is not None:
            plotParams['zlabel'] = args.zlabel
        #==============================================================================
        # handle units here
        if args.xu is not None:
            plotParams['xu'] = args.xu

        if args.yu is not None:
            plotParams['yu'] = args.yu

        if args.zu is not None:
            plotParams['zu'] = args.zu
        #==============================================================================
        if args.colormap is not None:
            plotParams['colormap'] = args.colormap
        #==============================================================================
        if args.colorbar is not None:
            if args.colorbar == 'n' or args.colorbar == 'no' or args.colorbar == 'false':
                plotParams['colorbar'] = False

            elif args.colorbar == 'y' or args.colorbar == 'yes' or args.colorbar == 'true':
                plotParams['colorbar'] = True
        #==============================================================================
        if args.invert_xaxis is not None:
            if args.invert_xaxis == 'n' or args.invert_xaxis == 'no' or args.invert_xaxis == 'false':
                plotParams['invert_xaxis'] = False

            elif args.invert_xaxis == 'y' or args.invert_xaxis == 'yes' or args.invert_xaxis == 'true':
                plotParams['invert_xaxis'] = True
        #==============================================================================
        if args.invert_yaxis is not None:
            if args.invert_yaxis == 'n' or args.invert_yaxis == 'no' or args.invert_yaxis == 'false':
                plotParams['invert_yaxis'] = False

            elif args.invert_yaxis == 'y' or args.invert_yaxis == 'yes' or args.invert_yaxis == 'true':
                plotParams['invert_yaxis'] = True
        #==============================================================================
        if args.invert_zaxis is not None:
            if args.invert_zaxis == 'n' or args.invert_zaxis == 'no' or args.invert_zaxis == 'false':
                plotParams['invert_zaxis'] = False

            elif args.invert_zaxis == 'y' or args.invert_zaxis == 'yes' or args.invert_zaxis == 'true':
                plotParams['invert_zaxis'] = True
        #==============================================================================
        if args.show_scatterer is not None:
            if args.show_scatterer == 'n' or args.show_scatterer == 'no' or args.show_scatterer == 'false':
                plotParams['show_scatterer'] = False

            elif args.show_scatterer == 'y' or args.show_scatterer == 'yes' or args.show_scatterer == 'true':
                plotParams['show_scatterer'] = True
        #==============================================================================
        if args.show_sources is not None:
            if args.show_sources == 'n' or args.show_sources == 'no' or args.show_sources == 'false':
                plotParams['show_sources'] = False

            elif args.show_sources == 'y' or args.show_sources == 'yes' or args.show_sources == 'true':
                plotParams['show_sources'] = True
        #==============================================================================
        if args.show_receivers is not None:
            if args.show_receivers == 'n' or args.show_receivers == 'no' or args.show_receivers == 'false':
                plotParams['show_receivers'] = False

            elif args.show_receivers == 'y' or args.show_receivers == 'yes' or args.show_receivers == 'true':
                plotParams['show_receivers'] = True

    #==============================================================================
    else:  # create a plotParams dictionary file with default values
        plotParams = default_params()

        # updated parameters based on passed arguments
        #for both image and wiggle plots
        if args.format is not None:
            plotParams['pltformat'] = args.format

        # for image/map plots
        if args.isolevel is not None:
            plotParams['isolevel'] = args.isolevel

        if args.colormap is not None:
            plotParams['colormap'] = args.colormap

        if args.colorbar is not None:
            if args.colorbar == 'n' or args.colorbar == 'no' or args.colorbar == 'false':
                plotParams['colorbar'] = False

            elif args.colorbar == 'y' or args.colorbar == 'yes' or args.colorbar == 'true':
                plotParams['colorbar'] = True

        if args.xlabel is not None:
            plotParams['xlabel'] = args.xlabel

        if args.ylabel is not None:
            plotParams['ylabel'] = args.ylabel

        if args.zlabel is not None:
            plotParams['zlabel'] = args.zlabel

        #==============================================================================
        # update units
        if args.xu is not None:
            plotParams['xu'] = args.xu

        if args.yu is not None:
            plotParams['yu'] = args.yu

        if args.zu is not None:
            plotParams['zu'] = args.zu
        #==============================================================================
        if args.invert_xaxis is not None:
            if args.invert_xaxis == 'n' or args.invert_xaxis == 'no' or args.invert_xaxis == 'false':
                plotParams['invert_xaxis'] = False

            elif args.invert_xaxis == 'y' or args.invert_xaxis == 'yes' or args.invert_xaxis == 'true':
                plotParams['invert_xaxis'] = True
        #==============================================================================
        if args.invert_yaxis is not None:
            if args.invert_yaxis == 'n' or args.invert_yaxis == 'no' or args.invert_yaxis == 'false':
                plotParams['invert_yaxis'] = False

            elif args.invert_yaxis == 'y' or args.invert_yaxis == 'yes' or args.invert_yaxis == 'true':
                plotParams['invert_yaxis'] = True
        #==============================================================================
        if args.invert_zaxis is not None:
            if args.invert_zaxis == 'n' or args.invert_zaxis == 'no' or args.invert_zaxis == 'false':
                plotParams['invert_zaxis'] = False

            elif args.invert_zaxis == 'y' or args.invert_zaxis == 'yes' or args.invert_zaxis == 'true':
                plotParams['invert_zaxis'] = True
        #==============================================================================
        if args.show_scatterer is not None:
            if args.show_scatterer == 'n' or args.show_scatterer == 'no' or args.show_scatterer == 'false':
                plotParams['show_scatterer'] = False

            elif args.show_scatterer == 'y' or args.show_scatterer == 'yes' or args.show_scatterer == 'true':
                plotParams['show_scatterer'] = True
        #==============================================================================
        if args.show_sources is not None:
            if args.show_sources == 'n' or args.show_sources == 'no' or args.show_sources == 'false':
                plotParams['show_sources'] = False

            elif args.show_sources == 'y' or args.show_sources == 'yes' or args.show_sources == 'true':
                plotParams['show_sources'] = True
        #==============================================================================
        if args.show_receivers is not None:
            if args.show_receivers == 'n' or args.show_receivers == 'no' or args.show_receivers == 'false':
                plotParams['show_receivers'] = False

            elif args.show_receivers == 'y' or args.show_receivers == 'yes' or args.show_receivers == 'true':
                plotParams['show_receivers'] = True

    pickle.dump(plotParams, open('plotParams.pkl', 'wb'),
                pickle.HIGHEST_PROTOCOL)

    #==============================================================================
    # load the shape of the scatterer and source/receiver locations
    datadir = np.load('datadir.npz')
    receiverPoints = np.load(str(datadir['receivers']))

    if 'sources' in datadir:
        sourcePoints = np.load(str(datadir['sources']))
    else:
        sourcePoints = None

    if 'scatterer' in datadir and plotParams['show_scatterer']:
        scatterer = np.load(str(datadir['scatterer']))
    else:
        scatterer = None
        if plotParams['show_scatterer']:
            print(
                textwrap.dedent('''
                  Warning: Attempted to load the file containing the scatterer coordinates,
                  but no such file exists. If a file exists containing the scatterer
                  points, run:
                      
                      'vzdata --path=<path/to/data/>'
                      
                  and specify the file containing the scatterer points when prompted. Otherwise,
                  specify 'no' when asked if a file containing the scatterer points exists.
                  '''))

    if Path('window.npz').exists():
        windowDict = np.load('window.npz')

        # Apply the receiver window
        rstart = windowDict['rstart']
        rstop = windowDict['rstop']
        rstep = windowDict['rstep']

        rinterval = np.arange(rstart, rstop, rstep)
        receiverPoints = receiverPoints[rinterval, :]

        if sourcePoints is not None:
            # Apply the source window
            sstart = windowDict['sstart']
            sstop = windowDict['sstop']
            sstep = windowDict['sstep']

            sinterval = np.arange(sstart, sstop, sstep)
            sourcePoints = sourcePoints[sinterval, :]

    #==============================================================================
    # Load the user-specified sampling grid
    if 'samplingGrid' in datadir:
        samplingGrid = np.load(str(datadir['samplingGrid']))
    else:
        try:
            samplingGrid = np.load('samplingGrid.npz')
        except FileNotFoundError:
            samplingGrid = None

    if samplingGrid is None:
        sys.exit(
            textwrap.dedent('''
                A sampling grid needs to be set up before it can be plotted.
                Enter:
                    
                    vzgrid --help
                
                from the command-line for more information on how to set up a
                sampling grid.
                '''))

    x = samplingGrid['x']
    y = samplingGrid['y']
    tau = samplingGrid['tau']
    if 'z' not in samplingGrid:
        X, Y = np.meshgrid(x, y, indexing='ij')
        Z = None
    else:
        z = samplingGrid['z']
        X, Y, Z = np.meshgrid(x, y, z, indexing='ij')

    #==============================================================================
    if Path('imageNFE.npz').exists() and not Path('imageLSE.npz').exists():
        if args.lse:
            sys.exit(
                textwrap.dedent('''
                    PlotError: User requested to plot an image obtained by solving
                    the Lippmann-Schwinger equation (LSE), but no such image exists.
                    '''))

        # plot the image obtained by solving the near-field equaiton (NFE)
        Dict = np.load('imageNFE.npz')
        flag = 'NFE'
        fig, ax = plotImage(Dict, X, Y, Z, tau, plotParams, flag, args.movie)

    elif not Path('imageNFE.npz').exists() and Path('imageLSE.npz').exists():
        if args.nfe:
            sys.exit(
                textwrap.dedent('''
                    PlotError: User requested to plot an image obtained by solving
                    the near-field equation (NFE), but no such image exists.
                    '''))

        # plot the image obtained by solving the Lippmann-Schwinger equation (LSE)
        Dict = np.load('imageLSE.npz')
        flag = 'LSE'
        fig, ax = plotImage(Dict, X, Y, Z, tau, plotParams, flag)

    elif Path('imageNFE.npz').exists() and Path('imageLSE.npz').exists():
        if args.nfe and not args.lse:
            # plot the image obtained by solving the near-field equaiton (NFE)
            Dict = np.load('imageNFE.npz')
            flag = 'NFE'
            fig, ax = plotImage(Dict, X, Y, Z, tau, plotParams, flag,
                                args.movie)

        elif not args.nfe and args.lse:
            # plot the image obtained by solving the Lippmann-Schwinger equation (LSE)
            Dict = np.load('imageLSE.npz')
            flag = 'LSE'
            fig, ax = plotImage(Dict, X, Y, Z, tau, plotParams, flag)

        elif args.nfe and args.lse:
            sys.exit(
                textwrap.dedent('''
                    PlotError: Please specify only one of the arguments \'--nfe\' or \'--lse\' to
                    view the corresponding image.'''))

        else:
            flag = ''
            print(
                textwrap.dedent('''
                  Images obtained by solving both NFE and LSE are available. Enter:
                      
                      vzimage --nfe
                        
                  to view the image obtained by solving NFE or
                    
                      vzimage --lse
                        
                  to view the image obtained by solving LSE.
                  '''))

    else:
        flag = ''
        if args.nfe:
            print(
                'Warning: An image obtained by solving the near-field equation (NFE) does not exist.'
            )

        elif args.lse:
            print(
                'Warning: An image obtained by solving the Lippmann-Schwinger equation (LSE) does not exist.'
            )

        elif args.nfe and args.lse:
            print(
                textwrap.dedent('''
                    Warning: An image has not yet been obtained by solving either the
                    near-field equation (NFE) or the Lippmann-Schwinger equation (LSE).
                    '''))

    try:
        ax
    except NameError:
        fig, ax = setFigure(num_axes=1,
                            mode=plotParams['view_mode'],
                            ax1_dim=receiverPoints.shape[1])

    plotMap(ax, None, receiverPoints, sourcePoints, scatterer, 'data',
            plotParams)

    #==============================================================================

    pltformat = plotParams['pltformat']
    fig.savefig('image' + flag + '.' + pltformat,
                format=pltformat,
                bbox_inches='tight',
                facecolor=fig.get_facecolor(),
                transparent=True)
    plt.show()
コード例 #5
0
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--nfe',
        action='store_true',
        help='''Plot the image obtained by solving the near-field equation.''')
    parser.add_argument(
        '--lse',
        action='store_true',
        help=
        '''Plot the image obtained by solving the Lippmann-Schwinger equation.'''
    )
    parser.add_argument(
        '--spacetime',
        action='store_true',
        help='''Plot the support of the source function in space-time.''')
    parser.add_argument(
        '--movie',
        action='store_true',
        help='''Save the space-time figure as a rotating frame (2D space)
                        or as a time-lapse structure (3D space).''')
    parser.add_argument('--isolevel',
                        type=float,
                        default=None,
                        help='''Specify the contour level of the isosurface for
                        three-dimensional visualizations. Level must be between 0 and 1.'''
                        )
    parser.add_argument(
        '--format',
        '-f',
        type=str,
        default=None,
        choices=['png', 'pdf', 'ps', 'eps', 'svg'],
        help=
        '''specify the image format of the saved file. Accepted formats are png, pdf,
                        ps, eps, and svg. Default format is set to pdf.''')
    parser.add_argument('--xlabel',
                        type=str,
                        default=None,
                        help='''specify a label for the x-axis.''')
    parser.add_argument('--ylabel',
                        type=str,
                        default=None,
                        help='''specify a label for the y-axis.''')
    parser.add_argument('--zlabel',
                        type=str,
                        default=None,
                        help='''specify a label for the z-axis.''')
    parser.add_argument(
        '--xu',
        type=str,
        default=None,
        help='Specify the units for the x-axis (e.g., \'m\' or \'km\').')
    parser.add_argument(
        '--yu',
        type=str,
        default=None,
        help='Specify the units for the y-axis (e.g., \'m\' or \'km\').')
    parser.add_argument(
        '--zu',
        type=str,
        default=None,
        help='Specify the units for the z-axis (e.g., \'m\' or \'km\').')
    parser.add_argument(
        '--colormap',
        type=str,
        default=None,
        choices=['viridis', 'plasma', 'inferno', 'magma', 'cividis'],
        help=
        'specify a perceptually uniform sequential colormap. Default is \'magma\''
    )
    parser.add_argument(
        '--colorbar',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'y/yes/true\' to plot colorbar. Default is \'n/no/false\'')
    parser.add_argument(
        '--invert_xaxis',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'y/yes/true\' to invert x-axis. Default is \'n/no/false\'')
    parser.add_argument(
        '--invert_yaxis',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'y/yes/true\' to invert y-axis. Default is \'n/no/false\'')
    parser.add_argument(
        '--invert_zaxis',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'y/yes/true\' to invert z-axis. Default is \'n/no/false\'')
    parser.add_argument(
        '--show_scatterer',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'y/yes/true\' to show scatterer boundary. Default is \'n/no/false\''
    )
    parser.add_argument(
        '--show_sources',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help='specify \'n/no/false\' to hide sources. Default is \'y/yes/true\''
    )
    parser.add_argument(
        '--show_receivers',
        type=str,
        default=None,
        choices=['n', 'no', 'false', 'y', 'yes', 'true'],
        help=
        'specify \'n/no/false\' to hide receivers. Default is \'y/yes/true\'')
    parser.add_argument(
        '--mode',
        type=str,
        choices=['light', 'dark'],
        required=False,
        help='''Specify whether to view plots in light mode for daytime viewing
                        or dark mode for nighttime viewing.
                        Mode must be either \'light\' or \'dark\'.''')
    args = parser.parse_args()

    #==============================================================================

    # if a plotParams.pkl file already exists, load relevant parameters
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))

        # for both wiggle plots and image plots
        if args.format is not None:
            plotParams['pltformat'] = args.format

        if args.mode is not None:
            plotParams['view_mode'] = args.mode

        # for image/map plots
        if args.isolevel is not None:
            plotParams['isolevel'] = args.isolevel

        if args.xlabel is not None:
            plotParams['xlabel'] = args.xlabel

        if args.ylabel is not None:
            plotParams['ylabel'] = args.ylabel

        if args.zlabel is not None:
            plotParams['zlabel'] = args.zlabel
        #==============================================================================
        # handle units here
        if args.xu is not None:
            plotParams['xu'] = args.xu

        if args.yu is not None:
            plotParams['yu'] = args.yu

        if args.zu is not None:
            plotParams['zu'] = args.zu
        #==============================================================================
        if args.colormap is not None:
            plotParams['colormap'] = args.colormap
        #==============================================================================
        if args.colorbar is not None:
            if args.colorbar == 'n' or args.colorbar == 'no' or args.colorbar == 'false':
                plotParams['colorbar'] = False

            elif args.colorbar == 'y' or args.colorbar == 'yes' or args.colorbar == 'true':
                plotParams['colorbar'] = True
        #==============================================================================
        if args.invert_xaxis is not None:
            if args.invert_xaxis == 'n' or args.invert_xaxis == 'no' or args.invert_xaxis == 'false':
                plotParams['invert_xaxis'] = False

            elif args.invert_xaxis == 'y' or args.invert_xaxis == 'yes' or args.invert_xaxis == 'true':
                plotParams['invert_xaxis'] = True
        #==============================================================================
        if args.invert_yaxis is not None:
            if args.invert_yaxis == 'n' or args.invert_yaxis == 'no' or args.invert_yaxis == 'false':
                plotParams['invert_yaxis'] = False

            elif args.invert_yaxis == 'y' or args.invert_yaxis == 'yes' or args.invert_yaxis == 'true':
                plotParams['invert_yaxis'] = True
        #==============================================================================
        if args.invert_zaxis is not None:
            if args.invert_zaxis == 'n' or args.invert_zaxis == 'no' or args.invert_zaxis == 'false':
                plotParams['invert_zaxis'] = False

            elif args.invert_zaxis == 'y' or args.invert_zaxis == 'yes' or args.invert_zaxis == 'true':
                plotParams['invert_zaxis'] = True
        #==============================================================================
        if args.show_scatterer is not None:
            if args.show_scatterer == 'n' or args.show_scatterer == 'no' or args.show_scatterer == 'false':
                plotParams['show_scatterer'] = False

            elif args.show_scatterer == 'y' or args.show_scatterer == 'yes' or args.show_scatterer == 'true':
                plotParams['show_scatterer'] = True
        #==============================================================================
        if args.show_sources is not None:
            if args.show_sources == 'n' or args.show_sources == 'no' or args.show_sources == 'false':
                plotParams['show_sources'] = False

            elif args.show_sources == 'y' or args.show_sources == 'yes' or args.show_sources == 'true':
                plotParams['show_sources'] = True
        #==============================================================================
        if args.show_receivers is not None:
            if args.show_receivers == 'n' or args.show_receivers == 'no' or args.show_receivers == 'false':
                plotParams['show_receivers'] = False

            elif args.show_receivers == 'y' or args.show_receivers == 'yes' or args.show_receivers == 'true':
                plotParams['show_receivers'] = True

    #==============================================================================
    else:  # create a plotParams dictionary file with default values
        plotParams = default_params()

        # updated parameters based on passed arguments
        #for both image and wiggle plots
        if args.format is not None:
            plotParams['pltformat'] = args.format

        # for image/map plots
        if args.isolevel is not None:
            plotParams['isolevel'] = args.isolevel

        if args.colormap is not None:
            plotParams['colormap'] = args.colormap

        if args.colorbar is not None:
            if args.colorbar == 'n' or args.colorbar == 'no' or args.colorbar == 'false':
                plotParams['colorbar'] = False

            elif args.colorbar == 'y' or args.colorbar == 'yes' or args.colorbar == 'true':
                plotParams['colorbar'] = True

        if args.xlabel is not None:
            plotParams['xlabel'] = args.xlabel

        if args.ylabel is not None:
            plotParams['ylabel'] = args.ylabel

        if args.zlabel is not None:
            plotParams['zlabel'] = args.zlabel

        #==============================================================================
        # update units
        if args.xu is not None:
            plotParams['xu'] = args.xu

        if args.yu is not None:
            plotParams['yu'] = args.yu

        if args.zu is not None:
            plotParams['zu'] = args.zu
        #==============================================================================
        if args.invert_xaxis is not None:
            if args.invert_xaxis == 'n' or args.invert_xaxis == 'no' or args.invert_xaxis == 'false':
                plotParams['invert_xaxis'] = False

            elif args.invert_xaxis == 'y' or args.invert_xaxis == 'yes' or args.invert_xaxis == 'true':
                plotParams['invert_xaxis'] = True
        #==============================================================================
        if args.invert_yaxis is not None:
            if args.invert_yaxis == 'n' or args.invert_yaxis == 'no' or args.invert_yaxis == 'false':
                plotParams['invert_yaxis'] = False

            elif args.invert_yaxis == 'y' or args.invert_yaxis == 'yes' or args.invert_yaxis == 'true':
                plotParams['invert_yaxis'] = True
        #==============================================================================
        if args.invert_zaxis is not None:
            if args.invert_zaxis == 'n' or args.invert_zaxis == 'no' or args.invert_zaxis == 'false':
                plotParams['invert_zaxis'] = False

            elif args.invert_zaxis == 'y' or args.invert_zaxis == 'yes' or args.invert_zaxis == 'true':
                plotParams['invert_zaxis'] = True
        #==============================================================================
        if args.show_scatterer is not None:
            if args.show_scatterer == 'n' or args.show_scatterer == 'no' or args.show_scatterer == 'false':
                plotParams['show_scatterer'] = False

            elif args.show_scatterer == 'y' or args.show_scatterer == 'yes' or args.show_scatterer == 'true':
                plotParams['show_scatterer'] = True
        #==============================================================================
        if args.show_sources is not None:
            if args.show_sources == 'n' or args.show_sources == 'no' or args.show_sources == 'false':
                plotParams['show_sources'] = False

            elif args.show_sources == 'y' or args.show_sources == 'yes' or args.show_sources == 'true':
                plotParams['show_sources'] = True
        #==============================================================================
        if args.show_receivers is not None:
            if args.show_receivers == 'n' or args.show_receivers == 'no' or args.show_receivers == 'false':
                plotParams['show_receivers'] = False

            elif args.show_receivers == 'y' or args.show_receivers == 'yes' or args.show_receivers == 'true':
                plotParams['show_receivers'] = True

    pickle.dump(plotParams, open('plotParams.pkl', 'wb'),
                pickle.HIGHEST_PROTOCOL)

    #==============================================================================
    # load the shape of the scatterer and source/receiver locations
    datadir = np.load('datadir.npz')
    receiverPoints = np.load(str(datadir['receivers']))

    if 'sources' in datadir:
        sourcePoints = np.load(str(datadir['sources']))
    else:
        sourcePoints = None

    if 'scatterer' in datadir and plotParams['show_scatterer']:
        scatterer = np.load(str(datadir['scatterer']))
    else:
        scatterer = None
        if plotParams['show_scatterer']:
            print(
                textwrap.dedent('''
                  Warning: Attempted to load the file containing the scatterer coordinates,
                  but no such file exists. If a file exists containing the scatterer
                  points, run:
                      
                      'vzdata --path=<path/to/data/>'
                      
                  and specify the file containing the scatterer points when prompted. Otherwise,
                  specify 'no' when asked if a file containing the scatterer points exists.
                  '''))

    if Path('window.npz').exists():
        windowDict = np.load('window.npz')

        # Apply the receiver window
        rstart = windowDict['rstart']
        rstop = windowDict['rstop']
        rstep = windowDict['rstep']

        rinterval = np.arange(rstart, rstop, rstep)
        receiverPoints = receiverPoints[rinterval, :]

        if sourcePoints is not None:
            # Apply the source window
            sstart = windowDict['sstart']
            sstop = windowDict['sstop']
            sstep = windowDict['sstep']

            sinterval = np.arange(sstart, sstop, sstep)
            sourcePoints = sourcePoints[sinterval, :]

    #==============================================================================
    if Path('imageNFE.npz').exists() and not Path('imageLSE.npz').exists():
        if args.lse:
            sys.exit(
                textwrap.dedent('''
                    PlotError: User requested to plot an image obtained by solving
                    the Lippmann-Schwinger equation (LSE), but no such image exists.
                    '''))

        # plot the image obtained by solving the near-field equaiton (NFE)
        Dict = np.load('imageNFE.npz')
        X = Dict['X']
        Y = Dict['Y']
        if 'Z' in Dict:
            Z = Dict['Z']
        else:
            Z = None
        if 'tau' in Dict:
            tau = Dict['tau']
            Ntau = len(tau)
        else:
            tau = None
        alpha = Dict['alpha']
        flag = 'NFE'
        fig1, ax1, *otherImages = plotImage(Dict, plotParams, flag,
                                            args.spacetime, args.movie)
        if otherImages:
            if len(otherImages) == 2:
                fig2, ax2 = otherImages
            elif len(otherImages) == 4:
                fig2, ax2, fig3, ax3 = otherImages

    elif not Path('imageNFE.npz').exists() and Path('imageLSE.npz').exists():
        if args.nfe:
            sys.exit(
                textwrap.dedent('''
                    PlotError: User requested to plot an image obtained by solving
                    the near-field equation (NFE), but no such image exists.
                    '''))

        # plot the image obtained by solving the Lippmann-Schwinger equation (LSE)
        Dict = np.load('imageLSE.npz')
        flag = 'LSE'
        fig1, ax1 = plotImage(Dict, plotParams, flag)

    elif Path('imageNFE.npz').exists() and Path('imageLSE.npz').exists():
        if args.nfe and not args.lse:
            # plot the image obtained by solving the near-field equaiton (NFE)
            Dict = np.load('imageNFE.npz')
            X = Dict['X']
            Y = Dict['Y']
            if 'Z' in Dict:
                Z = Dict['Z']
            else:
                Z = None
            if 'tau' in Dict:
                tau = Dict['tau']
                Ntau = len(tau)
            else:
                tau = None
            alpha = Dict['alpha']
            flag = 'NFE'
            fig1, ax1, *otherImages = plotImage(Dict, plotParams, flag,
                                                args.spacetime, args.movie)
            if otherImages:
                if len(otherImages) == 2:
                    fig2, ax2 = otherImages
                elif len(otherImages) == 4:
                    fig2, ax2, fig3, ax3 = otherImages

        elif not args.nfe and args.lse:
            # plot the image obtained by solving the Lippmann-Schwinger equation (LSE)
            Dict = np.load('imageLSE.npz')
            flag = 'LSE'
            fig1, ax1 = plotImage(Dict, plotParams, flag)

        elif args.nfe and args.lse:
            sys.exit(
                textwrap.dedent('''
                    PlotError: Please specify only one of the arguments \'--nfe\' or \'--lse\' to
                    view the corresponding image.'''))

        else:
            sys.exit(
                textwrap.dedent('''
                    Images obtained by solving both NFE and LSE are available. Enter:
                        
                        vzimage --nfe
                        
                    to view the image obtained by solving NFE or
                    
                        vzimage --lse
                        
                    to view the image obtained by solving LSE.
                    '''))

    else:
        flag = ''
        if args.nfe:
            print(
                'Warning: An image obtained by solving the near-field equation (NFE) does not exist.'
            )

        elif args.lse:
            print(
                'Warning: An image obtained by solving the Lippmann-Schwinger equation (LSE) does not exist.'
            )

        elif args.nfe and args.lse:
            print(
                textwrap.dedent('''
                    Warning: An image has not yet been obtained by solving either the
                    near-field equation (NFE) or the Lippmann-Schwinger equation (LSE).
                    '''))

    try:
        ax1
    except NameError:
        fig1, ax1 = setFigure(num_axes=1,
                              mode=plotParams['view_mode'],
                              ax1_dim=receiverPoints.shape[1])

    plotMap(ax1, None, receiverPoints, sourcePoints, scatterer, 'data',
            plotParams)

    #==============================================================================

    pltformat = plotParams['pltformat']
    fig1.savefig('image' + flag + '.' + pltformat,
                 format=pltformat,
                 bbox_inches='tight',
                 facecolor=fig1.get_facecolor(),
                 transparent=True)

    if flag == 'NFE' and args.spacetime:
        remove_keymap_conflicts({'left', 'right', 'up', 'down', 'save'})
        if Z is None:
            try:
                fig3.savefig('spacetime' + flag + '.' + pltformat,
                             format=pltformat,
                             bbox_inches='tight',
                             facecolor=fig3.get_facecolor(),
                             transparent=True)
                fig2.canvas.mpl_connect(
                    'key_press_event', lambda event: process_key_images(
                        event, plotParams, alpha, X, Y, Z, Ntau, tau))
            except NameError:
                print(
                    '\nSpace-time reconstruction is not available for a single sampling point in time.\n'
                )
        else:
            try:
                fig2.savefig('spacetime' + flag + '.' + pltformat,
                             format=pltformat,
                             bbox_inches='tight',
                             facecolor=fig2.get_facecolor(),
                             transparent=True)
                fig2.canvas.mpl_connect(
                    'key_press_event', lambda event: process_key_images(
                        event, plotParams, alpha, X, Y, Z, Ntau, tau))
            except NameError:
                print(
                    '\nSpace-time reconstruction is not available for a single sampling point in time.\n'
                )

    plt.show()
コード例 #6
0
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data',
                        action='store_true',
                        help='Plot the recorded data. (Default)')
    parser.add_argument('--testfunc',
                        action='store_true',
                        help='Plot the simulated test functions.')
    parser.add_argument('--tu',
                        type=str,
                        help='Specify the time units (e.g., \'s\' or \'ms\').')
    parser.add_argument(
        '--au',
        type=str,
        help='Specify the amplitude units (e.g., \'m\' or \'mm\').')
    parser.add_argument(
        '--pclip',
        type=float,
        help=
        '''Specify the percentage (0-1) of the peak amplitude to display. This
                        parameter is used for pcolormesh plots only. Default is set to 1.'''
    )
    parser.add_argument(
        '--title',
        type=str,
        help='''Specify a title for the wiggle plot. Default title is
                        \'Data\' if \'--data\' is passed and 'Test Function' if \'--testfunc\'
                        is passed.''')
    parser.add_argument(
        '--format',
        '-f',
        type=str,
        default='pdf',
        choices=['png', 'pdf', 'ps', 'eps', 'svg'],
        help=
        '''Specify the image format of the saved file. Accepted formats are png, pdf,
                        ps, eps, and svg. Default format is set to pdf.''')
    parser.add_argument(
        '--map',
        action='store_true',
        help=
        '''Plot a map of the receiver and source/sampling point locations. The current
                        source/sampling point will be highlighted. The boundary of the scatterer will also
                        be shown if available.''')
    parser.add_argument(
        '--mode',
        type=str,
        choices=['light', 'dark'],
        required=False,
        help='''Specify whether to view plots in light mode for daytime viewing
                        or dark mode for nighttime viewing.
                        Mode must be either \'light\' or \'dark\'.''')

    args = parser.parse_args()
    #==============================================================================
    # if a plotParams.pkl file already exists, load relevant parameters
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))

        # update parameters for wiggle plots based on passed arguments
        if args.mode is not None:
            plotParams['view_mode'] = args.mode

        if args.tu is not None:
            plotParams['tu'] = args.tu

        if args.au is not None:
            plotParams['au'] = args.au

        if args.pclip is not None:
            if args.pclip >= 0 and args.pclip <= 1:
                plotParams['pclip'] = args.pclip
            else:
                print(
                    textwrap.dedent('''
                      Warning: Invalid value passed to argument \'--pclip\'. Value must be
                      between 0 and 1.
                      '''))

        if args.title is not None:
            if args.data:
                plotParams['data_title'] = args.title
            elif args.testfunc:
                plotParams['tf_title'] = args.title

    else:  # create a plotParams dictionary file with default values
        plotParams = default_params()

        # update parameters for wiggle plots based on passed arguments
        if args.mode is not None:
            plotParams['view_mode'] = args.mode

        if args.tu is not None:
            plotParams['tu'] = args.tu

        if args.au is not None:
            plotParams['au'] = args.au

        if args.title is not None:
            if args.data:
                plotParams['data_title'] = args.title
            elif args.testfunc:
                plotParams['tf_title'] = args.title

    pickle.dump(plotParams, open('plotParams.pkl', 'wb'),
                pickle.HIGHEST_PROTOCOL)

    #==============================================================================
    # Load the relevant data to plot
    datadir = np.load('datadir.npz')
    receiverPoints = np.load(str(datadir['receivers']))
    recordingTimes = np.load(str(datadir['recordingTimes']))
    dt = recordingTimes[1] - recordingTimes[0]

    if 'scatterer' in datadir:
        scatterer = np.load(str(datadir['scatterer']))
    else:
        scatterer = None

    if Path('window.npz').exists():
        windowDict = np.load('window.npz')

        # Apply the receiver window
        rstart = windowDict['rstart']
        rstop = windowDict['rstop']
        rstep = windowDict['rstep']

        # Apply the time window
        tstart = windowDict['tstart']
        tstop = windowDict['tstop']
        tstep = windowDict['tstep']

        # Convert time window parameters to corresponding array indices
        Tstart = int(round(tstart / dt))
        Tstop = int(round(tstop / dt))

    else:
        rstart = 0
        rstop = receiverPoints.shape[0]
        rstep = 1

        tstart = recordingTimes[0]
        tstop = recordingTimes[-1]

        Tstart = 0
        Tstop = len(recordingTimes)
        tstep = 1

    rinterval = np.arange(rstart, rstop, rstep)
    receiverPoints = receiverPoints[rinterval, :]

    tinterval = np.arange(Tstart, Tstop, tstep)

    if all(v is True for v in [args.data, args.testfunc]):
        # User specified both data and testfuncs for plotting
        # Send error message and exit.
        sys.exit(
            textwrap.dedent('''
                Error: Cannot plot both recorded data and simulated test functions. Use
                
                    vzwiggles --data
                    
                to plot the recorded data or
                
                    vzwiggles --testfuncs
                    
                to plot the simulated test functions.
                '''))

    elif all(v is not True for v in [args.data, args.testfunc]):
        # User did not specify which wiggles to plot.
        # Plot recorded data by default.
        # load the 3D data array into variable 'X'
        # X[receiver, time, source]
        wiggleType = 'data'
        if Path('noisyData.npz').exists():
            userResponded = False
            print(
                textwrap.dedent('''
                  Detected that band-limited noise has been added to the data array.
                  Would you like to plot the noisy data? ([y]/n)
                    
                  Enter 'q/quit' exit the program.
                  '''))
            while userResponded == False:
                answer = input('Action: ')
                if answer == '' or answer == 'y' or answer == 'yes':
                    print('Proceeding with plot of noisy data...')
                    # read in the noisy data array
                    X = np.load('noisyData.npz')['noisyData']
                    userResponded = True
                elif answer == 'n' or answer == 'no':
                    print('Proceeding with plot of noise-free data...')
                    # read in the recorded data array
                    X = np.load(str(datadir['recordedData']))
                    userResponded = True
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.\n')
                else:
                    print(
                        'Invalid response. Please enter \'y/yes\', \'n\no\', or \'q/quit\'.'
                    )

        else:
            # read in the recorded data array
            X = np.load(str(datadir['recordedData']))

        time = recordingTimes
        if 'sources' in datadir:
            sourcePoints = np.load(str(datadir['sources']))
        else:
            sourcePoints = None
        X = X[rinterval, :, :]

    elif args.data:
        # load the 3D data array into variable 'X'
        # X[receiver, time, source]
        wiggleType = 'data'
        if Path('noisyData.npz').exists():
            userResponded = False
            print(
                textwrap.dedent('''
                  Detected that band-limited noise has been added to the data array.
                  Would you like to plot the noisy data? ([y]/n)
                    
                  Enter 'q/quit' exit the program.
                  '''))
            while userResponded == False:
                answer = input('Action: ')
                if answer == '' or answer == 'y' or answer == 'yes':
                    print('Proceeding with plot of noisy data...')
                    # read in the noisy data array
                    X = np.load('noisyData.npz')['noisyData']
                    userResponded = True
                elif answer == 'n' or answer == 'no':
                    print('Proceeding with plot of noise-free data...')
                    # read in the recorded data array
                    X = np.load(str(datadir['recordedData']))
                    userResponded = True
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.\n')
                else:
                    print(
                        'Invalid response. Please enter \'y/yes\', \'n\no\', or \'q/quit\'.'
                    )

        else:
            # read in the recorded data array
            X = np.load(str(datadir['recordedData']))

        time = recordingTimes
        if 'sources' in datadir:
            sourcePoints = np.load(str(datadir['sources']))
        else:
            sourcePoints = None
        X = X[rinterval, :, :]

    elif args.testfunc:
        wiggleType = 'testfunc'
        if 'testFuncs' in datadir and not Path('VZTestFuncs.npz').exists():
            X = np.load(str(datadir['testFuncs']))
            time = np.load(str(datadir['convolutionTimes']))
            sourcePoints = np.load(str(datadir['samplingPoints']))
            X = X[rinterval, :, :]

        elif not 'testFuncs' in datadir and Path('VZTestFuncs.npz').exists():
            print(
                '\nDetected that free-space test functions have already been computed...'
            )
            print(
                'Checking consistency with current space-time sampling grid...'
            )
            TFDict = np.load('VZTestFuncs.npz')

            samplingGrid = np.load('samplingGrid.npz')
            x = samplingGrid['x']
            y = samplingGrid['y']
            if 'z' in samplingGrid:
                z = samplingGrid['z']
            else:
                z = None
            tau = samplingGrid['tau']

            pulse = lambda t: pulseFun.pulse(t)
            velocity = pulseFun.velocity
            peakFreq = pulseFun.peakFreq
            peakTime = pulseFun.peakTime

            # set up the convolution times based on the length of the recording time interval
            time = recordingTimes[tinterval]
            T = time[-1] - time[0]
            time = np.linspace(-T, T, 2 * len(time) - 1)
            if samplingIsCurrent(TFDict, receiverPoints, time, velocity, tau,
                                 x, y, z, peakFreq, peakTime):
                print('Moving forward to plot test functions...')
                X = TFDict['TFarray']
                sourcePoints = TFDict['samplingPoints']

            else:
                if tau[0] != 0:
                    tu = plotParams['tu']
                    if tu != '':
                        print(
                            'Recomputing test functions for focusing time %0.2f %s...'
                            % (tau[0], tu))
                    else:
                        print(
                            'Recomputing test functions for focusing time %0.2f...'
                            % (tau[0]))
                    X, sourcePoints = sampleSpace(receiverPoints,
                                                  time - tau[0], velocity, x,
                                                  y, z, pulse)
                else:
                    print('Recomputing test functions...')
                    X, sourcePoints = sampleSpace(receiverPoints, time,
                                                  velocity, x, y, z, pulse)

                if z is None:
                    np.savez('VZTestFuncs.npz',
                             TFarray=X,
                             time=time,
                             receivers=receiverPoints,
                             peakFreq=peakFreq,
                             peakTime=peakTime,
                             velocity=velocity,
                             x=x,
                             y=y,
                             tau=tau,
                             samplingPoints=sourcePoints)
                else:
                    np.savez('VZTestFuncs.npz',
                             TFarray=X,
                             time=time,
                             receivers=receiverPoints,
                             peakFreq=peakFreq,
                             peakTime=peakTime,
                             velocity=velocity,
                             x=x,
                             y=y,
                             z=z,
                             tau=tau,
                             samplingPoints=sourcePoints)

        elif 'testFuncs' in datadir and Path('VZTestFuncs.npz').exists():
            userResponded = False
            print(
                textwrap.dedent('''
                 Two files are available containing simulated test functions.
                 
                 Enter '1' to view the user-provided test functions. (Default)
                 Enter '2' to view the test functions computed by Vezda.
                 Enter 'q/quit' to exit.
                 '''))
            while userResponded == False:
                answer = input('Action: ')

                if answer == '' or answer == '1':
                    X = np.load(str(datadir['testFuncs']))
                    time = np.load(str(datadir['convolutionTimes']))
                    sourcePoints = np.load(str(datadir['samplingPoints']))
                    X = X[rinterval, :, :]
                    userResponded = True
                    break

                elif answer == '2':
                    print(
                        '\nDetected that free-space test functions have already been computed...'
                    )
                    print(
                        'Checking consistency with current spatial sampling grid...'
                    )
                    TFDict = np.load('VZTestFuncs.npz')

                    samplingGrid = np.load('samplingGrid.npz')
                    x = samplingGrid['x']
                    y = samplingGrid['y']
                    if 'z' in samplingGrid:
                        z = samplingGrid['z']
                    else:
                        z = None
                    tau = samplingGrid['tau']

                    pulse = lambda t: pulseFun.pulse(t)
                    velocity = pulseFun.velocity
                    peakFreq = pulseFun.peakFreq
                    peakTime = pulseFun.peakTime

                    # set up the convolution times based on the length of the recording time interval
                    time = recordingTimes[tinterval]
                    T = time[-1] - time[0]
                    time = np.linspace(-T, T, 2 * len(time) - 1)
                    if samplingIsCurrent(TFDict, receiverPoints, time,
                                         velocity, tau, x, y, z, peakFreq,
                                         peakTime):
                        print('Moving forward to plot test functions...')
                        X = TFDict['TFarray']
                        sourcePoints = TFDict['samplingPoints']

                    else:
                        if tau[0] != 0:
                            tu = plotParams['tu']
                            if tu != '':
                                print(
                                    'Recomputing test functions for focusing time %0.2f %s...'
                                    % (tau[0], tu))
                            else:
                                print(
                                    'Recomputing test functions for focusing time %0.2f...'
                                    % (tau[0]))
                            X, sourcePoints = sampleSpace(
                                receiverPoints, time - tau[0], velocity, x, y,
                                z, pulse)
                        else:
                            print('Recomputing test functions...')
                            X, sourcePoints = sampleSpace(
                                receiverPoints, time, velocity, x, y, z, pulse)

                        if z is None:
                            np.savez('VZTestFuncs.npz',
                                     TFarray=X,
                                     time=time,
                                     receivers=receiverPoints,
                                     peakFreq=peakFreq,
                                     peakTime=peakTime,
                                     velocity=velocity,
                                     x=x,
                                     y=y,
                                     tau=tau,
                                     samplingPoints=sourcePoints)
                        else:
                            np.savez('VZTestFuncs.npz',
                                     TFarray=X,
                                     time=time,
                                     receivers=receiverPoints,
                                     peakFreq=peakFreq,
                                     peakTime=peakTime,
                                     velocity=velocity,
                                     x=x,
                                     y=y,
                                     z=z,
                                     tau=tau,
                                     samplingPoints=sourcePoints)

                    userResponded = True

                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.')

                else:
                    print(
                        'Invalid response. Please enter \'1\', \'2\', or \'q/quit\'.'
                    )

        else:
            print(
                '\nComputing free-space test functions for the current space-time sampling grid...'
            )
            # set up the convolution times based on the length of the recording time interval
            time = recordingTimes[tinterval]
            T = time[-1] - time[0]
            time = np.linspace(-T, T, 2 * len(time) - 1)

            samplingGrid = np.load('samplingGrid.npz')
            x = samplingGrid['x']
            y = samplingGrid['y']
            if 'z' in samplingGrid:
                z = samplingGrid['z']
            else:
                z = None
            tau = samplingGrid['tau']

            pulse = lambda t: pulseFun.pulse(t)
            velocity = pulseFun.velocity
            peakFreq = pulseFun.peakFreq
            peakTime = pulseFun.peakTime

            if tau[0] != 0:
                tu = plotParams['tu']
                if tu != '':
                    print(
                        'Computing test functions for focusing time %0.2f %s...'
                        % (tau[0], tu))
                else:
                    print(
                        'Computing test functions for focusing time %0.2f...' %
                        (tau[0]))
                X, sourcePoints = sampleSpace(receiverPoints, time - tau[0],
                                              velocity, x, y, z, pulse)
            else:
                X, sourcePoints = sampleSpace(receiverPoints, time, velocity,
                                              x, y, z, pulse)

            if z is None:
                np.savez('VZTestFuncs.npz',
                         TFarray=X,
                         time=time,
                         receivers=receiverPoints,
                         peakFreq=peakFreq,
                         peakTime=peakTime,
                         velocity=velocity,
                         x=x,
                         y=y,
                         tau=tau,
                         samplingPoints=sourcePoints)
            else:
                np.savez('VZTestFuncs.npz',
                         TFarray=X,
                         time=time,
                         receivers=receiverPoints,
                         peakFreq=peakFreq,
                         peakTime=peakTime,
                         velocity=velocity,
                         x=x,
                         y=y,
                         z=z,
                         tau=tau,
                         samplingPoints=sourcePoints)

    #==============================================================================
    if Path('window.npz').exists() and wiggleType == 'data':
        t0 = tstart
        tf = tstop

        # Apply the source window
        sstart = windowDict['sstart']
        sstop = windowDict['sstop']
        sstep = windowDict['sstep']

    else:
        t0 = time[0]
        tf = time[-1]

        sstart = 0
        sstop = X.shape[2]
        sstep = 1

    sinterval = np.arange(sstart, sstop, sstep)

    X = X[:, :, sinterval]
    if sourcePoints is not None:
        sourcePoints = sourcePoints[sinterval, :]

    # increment source/recording interval and receiver interval to be consistent
    # with one-based indexing (i.e., count from one instead of zero)
    sinterval += 1
    rinterval += 1
    rstart += 1

    Ns = X.shape[2]

    remove_keymap_conflicts({'left', 'right', 'up', 'down', 'save'})
    if args.map:
        fig, ax1, ax2 = setFigure(num_axes=2,
                                  mode=plotParams['view_mode'],
                                  ax2_dim=receiverPoints.shape[1])

        ax1.volume = X
        ax1.index = Ns // 2
        title = wave_title(ax1.index, sinterval, sourcePoints, wiggleType,
                           plotParams)
        plotWiggles(ax1, X[:, :, ax1.index], time, t0, tf, rstart, rinterval,
                    receiverPoints, title, wiggleType, plotParams)

        ax2.index = ax1.index
        plotMap(ax2, ax2.index, receiverPoints, sourcePoints, scatterer,
                wiggleType, plotParams)
        plt.tight_layout()
        fig.canvas.mpl_connect(
            'key_press_event', lambda event: process_key_waves(
                event, time, t0, tf, rstart, rinterval, sinterval,
                receiverPoints, sourcePoints, Ns, scatterer, args.map,
                wiggleType, plotParams))

    else:
        fig, ax = setFigure(num_axes=1, mode=plotParams['view_mode'])

        ax.volume = X
        ax.index = Ns // 2
        title = wave_title(ax.index, sinterval, sourcePoints, wiggleType,
                           plotParams)
        plotWiggles(ax, X[:, :, ax.index], time, t0, tf, rstart, rinterval,
                    receiverPoints, title, wiggleType, plotParams)
        plt.tight_layout()
        fig.canvas.mpl_connect(
            'key_press_event', lambda event: process_key_waves(
                event, time, t0, tf, rstart, rinterval, sinterval,
                receiverPoints, sourcePoints, Ns, scatterer, args.map,
                wiggleType, plotParams))

    plt.show()
コード例 #7
0
ファイル: data_utils.py プロジェクト: aaronprunty/vezda
from pathlib import Path
import textwrap
from vezda.math_utils import nextPow2
from vezda.signal_utils import tukey_taper
from vezda.sampling_utils import samplingIsCurrent, compute_impulse_responses
from vezda.plot_utils import default_params
sys.path.append(os.getcwd())
import pulseFun

datadir = np.load('datadir.npz')

# Used for getting time and frequency units
if Path('plotParams.pkl').exists():
    plotParams = pickle.load(open('plotParams.pkl', 'rb'))
else:
    plotParams = default_params()


def load_data(domain, taper=False, verbose=False, skip_fft=False):
    # load the recorded data
    print('Loading recorded waveforms...')
    if Path('noisyData.npz').exists():
        userResponded = False
        print(
            textwrap.dedent('''
              Detected that band-limited noise has been added to the data array.
              Would you like to use the noisy data? ([y]/n)
              
              Enter 'q/quit' exit the program.
              '''))
        while userResponded == False:
コード例 #8
0
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--format',
        '-f',
        type=str,
        default='pdf',
        choices=['png', 'pdf', 'ps', 'eps', 'svg'],
        help=
        '''specify the image format of the saved file. Accepted formats are png, pdf,
                        ps, eps, and svg. Default format is set to pdf.''')
    parser.add_argument(
        '--mode',
        type=str,
        choices=['light', 'dark'],
        required=False,
        help='''Specify whether to view plots in light mode for daytime viewing
                        or dark mode for nighttime viewing.
                        Mode must be either \'light\' or \'dark\'.''')
    args = parser.parse_args()

    #==============================================================================

    def process_key_picard(event, tstart, tstop, rinterval, receiverPoints,
                           scatterer, args, recordingTimes):
        if args.map:
            fig = event.canvas.figure
            ax1 = fig.axes[0]
            ax2 = fig.axes[1]
            if event.key == 'left' or event.key == 'down':
                previous_slice(ax1, tstart, tstop, rinterval, args,
                               recordingTimes, receiverPoints)
                previous_source(ax2, tstart, tstop, rinterval, args,
                                recordingTimes, receiverPoints)
            elif event.key == 'right' or event.key == 'up':
                next_slice(ax1, tstart, tstop, rinterval, args, recordingTimes,
                           receiverPoints)
                next_source(ax1, tstart, tstop, rinterval, args,
                            recordingTimes, receiverPoints)
        else:
            fig = event.canvas.figure
            ax = fig.axes[0]
            if event.key == 'left' or event.key == 'down':
                previous_slice(ax, tstart, tstop, rinterval, args,
                               recordingTimes, receiverPoints)
            elif event.key == 'right' or event.key == 'up':
                next_slice(ax, tstart, tstop, rinterval, args, recordingTimes,
                           receiverPoints)
        fig.canvas.draw()

    #==============================================================================
    try:
        s = np.load('singularValues.npy')
        U = np.load('leftVectors.npy')
    except FileNotFoundError:
        s = None
        U = None

    if any(v is None for v in [s, U]):
        sys.exit(
            textwrap.dedent('''
                A singular-value decomposition needs to be computed before any
                spectral analysis can be performed. Enter:
                    
                    vzsvd --help
                    
                from the command line for more information on how to compute
                a singular-value decomposition.
                '''))

    try:
        TFDict = np.load('VZTestFuncs.npz')
    except FileNotFoundError:
        TFDict = None

    if TFDict is None:
        sys.exit(textwrap.dedent('''
                '''))

    k = s.size
    s = np.reshape(s, (k, 1))
    # Compute coefficients for Picard plot
    TFarray = TFDict['TFarray']
    TF = TFarray[:, :, 0, 0]
    Nr, Nt = TF.shape
    b = np.reshape(TF, (Nt * Nr, 1))
    c = np.abs(U.T @ b)
    d = np.divide(c, s)

    #==============================================================================
    datadir = np.load('datadir.npz')
    receiverPoints = np.load(str(datadir['receivers']))
    sourcePoints = np.load(str(datadir['sources']))
    if 'scatterer' in datadir:
        scatterer = np.load(str(datadir['scatterer']))
    else:
        scatterer = None

    if Path('window.npz').exists():
        windowDict = np.load('window.npz')

        # Set the receiver window for receiverPoints
        rstart = windowDict['rstart']
        rstop = windowDict['rstop']
        rstep = windowDict['rstep']

        # Set the source window for sourcePoints
        sstart = windowDict['sstart']
        sstop = windowDict['sstop']
        sstep = windowDict['sstep']

    else:

        rstart = 0
        rstop = Nr
        rstep = 1

        sstart = 0
        sstop = sourcePoints.shape[0]
        sstep = 1

    # pltrstart is used to plot the correct receivers for
    # the simulated test function computed by Vezda
    pltrstart = rstart
    pltsstart = sstart

    rinterval = np.arange(rstart, rstop, rstep)
    receiverPoints = receiverPoints[rinterval, :]

    sinterval = np.arange(sstart, sstop, sstep)
    sourcePoints = sourcePoints[sinterval, :]

    #==============================================================================
    remove_keymap_conflicts({'left', 'right', 'up', 'down', 'save'})

    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))
    else:
        plotParams = default_params()

    if args.mode is not None:
        plotParams['view_mode'] = args.mode
        pickle.dump(plotParams, open('plotParams.pkl', 'wb'),
                    pickle.HIGHEST_PROTOCOL)

    fig, ax = setFigure(num_axes=1, mode=plotParams['view_mode'])

    #        ax1.index = k // 2
    #        PicardPlot(ax1, s, c, d)

    #        if receiverPoints.shape[1] == 2:
    #            ax2 = fig.add_subplot(122)
    #        elif receiverPoints.shape[1] == 3:
    #            ax2 = fig.add_subplot(122, projection='3d')

    #        ax2.index = ax1.index
    #        map_plot(ax2, ax2.index, args, rinterval, receiverPoints, sourcePoints, scatterer)
    #        plt.tight_layout()
    #        #fig.canvas.mpl_connect('key_press_event', lambda event: process_key(event, tstart, tstop, rinterval,
    #        #                                                                    receiverPoints, sourcePoints, scatterer,
    #        #                                                                    args, recordingTimes))
    ax.index = k // 2
    PicardPlot(ax, s, c, d)
    plt.tight_layout()
    fig.savefig('Picard.' + args.format,
                format=args.format,
                bbox_inches='tight')
    plt.show()
コード例 #9
0
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument('--nfo', action='store_true',
                        help='''Plot the singular-value decomposition of the
                        near-field operator (NFO).''')
    parser.add_argument('--lso', action='store_true',
                        help='''Plot the singular-value decomposition of the
                        Lippmann-Schwinger operator (LSO).''')
    parser.add_argument('--format', '-f', type=str, default='pdf', choices=['png', 'pdf', 'ps', 'eps', 'svg'],
                        help='''Specify the image format of the saved file. Accepted formats are png, pdf,
                        ps, eps, and svg. Default format is set to pdf.''')
    parser.add_argument('--mode', type=str, choices=['light', 'dark'], required=False,
                        help='''Specify whether to view plots in light mode for daytime viewing
                        or dark mode for nighttime viewing.
                        Mode must be either \'light\' or \'dark\'.''')
    args = parser.parse_args()
    
    # See if an SVD already exists. If so, attempt to load it...
    if args.nfo and not args.lso:
        operatorName = 'near-field operator'
        filename = 'NFO_SVD.npz'
    
    elif not args.nfo and args.lso:
        operatorName = 'Lippmann-Schwinger operator'
        filename = 'LSO_SVD.npz'
            
    elif args.nfo and args.lso:
        sys.exit(textwrap.dedent(
                '''
                UsageError: Please specify only one of the arguments \'--nfo\' or \'--lso\'.
                '''))
    
    else:
        sys.exit(textwrap.dedent(
                '''
                For which operator would you like to plot a singular-value decomposition?
                Enter:
                    
                    vzsvd --nfo
                
                for the near-field operator or
                
                    vzsvd --lso
                    
                for the Lippmann-Schwinger operator.
                '''))
            
    try:
        U, s, Vh = load_svd(filename)
    except IOError:
        sys.exit(textwrap.dedent(
                '''
                A singular-value decomposition of the {s} does not exist.
                '''.format(s=operatorName)))

    #==============================================================================
    # Read in data files 
    #==============================================================================
    datadir = np.load('datadir.npz')
    receiverPoints = np.load(str(datadir['receivers']))
    recordingTimes = np.load(str(datadir['recordingTimes']))
    
    # Apply user-specified windows
    rinterval, tinterval, tstep, dt, sinterval = get_user_windows()
    receiverPoints = receiverPoints[rinterval, :]
    recordingTimes = recordingTimes[tinterval]
    
    # Load appropriate source points and source window
    if args.nfo:    # Near-field operator                
        if 'sources' in datadir:
            sourcePoints = np.load(str(datadir['sources']))
            sourcePoints = sourcePoints[sinterval, :]
        else:
            sourcePoints = None
            
    else:
        # if args.lso (Lippmann-Schwinger operator)
            
        # in the case of the Lippmann-Schwinger operator, 'sourcePoints'
        # correspond to sampling points, which should always exist.
        if 'testFuncs' in datadir:
            sourcePoints = np.load(str(datadir['samplingPoints']))
                
        elif Path('VZTestFuncs.npz').exists():
            TFDict = np.load('VZTestFuncs.npz')
            sourcePoints = TFDict['samplingPoints']
        
        else:
            sys.exit(textwrap.dedent(
                    '''
                    Error: A sampling grid must exist and test functions computed
                    before a singular-value decomposition of the Lippmann-Schwinger
                    operator can be computed or plotted.
                    '''))
    
        # update sinterval for test functions
        sinterval = np.arange(0, sourcePoints.shape[0], 1)   
        
    # increment receiver/source intervals to be consistent with
    # one-based indexing (i.e., count from one instead of zero)
    rinterval += 1
    sinterval += 1
    
    #==============================================================================
    # Determine whether to plot SVD in time domain or frequency domain 
    #==============================================================================
    if np.issubdtype(U.dtype, np.complexfloating):
        domain = 'freq'
    else:
        domain = 'time'
    
    # Load plot parameters
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))
    else:
        plotParams = default_params()
        
    Nr = receiverPoints.shape[0]
    Nt = len(recordingTimes)
    k = len(s)
                
    if domain == 'freq':
        # plot singular vectors in frequency domain 
        N = nextPow2(2 * Nt)
        freqs = np.fft.rfftfreq(N, tstep * dt)
            
        if plotParams['fmax'] is None:
            plotParams['fmax'] = np.max(freqs)
            
        # Apply the frequency window
        fmin = plotParams['fmin']
        fmax = plotParams['fmax']
        df = 1.0 / (N * tstep * dt)
            
        startIndex = int(round(fmin / df))
        stopIndex = int(round(fmax / df))
        finterval = np.arange(startIndex, stopIndex, 1)
        freqs = freqs[finterval]
        
        M = len(freqs)         
        Ns = int(Vh.shape[1] / M)
        U = U.toarray().reshape((Nr, M, k))
        V = Vh.getH().toarray().reshape((Ns, M, k))
            
    else: # domain == 'time'
        M = 2 * Nt - 1
        Ns = int(Vh.shape[1] / M)
        U = U.reshape((Nr, M, k))
        V = Vh.T.reshape((Ns, M, k))
        T = recordingTimes[-1] - recordingTimes[0]
        times = np.linspace(-T, T, M)
        
    if args.mode is not None:
        plotParams['view_mode'] = args.mode
        
    pickle.dump(plotParams, open('plotParams.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
        
    remove_keymap_conflicts({'left', 'right', 'up', 'down', 'save'})
    if domain == 'freq':
            
        # plot the left singular vectors
        fig_lvec, ax_lvec_r, ax_lvec_i = setFigure(num_axes=2, mode=plotParams['view_mode'])
        ax_lvec_r.volume = U.real
        ax_lvec_i.volume = U.imag
        ax_lvec_r.index = 0
        ax_lvec_i.index = 0
        fig_lvec.suptitle('Left-Singular Vector', color=ax_lvec_r.titlecolor, fontsize=16)
        fig_lvec.subplots_adjust(bottom=0.27, top=0.86)
        leftTitle_r = vector_title('left', ax_lvec_r.index + 1, 'real')
        leftTitle_i = vector_title('left', ax_lvec_i.index + 1, 'imag')
        for ax, title in zip([ax_lvec_r, ax_lvec_i], [leftTitle_r, leftTitle_i]):
            left_im = plotFreqVectors(ax, ax.volume[:, :, ax.index], freqs, rinterval,
                                      receiverPoints, title, 'left', plotParams)
                
        lp0 = ax_lvec_r.get_position().get_points().flatten()
        lp1 = ax_lvec_i.get_position().get_points().flatten()
        left_cax = fig_lvec.add_axes([lp0[0], 0.12, lp1[2]-lp0[0], 0.03])
        lcbar = fig_lvec.colorbar(left_im, left_cax, orientation='horizontal')
        lcbar.outline.set_edgecolor(ax_lvec_r.cbaredgecolor)
        lcbar.ax.tick_params(axis='x', colors=ax_lvec_r.labelcolor)              
        lcbar.ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        lcbar.set_label('Amplitude',
                        labelpad=5, rotation=0, fontsize=12, color=ax_lvec_r.labelcolor)
        fig_lvec.canvas.mpl_connect('key_press_event', lambda event: process_key_vectors(event, freqs, rinterval, sinterval,
                                                                                         receiverPoints, sourcePoints, plotParams,
                                                                                         'cmplx_left'))
            
        # plot the right singular vectors
        fig_rvec, ax_rvec_r, ax_rvec_i = setFigure(num_axes=2, mode=plotParams['view_mode'])
        ax_rvec_r.volume = V.real
        ax_rvec_i.volume = V.imag
        ax_rvec_r.index = 0
        ax_rvec_i.index = 0
        fig_rvec.suptitle('Right-Singular Vector', color=ax_rvec_r.titlecolor, fontsize=16)
        fig_rvec.subplots_adjust(bottom=0.27, top=0.86)
        rightTitle_r = vector_title('right', ax_rvec_r.index + 1, 'real')
        rightTitle_i = vector_title('right', ax_rvec_i.index + 1, 'imag')
        for ax, title in zip([ax_rvec_r, ax_rvec_i], [rightTitle_r, rightTitle_i]):
            right_im = plotFreqVectors(ax, ax.volume[:, :, ax.index], freqs, sinterval,
                                       sourcePoints, title, 'right', plotParams)
            
        rp0 = ax_rvec_r.get_position().get_points().flatten()
        rp1 = ax_rvec_i.get_position().get_points().flatten()
        right_cax = fig_rvec.add_axes([rp0[0], 0.12, rp1[2]-rp0[0], 0.03])
        rcbar = fig_rvec.colorbar(right_im, right_cax, orientation='horizontal')  
        rcbar.outline.set_edgecolor(ax_rvec_r.cbaredgecolor)
        rcbar.ax.tick_params(axis='x', colors=ax_rvec_r.labelcolor)
        rcbar.ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        rcbar.set_label('Amplitude',
                        labelpad=5, rotation=0, fontsize=12, color=ax_lvec_r.labelcolor)
        fig_rvec.canvas.mpl_connect('key_press_event', lambda event: process_key_vectors(event, freqs, rinterval, sinterval,
                                                                                         receiverPoints, sourcePoints, plotParams,
                                                                                         'cmplx_right'))
            
    else:
        # domain == 'time'   
        fig_vec, ax_lvec, ax_rvec = setFigure(num_axes=2, mode=plotParams['view_mode'])
            
        ax_lvec.volume = U
        ax_lvec.index = 0
        leftTitle = vector_title('left', ax_lvec.index + 1)
        plotWiggles(ax_lvec, ax_lvec.volume[:, :, ax_lvec.index], times, rinterval,
                    receiverPoints, leftTitle, 'left', plotParams)
      
        ax_rvec.volume = V
        ax_rvec.index = 0
        rightTitle = vector_title('right', ax_rvec.index + 1)
        plotWiggles(ax_rvec, ax_rvec.volume[:, :, ax_rvec.index], times, sinterval,
                    sourcePoints, rightTitle, 'right', plotParams)
        fig_vec.tight_layout()
        fig_vec.canvas.mpl_connect('key_press_event', lambda event: process_key_vectors(event, times, rinterval, sinterval,
                                                                                        receiverPoints, sourcePoints, plotParams))
    #==============================================================================
    # plot the singular values
    # figure and axis for singular values
    fig_vals, ax_vals = setFigure(num_axes=1, mode=plotParams['view_mode'])
        
    n = np.arange(1, k + 1, 1)
    kappa = s[0] / s[-1]    # condition number = max(s) / min(s)
    ax_vals.plot(n, s, '.', clip_on=False, markersize=9, label=r'Condition Number: %0.1e' %(kappa), color=ax_vals.pointcolor)
    ax_vals.set_xlabel('n', color=ax_vals.labelcolor)
    ax_vals.set_ylabel('$\sigma_n$', color=ax_vals.labelcolor)
    legend = ax_vals.legend(title='Singular Values', loc='upper center', bbox_to_anchor=(0.5, 1.25),
                            markerscale=0, handlelength=0, handletextpad=0, fancybox=True, shadow=True,
                            fontsize='large')
    legend.get_title().set_fontsize('large')
    ax_vals.set_xlim([1, k])
    ax_vals.set_ylim(bottom=0)
    ax_vals.locator_params(axis='y', nticks=6)
    ax_vals.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    fig_vals.tight_layout()
    fig_vals.savefig('singularValues.' + args.format, format=args.format, bbox_inches='tight', facecolor=fig_vals.get_facecolor())
    
    plt.show()
コード例 #10
0
ファイル: plotSpectra.py プロジェクト: wjchu1995/vezda
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', action='store_true',
                        help='Plot the frequency spectrum of the recorded data. (Default)')
    parser.add_argument('--testfunc', action='store_true',
                        help='Plot the frequency spectrum of the simulated test functions.')
    parser.add_argument('--power', action='store_true',
                        help='''Plot the mean power spectrum of the input signals. Default is to plot the
                        mean amplitude spectrum of the Fourier transform.''')
    parser.add_argument('--fmin', type=float,
                        help='Specify the minimum frequency of the amplitude/power spectrum plot. Default is set to 0.')
    parser.add_argument('--fmax', type=float,
                        help='''Specify the maximum frequency of the amplitude/power spectrum plot. Default is set to the
                        maximum frequency bin based on the length of the time signal.''')
    parser.add_argument('--fu', type=str,
                        help='Specify the frequency units (e.g., Hz)')
    parser.add_argument('--format', '-f', type=str, default='pdf', choices=['png', 'pdf', 'ps', 'eps', 'svg'],
                        help='''specify the image format of the saved file. Accepted formats are png, pdf,
                        ps, eps, and svg. Default format is set to pdf.''')
    parser.add_argument('--mode', type=str, choices=['light', 'dark'], required=False,
                        help='''Specify whether to view plots in light mode for daytime viewing
                        or dark mode for nighttime viewing.
                        Mode must be either \'light\' or \'dark\'.''')
    args = parser.parse_args()
    
    #==============================================================================        
    # Get time window parameters
    tinterval, tstep, dt = get_user_windows()[1:4]
    datadir = np.load('datadir.npz')
    recordingTimes = np.load(str(datadir['recordingTimes']))
    recordingTimes = recordingTimes[tinterval]
    
    # Used for getting time and frequency units
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))
    else:
        plotParams = default_params()
    
    if all(v is True for v in [args.data, args.testfunc]):
        sys.exit(textwrap.dedent(
                '''
                Error: Cannot plot frequency spectrum of both recorded data and
                simulated test functions. Use
                
                    vzspectra --data
                    
                to plot the frequency spectrum of the recorded data or
                
                    vzspectra --testfuncs
                    
                to plot the frequency spectrum of the simulated test functions.
                '''))
    
    elif (args.data and not args.testfunc) or all(v is not True for v in [args.data, args.testfunc]):
        # default is to plot spectra of data if user does not specify either args.data or args.testfunc
        X = load_data(domain='time', verbose=True)
        
    elif not args.data and args.testfunc:
        if 'testFuncs' not in datadir and not Path('VZTestFuncs.npz').exists():
            X = load_test_funcs(domain='time', medium='constant', verbose=True)
        
        elif 'testFuncs' in datadir and not Path('VZTestFuncs.npz').exists():
            X = load_test_funcs(domain='time', medium='variable', verbose=True)
            
        elif not 'testFuncs' in datadir and Path('VZTestFuncs.npz').exists():
            X = load_test_funcs(domain='time', medium='constant', verbose=True)
                    
        elif 'testFuncs' in datadir and Path('VZTestFuncs.npz').exists():
            userResponded = False
            print(textwrap.dedent(
                 '''
                 Two files are available containing simulated test functions.
                 
                 Enter '1' to view the user-provided test functions. (Default)
                 Enter '2' to view the test functions computed by Vezda.
                 Enter 'q/quit' to exit.
                 '''))
            while userResponded == False:
                answer = input('Action: ')
                
                if answer == '' or answer == '1':
                    X = load_test_funcs(domain='time', medium='variable', verbose=True)
                    userResponded = True
                    break
                
                elif answer == '2':
                    X = load_test_funcs(domain='time', medium='constant', verbose=True)
                    userResponded = True
                    break
                
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.')
                
                else:
                    print('Invalid response. Please enter \'1\', \'2\', or \'q/quit\'.')
        
    #==============================================================================
    # compute spectra
    freqs, amplitudes = compute_spectrum(X, tstep * dt, args.power)
        
    if args.power:
        plotLabel = 'power'
        plotParams['freq_title'] = 'Mean Power Spectrum'
        plotParams['freq_ylabel'] = 'Power'
    else:
        plotLabel = 'amplitude'
        plotParams['freq_title'] = 'Mean Amplitude Spectrum'
        plotParams['freq_ylabel'] = 'Amplitude'
            
    if args.data or all(v is not True for v in [args.data, args.testfunc]):
        plotParams['freq_title'] += ' [' + plotParams['data_title'] + ']'
    elif args.testfunc:
        plotParams['freq_title'] += ' [' + plotParams['tf_title'] + 's]'
        
    if args.fmin is not None: 
        if args.fmin >= 0:
            if args.fmax is not None:
                if args.fmax > args.fmin:
                    plotParams['fmin'] = args.fmin
                    plotParams['fmax'] = args.fmax
                else:
                    sys.exit(textwrap.dedent(
                            '''
                            RelationError: The maximum frequency of the %s spectrum plot must
                            be greater than the mininum frequency.
                            ''' %(plotLabel)))   
            else:
                fmax = plotParams['fmax']
                if fmax > args.fmin:
                    plotParams['fmin'] = args.fmin
                else:
                    sys.exit(textwrap.dedent(
                            '''
                            RelationError: The specified minimum frequency of the %s spectrum 
                            plot must be less than the maximum frequency.
                            ''' %(plotLabel)))                                        
        else:
            sys.exit(textwrap.dedent(
                    '''
                    ValueError: The specified minimum frequency of the %s spectrum 
                    plot must be nonnegative.
                    ''' %(plotLabel)))
            
    #===============================================================================
    if args.fmax is not None:
        if args.fmin is not None:
            if args.fmin >= 0:
                if args.fmax > args.fmin:
                    plotParams['fmin'] = args.fmin
                    plotParams['fmax'] = args.fmax
                else:
                    sys.exit(textwrap.dedent(
                            '''
                            RelationError: The maximum frequency of the %s spectrum plot must
                            be greater than the mininum frequency.
                            ''' %(plotLabel)))
            else:
                sys.exit(textwrap.dedent(
                        '''
                        ValueError: The specified minimum frequency of the %s spectrum 
                        plot must be nonnegative.
                        ''' %(plotLabel)))
        else:
            fmin = plotParams['fmin']
            if args.fmax > fmin:
                plotParams['fmax'] = args.fmax
            else:
                sys.exit(textwrap.dedent(
                        '''
                        RelationError: The specified maximum frequency of the %s spectrum 
                        plot must be greater than the minimum frequency.
                        ''' %(plotLabel)))
    elif plotParams['fmax'] is None:
        plotParams['fmax'] = np.max(freqs)
                
    #===================================================================================
    if args.fu is not None:
        plotParams['fu'] = args.fu
            
    if args.mode is not None:
        plotParams['view_mode'] = args.mode
    
    pickle.dump(plotParams, open('plotParams.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
    
    fig, ax = setFigure(num_axes=1, mode=plotParams['view_mode'])
    ax.plot(freqs, amplitudes, color=ax.linecolor, linewidth=ax.linewidth)
    ax.set_title(plotParams['freq_title'], color=ax.titlecolor)
    
    # get frequency units from plotParams
    fu = plotParams['fu']
    fmin = plotParams['fmin']
    fmax = plotParams['fmax']
    if fu != '':
        ax.set_xlabel('Frequency (%s)' %(fu), color=ax.labelcolor)
    else:
        ax.set_xlabel('Frequency', color=ax.labelcolor)
    ax.set_ylabel(plotParams['freq_ylabel'], color=ax.labelcolor)
    ax.set_xlim([fmin, fmax])
    ax.set_ylim(bottom=0)
    ax.fill_between(freqs, 0, amplitudes, where=(amplitudes > 0), color='m', alpha=ax.alpha)
    ax.locator_params(axis='y', nticks=6)
    ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    plt.tight_layout()
    fig.savefig(plotLabel + 'Spectrum.' + args.format, format=args.format, bbox_inches='tight', facecolor=fig.get_facecolor())
    plt.show()
    
コード例 #11
0
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data',
        action='store_true',
        help='Plot the frequency spectrum of the recorded data. (Default)')
    parser.add_argument(
        '--testfunc',
        action='store_true',
        help='Plot the frequency spectrum of the simulated test functions.')
    parser.add_argument(
        '--power',
        action='store_true',
        help=
        '''Plot the mean power spectrum of the input signals. Default is to plot the
                        mean amplitude spectrum of the Fourier transform.''')
    parser.add_argument(
        '--fmin',
        type=float,
        help=
        'Specify the minimum frequency of the amplitude/power spectrum plot. Default is set to 0.'
    )
    parser.add_argument(
        '--fmax',
        type=float,
        help=
        '''Specify the maximum frequency of the amplitude/power spectrum plot. Default is set to the
                        maximum frequency bin based on the length of the time signal.'''
    )
    parser.add_argument('--fu',
                        type=str,
                        help='Specify the frequency units (e.g., Hz)')
    parser.add_argument(
        '--format',
        '-f',
        type=str,
        default='pdf',
        choices=['png', 'pdf', 'ps', 'eps', 'svg'],
        help=
        '''specify the image format of the saved file. Accepted formats are png, pdf,
                        ps, eps, and svg. Default format is set to pdf.''')
    parser.add_argument(
        '--mode',
        type=str,
        choices=['light', 'dark'],
        required=False,
        help='''Specify whether to view plots in light mode for daytime viewing
                        or dark mode for nighttime viewing.
                        Mode must be either \'light\' or \'dark\'.''')
    args = parser.parse_args()

    #==============================================================================
    # Load the recording times from the data directory
    datadir = np.load('datadir.npz')
    receiverPoints = np.load(str(datadir['receivers']))
    recordingTimes = np.load(str(datadir['recordingTimes']))
    dt = recordingTimes[1] - recordingTimes[0]

    if Path('window.npz').exists():
        windowDict = np.load('window.npz')

        # Apply the receiver window
        rstart = windowDict['rstart']
        rstop = windowDict['rstop']
        rstep = windowDict['rstep']

        # Apply the time window
        tstart = windowDict['tstart']
        tstop = windowDict['tstop']
        tstep = windowDict['tstep']

        # Convert time window parameters to corresponding array indices
        Tstart = int(round(tstart / dt))
        Tstop = int(round(tstop / dt))

    else:
        rstart = 0
        rstop = receiverPoints.shape[0]
        rstep = 1

        tstart = recordingTimes[0]
        tstop = recordingTimes[-1]

        Tstart = 0
        Tstop = len(recordingTimes)
        tstep = 1

    # Apply the receiver window
    rinterval = np.arange(rstart, rstop, rstep)
    receiverPoints = receiverPoints[rinterval, :]

    # Apply the time window
    tinterval = np.arange(Tstart, Tstop, tstep)
    recordingTimes = recordingTimes[tinterval]

    # Used for getting time and frequency units
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))
    else:
        plotParams = default_params()

    if all(v is True for v in [args.data, args.testfunc]):
        sys.exit(
            textwrap.dedent('''
                Error: Cannot plot frequency spectrum of both recorded data and
                simulated test functions. Use
                
                    vzspectra --data
                    
                to plot the frequency spectrum of the recorded data or
                
                    vzspectra --testfuncs
                    
                to plot the frequency spectrum of the simulated test functions.
                '''))

    elif (args.data and not args.testfunc) or all(
            v is not True for v in [args.data, args.testfunc]):
        # default is to plot spectra of data if user does not specify either args.data or args.testfunc
        if Path('noisyData.npz').exists():
            userResponded = False
            print(
                textwrap.dedent('''
                  Detected that band-limited noise has been added to the data array.
                  Would you like to plot the amplitude/power spectrum of the noisy data? ([y]/n)
              
                  Enter 'q/quit' exit the program.
                  '''))
            while userResponded == False:
                answer = input('Action: ')
                if answer == '' or answer == 'y' or answer == 'yes':
                    print('Proceeding with noisy data...')
                    # read in the noisy data array
                    noisy = True
                    X = np.load('noisyData.npz')['noisyData']
                    userResponded = True
                elif answer == 'n' or answer == 'no':
                    print('Proceeding with noise-free data...')
                    # read in the recorded data array
                    noisy = False
                    X = np.load(str(datadir['recordedData']))
                    userResponded = True
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.\n')
                else:
                    print(
                        'Invalid response. Please enter \'y/yes\', \'n\no\', or \'q/quit\'.'
                    )

        else:
            # read in the recorded data array
            noisy = False
            X = np.load(str(datadir['recordedData']))

        # Load the windowing parameters for the receiver and time axes of
        # the 3D data array
        if Path('window.npz').exists():
            print('Detected user-specified window:\n')

            # For display/printing purposes, count receivers with one-based
            # indexing. This amounts to incrementing the rstart parameter by 1
            print('window @ receivers : start =', rstart + 1)
            print('window @ receivers : stop =', rstop)
            print('window @ receivers : step =', rstep, '\n')

            tu = plotParams['tu']
            if tu != '':
                print('window @ time : start = %0.2f %s' % (tstart, tu))
                print('window @ time : stop = %0.2f %s' % (tstop, tu))
            else:
                print('window @ time : start =', tstart)
                print('window @ time : stop =', tstop)
            print('window @ time : step =', tstep, '\n')

            # Apply the source window
            slabel = windowDict['slabel']
            sstart = windowDict['sstart']
            sstop = windowDict['sstop']
            sstep = windowDict['sstep']
            sinterval = np.arange(sstart, sstop, sstep)

            # For display/printing purposes, count recordings/sources with one-based
            # indexing. This amounts to incrementing the sstart parameter by 1
            print('window @ %s : start = %s' % (slabel, sstart + 1))
            print('window @ %s : stop = %s' % (slabel, sstop))
            print('window @ %s : step = %s\n' % (slabel, sstep))

            print('Applying window to data volume...')
            X = X[rinterval, :, :]
            X = X[:, tinterval, :]
            X = X[:, :, sinterval]

            # Apply tapered cosine (Tukey) window to time signals.
            # This ensures the fast fourier transform (FFT) used in
            # the definition of the matrix-vector product below is
            # acting on a function that is continuous at its edges.

            Nt = X.shape[1]
            peakFreq = pulseFun.peakFreq
            # Np : Number of samples in the dominant period T = 1 / peakFreq
            Np = int(round(1 / (tstep * dt * peakFreq)))
            # alpha is set to taper over 6 of the dominant period of the
            # pulse function (3 periods from each end of the signal)
            alpha = 6 * Np / Nt
            print('Tapering time signals with Tukey window: %d' %
                  (int(round(alpha * 100))) + '%')
            TukeyWindow = tukey(Nt, alpha)
            X *= TukeyWindow[None, :, None]

    elif not args.data and args.testfunc:
        if Path('samplingGrid.npz').exists():
            samplingGrid = np.load('samplingGrid.npz')
            x = samplingGrid['x']
            y = samplingGrid['y']
            tau = samplingGrid['tau']
            if 'z' in samplingGrid:
                z = samplingGrid['z']
            else:
                z = None

        else:
            sys.exit(
                textwrap.dedent('''
                    A sampling grid needs to be set up and test functions
                    computed before their Fourier spectrum can be plotted.
                    Enter:
                        
                        vzgrid --help
                        
                    from the command-line for more information on how to set up a
                    sampling grid.
                    '''))

        pulse = lambda t: pulseFun.pulse(t)
        velocity = pulseFun.velocity
        peakFreq = pulseFun.peakFreq
        peakTime = pulseFun.peakTime

        if Path('VZTestFuncs.npz').exists():
            print(
                '\nDetected that free-space test functions have already been computed...'
            )
            print(
                'Checking consistency with current space-time sampling grid...'
            )
            TFDict = np.load('VZTestFuncs.npz')

            if samplingIsCurrent(TFDict, receiverPoints, recordingTimes,
                                 velocity, tau, x, y, z, peakFreq, peakTime):
                X = TFDict['TFarray']
                sourcePoints = TFDict['samplingPoints']

            else:
                print('Recomputing test functions...')
                if tau[0] != 0:
                    tu = plotParams['tu']
                    if tu != '':
                        print(
                            'Recomputing test functions for focusing time %0.2f %s...'
                            % (tau[0], tu))
                    else:
                        print(
                            'Recomputing test functions for focusing time %0.2f...'
                            % (tau[0]))
                    X, sourcePoints = sampleSpace(receiverPoints,
                                                  recordingTimes - tau[0],
                                                  velocity, x, y, z, pulse)
                else:
                    X, sourcePoints = sampleSpace(receiverPoints,
                                                  recordingTimes, velocity, x,
                                                  y, z, pulse)

                if z is None:
                    np.savez('VZTestFuncs.npz',
                             TFarray=X,
                             time=recordingTimes,
                             receivers=receiverPoints,
                             peakFreq=peakFreq,
                             peakTime=peakTime,
                             velocity=velocity,
                             x=x,
                             y=y,
                             tau=tau,
                             samplingPoints=sourcePoints)
                else:
                    np.savez('VZTestFuncs.npz',
                             TFarray=X,
                             time=recordingTimes,
                             receivers=receiverPoints,
                             peakFreq=peakFreq,
                             peakTime=peakTime,
                             velocity=velocity,
                             x=x,
                             y=y,
                             z=z,
                             tau=tau,
                             samplingPoints=sourcePoints)

        else:
            print(
                '\nComputing free-space test functions for the current space-time sampling grid...'
            )
            if tau[0] != 0:
                if tu != '':
                    print(
                        'Computing test functions for focusing time %0.2f %s...'
                        % (tau[0], tu))
                else:
                    print(
                        'Computing test functions for focusing time %0.2f...' %
                        (tau[0]))
                X, sourcePoints = sampleSpace(receiverPoints,
                                              recordingTimes - tau[0],
                                              velocity, x, y, z, pulse)
            else:
                X, sourcePoints = sampleSpace(receiverPoints, recordingTimes,
                                              velocity, x, y, z, pulse)

            if z is None:
                np.savez('VZTestFuncs.npz',
                         TFarray=X,
                         time=recordingTimes,
                         receivers=receiverPoints,
                         peakFreq=peakFreq,
                         peakTime=peakTime,
                         velocity=velocity,
                         x=x,
                         y=y,
                         tau=tau,
                         samplingPoints=sourcePoints)
            else:
                np.savez('VZTestFuncs.npz',
                         TFarray=X,
                         time=recordingTimes,
                         receivers=receiverPoints,
                         peakFreq=peakFreq,
                         peakTime=peakTime,
                         velocity=velocity,
                         x=x,
                         y=y,
                         z=z,
                         tau=tau,
                         samplingPoints=sourcePoints)

    #==============================================================================
    # compute spectra
    freqs, amplitudes = compute_spectrum(X, tstep * dt, args.power)

    if args.power:
        plotLabel = 'power'
        plotParams['freq_title'] = 'Mean Power Spectrum'
        plotParams['freq_ylabel'] = 'Power'
    else:
        plotLabel = 'amplitude'
        plotParams['freq_title'] = 'Mean Amplitude Spectrum'
        plotParams['freq_ylabel'] = 'Amplitude'

    if args.data or all(v is not True for v in [args.data, args.testfunc]):
        if noisy:
            plotParams[
                'freq_title'] += ' [Noisy ' + plotParams['data_title'] + ']'
        else:
            plotParams['freq_title'] += ' [' + plotParams['data_title'] + ']'
    elif args.testfunc:
        plotParams['freq_title'] += ' [' + plotParams['tf_title'] + 's]'

    if args.fmin is not None:
        if args.fmin >= 0:
            if args.fmax is not None:
                if args.fmax > args.fmin:
                    plotParams['fmin'] = args.fmin
                    plotParams['fmax'] = args.fmax
                else:
                    sys.exit(
                        textwrap.dedent('''
                            RelationError: The maximum frequency of the %s spectrum plot must
                            be greater than the mininum frequency.
                            ''' % (plotLabel)))
            else:
                fmax = plotParams['fmax']
                if fmax > args.fmin:
                    plotParams['fmin'] = args.fmin
                else:
                    sys.exit(
                        textwrap.dedent('''
                            RelationError: The specified minimum frequency of the %s spectrum 
                            plot must be less than the maximum frequency.
                            ''' % (plotLabel)))
        else:
            sys.exit(
                textwrap.dedent('''
                    ValueError: The specified minimum frequency of the %s spectrum 
                    plot must be nonnegative.
                    ''' % (plotLabel)))

    #===============================================================================
    if args.fmax is not None:
        if args.fmin is not None:
            if args.fmin >= 0:
                if args.fmax > args.fmin:
                    plotParams['fmin'] = args.fmin
                    plotParams['fmax'] = args.fmax
                else:
                    sys.exit(
                        textwrap.dedent('''
                            RelationError: The maximum frequency of the %s spectrum plot must
                            be greater than the mininum frequency.
                            ''' % (plotLabel)))
            else:
                sys.exit(
                    textwrap.dedent('''
                        ValueError: The specified minimum frequency of the %s spectrum 
                        plot must be nonnegative.
                        ''' % (plotLabel)))
        else:
            fmin = plotParams['fmin']
            if args.fmax > fmin:
                plotParams['fmax'] = args.fmax
            else:
                sys.exit(
                    textwrap.dedent('''
                        RelationError: The specified maximum frequency of the %s spectrum 
                        plot must be greater than the minimum frequency.
                        ''' % (plotLabel)))
    elif plotParams['fmax'] is None:
        plotParams['fmax'] = np.max(freqs)

    #===================================================================================
    if args.fu is not None:
        plotParams['fu'] = args.fu

    if args.mode is not None:
        plotParams['view_mode'] = args.mode

    pickle.dump(plotParams, open('plotParams.pkl', 'wb'),
                pickle.HIGHEST_PROTOCOL)

    fig, ax = setFigure(num_axes=1, mode=plotParams['view_mode'])
    ax.plot(freqs, amplitudes, color=ax.linecolor, linewidth=ax.linewidth)
    ax.set_title(plotParams['freq_title'], color=ax.titlecolor)

    # get frequency units from plotParams
    fu = plotParams['fu']
    fmin = plotParams['fmin']
    fmax = plotParams['fmax']
    if fu != '':
        ax.set_xlabel('Frequency (%s)' % (fu), color=ax.labelcolor)
    else:
        ax.set_xlabel('Frequency', color=ax.labelcolor)
    ax.set_ylabel(plotParams['freq_ylabel'], color=ax.labelcolor)
    ax.set_xlim([fmin, fmax])
    ax.set_ylim(bottom=0)
    ax.fill_between(freqs,
                    0,
                    amplitudes,
                    where=(amplitudes > 0),
                    color='m',
                    alpha=ax.alpha)
    ax.locator_params(axis='y', nticks=6)
    ax.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
    plt.tight_layout()
    fig.savefig(plotLabel + 'Spectrum.' + args.format,
                format=args.format,
                bbox_inches='tight',
                facecolor=fig.get_facecolor())
    plt.show()
コード例 #12
0
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--nfo',
        action='store_true',
        help='''Compute or plot the singular-value decomposition of the
                        near-field operator (NFO).''')
    parser.add_argument(
        '--lso',
        action='store_true',
        help='''Compute or plot the singular-value decomposition of the
                        Lippmann-Schwinger operator (LSO).''')
    parser.add_argument(
        '--numVals',
        '-k',
        type=int,
        help='''Specify the number of singular values/vectors to compute.
                        Must a positive integer between 1 and the order of the square
                        input matrix.''')
    parser.add_argument(
        '--domain',
        '-d',
        type=str,
        choices=['time', 'freq'],
        help='''Specify whether to compute the singular-value decomposition in
                        the time domain or frequency domain. Default is set to frequency domain
                        for faster, more accurate performance.''')
    parser.add_argument(
        '--plot',
        '-p',
        action='store_true',
        help='''Plot the computed singular values and vectors.''')
    parser.add_argument(
        '--format',
        '-f',
        type=str,
        default='pdf',
        choices=['png', 'pdf', 'ps', 'eps', 'svg'],
        help=
        '''Specify the image format of the saved file. Accepted formats are png, pdf,
                        ps, eps, and svg. Default format is set to pdf.''')
    parser.add_argument(
        '--mode',
        type=str,
        choices=['light', 'dark'],
        required=False,
        help='''Specify whether to view plots in light mode for daytime viewing
                        or dark mode for nighttime viewing.
                        Mode must be either \'light\' or \'dark\'.''')
    args = parser.parse_args()

    if args.nfo and not args.lso:
        operatorType = 'near-field operator'
        inputType = 'data'
        try:
            SVD = np.load('NFO_SVD.npz')
            s = SVD['s']
            Uh = SVD['Uh']
            V = SVD['V']
            domain = SVD['domain']

        except FileNotFoundError:
            s, Uh, V, domain = None, None, None, 'freq'

    elif not args.nfo and args.lso:
        operatorType = 'Lippmann-Schwinger operator'
        inputType = 'test functions'
        try:
            SVD = np.load('LSO_SVD.npz')
            s = SVD['s']
            Uh = SVD['Uh']
            V = SVD['V']
            domain = SVD['domain']

        except FileNotFoundError:
            s, Uh, V, domain = None, None, None, 'freq'

    elif args.nfo and args.lso:
        sys.exit(
            textwrap.dedent('''
                UsageError: Please specify only one of the arguments \'--nfo\' or \'--lso\'.
                '''))

    else:
        sys.exit(
            textwrap.dedent('''
                For which operator would you like to compute or plot a singular-value decomposition?
                Enter:
                    
                    vzsvd --nfo
                
                for the near-field operator or
                
                    vzsvd --lso
                    
                for the Lippmann-Schwinger operator.
                '''))

    #==============================================================================
    # if an SVD already exists...
    if any(v is not None for v in
           [s, Uh, V]) and args.numVals is not None and args.plot is True:
        if args.numVals >= 1 and args.numVals == len(s):
            userResponded = False
            print(
                textwrap.dedent('''
                 A singular-value decomposition of the {s} for {n} values/vectors already exists. 
                 What would you like to do?
                 
                 Enter '1' to specify a new number of values/vectors to compute. (Default)
                 Enter '2' to recompute a singular-value decomposition for {n} values/vectors.
                 Enter 'q/quit' to exit.
                 '''.format(s=operatorType, n=args.numVals)))
            while userResponded == False:
                answer = input('Action: ')
                if answer == '' or answer == '1':
                    k = int(
                        input(
                            'Please specify the number of singular values/vectors to compute: '
                        ))
                    if isValid(k):
                        print('Proceeding with numVals = %s...' % (k))
                        userResponded = True
                        computeSVD = True
                        break
                    else:
                        break
                elif answer == '2':
                    k = args.numVals
                    print(
                        'Recomputing SVD of the %s for %s singular values/vectors...'
                        % (operatorType, k))
                    userResponded = True
                    computeSVD = True
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.\n')
                else:
                    print(
                        'Invalid response. Please enter \'1\', \'2\', or \'q/quit\'.'
                    )

        elif args.numVals >= 1 and args.numVals != len(s):
            k = args.numVals
            computeSVD = True

        elif args.numVals < 1:
            userResponded = False
            print(
                textwrap.dedent('''
                 ValueError: Argument '-k/--numVals' must be a positive integer 
                 between 1 and the order of the square input matrix. The parameter will
                 be set to the default value of 6.
                 What would you like to do?
                 
                 Enter '1' to specify a value of the parameter. (Default)
                 Enter '2' to proceed with the default value.
                 Enter 'q/quit' exit the program.
                 '''))
            while userResponded == False:
                answer = input('Action: ')
                if answer == '' or answer == '1':
                    k = int(
                        input(
                            'Please specify the number of singular values/vectors to compute: '
                        ))
                    if isValid(k):
                        print('Proceeding with numVals = %s...' % (k))
                        userResponded = True
                        computeSVD = True
                        break
                    else:
                        break
                elif answer == '2':
                    k = 6
                    print('Proceeding with the default value numVals = %s...' %
                          (k))
                    computeSVD = True
                    userResponded = True
                    break
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.\n')
                else:
                    print(
                        'Invalid response. Please enter \'1\', \'2\', or \'q/quit\'.'
                    )

    elif all(v is not None for v in
             [s, Uh, V]) and args.numVals is None and args.plot is True:
        computeSVD = False

    elif all(v is not None for v in
             [s, Uh, V]) and args.numVals is not None and args.plot is False:
        if args.numVals >= 1 and args.numVals == len(s):
            userResponded = False
            print(
                textwrap.dedent('''
                 A singular-value decomposition of the {s} for {n} values/vectors already exists. 
                 What would you like to do?
                 
                 Enter '1' to specify a new number of values/vectors to compute. (Default)
                 Enter '2' to recompute a singular-value decomposition for {n} values/vectors.
                 Enter 'q/quit' to exit.
                 '''.format(s=operatorType, n=args.numVals)))
            while userResponded == False:
                answer = input('Action: ')
                if answer == '' or answer == '1':
                    k = int(
                        input(
                            'Please specify the number of singular values/vectors to compute: '
                        ))
                    if isValid(k):
                        print('Proceeding with numVals = %s...' % (k))
                        userResponded = True
                        computeSVD = True
                        break
                    else:
                        break
                elif answer == '2':
                    k = args.numVals
                    print(
                        'Recomputing SVD of the %s for %s singular values/vectors...'
                        % (operatorType, k))
                    userResponded = True
                    computeSVD = True
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.\n')
                else:
                    print(
                        'Invalid response. Please enter \'1\', \'2\', or \'q/quit\'.'
                    )

        elif args.numVals >= 1 and args.numVals != len(s):
            k = args.numVals
            computeSVD = True

        elif args.numVals < 1:
            userResponded = False
            print(
                textwrap.dedent('''
                 ValueError: Argument '-k/--numVals' must be a positive integer 
                 between 1 and the order of the square input matrix. The parameter will
                 be set to the default value of 6.
                 What would you like to do?
                 
                 Enter '1' to specify a value of the parameter. (Default)
                 Enter '2' to proceed with the default value.
                 Enter 'q/quit' exit the program.
                 '''))
            while userResponded == False:
                answer = input('Action: ')
                if answer == '' or answer == '1':
                    k = int(
                        input(
                            'Please specify the number of singular values/vectors to compute: '
                        ))
                    if isValid(k):
                        print('Proceeding with numVals = %s...' % (k))
                        userResponded = True
                        computeSVD = True
                        break
                    else:
                        break
                elif answer == '2':
                    k = 6
                    print('Proceeding with the default value numVals = %s...' %
                          (k))
                    computeSVD = True
                    userResponded = True
                    break
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.\n')
                else:
                    print(
                        'Invalid response. Please enter \'1\', \'2\', or \'q/quit\'.'
                    )

    elif all(v is not None for v in
             [s, Uh, V]) and args.numVals is None and args.plot is False:
        sys.exit(
            textwrap.dedent('''
                No action specified. A singular-value decomposition of the %s
                for %s values/vectors already exists. Please specify at least one of '-k/--numVals'
                or '-p/--plot' arguments with 'vzsvd' command.
                ''' % (operatorType, len(s))))
    #==============================================================================
    # if an SVD does not already exist...
    elif any(v is None for v in
             [s, Uh, V]) and args.numVals is not None and args.plot is True:
        if args.numVals >= 1:
            computeSVD = True
            k = args.numVals

        elif args.numVals < 1:
            userResponded = False
            print(
                textwrap.dedent('''
                 ValueError: Argument '-k/--numVals' must be a positive integer 
                 between 1 and the order of the square input matrix. The parameter will
                 be set to the default value of 6.
                 What would you like to do?
                 
                 Enter '1' to specify a value of the parameter. (Default)
                 Enter '2' to proceed with the default value.
                 Enter 'q/quit' exit the program.
                 '''))
            while userResponded == False:
                answer = input('Action: ')
                if answer == '' or answer == '1':
                    k = int(
                        input(
                            'Please specify the number of singular values/vectors to compute: '
                        ))
                    if isValid(k):
                        print('Proceeding with numVals = %s...' % (k))
                        userResponded = True
                        computeSVD = True
                        break
                    else:
                        break
                elif answer == '2':
                    k = 6
                    print('Proceeding with the default value numVals = %s...' %
                          (k))
                    computeSVD = True
                    userResponded = True
                    break
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.\n')
                else:
                    print(
                        'Invalid response. Please enter \'1\', \'2\', or \'q/quit\'.'
                    )

    elif any(v is None for v in
             [s, Uh, V]) and args.numVals is None and args.plot is True:
        userResponded = False
        print(
            textwrap.dedent('''
             PlotError: A singular-value decomposition of the {s} does not exist. A plot will be
             generated after a singular-value decomposition has been computed.
             
             Enter '1' to specify a number of singular values/vectors to compute. (Default)
             Enter 'q/quit' to exit.
             '''.format(s=operatorType)))
        while userResponded == False:
            answer = input('Action: ')
            if answer == '' or answer == '1':
                k = int(
                    input(
                        'Please specify the number of singular values/vectors to compute: '
                    ))
                if isValid(k):
                    print('Proceeding with numVals = %s...' % (k))
                    userResponded = True
                    computeSVD = True
                    break
                else:
                    break
            elif answer == 'q' or answer == 'quit':
                sys.exit('Exiting program.\n')
            else:
                print('Invalid response. Please enter \'1\', or \'q/quit\'.')

    elif any(v is None for v in
             [s, Uh, V]) and args.numVals is not None and args.plot is False:
        if args.numVals >= 1:
            k = args.numVals
            computeSVD = True

        elif args.numVals < 1:
            userResponded = False
            print(
                textwrap.dedent('''
                 ValueError: Argument '-k/--numVals' must be a positive integer 
                 between 1 and the order of the square input matrix. The parameter will
                 be set to the default value of 6.
                 What would you like to do?
                 
                 Enter '1' to specify a value of the parameter. (Default)
                 Enter '2' to proceed with the default value.
                 Enter 'q/quit' exit the program.
                 '''))
            while userResponded == False:
                answer = input('Action: ')
                if answer == '' or answer == '1':
                    k = int(
                        input(
                            'Please specify the number of singular values/vectors to compute: '
                        ))
                    if isValid(k):
                        print('Proceeding with numVals = %s...' % (k))
                        userResponded = True
                        computeSVD = True
                        break
                    else:
                        break
                elif answer == '2':
                    k = 6
                    print('Proceeding with the default value numVals = %s...' %
                          (k))
                    computeSVD = True
                    userResponded = True
                    break
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.\n')
                else:
                    print(
                        'Invalid response. Please enter \'1\', \'2\', or \'q/quit\'.'
                    )

    elif any(v is None for v in
             [s, Uh, V]) and args.numVals is None and args.plot is False:
        sys.exit(
            textwrap.dedent('''
                Nothing to be done. A singular-value decomposition of the {s} does not exist.
                Please specify at least one of '-k/--numVals' or '-p/--plot'
                arguments with 'vzsvd' command.
                '''.format(s=operatorType)))
    #==============================================================================
    # Read in data files
    datadir = np.load('datadir.npz')
    receiverPoints = np.load(str(datadir['receivers']))
    recordingTimes = np.load(str(datadir['recordingTimes']))
    dt = recordingTimes[1] - recordingTimes[0]

    if Path('window.npz').exists():
        windowDict = np.load('window.npz')

        # Apply the receiver window
        rstart = windowDict['rstart']
        rstop = windowDict['rstop']
        rstep = windowDict['rstep']

        # Apply the time window
        tstart = windowDict['tstart']
        tstop = windowDict['tstop']
        tstep = windowDict['tstep']

        # Convert time window parameters to corresponding array indices
        Tstart = int(round(tstart / dt))
        Tstop = int(round(tstop / dt))

    else:
        rstart = 0
        rstop = receiverPoints.shape[0]
        rstep = 1

        tstart = recordingTimes[0]
        tstop = recordingTimes[-1]

        Tstart = 0
        Tstop = len(recordingTimes)
        tstep = 1

    # Apply the receiver window
    rinterval = np.arange(rstart, rstop, rstep)
    receiverPoints = receiverPoints[rinterval, :]

    # Apply the time window
    tinterval = np.arange(Tstart, Tstop, tstep)
    recordingTimes = recordingTimes[tinterval]

    # Used for getting time and frequency units
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))
    else:
        plotParams = default_params()

    if computeSVD:
        # get time units for printing time windows or time shifts
        tu = plotParams['tu']

        if args.nfo:

            if Path('noisyData.npz').exists():
                userResponded = False
                print(
                    textwrap.dedent('''
                      Detected that band-limited noise has been added to the data array.
                      Would you like to compute an SVD of the noisy data? ([y]/n)
                      
                      Enter 'q/quit' exit the program.
                      '''))
                while userResponded == False:
                    answer = input('Action: ')
                    if answer == '' or answer == 'y' or answer == 'yes':
                        print(
                            'Proceeding with singular-value decomposition using noisy data...'
                        )
                        # read in the noisy data array
                        X = np.load('noisyData.npz')['noisyData']
                        userResponded = True
                    elif answer == 'n' or answer == 'no':
                        print(
                            'Proceeding with singular-value decomposition using noise-free data...'
                        )
                        # read in the recorded data array
                        X = np.load(str(datadir['recordedData']))
                        userResponded = True
                    elif answer == 'q' or answer == 'quit':
                        sys.exit('Exiting program.\n')
                    else:
                        print(
                            'Invalid response. Please enter \'y/yes\', \'n\no\', or \'q/quit\'.'
                        )

            else:
                # read in the recorded data array
                X = np.load(str(datadir['recordedData']))

            if Path('window.npz').exists():
                print('Detected user-specified window:\n')

                # For display/printing purposes, count receivers with one-based
                # indexing. This amounts to incrementing the rstart parameter by 1
                print('window @ receivers : start =', rstart + 1)
                print('window @ receivers : stop =', rstop)
                print('window @ receivers : step =', rstep, '\n')

                if tu != '':
                    print('window @ time : start = %0.2f %s' % (tstart, tu))
                    print('window @ time : stop = %0.2f %s' % (tstop, tu))
                else:
                    print('window @ time : start =', tstart)
                    print('window @ time : stop =', tstop)
                print('window @ time : step =', tstep, '\n')

                # Apply the source window
                slabel = windowDict['slabel']
                sstart = windowDict['sstart']
                sstop = windowDict['sstop']
                sstep = windowDict['sstep']
                sinterval = np.arange(sstart, sstop, sstep)

                # For display/printing purposes, count recordings/sources with one-based
                # indexing. This amounts to incrementing the sstart parameter by 1
                print('window @ %s : start = %s' % (slabel, sstart + 1))
                print('window @ %s : stop = %s' % (slabel, sstop))
                print('window @ %s : step = %s\n' % (slabel, sstep))

                print('Applying window to data volume...')
                X = X[rinterval, :, :]
                X = X[:, tinterval, :]
                X = X[:, :, sinterval]
                Nr, Nt, Ns = X.shape

                # Apply tapered cosine (Tukey) window to time signals.
                # This ensures the fast fourier transform (FFT) used in
                # the definition of the matrix-vector product below is
                # acting on a function that is continuous at its edges.

                peakFreq = pulseFun.peakFreq
                # Np : Number of samples in the dominant period T = 1 / peakFreq
                Np = int(round(1 / (tstep * dt * peakFreq)))
                # alpha is set to taper over 6 of the dominant period of the
                # pulse function (3 periods from each end of the signal)
                alpha = 6 * Np / Nt
                print('Tapering time signals with Tukey window: %d' %
                      (int(round(alpha * 100))) + '%')
                TukeyWindow = tukey(Nt, alpha)
                X *= TukeyWindow[None, :, None]

            else:
                Nr, Nt, Ns = X.shape

        elif args.lso:

            if Path('samplingGrid.npz').exists():
                samplingGrid = np.load('samplingGrid.npz')
                x = samplingGrid['x']
                y = samplingGrid['y']
                tau = samplingGrid['tau']
                if 'z' in samplingGrid:
                    z = samplingGrid['z']
                else:
                    z = None

            else:
                sys.exit(
                    textwrap.dedent('''
                        A sampling grid needs to be set up before computing a
                        singular-value decomposition of the %s.
                        Enter:
                            
                            vzgrid --help
                            
                        from the command-line for more information on how to set up a
                        sampling grid.
                        ''' % (operatorType)))

            pulse = lambda t: pulseFun.pulse(t)
            velocity = pulseFun.velocity
            peakFreq = pulseFun.peakFreq
            peakTime = pulseFun.peakTime

            if Path('VZTestFuncs.npz').exists():
                print(
                    '\nDetected that free-space test functions have already been computed...'
                )
                print(
                    'Checking consistency with current space-time sampling grid...'
                )
                TFDict = np.load('VZTestFuncs.npz')

                if samplingIsCurrent(TFDict, receiverPoints, recordingTimes,
                                     velocity, tau, x, y, z, peakFreq,
                                     peakTime):
                    X = TFDict['TFarray']
                    sourcePoints = TFDict['samplingPoints']
                    print('Moving forward to SVD...')

                else:
                    print('Recomputing test functions...')
                    # set up the convolution times based on length of recording time interval
                    T = recordingTimes[-1] - recordingTimes[0]
                    convolutionTimes = np.linspace(-T, T,
                                                   2 * len(recordingTimes) - 1)

                    if tau[0] != 0:
                        if tu != '':
                            print(
                                'Recomputing test functions for focusing time %0.2f %s...'
                                % (tau[0], tu))
                        else:
                            print(
                                'Recomputing test functions for focusing time %0.2f...'
                                % (tau[0]))
                        X, sourcePoints = sampleSpace(
                            receiverPoints, convolutionTimes - tau[0],
                            velocity, x, y, z, pulse)
                    else:
                        X, sourcePoints = sampleSpace(receiverPoints,
                                                      convolutionTimes,
                                                      velocity, x, y, z, pulse)

                    if z is None:
                        np.savez('VZTestFuncs.npz',
                                 TFarray=X,
                                 time=recordingTimes,
                                 receivers=receiverPoints,
                                 peakFreq=peakFreq,
                                 peakTime=peakTime,
                                 velocity=velocity,
                                 x=x,
                                 y=y,
                                 tau=tau,
                                 samplingPoints=sourcePoints)
                    else:
                        np.savez('VZTestFuncs.npz',
                                 TFarray=X,
                                 time=recordingTimes,
                                 receivers=receiverPoints,
                                 peakFreq=peakFreq,
                                 peakTime=peakTime,
                                 velocity=velocity,
                                 x=x,
                                 y=y,
                                 z=z,
                                 tau=tau,
                                 samplingPoints=sourcePoints)

            else:
                print(
                    '\nComputing free-space test functions for the current space-time sampling grid...'
                )
                if tau[0] != 0:
                    if tu != '':
                        print(
                            'Computing test functions for focusing time %0.2f %s...'
                            % (tau[0], tu))
                    else:
                        print(
                            'Computing test functions for focusing time %0.2f...'
                            % (tau[0]))
                    X, sourcePoints = sampleSpace(receiverPoints,
                                                  recordingTimes - tau[0],
                                                  velocity, x, y, z, pulse)
                else:
                    X, sourcePoints = sampleSpace(receiverPoints,
                                                  recordingTimes, velocity, x,
                                                  y, z, pulse)

                if z is None:
                    np.savez('VZTestFuncs.npz',
                             TFarray=X,
                             time=recordingTimes,
                             receivers=receiverPoints,
                             peakFreq=peakFreq,
                             peakTime=peakTime,
                             velocity=velocity,
                             x=x,
                             y=y,
                             tau=tau,
                             samplingPoints=sourcePoints)
                else:
                    np.savez('VZTestFuncs.npz',
                             TFarray=X,
                             time=recordingTimes,
                             receivers=receiverPoints,
                             peakFreq=peakFreq,
                             peakTime=peakTime,
                             velocity=velocity,
                             x=x,
                             y=y,
                             z=z,
                             tau=tau,
                             samplingPoints=sourcePoints)

            Nr, Nt, Ns = X.shape

        #==============================================================================
        if args.domain is not None:
            domain = args.domain

        if domain == 'freq':
            # Transform convolutional operator into frequency domain and bandpass for efficient SVD
            print('Transforming %s to the frequency domain...' % (inputType))
            N = nextPow2(2 * Nt)
            X = np.fft.rfft(X, n=N, axis=1)

            if plotParams['fmax'] is None:
                freqs = np.fft.rfftfreq(N, tstep * dt)
                plotParams['fmax'] = np.max(freqs)

            # Apply the frequency window
            fmin = plotParams['fmin']
            fmax = plotParams['fmax']
            fu = plotParams['fu']  # frequency units (e.g., Hz)

            if fu != '':
                print('Applying bandpass filter: [%0.2f %s, %0.2f %s]' %
                      (fmin, fu, fmax, fu))
            else:
                print('Applying bandpass filter: [%0.2f, %0.2f]' %
                      (fmin, fmax))

            df = 1.0 / (N * tstep * dt)
            startIndex = int(round(fmin / df))
            stopIndex = int(round(fmax / df))

            finterval = np.arange(startIndex, stopIndex, 1)
            X = X[:, finterval, :]

        #==============================================================================
        # Compute the k largest singular values (which='LM') of the operator A
        # Singular values are elements of the vector 's'
        # Left singular vectors are columns of 'U'
        # Right singular vectors are columns of 'V'

        A = asConvolutionalOperator(X)

        if k == 1:
            print('Computing SVD of the %s for 1 singular value/vector...' %
                  (operatorType))
        else:
            print('Computing SVD of the %s for %s singular values/vectors...' %
                  (operatorType, k))
        startTime = time.time()
        U, s, Vh = svds(A, k, which='LM')
        endTime = time.time()
        print('Elapsed time:', humanReadable(endTime - startTime), '\n')

        # sort the singular values and corresponding vectors in descending order
        # (i.e., largest to smallest)
        index = s.argsort()[::-1]
        s = s[index]
        Uh = U[:, index].conj().T
        V = Vh[index, :].conj().T

        # Write binary output with numpy
        if args.nfo:
            np.savez('NFO_SVD.npz', s=s, Uh=Uh, V=V, domain=domain)
        elif args.lso:
            np.savez('LSO_SVD.npz', s=s, Uh=Uh, V=V, domain=domain)

    #==============================================================================
    if args.plot and all(v is not None for v in [s, Uh, V]):

        Nr = receiverPoints.shape[0]
        Nt = len(recordingTimes)

        try:
            k
        except NameError:
            k = len(s)

        if args.domain is not None and domain != args.domain:
            if domain == 'freq':
                s1 = 'time'
                s2 = 'frequency'
            else:
                s1 = 'frequency'
                s2 = 'time'
            sys.exit(
                textwrap.dedent('''
                    Error: Attempted to plot the singular-value decomposition in the %s
                    domain, but the decomposition was computed in the %s domain.
                    ''' % (s1, s2)))

        if domain == 'freq':
            # plot singular vectors in frequency domain
            N = nextPow2(2 * Nt)
            freqs = np.fft.rfftfreq(N, tstep * dt)

            if plotParams['fmax'] is None:
                plotParams['fmax'] = np.max(freqs)

            # Apply the frequency window
            fmin = plotParams['fmin']
            fmax = plotParams['fmax']
            df = 1.0 / (N * tstep * dt)

            startIndex = int(round(fmin / df))
            stopIndex = int(round(fmax / df))
            finterval = np.arange(startIndex, stopIndex, 1)
            freqs = freqs[finterval]
            fmax = freqs[-1]

            M = len(freqs)
            Ns = int(V.shape[0] / M)
            U = np.reshape(Uh.conj().T, (Nr, M, k))
            V = np.reshape(V, (Ns, M, k))

        else:  # domain == 'time'
            M = 2 * Nt - 1
            Ns = int(V.shape[0] / M)
            U = np.reshape(Uh.T, (Nr, M, k))
            V = np.reshape(V, (Ns, M, k))
            T = recordingTimes[-1] - recordingTimes[0]
            times = np.linspace(-T, T, M)

        if args.nfo:  # Near-field operator
            try:
                sinterval
            except NameError:
                if Path('window.npz').exists():
                    sstart = windowDict['sstart']
                    sstop = windowDict['sstop']
                    sstep = windowDict['sstep']
                else:
                    sstart = 0
                    sstop = Ns
                    sstep = 1

                sinterval = np.arange(sstart, sstop, sstep)

            if 'sources' in datadir:
                sourcePoints = np.load(str(datadir['sources']))
                sourcePoints = sourcePoints[sinterval, :]
            else:
                sourcePoints = None

        else:
            # if args.lso (Lippmann-Schwinger operator)

            # in the case of the Lippmann-Schwinger operator, 'sourcePoints'
            # correspond to sampling points, which should always exist.
            try:
                sourcePoints
            except NameError:
                if Path('VZTestFuncs.npz').exists():
                    TFDict = np.load('VZTestFuncs.npz')
                    sourcePoints = TFDict['samplingPoints']
                else:
                    sys.exit(
                        textwrap.dedent('''
                            Error: A sampling grid must exist and test functions computed
                            before a singular-value decomposition of the Lippmann-Schwinger
                            operator can be computed or plotted.
                            '''))

            sstart = 0
            sstop = sourcePoints.shape[0]
            sstep = 1
            sinterval = np.arange(sstart, sstop, sstep)

        # increment source/recording interval and receiver interval to be consistent
        # with one-based indexing (i.e., count from one instead of zero)
        sinterval += 1
        rinterval += 1
        rstart += 1
        sstart += 1

        if args.mode is not None:
            plotParams['view_mode'] = args.mode

        pickle.dump(plotParams, open('plotParams.pkl', 'wb'),
                    pickle.HIGHEST_PROTOCOL)

        remove_keymap_conflicts({'left', 'right', 'up', 'down', 'save'})
        if domain == 'freq':

            # plot the left singular vectors
            fig_lvec, ax_lvec_r, ax_lvec_i = setFigure(
                num_axes=2, mode=plotParams['view_mode'])
            ax_lvec_r.volume = U.real
            ax_lvec_i.volume = U.imag
            ax_lvec_r.index = 0
            ax_lvec_i.index = 0
            fig_lvec.suptitle('Left-Singular Vector',
                              color=ax_lvec_r.titlecolor,
                              fontsize=16)
            fig_lvec.subplots_adjust(bottom=0.27, top=0.86)
            leftTitle_r = vector_title('left', ax_lvec_r.index + 1, 'real')
            leftTitle_i = vector_title('left', ax_lvec_i.index + 1, 'imag')
            for ax, title in zip([ax_lvec_r, ax_lvec_i],
                                 [leftTitle_r, leftTitle_i]):
                left_im = plotFreqVectors(ax, ax.volume[:, :, ax.index], freqs,
                                          fmin, fmax, rstart, rinterval,
                                          receiverPoints, title, 'left',
                                          plotParams)

            lp0 = ax_lvec_r.get_position().get_points().flatten()
            lp1 = ax_lvec_i.get_position().get_points().flatten()
            left_cax = fig_lvec.add_axes([lp0[0], 0.12, lp1[2] - lp0[0], 0.03])
            lcbar = fig_lvec.colorbar(left_im,
                                      left_cax,
                                      orientation='horizontal')
            lcbar.outline.set_edgecolor(ax_lvec_r.cbaredgecolor)
            lcbar.ax.tick_params(axis='x', colors=ax_lvec_r.labelcolor)
            lcbar.ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
            lcbar.set_label('Amplitude',
                            labelpad=5,
                            rotation=0,
                            fontsize=12,
                            color=ax_lvec_r.labelcolor)
            fig_lvec.canvas.mpl_connect(
                'key_press_event', lambda event: process_key_vectors(
                    event, freqs, fmin, fmax, rstart, sstart, rinterval,
                    sinterval, receiverPoints, sourcePoints, plotParams,
                    'cmplx_left'))

            # plot the right singular vectors
            fig_rvec, ax_rvec_r, ax_rvec_i = setFigure(
                num_axes=2, mode=plotParams['view_mode'])
            ax_rvec_r.volume = V.real
            ax_rvec_i.volume = V.imag
            ax_rvec_r.index = 0
            ax_rvec_i.index = 0
            fig_rvec.suptitle('Right-Singular Vector',
                              color=ax_rvec_r.titlecolor,
                              fontsize=16)
            fig_rvec.subplots_adjust(bottom=0.27, top=0.86)
            rightTitle_r = vector_title('right', ax_rvec_r.index + 1, 'real')
            rightTitle_i = vector_title('right', ax_rvec_i.index + 1, 'imag')
            for ax, title in zip([ax_rvec_r, ax_rvec_i],
                                 [rightTitle_r, rightTitle_i]):
                right_im = plotFreqVectors(ax, ax.volume[:, :, ax.index],
                                           freqs, fmin, fmax, sstart,
                                           sinterval, sourcePoints, title,
                                           'right', plotParams)

            rp0 = ax_rvec_r.get_position().get_points().flatten()
            rp1 = ax_rvec_i.get_position().get_points().flatten()
            right_cax = fig_rvec.add_axes(
                [rp0[0], 0.12, rp1[2] - rp0[0], 0.03])
            rcbar = fig_rvec.colorbar(right_im,
                                      right_cax,
                                      orientation='horizontal')
            rcbar.outline.set_edgecolor(ax_rvec_r.cbaredgecolor)
            rcbar.ax.tick_params(axis='x', colors=ax_rvec_r.labelcolor)
            rcbar.ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
            rcbar.set_label('Amplitude',
                            labelpad=5,
                            rotation=0,
                            fontsize=12,
                            color=ax_lvec_r.labelcolor)
            fig_rvec.canvas.mpl_connect(
                'key_press_event', lambda event: process_key_vectors(
                    event, freqs, fmin, fmax, rstart, sstart, rinterval,
                    sinterval, receiverPoints, sourcePoints, plotParams,
                    'cmplx_right'))

        else:
            # domain == 'time'
            fig_vec, ax_lvec, ax_rvec = setFigure(num_axes=2,
                                                  mode=plotParams['view_mode'])

            ax_lvec.volume = U
            ax_lvec.index = 0
            leftTitle = vector_title('left', ax_lvec.index + 1)
            plotWiggles(ax_lvec, ax_lvec.volume[:, :, ax_lvec.index], times,
                        -T, T, rstart, rinterval, receiverPoints, leftTitle,
                        'left', plotParams)

            ax_rvec.volume = V
            ax_rvec.index = 0
            rightTitle = vector_title('right', ax_rvec.index + 1)
            plotWiggles(ax_rvec, ax_rvec.volume[:, :, ax_rvec.index], times,
                        -T, T, sstart, sinterval, sourcePoints, rightTitle,
                        'right', plotParams)
            fig_vec.tight_layout()
            fig_vec.canvas.mpl_connect(
                'key_press_event', lambda event: process_key_vectors(
                    event, times, -T, T, rstart, sstart, rinterval, sinterval,
                    receiverPoints, sourcePoints, plotParams))
        #==============================================================================
        # plot the singular values
        # figure and axis for singular values
        fig_vals, ax_vals = setFigure(num_axes=1, mode=plotParams['view_mode'])

        n = np.arange(1, k + 1, 1)
        kappa = s[0] / s[-1]  # condition number = max(s) / min(s)
        ax_vals.plot(n,
                     s,
                     '.',
                     clip_on=False,
                     markersize=9,
                     label=r'Condition Number: %0.1e' % (kappa),
                     color=ax_vals.pointcolor)
        ax_vals.set_xlabel('n', color=ax_vals.labelcolor)
        ax_vals.set_ylabel('$\sigma_n$', color=ax_vals.labelcolor)
        legend = ax_vals.legend(title='Singular Values',
                                loc='upper center',
                                bbox_to_anchor=(0.5, 1.25),
                                markerscale=0,
                                handlelength=0,
                                handletextpad=0,
                                fancybox=True,
                                shadow=True,
                                fontsize='large')
        legend.get_title().set_fontsize('large')
        ax_vals.set_xlim([1, k])
        ax_vals.set_ylim(bottom=0)
        ax_vals.locator_params(axis='y', nticks=6)
        ax_vals.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
        fig_vals.tight_layout()
        fig_vals.savefig('singularValues.' + args.format,
                         format=args.format,
                         bbox_inches='tight',
                         facecolor=fig_vals.get_facecolor())

        plt.show()
コード例 #13
0
ファイル: plotWiggles.py プロジェクト: aaronprunty/vezda
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', action='store_true',
                        help='Plot the recorded data. (Default)')
    parser.add_argument('--impulse', action='store_true',
                        help='Plot the simulated impulse responses.')
    parser.add_argument('--medium', type=str, default='constant', choices=['constant', 'variable'],
                        help='''Specify whether the background medium is constant or variable
                        (inhomogeneous). If argument is set to 'constant', the velocity defined in
                        the required 'pulsesFun.py' file is used. Default is set to 'constant'.''')
    parser.add_argument('--tu', type=str,
                        help='Specify the time units (e.g., \'s\' or \'ms\').')
    parser.add_argument('--au', type=str,
                        help='Specify the amplitude units (e.g., \'m\' or \'mm\').')
    parser.add_argument('--colormap', type=str, default=None, choices=['grays', 'seismic', 'native'],
                        help='specify a colormap for wiggle plots. Default is \'grays\'.')
    parser.add_argument('--pclip', type=float,
                        help='''Specify the percentage (0-1) of the peak amplitude to display. This
                        parameter is used for pcolormesh plots only. Default is set to 1.''')
    parser.add_argument('--title', type=str,
                        help='''Specify a title for the wiggle plot. Default title is
                        \'Data\' if \'--data\' is passed and 'Impulse Response' if \'--impulse\'
                        is passed.''')
    parser.add_argument('--format', '-f', type=str, default='pdf', choices=['png', 'pdf', 'ps', 'eps', 'svg'],
                        help='''Specify the image format of the saved file. Accepted formats are png, pdf,
                        ps, eps, and svg. Default format is set to pdf.''')
    parser.add_argument('--map', action='store_true',
                        help='''Plot a map of the receiver and source/search point locations. The current
                        source/search point will be highlighted. The boundary of the scatterer will also
                        be shown if available.''')
    parser.add_argument('--mode', type=str, choices=['light', 'dark'], required=False,
                        help='''Specify whether to view plots in light mode for daytime viewing
                        or dark mode for nighttime viewing.
                        Mode must be either \'light\' or \'dark\'.''')
    
    args = parser.parse_args()
    #==============================================================================
    # if a plotParams.pkl file already exists, load relevant parameters
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))
        
        # update parameters for wiggle plots based on passed arguments
        if args.mode is not None:
            plotParams['view_mode'] = args.mode
        
        if args.tu is not None:
            plotParams['tu'] = args.tu
        
        if args.au is not None:
            plotParams['au'] = args.au
            
        if args.colormap is not None:
            plotParams['wiggle_colormap'] = args.colormap
            
        if args.pclip is not None:
            if args.pclip >= 0 and args.pclip <= 1:
                plotParams['pclip'] = args.pclip
            else:
                print(textwrap.dedent(
                      '''
                      Warning: Invalid value passed to argument \'--pclip\'. Value must be
                      between 0 and 1.
                      '''))
            
        if args.title is not None:
            if args.data:
                plotParams['data_title'] = args.title
            elif args.impulse:
                plotParams['tf_title'] = args.title
    
    else: # create a plotParams dictionary file with default values
        plotParams = default_params()
        
        # update parameters for wiggle plots based on passed arguments
        if args.mode is not None:
            plotParams['view_mode'] = args.mode
        
        if args.tu is not None:
            plotParams['tu'] = args.tu
        
        if args.au is not None:
            plotParams['au'] = args.au
            
        if args.colormap is not None:
            plotParams['wiggle_colormap'] = args.colormap
        
        if args.title is not None:
            if args.data:
                plotParams['data_title'] = args.title
            elif args.impulse:
                plotParams['tf_title'] = args.title
    
    pickle.dump(plotParams, open('plotParams.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)

    #==============================================================================
    # Load the relevant data to plot
    datadir = np.load('datadir.npz')
    receiverPoints = np.load(str(datadir['receivers']))
    time = np.load(str(datadir['recordingTimes']))
    
    # Apply any user-specified windows
    receiverNumbers, tinterval, tstep, dt, sourceNumbers = get_user_windows()
    receiverPoints = receiverPoints[receiverNumbers, :]
    time = time[tinterval]
    
    if 'sources' in datadir:
        sourcePoints = np.load(str(datadir['sources']))
        sourcePoints = sourcePoints[sourceNumbers, :]
            
        # Check for source-receiver reciprocity
        reciprocalNumbers = get_unique_indices(sourcePoints, receiverPoints)
        reciprocalNumbers = np.asarray(reciprocalNumbers, dtype=np.int)
            
        if len(reciprocalNumbers) > 0:
            newReceivers = sourcePoints[reciprocalNumbers, :]
            reciprocity = True
        else:
            reciprocity = False
    else:
        reciprocity = False
        sourcePoints = None
    
    # Load the scatterer boundary, if it exists
    if 'scatterer' in datadir:
        scatterer = np.load(str(datadir['scatterer']))
    else:
        scatterer = None    
    
    if all(v is True for v in [args.data, args.impulse]):
        # User specified both data and impulse response for plotting
        # Send error message and exit.
        sys.exit(textwrap.dedent(
                '''
                Error: Cannot plot both recorded data and simulated impulse responses. Use
                
                    vzwiggles --data
                    
                to plot the recorded data or
                
                    vzwiggles --impulse
                    
                to plot the simulated impulse responses.
                '''))
    
    elif all(v is not True for v in [args.data, args.impulse]) or args.data:
        # If user did not specify which wiggles to plot, plot recorded data by default.
        # load the 3D data array into variable 'X'
        # X[receiver, time, source]
        wiggleType = 'data'
        X = load_data(domain='time', verbose=True)
        
        if reciprocity:
            Nr = len(receiverNumbers)
            Ns = len(sourceNumbers)
            M = len(reciprocalNumbers)
            
            XR = X[-M:, :, -Nr:]
            X = X[:Nr, :, :Ns]
            
            reciprocalNumbers += 1        
            ER = Experiment(XR, time, reciprocalNumbers, newReceivers,
                            receiverNumbers, receiverPoints, wiggleType)                        
        else:
            ER = None
        
    elif args.impulse:
        wiggleType = 'impulse'
        
        # Update time to convolution times
        T = time[-1] - time[0]
        time = np.linspace(-T, T, 2 * len(time) - 1)
        X, sourcePoints = load_impulse_responses(domain='time', medium=args.medium,
                                                 verbose=True, return_search_points=True)
        
        # Update sourceNumbers to match search points
        sourceNumbers = np.arange(sourcePoints.shape[0])
        
        if reciprocity:
            Nr = len(receiverNumbers)
            M = len(reciprocalNumbers)
            
            XR = X[-M:, :, :]
            X = X[:Nr, :, :]
            
            reciprocalNumbers += 1        
            ER = Experiment(XR, time, reciprocalNumbers, newReceivers,
                            sourceNumbers, sourcePoints, wiggleType)
        else:
            ER = None
    
    #==============================================================================        
    # increment source/receiver numbers to be consistent with
    # one-based indexing (i.e., count from one instead of zero)
    sourceNumbers += 1
    receiverNumbers += 1
    
    E = Experiment(X, time, receiverNumbers, receiverPoints,
                   sourceNumbers, sourcePoints, wiggleType)
        
    p = Plotter(E, ER)
    p.plot(scatterer, plotParams, args.map)
    plt.show()
コード例 #14
0
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', action='store_true',
                        help='Plot the frequency spectra of the recorded data. (Default)')
    parser.add_argument('--impulse', action='store_true',
                        help='Plot the frequency spectra of the simulated impulse responses.')
    parser.add_argument('--scaling', '-s', type=str, default='amp', choices=['amp', 'pow', 'psd'],
                        help='''Specify the scaling of the spectrum. Choose from amplitude ('amp'), power ('pow'),
                        or power spectral density ('psd') using Welch's method. Default is set to 'amp'.''')
    parser.add_argument('--nseg', type=int,
                        help='''Specify the approximate number of segments into which the time signal will be partitioned.
                        Only used if scaling is set to 'psd'. Increasing the number of segments increases computational
                        cost and the accuracy of the PSD estimate, but decreases frequency resolution. Default is set to 20.''')
    parser.add_argument('--fmin', type=float,
                        help='Specify the minimum frequency of the amplitude/power spectrum plot. Default is set to 0.')
    parser.add_argument('--fmax', type=float,
                        help='''Specify the maximum frequency of the amplitude/power spectrum plot. Default is set to the
                        maximum frequency bin based on the length of the time signal.''')
    parser.add_argument('--au', type=str,
                        help='Specify the amplitude units (e.g., Pa)')    
    parser.add_argument('--fu', type=str,
                        help='Specify the frequency units (e.g., Hz)')
    parser.add_argument('--format', '-f', type=str, default='pdf', choices=['png', 'pdf', 'ps', 'eps', 'svg'],
                        help='''specify the image format of the saved file. Accepted formats are png, pdf,
                        ps, eps, and svg. Default format is set to pdf.''')
    parser.add_argument('--mode', type=str, choices=['light', 'dark'], required=False,
                        help='''Specify whether to view plots in light mode for daytime viewing
                        or dark mode for nighttime viewing.
                        Mode must be either \'light\' or \'dark\'.''')
    args = parser.parse_args()
    
    #==============================================================================        
    # Get time window parameters
    tinterval, tstep, dt = get_user_windows()[1:4]
    datadir = np.load('datadir.npz')
    recordingTimes = np.load(str(datadir['recordingTimes']))
    recordingTimes = recordingTimes[tinterval]
    
    # Used for getting time and frequency units
    if Path('plotParams.pkl').exists():
        plotParams = pickle.load(open('plotParams.pkl', 'rb'))
    else:
        plotParams = default_params()
    
    if all(v is True for v in [args.data, args.impulse]):
        X = load_data(domain='time', verbose=True)
        Xlabel = plotParams['data_title']
        Xcolor = 'm'
        
        if 'testFuncs' not in datadir and not Path('VZImpulseResponses.npz').exists():
            X2 = load_impulse_responses(domain='time', medium='constant', verbose=True)
            X2label = plotParams['impulse_title']
            X2color = 'c'
        
        elif 'testFuncs' in datadir and not Path('VZImpulseResponses.npz').exists():
            X2 = load_impulse_responses(domain='time', medium='variable', verbose=True)
            X2label = plotParams['impulse_title']
            X2color = 'c'
            
        elif not 'testFuncs' in datadir and Path('VZImpulseResponses.npz').exists():
            X2 = load_impulse_responses(domain='time', medium='constant', verbose=True)
            X2label = plotParams['impulse_title']
            X2color = 'c'
            
        elif 'testFuncs' in datadir and Path('VZImpulseResponses.npz').exists():
            userResponded = False
            print(textwrap.dedent(
                 '''
                 Two files are available containing simulated impulse responses.
                 
                 Enter '1' to view the user-provided impulse responses. (Default)
                 Enter '2' to view the impulse responses computed by Vezda.
                 Enter 'q/quit' to exit.
                 '''))
            while userResponded == False:
                answer = input('Action: ')
                
                if answer == '' or answer == '1':
                    X2 = load_impulse_responses(domain='time', medium='variable', verbose=True)
                    X2label = plotParams['impulse_title']
                    X2color = 'c'
                    userResponded = True
                    break
                
                elif answer == '2':
                    X2 = load_impulse_responses(domain='time', medium='constant', verbose=True)
                    X2label = plotParams['impulse_title']
                    X2color = 'c'
                    userResponded = True
                    break
                
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.')
                
                else:
                    print('Invalid response. Please enter \'1\', \'2\', or \'q/quit\'.')
    
    elif (args.data and not args.impulse) or all(v is not True for v in [args.data, args.impulse]):
        # default is to plot spectra of data if user does not specify either args.data or args.impulse
        X = load_data(domain='time', verbose=True)
        Xlabel = plotParams['data_title']
        Xcolor = 'm'
        X2 = None
        
    elif not args.data and args.impulse:
        if 'testFuncs' not in datadir and not Path('VZImpulseResponses.npz').exists():
            X = load_impulse_responses(domain='time', medium='constant', verbose=True)
            Xlabel = plotParams['impulse_title']
            Xcolor = 'c'
        
        elif 'testFuncs' in datadir and not Path('VZImpulseResponses.npz').exists():
            X = load_impulse_responses(domain='time', medium='variable', verbose=True)
            Xlabel = plotParams['impulse_title']
            Xcolor = 'c'
            
        elif not 'testFuncs' in datadir and Path('VZImpulseResponses.npz').exists():
            X = load_impulse_responses(domain='time', medium='constant', verbose=True)
            Xlabel = plotParams['impulse_title']
            Xcolor = 'c'
                    
        elif 'testFuncs' in datadir and Path('VZImpulseResponses.npz').exists():
            userResponded = False
            print(textwrap.dedent(
                 '''
                 Two files are available containing simulated impulse responses.
                 
                 Enter '1' to view the user-provided impulse responses. (Default)
                 Enter '2' to view the impulse responses computed by Vezda.
                 Enter 'q/quit' to exit.
                 '''))
            while userResponded == False:
                answer = input('Action: ')
                
                if answer == '' or answer == '1':
                    X = load_impulse_responses(domain='time', medium='variable', verbose=True)
                    Xlabel = plotParams['impulse_title']
                    Xcolor = 'c'
                    userResponded = True
                    break
                
                elif answer == '2':
                    X = load_impulse_responses(domain='time', medium='constant', verbose=True)
                    Xlabel = plotParams['impulse_title']
                    Xcolor = 'c'
                    userResponded = True
                    break
                
                elif answer == 'q' or answer == 'quit':
                    sys.exit('Exiting program.')
                
                else:
                    print('Invalid response. Please enter \'1\', \'2\', or \'q/quit\'.')
        
        X2 = None
        
    #==============================================================================
    # compute spectra
    if args.nseg is not None:
        if args.nseg >= 1:
            nseg = args.nseg
        else:
            sys.exit(textwrap.dedent(
                    '''
                    Error: Optional argument '--nseg' must be greater than or equal to one.
                    '''))
    else:
        # if args.nseg is None
        nseg = 1
    
    freqs, A = compute_spectra(X, tstep * dt, scaling=args.scaling, nseg=nseg)
    if X2 is not None:
        Nt = X.shape[1]
        X2 = X2[:, -Nt:, :]
        freqs2, A2 = compute_spectra(X2, tstep * dt, scaling=args.scaling, nseg=nseg)
    else:
        freqs2 = None
        A2 = None
        
    if args.au is not None:
        plotParams['au'] = args.au
    
    if args.fu is not None:
        plotParams['fu'] = args.fu

    au = plotParams['au']
    fu = plotParams['fu']
    
    if args.scaling == 'amp':
        plotLabel = 'amplitude'
        plotParams['freq_title'] = 'Mean Amplitude Spectrum'
        if au != '':
            plotParams['freq_ylabel'] = 'Amplitude (%s)' %(au)
        else:
            plotParams['freq_ylabel'] = 'Amplitude'
   
    elif args.scaling == 'pow':
        plotLabel = 'power'
        plotParams['freq_title'] = 'Mean Power Spectrum'
        if au != '':
            plotParams['freq_ylabel'] = 'Power (%s)' %(au + '$^2$')
        else:
            plotParams['freq_ylabel'] = 'Power'
    
    elif args.scaling == 'psd':
        plotLabel = 'psd'
        plotParams['freq_title'] = 'Mean Power Spectral Density'
        if au != '' and fu != '':
            plotParams['freq_ylabel'] = 'Power/Frequency (%s)' %(au + '$^2/$' + fu)
        else:
            plotParams['freq_ylabel'] = 'Power/Frequency'
        
    if args.fmin is not None: 
        if args.fmin >= 0:
            if args.fmax is not None:
                if args.fmax > args.fmin:
                    plotParams['fmin'] = args.fmin
                    plotParams['fmax'] = args.fmax
                else:
                    sys.exit(textwrap.dedent(
                            '''
                            RelationError: The maximum frequency of the %s spectrum plot must
                            be greater than the mininum frequency.
                            ''' %(plotLabel)))   
            else:
                fmax = plotParams['fmax']
                if fmax > args.fmin:
                    plotParams['fmin'] = args.fmin
                else:
                    sys.exit(textwrap.dedent(
                            '''
                            RelationError: The specified minimum frequency of the %s spectrum 
                            plot must be less than the maximum frequency.
                            ''' %(plotLabel)))                                        
        else:
            sys.exit(textwrap.dedent(
                    '''
                    ValueError: The specified minimum frequency of the %s spectrum 
                    plot must be nonnegative.
                    ''' %(plotLabel)))
            
    #===============================================================================
    if args.fmax is not None:
        if args.fmin is not None:
            if args.fmin >= 0:
                if args.fmax > args.fmin:
                    plotParams['fmin'] = args.fmin
                    plotParams['fmax'] = args.fmax
                else:
                    sys.exit(textwrap.dedent(
                            '''
                            RelationError: The maximum frequency of the %s spectrum plot must
                            be greater than the mininum frequency.
                            ''' %(plotLabel)))
            else:
                sys.exit(textwrap.dedent(
                        '''
                        ValueError: The specified minimum frequency of the %s spectrum 
                        plot must be nonnegative.
                        ''' %(plotLabel)))
        else:
            fmin = plotParams['fmin']
            if args.fmax > fmin:
                plotParams['fmax'] = args.fmax
            else:
                sys.exit(textwrap.dedent(
                        '''
                        RelationError: The specified maximum frequency of the %s spectrum 
                        plot must be greater than the minimum frequency.
                        ''' %(plotLabel)))
    elif plotParams['fmax'] is None:
        plotParams['fmax'] = np.max(freqs)
                
    #===================================================================================
    if args.mode is not None:
        plotParams['view_mode'] = args.mode
    
    pickle.dump(plotParams, open('plotParams.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
    
    fig, ax = setFigure(num_axes=1, mode=plotParams['view_mode'])
    
    if args.scaling == 'psd':
        plotscale = 'log'
    else:
        plotscale = 'linear'
    
    gradient_fill(freqs, A, fill_color=Xcolor, ax=ax, scale=plotscale, zorder=2)
    handles, labels = [], []
    handles.append(Line2D([0], [0], color=Xcolor, lw=4))
    labels.append(Xlabel)
    if all(v is not None for v in [freqs2, A2]):
        gradient_fill(freqs2, A2, fill_color=X2color, ax=ax, scale=plotscale, zorder=1)
        handles.append(Line2D([0], [0], color=X2color, lw=4))
        labels.append(X2label)
        
    ax.legend(handles, labels, fancybox=True, framealpha=1, shadow=True, loc='upper right')
    ax.set_title(plotParams['freq_title'], color=ax.titlecolor)
    
    fmin = plotParams['fmin']
    fmax = plotParams['fmax']
    if fu != '':
        ax.set_xlabel('Frequency (%s)' %(fu), color=ax.labelcolor)
    else:
        ax.set_xlabel('Frequency', color=ax.labelcolor)
    ax.set_ylabel(plotParams['freq_ylabel'], color=ax.labelcolor)
    ax.set_xlim([fmin, fmax])
    if args.scaling != 'psd':
        ax.set_ylim(bottom=0)
        ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    
    plt.tight_layout()
    fig.savefig(plotLabel + 'Spectrum.' + args.format, format=args.format, bbox_inches='tight', facecolor=fig.get_facecolor())
    plt.show()