예제 #1
0
def test_parse_shape_symbolic():
    backends = collect_test_backends(symbolic=True, layers=False)
    backends += collect_test_backends(symbolic=True, layers=True)
    for backend in backends:
        if backend.framework_name == 'keras':
            # need special way to compile, shape vars can be used only inside layers
            continue
        print('special shape parsing for', backend.framework_name)
        input_symbols = [
            backend.create_symbol([10, 20, 30, 40]),
            backend.create_symbol([10, 20, None, None]),
            backend.create_symbol([None, None, None, None]),
        ]
        if backend.framework_name in ['mxnet.symbol']:
            # mxnet can't normally run inference
            input_symbols = [backend.create_symbol([10, 20, 30, 40])]

        for input_symbol in input_symbols:
            shape_placeholder = parse_shape(input_symbol, 'a b c d')
            shape = {}
            for name, symbol in shape_placeholder.items():
                shape[name] = symbol if isinstance(symbol, int) \
                    else backend.eval_symbol(symbol, [(input_symbol, numpy.zeros([10, 20, 30, 40]))])
            print(shape)
            result_placeholder = rearrange(
                input_symbol,
                'a b (c1 c2) (d1 d2) -> (a b d1) c1 (c2 d2)',
                **parse_shape(input_symbol, 'a b c1 _'),
                d2=2)
            result = backend.eval_symbol(
                result_placeholder,
                [(input_symbol, numpy.zeros([10, 20, 30, 40]))])
            print(result.shape)
            assert result.shape == (10 * 20 * 20, 30, 1 * 2)
            assert numpy.allclose(result, 0)
예제 #2
0
파일: test_ops.py 프로젝트: ml-lab/einops
    def test6(x):
        # parsing parameters
        t = rearrange(x, 'b c h w -> (b h w) c')
        t = t[:, ::
              2]  # replacement for dot-product, just changes size of second axis
        assert t.shape == (10 * 30 * 40, 10)

        y = rearrange(t, '(b h w) c2 -> b c2 h w', **parse_shape(x, 'b _ h w'))
        assert y.shape == (10, 10, 30, 40)
        return y
예제 #3
0
    def test_parse_shape_symbolic(self, shape):
        print('special shape parsing for', self.backend.framework_name)
        if self.backend.framework_name in ['mxnet.symbol']:
            # mxnet can't normally run inference
            shape = [10, 20, 30, 40]
        input_symbol = self.backend.create_symbol(shape)

        shape_placeholder = parse_shape(input_symbol, 'a b c d')
        shape = {}
        for name, symbol in shape_placeholder.items():
            shape[name] = symbol if isinstance(symbol, int) \
                else self.backend.eval_symbol(symbol, [(input_symbol, numpy.zeros([10, 20, 30, 40]))])
        print(shape)
        result_placeholder = rearrange(input_symbol, 'a b (c1 c2) (d1 d2) -> (a b d1) c1 (c2 d2)',
                                       **parse_shape(input_symbol, 'a b c1 _'), d2=2)
        result = self.backend.eval_symbol(result_placeholder, [(input_symbol, numpy.zeros([10, 20, 30, 40]))])
        print(result.shape)
        assert result.shape == (10 * 20 * 20, 30, 1 * 2)
        assert numpy.allclose(result, 0)
예제 #4
0
 def test_ellipsis(self, static_shape: List[int], shape: List[Optional[int]],
                   pattern: str, expected: Dict[str, int]):
     if self.backend.framework_name in ['mxnet.symbol']:
         # mxnet can't normally run inference
         shape = static_shape
     input_symbol = self.backend.create_symbol(shape)
     shape_placeholder = parse_shape(input_symbol, pattern)
     out_shape = {}
     for name, symbol in shape_placeholder.items():
         if isinstance(symbol, int):
             out_shape[name] = symbol
         else:
             out_shape[name] = self.backend.eval_symbol(
                 symbol, [(input_symbol, numpy.zeros(static_shape))])
     assert out_shape == expected
예제 #5
0
def test_parse_shape_imperative():
    backends = collect_test_backends(symbolic=False, layers=False)
    backends += collect_test_backends(symbolic=False, layers=True)
    for backend in backends:
        print('Shape parsing for ', backend.framework_name)
        x = numpy.zeros([10, 20, 30, 40])
        parsed1 = parse_shape(x, 'a b c d')
        parsed2 = parse_shape(backend.from_numpy(x), 'a b c d')
        assert parsed1 == parsed2 == dict(a=10, b=20, c=30, d=40)
        assert parsed1 != dict(a=1, b=20, c=30, d=40) != parsed2

        parsed1 = parse_shape(x, '_ _ _ _')
        parsed2 = parse_shape(backend.from_numpy(x), '_ _ _ _')
        assert parsed1 == parsed2 == dict()

        parsed1 = parse_shape(x, '_ _ _ hello')
        parsed2 = parse_shape(backend.from_numpy(x), '_ _ _ hello')
        assert parsed1 == parsed2 == dict(hello=40)

        parsed1 = parse_shape(x, '_ _ a1 a1a111a')
        parsed2 = parse_shape(backend.from_numpy(x), '_ _ a1 a1a111a')
        assert parsed1 == parsed2 == dict(a1=30, a1a111a=40)
예제 #6
0
 def test_underscore_one(self):
     parsed1 = parse_shape(self.x, '_ _ _ hello')
     parsed2 = parse_shape(self.backend.from_numpy(self.x), '_ _ _ hello')
     assert parsed1 == parsed2 == dict(hello=40)
예제 #7
0
 def test_parse_shape_imperative(self):
     print('Shape parsing for ', self.backend.framework_name)
     parsed1 = parse_shape(self.x, 'a b c d')
     parsed2 = parse_shape(self.backend.from_numpy(self.x), 'a b c d')
     assert parsed1 == parsed2 == dict(a=10, b=20, c=30, d=40)
     assert parsed1 != dict(a=1, b=20, c=30, d=40) != parsed2
예제 #8
0
 def test_ellipsis(self, shape: List[int], pattern: str,
                   expected: Dict[str, int]):
     x = numpy.ones(shape)
     parsed1 = parse_shape(x, pattern)
     parsed2 = parse_shape(self.backend.from_numpy(x), pattern)
     assert parsed1 == parsed2 == expected
예제 #9
0
    def test_repeating(self):
        with assert_raises(einops.EinopsError):
            parse_shape(self.x, 'a a b b')

        with assert_raises(einops.EinopsError):
            parse_shape(self.backend.from_numpy(self.x), 'a a b b')
예제 #10
0
 def test_underscore_several(self):
     parsed1 = parse_shape(self.x, '_ _ a1 a1a111a')
     parsed2 = parse_shape(self.backend.from_numpy(self.x), '_ _ a1 a1a111a')
     assert parsed1 == parsed2 == dict(a1=30, a1a111a=40)