def log(expr): """Logarithm""" if type(expr) == GC: return GC(se.log(expr.expr), {s: d / expr.expr for s, d in expr.gradients.items()}) return se.log(expr)
def acos(expr): """Arccosine""" if type(expr) == GC: return GC(se.acos(expr.expr), { s: -d / se.sqrt(1 - expr.expr**2) for s, d in expr.gradients.items() }) return se.acos(expr)
def tan(expr): """Tangent""" if type(expr) == GC: return GC(se.tan(expr.expr), { s: d * (1 + se.tan(expr.expr)**2) for s, d in expr.gradients.items() }) return se.tan(expr)
def cos(expr): """Cosine""" if type(expr) == GC: return GC( se.cos(expr.expr), {s: -se.sin(expr.expr) * d for s, d in expr.gradients.items()}) return se.cos(expr)
def sqrt(expr): """Square root""" if type(expr) == GC: return GC(se.sqrt(expr.expr), { s: d / (2 * se.sqrt(expr.expr)) for s, d in expr.gradients.items() }) return se.sqrt(expr)
def exp(expr): """Exponential""" if type(expr) == GC: return GC( se.exp(expr.expr), {s: d * se.exp(expr.expr) for s, d in expr.gradients.items()}) return se.exp(expr)
def atanh(expr): """Hyperbolic arctangent""" if type(expr) == GC: return GC( se.atanh(expr.expr), {s: d / (1 - expr.expr**2) for s, d in expr.gradients.items()}) return se.atanh(expr)
def acosh(expr): """Hyperbolic arccosine""" if type(expr) == GC: return GC(se.acosh(expr.expr), { s: d / se.sqrt(expr.expr**2 - 1) for s, d in expr.gradients.items() }) return se.acosh(expr)
def tanh(expr): """Hyperbolic tangent""" if type(expr) == GC: return GC(se.tanh(expr.expr), { s: d * (1 - se.tanh(expr.expr)**2) for s, d in expr.gradients.items() }) return se.tanh(expr)
def cosh(expr): """Hyperbolic cosine""" if type(expr) == GC: return GC( se.cosh(expr.expr), {s: d * se.sinh(expr.expr) for s, d in expr.gradients.items()}) return se.cosh(expr)
def atan(expr): """Arctangent""" if type(expr) == GC: return GC( se.atan(expr.expr), {s: d / (1 + expr.expr**2) for s, d in expr.gradients.items()}) return se.atan(expr)
def abs(expr): """Absolute value""" if type(expr) == GC: return GC( fake_abs(expr.expr), { s: d * expr.expr / se.sqrt(expr.expr**2) for s, d in expr.gradients.items() }) return fake_abs(expr)
def get_diff(term, symbols=None): """Returns the derivative of a passed expression.""" if type(term) == Symbol: return get_diff_symbol(term) if type(term) != GC: term = GC(term) if symbols is None: term.do_full_diff() return sum([s * t for s, t in term.gradients.items()]) else: return sum([s * term[s] for s in symbols if s in term])
def greater_than(x, y): """Creates a gradient approximating the :math:`x > y` expression. The gradient contains a fake derivative mapping the velocity of x to True and the velocity of y to False.""" fake_diffs = {} if type(y) == Symbol: fake_diffs[get_diff_symbol(y)] = -1 else: if type(y) in symengine_types: y = GC(y) if type(y) == GC: y.do_full_diff() fake_diffs = {s: -g for s, g in y.gradients.items()} if type(x) == Symbol: x_d = get_diff_symbol(x) if x_d in fake_diffs: fake_diffs[x_d] += 1 else: fake_diffs[x_d] = 1 else: if type(x) in symengine_types: x = GC(x) if type(x) == GC: x.do_full_diff() fake_diffs = merge_gradients_add(fake_diffs, x.gradients) return GC(0.5 * tanh((x - y) * contrast) + 0.5, fake_diffs)
def alg_and(x, y): """Creates a gradient approximating the :math:`x \wedge y` expression by means of multiplication. x, y are assumed to be boolean approximations resulting in 1 for truth and 0 for falsehood.""" if type(x) is GC: if type(y) in symengine_types: y = GC(y) if type(y) is GC: y.do_full_diff() return GC(x.expr * y.expr, merge_gradients_add(x.gradients, y.gradients)) return GC(x.expr * y, x.gradients) elif type(y) is GC: if type(x) in symengine_types: x = GC(x) if type(x) is GC: x.do_full_diff() return GC(y.expr * y.expr, merge_gradients_add(x.gradients, y.gradients)) return GC(y.expr * x, y.gradients) return x * y
def wrap_expr(expr): return expr if type(expr) == GC else GC(expr)