Esempio n. 1
0
 def grad(self, inputs, grads):
     (k, x) = inputs
     (gz, ) = grads
     return [
         gz * gammaincc_der(k, x),
         gz * -exp(-x + (k - 1) * log(x) - gammaln(k)),
     ]
Esempio n. 2
0
    def grad(self, inp, grads):
        a, b, x = inp
        (gz, ) = grads

        return [
            gz * betainc_der(a, b, x, True),
            gz * betainc_der(a, b, x, False),
            gz * exp(
                log1p(-x) * (b - 1) + log(x) * (a - 1) -
                (gammaln(a) + gammaln(b) - gammaln(a + b))),
        ]
Esempio n. 3
0
    def L_op(self, inputs, outputs, grads):
        (x, ) = inputs
        (gz, ) = grads
        if x.type in complex_types:
            raise NotImplementedError()
        if outputs[0].type in discrete_types:
            if x.type in discrete_types:
                return [x.zeros_like(dtype=config.floatX)]
            else:
                return [x.zeros_like()]

        cst = np.asarray(2.0 / np.sqrt(np.pi),
                         dtype=upcast(x.type.dtype, gz.type.dtype))
        return (-gz * cst * exp(-x * x), )
Esempio n. 4
0
def test_jax_Composite(x, y, x_val, y_val):
    x_s = aes.float64("x")
    y_s = aes.float64("y")

    comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)]))

    out = comp_op(x, y)

    out_fg = FunctionGraph([x, y], [out])

    test_input_vals = [
        x_val.astype(config.floatX),
        y_val.astype(config.floatX),
    ]
    _ = compare_jax_and_py(out_fg, test_input_vals)