Esempio n. 1
0
def test_update():
    sj.current_graph().reset()
    w = symjax.tensor.zeros(10)
    for i in range(10):
        w = symjax.tensor.index_update(w, i, i)
    f = symjax.function(outputs=w)
    assert np.array_equal(f(), np.arange(10))
    w2 = symjax.tensor.zeros(10)
    for i in range(10):
        w2 = symjax.tensor.index_update(w2, (i, ), i)
    f = symjax.function(outputs=w2)
    assert np.array_equal(f(), np.arange(10))

    w3 = symjax.tensor.zeros(10)
    for i in range(10):
        w3 = symjax.tensor.index_update(w3, symjax.tensor.index[i], i)
    f = symjax.function(outputs=w3)
    assert np.array_equal(f(), np.arange(10))

    w4 = symjax.tensor.Variable(symjax.tensor.zeros(10))
    i = symjax.tensor.Variable(0, dtype="int32")
    update = symjax.tensor.index_update(w4, i, i)
    f = symjax.function(updates={w4: update, i: i + 1})
    for i in range(10):
        f()
    assert np.array_equal(w4.value, np.arange(10))
Esempio n. 2
0
def test_bn():
    sj.current_graph().reset()
    BATCH_SIZE = 5
    DIM = 2
    input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input")
    deterministic = T.Placeholder((1,), "bool", name="deterministic")

    bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic)

    update = sj.function(input, deterministic, outputs=bn, updates=sj.get_updates())
    get_stats = sj.function(input, outputs=bn.avg_mean)

    data = np.random.randn(50, DIM) * 4 + 2

    true_means = []
    actual_means = []

    for i in range(10):
        batch = data[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
        output = update(batch, 0)
        assert np.allclose(
            output, (batch - batch.mean(0)) / (1e-4 + batch.std(0)), 1e-4
        )
        actual_means.append(get_stats(batch))
        if i == 0:
            true_means.append(batch.mean(0))
        else:
            true_means.append(0.9 * true_means[-1] + 0.1 * batch.mean(0))

    true_means = np.array(true_means)
    actual_means = np.array(actual_means).squeeze()

    assert np.allclose(true_means, actual_means, 1e-4)
Esempio n. 3
0
def create_vae(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1):

    x = T.Placeholder([batch_size, Ds[-1]], 'float32')

    # ENCODER
    enc = encoder(x, Ds[0])
    mu = enc[-1][:, :Ds[0]]
    logvar = enc[-1][:, Ds[0]:]
    var = T.exp(logvar)
 
    z = mu + T.exp(0.5 * logvar) * T.random.randn((batch_size, Ds[0]))
    z_ph = T.Placeholder((batch_size, Ds[0]), 'float32')

    # DECODER
    Ws, bs = init_weights(Ds, seed, scaler)

    Ws = [T.Variable(w) for w in Ws]
    bs = [T.Variable(b) for b in bs]
    logvar_x = T.Variable(T.zeros(1), name='logvar_x') 
    var_x = T.exp(logvar_x)

    h, h_ph = [z], [z_ph]
    for w, b in zip(Ws[:-1], bs[:-1]):
        h.append(T.matmul(h[-1], w.transpose()) + b)
        h.append(h[-1] * relu_mask(h[-1], leakiness))
        h_ph.append(T.matmul(h_ph[-1], w.transpose()) + b)
        h_ph.append(h_ph[-1] * relu_mask(h_ph[-1], leakiness))

    h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1])
    h_ph.append(T.matmul(h_ph[-1], Ws[-1].transpose()) + bs[-1])

    prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\
            + sum([T.mean(v**2) for v in bs[:-1]], 0.) / cov_b 
    kl = 0.5 * (1 + logvar - var - mu ** 2).sum(1)
    px = - 0.5 * (logvar_x + ((x - h[-1])**2 / var_x)).sum(1)
    loss = - (px + kl).mean() + prior

    variables = Ws + bs + sj.layers.get_variables(enc) + [logvar_x]
    opti = sj.optimizers.Adam(loss, lr, params=variables)

    train = sj.function(x, outputs=loss, updates=opti.updates)
    g = sj.function(z_ph, outputs=h_ph[-1])
    params = sj.function(outputs = Ws + bs + [T.exp(logvar_x) * T.ones(Ds[-1])])
    get_varx = sj.function(outputs = var_x)


    output = {'train': train, 'g':g, 'params':params}
    output['model'] = 'VAE'
    output['varx'] = get_varx
    output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler,
                    'prior': sj.function(outputs=prior)}
    def sample(n):
        samples = []
        for i in range(n // batch_size):
            samples.append(g(np.random.randn(batch_size, Ds[0])))
        return np.concatenate(samples)
    output['sample'] = sample
    return output
Esempio n. 4
0
def test_placeholders():
    a = symjax.tensor.ones(1) * 2
    x = symjax.tensor.Placeholder((), "int32")
    f = symjax.function(x, outputs=x * a)
    y = symjax.tensor.Placeholder((), "int32")
    g = symjax.function(y, outputs=y * a)
    assert np.isclose(f(1), 2)
    assert np.isclose(g(2), 4)
Esempio n. 5
0
def create_fns(input, in_signs, Ds):

    cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])])
    
    Ws = [sj.initializers.he((j, i)) for j, i in zip(Ds[1:], Ds[:-1])]
    bs = [sj.initializers.he((j,)) for j in Ds[1:]]

    A_w = [T.eye(Ds[0])]
    B_w = [T.zeros(Ds[0])]
    
    A_q = [T.eye(Ds[0])]
    B_q = [T.zeros(Ds[0])]
    
    maps = [input]
    signs = []
    masks = [T.ones(Ds[0])]
    in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1.,
                                     0.1)

    for w, b in zip(Ws[:-1], bs[:-1]):
        
        pre_activation = T.matmul(w, maps[-1]) + b
        signs.append(T.sign(pre_activation))
        masks.append(T.where(pre_activation > 0, 1., 0.1))

        maps.append(pre_activation * masks[-1])

    maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1])

    # compute per region A and B
    for start, end, w, b, m in zip(cumulative_units[:-1],
                                   cumulative_units[1:], Ws, bs, masks):

        A_w.append(T.matmul(w * m, A_w[-1]))
        B_w.append(T.matmul(w * m, B_w[-1]) + b)

        A_q.append(T.matmul(w * in_masks[start:end], A_q[-1]))
        B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b)

    signs = T.concatenate(signs)
    ineq_b = T.concatenate(B_w[1:-1])
    ineq_A = T.vstack(A_w[1:-1])

    inequalities = T.hstack([ineq_b[:, None], ineq_A])
    inequalities = inequalities * signs[:, None] / T.linalg.norm(ineq_A, 2,
                                                         1, keepdims=True)

    inequalities_code = T.hstack([T.concatenate(B_q[1:-1])[:, None],
                                  T.vstack(A_q[1:-1])])
    inequalities_code = inequalities_code * in_signs[:, None]

    f = sj.function(input, outputs=[maps[-1], A_w[-1], B_w[-1],
                                    inequalities, signs])
    g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]])
    all_g = sj.function(in_signs, outputs=inequalities_code)
    h = sj.function(input, outputs=maps[-1])

    return f, g, h, all_g
Esempio n. 6
0
def test_stop():
    a = symjax.tensor.ones(())
    b = a + a**2
    g = symjax.gradients(b, [a])[0]
    f = symjax.function(outputs=g)
    assert f() == 3
    b = a + symjax.tensor.stop_gradient(a**2)
    g = symjax.gradients(b, [a])[0]
    f = symjax.function(outputs=g)
    assert f() == 1
Esempio n. 7
0
def test_clone_0():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    with sj.Scope("placing"):
        u = T.Placeholder((), "float32", name="u")
    value = 2 * w * u
    c = value.clone({w: u})
    f = sj.function(u, outputs=value)
    g = sj.function(u, outputs=c)

    assert np.array_equal([f(1), g(1), f(2), g(2)], [2, 2, 4, 8])
Esempio n. 8
0
def test_seed():
    a = T.random.randn((), seed=10)
    b = T.random.randn(())
    c = T.random.randn((), seed=10)
    f = symjax.function(outputs=[a, b, c])
    result1 = f()
    result2 = f()
    print(result1)
    print(result2)
    assert result1[0] == result1[2]
    assert result1[0] != result1[1]

    assert result2[0] == result2[2]
    assert result2[0] != result1[0]

    a = T.random.randn((), seed=10)
    b = T.random.randn(())
    c = T.random.randn((), seed=10)
    f = symjax.function(outputs=[a, b, c])
    result12 = f()
    result22 = f()
    assert result12[0] == result12[2]
    assert result12[0] != result12[1]
    assert result22[0] == result22[2]
    assert result22[0] != result12[0]

    assert np.isclose(result1[0], result12[0])
    assert np.isclose(result1[2], result12[2])
    assert not np.isclose(result1[1], result12[1])

    assert np.isclose(result2[0], result22[0])
    assert np.isclose(result2[2], result22[2])
    assert not np.isclose(result2[1], result22[1])

    symjax.current_graph().reset()

    a = T.random.randn((), seed=10)
    b = T.random.randn(())
    c = T.random.randn((), seed=10)
    f = symjax.function(outputs=[a, b, c])
    result12 = f()
    result22 = f()
    assert result12[0] == result12[2]
    assert result12[0] != result12[1]
    assert result22[0] == result22[2]
    assert result22[0] != result12[0]

    assert np.isclose(result1[0], result12[0])
    assert np.isclose(result1[2], result12[2])
    assert not np.isclose(result1[1], result12[1])

    assert np.isclose(result2[0], result22[0])
    assert np.isclose(result2[2], result22[2])
    assert not np.isclose(result2[1], result22[1])
Esempio n. 9
0
    def update_target(self, tau=None):

        if not hasattr(self, "_update_target"):
            with symjax.Scope("update_target"):
                targets = []
                currents = []
                if hasattr(self, "target_actor"):
                    targets += self.target_actor.params(True)
                    currents += self.actor.params(True)
                if hasattr(self, "target_critic"):
                    targets += self.target_critic.params(True)
                    currents += self.critic.params(True)

                _tau = T.Placeholder((), "float32")
                updates = {
                    t: t * (1 - _tau) + a * _tau
                    for t, a in zip(targets, currents)
                }
            self._update_target = symjax.function(_tau, updates=updates)

        if tau is None:
            if not hasattr(self, "tau"):
                raise RuntimeError("tau must be specified")
        tau = tau or self.tau
        self._update_target(tau)
Esempio n. 10
0
def SJ(x, y, N, preallocate=False):
    symjax.current_graph().reset()
    sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])
    sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])

    np.random.seed(0)

    sj_W = T.Variable(np.random.randn(D, 1).astype("float32"))
    sj_b = T.Variable(
        np.random.randn(
            1,
        ).astype("float32")
    )

    sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean()

    optimizers.Adam(sj_loss, lr)

    train = symjax.function(sj_input, sj_output, updates=symjax.get_updates())

    if preallocate:
        import jax

        x = jax.device_put(x)
        y = jax.device_put(y)

    t = time.time()
    for i in range(N):
        train(x, y)

    return time.time() - t
Esempio n. 11
0
def SJ(x, y, N, lr, model, preallocate=False):
    symjax.current_graph().reset()
    sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])
    sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])

    np.random.seed(0)

    sj_W = T.Variable(np.random.randn(D, 1).astype("float32"))
    sj_b = T.Variable(np.random.randn(1, ).astype("float32"))

    sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output)**2).mean()

    if model == "SGD":
        optimizers.SGD(sj_loss, lr)
    elif model == "Adam":
        optimizers.Adam(sj_loss, lr)
    train = symjax.function(sj_input,
                            sj_output,
                            outputs=sj_loss,
                            updates=symjax.get_updates())

    losses = []
    for i in tqdm(range(N)):
        losses.append(train(x, y))

    return losses
Esempio n. 12
0
def test_stack():
    u = tt.Variable(tt.ones((2, )))
    output = tt.stack([u, 2 * u, 3 * u])
    f = symjax.function(outputs=output)
    assert np.allclose(f(), (np.arange(3)[:, None] + 1) * np.ones((3, 2)))
    print(f())
    print(f())
Esempio n. 13
0
def test_base():
    a = T.ones((10, ))
    b = a.sum()
    print(b.get())
    print(b.get())
    f = symjax.function(outputs=b)
    [f() for i in range(100)]
Esempio n. 14
0
def create_glo(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1,
               GLO=False):

    x = T.Placeholder([batch_size, Ds[-1]], 'float32')
    z = T.Variable(T.random.randn((batch_size, Ds[0])))
    logvar_x = T.Variable(T.ones(1))

    # DECODER
    Ws, bs = init_weights(Ds, seed, scaler)
    Ws = [T.Variable(w) for w in Ws]
    bs = [T.Variable(b) for b in bs]
    h = [z]
    for w, b in zip(Ws[:-1], bs[:-1]):
        h.append(T.matmul(h[-1], w.transpose()) + b)
        h.append(h[-1] * relu_mask(h[-1], leakiness))
    h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1])

    # LOSS
    prior = sum([T.sum(w**2) for w in Ws], 0.) / cov_W + sum([T.sum(v**2) for v in bs[:-1]], 0.) / cov_b
    if GLO:
        loss = T.sum((x - h[-1])**2) / batch_size + prior
        variables = Ws + bs
    else:
        loss = Ds[-1] * logvar_x.sum() + T.sum((x - h[-1])**2 / T.exp(logvar_x)) / batch_size + (z**2).sum() / batch_size + prior
        variables = Ws + bs

    prior = sum([(b**2).sum() for b in bs], 0.) / cov_b\
            + sum([(w**2).sum() for w in Ws], 0.) / cov_W
 
    opti = sj.optimizers.Adam(loss + prior, lr, params=variables)
    infer = sj.optimizers.Adam(loss, lr, params=[z])

    estimate = sj.function(x, outputs=z, updates=infer.updates)
    train = sj.function(x, outputs=loss, updates=opti.updates)
    lossf = sj.function(x, outputs=loss)
    params = sj.function(outputs = Ws + bs + [T.ones(Ds[-1]) * T.exp(logvar_x)])

    output = {'train': train, 'estimate':estimate, 'params':params}
    output['reset'] = lambda v: z.assign(v)
    if GLO:
        output['model'] = 'GLO'
    else:
        output['model'] = 'HARD'
    output['loss'] = lossf
    output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler}
    return output
Esempio n. 15
0
 def RNTK_function(self):
     print(f"N, {self.N}, length, {self.length}")
     DATA = T.Placeholder((self.N, self.length), 'float32')
     RNTK,GP = self.RNTK_first(DATA[:,0])
     v, _ = T.scan(lambda a,b:self.RNTK_middle(a,b),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP]))
     RNTK_last,RNTK_avg = self.RNTK_output(v)
     f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg])
     return RNTK_last,RNTK_avg
Esempio n. 16
0
def RNTK_function(N,length,param):
    DATA = T.Placeholder((N, length), 'float32')
    RNTK,GP = RNTK_first(DATA[:,0], param['sigmaw'],param['sigmau'],param['sigmab'],param['sigmah'],param['L'], param['Lf'],param['sigmav'])
    v, _ = T.scan(lambda a,b:RNTK_middle(a,b,param['sigmaw'],param['sigmau'],param['sigmab'],param['L'], param['Lf'],param['sigmav']
                                         ),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP]))
    RNTK_last,RNTK_avg = RNTK_output(v, param['sigmav'],param['L'],param['Lf'],length)
    f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg])
    return f
Esempio n. 17
0
    def __init__(self, states, actions=None):
        self.state_shape = states.shape[1:]
        state = T.Placeholder((1, ) + states.shape[1:],
                              "float32",
                              name="critic_state")
        if actions:
            self.action_shape = actions.shape[1:]
            action = T.Placeholder((1, ) + actions.shape[1:],
                                   "float32",
                                   name="critic_action")
            action_shape = action.shape[1:]

            with symjax.Scope("critic"):
                q_values = self.create_network(states, actions)
                if q_values.ndim == 2:
                    assert q_values.shape[1] == 1
                    q_values = q_values[:, 0]
                q_value = q_values.clone({states: state, actions: action})
                self._params = symjax.get_variables(
                    trainable=None, scope=symjax.current_graph().scope_name)

            inputs = [states, actions]
            input = [state, action]
            self.actions = actions
            self.action = action

        else:
            with symjax.Scope("critic"):
                q_values = self.create_network(states)
                if q_values.ndim == 2:
                    assert q_values.shape[1] == 1
                    q_values = q_values[:, 0]
                q_value = q_values.clone({states: state})
                self._params = symjax.get_variables(
                    trainable=None, scope=symjax.current_graph().scope_name)

            inputs = [states]
            input = [state]

        self.q_values = q_values
        self.state = state
        self.states = states

        self._get_q_values = symjax.function(*inputs, outputs=q_values)
        self._get_q_value = symjax.function(*input, outputs=q_value[0])
Esempio n. 18
0
def test_g():
    a = symjax.tensor.ones(())
    b = symjax.tensor.Variable(1.0)
    l = a * b
    g = symjax.gradients(l, [a])[0]
    f = symjax.function(outputs=g, updates={b: b + 1.0})
    assert f() == 1
    assert f() == 2
    assert f() == 3
Esempio n. 19
0
def test_grad():
    w = tt.Placeholder((), "float32")
    v = tt.Variable(1.0, dtype="float32")
    x = w * v + 2
    #    symjax.nn.optimizers.Adam(x, 0.001)
    g = symjax.gradients(x.sum(), [v])[0]
    f = symjax.function(w, outputs=g)
    assert f(1) == 1
    assert f(10) == 10
Esempio n. 20
0
    def build_net(self, Q):
        # ------------------ all inputs ------------------------
        state = T.Placeholder([self.batch_size, self.n_states],
                              "float32",
                              name="s")
        next_state = T.Placeholder([self.batch_size, self.n_states],
                                   "float32",
                                   name="s_")
        reward = T.Placeholder(
            [
                self.batch_size,
            ],
            "float32",
            name="r",
        )  # input reward
        action = T.Placeholder(
            [
                self.batch_size,
            ],
            "int32",
            name="a",
        )  # input Action

        with symjax.Scope("eval_net"):
            q_eval = Q(state, self.n_actions)
        with symjax.Scope("test_set"):
            q_next = Q(next_state, self.n_actions)

        q_target = reward + self.reward_decay * q_next.max(1)
        q_target = T.stop_gradient(q_target)

        a_indices = T.stack([T.range(self.batch_size), action], axis=1)
        q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)),
                                         1).squeeze(1)
        loss = T.mean((q_target - q_eval_wrt_a)**2)
        nn.optimizers.Adam(loss, self.lr)

        self.train = symjax.function(state,
                                     action,
                                     reward,
                                     next_state,
                                     updates=symjax.get_updates())
        self.q_eval = symjax.function(state, outputs=q_eval)
Esempio n. 21
0
def test_updating_variables():
    sj.current_graph().reset()
    w1 = symjax.tensor.Variable(1.0, dtype="float32")
    input = symjax.tensor.Placeholder((), "float32")
    update = w1 + input + 1
    f = symjax.function(input, updates={w1: update})

    assert w1.value == 1.0
    f(10)
    assert w1.value == 12.0
Esempio n. 22
0
def test_map():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32")
    out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)],
                non_sequences=[w, u])
    f = sj.function(u, outputs=out, updates={w: w + 1})
    assert np.array_equal(f(2), np.arange(3))
    assert np.array_equal(f(2), np.zeros(3))
    assert np.array_equal(f(0), -np.arange(3) * 3)
Esempio n. 23
0
def SJ_EMA(X, debias=True):
    symjax.current_graph().reset()
    x = T.Placeholder((), "float32", name="x")
    value = symjax.nn.schedules.ExponentialMovingAverage(x, 0.9,
                                                         debias=debias)[0]
    train = symjax.function(x, outputs=value, updates=symjax.get_updates())
    outputs = []
    for i in range(len(X)):
        outputs.append(train(X[i]))
    return outputs
Esempio n. 24
0
    def create_func_for_diag(self,
                             dim1idx,
                             dim2idx,
                             function=False,
                             jmode=False):
        diag = self.make_inputs(dim1idx, dim2idx, jmode=jmode)
        # print('test')

        ## prev_vals - (2,1) - previous phi and lambda values
        ## idx - where we are on the diagonal
        ## d1idx - y value of first dimension diag start
        ## d2idx - x value of second dimension diag start
        ## d1ph - max value of first dimension
        ## d2ph - max value of second dimension
        bc = self.sh**2 * self.sw**2 * T.eye(
            self.n, self.n) + (self.su**2) * self.X + self.sb**2
        single_boundary_condition = T.expand_dims(bc, axis=0)
        # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0)
        boundary_condition = T.concatenate(
            [single_boundary_condition,
             single_boundary_condition])  #one for phi and lambda

        def fn(prev_vals, idx, Xph):

            ## change - xph must now index the dataset instead of being passed in

            # tiprime_iter = d1idx + idx
            # ti_iter = d2idx + idx
            prev_lambda = prev_vals[0]
            prev_phi = prev_vals[1]
            ## not boundary condition
            S, D = self.VT(prev_lambda)
            new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2  ## took out an X
            new_phi = new_lambda + self.sw**2 * prev_phi * D
            lambda_expanded = T.expand_dims(new_lambda, axis=0)
            phi_expanded = T.expand_dims(new_phi, axis=0)
            to_return = T.concatenate([lambda_expanded, phi_expanded])

            # jax.lax.cond(to_return.shape == (2,10,10), lambda _: print(f'{idx}, true'), lambda _: print(f'{idx}, false'), operand = None)

            return to_return, to_return

        last_ema, all_ema = T.scan(fn,
                                   init=boundary_condition,
                                   sequences=[diag],
                                   non_sequences=[self.X])

        expanded_ema = T.concatenate(
            [T.expand_dims(boundary_condition, axis=0), all_ema])
        print(expanded_ema)
        if function:
            f = symjax.function(diag, outputs=expanded_ema)
            return f
        else:
            return expanded_ema
Esempio n. 25
0
def test_vectorize_sgd():
    sj.current_graph().reset()
    x = symjax.tensor.Placeholder((0, 2), "float32")
    y = symjax.tensor.Placeholder((0, ), "float32")

    w = symjax.tensor.Variable((1, 1), dtype="float32")
    loss = ((x.dot(w) - y)**2).mean()

    g = symjax.gradients(loss, [w])[0]

    other_g = symjax.gradients(x.dot(w).sum(), [w])[0]

    f = symjax.function(x, y, outputs=loss, updates={w: w - 0.1 * g})
    other_f = symjax.function(x, outputs=other_g)

    L = [10]
    for i in range(10):
        L.append(f(np.ones((i + 1, 2)), -1 * np.ones(i + 1)))
        assert L[-1] < L[-2]
        assert np.array_equal(other_f(np.ones((i + 1, 2))), [i + 1.0, i + 1.0])
Esempio n. 26
0
def test_pc():
    a, cpt = symjax.nn.schedules.PiecewiseConstant(0, {4: 1, 8: 2})
    f = symjax.function(outputs=a, updates={cpt: cpt + 1})
    for i in range(10):
        value = f()
        if i < 4:
            assert np.array_equal(value, 0)
        elif i < 8:
            assert np.array_equal(value, 1)
        else:
            assert np.array_equal(value, 2)
Esempio n. 27
0
def test_grad_map():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32", name="u")
    out = T.map(lambda a, w, u: w * a * u, (T.range(3), ),
                non_sequences=(w, u))
    g = sj.gradients(out.sum(), w)
    f = sj.function(u, outputs=g)

    assert np.array_equal(f(0), 0)
    assert np.array_equal(f(1), 3)
Esempio n. 28
0
def test_vectorize():
    sj.current_graph().reset()
    x = symjax.tensor.Placeholder((0, 2), "float32")
    w = symjax.tensor.Variable(1.0, dtype="float32")
    p = x.sum(1)

    f = symjax.function(x, outputs=p, updates={w: x.sum()})

    assert np.array_equal(f(np.ones((1, 2))), [2.0])
    assert w.value == 2.0
    assert np.array_equal(f(np.ones((2, 2))), [2.0, 2.0])
    assert w.value == 4.0
Esempio n. 29
0
def test_sma():
    symjax.current_graph().reset()
    a = symjax.tensor.Placeholder((4, ), "float32")
    sma, var = symjax.nn.schedules.SimpleMovingAverage(a, 3)
    f = symjax.function(a, outputs=[sma, var], updates=symjax.get_updates())

    data = np.random.randn(4, 4)
    current = [data[0], data[:2].mean(0), data[:3].mean(0), data[1:4].mean(0)]

    for i in range(data.shape[0]):
        out = f(data[i])
        assert np.allclose(out[0], current[i])
Esempio n. 30
0
def test_global_pool():
    np.random.seed(0)
    sj.current_graph().reset()
    BATCH_SIZE = 4096
    DIM = 8
    input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input")

    output = nn.layers.Dense(input, 64)
    output = nn.layers.Dense(output, output.shape[-1] * 2)
    output = nn.layers.Dense(output, output.shape[-1] * 2)
    get = sj.function(input, outputs=output)
    assert get(np.ones((BATCH_SIZE, DIM))).shape == (BATCH_SIZE, 64 * 4)