def __init__(self, input, p, deterministic, seed=None): mask = T.random.bernoulli(shape=input.shape, p=1 - p, seed=seed) return T.cond( deterministic, lambda x: x, lambda x, mask: x * mask / T.maximum(1e-4, (1 - p)), (input, ), (input, mask), )
def __init__( self, input, axis, deterministic, const=0.001, beta_1=0.99, beta_2=0.99, W=T.ones, b=T.zeros, trainable_W=True, trainable_b=True, ): parameter_shape = [ input.shape[i] if i in axis else 1 for i in range(input.ndim) ] r_axes = [i for i in range(input.ndim) if i not in axis] W = create_variable("W", W, parameter_shape, trainable=trainable_W) b = create_variable("b", b, parameter_shape, trainable=trainable_b) input_mean = input.mean(r_axes, keepdims=True) # this definition is traditionally seen as less accurate than jnp.var's # mean((x - mean(x))**2) but may be faster and even, given typical # activation distributions and low-precision arithmetic, more accurate # when used in neural network normalization layers input_var = (input**2).mean(r_axes, keepdims=True) - input_mean**2 avg_mean = schedules.ExponentialMovingAverage(input_mean, beta_1, debias=False, name="mean_ema")[1] avg_var = schedules.ExponentialMovingAverage( input_var, beta_2, init=T.ones_like(input_var, detach=True), debias=False, name="var_ema", )[1] output = T.cond( deterministic, lambda x, m, v, c: nn.normalize(x, mean=m, variance=v, epsilon=c), lambda x, m, v, c: nn.normalize(x, mean=m, variance=v, epsilon=c), (input, avg_mean, avg_var, const), (input, input_mean, input_var, const), ) if b is None and W is not None: return W * output elif b is not None and W is None: return output + b elif b is not None and W is not None: return W * output + b else: return output
def test_cond2(): sj.current_graph().reset() v = T.ones((10, 10)) u = T.Placeholder((), "int32") out = T.cond( u > 0, lambda u: 4 * u, lambda u: u, true_inputs=(v, ), false_inputs=(2 * v, ), ) f = sj.function(u, outputs=out) assert np.array_equal(f(1), 4 * np.ones((10, 10))) assert np.array_equal(f(0), 2 * np.ones((10, 10)))
def fn(prev_vals, idx, data_idxs, DATAPH, DATAPRIMEPH, qtph, qtprimeph): xTP = DATAPRIMEPH[data_idxs[0][0]] #N1 xT = DATAPH[data_idxs[0][1]] #N2 xINNER = T.inner(xT, xTP) #N1 x N2 prev_lambda = prev_vals[0] prev_phi = prev_vals[1] prev_K = prev_vals[2] prev_theta = prev_vals[3] ## not boundary condition # print(qtph[data_idxs[0][1] - 1]) # print(qtprimeph[data_idxs[0][0] - 1]) S, D = self.alg2_VT(qtph[data_idxs[0][1] - 1], qtprimeph[data_idxs[0][0] - 1] ,prev_lambda) new_lambda = self.sw ** 2 * S + self.su ** 2 * xINNER + self.sb ** 2 ## took out an X new_phi = new_lambda + self.sw ** 2 * prev_phi * D # new_phi = prev_phi # new_lambda = prev_lambda #compute kernels S_kernel, D_kernel = self.alg2_VT(qtph[data_idxs[0][1]], qtprimeph[data_idxs[0][0]], new_lambda) new_K = prev_K + self.sv**2 * S_kernel#get current lamnda, get current qtph and qtprimeph new_theta = prev_theta + self.sv**2 * S_kernel + self.sv**2 * D_kernel * new_phi #TODO # ret_K = prev_K # ret_theta = prev_theta equal_check = lambda e: T.equal(idx, e) equal_result = sum(np.vectorize(equal_check)(self.ends_of_calced_diags)) > 0 def true_f(k,t, qp, gph, dataph, dataprimeph, di): xTP_NEXT = dataprimeph[di[1][0]] xT_NEXT = dataph[di[1][1]] xINNER_NEXT = T.inner(xT_NEXT, xTP_NEXT) new_bc = self.make_boundary_condition(xINNER_NEXT) ret_lambda = ret_phi = new_bc S_bc_kernel, D_bc_kernel = self.alg2_VT(qp[di[1][1]], gph[di[1][0]], ret_lambda) ret_K = k + self.sv**2 * S_bc_kernel#get current lamnda, get current qtph and qtprimeph ret_theta = t + self.sv**2 * S_bc_kernel + self.sv**2 * D_bc_kernel * ret_phi #TODO return ret_lambda, ret_phi, ret_K, ret_theta false_f = lambda l,p,k,t: (l,p,k,t) ret_lambda, ret_phi, ret_K, ret_theta = T.cond(equal_result, true_f, false_f, [new_K, new_theta, qtph, qtprimeph, DATAPH, DATAPRIMEPH, data_idxs], [new_lambda, new_phi, new_K, new_theta]) to_carry = create_T_list([ret_lambda, ret_phi, ret_K, ret_theta]) # print('got poast second create list') return to_carry, np.array(())
def test_cond3(): sj.current_graph().reset() v = T.ones((10, 10)) * 3 u = T.Placeholder((), "int32") out = T.cond( u > 0, lambda a, u: a * u, lambda a, u: a + u, true_inputs=( 2 * T.ones((10, 10)), v, ), false_inputs=( 2 * T.ones((10, 10)), v, ), ) f = sj.function(u, outputs=out) assert np.array_equal(f(1), 6 * np.ones((10, 10))) assert np.array_equal(f(0), 5 * np.ones((10, 10)))
def test_cond5(): sj.current_graph().reset() v = T.ones((10, 10)) * 3 W = T.Variable(1) u = T.Placeholder((), "int32") out = T.cond( u > 0, lambda a, u: a * u[0], lambda a, u: a + u[1], true_inputs=( W, v, ), false_inputs=( W, v, ), ) f = sj.function(u, outputs=out, updates={W: W + 1}) assert np.array_equal(f(1), 3 * np.ones(10)) assert np.array_equal(f(0), 5 * np.ones(10)) assert np.array_equal(f(1), 9 * np.ones(10))
#a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx) #g = symjax.gradients(a[1][-1],xx) f = symjax.function(outputs=a[0], updates={vvar: vvar + 1}) print(f(), f(), f()) asdf # fori loop b = T.Placeholder((), 'int32') xx = T.ones(1) a = T.fori_loop(0, b, lambda i, x: i * x, xx) f = symjax.function(b, outputs=a) print(f(0), f(1), f(2), f(3)) # COND example 1 value = T.Placeholder((), np.float32) output = T.cond(value < 0, (value, w), lambda a, b: a * b, (value, w), lambda a, b: a * b) print(output.get({value: -1., w: jax.numpy.arange(3).astype('float32')})) print(output.get({value: 1., w: jax.numpy.arange(3).astype('float32')})) fn = symjax.function(value, w, outputs=[output]) print(fn(-1., jax.numpy.arange(3).astype('float32'))) print(fn(1., jax.numpy.arange(3).astype('float32'))) # COND example 2 value = T.Placeholder((), np.float32) output = T.cond(value < 0, value, lambda a: a * 10, value, lambda a: a * 20) print(output.get({value: -1.})) print(output.get({value: 1.})) fn = symjax.function(value, outputs=[output]) print(fn(-1.)) print(fn(1.))