def test_arithmetic_diff(): x = autodiff.Variable("x") direction = {"x": 1} expr = x + x assert expr.forward_diff(direction, dict(x=5)) == 2 assert expr.forward_diff(direction, dict(x=1)) == 2 expr = (x + autodiff.Constant(10)) * (x + autodiff.Constant(5)) # D_x[expr] = D_x[x ** 2 + 15 x + 150] = 2 * x + 15 assert expr.forward_diff(direction, dict(x=5)) == 25 assert expr.forward_diff(direction, dict(x=0)) == 15 expr = (x + autodiff.Constant(10)) / (x * x) # D_x[expr] = D_x[(x + 10) / (x ** 2)] = - 1 / x**2 - 20 / x ** 3 assert expr.forward_diff(direction, dict(x=1)) == -21 assert expr.forward_diff(direction, dict(x=2)) == -0.25 - 20 / 8 expr = x for _ in range(50): expr = expr * x assert expr.forward_diff(direction, dict(x=0)) == 0 assert expr.forward_diff(direction, dict(x=1)) == 51 assert expr.forward_diff(direction, dict(x=2)) == 51 * 2**50
def test_constant_expr(): x = autodiff.Constant(10) y = autodiff.Constant(2) assert x.eval({}) == 10 assert y.eval({}) == 2 assert (x + y).eval({}) == 12 assert (x * y).eval({}) == 20 assert (x - y).eval({}) == 8 assert (x / y).eval({}) == 5 assert (x**y).eval({}) == 100 assert (x + x + x).eval({}) == 30 assert (x - x - x).eval({}) == -10 assert (x * x * x).eval({}) == 1000 assert (x / x / x).eval({}) == 0.1
def test_power_diff(): x = autodiff.Variable("x") direction = {"x": 1} expr = x**autodiff.Constant(2) assert expr.forward_diff(direction, dict(x=0)) == 0 assert expr.forward_diff(direction, dict(x=1)) == 2 assert expr.forward_diff(direction, dict(x=2)) == 4 expr = autodiff.Constant(math.e)**x assert expr.forward_diff(direction, dict(x=0)) == 1 assert expr.forward_diff(direction, dict(x=1)) == math.e assert expr.forward_diff(direction, dict(x=2)) == math.e**2 expr = x for i in range(100): expr = expr**x assert expr.forward_diff(direction, dict(x=1)) == 1
def loss_func(self, feed_dict={}): batch_size = feed_dict["predicted_y"].shape[0] return ad.add( ad.negative( ad.__getitem__( ad.Placeholder(feed_dict["predicted_y"]), ad.Constant( tuple([range(batch_size), feed_dict["true_y"].ravel()])))), ad.log( ad.sum(ad.exp(ad.Placeholder(feed_dict["predicted_y"])), axis=1)))
def test_variable_expr(): x = autodiff.Variable("x") y = autodiff.Constant(10) assert x.eval(dict(x=5)) == 5 assert (x + y).eval(dict(x=5)) == 15 assert (x - y).eval(dict(x=5)) == -5 assert (x * y).eval(dict(x=5)) == 50 assert (x / y).eval(dict(x=5)) == 0.5 assert (x**y).eval(dict(x=2)) == 2**10 with pytest.raises(KeyError): x.eval({})
def make_graph(func, *inputs): """ Make AutoHOOT graph based on the input function and inputs. Parameters ---------- func: The input function. inputs: The input tensors of the input function. Returns ------- out_list: The out node list. variable_list: The input variables list. """ jaxpr, consts = make_jaxpr(func)(*inputs) node_set = {} variable_list = [] for i, var in enumerate(jaxpr.invars): variable = ad.Variable(name=str(var), shape=list(inputs[i].shape)) node_set[str(var)] = variable variable_list.append(variable) for i, const in enumerate(jaxpr.constvars): node_set[str(const)] = ad.Constant(name=str(const), shape=list(consts[i].shape), value=consts[i]) for eqn in jaxpr.eqns: assert len(eqn.outvars) == 1 outname = str(eqn.outvars[0]) innodes = [node_set[str(var)] for var in eqn.invars] node_set[outname] = parse_jax_operator(eqn.primitive, eqn.params, innodes) out_list = [node_set[str(var)] for var in jaxpr.outvars] return out_list, variable_list
def loss_func(self, feed_dict={}): diff = ad.subtract(ad.Placeholder(feed_dict['predicted_y']), ad.Placeholder(feed_dict['true_y'])) return ad.multiply(ad.power(diff, ad.Constant(2)), ad.Constant(1 / 2))
def test4(x): return ad.multiply(ad.power(ad.Variable(x), ad.Constant(2)), ad.Constant(1 / 2))
def test3(x): return ad.power(ad.Variable(x), ad.Constant(3))
def test(x, y, z, w): return ad.multiply( ad.add(ad.multiply(ad.Variable(x), ad.Variable(y)), ad.maximum(ad.Variable(z), ad.Variable(w))), ad.Constant(2))
def tanh(x): return ad.divide( ad.subtract(ad.Constant(1), ad.exp(ad.negative(ad.Variable(x)))), ad.add(ad.Constant(1), ad.exp(ad.negative(ad.Variable(x)))))