示例#1
0
def single_mode_proportional_to_time(**kwargs):
    """Return WaveformModes object a single nonzero mode, proportional to time

    The waveform output by this function will have just one nonzero mode.  The behavior of that mode will be
    particularly simple; it will just be proportional to time.

    Parameters
    ----------
    s : int, optional
        Spin weight of the waveform field.  Default is -2.
    ell, m : int, optional
        The (ell, m) values of the nonzero mode in the returned waveform.  Default value is (abs(s), -abs(s)).
    ell_min, ell_max : int, optional
        Smallest and largest ell values present in the output.  Default values are abs(s) and 8.
    data_type : int, optional
        Default value is whichever psi_n corresponds to the input spin.  It is important to choose these, rather than
        `h` or `sigma` for the analytical solution to translations, which doesn't account for the direct contribution
        of supertranslations (as opposed to the indirect contribution, which involves moving points around).
    t_0, t_1 : float, optional
        Beginning and end of time.  Default values are -20. and 20.
    dt : float, optional
        Time step.  Default value is 0.1.
    beta : complex, optional
        Constant of proportionality such that nonzero mode is beta*t.  Default is 1.

    """
    s = kwargs.pop("s", -2)
    ell = kwargs.pop("ell", abs(s))
    m = kwargs.pop("m", -ell)
    ell_min = kwargs.pop("ell_min", abs(s))
    ell_max = kwargs.pop("ell_max", 8)
    data_type = kwargs.pop("data_type",
                           scri.DataType[scri.SpinWeights.index(s)])
    t_0 = kwargs.pop("t_0", -20.0)
    t_1 = kwargs.pop("t_1", 20.0)
    dt = kwargs.pop("dt", 1.0 / 10.0)
    t = np.arange(t_0, t_1 + dt, dt)
    n_times = t.size
    beta = kwargs.pop("beta", 1.0)
    data = np.zeros((n_times, sf.LM_total_size(ell_min, ell_max)),
                    dtype=complex)
    data[:, sf.LM_index(ell, m, ell_min)] = beta * t

    if kwargs:
        import pprint

        warnings.warn(
            f"\nUnused kwargs passed to this function:\n{pprint.pformat(kwargs, width=1)}"
        )

    return scri.WaveformModes(
        t=t,
        data=data,
        ell_min=ell_min,
        ell_max=ell_max,
        frameType=scri.Inertial,
        dataType=data_type,
        r_is_scaled_out=True,
        m_is_scaled_out=True,
    )
示例#2
0
def linear_waveform(begin=-10., end=100., n_times=1000, ell_min=2, ell_max=8):
    np.random.seed(hash('linear_waveform') %
                   4294967294)  # Use mod to get in an acceptable range
    axis = np.quaternion(0., *np.random.uniform(-1, 1, size=3)).normalized()
    t = np.linspace(begin, end, num=n_times)
    omega = 2 * np.pi * 4 / (t[-1] - t[0])
    frame = np.array([np.exp(axis * (omega * t_i / 2)) for t_i in t])
    lm = np.array([[ell, m] for ell in range(ell_min, ell_max + 1)
                   for m in range(-ell, ell + 1)])
    data = np.empty((t.shape[0], lm.shape[0]), dtype=complex)
    for i, m in enumerate(lm[:, 1]):
        # N.B.: This form is used in test_linear_interpolation; if you
        # change it here, you must change it there.
        data[:, i] = (m - 1j * m) * t
    W = scri.WaveformModes(t=t,
                           frame=frame,
                           data=data,
                           ell_min=min(lm[:, 0]),
                           ell_max=max(lm[:, 0]),
                           history=['# Called from linear_waveform'],
                           frameType=scri.Corotating,
                           dataType=scri.h,
                           r_is_scaled_out=True,
                           m_is_scaled_out=True)
    return W
示例#3
0
def random_waveform(begin=-10.,
                    end=100.,
                    n_times=1000,
                    ell_min=None,
                    ell_max=8,
                    dataType=scri.h):
    np.random.seed(hash('random_waveform') %
                   4294967294)  # Use mod to get in an acceptable range
    spin_weight = scri.SpinWeights[scri.DataType.index(dataType)]
    if ell_min is None:
        ell_min = abs(spin_weight)
    n_modes = (ell_max * (ell_max + 2) - ell_min**2 + 1)
    t = np.sort(np.random.uniform(begin, end, size=n_times))
    frame = np.array([
        np.quaternion(*np.random.uniform(-1, 1, 4)).normalized() for t_i in t
    ])
    data = np.random.normal(size=(n_times, n_modes, 2)).view(complex)[:, :, 0]
    W = scri.WaveformModes(t=t,
                           frame=frame,
                           data=data,
                           ell_min=ell_min,
                           ell_max=ell_max,
                           history=['# Called from random_waveform'],
                           frameType=scri.Corotating,
                           dataType=dataType,
                           r_is_scaled_out=True,
                           m_is_scaled_out=False)
    return W
示例#4
0
def test_rotations_of_0_0_mode(Rs):
    # The (ell,m)=(0,0) mode should be rotationally invariant
    n_copies = 10
    W_in = delta_waveform(0,
                          0,
                          begin=-10.0,
                          end=100.0,
                          n_times=n_copies * len(Rs),
                          ell_min=0,
                          ell_max=8)
    assert W_in.ensure_validity(alter=False)
    W_out = scri.WaveformModes(W_in)
    R_basis = np.array([R for R in Rs for i in range(n_copies)])
    W_out.rotate_decomposition_basis(R_basis)
    assert W_out.ensure_validity(alter=False)
    assert np.array_equal(W_out.t, W_in.t)
    assert np.max(np.abs(W_out.frame - R_basis)) == 0.0
    assert np.array_equal(W_out.data, W_in.data)
    assert W_out.ell_min == W_in.ell_min
    assert W_out.ell_max == W_in.ell_max
    assert np.array_equal(W_out.LM, W_in.LM)
    for h_in, h_out in zip(W_in.history, W_out.history[:-1]):
        assert h_in == h_out.replace(
            f"{type(W_out).__name__}_{str(W_out.num)}",
            f"{type(W_in).__name__}_{str(W_in.num)}") or (
                h_in.startswith("# ") and h_out.startswith("# "))
    assert W_out.frameType == W_in.frameType
    assert W_out.dataType == W_in.dataType
    assert W_out.r_is_scaled_out == W_in.r_is_scaled_out
    assert W_out.m_is_scaled_out == W_in.m_is_scaled_out
    assert W_out.num != W_in.num
示例#5
0
def test_pickling():
    import pickle
    W1 = scri.WaveformModes()
    W1_str = pickle.dumps(W1)
    W2 = pickle.loads(W1_str)
    assert '_WaveformBase__num' in W1.__dict__.keys()
    assert '_WaveformBase__num' in W2.__dict__.keys()
    assert W1._allclose(W2, rtol=0, atol=0, compare_history_beginnings=True)
示例#6
0
def test_rotations_of_each_mode_individually(Rs):
    ell_min = 0
    ell_max = 8  # sf.ell_max is just too much; this test is too slow, and ell=8 should be fine
    R_basis = Rs
    Ds = np.empty((len(Rs), sf.LMpM_total_size(ell_min, ell_max)),
                  dtype=complex)
    for i, R in enumerate(Rs):
        Ds[i, :] = sf.Wigner_D_matrices(R, ell_min, ell_max)
    for ell in range(ell_max + 1):
        first_zeros = np.zeros((len(Rs), sf.LM_total_size(ell_min, ell - 1)),
                               dtype=complex)
        later_zeros = np.zeros((len(Rs), sf.LM_total_size(ell + 1, ell_max)),
                               dtype=complex)
        for Mp in range(-ell, ell):
            W_in = delta_waveform(ell,
                                  Mp,
                                  begin=-10.0,
                                  end=100.0,
                                  n_times=len(Rs),
                                  ell_min=ell_min,
                                  ell_max=ell_max)
            # Now, the modes are f^{\ell,m[} = \delta^{\ell,mp}_{L,Mp}
            assert W_in.ensure_validity(alter=False)
            W_out = scri.WaveformModes(W_in)
            W_out.rotate_decomposition_basis(Rs)
            assert W_out.ensure_validity(alter=False)
            assert np.array_equal(W_out.t, W_in.t)
            assert np.max(np.abs(W_out.frame - R_basis)) == 0.0
            i_D0 = sf.LMpM_index(ell, Mp, -ell, ell_min)
            assert np.array_equal(
                W_out.data[:, :sf.LM_total_size(ell_min, ell - 1)],
                first_zeros)
            if ell < ell_max:
                assert np.array_equal(
                    W_out.data[:,
                               sf.LM_total_size(ell_min, ell - 1):-sf.
                               LM_total_size(ell + 1, ell_max)],
                    Ds[:, i_D0:i_D0 + (2 * ell + 1)],
                )
                assert np.array_equal(
                    W_out.data[:, -sf.LM_total_size(ell + 1, ell_max):],
                    later_zeros)
            else:
                assert np.array_equal(
                    W_out.data[:, sf.LM_total_size(ell_min, ell - 1):],
                    Ds[:, i_D0:i_D0 + (2 * ell + 1)])
            assert W_out.ell_min == W_in.ell_min
            assert W_out.ell_max == W_in.ell_max
            assert np.array_equal(W_out.LM, W_in.LM)
            for h_in, h_out in zip(W_in.history, W_out.history[:-1]):
                assert h_in == h_out.replace(
                    type(W_out).__name__ + str(W_out.num),
                    type(W_in).__name__ + str(W_in.num))
            assert W_out.frameType == W_in.frameType
            assert W_out.dataType == W_in.dataType
            assert W_out.r_is_scaled_out == W_in.r_is_scaled_out
            assert W_out.m_is_scaled_out == W_in.m_is_scaled_out
            assert W_out.num != W_in.num
示例#7
0
def single_mode_constant_rotation(**kwargs):
    """Return WaveformModes object a single nonzero mode, with phase proportional to time

    The waveform output by this function will have just one nonzero mode.  The behavior of that mode will be fairly
    simple; it will be given by exp(i*omega*t).  Note that omega can be complex, which gives damping.

    Parameters
    ----------
    s : int, optional
        Spin weight of the waveform field.  Default is -2.
    ell, m : int, optional
        The (ell, m) values of the nonzero mode in the returned waveform.  Default value is (abs(s), -abs(s)).
    ell_min, ell_max : int, optional
        Smallest and largest ell values present in the output.  Default values are abs(s) and 8.
    data_type : int, optional
        Default value is whichever psi_n corresponds to the input spin.  It is important to choose these, rather than
        `h` or `sigma` for the analytical solution to translations, which doesn't account for the direct contribution
        of supertranslations (as opposed to the indirect contribution, which involves moving points around).
    t_0, t_1 : float, optional
        Beginning and end of time.  Default values are -20. and 20.
    dt : float, optional
        Time step.  Default value is 0.1.
    omega : complex, optional
        Constant of proportionality such that nonzero mode is exp(i*omega*t).  Note that this can be complex, which
        implies damping.  Default is 0.5.

    """
    s = kwargs.pop('s', -2)
    ell = kwargs.pop('ell', abs(s))
    m = kwargs.pop('m', -ell)
    ell_min = kwargs.pop('ell_min', abs(s))
    ell_max = kwargs.pop('ell_max', 8)
    data_type = kwargs.pop('data_type',
                           scri.DataType[scri.SpinWeights.index(s)])
    t_0 = kwargs.pop('t_0', -20.0)
    t_1 = kwargs.pop('t_1', 20.0)
    dt = kwargs.pop('dt', 1. / 10.)
    t = np.arange(t_0, t_1 + dt, dt)
    n_times = t.size
    omega = complex(kwargs.pop('omega', 0.5))
    data = np.zeros((n_times, sf.LM_total_size(ell_min, ell_max)),
                    dtype=complex)
    data[:, sf.LM_index(ell, m, ell_min)] = np.exp(1j * omega * t)

    if kwargs:
        import pprint
        warnings.warn("\nUnused kwargs passed to this function:\n{0}".format(
            pprint.pformat(kwargs, width=1)))

    return scri.WaveformModes(t=t,
                              data=data,
                              ell_min=ell_min,
                              ell_max=ell_max,
                              frameType=scri.Inertial,
                              dataType=data_type,
                              r_is_scaled_out=True,
                              m_is_scaled_out=True)
示例#8
0
def delta_waveform(ell, m, begin=-10., end=100., n_times=1000, ell_min=2, ell_max=8):
    """WaveformModes with 1 in selected slot and 0 elsewhere"""
    n_modes = (ell_max * (ell_max + 2) - ell_min ** 2 + 1)
    t = np.linspace(begin, end, num=n_times)
    data = np.zeros((n_times, n_modes), dtype=complex)
    data[:, sf.LM_index(ell, m, ell_min)] = 1.0 + 0.0j
    W = scri.WaveformModes(t=t, data=data,  # frame=frame,
                           ell_min=ell_min, ell_max=ell_max,
                           history=['# Called from delta_waveform'],
                           frameType=scri.Inertial, dataType=scri.psi4,
                           r_is_scaled_out=False, m_is_scaled_out=True)
    return W
示例#9
0
def constant_waveform(begin=-10., end=100., n_times=1000, ell_min=2, ell_max=8):
    t = np.linspace(begin, end, num=n_times)
    frame = np.array([quaternion.x for t_i in t])
    lm = np.array([[ell, m] for ell in range(ell_min, ell_max + 1) for m in range(-ell, ell + 1)])
    data = np.empty((t.shape[0], lm.shape[0]), dtype=complex)
    for i, m in enumerate(lm[:, 1]):
        data[:, i] = (m - 1j * m)
    W = scri.WaveformModes(t=t, frame=frame, data=data,
                           ell_min=min(lm[:, 0]), ell_max=max(lm[:, 0]),
                           history=['# Called from constant_waveform'],
                           frameType=scri.Corotating, dataType=scri.h,
                           r_is_scaled_out=True, m_is_scaled_out=True)
    return W
示例#10
0
def test_empty_WaveformModes():
    W = scri.WaveformModes()
    assert W.ensure_validity(alter=False)
    assert W.t.shape == (0, )
    assert W.frame.shape == (0, )
    assert W.data.shape == (0, 0)
    assert W.LM.shape == (0, 2)
    assert W.ell_min == 0
    assert W.ell_max == -1
    # assert W.history == ['WaveformModes22 = WaveformModes(...)']
    assert W.frameType == scri.UnknownFrameType
    assert W.dataType == scri.UnknownDataType
    assert not W.r_is_scaled_out  # != True
    assert not W.m_is_scaled_out  # != True
示例#11
0
def modes_constructor(constructor_statement, data_functor, **kwargs):
    """WaveformModes object filled with data from the input functor

    Additional keyword arguments are mostly passed to the WaveformModes initializer, though some more reasonable
    defaults are provided.

    Parameters
    ----------
    constructor_statement : str
        This is a string form of the function call used to create the object.  This is passed to the WaveformBase
        initializer as the parameter of the same name.  See the docstring for more information.
    data_functor : function
        Takes a 1-d array of time values and an array of (ell, m) values and returns the complex array of data.
    t : float array, optional
        Time values of the data.  Default is `np.linspace(-10., 100., num=1101))`.
    ell_min, ell_max : int, optional
        Smallest and largest ell value present in the data.  Defaults are 2 and 8.

    """
    t = np.array(kwargs.pop("t", np.linspace(-10.0, 100.0, num=1101)),
                 dtype=float)
    frame = np.array(kwargs.pop("frame", []), dtype=np.quaternion)
    frameType = int(kwargs.pop("frameType", scri.Inertial))
    dataType = int(kwargs.pop("dataType", scri.h))
    r_is_scaled_out = bool(kwargs.pop("r_is_scaled_out", True))
    m_is_scaled_out = bool(kwargs.pop("m_is_scaled_out", True))
    ell_min = int(kwargs.pop("ell_min", abs(scri.SpinWeights[dataType])))
    ell_max = int(kwargs.pop("ell_max", 8))
    if kwargs:
        import pprint

        warnings.warn(
            f"\nUnused kwargs passed to this function:\n{pprint.pformat(kwargs, width=1)}"
        )
    data = data_functor(t, sf.LM_range(ell_min, ell_max))
    w = scri.WaveformModes(
        t=t,
        frame=frame,
        data=data,
        history=["# Called from constant_waveform"],
        frameType=frameType,
        dataType=dataType,
        r_is_scaled_out=r_is_scaled_out,
        m_is_scaled_out=m_is_scaled_out,
        constructor_statement=constructor_statement,
        ell_min=ell_min,
        ell_max=ell_max,
    )
    return w
def get_waveforms_from_cce_volume(filename, lmax=8):
    # scri takes a while to import, so only import it when needed
    import scri
    from scri import h, psi4, psi3, psi2, psi1, psi0

    output_extension = "CCE_" + filename.split("CceVolume")[-1]

    def real_idx(l, m):
        return 2 * (l**2 + l + m)

    def imag_idx(l, m):
        return 2 * (l**2 + l + m) + 1

    with h5py.File(filename, "r") as cce_volume_file:
        print("Preparing to extract waveforms from {}".format(filename))
        scri_data = cce_volume_file.get("cce_scri_data.vol")
        time_ids_and_values = [(x, scri_data.get(x).attrs["observation_value"])
                               for x in scri_data.keys()]
        time_ids_and_values = sorted(time_ids_and_values, key=lambda x: x[1])

        for data_type in [h, psi4, psi3, psi2, psi1, psi0]:
            if data_type is h:
                raw_data = []
                time_set = []
                for (time_id, time) in time_ids_and_values:
                    time_set.append(time)
                    raw_data.append(scri_data[time_id]["Strain"][()])
                raw_data = np.array(raw_data)
                time_set = np.array(time_set)

                modes = [(l, m) for l in range(0, lmax + 1)
                         for m in range(-l, l + 1)]
                mode_data = []
                for (l, m) in modes:
                    mode_data.append(raw_data[:, real_idx(l, m)] +
                                     1j * raw_data[:, imag_idx(l, m)])
                mode_data = np.array(mode_data).T

                WM = scri.WaveformModes(
                    t=np.array(time_set),
                    data=mode_data,
                    ell_min=0,
                    ell_max=lmax,
                    dataType=data_type,
                )
            else:
                raw_data = []
                time_set = []
                for (time_id, time) in time_ids_and_values:
                    time_set.append(time)
                    raw_data.append(
                        scri_data[time_id][scri.DataNames[data_type]][()])
                raw_data = np.array(raw_data)
                time_set = np.array(time_set)

                modes = [(l, m) for l in range(0, lmax + 1)
                         for m in range(-l, l + 1)]
                mode_data = []
                for (l, m) in modes:
                    mode_data.append(raw_data[:, real_idx(l, m)] +
                                     1j * raw_data[:, imag_idx(l, m)])
                mode_data = np.array(mode_data).T

                WM = scri.WaveformModes(
                    t=np.array(time_set),
                    data=mode_data,
                    ell_min=0,
                    ell_max=lmax,
                    dataType=data_type,
                )
            print("Writing {}...".format(WM.data_type_string), end='')
            scri.SpEC.write_to_h5(WM, output_extension)
            print("Done")
示例#13
0
def single_mode_proportional_to_time_supertranslated(**kwargs):
    """Return WaveformModes as in single_mode_proportional_to_time, with analytical supertranslation

    This function constructs the same basic object as the `single_mode_proportional_to_time`, but then applies an
    analytical supertranslation.  The arguments to this function are the same as to the other, with two additions:

    Additional parameters
    ---------------------
    supertranslation : complex array, optional
        Spherical-harmonic modes of the supertranslation to apply to the waveform.  This is overwritten by
         `space_translation` if present.  Default value is `None`.
    space_translation : float array of length 3, optional
        This is just the 3-vector representing the displacement to apply to the waveform.  Note that if
        `supertranslation`, this parameter overwrites it.  Default value is [1.0, 0.0, 0.0].

    """
    s = kwargs.pop("s", -2)
    ell = kwargs.pop("ell", abs(s))
    m = kwargs.pop("m", -ell)
    ell_min = kwargs.pop("ell_min", abs(s))
    ell_max = kwargs.pop("ell_max", 8)
    data_type = kwargs.pop("data_type",
                           scri.DataType[scri.SpinWeights.index(s)])
    t_0 = kwargs.pop("t_0", -20.0)
    t_1 = kwargs.pop("t_1", 20.0)
    dt = kwargs.pop("dt", 1.0 / 10.0)
    t = np.arange(t_0, t_1 + dt, dt)
    n_times = t.size
    beta = kwargs.pop("beta", 1.0)
    data = np.zeros((n_times, sf.LM_total_size(ell_min, ell_max)),
                    dtype=complex)
    data[:, sf.LM_index(ell, m, ell_min)] = beta * t
    supertranslation = np.array(kwargs.pop("supertranslation",
                                           np.array([], dtype=complex)),
                                dtype=complex)
    if "space_translation" in kwargs:
        if supertranslation.size < 4:
            supertranslation.resize((4, ))
        supertranslation[1:4] = -sf.vector_as_ell_1_modes(
            kwargs.pop("space_translation"))
    supertranslation_ell_max = int(math.sqrt(supertranslation.size) - 1)
    if supertranslation_ell_max * (supertranslation_ell_max +
                                   2) + 1 != supertranslation.size:
        raise ValueError(
            f"Bad number of elements in supertranslation: {supertranslation.size}"
        )
    for i, (ellpp, mpp) in enumerate(sf.LM_range(0, supertranslation_ell_max)):
        if supertranslation[i] != 0.0:
            mp = m + mpp
            for ellp in range(ell_min, min(ell_max, (ell + ellpp)) + 1):
                if ellp >= abs(mp):
                    addition = (beta * supertranslation[i] * math.sqrt(
                        ((2 * ellpp + 1) * (2 * ell + 1) *
                         (2 * ellp + 1)) / (4 * math.pi)) *
                                sf.Wigner3j(ellpp, ell, ellp, 0, -s, s) *
                                sf.Wigner3j(ellpp, ell, ellp, mpp, m, -mp))
                    if (s + mp) % 2 == 1:
                        addition *= -1
                    data[:, sf.LM_index(ellp, mp, ell_min)] += addition

    if kwargs:
        import pprint

        warnings.warn(
            f"\nUnused kwargs passed to this function:\n{pprint.pformat(kwargs, width=1)}"
        )

    return scri.WaveformModes(
        t=t,
        data=data,
        ell_min=ell_min,
        ell_max=ell_max,
        frameType=scri.Inertial,
        dataType=data_type,
        r_is_scaled_out=True,
        m_is_scaled_out=True,
    )
示例#14
0
                                     0) * np.cos(phi) / 40.0
    for ell in range(ell_min, ell_max + 1):
        for m in range(-ell, ell + 1):
            data[:, sf.LM_index(ell, m, ell_min)] = pn_leading_order_amplitude(
                ell, m, x,
                mass_ratio=mass_ratio) * (1 + np.sign(m) * modulation)

    # Apply ringdown (mode amplitudes are constant after t_merger)
    data *= ringdown[:, np.newaxis]

    h_corot = scri.WaveformModes(
        t=t,
        frame=frame,
        data=data,
        ell_min=ell_min,
        ell_max=ell_max,
        frameType=scri.Corotating,
        dataType=data_type,
        r_is_scaled_out=True,
        m_is_scaled_out=True,
    )

    if inertial:
        return h_corot.to_inertial_frame()
    else:
        return h_corot


def pn_leading_order_amplitude(ell, m, x, mass_ratio=1.0):
    """Return the leading-order amplitude of r*h/M in PN theory
def load(file_name, ignore_validation=False, check_md5=True, **kwargs):
    """Load a waveform in RPXMB format

    Parameters
    ----------
    file_name : str
        Relative or absolute path to the input HDF5 file.  If this string contains
        but does not *end* with `'.h5'`, the remainder of the string is taken to be
        the group within the HDF5 file in which the data is stored.  Also note that
        a JSON file is expected in the same location, with `.h5` replaced by
        `.json` (and the corresponding data must be stored under the `group` key if
        relevant).
    ignore_validation : bool, optional
        If `True`, the JSON file need not be present, and the validation keys
        (`h5_file_size`, `n_times`, and `md5sum`) will be ignored — though warnings
        may be issued.  If `False`, these are all required, with the possible
        exception of `h5_file_size` and `md5sum` if a group is used within the HDF5
        file, or `md5sum` if `check_md5` is `False`.
    check_md5 : bool, optional
        Default is `True`.  See `ignore_validation` for explanation.

    Keyword parameters
    ------------------
    data_type : str, optional
        One of `scri.DataNames`.  Default is "UnknownDataType".
    m_is_scaled_out : bool, optional
        Default is True
    r_is_scaled_out : bool, optional
        Default is True

    Note that the keyword parameters will be overridden by corresponding entries in
    the JSON file, if they exist.  If the JSON file does not exist, any keyword
    parameters not listed above will be passed through as the `json_data` field of
    the returned waveform.

    """
    import os
    import warnings
    import pathlib
    import bz2
    import json
    import numpy as np
    import h5py
    import quaternion
    import scri
    from scri.utilities import xor_timeseries_reverse as unxor
    from sxs.utilities import md5checksum

    def invalid(message):
        if ignore_validation:
            pass
        elif ignore_validation is None:
            warnings.warn(message)
        else:
            raise ValueError(message)

    group = None
    if ".h5" in file_name and not file_name.endswith(".h5"):
        file_name, group = file_name.split(".h5")
    if group == "/":
        group = None

    h5_path = pathlib.Path(file_name).expanduser().resolve().with_suffix(".h5")
    json_path = h5_path.with_suffix(".json")

    # This will be used for validation
    h5_size = h5_path.stat().st_size

    data_type = kwargs.pop("data_type", "UnknownDataType")
    m_is_scaled_out = bool(kwargs.pop("m_is_scaled_out", True))
    r_is_scaled_out = bool(kwargs.pop("r_is_scaled_out", True))

    if not json_path.exists():
        invalid(f'\nJSON file "{json_path}" cannot be found, but is expected for this data format.')
        json_data = kwargs.copy()
    else:
        with open(json_path) as f:
            json_data = json.load(f)
        if group is not None:
            json_data = json_data[group]

        data_type = json_data.get("data_info", {}).get("data_type", data_type)
        m_is_scaled_out = bool(json_data.get("data_info", {}).get("m_is_scaled_out", m_is_scaled_out))
        r_is_scaled_out = bool(json_data.get("data_info", {}).get("r_is_scaled_out", r_is_scaled_out))

        # Make sure this is our format
        sxs_format = json_data.get("sxs_format", "")
        if sxs_format not in sxs_formats:
            invalid(
                f"\nThe `sxs_format` found in JSON file is '{sxs_format}';\n"
                f"it should be one of\n"
                f"    {sxs_formats}."
            )

        if group is None:
            # Make sure the expected H5 file size matches the observed value
            json_h5_file_size = json_data.get("validation", {}).get("h5_file_size", 0)
            if json_h5_file_size != h5_size:
                invalid(
                    f"\nMismatch between `validation/h5_file_size` key in JSON file ({json_h5_file_size}) "
                    f'and observed file size ({h5_size}) of "{h5_path}".'
                )

            # Make sure the expected H5 file hash matches the observed value
            if check_md5:
                md5sum = md5checksum(h5_path)
                json_md5sum = json_data.get("validation", {}).get("md5sum", "")
                if json_md5sum != md5sum:
                    invalid(f"\nMismatch between `validation/md5sum` key in JSON file and observed MD5 checksum.")

    dataType = scri.DataType[scri.DataNames.index(data_type)]

    with h5py.File(h5_path, "r") as f:
        if group is not None:
            g = f[group]
        else:
            g = f
        # Make sure this is our format
        sxs_format = g.attrs["sxs_format"]
        if sxs_format not in sxs_formats:
            raise ValueError(
                f'The `sxs_format` found in H5 file is "{sxs_format}"; it should be one of\n'
                f"    {sxs_formats}."
            )

        # Ensure that the 'validation' keys from the JSON file are the same as in this file
        n_times = g.attrs["n_times"]
        json_n_times = json_data.get("validation", {}).get("n_times", 0)
        if json_n_times != n_times:
            invalid(
                f"\nNumber of time steps in H5 file ({n_times}) "
                f"does not match expected value from JSON ({json_n_times})."
            )

        # Read the raw data
        sizeof_float = 8
        sizeof_complex = 2 * sizeof_float
        ell_min = g.attrs["ell_min"]
        ell_max = g.attrs["ell_max"]
        shuffle_widths = tuple(g.attrs["shuffle_widths"])
        unshuffle = scri.utilities.multishuffle(shuffle_widths, forward=False)
        n_modes = ell_max * (ell_max + 2) - ell_min ** 2 + 1
        i1 = n_times * sizeof_float
        i2 = i1 + n_times * sizeof_complex * n_modes
        uncompressed_data = bz2.decompress(g["data"][...])
        t = np.frombuffer(uncompressed_data[:i1], dtype=np.uint64)
        data = np.frombuffer(uncompressed_data[i1:i2], dtype=np.uint64)
        log_frame = np.frombuffer(uncompressed_data[i2:], dtype=np.uint64)

    # Unshuffle the raw data
    t = unshuffle(t)
    data = unshuffle(data)
    log_frame = unshuffle(log_frame)

    # Reshape and re-interpret the data
    t = t.view(np.float64)
    data = data.reshape((-1, n_times)).T.copy().view(complex)
    log_frame = log_frame.reshape((-1, n_times)).T.copy().view(np.float64)

    # Un-XOR the data
    t = unxor(t)
    data = unxor(data)
    log_frame = unxor(log_frame)

    frame = np.exp(quaternion.as_quat_array(np.insert(log_frame, 0, 0.0, axis=1)))

    w = scri.WaveformModes(
        t=t,
        frame=frame,
        data=data,
        frameType=scri.Corotating,
        dataType=dataType,
        m_is_scaled_out=m_is_scaled_out,
        r_is_scaled_out=r_is_scaled_out,
        ell_min=ell_min,
        ell_max=ell_max,
    )
    w.convert_from_conjugate_pairs()
    w.json_data = json_data

    return w, log_frame
def load(file_name, ignore_validation=False, check_md5=True):
    import os
    import warnings
    import pathlib
    import bz2
    import json
    import numpy as np
    import h5py
    import quaternion
    import scri
    from scri.utilities import xor_timeseries_reverse as unxor
    from sxs.utilities import md5checksum

    def invalid(message):
        if ignore_validation:
            warnings.warn(message)
        else:
            raise ValueError(message)

    h5_path = pathlib.Path(file_name).expanduser().resolve().with_suffix(".h5")
    json_path = h5_path.with_suffix(".json")

    # This will be used for validation
    h5_size = os.stat(h5_path).st_size

    if not json_path.exists():
        invalid(
            f'\nJSON file "{json_path}" cannot be found, but is expected for this data format.'
        )
        json_data = {}
    else:
        with open(json_path) as f:
            json_data = json.load(f)

        dataType = json_data.get("data_info", {}).get("data_type",
                                                      "UnknownDataType")
        dataType = scri.DataType[scri.DataNames.index(dataType)]

        # Make sure this is our format
        sxs_format = json_data.get("sxs_format", "")
        if sxs_format not in sxs_formats:
            invalid(
                f'\nThe `sxs_format` found in JSON file is "{sxs_format}"; it should be one of\n'
                f"    {sxs_formats}.")

        # Make sure the expected H5 file size matches the observed value
        json_h5_file_size = json_data.get("validation",
                                          {}).get("h5_file_size", 0)
        if json_h5_file_size != h5_size:
            invalid(
                f"\nMismatch between `validation/h5_file_size` key in JSON file ({json_h5_file_size}) "
                f'and observed file size ({h5_size}) of "{h5_path}".')

        # Make sure the expected H5 file hash matches the observed value
        if check_md5:
            md5sum = md5checksum(h5_path)
            json_md5sum = json_data.get("validation", {}).get("md5sum", "")
            if json_md5sum != md5sum:
                invalid(
                    f"\nMismatch between `validation/md5sum` key in JSON file and observed MD5 checksum."
                )

    with h5py.File(h5_path, "r") as f:
        # Make sure this is our format
        sxs_format = f.attrs["sxs_format"]
        if sxs_format not in sxs_formats:
            raise ValueError(
                f'The `sxs_format` found in H5 file is "{sxs_format}"; it should be one of\n'
                f"    {sxs_formats}.")

        # Ensure that the 'validation' keys from the JSON file are the same as in this file
        n_times = f.attrs["n_times"]
        json_n_times = json_data.get("validation", {}).get("n_times", 0)
        if json_n_times != n_times:
            invalid(
                f"\nNumber of time steps in H5 file ({n_times}) "
                f"does not match expected value from JSON ({json_n_times}).")

        # Read the raw data
        sizeof_float = 8
        sizeof_complex = 2 * sizeof_float
        ell_min = f.attrs["ell_min"]
        ell_max = f.attrs["ell_max"]
        shuffle_widths = tuple(f.attrs["shuffle_widths"])
        unshuffle = scri.utilities.multishuffle(shuffle_widths, forward=False)
        n_modes = ell_max * (ell_max + 2) - ell_min**2 + 1
        i1 = n_times * sizeof_float
        i2 = i1 + n_times * sizeof_complex * n_modes
        uncompressed_data = bz2.decompress(f["data"][...])
        t = np.frombuffer(uncompressed_data[:i1], dtype=np.uint64)
        data = np.frombuffer(uncompressed_data[i1:i2], dtype=np.uint64)
        log_frame = np.frombuffer(uncompressed_data[i2:], dtype=np.uint64)

    # Unshuffle the raw data
    t = unshuffle(t)
    data = unshuffle(data)
    log_frame = unshuffle(log_frame)

    # Reshape and re-interpret the data
    t = t.view(np.float64)
    data = data.reshape((-1, n_times)).T.copy().view(np.complex128)
    log_frame = log_frame.reshape((-1, n_times)).T.copy().view(np.float64)

    # Un-XOR the data
    t = unxor(t)
    data = unxor(data)
    log_frame = unxor(log_frame)

    frame = np.exp(
        quaternion.as_quat_array(np.insert(log_frame, 0, 0.0, axis=1)))

    w = scri.WaveformModes(
        t=t,
        frame=frame,
        data=data,
        frameType=scri.Corotating,
        dataType=dataType,
        m_is_scaled_out=True,
        r_is_scaled_out=True,
        ell_min=ell_min,
        ell_max=ell_max,
    )
    w.convert_from_conjugate_pairs()
    w.json_data = json_data

    return w, log_frame