示例#1
0
        def test_specify_shape(self):
            dtype = self.dtype
            if dtype is None:
                dtype = aesara.config.floatX

            rng = np.random.RandomState(utt.fetch_seed())
            x1_1 = np.asarray(rng.uniform(1, 2, [4, 2]), dtype=dtype)
            x1_1 = self.cast_value(x1_1)
            x1_2 = np.asarray(rng.uniform(1, 2, [4, 2]), dtype=dtype)
            x1_2 = self.cast_value(x1_2)
            x2 = np.asarray(rng.uniform(1, 2, [4, 3]), dtype=dtype)
            x2 = self.cast_value(x2)

            # Test that we can replace with values of the same shape
            x1_shared = self.shared_constructor(x1_1)
            x1_specify_shape = specify_shape(x1_shared, x1_1.shape)
            x1_shared.set_value(x1_2)
            assert np.allclose(
                self.ref_fct(x1_shared.get_value(borrow=True)), self.ref_fct(x1_2)
            )
            shape_op_fct = aesara.function([], x1_shared.shape)
            topo = shape_op_fct.maker.fgraph.toposort()
            if aesara.config.mode != "FAST_COMPILE":
                assert len(topo) == 3
                assert isinstance(topo[0].op, Shape_i)
                assert isinstance(topo[1].op, Shape_i)
                assert isinstance(topo[2].op, MakeVector)

            # Test that we forward the input
            specify_shape_fct = aesara.function([], x1_specify_shape)
            assert np.all(self.ref_fct(specify_shape_fct()) == self.ref_fct(x1_2))
            topo_specify = specify_shape_fct.maker.fgraph.toposort()
            assert len(topo_specify) == 2

            # Test that we put the shape info into the graph
            shape_constant_fct = aesara.function([], x1_specify_shape.shape)
            assert np.all(shape_constant_fct() == shape_op_fct())
            topo_cst = shape_constant_fct.maker.fgraph.toposort()
            if aesara.config.mode != "FAST_COMPILE":
                assert len(topo_cst) == 1
                topo_cst[0].op == aesara.compile.function.types.deep_copy_op

            # Test that we can take the grad.
            if aesara.sparse.enable_sparse and isinstance(
                x1_specify_shape.type, aesara.sparse.SparseType
            ):
                # SparseVariable don't support sum for now.
                assert not hasattr(x1_specify_shape, "sum")
            else:
                shape_grad = aesara.gradient.grad(x1_specify_shape.sum(), x1_shared)
                shape_constant_fct_grad = aesara.function([], shape_grad)
                # aesara.printing.debugprint(shape_constant_fct_grad)
                shape_constant_fct_grad()

            # Test that we can replace with values of the different shape
            # but that will raise an error in some case, but not all
            specify_shape_fct()
            x1_shared.set_value(x2)
            with pytest.raises(AssertionError):
                specify_shape_fct()

            # No assertion will be raised as the Op is removed from the graph
            # when their is optimization
            if aesara.config.mode not in ["FAST_COMPILE", "DebugMode", "DEBUG_MODE"]:
                shape_constant_fct()
            else:
                with pytest.raises(AssertionError):
                    shape_constant_fct()
示例#2
0
        def test_specify_shape_inplace(self):
            # test that specify_shape don't break inserting inplace op

            dtype = self.dtype
            if dtype is None:
                dtype = aesara.config.floatX

            rng = np.random.default_rng(utt.fetch_seed())
            a = np.asarray(rng.uniform(1, 2, [40, 40]), dtype=dtype)
            a = self.cast_value(a)
            a_shared = self.shared_constructor(a)
            b = np.asarray(rng.uniform(1, 2, [40, 40]), dtype=dtype)
            b = self.cast_value(b)
            b_shared = self.shared_constructor(b)
            s = np.zeros((40, 40), dtype=dtype)
            s = self.cast_value(s)
            s_shared = self.shared_constructor(s)
            f = aesara.function(
                [],
                updates=[(s_shared,
                          aesara.tensor.dot(a_shared, b_shared) + s_shared)],
            )
            topo = f.maker.fgraph.toposort()
            f()
            # [Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 2e-06)]
            if aesara.config.mode != "FAST_COMPILE":
                assert (sum([
                    node.op.__class__.__name__
                    in ["Gemm", "GpuGemm", "StructuredDot"] for node in topo
                ]) == 1)
                assert all(node.op == aesara.tensor.blas.gemm_inplace
                           for node in topo
                           if isinstance(node.op, aesara.tensor.blas.Gemm))
                assert all(node.op.inplace for node in topo
                           if node.op.__class__.__name__ == "GpuGemm")
            # Their is no inplace gemm for sparse
            # assert all(node.op.inplace for node in topo if node.op.__class__.__name__ == "StructuredDot")
            s_shared_specify = specify_shape(
                s_shared,
                s_shared.get_value(borrow=True).shape)

            # now test with the specify shape op in the output
            f = aesara.function(
                [],
                s_shared.shape,
                updates=[
                    (s_shared,
                     aesara.tensor.dot(a_shared, b_shared) + s_shared_specify)
                ],
            )
            topo = f.maker.fgraph.toposort()
            shp = f()
            assert np.all(shp == (40, 40))
            if aesara.config.mode != "FAST_COMPILE":
                assert (sum([
                    node.op.__class__.__name__
                    in ["Gemm", "GpuGemm", "StructuredDot"] for node in topo
                ]) == 1)
                assert all(node.op == aesara.tensor.blas.gemm_inplace
                           for node in topo
                           if isinstance(node.op, aesara.tensor.blas.Gemm))
                assert all(node.op.inplace for node in topo
                           if node.op.__class__.__name__ == "GpuGemm")
            # now test with the specify shape op in the inputs and outputs
            a_shared = specify_shape(a_shared,
                                     a_shared.get_value(borrow=True).shape)
            b_shared = specify_shape(b_shared,
                                     b_shared.get_value(borrow=True).shape)

            f = aesara.function(
                [],
                s_shared.shape,
                updates=[
                    (s_shared,
                     aesara.tensor.dot(a_shared, b_shared) + s_shared_specify)
                ],
            )
            topo = f.maker.fgraph.toposort()
            shp = f()
            assert np.all(shp == (40, 40))
            if aesara.config.mode != "FAST_COMPILE":
                assert (sum([
                    node.op.__class__.__name__
                    in ["Gemm", "GpuGemm", "StructuredDot"] for node in topo
                ]) == 1)
                assert all(node.op == aesara.tensor.blas.gemm_inplace
                           for node in topo
                           if isinstance(node.op, aesara.tensor.blas.Gemm))
                assert all(node.op.inplace for node in topo
                           if node.op.__class__.__name__ == "GpuGemm")
示例#3
0
 def test_specifyshape(self):
     self.check_rop_lop(specify_shape(self.x, self.in_shape), self.in_shape)
示例#4
0
        def test_specify_shape_partial(self):
            dtype = self.dtype
            if dtype is None:
                dtype = aesara.config.floatX

            rng = np.random.default_rng(utt.fetch_seed())
            x1_1 = np.asarray(rng.uniform(1, 2, [4, 2]), dtype=dtype)
            x1_1 = self.cast_value(x1_1)
            x1_2 = np.asarray(rng.uniform(1, 2, [4, 2]), dtype=dtype)
            x1_2 = self.cast_value(x1_2)
            x2 = np.asarray(rng.uniform(1, 2, [5, 2]), dtype=dtype)
            x2 = self.cast_value(x2)

            # Test that we can replace with values of the same shape
            x1_shared = self.shared_constructor(x1_1)
            x1_specify_shape = specify_shape(
                x1_shared,
                (aet.as_tensor_variable(x1_1.shape[0]), x1_shared.shape[1]),
            )
            x1_shared.set_value(x1_2)
            assert np.allclose(self.ref_fct(x1_shared.get_value(borrow=True)),
                               self.ref_fct(x1_2))
            shape_op_fct = aesara.function([], x1_shared.shape)
            topo = shape_op_fct.maker.fgraph.toposort()
            shape_op_fct()
            if aesara.config.mode != "FAST_COMPILE":
                assert len(topo) == 3
                assert isinstance(topo[0].op, Shape_i)
                assert isinstance(topo[1].op, Shape_i)
                assert isinstance(topo[2].op, MakeVector)

            # Test that we forward the input
            specify_shape_fct = aesara.function([], x1_specify_shape)
            specify_shape_fct()
            # aesara.printing.debugprint(specify_shape_fct)
            assert np.all(
                self.ref_fct(specify_shape_fct()) == self.ref_fct(x1_2))
            topo_specify = specify_shape_fct.maker.fgraph.toposort()
            if aesara.config.mode != "FAST_COMPILE":
                assert len(topo_specify) == 4

            # Test that we put the shape info into the graph
            shape_constant_fct = aesara.function([], x1_specify_shape.shape)
            # aesara.printing.debugprint(shape_constant_fct)
            assert np.all(shape_constant_fct() == shape_op_fct())
            topo_cst = shape_constant_fct.maker.fgraph.toposort()
            if aesara.config.mode != "FAST_COMPILE":
                assert len(topo_cst) == 2

            # Test that we can replace with values of the different shape
            # but that will raise an error in some case, but not all
            x1_shared.set_value(x2)
            with pytest.raises(AssertionError):
                specify_shape_fct()

            # No assertion will be raised as the Op is removed from the graph
            if aesara.config.mode not in [
                    "FAST_COMPILE", "DebugMode", "DEBUG_MODE"
            ]:
                shape_constant_fct()
            else:
                with pytest.raises(AssertionError):
                    shape_constant_fct()