Example #1
0
def create_vae(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1):

    x = T.Placeholder([batch_size, Ds[-1]], 'float32')

    # ENCODER
    enc = encoder(x, Ds[0])
    mu = enc[-1][:, :Ds[0]]
    logvar = enc[-1][:, Ds[0]:]
    var = T.exp(logvar)
 
    z = mu + T.exp(0.5 * logvar) * T.random.randn((batch_size, Ds[0]))
    z_ph = T.Placeholder((batch_size, Ds[0]), 'float32')

    # DECODER
    Ws, bs = init_weights(Ds, seed, scaler)

    Ws = [T.Variable(w) for w in Ws]
    bs = [T.Variable(b) for b in bs]
    logvar_x = T.Variable(T.zeros(1), name='logvar_x') 
    var_x = T.exp(logvar_x)

    h, h_ph = [z], [z_ph]
    for w, b in zip(Ws[:-1], bs[:-1]):
        h.append(T.matmul(h[-1], w.transpose()) + b)
        h.append(h[-1] * relu_mask(h[-1], leakiness))
        h_ph.append(T.matmul(h_ph[-1], w.transpose()) + b)
        h_ph.append(h_ph[-1] * relu_mask(h_ph[-1], leakiness))

    h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1])
    h_ph.append(T.matmul(h_ph[-1], Ws[-1].transpose()) + bs[-1])

    prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\
            + sum([T.mean(v**2) for v in bs[:-1]], 0.) / cov_b 
    kl = 0.5 * (1 + logvar - var - mu ** 2).sum(1)
    px = - 0.5 * (logvar_x + ((x - h[-1])**2 / var_x)).sum(1)
    loss = - (px + kl).mean() + prior

    variables = Ws + bs + sj.layers.get_variables(enc) + [logvar_x]
    opti = sj.optimizers.Adam(loss, lr, params=variables)

    train = sj.function(x, outputs=loss, updates=opti.updates)
    g = sj.function(z_ph, outputs=h_ph[-1])
    params = sj.function(outputs = Ws + bs + [T.exp(logvar_x) * T.ones(Ds[-1])])
    get_varx = sj.function(outputs = var_x)


    output = {'train': train, 'g':g, 'params':params}
    output['model'] = 'VAE'
    output['varx'] = get_varx
    output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler,
                    'prior': sj.function(outputs=prior)}
    def sample(n):
        samples = []
        for i in range(n // batch_size):
            samples.append(g(np.random.randn(batch_size, Ds[0])))
        return np.concatenate(samples)
    output['sample'] = sample
    return output
Example #2
0
def create_glo(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1,
               GLO=False):

    x = T.Placeholder([batch_size, Ds[-1]], 'float32')
    z = T.Variable(T.random.randn((batch_size, Ds[0])))
    logvar_x = T.Variable(T.ones(1))

    # DECODER
    Ws, bs = init_weights(Ds, seed, scaler)
    Ws = [T.Variable(w) for w in Ws]
    bs = [T.Variable(b) for b in bs]
    h = [z]
    for w, b in zip(Ws[:-1], bs[:-1]):
        h.append(T.matmul(h[-1], w.transpose()) + b)
        h.append(h[-1] * relu_mask(h[-1], leakiness))
    h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1])

    # LOSS
    prior = sum([T.sum(w**2) for w in Ws], 0.) / cov_W + sum([T.sum(v**2) for v in bs[:-1]], 0.) / cov_b
    if GLO:
        loss = T.sum((x - h[-1])**2) / batch_size + prior
        variables = Ws + bs
    else:
        loss = Ds[-1] * logvar_x.sum() + T.sum((x - h[-1])**2 / T.exp(logvar_x)) / batch_size + (z**2).sum() / batch_size + prior
        variables = Ws + bs

    prior = sum([(b**2).sum() for b in bs], 0.) / cov_b\
            + sum([(w**2).sum() for w in Ws], 0.) / cov_W
 
    opti = sj.optimizers.Adam(loss + prior, lr, params=variables)
    infer = sj.optimizers.Adam(loss, lr, params=[z])

    estimate = sj.function(x, outputs=z, updates=infer.updates)
    train = sj.function(x, outputs=loss, updates=opti.updates)
    lossf = sj.function(x, outputs=loss)
    params = sj.function(outputs = Ws + bs + [T.ones(Ds[-1]) * T.exp(logvar_x)])

    output = {'train': train, 'estimate':estimate, 'params':params}
    output['reset'] = lambda v: z.assign(v)
    if GLO:
        output['model'] = 'GLO'
    else:
        output['model'] = 'HARD'
    output['loss'] = lossf
    output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler}
    return output
Example #3
0
def softmax(x, axis=-1):
    r"""Softmax function.

    Computes the function which rescales elements to the range :math:`[0, 1]`
    such that the elements along :code:`axis` sum to :math:`1`.

    .. math ::
      \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}

    Args:
      axis: the axis or axes along which the softmax should be computed. The
        softmax output summed across these dimensions should sum to :math:`1`.
        Either an integer or a tuple of integers.
    """
    unnormalized = T.exp(x - T.stop_gradient(x.max(axis, keepdims=True)))
    return unnormalized / unnormalized.sum(axis, keepdims=True)
Example #4
0
def log_softmax(x, axis=-1):
    r"""Log-Softmax function.

    Computes the logarithm of the :code:`softmax` function, which rescales
    elements to the range :math:`[-\infty, 0)`.

    .. math ::
      \mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
      \right)

    Args:
      axis: the axis or axes along which the :code:`log_softmax` should be
        computed. Either an integer or a tuple of integers.
    """
    shifted = x - T.stop_gradient(x.max(axis, keepdims=True))
    return shifted - T.log(T.sum(T.exp(shifted), axis, keepdims=True))
Example #5
0
def generate_gaussian_filterbank(N, M, J, f0, f1, modes=1):

    # gaussian parameters
    freqs = get_scaled_freqs(f0, f1, J)
    freqs *= (J - 1) * 10

    if modes > 1:
        other_modes = np.random.randint(0, J, J * (modes - 1))
        freqs = np.concatenate([freqs, freqs[other_modes]])

    # crate the average vectors
    mu_init = np.stack([freqs, 0.1 * np.random.randn(J * modes)], 1)
    mu = T.Variable(mu_init.astype('float32'), name='mu')

    # create the covariance matrix
    cor = T.Variable(0.01 * np.random.randn(J * modes).astype('float32'),
                     name='cor')

    sigma_init = np.stack([freqs / 6, 1. + 0.01 * np.random.randn(J * modes)],
                          1)
    sigma = T.Variable(sigma_init.astype('float32'), name='sigma')

    # create the mixing coefficients
    mixing = T.Variable(np.ones((modes, 1, 1)).astype('float32'))

    # now apply our parametrization
    coeff = T.stop_gradient(T.sqrt((T.abs(sigma) + 0.1).prod(1))) * 0.95
    Id = T.eye(2)
    cov = Id * T.expand_dims((T.abs(sigma)+0.1),1) +\
            T.flip(Id, 0) * (T.tanh(cor) * coeff).reshape((-1, 1, 1))
    cov_inv = T.linalg.inv(cov)

    # get the gaussian filters
    time = T.linspace(-5, 5, M)
    freq = T.linspace(0, J * 10, N)
    x, y = T.meshgrid(time, freq)
    grid = T.stack([y.flatten(), x.flatten()], 1)
    centered = grid - T.expand_dims(mu, 1)
    # asdf
    gaussian = T.exp(-(T.matmul(centered, cov_inv)**2).sum(-1))
    norm = T.linalg.norm(gaussian, 2, 1, keepdims=True)
    gaussian_2d = T.abs(mixing) * T.reshape(gaussian / norm, (J, modes, N, M))
    return gaussian_2d.sum(1, keepdims=True), mu, cor, sigma, mixing
Example #6
0
def create_fns(input,
               in_signs,
               Ds,
               x,
               m0,
               m1,
               m2,
               batch_in_signs,
               alpha=0.1,
               sigma=1,
               sigma_x=1,
               lr=0.0002):

    cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])])
    BS = batch_in_signs.shape[0]
    Ws = [
        T.Variable(sj.initializers.glorot((j, i)) * sigma)
        for j, i in zip(Ds[1:], Ds[:-1])
    ]
    bs = [T.Variable(sj.initializers.he((j,)) * sigma) for j in Ds[1:-1]]\
                + [T.Variable(T.zeros((Ds[-1],)))]

    A_w = [T.eye(Ds[0])]
    B_w = [T.zeros(Ds[0])]

    A_q = [T.eye(Ds[0])]
    B_q = [T.zeros(Ds[0])]

    batch_A_q = [T.eye(Ds[0]) * T.ones((BS, 1, 1))]
    batch_B_q = [T.zeros((BS, Ds[0]))]

    maps = [input]
    signs = []
    masks = [T.ones(Ds[0])]

    in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1., alpha)
    batch_in_masks = T.where(
        T.concatenate([T.ones((BS, Ds[0])), batch_in_signs], 1) > 0, 1., alpha)

    for w, b in zip(Ws[:-1], bs[:-1]):

        pre_activation = T.matmul(w, maps[-1]) + b
        signs.append(T.sign(pre_activation))
        masks.append(T.where(pre_activation > 0, 1., alpha))

        maps.append(pre_activation * masks[-1])

    maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1])

    # compute per region A and B
    for start, end, w, b, m in zip(cumulative_units[:-1], cumulative_units[1:],
                                   Ws, bs, masks):

        A_w.append(T.matmul(w * m, A_w[-1]))
        B_w.append(T.matmul(w * m, B_w[-1]) + b)

        A_q.append(T.matmul(w * in_masks[start:end], A_q[-1]))
        B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b)

        batch_A_q.append(
            T.matmul(w * batch_in_masks[:, None, start:end], batch_A_q[-1]))
        batch_B_q.append((w * batch_in_masks[:, None, start:end]\
                            * batch_B_q[-1][:, None, :]).sum(2) + b)

    batch_B_q = batch_B_q[-1]
    batch_A_q = batch_A_q[-1]

    signs = T.concatenate(signs)

    inequalities = T.hstack(
        [T.concatenate(B_w[1:-1])[:, None],
         T.vstack(A_w[1:-1])]) * signs[:, None]

    inequalities_code = T.hstack(
        [T.concatenate(B_q[1:-1])[:, None],
         T.vstack(A_q[1:-1])]) * in_signs[:, None]

    #### loss
    log_sigma2 = T.Variable(sigma_x)
    sigma2 = T.exp(log_sigma2)

    Am1 = T.einsum('qds,nqs->nqd', batch_A_q, m1)
    Bm0 = T.einsum('qd,nq->nd', batch_B_q, m0)
    B2m0 = T.einsum('nq,qd->n', m0, batch_B_q**2)
    AAm2 = T.einsum('qds,qdu,nqup->nsp', batch_A_q, batch_A_q, m2)

    inner = -(x * (Am1.sum(1) + Bm0)).sum(1) + (Am1 * batch_B_q).sum((1, 2))

    loss_2 = (x**2).sum(1) + B2m0 + T.trace(AAm2, axis1=1, axis2=2).squeeze()

    loss_z = T.trace(m2.sum(1), axis1=1, axis2=2).squeeze()

    cst = 0.5 * (Ds[0] + Ds[-1]) * T.log(2 * np.pi)

    loss = cst + 0.5 * Ds[-1] * log_sigma2 + inner / sigma2\
            + 0.5 * loss_2 / sigma2 + 0.5 * loss_z

    mean_loss = loss.mean()
    adam = sj.optimizers.NesterovMomentum(mean_loss, Ws + bs, lr, 0.9)

    train_f = sj.function(batch_in_signs,
                          x,
                          m0,
                          m1,
                          m2,
                          outputs=mean_loss,
                          updates=adam.updates)
    f = sj.function(input,
                    outputs=[maps[-1], A_w[-1], B_w[-1], inequalities, signs])
    g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]])
    all_g = sj.function(in_signs, outputs=inequalities_code)
    h = sj.function(input, outputs=maps[-1])
    return f, g, h, all_g, train_f, sigma2
Example #7
0
    def __init__(
        self,
        env,
        actor,
        critic,
        lr=1e-4,
        batch_size=32,
        n=1,
        clip_ratio=0.2,
        entcoeff=0.01,
        target_kl=0.01,
        train_pi_iters=4,
        train_v_iters=4,
    ):

        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_min = env.action_space.low
            action_max = env.action_space.high
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.continuous = continuous
        self.lr = lr
        self.train_pi_iters = train_pi_iters
        self.train_v_iters = train_v_iters
        self.clip_ratio = clip_ratio
        self.target_kl = target_kl
        self.extras = {"logprob": ()}
        self.entcoeff = entcoeff

        state_ph = T.Placeholder((batch_size, num_states), "float32")
        ret_ph = T.Placeholder((batch_size, ), "float32")
        adv_ph = T.Placeholder((batch_size, ), "float32")
        act_ph = T.Placeholder((batch_size, num_actions), "float32")

        with symjax.Scope("actor"):
            logits = actor(state_ph)
            if continuous:
                logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )
        with symjax.Scope("old_actor"):
            old_logits = actor(state_ph)
            if continuous:
                old_logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )

        if not continuous:
            pi = Categorical(logits=logits)
        else:
            pi = MultivariateNormal(mean=logits, diag_log_std=logstd)

        actions = T.clip(pi.sample(), -2, 2)  # pi

        actor_params = actor_params = symjax.get_variables(scope="/actor/")
        old_actor_params = actor_params = symjax.get_variables(
            scope="/old_actor/")

        self.update_target = symjax.function(
            updates={o: a
                     for o, a in zip(old_actor_params, actor_params)})

        # PPO objectives
        # pi(a|s) / pi_old(a|s)
        pi_log_prob = pi.log_prob(act_ph)
        old_pi_log_prob = pi_log_prob.clone({
            logits: old_logits,
            logstd: old_logstd
        })

        ratio = T.exp(pi_log_prob - old_pi_log_prob)
        surr1 = ratio * adv_ph
        surr2 = adv_ph * T.clip(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio)

        pi_loss = -T.minimum(surr1, surr2).mean()
        # ent_loss = pi.entropy().mean() * self.entcoeff

        with symjax.Scope("critic"):
            v = critic(state_ph)
        # critic loss
        v_loss = ((ret_ph - v)**2).mean()

        # Info (useful to watch during learning)

        # a sample estimate for KL
        approx_kl = (old_pi_log_prob - pi_log_prob).mean()
        # a sample estimate for entropy
        # approx_ent = -logprob_given_actions.mean()
        # clipped = T.logical_or(
        #     ratio > (1 + clip_ratio), ratio < (1 - clip_ratio)
        # )
        # clipfrac = clipped.astype("float32").mean()

        with symjax.Scope("update_pi"):
            print(len(actor_params), "actor parameters")
            nn.optimizers.Adam(
                pi_loss,
                self.lr,
                params=actor_params,
            )
        with symjax.Scope("update_v"):
            critic_params = symjax.get_variables(scope="/critic/")
            print(len(critic_params), "critic parameters")

            nn.optimizers.Adam(
                v_loss,
                self.lr,
                params=critic_params,
            )

        self.get_params = symjax.function(outputs=critic_params)

        self.learn_pi = symjax.function(
            state_ph,
            act_ph,
            adv_ph,
            outputs=[pi_loss, approx_kl],
            updates=symjax.get_updates(scope="/update_pi/"),
        )
        self.learn_v = symjax.function(
            state_ph,
            ret_ph,
            outputs=v_loss,
            updates=symjax.get_updates(scope="/update_v/"),
        )

        single_state = T.Placeholder((1, num_states), "float32")
        single_v = v.clone({state_ph: single_state})
        single_sample = actions.clone({state_ph: single_state})

        self._act = symjax.function(single_state, outputs=single_sample)
        self._get_v = symjax.function(single_state, outputs=single_v)

        single_action = T.Placeholder((1, num_actions), "float32")
Example #8
0
    def __init__(
        self,
        state_dim,
        action_dim,
        lr,
        gamma,
        K_epochs,
        eps_clip,
        actor,
        critic,
        batch_size,
        continuous=True,
    ):
        self.lr = lr
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.batch_size = batch_size

        state = T.Placeholder((batch_size, ) + state_dim, "float32")

        reward = T.Placeholder((batch_size, ), "float32")
        old_action_logprobs = T.Placeholder((batch_size, ), "float32")

        logits = actor(state)

        if not continuous:
            given_action = T.Placeholder((batch_size, ), "int32")
            dist = Categorical(logits=logits)
        else:
            mean = T.tanh(logits[:, :logits.shape[1] // 2])
            std = T.exp(logits[:, logits.shape[1] // 2:])
            given_action = T.Placeholder((batch_size, action_dim), "float32")
            dist = MultivariateNormal(mean=mean, diag_std=std)

        sample = dist.sample()
        sample_logprobs = dist.log_prob(sample)

        self._act = symjax.function(state, outputs=[sample, sample_logprobs])

        given_action_logprobs = dist.log_prob(given_action)

        # Finding the ratio (pi_theta / pi_theta__old):
        ratios = T.exp(sample_logprobs - old_action_logprobs)
        ratios = T.clip(ratios, None, 1 + self.eps_clip)

        state_value = critic(state)
        advantages = reward - T.stop_gradient(state_value)

        loss = (-T.mean(ratios * advantages) + 0.5 * T.mean(
            (state_value - reward)**2) - 0.0 * dist.entropy().mean())

        print(loss)
        nn.optimizers.Adam(loss, self.lr)

        self.learn = symjax.function(
            state,
            given_action,
            reward,
            old_action_logprobs,
            outputs=T.mean(loss),
            updates=symjax.get_updates(),
        )
Example #9
0
import sys

sys.path.insert(0, "../")
import symjax
import symjax.tensor as T

# create our variable to be optimized
mu = T.Variable(T.random.normal((1,), seed=1))
cost = T.exp(-((mu - 1) ** 2))
lr = symjax.schedules.PiecewiseConstant(0.01, {100: 0.003, 150: 0.001})
opt = symjax.optimizers.Adam(cost, lr, params=[mu])
print(opt.updates)
f = symjax.function(outputs=cost, updates=opt.updates)

for k in range(4):
    for i in range(10):
        print(f())
    print("done")
    for v in opt.variables + [mu]:
        v.reset()
    lr.reset()

# 0.008471076
# 0.008201109
# 0.007946267
# 0.007705368
# 0.0074773384
# 0.007261208
# 0.0070561105
# 0.006861261
# 0.006675923
Example #10
0
sys.path.insert(0, "../")

import symjax
import symjax.tensor as T
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

matplotlib.use("Agg")

###### DERIVATIVE OF GAUSSIAN EXAMPLE

t = T.Placeholder((1000, ), "float32")
print(t)
f = T.meshgrid(t, t)
f = T.exp(-(t**2))
u = f.sum()
g = symjax.gradients(u, [t])
g2 = symjax.gradients(g[0].sum(), [t])
g3 = symjax.gradients(g2[0].sum(), [t])

dog = symjax.function(t, outputs=[g[0], g2[0], g3[0]])

plt.plot(np.array(dog(np.linspace(-10, 10, 1000))).T)

###### GRADIENT DESCENT
z = T.Variable(3.0)
loss = z**2
g_z = symjax.gradients(loss, [z])
print(loss, z)
train = symjax.function(outputs=[loss, z], updates={z: z - 0.1 * g_z[0]})
Example #11
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        batch_size,
        actor,
        critic,
        lr=1e-3,
        K_epochs=80,
        eps_clip=0.2,
        gamma=0.99,
        entropy_beta=0.01,
    ):
        self.actor = actor
        self.critic = critic
        self.gamma = gamma
        self.lr = lr
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.batch_size = batch_size

        states = T.Placeholder((batch_size, ) + state_shape,
                               "float32",
                               name="states")
        actions = T.Placeholder((batch_size, ) + actions_shape,
                                "float32",
                                name="states")
        rewards = T.Placeholder((batch_size, ),
                                "float32",
                                name="discounted_rewards")
        advantages = T.Placeholder((batch_size, ),
                                   "float32",
                                   name="advantages")

        self.target_actor = actor(states, distribution="gaussian")
        self.actor = actor(states, distribution="gaussian")
        self.critic = critic(states)

        # Finding the ratio (pi_theta / pi_theta__old) and
        # surrogate Loss https://arxiv.org/pdf/1707.06347.pdf
        with symjax.Scope("policy_loss"):
            ratios = T.exp(
                self.actor.actions.log_prob(actions) -
                self.target_actor.actions.log_prob(actions))
            ratios = T.clip(ratios, 0, 10)
            clipped_ratios = T.clip(ratios, 1 - self.eps_clip,
                                    1 + self.eps_clip)

            surr1 = advantages * ratios
            surr2 = advantages * clipped_ratios

            actor_loss = -(T.minimum(surr1, surr2)).mean()

        with symjax.Scope("monitor"):
            clipfrac = (((ratios > (1 + self.eps_clip)) |
                         (ratios <
                          (1 - self.eps_clip))).astype("float32").mean())
            approx_kl = (self.target_actor.actions.log_prob(actions) -
                         self.actor.actions.log_prob(actions)).mean()

        with symjax.Scope("critic_loss"):
            critic_loss = T.mean((rewards - self.critic.q_values)**2)

        with symjax.Scope("entropy"):
            entropy = self.actor.actions.entropy().mean()

        loss = actor_loss + critic_loss  # - entropy_beta * entropy

        with symjax.Scope("optimizer"):
            nn.optimizers.Adam(
                loss,
                lr,
                params=self.actor.params(True) + self.critic.params(True),
            )

        # create the update function
        self._train = symjax.function(
            states,
            actions,
            rewards,
            advantages,
            outputs=[actor_loss, critic_loss, clipfrac, approx_kl],
            updates=symjax.get_updates(scope="*optimizer"),
        )

        # initialize target as current
        self.update_target(1)
Example #12
0
def test_pymc():
    class RandomVariable(symjax.tensor.Variable):
        def __init__(self, name, shape, observed):
            if observed is None:
                super().__init__(np.zeros(shape), name=name)
            else:
                super().__init__(observed, name=name, trainable=False)

        def logp(self, value):
            raise NotImplementedError()

        def random(self, sample_shape):
            raise NotImplementedError()

        @property
        def logpt(self):
            return self.logp(self)

    class Normal(RandomVariable):
        def __init__(self, name, mu, sigma, shape=None, observed=None):
            self.mu = mu
            self.sigma = sigma
            super().__init__(name, shape, observed)

        def logp(self, value):
            tau = self.sigma**-2.0
            return (-tau *
                    (value - self.mu)**2 + tt.log(tau / np.pi / 2.0)) / 2.0

        def random(self, sample_shape):
            return np.random.randn(sample_shape) * self.sigma + self.mu

    x = Normal("x", 0, 10.0)
    s = Normal("s", 0.0, 5.0)
    y = Normal("y", x, tt.exp(s))

    assert symjax.current_graph().get(y) == 0.0

    #################
    model_logpt = x.logpt + s.logpt + y.logpt

    f = symjax.function(x, s, y, outputs=model_logpt)

    normal_loglike = jsp.stats.norm.logpdf

    def f_(x, s, y):
        return (normal_loglike(x, 0.0, 10.0) + normal_loglike(s, 0.0, 5.0) +
                normal_loglike(y, x, jnp.exp(s)))

    for i in range(10):
        x_val = np.random.randn() * 10.0
        s_val = np.random.randn() * 5.0
        y_val = np.random.randn() * 0.1 + x_val

        np.testing.assert_allclose(f(x_val, s_val, y_val),
                                   f_(x_val, s_val, y_val),
                                   rtol=1e-06)

    model_dlogpt = symjax.gradients(model_logpt, [x, s, y])

    f_with_grad = symjax.function(x, s, y, outputs=[model_logpt, model_dlogpt])

    f_with_grad(x_val, s_val, y_val)

    grad_fn = jax.grad(f_, argnums=[0, 1, 2])
    f_(x_val, s_val, y_val), grad_fn(x_val, s_val, y_val)