예제 #1
0
def test_sigmerger_check():
    # 0-d signals
    assert SigMerger.check([Signal(0), Signal(0)])
    assert not SigMerger.check([Signal(0), Signal(1)])

    # compatible along first axis
    assert SigMerger.check(
        [Signal(np.empty((1, 2))),
         Signal(np.empty((2, 2)))])

    # compatible along second axis
    assert SigMerger.check(
        [Signal(np.empty(
            (2, 1))), Signal(np.empty((2, 2)))], axis=1)
    assert not SigMerger.check(
        [Signal(np.empty(
            (2, 1))), Signal(np.empty((2, 2)))], axis=0)

    # shape mismatch
    assert not SigMerger.check(
        [Signal(np.empty(
            (2, ))), Signal(np.empty((2, 2)))])

    # mixed dtype
    assert not SigMerger.check(
        [Signal(np.empty(2, dtype=int)),
         Signal(np.empty(2, dtype=float))])

    s1 = Signal(np.empty(5))
    s2 = Signal(np.empty(5))

    # mixed signal and view
    assert not SigMerger.check([s1, s1[:3]])

    # mixed bases
    assert not SigMerger.check([s1[:2], s2[2:]])

    # compatible views
    assert SigMerger.check([s1[:2], s1[2:]])

    # sparse signals not mergeable
    assert not SigMerger.check([
        Signal(SparseMatrix([[0, 0]], 1.0, (1, 1))),
        Signal(SparseMatrix([[0, 0]], 1.0, (1, 1))),
    ])

    # same signal cannot appear twice
    sig = Signal(0)
    assert not SigMerger.check([sig, sig])
예제 #2
0
def test_transforms():
    check_init_args(Dense, ["shape", "init"])
    # No check_repr because dense matrices are usually too big
    assert repr(Dense((1, 2), init=[[1, 1]])) == "Dense(shape=(1, 2))"

    check_init_args(
        Convolution,
        [
            "n_filters",
            "input_shape",
            "kernel_size",
            "strides",
            "padding",
            "channels_last",
            "init",
        ],
    )
    check_repr(Convolution(n_filters=3, input_shape=(1, 2, 3)))
    check_repr(
        Convolution(n_filters=3, input_shape=(1, 2, 3), kernel_size=(3, 2)))
    check_repr(
        Convolution(n_filters=3, input_shape=(1, 2, 3), channels_last=False))
    assert (repr(Convolution(
        n_filters=3,
        input_shape=(1, 2,
                     3))) == "Convolution(n_filters=3, input_shape=(1, 2, 3))")
    assert (repr(
        Convolution(n_filters=3, input_shape=(1, 2, 3), kernel_size=(3, 2))
    ) == "Convolution(n_filters=3, input_shape=(1, 2, 3), kernel_size=(3, 2))")
    assert (
        repr(
            Convolution(n_filters=3,
                        input_shape=(1, 2, 3),
                        channels_last=False)) ==
        "Convolution(n_filters=3, input_shape=(1, 2, 3), channels_last=False)")

    check_init_args(Sparse, ["shape", "indices", "init"])
    # No check_repr because sparse matrices are usually too big
    assert repr(Sparse((1, 1), indices=[[1, 1],
                                        [1, 1]])) == "Sparse(shape=(1, 1))"
    assert (repr(Sparse((1, 1), indices=[[1, 1], [1, 1], [1, 1]],
                        init=2)) == "Sparse(shape=(1, 1))")

    check_init_args(SparseMatrix, ["indices", "data", "shape"])
    check_repr(
        SparseMatrix(indices=[[1, 2], [3, 4]], data=[5, 6], shape=(7, 8)))
    assert repr(SparseMatrix(((1, 2), (3, 4)), (5, 6), (7, 8))).replace(
        ", dtype=int64",
        "") == ("SparseMatrix(indices=array([[1, 2],\n       [3, 4]]), "
                "data=array([5, 6]), shape=(7, 8))")

    check_init_args(ChannelShape, ["shape", "channels_last"])
    check_repr(ChannelShape(shape=(1, 2, 3), channels_last=True))
    assert (repr(ChannelShape(
        (1, 2, 3))) == "ChannelShape(shape=(1, 2, 3), channels_last=True)")
    assert (repr(ChannelShape((1, 2, 3), channels_last=False)) ==
            "ChannelShape(shape=(1, 2, 3), channels_last=False)")

    # __str__ always has channels last
    assert str(ChannelShape((1, 2, 3))) == "(1, 2, ch=3)"
    assert str(ChannelShape((1, 2, 3), channels_last=False)) == "(ch=1, 2, 3)"

    check_init_args(NoTransform, ["size_in"])
    check_repr(NoTransform(size_in=1))
    for dimensions in range(2):
        assert repr(
            NoTransform(dimensions)) == "NoTransform(size_in=%d)" % dimensions