def NLL(amp, f0, g):
     num1 = ( bu.damped_osc_amp(keys[b], amp, f0, g) - mag[b]/magscale )**2
     num2 = ( bu.damped_osc_phase(keys[b], 1.0, f0, g, phase0=phase0) \
                     - unphase[b])**2 
     denom1 = mag_weights[b]**2
     denom2 = phase_weights[b]**2
     return np.sum( (num1 / denom1) + (num2 / denom2) )
def make_tf_array(freqs, Hfunc, suppress_off_diag=False, smoothing=1.0, \
                  adjust_phase=False, adjust_phase_dict={}):
    '''Makes a 3x3xNfreq complex-valued array for use in diagonalization
           INPUTS: freqs, array of frequencies
                   Hfunc, output from build_Hfuncs()

           OUTPUTS: Harr, array output'''

    fits, interps = Hfunc

    Nfreq = len(freqs)
    Harr = np.zeros((Nfreq,3,3),dtype=np.complex128)

    ### Sample the Hfunc at the desired frequencies
    for drive in [0,1,2]:
        for resp in [0,1,2]:
            if suppress_off_diag and (drive != resp):
                continue
            interpolate = interps[resp][drive]
            fit = fits[resp][drive]
            if interpolate:
                oldfreqs = fit[0][0]
                oldmag = fit[0][1]
                oldphase = fit[1][1]

                mw = (1.0 / np.std(oldmag[:10])) * np.ones(len(oldfreqs))
                pw = (1.0 / np.std(oldphase[:10])) * np.ones(len(oldfreqs))
                magfunc = interp.UnivariateSpline(oldfreqs, oldmag, w=mw, k=2, s=smoothing)
                phasefunc = interp.UnivariateSpline(oldfreqs, oldphase, w=pw, k=2, s=smoothing)

                mag_extrap = \
                    make_extrapolator( magfunc, xs=oldfreqs, ys=oldmag, \
                                       pts=fit[0][2], arb_power_law=fit[0][3])

                phase_extrap = \
                    make_extrapolator( phasefunc, xs=oldfreqs, ys=oldphase, \
                                       pts=fit[1][2], arb_power_law=fit[1][3], semilogx=True)
                mag = mag_extrap(freqs)
                phase = phase_extrap(freqs)

            else:
                mag = bu.damped_osc_amp(freqs, *fit[0])
                phase = bu.damped_osc_phase(freqs, *fit[1], phase0=fit[2])

            if adjust_phase:
                adjust_key = '{:d}{:d}'.format(drive, resp)
                if adjust_key in adjust_phase_dict:
                    adjust_freqs = list(adjust_phase_dict[adjust_key].keys())
                    for freq in adjust_freqs:
                        freqind = np.argmin( np.abs(freqs - freq) )
                        phase[freqind] = adjust_phase_dict[adjust_key][freq]


            Harr[:,drive,resp] = mag * np.exp(1.0j * phase)

    ### Make the TF at the DC bin equal to the TF at the first 
    ### actual frequency bin. If using analytic functions for damped
    ### harmonic oscillators, these two should already be the same.
    ### If using an interpolating function with custom extrapolation, 
    ### this avoids singular matrices because usually the z-dirction 
    ### response goes to 0 at 0 frequency
    # Harr[0,:,:] = Harr[1,:,:]
    ### THIS IS COMMENTED BECAUSE NEW DATA DOESN"T INCLUDE A DC VALUE

    ### numPy's matrix inverse can handle an array of matrices
    Hout = np.linalg.inv(Harr)

    ### If the diagonal components are suppressed, sometimes the 
    ### inversion does some weird stuff so explicitly set the 
    ### off-diagonal components to 0 again
    if suppress_off_diag:
        for drive in [0,1,2]:
            for resp in [0,1,2]:
                if drive == resp:
                    continue
                Hout[:,drive,resp] = 0.0 + 0.0j

    return Hout
 def NLL(amp, f0, g):
     num = ( bu.damped_osc_amp(keys[b], amp, f0, g) - mag[b]/magscale )**2 
     denom = mag_weights[b]**2
     return np.sum(num / denom)
def build_Hfuncs(Hout_cal, fit_freqs = [10.,600.], fpeaks=[400.,400.,200.], \
                 weight_peak=False, weight_lowf=False, lowf_weight_fac=0.1, \
                 lowf_thresh=120., linearize=False, ignore_phase=False,
                 weight_phase=False, plot=False, plot_fits=False, \
                 plot_inits=False, plot_off_diagonal=False, \
                 grid = False, fit_osc_sum=False, deweight_peak=False, \
                 interpolate = False, max_freq=600, num_to_avg=5, \
                 real_unwrap=[[0, 1, 1], [1, 0, 1], [1, 1, 0]], \
                 derpy_unwrap=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], \
                 interps=[[0, 1, 1], [1, 0, 1], [1, 1, 0]], \
                 smoothing=1.0, amp_xlim=(), amp_ylim=(), \
                 phase_xlim=(), phase_ylim=()):
    # Build the calibrated transfer function array
    # i.e. transfer matrices at each frequency and fit functions to each component

    keys = list(Hout_cal.keys())
    keys.sort()

    keys = np.array(keys)

    mats = []
    for freq in keys:
        mat = Hout_cal[freq]
        mats.append(mat)

    mats = np.array(mats)
    fits = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]

    if plot:
        # figsize = (8,6)
        figsize = (10,8)
        # f1, axarr1 = plt.subplots(3,3, sharex=True, sharey='row', figsize=figsize)
        f1, axarr1 = plt.subplots(3,3, sharex=True, sharey='row', figsize=figsize)
        f2, axarr2 = plt.subplots(3,3, sharex=True, sharey='row', figsize=figsize)

        # colors = bu.get_color_map(5, cmap='inferno')
        data_color, fit_color = ['k', 'C1']

    for drive in [0,1,2]:
        for resp in [0,1,2]:

            interpolate = interps[resp][drive]

            d_unwrap = derpy_unwrap[resp][drive]
            r_unwrap = real_unwrap[resp][drive]

            ### A shitty shit factor that I have to add because folks operating
            ### the new trap don't know how to scale things
            # if resp == 2:
            #     plot_fac = 3e-7
            # else:
            #     plot_fac = 1.0
            plot_fac = 1

            ### Build the array of TF magnitudes and remove NaNs
            mag = np.abs(mats[:,resp,drive])
            nans = np.isnan(mag)
            for nanind, boolv in enumerate(nans):
                if boolv:
                    mag[nanind] = mag[nanind-1]

            ### Build the array of TF phases and remove NaNs
            phase = np.angle(mats[:,resp,drive])
            nans2 = np.isnan(phase)
            for nanind, boolv in enumerate(nans2):
                if boolv:
                    phase[nanind] = phase[nanind-1]

            ### Unwrap the phase
            if d_unwrap:
                pos_inds = phase > np.pi / 4.0
                unphase = phase - 2.0 * np.pi * pos_inds
            else:
                unphase = np.copy(phase)

            if r_unwrap:
                unphase = np.unwrap(unphase)

            # if drive == resp == 1:
            #     plt.figure()
            #     plt.semilogx(keys, phase)
            #     plt.semilogx(keys, np.unwrap(phase))
            #     plt.semilogx(keys, unphase)
            #     plt.show()

            #     input()

            b1 = keys >= fit_freqs[0]
            b2 = keys <= fit_freqs[1]
            b = b1 * b2

            if interpolate:
                num = num_to_avg
                mw = (1.0 / np.std(mag[b][:10])) * np.ones(np.sum(b))
                pw = (1.0 / np.std(unphase[b][:10])) * np.ones(np.sum(b))
                magfunc = interp.UnivariateSpline(keys[b], mag[b], w=mw, k=2, s=smoothing)
                phasefunc = interp.UnivariateSpline(keys[b], unphase[b], w=pw, k=2, s=smoothing)
                # magfunc = interp.interp1d(keys[b], mag[b], kind='quadratic')
                # phasefunc = interp.interp1d(keys[b], unphase[b], kind='quadratic')

                if resp == 2:
                    arb_power_law_mag = (True, True)
                    arb_power_law_phase = (True, True)
                    if drive == 2:
                        pts_mag = (4, 30)
                        pts_phase = (4, 20)
                    else:
                        pts_mag = (10, 30)
                        pts_phase = (10, 20)
                else:
                    arb_power_law_mag = (False, True)
                    arb_power_law_phase = (False, True)
                    pts_mag = (10, 30)
                    pts_phase = (10, 20)

                magfunc2 = make_extrapolator(magfunc, xs=keys[b], ys=mag[b], \
                                                pts=pts_mag, order=(0, 0), \
                                                arb_power_law=arb_power_law_mag)
                phasefunc2 = make_extrapolator(phasefunc, xs=keys[b], ys=unphase[b], \
                                                pts=pts_phase, order=(0, 0), \
                                                arb_power_law=arb_power_law_phase, semilogx=True)

                mag_params = (keys[b], mag[b], pts_mag, arb_power_law_mag)
                phase_params = (keys[b], unphase[b], pts_phase, arb_power_law_phase)
                fits[resp][drive] = (mag_params, phase_params)

                if plot:
                    pts = np.linspace(np.min(keys) / 2., np.max(keys) * 2., len(keys) * 100)
            
                    axarr1[resp,drive].loglog(keys, mag * plot_fac, 'o', ms=6, color=data_color)
                    axarr2[resp,drive].semilogx(keys, unphase, 'o', ms=6, color=data_color)

                    if plot_fits and ((resp == drive) or plot_off_diagonal):
                        axarr1[resp,drive].loglog(pts, magfunc2(pts) * plot_fac, color=fit_color, \
                                                    linestyle='--', linewidth=2, alpha=1.0)
                        axarr2[resp,drive].semilogx(pts, phasefunc2(pts), color=fit_color, \
                                                    linestyle='--', linewidth=2, alpha=1.0)


            if not interpolate:

                magscale = np.mean(mag[b])
                phasescale = np.mean(unphase[b])

                fpeak = keys[np.argmax(mag)]
                if fpeak < 100.0:
                    fpeak = fpeaks[resp]

                ### Make initial guess based on high-pressure thermal spectra fits
                if (drive == 2) or (resp == 2):
                    ### Z-direction is considerably different than X or Y
                    g = fpeak * 2.0
                else:
                    g = fpeak * 0.15

                amp0 = np.mean( mag[b][:np.argmin(np.abs(keys[b] - 100.0))] ) \
                                * ((2.0 * np.pi * fpeak)**2)

                ### Construct initial paramter arrays
                p0_mag = [amp0/magscale, fpeak, g]
                p0_phase = [1., fpeak, g]  ### includes arbitrary smearing amplitude

                ### Construct weights if desired
                npkeys = np.array(keys)
                mag_weights = np.zeros_like(npkeys) + 1.
                phase_weights = np.zeros(len(npkeys)) + 1.

                if (weight_peak or deweight_peak):
                    if weight_peak:
                        fac = -0.7
                    else:
                        if drive != resp:
                            fac = 1.0
                        else:
                            fac = 1.0
                    mag_weights = mag_weights + fac * np.exp(-(npkeys-fpeak)**2 / (2 * 50) )
                    phase_weights = phase_weights + fac * np.exp(-(npkeys-fpeak)**2 / (2 * 50) )

                if weight_lowf:
                    ind = np.argmin(np.abs(npkeys - lowf_thresh))
                    # if drive != resp:
                    mag_weights[:ind] *= lowf_weight_fac #0.01
                    phase_weights[:ind] *= lowf_weight_fac #0.01

                # mag_weights *= amp0 / ((2.0 * np.pi * fpeak)**2)
                # mag_weights *= np.sqrt(mag)

                # plt.figure()
                # plt.loglog(keys[b], mag[b]/magscale)
                # plt.loglog(keys[b], mag_weights[b])
                # plt.show()


                lowkey = np.argmin(np.abs(keys[b]-10.0))
                highkey = np.argmin(np.abs(keys[b]-100.0))
                avg = np.mean(unphase[b][lowkey:highkey])

                mult = np.argmin(np.abs(avg - np.array([0, np.pi, -1.0*np.pi])))
                if mult == 2:
                    mult = -1
                phase0 = np.pi * mult

                # def NLL(amp, f0, g):
                #     mag_term = ( bu.damped_osc_amp(keys[b], amp, f0, g) - mag[b] )**2 / mag_weights[b]**2
                #     phase_term = ( bu.damped_osc_phase(keys[b], 1.0, f0, g, phase0=phase0) - unphase[b] )**2 \
                #                         / phase_weights[b]**2
                #     return np.sum(mag_term + phase_term)

                # if linearize:
                #     def NLL(amp, f0, g):
                #         num = ( np.log(bu.damped_osc_amp(keys[b], amp, f0, g)) - np.log(mag[b]) )**2 
                #         denom = np.log(mag_weights[b])**2
                #         return np.sum(num / denom)
                if ignore_phase:
                    def NLL(amp, f0, g):
                        num = ( bu.damped_osc_amp(keys[b], amp, f0, g) - mag[b]/magscale )**2 
                        denom = mag_weights[b]**2
                        return np.sum(num / denom)
                else:
                    def NLL(amp, f0, g):
                        num1 = ( bu.damped_osc_amp(keys[b], amp, f0, g) - mag[b]/magscale )**2
                        num2 = ( bu.damped_osc_phase(keys[b], 1.0, f0, g, phase0=phase0) \
                                        - unphase[b])**2 
                        denom1 = mag_weights[b]**2
                        denom2 = phase_weights[b]**2
                        return np.sum( (num1 / denom1) + (num2 / denom2) )


                m = Minuit(NLL,
                           amp = amp0/magscale, # set start parameter
                           # fix_amp = 'True', # you can also fix it
                           limit_amp = (0.0, np.inf),
                           f0 = fpeak, # set start parameter
                           # fix_f0 = 'True', 
                           limit_f0 = (0.0, np.inf),
                           g = g, # set start parameter
                           # fix_g = "True", 
                           limit_g = (0, np.inf),
                           errordef = 1,
                           print_level = 1, 
                           pedantic=False)
                m.migrad(ncall=500000)

                # plt.figure()
                # m.draw_mnprofile('f0')
                # plt.figure()
                # m.draw_mncontour('amp', 'f0')
                # input()

                popt_mag = [m.values['amp']*magscale, m.values['f0'], m.values['g']]
                popt_phase = [1.0, m.values['f0'], m.values['g']]

                print()
                print(popt_mag)
                print()

                fits[resp][drive] = (popt_mag, popt_phase, phase0)

                # if drive == resp:
                #     print()
                #     print(drive)
                #     print(popt_mag)
                #     print(pcov_mag)
                #     print(popt_phase)
                #     print(pcov_phase)
                #     print()

                if plot:
                    pts = np.linspace(np.min(keys) / 2., np.max(keys) * 2., len(keys) * 100)

                    axarr1[resp,drive].loglog(keys, mag, 'o', ms=6, color=data_color)
                    axarr2[resp,drive].semilogx(keys, unphase, 'o', ms=6, color=data_color)

                    if plot_fits:
                        fitmag = bu.damped_osc_amp(pts, *popt_mag)
                        axarr1[resp,drive].loglog(pts, fitmag, ls='-', \
                                                  color=fit_color, linewidth=2)

                        fitphase = bu.damped_osc_phase(pts, *popt_phase, phase0=phase0)
                        axarr2[resp,drive].semilogx(pts, fitphase, ls='-', \
                                                    color=fit_color, linewidth=2)

                    if plot_inits:
                        maginit = bu.damped_osc_amp(pts, *p0_mag)
                        axarr1[resp,drive].loglog(pts, maginit, ls='-', color='k', linewidth=2)

                        phaseinit = bu.damped_osc_phase(pts, *p0_phase, phase0=phase0)
                        axarr2[resp,drive].semilogx(pts, phaseinit, \
                                                    ls='-', color='k', linewidth=2)


    if plot:

        ax_to_pos = {0: 'X', 1: 'Y', 2: 'Z'}
        for drive in [0,1,2]:
            # axarr1[0, drive].set_title("Drive direction {:s}".format(ax_to_pos[drive]))
            axarr1[0, drive].set_title("Drive $\\widetilde{{F}}_{:s}$".format(ax_to_pos[drive]))
            axarr1[2, drive].set_xlabel("Frequency [Hz]")
            # axarr1[2, drive].set_xticks([1, 10, 100])
            axarr1[2, drive].set_xticks([1, 10, 100, 1000])
            if amp_xlim:
                axarr1[2, drive].set_xlim(*amp_xlim)

            # axarr2[0, drive].set_title("Drive direction {:s}".format(ax_to_pos[drive]))
            axarr2[0, drive].set_title("Drive $\\widetilde{{F}}_{:s}$".format(ax_to_pos[drive]))
            axarr2[2, drive].set_xlabel("Frequency [Hz]")
            # axarr2[2, drive].set_xticks([1, 10, 100])
            axarr2[2, drive].set_xticks([1, 10, 100, 1000])
            # axarr2[2, drive].set_xlim(2.5, 1800)
            if phase_xlim:
                axarr2[2, drive].set_xlim(*phase_xlim)


        for response in [0,1,2]:

            mag_major_locator = LogLocator(base=10.0, numticks=30)
            mag_minor_locator = LogLocator(base=10.0, numticks=30)

            # axarr1[response, 0].set_ylabel("Resp {:s} [Arb/N]".format(ax_to_pos[response]))
            axarr1[response, 0].set_ylabel("$| \\widetilde{{R}}_{:s} / \\widetilde{{F}}_i |$ [Arb/N]"\
                                                .format(ax_to_pos[response]))
            # axarr1[response, 0].set_yticks([1e9, 1e11, 1e13])
            axarr1[response, 0].yaxis.set_major_locator(mag_major_locator)
            # axarr1[response, 0].set_yticks([1e8, 1e10, 1e12, 1e14], minor=True)
            axarr1[response, 0].yaxis.set_minor_locator(mag_minor_locator)
            axarr1[response, 0].yaxis.set_minor_formatter(NullFormatter())
            # if response != 2:
            #     axarr1[response, 0].set_yticks([1e9, 1e10, 1e11])
            #     axarr1[response, 0].set_ylim(3e9, 1.3e11)

            if amp_ylim:
                if type(amp_ylim) == tuple:
                    axarr1[response, 0].set_ylim(*amp_ylim)
                elif type(amp_ylim) == list:
                    axarr1[response, 0].set_ylim(*(amp_ylim[response]))
                else:
                    print('custom y-axis limits provided are not of the right type')

            # axarr2[response, 0].set_ylabel("Resp {:s} [$\\pi\\cdot$rad]".format(ax_to_pos[response]))
            axarr2[response, 0].set_ylabel("$ \\angle \\, \\widetilde{{R}}_{:s} / \\widetilde{{F}}_i $ [rad]"\
                                                .format(ax_to_pos[response]))
            axarr2[response, 0].set_yticks([-2*np.pi, -1*np.pi, 0, 1*np.pi, 2*np.pi])
            axarr2[response, 0].set_yticklabels(['-2$\\pi$', '-$\\pi$', '0', '$\\pi$', '2$\\pi$'])
            axarr2[response, 0].set_ylim(-1.3*np.pi, 1.3*np.pi)

            if phase_ylim:
                axarr2[response, 0].set_ylim(*phase_ylim)


        f1.suptitle("Magnitude of Transfer Function", fontsize=18)
        f2.suptitle("Phase of Transfer Function", fontsize=18)

        f1.tight_layout()
        f2.tight_layout()

        f1.subplots_adjust(wspace=0.065, hspace=0.1, top=0.9)
        f2.subplots_adjust(wspace=0.065, hspace=0.1, top=0.9)


        if grid:
            for d in [0,1,2]:
                for r in [0,1,2]:
                    axarr1[r,d].grid(True, which='both')
                    axarr2[r,d].grid(True, which='both')

        plt.show()

    return fits, interps
        if plot_raw_data:
            plt.loglog(freqs, asd)
            plt.show()

        p0 = [np.std(df.pos_data[i]) * df.nsamp, 300, 100]

        # try:
        popt, pcov = opti.curve_fit(bu.damped_osc_amp, freqs[inds], asd[inds], p0=p0, \
                                    maxfev=10000)
        # except:
        #     popt = p0

        if fit_debug:
            plt.loglog(freqs, asd)
            plt.loglog(freqs, bu.damped_osc_amp(freqs, *p0), lw=2, ls='--', \
                        color='k', label='init guess')
            plt.loglog(freqs, bu.damped_osc_amp(freqs, *popt), lw=2, ls='--', \
                        color='r', label='fit result')
            plt.legend(fontsize=10)
            plt.show()
            input()

        fit_freqs[filind, i] = np.abs(popt[1])

fig, ax = plt.subplots(1, 1, figsize=(8, 6))

ax.plot(zavg, fit_freqs[:, 0], label='X freqs')
ax.plot(zavg, fit_freqs[:, 1], label='Y freqs')

ax.set_ylabel('Frequency [Hz]')
    label = '$ \\gamma = {:0.2g}$ Hz, [$\\gamma (p) = {:0.2g}$ Hz]'\
                .format(gamma_fit, gamma_calc / (2.0 * np.pi))
    # label = '$\\omega_0 = {:0.1f}$ Hz'.format(result[6])
    gammas[0].append(gamma_calc)
    gammas[1].append(gamma_fit)
    gammas[2].append(gamma_fit_2)

    freqs = result['freqs']
    fit_asd = result['fit_asd']
    plt.loglog(freqs,
               fit_asd * fac,
               color=colors[resultind],
               alpha=0.7,
               label=label)
    plt.loglog(freqs, bu.damped_osc_amp(freqs, *result['params'])*fac, \
               color=colors[resultind], lw=3)
plt.legend(fontsize=10, ncol=2, loc='upper left')
plt.xlabel('Frequency [Hz]')
plt.ylabel('Phase ASD [arb]')
plt.xlim(1.0, 2500)
plt.ylim(1e-7, 3e12)
plt.tight_layout()

gammas = np.array(gammas)
plt.figure()
plt.plot(gammas[1] / gammas[0])
plt.plot(gammas[2] / gammas[0])

plt.show()