def test_disconnected_paths(self): # Test that taking gradient going through a disconnected # path rasises an exception a = np.asarray(self.rng.randn(5, 5), dtype=config.floatX) x = matrix("x") # This MUST raise a DisconnectedInputError error. # This also rasies an additional warning from gradients.py. with pytest.raises(DisconnectedInputError): grad(disconnected_grad(x).sum(), x) # This MUST NOT raise a DisconnectedInputError error. y = grad((x + disconnected_grad(x)).sum(), x) a = matrix("a") b = matrix("b") y = a + disconnected_grad(b) # This MUST raise a DisconnectedInputError error. # This also rasies an additional warning from gradients.py. with pytest.raises(DisconnectedInputError): grad(y.sum(), b) # This MUST NOT raise a DisconnectedInputError error. grad(y.sum(), a)
def test_connection_pattern(self): T = aesara.tensor x = T.matrix("x") y = gradient.disconnected_grad(x) connection_pattern = y.owner.op.connection_pattern(y.owner) assert connection_pattern == [[False]]
def test_op_removed(self): x = matrix("x") y = x * disconnected_grad(x) f = aesara.function([x], y) # need to refer to aesara.disconnected_grad here, # aesara.disconnected_grad is a wrapper function! assert disconnected_grad_ not in [node.op for node in f.maker.fgraph.toposort()]
def test_grad(self): a = np.asarray(self.rng.randn(5, 5), dtype=config.floatX) x = matrix("x") expressions_gradients = [ (x * disconnected_grad(x), x), (x * disconnected_grad(exp(x)), exp(x)), (x**2 * disconnected_grad(x), 2 * x**2), ] for expr, expr_grad in expressions_gradients: g = grad(expr.sum(), x) # gradient according to aesara f = aesara.function([x], g, on_unused_input="ignore") # desired gradient f2 = aesara.function([x], expr_grad, on_unused_input="ignore") assert np.allclose(f(a), f2(a))
def incsubtensor_logp(op, var, rvs_to_values, indexed_rv_var, rv_values, *indices, **kwargs): index = indices_from_subtensor(getattr(op, "idx_list", None), indices) _, (new_rv_var, ) = clone( tuple(v for v in graph_inputs((indexed_rv_var, )) if not isinstance(v, Constant)), (indexed_rv_var, ), copy_inputs=False, copy_orphans=False, ) new_values = at.set_subtensor( disconnected_grad(new_rv_var)[index], rv_values) logp_var = logpt(indexed_rv_var, new_values, **kwargs) return logp_var
def test_connection_pattern(self): x = matrix("x") y = disconnected_grad(x) connection_pattern = y.owner.op.connection_pattern(y.owner) assert connection_pattern == [[False]]