예제 #1
0
    def merge(ops):
        # Simple merge if all X are the same.
        if all(o.X is ops[0].X for o in ops):
            A, A_sigr = SigMerger.merge([o.A for o in ops])
            Y, Y_sigr = SigMerger.merge([o.Y for o in ops])
            return (op.DotInc(A, ops[0].X,
                              Y), Merger.merge_dicts(A_sigr, Y_sigr))

        assert all(o1.X is not o2.X for i, o1 in enumerate(ops)
                   for o2 in ops[i + 1:])

        # BSR merge if X differ
        X, X_sigr = SigMerger.merge([o.X for o in ops])
        Y, Y_sigr = SigMerger.merge([o.Y for o in ops])

        # Construct sparse A representation
        data = np.array([o.A.initial_value for o in ops], dtype=rc.float_dtype)
        if data.ndim == 1:
            raise NotImplementedError("A.ndim should be > 2")
        elif data.ndim == 2:
            raise NotImplementedError("A.ndim should be > 2")
        indptr = np.arange(len(ops) + 1, dtype=rc.int_dtype)
        indices = np.arange(len(ops), dtype=rc.int_dtype)
        name = "bsr_merged<{first}, ..., {last}>".format(first=ops[0].A.name,
                                                         last=ops[-1].A.name)
        readonly = all([o.A.readonly for o in ops])
        A = Signal(data, name=name, readonly=readonly)
        A_sigr = {}
        for i, s in enumerate([o.A for o in ops]):
            A_sigr[s] = Signal(
                data[i],
                name="%s[%i]" % (s.name, i),
                base=A,
                offset=i * A.itemsize * np.prod(A.shape[1:]),
            )
            assert np.allclose(s.initial_value,
                               A_sigr[s].initial_value,
                               atol=0,
                               rtol=0,
                               equal_nan=True)
            assert s.shape == A_sigr[s].shape or (s.shape == () and
                                                  A_sigr[s].shape == (1, 1))

        reshape = op.reshape_dot(
            ops[0].A.initial_value,
            ops[0].X.initial_value,
            ops[0].Y.initial_value,
            tag=ops[0].tag,
        )
        return (
            op.BsrDotInc(A,
                         X,
                         Y,
                         indices=indices,
                         indptr=indptr,
                         reshape=reshape),
            Merger.merge_dicts(X_sigr, Y_sigr, A_sigr),
        )
예제 #2
0
    def merge(ops):
        # Simple merge if all X are the same.
        if all(o.X is ops[0].X for o in ops):
            A, A_sigr = SigMerger.merge([o.A for o in ops])
            Y, Y_sigr = SigMerger.merge([o.Y for o in ops])
            return (operator.DotInc(A, ops[0].X,
                                    Y), Merger.merge_dicts(A_sigr, Y_sigr))

        assert all(o1.X is not o2.X for i, o1 in enumerate(ops)
                   for o2 in ops[i + 1:])

        # BSR merge if X differ
        X, X_sigr = SigMerger.merge([o.X for o in ops])
        Y, Y_sigr = SigMerger.merge([o.Y for o in ops])

        # Construct sparse A representation
        data = np.array([o.A.initial_value for o in ops])
        if data.ndim == 1:
            data = data.reshape((data.size, 1, 1))
        elif data.ndim == 2:
            data = data.reshape(data.shape + (1, ))
        indptr = np.arange(len(ops) + 1, dtype=int)
        indices = np.arange(len(ops), dtype=int)
        name = 'bsr_merged<{first}, ..., {last}>'.format(first=ops[0].A.name,
                                                         last=ops[-1].A.name)
        readonly = all([o.A.readonly for o in ops])
        A = Signal(data, name=name, readonly=readonly)
        A_sigr = {}
        for i, s in enumerate([o.A for o in ops]):
            A_sigr[s] = Signal(data[i],
                               name="%s[%i]" % (s.name, i),
                               base=A,
                               offset=i * A.itemsize * np.prod(A.shape[1:]))
            assert np.all(s.initial_value == A_sigr[s].initial_value)
            assert s.shape == A_sigr[s].shape or (s.shape == () and
                                                  A_sigr[s].shape == (1, 1))

        reshape = operator.reshape_dot(ops[0].A.initial_value,
                                       ops[0].X.initial_value,
                                       ops[0].Y.initial_value,
                                       tag=ops[0].tag)
        return (operator.BsrDotInc(A,
                                   X,
                                   Y,
                                   indices=indices,
                                   indptr=indptr,
                                   reshape=reshape),
                Merger.merge_dicts(X_sigr, Y_sigr, A_sigr))
예제 #3
0
파일: optimizer.py 프로젝트: nengo/nengo
    def merge(ops):
        # Simple merge if all X are the same.
        if all(o.X is ops[0].X for o in ops):
            A, A_sigr = SigMerger.merge([o.A for o in ops])
            Y, Y_sigr = SigMerger.merge([o.Y for o in ops])
            return (operator.DotInc(A, ops[0].X, Y),
                    Merger.merge_dicts(A_sigr, Y_sigr))

        assert all(o1.X is not o2.X
                   for i, o1 in enumerate(ops) for o2 in ops[i+1:])

        # BSR merge if X differ
        X, X_sigr = SigMerger.merge([o.X for o in ops])
        Y, Y_sigr = SigMerger.merge([o.Y for o in ops])

        # Construct sparse A representation
        data = np.array([o.A.initial_value for o in ops])
        if data.ndim == 1:
            data = data.reshape((data.size, 1, 1))
        elif data.ndim == 2:
            data = data.reshape(data.shape + (1,))
        indptr = np.arange(len(ops) + 1, dtype=int)
        indices = np.arange(len(ops), dtype=int)
        name = 'bsr_merged<{first}, ..., {last}>'.format(
            first=ops[0].A.name, last=ops[-1].A.name)
        readonly = all([o.A.readonly for o in ops])
        A = Signal(data, name=name, readonly=readonly)
        A_sigr = {}
        for i, s in enumerate([o.A for o in ops]):
            A_sigr[s] = Signal(data[i], name="%s[%i]" % (s.name, i), base=A,
                               offset=i * A.itemsize * np.prod(A.shape[1:]))
            assert np.all(s.initial_value == A_sigr[s].initial_value)
            assert s.shape == A_sigr[s].shape or (
                s.shape == () and A_sigr[s].shape == (1, 1))

        reshape = operator.reshape_dot(
            ops[0].A.initial_value, ops[0].X.initial_value,
            ops[0].Y.initial_value, tag=ops[0].tag)
        return (
            operator.BsrDotInc(
                A, X, Y, indices=indices, indptr=indptr, reshape=reshape),
            Merger.merge_dicts(X_sigr, Y_sigr, A_sigr))
예제 #4
0
def test_reshape_dot(rng):
    scalar = np.array(1)
    vec = [np.ones(i) for i in range(4)]
    mat11 = np.ones((1, 1))
    mat23 = np.ones((2, 3))
    mat33 = np.ones((3, 3))

    # if A.shape == ():
    assert reshape_dot(A=scalar, X=scalar, Y=scalar) is True
    assert reshape_dot(A=scalar, X=vec[2], Y=vec[2]) is False
    with pytest.raises(BuildError, match="shape mismatch"):
        reshape_dot(A=scalar, X=vec[3], Y=vec[1])

    # elif X.shape == ():
    assert reshape_dot(A=vec[1], X=scalar, Y=vec[1]) is True
    assert reshape_dot(A=vec[2], X=scalar, Y=vec[2]) is False
    with pytest.raises(BuildError, match="shape mismatch"):
        reshape_dot(A=vec[2], X=scalar, Y=vec[1])

    # elif X.ndim == 1:
    assert reshape_dot(A=vec[0], X=vec[0], Y=vec[0]) is False
    assert reshape_dot(A=vec[1], X=vec[1], Y=vec[1]) is True
    assert reshape_dot(A=vec[1], X=vec[1], Y=scalar) is True
    assert reshape_dot(A=vec[2], X=vec[2], Y=vec[2]) is False
    assert reshape_dot(A=mat23, X=vec[3], Y=vec[2]) is False
    with pytest.raises(BuildError, match="shape mismatch"):
        reshape_dot(A=mat23, X=vec[2], Y=vec[2])
    with pytest.raises(BuildError, match="shape mismatch"):
        reshape_dot(A=mat23, X=vec[3], Y=vec[1])

    # else:
    assert reshape_dot(A=mat11, X=mat11, Y=mat11) is True
    assert reshape_dot(A=mat23, X=mat33, Y=mat23) is False
    with pytest.raises(BuildError, match="shape mismatch"):
        reshape_dot(A=mat11, X=mat23, Y=mat23)
    with pytest.raises(BuildError, match="shape mismatch"):
        reshape_dot(A=mat23, X=mat33, Y=mat33)
    with pytest.raises(BuildError, match="shape mismatch"):
        reshape_dot(A=mat23, X=mat33, Y=vec[2])