Example #1
0
def SJ(x, y, N, lr, model, preallocate=False):
    symjax.current_graph().reset()
    sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])
    sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])

    np.random.seed(0)

    sj_W = T.Variable(np.random.randn(D, 1).astype("float32"))
    sj_b = T.Variable(np.random.randn(1, ).astype("float32"))

    sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output)**2).mean()

    if model == "SGD":
        optimizers.SGD(sj_loss, lr)
    elif model == "Adam":
        optimizers.Adam(sj_loss, lr)
    train = symjax.function(sj_input,
                            sj_output,
                            outputs=sj_loss,
                            updates=symjax.get_updates())

    losses = []
    for i in tqdm(range(N)):
        losses.append(train(x, y))

    return losses
Example #2
0
def SJ(x, y, N, preallocate=False):
    symjax.current_graph().reset()
    sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])
    sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])

    np.random.seed(0)

    sj_W = T.Variable(np.random.randn(D, 1).astype("float32"))
    sj_b = T.Variable(
        np.random.randn(
            1,
        ).astype("float32")
    )

    sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean()

    optimizers.Adam(sj_loss, lr)

    train = symjax.function(sj_input, sj_output, updates=symjax.get_updates())

    if preallocate:
        import jax

        x = jax.device_put(x)
        y = jax.device_put(y)

    t = time.time()
    for i in range(N):
        train(x, y)

    return time.time() - t
Example #3
0
def test_bn():
    sj.current_graph().reset()
    BATCH_SIZE = 5
    DIM = 2
    input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input")
    deterministic = T.Placeholder((1,), "bool", name="deterministic")

    bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic)

    update = sj.function(input, deterministic, outputs=bn, updates=sj.get_updates())
    get_stats = sj.function(input, outputs=bn.avg_mean)

    data = np.random.randn(50, DIM) * 4 + 2

    true_means = []
    actual_means = []

    for i in range(10):
        batch = data[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
        output = update(batch, 0)
        assert np.allclose(
            output, (batch - batch.mean(0)) / (1e-4 + batch.std(0)), 1e-4
        )
        actual_means.append(get_stats(batch))
        if i == 0:
            true_means.append(batch.mean(0))
        else:
            true_means.append(0.9 * true_means[-1] + 0.1 * batch.mean(0))

    true_means = np.array(true_means)
    actual_means = np.array(actual_means).squeeze()

    assert np.allclose(true_means, actual_means, 1e-4)
Example #4
0
def create_vae(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1):

    x = T.Placeholder([batch_size, Ds[-1]], 'float32')

    # ENCODER
    enc = encoder(x, Ds[0])
    mu = enc[-1][:, :Ds[0]]
    logvar = enc[-1][:, Ds[0]:]
    var = T.exp(logvar)
 
    z = mu + T.exp(0.5 * logvar) * T.random.randn((batch_size, Ds[0]))
    z_ph = T.Placeholder((batch_size, Ds[0]), 'float32')

    # DECODER
    Ws, bs = init_weights(Ds, seed, scaler)

    Ws = [T.Variable(w) for w in Ws]
    bs = [T.Variable(b) for b in bs]
    logvar_x = T.Variable(T.zeros(1), name='logvar_x') 
    var_x = T.exp(logvar_x)

    h, h_ph = [z], [z_ph]
    for w, b in zip(Ws[:-1], bs[:-1]):
        h.append(T.matmul(h[-1], w.transpose()) + b)
        h.append(h[-1] * relu_mask(h[-1], leakiness))
        h_ph.append(T.matmul(h_ph[-1], w.transpose()) + b)
        h_ph.append(h_ph[-1] * relu_mask(h_ph[-1], leakiness))

    h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1])
    h_ph.append(T.matmul(h_ph[-1], Ws[-1].transpose()) + bs[-1])

    prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\
            + sum([T.mean(v**2) for v in bs[:-1]], 0.) / cov_b 
    kl = 0.5 * (1 + logvar - var - mu ** 2).sum(1)
    px = - 0.5 * (logvar_x + ((x - h[-1])**2 / var_x)).sum(1)
    loss = - (px + kl).mean() + prior

    variables = Ws + bs + sj.layers.get_variables(enc) + [logvar_x]
    opti = sj.optimizers.Adam(loss, lr, params=variables)

    train = sj.function(x, outputs=loss, updates=opti.updates)
    g = sj.function(z_ph, outputs=h_ph[-1])
    params = sj.function(outputs = Ws + bs + [T.exp(logvar_x) * T.ones(Ds[-1])])
    get_varx = sj.function(outputs = var_x)


    output = {'train': train, 'g':g, 'params':params}
    output['model'] = 'VAE'
    output['varx'] = get_varx
    output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler,
                    'prior': sj.function(outputs=prior)}
    def sample(n):
        samples = []
        for i in range(n // batch_size):
            samples.append(g(np.random.randn(batch_size, Ds[0])))
        return np.concatenate(samples)
    output['sample'] = sample
    return output
Example #5
0
    def update_target(self, tau=None):

        if not hasattr(self, "_update_target"):
            with symjax.Scope("update_target"):
                targets = []
                currents = []
                if hasattr(self, "target_actor"):
                    targets += self.target_actor.params(True)
                    currents += self.actor.params(True)
                if hasattr(self, "target_critic"):
                    targets += self.target_critic.params(True)
                    currents += self.critic.params(True)

                _tau = T.Placeholder((), "float32")
                updates = {
                    t: t * (1 - _tau) + a * _tau
                    for t, a in zip(targets, currents)
                }
            self._update_target = symjax.function(_tau, updates=updates)

        if tau is None:
            if not hasattr(self, "tau"):
                raise RuntimeError("tau must be specified")
        tau = tau or self.tau
        self._update_target(tau)
Example #6
0
 def make_inputs(self, dim1idx, dim2idx, jmode=False):
     ## revelation - we dont actually need the diagonal number, just the length! This means we no longer need arange
     # print("dim1idx", dim1idx)
     test = jnp.min(jnp.array([self.dim_1, self.dim_2
                               ])) - (dim1idx + dim2idx)
     # print("test", test)
     return T.Placeholder((test, ), "int32")
Example #7
0
def RNTK_function(N,length,param):
    DATA = T.Placeholder((N, length), 'float32')
    RNTK,GP = RNTK_first(DATA[:,0], param['sigmaw'],param['sigmau'],param['sigmab'],param['sigmah'],param['L'], param['Lf'],param['sigmav'])
    v, _ = T.scan(lambda a,b:RNTK_middle(a,b,param['sigmaw'],param['sigmau'],param['sigmab'],param['L'], param['Lf'],param['sigmav']
                                         ),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP]))
    RNTK_last,RNTK_avg = RNTK_output(v, param['sigmav'],param['L'],param['Lf'],length)
    f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg])
    return f
Example #8
0
 def RNTK_function(self):
     print(f"N, {self.N}, length, {self.length}")
     DATA = T.Placeholder((self.N, self.length), 'float32')
     RNTK,GP = self.RNTK_first(DATA[:,0])
     v, _ = T.scan(lambda a,b:self.RNTK_middle(a,b),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP]))
     RNTK_last,RNTK_avg = self.RNTK_output(v)
     f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg])
     return RNTK_last,RNTK_avg
Example #9
0
    def __init__(self, states, actions=None):
        self.state_shape = states.shape[1:]
        state = T.Placeholder((1, ) + states.shape[1:],
                              "float32",
                              name="critic_state")
        if actions:
            self.action_shape = actions.shape[1:]
            action = T.Placeholder((1, ) + actions.shape[1:],
                                   "float32",
                                   name="critic_action")
            action_shape = action.shape[1:]

            with symjax.Scope("critic"):
                q_values = self.create_network(states, actions)
                if q_values.ndim == 2:
                    assert q_values.shape[1] == 1
                    q_values = q_values[:, 0]
                q_value = q_values.clone({states: state, actions: action})
                self._params = symjax.get_variables(
                    trainable=None, scope=symjax.current_graph().scope_name)

            inputs = [states, actions]
            input = [state, action]
            self.actions = actions
            self.action = action

        else:
            with symjax.Scope("critic"):
                q_values = self.create_network(states)
                if q_values.ndim == 2:
                    assert q_values.shape[1] == 1
                    q_values = q_values[:, 0]
                q_value = q_values.clone({states: state})
                self._params = symjax.get_variables(
                    trainable=None, scope=symjax.current_graph().scope_name)

            inputs = [states]
            input = [state]

        self.q_values = q_values
        self.state = state
        self.states = states

        self._get_q_values = symjax.function(*inputs, outputs=q_values)
        self._get_q_value = symjax.function(*input, outputs=q_value[0])
Example #10
0
def test_grad():
    w = tt.Placeholder((), "float32")
    v = tt.Variable(1.0, dtype="float32")
    x = w * v + 2
    #    symjax.nn.optimizers.Adam(x, 0.001)
    g = symjax.gradients(x.sum(), [v])[0]
    f = symjax.function(w, outputs=g)
    assert f(1) == 1
    assert f(10) == 10
Example #11
0
    def build_net(self, Q):
        # ------------------ all inputs ------------------------
        state = T.Placeholder([self.batch_size, self.n_states],
                              "float32",
                              name="s")
        next_state = T.Placeholder([self.batch_size, self.n_states],
                                   "float32",
                                   name="s_")
        reward = T.Placeholder(
            [
                self.batch_size,
            ],
            "float32",
            name="r",
        )  # input reward
        action = T.Placeholder(
            [
                self.batch_size,
            ],
            "int32",
            name="a",
        )  # input Action

        with symjax.Scope("eval_net"):
            q_eval = Q(state, self.n_actions)
        with symjax.Scope("test_set"):
            q_next = Q(next_state, self.n_actions)

        q_target = reward + self.reward_decay * q_next.max(1)
        q_target = T.stop_gradient(q_target)

        a_indices = T.stack([T.range(self.batch_size), action], axis=1)
        q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)),
                                         1).squeeze(1)
        loss = T.mean((q_target - q_eval_wrt_a)**2)
        nn.optimizers.Adam(loss, self.lr)

        self.train = symjax.function(state,
                                     action,
                                     reward,
                                     next_state,
                                     updates=symjax.get_updates())
        self.q_eval = symjax.function(state, outputs=q_eval)
Example #12
0
def SJ_EMA(X, debias=True):
    symjax.current_graph().reset()
    x = T.Placeholder((), "float32", name="x")
    value = symjax.nn.schedules.ExponentialMovingAverage(x, 0.9,
                                                         debias=debias)[0]
    train = symjax.function(x, outputs=value, updates=symjax.get_updates())
    outputs = []
    for i in range(len(X)):
        outputs.append(train(X[i]))
    return outputs
def create_func(dic, printbool = False):
    N = int(dic["n_patrons1="])
    ti_length = int(dic["n_entradasTi="])
    ti_prime_length = int(dic["n_entradasTiP="])
    DATA = T.Placeholder((N, ti_length), 'float32', name = "X")
    DATAPRIME = T.Placeholder((N, ti_prime_length), 'float32', name = "X")
    # x = DATA[:,0]
    # X = x*x[:, None]
    # n = X.shape[0]

    rntkod = RNTK(dic, DATA, DATAPRIME) #could be flipped 

    start = time.time()
    kernels_ema = rntkod.create_func_for_diag()
    diag_func = symjax.function(DATA, DATAPRIME, outputs=kernels_ema)
    if printbool:
        print("time to create symjax", time.time() - start)

    return diag_func, rntkod
Example #14
0
def test_map():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32")
    out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)],
                non_sequences=[w, u])
    f = sj.function(u, outputs=out, updates={w: w + 1})
    assert np.array_equal(f(2), np.arange(3))
    assert np.array_equal(f(2), np.zeros(3))
    assert np.array_equal(f(0), -np.arange(3) * 3)
Example #15
0
def test_clone_0():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    with sj.Scope("placing"):
        u = T.Placeholder((), "float32", name="u")
    value = 2 * w * u
    c = value.clone({w: u})
    f = sj.function(u, outputs=value)
    g = sj.function(u, outputs=c)

    assert np.array_equal([f(1), g(1), f(2), g(2)], [2, 2, 4, 8])
Example #16
0
def test_grad_map():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32", name="u")
    out = T.map(lambda a, w, u: w * a * u, (T.range(3), ),
                non_sequences=(w, u))
    g = sj.gradients(out.sum(), w)
    f = sj.function(u, outputs=g)

    assert np.array_equal(f(0), 0)
    assert np.array_equal(f(1), 3)
Example #17
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        n_episodes,
        episode_length,
        actor,
        lr=1e-3,
        gamma=0.99,
    ):
        self.actor = actor
        self.gamma = gamma
        self.lr = lr
        self.episode_length = episode_length
        self.n_episodes = n_episodes
        self.batch_size = episode_length * n_episodes

        states = T.Placeholder((self.batch_size, ) + state_shape, "float32")
        actions = T.Placeholder((self.batch_size, ) + actions_shape, "float32")
        discounted_rewards = T.Placeholder((self.batch_size, ), "float32")

        self.actor = actor(states, distribution="gaussian")

        logprobs = self.actor.actions.log_prob(actions)
        actor_loss = -(logprobs * discounted_rewards).sum() / n_episodes

        with symjax.Scope("REINFORCE_optimizer"):
            nn.optimizers.Adam(
                actor_loss,
                lr,
                params=self.actor.params(True),
            )

        # create the update function
        self._train = symjax.function(
            states,
            actions,
            discounted_rewards,
            outputs=actor_loss,
            updates=symjax.get_updates(scope="*/REINFORCE_optimizer"),
        )
Example #18
0
def test_global_pool():
    np.random.seed(0)
    sj.current_graph().reset()
    BATCH_SIZE = 4096
    DIM = 8
    input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input")

    output = nn.layers.Dense(input, 64)
    output = nn.layers.Dense(output, output.shape[-1] * 2)
    output = nn.layers.Dense(output, output.shape[-1] * 2)
    get = sj.function(input, outputs=output)
    assert get(np.ones((BATCH_SIZE, DIM))).shape == (BATCH_SIZE, 64 * 4)
Example #19
0
def test_dropout():
    np.random.seed(0)
    sj.current_graph().reset()
    BATCH_SIZE = 4096
    DIM = 8
    input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input")
    deterministic = T.Placeholder((), "bool", name="deterministic")

    bn = nn.layers.Dropout(input, p=0.2, deterministic=deterministic)

    update = sj.function(input, deterministic, outputs=bn)

    data = np.ones((BATCH_SIZE, DIM))

    output1 = update(data, 0)
    output2 = update(data, 0)
    output3 = update(data, 1)

    assert not np.allclose(output1, output2, 1e-1)
    assert np.allclose(output1.mean(0) / 2 + output2.mean(0) / 2, 1, 0.08)
    assert np.all(output3)
Example #20
0
def test_flip():
    np.random.seed(0)
    sj.current_graph().reset()
    BATCH_SIZE = 2048
    DIM = 8
    input = T.Placeholder((BATCH_SIZE, DIM, DIM), "float32", name="input")
    deterministic = T.Placeholder((1,), "bool", name="deterministic")

    bn = nn.layers.RandomFlip(input, axis=2, p=0.5, deterministic=deterministic)

    update = sj.function(input, deterministic, outputs=bn)

    data = np.ones((BATCH_SIZE, DIM, DIM))
    data[:, :, : DIM // 2] = 0

    output1 = update(data, 0)
    output2 = update(data, 0)
    output3 = update(data, 1)

    assert not np.allclose(output1, output2, 1e-1)
    assert np.allclose(output1.mean(0) / 2 + output2.mean(0) / 2, 0.5, 0.05)
    assert np.allclose(data, output3, 1e-6)
Example #21
0
def test_bn():
    np.random.seed(0)
    sj.current_graph().reset()
    BATCH_SIZE = 5
    DIM = 2
    input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input")
    deterministic = T.Placeholder((), "bool", name="deterministic")

    bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic)

    update = sj.function(input,
                         deterministic,
                         outputs=bn,
                         updates=sj.get_updates())
    avmean = symjax.get_variables(trainable=None)[-3]
    print(avmean)
    get_stats = sj.function(input, outputs=avmean[0])

    data = np.random.randn(50, DIM) * 4 + 2

    true_means = [np.zeros(DIM)]
    actual_means = [np.zeros(DIM)]

    for i in range(10):
        batch = data[BATCH_SIZE * i:BATCH_SIZE * (i + 1)]

        output = update(batch, 0)
        assert np.allclose(
            output,
            (batch - batch.mean(0)) / np.sqrt(0.001 + batch.var(0)),
            1e-4,
        )

        actual_means.append(get_stats(batch))
        true_means.append(0.99 * true_means[-1] + 0.01 * batch.mean(0))

    true_means = np.array(true_means)
    actual_means = np.array(actual_means).squeeze()
    assert np.allclose(true_means, actual_means)
Example #22
0
def test_while():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    v = T.Placeholder((), "float32")
    out = T.while_loop(
        lambda i, u: i[0] + u < 5,
        lambda i: (i[0] + 1.0, i[0]**2),
        (w, 1.0),
        non_sequences_cond=(v, ),
    )
    f = sj.function(v, outputs=out)
    assert np.array_equal(np.array(f(0)), [5, 16])
    assert np.array_equal(f(2), [3, 4])
Example #23
0
def test_clone_base():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    w2 = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32", name="u")
    uu = T.Placeholder((), "float32", name="uu")

    aa = T.Placeholder((), "float32")
    bb = T.Placeholder((), "float32")

    l = 2 * w * u * w2
    g = sj.gradients(l, w)
    guu = T.clone(l, {u: uu})
    guuu = T.clone(l, {w: uu})

    f = sj.function(u, outputs=g, updates={w2: w2 + 1})
    fuu = sj.function(uu, outputs=guu, updates={w2: w2 + 1})
    fuuu = sj.function(u, uu, outputs=guuu, updates={w2: w2 + 1})

    #    print(f(2))
    assert np.array_equal(f(2), 4.0)
    assert np.array_equal(fuu(1), 4)
    assert np.array_equal(fuuu(0, 0), 0)
Example #24
0
def test_cond2():
    sj.current_graph().reset()
    v = T.ones((10, 10))
    u = T.Placeholder((), "int32")
    out = T.cond(
        u > 0,
        lambda u: 4 * u,
        lambda u: u,
        true_inputs=(v, ),
        false_inputs=(2 * v, ),
    )
    f = sj.function(u, outputs=out)
    assert np.array_equal(f(1), 4 * np.ones((10, 10)))
    assert np.array_equal(f(0), 2 * np.ones((10, 10)))
Example #25
0
def create_glo(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1,
               GLO=False):

    x = T.Placeholder([batch_size, Ds[-1]], 'float32')
    z = T.Variable(T.random.randn((batch_size, Ds[0])))
    logvar_x = T.Variable(T.ones(1))

    # DECODER
    Ws, bs = init_weights(Ds, seed, scaler)
    Ws = [T.Variable(w) for w in Ws]
    bs = [T.Variable(b) for b in bs]
    h = [z]
    for w, b in zip(Ws[:-1], bs[:-1]):
        h.append(T.matmul(h[-1], w.transpose()) + b)
        h.append(h[-1] * relu_mask(h[-1], leakiness))
    h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1])

    # LOSS
    prior = sum([T.sum(w**2) for w in Ws], 0.) / cov_W + sum([T.sum(v**2) for v in bs[:-1]], 0.) / cov_b
    if GLO:
        loss = T.sum((x - h[-1])**2) / batch_size + prior
        variables = Ws + bs
    else:
        loss = Ds[-1] * logvar_x.sum() + T.sum((x - h[-1])**2 / T.exp(logvar_x)) / batch_size + (z**2).sum() / batch_size + prior
        variables = Ws + bs

    prior = sum([(b**2).sum() for b in bs], 0.) / cov_b\
            + sum([(w**2).sum() for w in Ws], 0.) / cov_W
 
    opti = sj.optimizers.Adam(loss + prior, lr, params=variables)
    infer = sj.optimizers.Adam(loss, lr, params=[z])

    estimate = sj.function(x, outputs=z, updates=infer.updates)
    train = sj.function(x, outputs=loss, updates=opti.updates)
    lossf = sj.function(x, outputs=loss)
    params = sj.function(outputs = Ws + bs + [T.ones(Ds[-1]) * T.exp(logvar_x)])

    output = {'train': train, 'estimate':estimate, 'params':params}
    output['reset'] = lambda v: z.assign(v)
    if GLO:
        output['model'] = 'GLO'
    else:
        output['model'] = 'HARD'
    output['loss'] = lossf
    output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler}
    return output
Example #26
0
def create_func(dic, printbool=False):
    N = int(dic["n_patrons1="])
    length = int(dic["n_entradas="])
    DATA = T.Placeholder((N, length), 'float32', name="X")
    x = DATA[:, 0]
    X = x * x[:, None]
    n = X.shape[0]
    print(n, N)

    rntkod = RNTK(dic, X, n)  #could be flipped

    start = time.time()
    lin_ema = rntkod.create_func_for_diag()
    diag_func = symjax.function(DATA, outputs=lin_ema)
    if printbool:
        print("time to create symjax", time.time() - start)

    return diag_func, rntkod
Example #27
0
def test_cond3():
    sj.current_graph().reset()
    v = T.ones((10, 10)) * 3
    u = T.Placeholder((), "int32")
    out = T.cond(
        u > 0,
        lambda a, u: a * u,
        lambda a, u: a + u,
        true_inputs=(
            2 * T.ones((10, 10)),
            v,
        ),
        false_inputs=(
            2 * T.ones((10, 10)),
            v,
        ),
    )
    f = sj.function(u, outputs=out)
    assert np.array_equal(f(1), 6 * np.ones((10, 10)))
    assert np.array_equal(f(0), 5 * np.ones((10, 10)))
Example #28
0
    def __init__(self, states, actions_distribution=None, name="actor"):

        self.state_shape = states.shape[1:]
        state = T.Placeholder((1, ) + states.shape[1:], "float32")
        self.actions_distribution = actions_distribution

        with symjax.Scope(name):
            if actions_distribution == symjax.probabilities.Normal:

                means, covs = self.create_network(states)

                actions = actions_distribution(means, cov=covs)
                samples = actions.sample()
                samples_log_prob = actions.log_prob(samples)

                action = symjax.probabilities.MultivariateNormal(
                    means.clone({states: state}),
                    cov=covs.clone({states: state}),
                )
                sample = self.action.sample()
                sample_log_prob = self.action.log_prob(sample)

                self._get_actions = symjax.function(
                    states, outputs=[samples, samples_log_prob])
                self._get_action = symjax.function(
                    state,
                    outputs=[sample[0], sample_log_prob[0]],
                )
            elif actions_distribution is None:
                actions = self.create_network(states)
                action = actions.clone({states: state})

                self._get_actions = symjax.function(states, outputs=actions)
                self._get_action = symjax.function(state, outputs=action[0])

            self._params = symjax.get_variables(
                trainable=None, scope=symjax.current_graph().scope_name)
        self.actions = actions
        self.state = state
        self.action = action
Example #29
0
def test_cond5():
    sj.current_graph().reset()
    v = T.ones((10, 10)) * 3
    W = T.Variable(1)
    u = T.Placeholder((), "int32")
    out = T.cond(
        u > 0,
        lambda a, u: a * u[0],
        lambda a, u: a + u[1],
        true_inputs=(
            W,
            v,
        ),
        false_inputs=(
            W,
            v,
        ),
    )
    f = sj.function(u, outputs=out, updates={W: W + 1})
    assert np.array_equal(f(1), 3 * np.ones(10))
    assert np.array_equal(f(0), 5 * np.ones(10))
    assert np.array_equal(f(1), 9 * np.ones(10))
Example #30
0
We then demonstrate how to do a simple for loop and then a while loop.

"""

import matplotlib.pyplot as plt

import symjax
import symjax.tensor as T
import numpy as np


# suppose we are given a time-serie and we want to compute an
# exponential moving average, we also use the EMA coefficient alpha
# based on the user input

signal = T.Placeholder((512,), "float32", name="signal")
alpha = T.Placeholder((), "float32", "alpha")

# to use a scan function one needs a function to be applied at each step
# in our case an exponential moving average function
# this function should output the new value of the carry as well as an
# additional output, in our case, the carry (EMA) is also what we want to
# output at each tiem step


def fn(at, xt, alpha):
    # the function first input is the carry, then are the (ordered)
    # values from sequences and non_sequences similar to Theano
    EMA = at * alpha + (1 - alpha) * xt
    return EMA, EMA