def check_dynamics_operator_symbolic():
    """
    Notes: import pi, sin and cos from sympy

    Returns:
        None
    """
    # todo: make it flexible and working
    w_c, w1 = sym.symbols('w_c w1')
    mc = [sym.symbols('m0{}'.format())]

    # 2-pool model
    spin_model = SpinModel(s0=100,
                           mc=(1.0, 0.152),
                           w0=((w_c, ) * 2),
                           r1=(1.8, 1.0),
                           r2=(32.2581, 8.4746e4),
                           k=(0.05, ),
                           approx=(None, 'superlorentz_approx'))

    print(spin_model)
    print(spin_model.m_eq)
    print(spin_model._k_op)
    print(spin_model._l_op)
    print(dynamics_operator(spin_model, w_c + 10.0, w1))

    # 3-pool model
    spin_model = SpinModel(s0=100,
                           mc=(1.0, 0.152),
                           w0=((w_c, ) * 3),
                           r1=(1.8, 1.0, 1.2),
                           r2=(32.2581, 8.4746e4, 30.0),
                           k=(0.05, 0.5, 0.1),
                           approx=(None, 'superlorentz_approx', None))

    print(spin_model)
    print(spin_model.m_eq)
    print(spin_model._k_op)
    print(spin_model._l_op)
    print(dynamics_operator(spin_model, w_c + 10.0, w1))

    # 4-pool model
    spin_model = SpinModel(s0=100,
                           mc=(1.0, 0.152),
                           w0=((w_c, ) * 4),
                           r1=(1.8, 1.0, 1.2, 2.0),
                           r2=(32.2581, 8.4746e4, 30.0, 60.0),
                           k=(0.05, 0.5, 0.1, 0.001, 0.4, 0.2),
                           approx=(None, 'superlorentz_approx', None, 'gauss'))

    print(spin_model)
    print(spin_model.m_eq)
    print(spin_model._k_op)
    print(spin_model._l_op)
    print(dynamics_operator(spin_model, w_c + 10.0, w1))
def check_dynamics_operator():
    """
    Notes: import pi, sin and cos from numpy

    Returns:

    """
    w_c = GAMMA['1H'] * B0
    w1 = 1.0

    # 2-pool model
    spin_model = SpinModel(s0=100,
                           mc=(1.0, 0.152),
                           w0=((w_c, ) * 2),
                           r1=(1.8, 1.0),
                           r2=(32.2581, 8.4746e4),
                           k=(0.3456, ),
                           approx=(None, 'superlorentz_approx'))

    print(spin_model)
    print(spin_model.m_eq)
    print(spin_model._k_op)
    print(spin_model._l_op)
    print(dynamics_operator(spin_model, w_c + 10.0, w1))

    # 3-pool model
    spin_model = SpinModel(m0=[v * 100.0 for v in (1.0, 0.152, 0.3)],
                           w0=((w_c, ) * 3),
                           r1=(1.8, 1.0, 1.2),
                           r2=(32.2581, 8.4746e4, 30.0),
                           k=(0.05, 0.5, 0.1),
                           approx=(None, 'superlorentz_approx', None))

    print(spin_model)
    print(spin_model.m_eq)
    print(spin_model._k_op)
    print(spin_model._l_op)
    print(dynamics_operator(spin_model, w_c + 10.0, w1))

    # 4-pool model
    spin_model = SpinModel(m0=[v * 100.0 for v in (1.0, 0.152, 0.3, 0.01)],
                           w0=((w_c, ) * 4),
                           r1=(1.8, 1.0, 1.2, 2.0),
                           r2=(32.2581, 8.4746e4, 30.0, 60.0),
                           k=(0.05, 0.5, 0.1, 0.001, 0.4, 0.2),
                           approx=(None, 'superlorentz_approx', None, 'gauss'))

    print(spin_model)
    print(spin_model.m_eq)
    print(spin_model._k_op)
    print(spin_model._l_op)
    print(dynamics_operator(spin_model, w_c + 10.0, w1))
def check_mt_sequence():
    """
    Test for the MT sequence.
    """
    w_c = GAMMA['1H'] * B0

    spin_model = SpinModel(s0=100,
                           mc=(1.0, 0.152),
                           w0=((w_c, ) * 2),
                           r1=(1.8, 1.0),
                           r2=(32.2581, 8.4746e4),
                           k=(0.3456, ),
                           approx=(None, 'superlorentz_approx'))

    num_repetitions = 300

    mt_flash_kernel = PulseSequence([
        Delay(10.0e-3),
        Spoiler(1.0),
        Pulse.shaped(40.0e-3, 220.0, 4000, 'gauss', None, w_c + 50.0, 'poly',
                     {'fit_order': 5}),
        Delay(20.0e-3),
        Spoiler(1.0),
        Pulse.shaped(10.0e-6, 90.0, 1, 'rect', None),
        Delay(30.0e-3)
    ],
                                    b0=3.0)
    mt_flash = PulseSequenceRepeated(mt_flash_kernel, num_repetitions)

    signal = mt_flash.signal(spin_model)

    print(mt_flash)
    print(mt_flash.propagator(spin_model))
    print(signal)
 def mt_signal(x_arr, s0, mc_a, r1a, r2a, r2b, k_ab):
     spin_model = SpinModel(
         s0=s0,
         mc=(mc_a, 1.0 - mc_a),
         w0=(w_c, w_c * (1 - 3.5e-6)),
         # w0=((w_c,) * 2),
         r1=(r1a, 1.0),
         r2=(r2a, r2b),
         k=(k_ab, ),
         approx=(None, 'superlorentz_approx'))
     y_arr = np.zeros_like(x_arr[:, 0])
     i = 0
     for freq, flip_angle in x_arr:
         mt_flash.set_flip_angle(flip_angle)
         mt_flash.set_freq(freq)
         y_arr[i] = mt_flash.signal(spin_model)
         i += 1
     return y_arr
def check_z_spectrum_sparse(spin_model=SpinModel(s0=1e8,
                                                 mc=(0.8681, 0.1319),
                                                 w0=((GAMMA['1H'] * 7.0, ) *
                                                     2),
                                                 r1=(1.8, 1.0),
                                                 r2=(32.2581, 8.4746e4),
                                                 k=(0.3456, ),
                                                 approx=(
                                                     None,
                                                     'superlorentz_approx')),
                            frequencies=np.round(np.geomspace(50, 10000, 32)),
                            amplitudes=np.round(np.linspace(1, 5000, 24)),
                            plot_data=True,
                            save_file=None):
    """
    Test calculation of z-spectra

    Args:

        spin_model (SpinModel):
        frequencies (ndarray[float]):
        amplitudes (ndarray[float]):
        plot_data (bool):
        save_file (string):

    Returns:
        freq

    """
    print('Checking Z-spectrum (sparse)')
    w_c = spin_model.w0[0]

    flip_angles = amplitudes * 11.799 / 50.0

    my_seq = MultiMtVarMGESS(pulses=[
        MagnetizationPreparation.shaped(10.0e-3, 90.0, 4000, 'gauss', {}, w_c,
                                        'poly', {'fit_order': 3}),
        Delay(1.0e-3),
        Spoiler(1.0),
        PulseExc.shaped(2.1e-3, 15.0, 1, 'rect', {}),
        ReadOut(),
        Spoiler(1.0),
    ],
                             tes=5.0e-3,
                             tr=70.0e-3,
                             n_r=300,
                             w_c=w_c,
                             preps=[(
                                 df,
                                 mfa,
                             ) for df in frequencies for mfa in flip_angles])
    data = my_seq.signal(spin_model).reshape(
        (len(frequencies), len(flip_angles)))

    # plot results
    if plot_data:
        sns.set_style('whitegrid')
        X, Y = np.meshgrid(flip_angles, np.log10(frequencies))
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.set_xlabel('Pulse Amplitude (flip angle) / deg')
        ax.set_ylabel('Frequency offset / Hz (log10 scale)')
        ax.set_zlabel('Signal Intensity / arb. units')
        ax.plot_surface(X,
                        Y,
                        data,
                        cmap=mpl.cm.plasma,
                        rstride=1,
                        cstride=1,
                        linewidth=0.01,
                        antialiased=False)
    if save_file:
        np.savez(save_file, frequencies, amplitudes, data)
    return data, frequencies, flip_angles
def check_approx_propagator(
        spin_model=SpinModel(s0=100,
                             mc=(0.8681, 0.1319),
                             w0=((GAMMA['1H'] * 7.0, ) * 2),
                             r1=(1.8, 1.0),
                             r2=(32.2581, 8.4746e4),
                             k=(0.3456, ),
                             approx=(None, 'superlorentz_approx')),
        flip_angle=90.0):
    """
    Test the approximation of propagators - for speeding up.

    Args:
        spin_model (SpinModel):
        flip_angles (float):
    """
    w_c = spin_model.w0[0]

    modes = ['exact']
    modes.extend(['linear', 'reduced'])
    # modes.extend(['sum_simple', 'sum_order1', 'sum_sep', 'reduced'])
    modes.extend(['poly_{}'.format(order) for order in range(4, 5)])
    modes.extend([
        'interp_{}_{}'.format(mode, num_samples)
        for mode in ['linear', 'cubic'] for num_samples in range(4, 5)
    ])
    modes = {
        'linear': {
            'num_samples': tuple(range(10, 20, 5))
        },
        'interp': {
            'method': ('linear', 'cubic'),
            'num_samples': tuple(range(10, 20, 3))
        },
        'reduced': {
            'num_resamples': tuple(range(10, 20, 5))
        },
        'poly': {
            'fit_order': tuple(range(3, 6))
        }
    }

    shapes = {
        'gauss': {},
        'lorentz': {},
        'sinc': {},
        'fermi': {},
        # 'random': {},
        'cos_sin': {},
    }
    exact_p_ops = {}
    for shape, shape_kwargs in shapes.items():
        pulse = Pulse.shaped(40.0e-3, flip_angle, 4000, shape, shape_kwargs,
                             w_c, 'exact', {})
        exact_p_ops[shape] = pulse.propagator(spin_model)

    for shape, shape_kwargs in shapes.items():
        for mode, mode_params in modes.items():
            kwargs_items = [{}]
            names = mode_params.keys()
            for values in itertools.product(*[mode_params[i] for i in names]):
                kwargs_items.append(dict(zip(names, values)))
            for kws in kwargs_items:
                pulse = Pulse.shaped(40.0e-3, flip_angle, 4000, shape,
                                     shape_kwargs, w_c, mode, kws)
                begin_time = datetime.datetime.now()
                p_op = pulse.propagator(spin_model)
                elapsed = datetime.timedelta(datetime.datetime.now() -
                                             begin_time)
                rel_error = np.sum(np.abs(exact_p_ops[shape] - p_op)) / \
                            np.sum(np.abs(exact_p_ops[shape]))
                print('{:>8s}, {:>8s}, {:>48s},\t{:.3e}, {}'.format(
                    shape, mode, str(kws), rel_error, elapsed))