Ejemplo n.º 1
0
 def __init__(self, input, p, deterministic, seed=None):
     mask = T.random.bernoulli(shape=input.shape, p=1 - p, seed=seed)
     return T.cond(
         deterministic,
         lambda x: x,
         lambda x, mask: x * mask / T.maximum(1e-4, (1 - p)),
         (input, ),
         (input, mask),
     )
Ejemplo n.º 2
0
    def __init__(
        self,
        input,
        axis,
        deterministic,
        const=0.001,
        beta_1=0.99,
        beta_2=0.99,
        W=T.ones,
        b=T.zeros,
        trainable_W=True,
        trainable_b=True,
    ):

        parameter_shape = [
            input.shape[i] if i in axis else 1 for i in range(input.ndim)
        ]
        r_axes = [i for i in range(input.ndim) if i not in axis]

        W = create_variable("W", W, parameter_shape, trainable=trainable_W)
        b = create_variable("b", b, parameter_shape, trainable=trainable_b)

        input_mean = input.mean(r_axes, keepdims=True)
        # this definition is traditionally seen as less accurate than jnp.var's
        # mean((x - mean(x))**2) but may be faster and even, given typical
        # activation distributions and low-precision arithmetic, more accurate
        # when used in neural network normalization layers
        input_var = (input**2).mean(r_axes, keepdims=True) - input_mean**2

        avg_mean = schedules.ExponentialMovingAverage(input_mean,
                                                      beta_1,
                                                      debias=False,
                                                      name="mean_ema")[1]
        avg_var = schedules.ExponentialMovingAverage(
            input_var,
            beta_2,
            init=T.ones_like(input_var, detach=True),
            debias=False,
            name="var_ema",
        )[1]

        output = T.cond(
            deterministic,
            lambda x, m, v, c: nn.normalize(x, mean=m, variance=v, epsilon=c),
            lambda x, m, v, c: nn.normalize(x, mean=m, variance=v, epsilon=c),
            (input, avg_mean, avg_var, const),
            (input, input_mean, input_var, const),
        )
        if b is None and W is not None:
            return W * output
        elif b is not None and W is None:
            return output + b
        elif b is not None and W is not None:
            return W * output + b
        else:
            return output
Ejemplo n.º 3
0
def test_cond2():
    sj.current_graph().reset()
    v = T.ones((10, 10))
    u = T.Placeholder((), "int32")
    out = T.cond(
        u > 0,
        lambda u: 4 * u,
        lambda u: u,
        true_inputs=(v, ),
        false_inputs=(2 * v, ),
    )
    f = sj.function(u, outputs=out)
    assert np.array_equal(f(1), 4 * np.ones((10, 10)))
    assert np.array_equal(f(0), 2 * np.ones((10, 10)))
Ejemplo n.º 4
0
        def fn(prev_vals, idx, data_idxs, DATAPH, DATAPRIMEPH, qtph, qtprimeph):

            xTP = DATAPRIMEPH[data_idxs[0][0]] #N1
            xT = DATAPH[data_idxs[0][1]] #N2
            xINNER = T.inner(xT, xTP) #N1 x N2

            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            prev_K = prev_vals[2]
            prev_theta = prev_vals[3]
            ## not boundary condition
            # print(qtph[data_idxs[0][1] - 1])
            # print(qtprimeph[data_idxs[0][0] - 1])
            S, D = self.alg2_VT(qtph[data_idxs[0][1] - 1], qtprimeph[data_idxs[0][0] - 1] ,prev_lambda)
            new_lambda = self.sw ** 2 * S + self.su ** 2 * xINNER + self.sb ** 2 ## took out an X
            new_phi = new_lambda + self.sw ** 2 * prev_phi * D
            # new_phi = prev_phi
            # new_lambda = prev_lambda

            #compute kernels
            S_kernel, D_kernel = self.alg2_VT(qtph[data_idxs[0][1]], qtprimeph[data_idxs[0][0]], new_lambda)
            new_K = prev_K + self.sv**2 * S_kernel#get current lamnda, get current qtph and qtprimeph
            new_theta = prev_theta + self.sv**2 * S_kernel + self.sv**2 * D_kernel * new_phi #TODO
            # ret_K = prev_K
            # ret_theta = prev_theta

            equal_check = lambda e: T.equal(idx, e)

            equal_result = sum(np.vectorize(equal_check)(self.ends_of_calced_diags)) > 0

            def true_f(k,t, qp, gph, dataph, dataprimeph, di): 
                xTP_NEXT = dataprimeph[di[1][0]]
                xT_NEXT = dataph[di[1][1]]
                xINNER_NEXT = T.inner(xT_NEXT, xTP_NEXT)
                new_bc = self.make_boundary_condition(xINNER_NEXT)
                ret_lambda = ret_phi = new_bc 

                S_bc_kernel, D_bc_kernel = self.alg2_VT(qp[di[1][1]], gph[di[1][0]], ret_lambda)
                ret_K = k + self.sv**2 * S_bc_kernel#get current lamnda, get current qtph and qtprimeph
                ret_theta = t + self.sv**2 * S_bc_kernel + self.sv**2 * D_bc_kernel * ret_phi #TODO

                return ret_lambda, ret_phi, ret_K, ret_theta
            false_f = lambda l,p,k,t: (l,p,k,t)

            ret_lambda, ret_phi, ret_K, ret_theta = T.cond(equal_result, true_f, false_f, [new_K, new_theta, qtph, qtprimeph, DATAPH, DATAPRIMEPH, data_idxs], [new_lambda, new_phi, new_K, new_theta])
            
            to_carry = create_T_list([ret_lambda, ret_phi, ret_K, ret_theta])
            # print('got poast second create list')
            
            return to_carry, np.array(())
Ejemplo n.º 5
0
def test_cond3():
    sj.current_graph().reset()
    v = T.ones((10, 10)) * 3
    u = T.Placeholder((), "int32")
    out = T.cond(
        u > 0,
        lambda a, u: a * u,
        lambda a, u: a + u,
        true_inputs=(
            2 * T.ones((10, 10)),
            v,
        ),
        false_inputs=(
            2 * T.ones((10, 10)),
            v,
        ),
    )
    f = sj.function(u, outputs=out)
    assert np.array_equal(f(1), 6 * np.ones((10, 10)))
    assert np.array_equal(f(0), 5 * np.ones((10, 10)))
Ejemplo n.º 6
0
def test_cond5():
    sj.current_graph().reset()
    v = T.ones((10, 10)) * 3
    W = T.Variable(1)
    u = T.Placeholder((), "int32")
    out = T.cond(
        u > 0,
        lambda a, u: a * u[0],
        lambda a, u: a + u[1],
        true_inputs=(
            W,
            v,
        ),
        false_inputs=(
            W,
            v,
        ),
    )
    f = sj.function(u, outputs=out, updates={W: W + 1})
    assert np.array_equal(f(1), 3 * np.ones(10))
    assert np.array_equal(f(0), 5 * np.ones(10))
    assert np.array_equal(f(1), 9 * np.ones(10))
Ejemplo n.º 7
0
#a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx)
#g = symjax.gradients(a[1][-1],xx)
f = symjax.function(outputs=a[0], updates={vvar: vvar + 1})
print(f(), f(), f())
asdf

# fori loop
b = T.Placeholder((), 'int32')
xx = T.ones(1)
a = T.fori_loop(0, b, lambda i, x: i * x, xx)
f = symjax.function(b, outputs=a)
print(f(0), f(1), f(2), f(3))

# COND example 1
value = T.Placeholder((), np.float32)
output = T.cond(value < 0, (value, w), lambda a, b: a * b, (value, w),
                lambda a, b: a * b)
print(output.get({value: -1., w: jax.numpy.arange(3).astype('float32')}))
print(output.get({value: 1., w: jax.numpy.arange(3).astype('float32')}))
fn = symjax.function(value, w, outputs=[output])
print(fn(-1., jax.numpy.arange(3).astype('float32')))
print(fn(1., jax.numpy.arange(3).astype('float32')))

# COND example 2
value = T.Placeholder((), np.float32)
output = T.cond(value < 0, value, lambda a: a * 10, value, lambda a: a * 20)
print(output.get({value: -1.}))
print(output.get({value: 1.}))
fn = symjax.function(value, outputs=[output])
print(fn(-1.))
print(fn(1.))