Beispiel #1
0
def test_average_quats():
    """Test averaging of quaternions."""
    sq2 = 1. / np.sqrt(2.)
    quats = np.array(
        [[0, sq2, sq2], [0, sq2, sq2], [0, sq2, 0], [0, 0, sq2], [sq2, 0, 0]],
        float)
    # In MATLAB:
    # quats = [[0, sq2, sq2, 0]; [0, sq2, sq2, 0];
    #          [0, sq2, 0, sq2]; [0, 0, sq2, sq2]; [sq2, 0, 0, sq2]];
    expected = [
        quats[0], quats[0], [0, 0.788675134594813, 0.577350269189626],
        [0, 0.657192299694123, 0.657192299694123],
        [0.100406058540540, 0.616329446922803, 0.616329446922803]
    ]
    # Averaging the first two should give the same thing:
    for lim, ex in enumerate(expected):
        assert_allclose(_average_quats(quats[:lim + 1]), ex, atol=1e-7)
    quats[1] *= -1  # same quaternion (hidden value is zero here)!
    rot_0, rot_1 = quat_to_rot(quats[:2])
    assert_allclose(rot_0, rot_1, atol=1e-7)
    for lim, ex in enumerate(expected):
        assert_allclose(_average_quats(quats[:lim + 1]), ex, atol=1e-7)
    # Assert some symmetry
    count = 0
    extras = [[sq2, sq2, 0]] + list(np.eye(3))
    for quat in np.concatenate((quats, expected, extras)):
        if np.isclose(_quat_real(quat), 0., atol=1e-7):  # can flip sign
            count += 1
            angle = _angle_between_quats(quat, -quat)
            assert_allclose(angle, 0., atol=1e-7)
            rot_0, rot_1 = quat_to_rot(np.array((quat, -quat)))
            assert_allclose(rot_0, rot_1, atol=1e-7)
    assert count == 4 + len(extras)
Beispiel #2
0
def test_quaternions():
    """Test quaternion calculations."""
    rots = [np.eye(3)]
    for fname in [test_fif_fname, ctf_fname, hp_fif_fname]:
        rots += [read_info(fname)['dev_head_t']['trans'][:3, :3]]
    # nasty numerical cases
    rots += [
        np.array([
            [-0.99978541, -0.01873462, -0.00898756],
            [-0.01873462, 0.62565561, 0.77987608],
            [-0.00898756, 0.77987608, -0.62587152],
        ])
    ]
    rots += [
        np.array([
            [0.62565561, -0.01873462, 0.77987608],
            [-0.01873462, -0.99978541, -0.00898756],
            [0.77987608, -0.00898756, -0.62587152],
        ])
    ]
    rots += [
        np.array([
            [-0.99978541, -0.00898756, -0.01873462],
            [-0.00898756, -0.62587152, 0.77987608],
            [-0.01873462, 0.77987608, 0.62565561],
        ])
    ]
    for rot in rots:
        assert_allclose(rot,
                        quat_to_rot(rot_to_quat(rot)),
                        rtol=1e-5,
                        atol=1e-5)
        rot = rot[np.newaxis, np.newaxis, :, :]
        assert_allclose(rot,
                        quat_to_rot(rot_to_quat(rot)),
                        rtol=1e-5,
                        atol=1e-5)

    # let's make sure our angle function works in some reasonable way
    for ii in range(3):
        for jj in range(3):
            a = np.zeros(3)
            b = np.zeros(3)
            a[ii] = 1.
            b[jj] = 1.
            expected = np.pi if ii != jj else 0.
            assert_allclose(_angle_between_quats(a, b), expected, atol=1e-5)

    y_180 = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1.]])
    assert_allclose(_angle_between_quats(rot_to_quat(y_180), np.zeros(3)),
                    np.pi)
    h_180_attitude_90 = np.array([[0, 1, 0], [1, 0, 0], [0, 0, -1.]])
    assert_allclose(
        _angle_between_quats(rot_to_quat(h_180_attitude_90), np.zeros(3)),
        np.pi)
Beispiel #3
0
def _rand_affine(rng):
    quat = rng.randn(3)
    quat /= 5 * np.linalg.norm(quat)
    affine = np.eye(4)
    affine[:3, 3] = rng.randn(3) / 5.
    affine[:3, :3] = quat_to_rot(quat)
    return affine
Beispiel #4
0
def test_euler(quats):
    """Test euler transformations."""
    euler = _quat_to_euler(quats)
    quats_2 = _euler_to_quat(euler)
    assert_allclose(quats, quats_2, atol=1e-14)
    quat_rot = quat_to_rot(quats)
    euler_rot = np.array([rotation(*e)[:3, :3] for e in euler])
    assert_allclose(quat_rot, euler_rot, atol=1e-14)
def test_fit_matched_points(quats, scaling, do_scale):
    """Test analytical least-squares matched point fitting."""
    if scaling != 1 and not do_scale:
        return  # no need to test this, it will not be good
    rng = np.random.RandomState(0)
    fro = rng.randn(10, 3)
    translation = rng.randn(3)
    for qi, quat in enumerate(quats):
        to = scaling * np.dot(quat_to_rot(quat), fro.T).T + translation
        for corrupted in (False, True):
            # mess up a point
            if corrupted:
                to[0, 2] += 100
                weights = np.ones(len(to))
                weights[0] = 0
            else:
                weights = None
            est, scale_est = _check_fit_matched_points(
                fro, to, weights=weights, do_scale=do_scale)
            assert_allclose(scale_est, scaling, rtol=1e-5)
            assert_allclose(est[:3], quat, atol=1e-14)
            assert_allclose(est[3:], translation, atol=1e-14)
        # if we don't adjust for the corruption above, it should get worse
        angle = dist = None
        for weighted in (False, True):
            if not weighted:
                weights = None
                dist_bounds = (5, 20)
                if scaling == 1:
                    angle_bounds = (5, 95)
                    angtol, dtol, stol = 1, 15, 3
                else:
                    angle_bounds = (5, 105)
                    angtol, dtol, stol = 20, 15, 3
            else:
                weights = np.ones(len(to))
                weights[0] = 10  # weighted=True here means "make it worse"
                angle_bounds = (angle, 180)  # unweighted values as new min
                dist_bounds = (dist, 100)
                if scaling == 1:
                    # XXX this angtol is not great but there is a hard to
                    # identify linalg/angle calculation bug on Travis...
                    angtol, dtol, stol = 180, 70, 3
                else:
                    angtol, dtol, stol = 50, 70, 3
            est, scale_est = _check_fit_matched_points(
                fro, to, weights=weights, do_scale=do_scale,
                angtol=angtol, dtol=dtol, stol=stol)
            assert not np.allclose(est[:3], quat, atol=1e-5)
            assert not np.allclose(est[3:], translation, atol=1e-5)
            angle = np.rad2deg(_angle_between_quats(est[:3], quat))
            assert_array_less(angle_bounds[0], angle)
            assert_array_less(angle, angle_bounds[1])
            dist = np.linalg.norm(est[3:] - translation)
            assert_array_less(dist_bounds[0], dist)
            assert_array_less(dist, dist_bounds[1])
def test_quaternions():
    """Test quaternion calculations
    """
    rots = [np.eye(3)]
    for fname in [test_fif_fname, ctf_fname, hp_fif_fname]:
        rots += [read_info(fname)['dev_head_t']['trans'][:3, :3]]
    # nasty numerical cases
    rots += [np.array([
        [-0.99978541, -0.01873462, -0.00898756],
        [-0.01873462, 0.62565561, 0.77987608],
        [-0.00898756, 0.77987608, -0.62587152],
    ])]
    rots += [np.array([
        [0.62565561, -0.01873462, 0.77987608],
        [-0.01873462, -0.99978541, -0.00898756],
        [0.77987608, -0.00898756, -0.62587152],
    ])]
    rots += [np.array([
        [-0.99978541, -0.00898756, -0.01873462],
        [-0.00898756, -0.62587152, 0.77987608],
        [-0.01873462, 0.77987608, 0.62565561],
    ])]
    for rot in rots:
        assert_allclose(rot, quat_to_rot(rot_to_quat(rot)),
                        rtol=1e-5, atol=1e-5)
        rot = rot[np.newaxis, np.newaxis, :, :]
        assert_allclose(rot, quat_to_rot(rot_to_quat(rot)),
                        rtol=1e-5, atol=1e-5)

    # let's make sure our angle function works in some reasonable way
    for ii in range(3):
        for jj in range(3):
            a = np.zeros(3)
            b = np.zeros(3)
            a[ii] = 1.
            b[jj] = 1.
            expected = np.pi if ii != jj else 0.
            assert_allclose(_angle_between_quats(a, b), expected, atol=1e-5)
Beispiel #7
0
def test_average_quats():
    """Test averaging of quaternions."""
    sq2 = 1. / np.sqrt(2.)
    quats = np.array(
        [[0, sq2, sq2], [0, sq2, sq2], [0, sq2, 0], [0, 0, sq2], [sq2, 0, 0]],
        float)
    # In MATLAB:
    # quats = [[0, sq2, sq2, 0]; [0, sq2, sq2, 0];
    #          [0, sq2, 0, sq2]; [0, 0, sq2, sq2]; [sq2, 0, 0, sq2]];
    expected = [
        quats[0], quats[0], [0, 0.788675134594813, 0.577350269189626],
        [0, 0.657192299694123, 0.657192299694123],
        [0.100406058540540, 0.616329446922803, 0.616329446922803]
    ]
    # Averaging the first two should give the same thing:
    for lim, ex in enumerate(expected):
        assert_allclose(_average_quats(quats[:lim + 1]), ex, atol=1e-7)
    quats[1] *= -1  # same quaternion (hidden value is zero here)!
    rot_0, rot_1 = quat_to_rot(quats[:2])
    assert_allclose(rot_0, rot_1, atol=1e-7)
    for lim, ex in enumerate(expected):
        assert_allclose(_average_quats(quats[:lim + 1]), ex, atol=1e-7)
Beispiel #8
0
def test_get_chpi():
    """Test CHPI position computation
    """
    with warnings.catch_warnings(record=True):  # deprecation
        trans0, rot0, _, quat0 = get_chpi_positions(hp_fname, return_quat=True)
    assert_allclose(rot0[0], quat_to_rot(quat0[0]))
    trans0, rot0 = trans0[:-1], rot0[:-1]
    raw = Raw(hp_fif_fname)
    with warnings.catch_warnings(record=True):  # deprecation
        out = get_chpi_positions(raw)
    trans1, rot1, t1 = out
    trans1, rot1 = trans1[2:], rot1[2:]
    # these will not be exact because they don't use equiv. time points
    assert_allclose(trans0, trans1, atol=1e-5, rtol=1e-1)
    assert_allclose(rot0, rot1, atol=1e-6, rtol=1e-1)
    # run through input checking
    raw_no_chpi = Raw(test_fif_fname)
    with warnings.catch_warnings(record=True):  # deprecation
        assert_raises(TypeError, get_chpi_positions, 1)
        assert_raises(ValueError, get_chpi_positions, hp_fname, [1])
        assert_raises(RuntimeError, get_chpi_positions, raw_no_chpi)
        assert_raises(ValueError, get_chpi_positions, raw, t_step='foo')
        assert_raises(IOError, get_chpi_positions, 'foo')
Beispiel #9
0
def test_get_chpi():
    """Test CHPI position computation
    """
    with warnings.catch_warnings(record=True):  # deprecation
        trans0, rot0, _, quat0 = get_chpi_positions(hp_fname, return_quat=True)
    assert_allclose(rot0[0], quat_to_rot(quat0[0]))
    trans0, rot0 = trans0[:-1], rot0[:-1]
    raw = Raw(hp_fif_fname)
    with warnings.catch_warnings(record=True):  # deprecation
        out = get_chpi_positions(raw)
    trans1, rot1, t1 = out
    trans1, rot1 = trans1[2:], rot1[2:]
    # these will not be exact because they don't use equiv. time points
    assert_allclose(trans0, trans1, atol=1e-5, rtol=1e-1)
    assert_allclose(rot0, rot1, atol=1e-6, rtol=1e-1)
    # run through input checking
    raw_no_chpi = Raw(test_fif_fname)
    with warnings.catch_warnings(record=True):  # deprecation
        assert_raises(TypeError, get_chpi_positions, 1)
        assert_raises(ValueError, get_chpi_positions, hp_fname, [1])
        assert_raises(RuntimeError, get_chpi_positions, raw_no_chpi)
        assert_raises(ValueError, get_chpi_positions, raw, t_step='foo')
        assert_raises(IOError, get_chpi_positions, 'foo')
def test_average_quats():
    """Test averaging of quaternions."""
    sq2 = 1. / np.sqrt(2.)
    quats = np.array([[0, sq2, sq2],
                      [0, sq2, sq2],
                      [0, sq2, 0],
                      [0, 0, sq2],
                      [sq2, 0, 0]], float)
    # In MATLAB:
    # quats = [[0, sq2, sq2, 0]; [0, sq2, sq2, 0];
    #          [0, sq2, 0, sq2]; [0, 0, sq2, sq2]; [sq2, 0, 0, sq2]];
    expected = [quats[0],
                quats[0],
                [0, 0.788675134594813, 0.577350269189626],
                [0, 0.657192299694123, 0.657192299694123],
                [0.100406058540540, 0.616329446922803, 0.616329446922803]]
    # Averaging the first two should give the same thing:
    for lim, ex in enumerate(expected):
        assert_allclose(_average_quats(quats[:lim + 1]), ex, atol=1e-7)
    quats[1] *= -1  # same quaternion (hidden value is zero here)!
    rot_0, rot_1 = quat_to_rot(quats[:2])
    assert_allclose(rot_0, rot_1, atol=1e-7)
    for lim, ex in enumerate(expected):
        assert_allclose(_average_quats(quats[:lim + 1]), ex, atol=1e-7)
ind_ok = np.where(g > GOF_LIMIT)[0]
times_ok = times_all[ind_ok]
"""
quats give device -> head transformation: y = Rx + T 
where y is in head coords and x in device coords.
to get pos of head origin in device coords: y=0 -> Rx = -T -> x = -R.T * T 
rotation of head coord system: y = Rx + T -> x = R.T * y -R.T * T
-> rotation given by R.T
result agrees with maxfilter (note that maxfilter plots coords of its sss
expansion origin, not of the head origin!)
"""

# rotation quaternions -> rot matrices
Rq, _ = raw[picks_chpi_r, :]
Rq = Rq[:, ind_ok].T
R = quat_to_rot(Rq)

# translation vectors
T, _ = raw[picks_chpi_t, :]
T = T[:, ind_ok].T

# head origin (see above)
nrot = R.shape[0]
Y = np.zeros((nrot, 3))
A = np.zeros((nrot, 3))
for k in np.arange(nrot):
    Y[k, :] = -np.dot(R[k].T, T[k])

# get rot. angle changes directly from quaternion data
#Rq = lowpass(Rq, raw.info['sfreq'], args.lpcorner, axis=0)
dA = _angle_between_quats(Rq[1:, :], Rq[:-1, :]) / np.pi * 180
times_ok = times_all[ind_ok]

"""
quats give device -> head transformation: y = Rx + T 
where y is in head coords and x in device coords.
to get pos of head origin in device coords: y=0 -> Rx = -T -> x = -R.T * T 
rotation of head coord system: y = Rx + T -> x = R.T * y -R.T * T
-> rotation given by R.T
result agrees with maxfilter (note that maxfilter plots coords of its sss
expansion origin, not of the head origin!)
"""

# rotation quaternions -> rot matrices
Rq, _ = raw[picks_chpi_r, :]
Rq = Rq[:, ind_ok].T
R = quat_to_rot(Rq)

# translation vectors
T, _ = raw[picks_chpi_t, :]
T = T[:, ind_ok].T

# head origin (see above)
nrot = R.shape[0]
Y = np.zeros((nrot, 3))
A = np.zeros((nrot, 3))
for k in np.arange(nrot):
    Y[k, :] = -np.dot(R[k].T, T[k])

# get rot. angle changes directly from quaternion data
#Rq = lowpass(Rq, raw.info['sfreq'], args.lpcorner, axis=0)
dA = _angle_between_quats(Rq[1:, :], Rq[:-1, :]) / np.pi * 180
Beispiel #13
0
def contAvg_headpos(condition, method='median', folder=[], summary=False):
    """
    Calculate average transformation from dewar to head coordinates, based 
    on the continous head position estimated from MaxFilter

    Parameters
    ----------
    condition : str
        String containing part of common filename, e.g. "task" for files 
        task-1.fif, task-2.fif, etc. Consistent naiming of files is mandatory!
    method : str
        How to calculate "average, "mean" or "median" (default = "median")
    folder : str
        Path to input files. Default = current dir.

    Returns
    -------
    MNE-Python transform object
        4x4 transformation matrix
    """
    # Check that the method works
    method = method.lower()

    if method not in ['median', 'mean']:
        raise RuntimeError(
            'Wrong method. Must be either \"mean\" or "median"!')
    if not condition:
        raise RuntimeError('You must provide a conditon!')

    # Get and set folders
    if not folder:
        rawdir = getcwd()  # [!] Match up with bash script !
    else:
        rawdir = folder

    print(rawdir)
    quatdir = op.join(rawdir, 'quat_files')

    mean_trans_folder = op.join(rawdir, 'trans_files')
    if not op.exists(mean_trans_folder):  # Make sure output folder exists
        mkdir(mean_trans_folder)

    mean_trans_file = op.join(mean_trans_folder, condition + '-trans.fif')
    if op.isfile(mean_trans_file):
        warnings.warn(
            'N"%s\" already exists is %s. Delete if you want to rerun' %
            (mean_trans_file, mean_trans_folder), RuntimeWarning)
        return

    # Change to subject dir
    files2combine = find_condition_files(quatdir, condition)
    files2combine.sort()

    if not files2combine:
        raise RuntimeError('No files called \"%s\" found in %s' %
                           (condition, quatdir))

    allfiles = []
    for ff in files2combine:
        fl = ff.split('_')[0]
        tmplist = [f for f in listdir(quatdir) if fl in f and '_quat' in f]

        #Fix order
        if len(tmplist) > 1:
            tmplist.sort()
            if any("-" in f for f in tmplist):
                firstfile = tmplist[
                    -1]  # The file without a number will always be last!
                tmpfs = sorted(tmplist[:-1],
                               key=lambda a: int(re.split('-|.fif', a)[-2])
                               )  # Assuming consistent naming!!!
                tmplist[0] = firstfile
                tmplist[1:] = tmpfs
                allfiles = allfiles + tmplist

    if len(allfiles) > 1:
        print('Files used for average head pos:')
        for ib in range(len(allfiles)):
            print('{:d}: {:s}'.format(ib + 1, allfiles[ib]))
    else:
        print('Will find average head pos in %s' % files2combine)

    # LOAD DATA
    # raw = read_raw_fif(op.join(quatdir,firstfile), preload=True, allow_maxshield=True, verbose=False).pick_types(meg=False, chpi=True)
    # Use files2combine instead of allfiles as MNE will find split files automatically.
    for idx, ffs in enumerate(files2combine):
        if idx == 0:
            raw = read_raw_fif(op.join(quatdir, ffs),
                               preload=True,
                               allow_maxshield=True).pick_types(meg=False,
                                                                chpi=True)
        else:
            raw.append(
                read_raw_fif(op.join(quatdir, ffs),
                             preload=True,
                             allow_maxshield=True).pick_types(meg=False,
                                                              chpi=True))

    quat, times = raw.get_data(return_times=True)
    gof = quat[6, ]  # Godness of fit channel
    # fs = raw.info['sfreq']

    # In case "record raw" started before "cHPI"
    if np.any(gof < 0.98):
        begsam = np.argmax(gof > 0.98)
        raw.crop(tmin=raw.times[begsam])
        quat = quat[:, begsam:].copy()
        times = times[begsam:].copy()

    # Make summaries
    if summary:
        plot_movement(quat, times, dirname=rawdir, identifier=condition)
        total_dist_moved(quat,
                         times,
                         write=True,
                         dirname=rawdir,
                         identifier=condition)

    # Get continous transformation
    print('Reading transformation. This will take a while...')
    H = np.empty([4, 4, len(times)])  # Initiate transforms
    init_rot_angles = np.empty([len(times), 3])

    for i, t in enumerate(times):
        Hi = np.eye(4, 4)
        Hi[0:3, 3] = quat[3:6, i].copy()
        Hi[:3, :3] = quat_to_rot(quat[0:3, i])
        init_rot_angles[i, :] = rotation_angles(Hi[:3, :3])
        assert (np.sum(Hi[-1]) == 1.0)  # sanity check result
        H[:, :, i] = Hi.copy()

    if method in ["mean"]:
        H_mean = np.mean(H, axis=2)  # stack, then average over new dim
        mean_rot_xfm = rotation3d(*tuple(
            np.mean(init_rot_angles,
                    axis=0)))  # stack, then average, then make new xfm
    elif method in ["median"]:
        H_mean = np.median(H, axis=2)  # stack, then average over new dim
        mean_rot_xfm = rotation3d(*tuple(
            np.median(init_rot_angles,
                      axis=0)))  # stack, then average, then make new xfm

    H_mean[:3, :3] = mean_rot_xfm
    assert (np.sum(H_mean[-1]) == 1.0)  # sanity check result

    # Create the mean structure and save as .fif
    mean_trans = raw.info['dev_head_t']  # use the last info as a template
    mean_trans['trans'] = H_mean.copy()

    # Write file
    write_trans(mean_trans_file, mean_trans)
    print("Wrote " + mean_trans_file)

    return mean_trans
Beispiel #14
0
def contAvg_headpos(condition, method='median', folder=[]):
    """
    Calculate average transformation from dewar to head coordinates, based 
    on the continous head position estimated from MaxFilter

    Parameters
    ----------
    condition : str
        String containing part of common filename, e.g. "task" for files 
        task-1.fif, task-2.fif, etc. Consistent naiming of files is mandatory!
    method : str
        How to calculate "average, "mean" or "median" (default = "median")
    folder : str
        Path to input files. Default = current dir.

    Returns
    -------
    MNE-Python transform object
        4x4 transformation matrix
    """
    # Check that the method works
    if method not in ['median', 'mean']:
        raise RuntimeError(
            'Wrong method. Must be either \"mean\" or "median"!')
    if not condition:
        raise RuntimeError('You must provide a conditon!')

    # Get and set folders
    if not folder:
        rawdir = getcwd()  # [!] Match up with bash script !
    else:
        rawdir = folder

    print(rawdir)
    quatdir = op.join(rawdir, 'quat_files')

    mean_trans_folder = op.join(rawdir, 'trans_files')
    if not op.exists(mean_trans_folder):  # Make sure output folder exists
        mkdir(mean_trans_folder)

    mean_trans_file = op.join(mean_trans_folder, condition + '-trans.fif')
    if op.isfile(mean_trans_file):
        raise RuntimeError(
            'N"%s\" already exists is %s. Delete aif you want to rerun' %
            (mean_trans_file, mean_trans_folder))

    # Change to subject dir
#    files2combine = glob.glob('%s*' % condition)
    files2combine = [
        f for f in listdir(quatdir) if condition in f and '_quat' in f
    ]

    if not files2combine:
        raise RuntimeError('No files called \"%s\" found in %s' %
                           (condition, quatdir))
    elif len(files2combine) > 1:
        print('Files used for average head pos:')
        for ib in range(len(files2combine)):
            print('{:d}: {:s}'.format(ib + 1, files2combine[ib]))
    else:
        print('Will find average head pos in %s' % files2combine)

    # LOAD DATA
    for idx, ffs in enumerate(files2combine):
        #        print op.join(quatdir,ffs)
        if idx == 0:
            raw = read_raw_fif(op.join(quatdir, ffs),
                               preload=True,
                               allow_maxshield=True).pick_types(meg=False,
                                                                chpi=True)
        else:
            raw.append(
                read_raw_fif(op.join(quatdir, ffs),
                             preload=True,
                             allow_maxshield=True).pick_types(meg=False,
                                                              chpi=True))

    quat, times = raw.get_data(return_times=True)
    gof = quat[6, ]  # Godness of fit channel
    fs = raw.info['sfreq']

    # In case "record raw" started before "cHPI"
    if np.any(gof < 0.98):
        begsam = np.argmax(gof > 0.98)

        raw.crop(tmin=raw.times[begsam])
        quat = quat[:, begsam:].copy()
        times = times[begsam:].copy()

    # Get continous transformation
    print('Reading transformation. This will take a while...')
    H = np.empty([4, 4, len(times)])  # Initiate transforms
    init_rot_angles = np.empty([len(times), 3])

    for i, t in enumerate(times):
        Hi = np.eye(4, 4)
        Hi[0:3, 3] = quat[3:6, i].copy()
        Hi[:3, :3] = quat_to_rot(quat[0:3, i])
        init_rot_angles[i, :] = rotation_angles(Hi[:3, :3])
        assert (np.sum(Hi[-1]) == 1.0)  # sanity check result
        H[:, :, i] = Hi.copy()

    if method in ["mean"]:
        H_mean = np.mean(H, axis=2)  # stack, then average over new dim
        mean_rot_xfm = rotation3d(*tuple(
            np.mean(init_rot_angles,
                    axis=0)))  # stack, then average, then make new xfm
    elif method in ["median"]:
        H_mean = np.median(H, axis=2)  # stack, then average over new dim
        mean_rot_xfm = rotation3d(*tuple(
            np.median(init_rot_angles,
                      axis=0)))  # stack, then average, then make new xfm

    H_mean[:3, :3] = mean_rot_xfm
    assert (np.sum(H_mean[-1]) == 1.0)  # sanity check result

    # Create the mean structure and save as .fif
    mean_trans = raw.info['dev_head_t']  # use the last info as a template
    mean_trans['trans'] = H_mean.copy()

    #    plot_alignment(raw.info,subject='0406',subjects_dir='/home/mikkel/PD_motor/fs_subjects_dir/',dig=True, meg='helmet')

    # Write file
    write_trans(mean_trans_file, mean_trans)
    print("Wrote " + mean_trans_file)

    return mean_trans