Пример #1
0
def test_GradientTable():

    gradients = np.array([[0, 0, 0],
                          [1, 0, 0],
                          [0, 0, 1],
                          [3, 4, 0],
                          [5, 0, 12]], 'float')

    expected_bvals = np.array([0, 1, 1, 5, 13])
    expected_b0s_mask = expected_bvals == 0
    expected_bvecs = gradients / (expected_bvals + expected_b0s_mask)[:, None]

    gt = GradientTable(gradients, b0_threshold=0)
    npt.assert_('B-values shape (5,)' in gt.__str__())
    npt.assert_array_almost_equal(gt.bvals, expected_bvals)
    npt.assert_array_equal(gt.b0s_mask, expected_b0s_mask)
    npt.assert_array_almost_equal(gt.bvecs, expected_bvecs)
    npt.assert_array_almost_equal(gt.gradients, gradients)

    gt = GradientTable(gradients, b0_threshold=1)
    npt.assert_array_equal(gt.b0s_mask, [1, 1, 1, 0, 0])
    npt.assert_array_equal(gt.bvals, expected_bvals)
    npt.assert_array_equal(gt.bvecs, expected_bvecs)

    # checks negative values in gtab
    npt.assert_raises(ValueError, GradientTable, -1)
    npt.assert_raises(ValueError, GradientTable, np.ones((6, 2)))
    npt.assert_raises(ValueError, GradientTable, np.ones((6,)))

    with warnings.catch_warnings(record=True) as l_warns:
        bad_gt = gradient_table(expected_bvals, expected_bvecs,
                                b0_threshold=200)

        # Select only UserWarning
        selected_w = [w for w in l_warns
                      if issubclass(w.category, UserWarning)]
        assert len(selected_w) >= 1
        msg = [str(m.message) for m in selected_w]
        npt.assert_equal('b0_threshold has a value > 199' in msg, True)