Пример #1
0
def create_vae(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1):

    x = T.Placeholder([batch_size, Ds[-1]], 'float32')

    # ENCODER
    enc = encoder(x, Ds[0])
    mu = enc[-1][:, :Ds[0]]
    logvar = enc[-1][:, Ds[0]:]
    var = T.exp(logvar)
 
    z = mu + T.exp(0.5 * logvar) * T.random.randn((batch_size, Ds[0]))
    z_ph = T.Placeholder((batch_size, Ds[0]), 'float32')

    # DECODER
    Ws, bs = init_weights(Ds, seed, scaler)

    Ws = [T.Variable(w) for w in Ws]
    bs = [T.Variable(b) for b in bs]
    logvar_x = T.Variable(T.zeros(1), name='logvar_x') 
    var_x = T.exp(logvar_x)

    h, h_ph = [z], [z_ph]
    for w, b in zip(Ws[:-1], bs[:-1]):
        h.append(T.matmul(h[-1], w.transpose()) + b)
        h.append(h[-1] * relu_mask(h[-1], leakiness))
        h_ph.append(T.matmul(h_ph[-1], w.transpose()) + b)
        h_ph.append(h_ph[-1] * relu_mask(h_ph[-1], leakiness))

    h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1])
    h_ph.append(T.matmul(h_ph[-1], Ws[-1].transpose()) + bs[-1])

    prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\
            + sum([T.mean(v**2) for v in bs[:-1]], 0.) / cov_b 
    kl = 0.5 * (1 + logvar - var - mu ** 2).sum(1)
    px = - 0.5 * (logvar_x + ((x - h[-1])**2 / var_x)).sum(1)
    loss = - (px + kl).mean() + prior

    variables = Ws + bs + sj.layers.get_variables(enc) + [logvar_x]
    opti = sj.optimizers.Adam(loss, lr, params=variables)

    train = sj.function(x, outputs=loss, updates=opti.updates)
    g = sj.function(z_ph, outputs=h_ph[-1])
    params = sj.function(outputs = Ws + bs + [T.exp(logvar_x) * T.ones(Ds[-1])])
    get_varx = sj.function(outputs = var_x)


    output = {'train': train, 'g':g, 'params':params}
    output['model'] = 'VAE'
    output['varx'] = get_varx
    output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler,
                    'prior': sj.function(outputs=prior)}
    def sample(n):
        samples = []
        for i in range(n // batch_size):
            samples.append(g(np.random.randn(batch_size, Ds[0])))
        return np.concatenate(samples)
    output['sample'] = sample
    return output
Пример #2
0
def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5):
    """Normalizes an array by subtracting mean and dividing by sqrt(var)."""
    if mean is None:
        mean = T.mean(x, axis, keepdims=True)
    if variance is None:
        # this definition is traditionally seen as less accurate than jnp.var's
        # mean((x - mean(x))**2) but may be faster and even, given typical
        # activation distributions and low-precision arithmetic, more accurate
        # when used in neural network normalization layers
        variance = T.mean(T.square(x), axis, keepdims=True) - T.square(mean)
    return (x - mean) * T.rsqrt(variance + epsilon)
Пример #3
0
    def forward(self, input, deterministic=None):

        if deterministic is None:
            deterministic = self.deterministic
        dirac = T.cast(deterministic, 'float32')

        self.mean = T.mean(input, self.axis, keepdims=True)
        self.var = T.var(input, self.axis, keepdims=True)
        if len(self.updates.keys()) == 0:
            self.avgmean, upm, step = T.ExponentialMovingAverage(
                self.mean, self.beta1)
            self.avgvar, upv, step = T.ExponentialMovingAverage(
                self.var,
                self.beta2,
                step=step,
                init=numpy.ones(self.var.shape).astype('float32'))
            self.add_variable(self.avgmean)
            self.add_variable(self.avgvar)
            self.add_update(upm)
            self.add_update(upv)

        self.usemean = self.mean * (1 - dirac) + self.avgmean * dirac
        self.usevar = self.var * (1 - dirac) + self.avgvar * dirac
        return self.W * (input - self.usemean) / \
            (T.sqrt(self.usevar) + self.const) + self.b
Пример #4
0
    def build_net(self, Q):
        # ------------------ all inputs ------------------------
        state = T.Placeholder([self.batch_size, self.n_states],
                              "float32",
                              name="s")
        next_state = T.Placeholder([self.batch_size, self.n_states],
                                   "float32",
                                   name="s_")
        reward = T.Placeholder(
            [
                self.batch_size,
            ],
            "float32",
            name="r",
        )  # input reward
        action = T.Placeholder(
            [
                self.batch_size,
            ],
            "int32",
            name="a",
        )  # input Action

        with symjax.Scope("eval_net"):
            q_eval = Q(state, self.n_actions)
        with symjax.Scope("test_set"):
            q_next = Q(next_state, self.n_actions)

        q_target = reward + self.reward_decay * q_next.max(1)
        q_target = T.stop_gradient(q_target)

        a_indices = T.stack([T.range(self.batch_size), action], axis=1)
        q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)),
                                         1).squeeze(1)
        loss = T.mean((q_target - q_eval_wrt_a)**2)
        nn.optimizers.Adam(loss, self.lr)

        self.train = symjax.function(state,
                                     action,
                                     reward,
                                     next_state,
                                     updates=symjax.get_updates())
        self.q_eval = symjax.function(state, outputs=q_eval)
Пример #5
0
    def forward(
        self,
        input,
        axis,
        deterministic,
        const=1e-4,
        beta1=0.9,
        beta2=0.9,
        W=T.ones,
        b=T.zeros,
        trainable_W=True,
        trainable_b=True,
    ):

        self.beta1 = beta1
        self.beta2 = beta2
        self.const = const
        self.axis = axis
        self.deterministic = deterministic

        parameter_shape = [
            input.shape[i] if i in axis else 1 for i in range(input.ndim)
        ]
        reduce_axes = [i for i in range(input.ndim) if i not in axis]

        self.create_variable("W", W, parameter_shape, trainable=trainable_W)
        self.create_variable("b", b, parameter_shape, trainable=trainable_b)

        input_mean = T.mean(input, reduce_axes, keepdims=True)
        input_inv_std = 1 / (T.std(input, reduce_axes, keepdims=True) + const)

        self.avg_mean = schedules.ExponentialMovingAverage(input_mean,
                                                           beta1)[1]
        self.avg_inv_std = schedules.ExponentialMovingAverage(
            input_inv_std, beta2)[1]

        use_mean = T.where(deterministic, self.avg_mean, input_mean)
        use_inv_std = T.where(deterministic, self.avg_inv_std, input_inv_std)
        W = self.W or 1.0
        b = self.b if self.b is not None else 0.0
        return W * (input - use_mean) * use_inv_std + b
Пример #6
0
def fn(window):
    # the function first input is the current index of the for loop
    # the other inputs are the (ordered) sequences and non_sequnces
    # values

    return T.mean(window)
Пример #7
0
def create_fns(batch_size, R, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1,
               var_x=1):

    alpha = T.Placeholder((1,), 'float32')
    x = T.Placeholder((Ds[0],), 'float32')
    X = T.Placeholder((batch_size, Ds[-1]), 'float32')

    signs = T.Placeholder((np.sum(Ds[1:-1]),), 'float32')
    SIGNS = T.Placeholder((R, np.sum(Ds[1:-1])), 'float32')
    
    m0 = T.Placeholder((batch_size, R), 'float32')
    m1 = T.Placeholder((batch_size, R, Ds[0]), 'float32')
    m2 = T.Placeholder((batch_size, R, Ds[0], Ds[0]), 'float32')

    Ws, vs = init_weights(Ds, seed, scaler)
    Ws = [T.Variable(w, name='W' + str(l)) for l, w in enumerate(Ws)]
    vs = [T.Variable(v, name='v' + str(l)) for l, v in enumerate(vs)]

    var_x = T.Variable(T.ones(Ds[-1]) * var_x)
    var_z = T.Variable(T.ones(Ds[0]))

    # create the placeholders
    Ws_ph = [T.Placeholder(w.shape, w.dtype) for w in Ws]
    vs_ph = [T.Placeholder(v.shape, v.dtype) for v in vs]
    var_x_ph = T.Placeholder(var_x.shape, var_x.dtype)

    ############################################################################
    # Compute the output of g(x)
    ############################################################################

    maps = [x]
    xsigns = []
    masks = []
    
    for w, v in zip(Ws[:-1], vs[:-1]):
        
        pre_activation = T.matmul(w, maps[-1]) + v
        xsigns.append(T.sign(pre_activation))
        masks.append(relu_mask(pre_activation, leakiness))
        maps.append(pre_activation * masks[-1])

    xsigns = T.concatenate(xsigns)
    maps.append(T.matmul(Ws[-1], maps[-1]) + vs[-1])

    ############################################################################
    # compute the masks and then the per layer affine mappings
    ############################################################################

    cumulative_units = np.cumsum([0] + Ds[1:])
    xqs = relu_mask([xsigns[None, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)
    qs = relu_mask([signs[None, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)
    Qs = relu_mask([SIGNS[:, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)

    Axs, bxs = get_Abs(Ws, vs, xqs)
    Aqs, bqs = get_Abs(Ws, vs, qs)
    AQs, bQs = get_Abs(Ws, vs, Qs)

    all_bxs = T.hstack(bxs[:-1]).transpose()
    all_Axs = T.hstack(Axs[:-1])[0]

    all_bqs = T.hstack(bqs[:-1]).transpose()
    all_Aqs = T.hstack(Aqs[:-1])[0]

    x_inequalities = T.hstack([all_Axs, all_bxs]) * xsigns[:, None]
    q_inequalities = T.hstack([all_Aqs, all_bqs]) * signs[:, None]

    ############################################################################
    # loss (E-step NLL)
    ############################################################################

    Bm0 = T.einsum('nd,Nn->Nd', bQs[-1], m0)
    B2m0 = T.einsum('nd,Nn->Nd', bQs[-1] ** 2, m0)
    Am1 = T.einsum('nds,Nns->Nd', AQs[-1], m1)
    ABm1 = T.einsum('nds,nd,Nns->Nd', AQs[-1], bQs[-1], m1)
    Am2ATdiag = T.diagonal(T.einsum('nds,Nnsc,npc->Ndp', AQs[-1], m2, AQs[-1]),
                        axis1=1, axis2=2)
    xAm1Bm0 = X * (Am1 + Bm0)

    M2diag = T.diagonal(m2.sum(1), axis1=1, axis2=2)
    
    prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\
            + sum([T.mean(v**2) for v in vs[:-1]], 0.) / cov_b
    loss = - 0.5 * (T.log(var_x).sum() + T.log(var_z).sum()\
            + (M2diag / var_z).sum(1).mean() + ((X ** 2 - 2 * xAm1Bm0 + B2m0\
            + Am2ATdiag + 2 * ABm1) / var_x).sum(1).mean())

    mean_loss = - (loss + 0.5 * prior)
    adam = sj.optimizers.SGD(mean_loss, 0.001, params=Ws + vs)

    ############################################################################
    # update of var_x
    ############################################################################

    update_varx = (X ** 2 - 2 * xAm1Bm0 + B2m0 + Am2ATdiag + 2 * ABm1).mean()\
                    * T.ones(Ds[-1])
    update_varz = M2diag.mean() * T.ones(Ds[0])

    ############################################################################
    # update for biases IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX
    ############################################################################

    FQ = get_forward(Ws, Qs)
    update_vs = {}
    for i in range(len(vs)):
        
        if i < len(vs) - 1:
            # now we forward each bias to the x-space except the ith
            separated_bs = bQs[-1] - T.einsum('nds,s->nd', FQ[i], vs[i])
            # compute the residual and apply sigma
            residual = (X[:, None, :] - separated_bs) * m0[:, :, None]\
                                         - T.einsum('nds,Nns->Nnd', AQs[-1], m1)
            back_error = T.einsum('nds,nd->s', FQ[i], residual.mean(0))
            whiten = T.einsum('ndc,nds,n->cs', FQ[i] , FQ[i], m0.mean(0))\
                        + T.eye(back_error.shape[0]) / (Ds[i] * cov_b)
            update_vs[vs[i]] = T.linalg.solve(whiten, back_error)
        else:
            back_error = (X - (Am1 + Bm0) + vs[-1])
            update_vs[vs[i]] = back_error.mean(0)

    ############################################################################
    # update for slopes IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX 
    ############################################################################

    update_Ws = {}
    for i in range(len(Ws)):
        
        U = T.einsum('nds,ndc->nsc', FQ[i], FQ[i])
        if i == 0:
            V = m2.mean(0)
        else:
            V1 = T.einsum('nd,nq,Nn->ndq', bQs[i-1], bQs[i-1], m0)
            V2 = T.einsum('nds,nqc,Nnsc->ndq', AQs[i-1], AQs[i-1], m2)
            V3 = T.einsum('nds,nq,Nns->ndq', AQs[i-1], bQs[i-1], m1)
            Q = T.einsum('nd,nq->ndq', Qs[i - 1], Qs[i - 1])
            V = Q * (V1 + V2 + V3 + V3.transpose((0, 2, 1))) / batch_size

        whiten = T.stack([T.kron(U[n], V[n]) for n in range(V.shape[0])]).sum(0)
        whiten = whiten + T.eye(whiten.shape[-1]) / (Ds[i]*Ds[i+1]*cov_W)
        # compute the residual (bottom up)
        if i == len(Ws) - 1:
            bottom_up = (X[:, None, :] - vs[-1])
        else:
            if i == 0:
                residual = (X[:, None, :] - bQs[-1])
            else:
                residual = (X[:, None, :] - bQs[-1]\
                            + T.einsum('nds,ns->nd', FQ[i - 1], bQs[i-1]))
            bottom_up = T.einsum('ndc,Nnd->Nnc', FQ[i], residual)

        # compute the top down vector
        if i == 0:
            top_down = m1
        else:
            top_down = Qs[i - 1] * (T.einsum('nds,Nns->Nnd', AQs[i - 1], m1) +\
                               T.einsum('nd,Nn->Nnd', bQs[i - 1], m0))

        vector = T.einsum('Nnc,Nns->cs', bottom_up, top_down) / batch_size
        condition = T.diagonal(whiten)
        update_W = T.linalg.solve(whiten, vector.reshape(-1)).reshape(Ws[i].shape)
        update_Ws[Ws[i]] = update_W

    ############################################################################
    # create the io functions
    ############################################################################

    params = sj.function(outputs = Ws + vs + [var_x])
    ll = T.Placeholder((), 'int32')
    selector = T.one_hot(ll, len(vs))
    for i in range(len(vs)):
        update_vs[vs[i]] = ((1 - alpha) * vs[i] + alpha * update_vs[vs[i]])\
                            * selector[i] + vs[i] * (1 - selector[i])
    for i in range(len(Ws)):
        update_Ws[Ws[i]] = ((1 - alpha) * Ws[i] + alpha * update_Ws[Ws[i]])\
                            * selector[i] + Ws[i] * (1 - selector[i])

    output = {'train':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                  updates=adam.updates),
              'update_var':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                        updates = {var_x: update_varx}),
              'update_vs':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                      updates = update_vs),
              'loss':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss),
              'update_Ws':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                      updates = update_Ws),
              'signs2Ab': sj.function(signs, outputs=[Aqs[-1][0], bqs[-1][0]]),
              'signs2ineq': sj.function(signs, outputs=q_inequalities),
              'g': sj.function(x, outputs=maps[-1]),
              'input2all': sj.function(x, outputs=[maps[-1], Axs[-1][0],
                                       bxs[-1][0], x_inequalities, xsigns]),
              'get_nll': sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss),
              'assign': sj.function(*Ws_ph, *vs_ph, var_x_ph,
                                    updates=dict(zip(Ws + vs + [var_x],
                                             Ws_ph + vs_ph + [var_x_ph]))),
              'varx': sj.function(outputs=var_x),
              'prior': sj.function(outputs=prior),
              'varz': sj.function(outputs=var_z),
              'params': params,
#              'probed' : sj.function(SIGNS, X, m0, m1, m2, outputs=probed),
              'input2signs': sj.function(x, outputs=xsigns),
              'S' : Ds[0], 'D':  Ds[-1], 'R': R, 'model': 'EM', 'L':len(Ds)-1,
              'kwargs': {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler}}
 
    def sample(n):
        samples = []
        for i in range(n):
            samples.append(output['g'](np.random.randn(Ds[0])))
        return np.array(samples)
    
    output['sample'] = sample

    return output
Пример #8
0
    def __init__(
        self,
        state_dim,
        action_dim,
        lr,
        gamma,
        K_epochs,
        eps_clip,
        actor,
        critic,
        batch_size,
        continuous=True,
    ):
        self.lr = lr
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.batch_size = batch_size

        state = T.Placeholder((batch_size, ) + state_dim, "float32")

        reward = T.Placeholder((batch_size, ), "float32")
        old_action_logprobs = T.Placeholder((batch_size, ), "float32")

        logits = actor(state)

        if not continuous:
            given_action = T.Placeholder((batch_size, ), "int32")
            dist = Categorical(logits=logits)
        else:
            mean = T.tanh(logits[:, :logits.shape[1] // 2])
            std = T.exp(logits[:, logits.shape[1] // 2:])
            given_action = T.Placeholder((batch_size, action_dim), "float32")
            dist = MultivariateNormal(mean=mean, diag_std=std)

        sample = dist.sample()
        sample_logprobs = dist.log_prob(sample)

        self._act = symjax.function(state, outputs=[sample, sample_logprobs])

        given_action_logprobs = dist.log_prob(given_action)

        # Finding the ratio (pi_theta / pi_theta__old):
        ratios = T.exp(sample_logprobs - old_action_logprobs)
        ratios = T.clip(ratios, None, 1 + self.eps_clip)

        state_value = critic(state)
        advantages = reward - T.stop_gradient(state_value)

        loss = (-T.mean(ratios * advantages) + 0.5 * T.mean(
            (state_value - reward)**2) - 0.0 * dist.entropy().mean())

        print(loss)
        nn.optimizers.Adam(loss, self.lr)

        self.learn = symjax.function(
            state,
            given_action,
            reward,
            old_action_logprobs,
            outputs=T.mean(loss),
            updates=symjax.get_updates(),
        )
Пример #9
0
    def __init__(
        self,
        env_fn,
        actor,
        critic,
        gamma=0.99,
        tau=0.01,
        lr=1e-3,
        batch_size=32,
        epsilon=0.1,
        epsilon_decay=1 / 1000,
        min_epsilon=0.01,
        reward=None,
    ):

        # comment out this line if you don't want to record a video of the agent
        # if save_folder is not None:
        # test_env = gym.wrappers.Monitor(test_env)

        # get size of state space and action space
        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_max = env.action_space.high[0]
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.gamma = gamma
        self.continuous = continuous
        self.observ_min = np.clip(env.observation_space.low, -20, 20)
        self.observ_max = np.clip(env.observation_space.high, -20, 20)
        self.env = env
        self.reward = reward

        # state
        state = T.Placeholder((batch_size, num_states), "float32")
        gradients = T.Placeholder((batch_size, num_actions), "float32")
        action = T.Placeholder((batch_size, num_actions), "float32")
        target = T.Placeholder((batch_size, 1), "float32")

        with symjax.Scope("actor_critic"):

            scaled_out = action_max * actor(state)
            Q = critic(state, action)

        a_loss = -T.sum(gradients * scaled_out)
        q_loss = T.mean((Q - target)**2)

        nn.optimizers.Adam(a_loss + q_loss, lr)

        self.update = symjax.function(
            state,
            action,
            target,
            gradients,
            outputs=[a_loss, q_loss],
            updates=symjax.get_updates(),
        )
        g = symjax.gradients(T.mean(Q), [action])[0]
        self.get_gradients = symjax.function(state, action, outputs=g)

        # also create the target variants
        with symjax.Scope("actor_critic_target"):
            scaled_out_target = action_max * actor(state)
            Q_target = critic(state, action)

        self.actor_predict = symjax.function(state, outputs=scaled_out)
        self.actor_predict_target = symjax.function(state,
                                                    outputs=scaled_out_target)
        self.critic_predict = symjax.function(state, action, outputs=Q)
        self.critic_predict_target = symjax.function(state,
                                                     action,
                                                     outputs=Q_target)

        t_params = symjax.get_variables(scope="/actor_critic_target/*")
        params = symjax.get_variables(scope="/actor_critic/*")
        replacement = {
            t: tau * e + (1 - tau) * t
            for t, e in zip(t_params, params)
        }
        self.update_target = symjax.function(updates=replacement)

        single_state = T.Placeholder((1, num_states), "float32")
        if not continuous:
            scaled_out = clean_action.argmax(-1)

        self.act = symjax.function(single_state,
                                   outputs=scaled_out.clone(
                                       {state: single_state})[0])
Пример #10
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        batch_size,
        actor,
        critic,
        lr=1e-3,
        K_epochs=80,
        eps_clip=0.2,
        gamma=0.99,
        entropy_beta=0.01,
    ):
        self.actor = actor
        self.critic = critic
        self.gamma = gamma
        self.lr = lr
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.batch_size = batch_size

        states = T.Placeholder((batch_size, ) + state_shape,
                               "float32",
                               name="states")
        actions = T.Placeholder((batch_size, ) + actions_shape,
                                "float32",
                                name="states")
        rewards = T.Placeholder((batch_size, ),
                                "float32",
                                name="discounted_rewards")
        advantages = T.Placeholder((batch_size, ),
                                   "float32",
                                   name="advantages")

        self.target_actor = actor(states, distribution="gaussian")
        self.actor = actor(states, distribution="gaussian")
        self.critic = critic(states)

        # Finding the ratio (pi_theta / pi_theta__old) and
        # surrogate Loss https://arxiv.org/pdf/1707.06347.pdf
        with symjax.Scope("policy_loss"):
            ratios = T.exp(
                self.actor.actions.log_prob(actions) -
                self.target_actor.actions.log_prob(actions))
            ratios = T.clip(ratios, 0, 10)
            clipped_ratios = T.clip(ratios, 1 - self.eps_clip,
                                    1 + self.eps_clip)

            surr1 = advantages * ratios
            surr2 = advantages * clipped_ratios

            actor_loss = -(T.minimum(surr1, surr2)).mean()

        with symjax.Scope("monitor"):
            clipfrac = (((ratios > (1 + self.eps_clip)) |
                         (ratios <
                          (1 - self.eps_clip))).astype("float32").mean())
            approx_kl = (self.target_actor.actions.log_prob(actions) -
                         self.actor.actions.log_prob(actions)).mean()

        with symjax.Scope("critic_loss"):
            critic_loss = T.mean((rewards - self.critic.q_values)**2)

        with symjax.Scope("entropy"):
            entropy = self.actor.actions.entropy().mean()

        loss = actor_loss + critic_loss  # - entropy_beta * entropy

        with symjax.Scope("optimizer"):
            nn.optimizers.Adam(
                loss,
                lr,
                params=self.actor.params(True) + self.critic.params(True),
            )

        # create the update function
        self._train = symjax.function(
            states,
            actions,
            rewards,
            advantages,
            outputs=[actor_loss, critic_loss, clipfrac, approx_kl],
            updates=symjax.get_updates(scope="*optimizer"),
        )

        # initialize target as current
        self.update_target(1)