Ejemplo n.º 1
0
def test__pull_axis():
    a = Axis('x', 0, None)
    b = Axis('y', 1, None)
    c = Axis('z', 2, None)
    t_pos = Axis('y', 1, None)
    t_neg = Axis('x', 5, None)
    axes = [a, b, c]
    yield nt.assert_true, t_pos in axes
    yield nt.assert_false, t_neg in axes
    yield nt.assert_equal, axes, _pull_axis(axes, t_neg)
    yield nt.assert_equal, axes[:-1], _pull_axis(axes, c)
    new_axes = [a, Axis('z', 1, None)]
    yield nt.assert_equal, new_axes, _pull_axis(axes, t_pos)
Ejemplo n.º 2
0
def test__pull_axis():
    a = Axis('x', 0, None)
    b = Axis('y', 1, None)
    c = Axis('z', 2, None)
    t_pos = Axis('y', 1, None)
    t_neg = Axis('x', 5, None)
    axes = [a, b, c]
    yield nt.assert_true, t_pos in axes
    yield nt.assert_false, t_neg in axes
    yield nt.assert_equal, axes, _pull_axis(axes, t_neg)
    yield nt.assert_equal, axes[:-1], _pull_axis(axes, c)
    new_axes = [a, Axis('z', 1, None)]
    yield nt.assert_equal, new_axes, _pull_axis(axes, t_pos)
Ejemplo n.º 3
0
def test__pull_axis():
    a = Axis("x", 0, None)
    b = Axis("y", 1, None)
    c = Axis("z", 2, None)
    t_pos = Axis("y", 1, None)
    t_neg = Axis("x", 5, None)
    axes = [a, b, c]
    nt.assert_true(t_pos in axes)
    nt.assert_false(t_neg in axes)
    nt.assert_equal(axes, _pull_axis(axes, t_neg))
    nt.assert_equal(axes[:-1], _pull_axis(axes, c))
    new_axes = [a, Axis("z", 1, None)]
    nt.assert_equal(new_axes, _pull_axis(axes, t_pos))
Ejemplo n.º 4
0
def test__pull_axis():
    a = Axis('x', 0, None)
    b = Axis('y', 1, None)
    c = Axis('z', 2, None)
    t_pos = Axis('y', 1, None)
    t_neg = Axis('x', 5, None)
    axes = [a, b, c]
    nt.assert_true(t_pos in axes)
    nt.assert_false(t_neg in axes)
    nt.assert_equal(axes, _pull_axis(axes, t_neg))
    nt.assert_equal(axes[:-1], _pull_axis(axes, c))
    new_axes = [a, Axis('z', 1, None)]
    nt.assert_equal(new_axes, _pull_axis(axes, t_pos))
Ejemplo n.º 5
0
def assert_axes_correct(d_arr, op, axis):
    from datarray.datarray import _names_to_numbers, _pull_axis
    opr = getattr(d_arr, op)
    d = opr(axis=axis)
    axis_idx = _names_to_numbers(d_arr.axes, [axis])[0]
    if op not in accumulations:
        axes = _pull_axis(d_arr.axes, d_arr.axes[axis_idx])
    else:
        axes = d_arr.axes
    assert all( [ax1==ax2 for ax1, ax2 in zip(d.axes, axes)] ), \
           'mislabeled axes from operation %s'%op
Ejemplo n.º 6
0
def assert_axes_correct(d_arr, op, axis):
    from datarray.datarray import _names_to_numbers, _pull_axis
    opr = getattr(d_arr, op)
    d = opr(axis=axis)
    axis_idx = _names_to_numbers(d_arr.axes, [axis])[0]
    if op not in accumulations:
        axes = _pull_axis(d_arr.axes, d_arr.axes[axis_idx])
    else:
        axes = d_arr.axes
    assert all( [ax1==ax2 for ax1, ax2 in zip(d.axes, axes)] ), \
           'mislabeled axes from operation %s'%op