Exemplo n.º 1
0
def SJ(x, y, N, lr, model, preallocate=False):
    symjax.current_graph().reset()
    sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])
    sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])

    np.random.seed(0)

    sj_W = T.Variable(np.random.randn(D, 1).astype("float32"))
    sj_b = T.Variable(np.random.randn(1, ).astype("float32"))

    sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output)**2).mean()

    if model == "SGD":
        optimizers.SGD(sj_loss, lr)
    elif model == "Adam":
        optimizers.Adam(sj_loss, lr)
    train = symjax.function(sj_input,
                            sj_output,
                            outputs=sj_loss,
                            updates=symjax.get_updates())

    losses = []
    for i in tqdm(range(N)):
        losses.append(train(x, y))

    return losses
Exemplo n.º 2
0
def SJ(x, y, N, preallocate=False):
    symjax.current_graph().reset()
    sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])
    sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])

    np.random.seed(0)

    sj_W = T.Variable(np.random.randn(D, 1).astype("float32"))
    sj_b = T.Variable(
        np.random.randn(
            1,
        ).astype("float32")
    )

    sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean()

    optimizers.Adam(sj_loss, lr)

    train = symjax.function(sj_input, sj_output, updates=symjax.get_updates())

    if preallocate:
        import jax

        x = jax.device_put(x)
        y = jax.device_put(y)

    t = time.time()
    for i in range(N):
        train(x, y)

    return time.time() - t
Exemplo n.º 3
0
 def updates(self):
     return symjax.get_updates(scope=self._scope_name)
     if hasattr(self, "_updates"):
         return self._updates
     else:
         self._updates = {}
         return self._updates
Exemplo n.º 4
0
def test_bn():
    sj.current_graph().reset()
    BATCH_SIZE = 5
    DIM = 2
    input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input")
    deterministic = T.Placeholder((1,), "bool", name="deterministic")

    bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic)

    update = sj.function(input, deterministic, outputs=bn, updates=sj.get_updates())
    get_stats = sj.function(input, outputs=bn.avg_mean)

    data = np.random.randn(50, DIM) * 4 + 2

    true_means = []
    actual_means = []

    for i in range(10):
        batch = data[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
        output = update(batch, 0)
        assert np.allclose(
            output, (batch - batch.mean(0)) / (1e-4 + batch.std(0)), 1e-4
        )
        actual_means.append(get_stats(batch))
        if i == 0:
            true_means.append(batch.mean(0))
        else:
            true_means.append(0.9 * true_means[-1] + 0.1 * batch.mean(0))

    true_means = np.array(true_means)
    actual_means = np.array(actual_means).squeeze()

    assert np.allclose(true_means, actual_means, 1e-4)
Exemplo n.º 5
0
def SJ_EMA(X, debias=True):
    symjax.current_graph().reset()
    x = T.Placeholder((), "float32", name="x")
    value = symjax.nn.schedules.ExponentialMovingAverage(x, 0.9,
                                                         debias=debias)[0]
    train = symjax.function(x, outputs=value, updates=symjax.get_updates())
    outputs = []
    for i in range(len(X)):
        outputs.append(train(X[i]))
    return outputs
Exemplo n.º 6
0
def test_sma():
    symjax.current_graph().reset()
    a = symjax.tensor.Placeholder((4, ), "float32")
    sma, var = symjax.nn.schedules.SimpleMovingAverage(a, 3)
    f = symjax.function(a, outputs=[sma, var], updates=symjax.get_updates())

    data = np.random.randn(4, 4)
    current = [data[0], data[:2].mean(0), data[:3].mean(0), data[1:4].mean(0)]

    for i in range(data.shape[0]):
        out = f(data[i])
        assert np.allclose(out[0], current[i])
Exemplo n.º 7
0
def test_ema():
    symjax.current_graph().reset()
    a = symjax.tensor.Placeholder((), "float32")
    ema, var = symjax.nn.schedules.ExponentialMovingAverage(a,
                                                            0.9,
                                                            debias=False)
    # t = symjax.get_variables("*num_steps*", trainable=False)
    f = symjax.function(a, outputs=[ema, var], updates=symjax.get_updates())
    current = 0

    for i in range(10):
        out = f(1)
        assert np.allclose(out[1], current)
        current = 0.9 * current + 0.1 * 1
        assert np.allclose(out[0], current)
Exemplo n.º 8
0
    def build_net(self, Q):
        # ------------------ all inputs ------------------------
        state = T.Placeholder([self.batch_size, self.n_states],
                              "float32",
                              name="s")
        next_state = T.Placeholder([self.batch_size, self.n_states],
                                   "float32",
                                   name="s_")
        reward = T.Placeholder(
            [
                self.batch_size,
            ],
            "float32",
            name="r",
        )  # input reward
        action = T.Placeholder(
            [
                self.batch_size,
            ],
            "int32",
            name="a",
        )  # input Action

        with symjax.Scope("eval_net"):
            q_eval = Q(state, self.n_actions)
        with symjax.Scope("test_set"):
            q_next = Q(next_state, self.n_actions)

        q_target = reward + self.reward_decay * q_next.max(1)
        q_target = T.stop_gradient(q_target)

        a_indices = T.stack([T.range(self.batch_size), action], axis=1)
        q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)),
                                         1).squeeze(1)
        loss = T.mean((q_target - q_eval_wrt_a)**2)
        nn.optimizers.Adam(loss, self.lr)

        self.train = symjax.function(state,
                                     action,
                                     reward,
                                     next_state,
                                     updates=symjax.get_updates())
        self.q_eval = symjax.function(state, outputs=q_eval)
Exemplo n.º 9
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        n_episodes,
        episode_length,
        actor,
        lr=1e-3,
        gamma=0.99,
    ):
        self.actor = actor
        self.gamma = gamma
        self.lr = lr
        self.episode_length = episode_length
        self.n_episodes = n_episodes
        self.batch_size = episode_length * n_episodes

        states = T.Placeholder((self.batch_size, ) + state_shape, "float32")
        actions = T.Placeholder((self.batch_size, ) + actions_shape, "float32")
        discounted_rewards = T.Placeholder((self.batch_size, ), "float32")

        self.actor = actor(states, distribution="gaussian")

        logprobs = self.actor.actions.log_prob(actions)
        actor_loss = -(logprobs * discounted_rewards).sum() / n_episodes

        with symjax.Scope("REINFORCE_optimizer"):
            nn.optimizers.Adam(
                actor_loss,
                lr,
                params=self.actor.params(True),
            )

        # create the update function
        self._train = symjax.function(
            states,
            actions,
            discounted_rewards,
            outputs=actor_loss,
            updates=symjax.get_updates(scope="*/REINFORCE_optimizer"),
        )
Exemplo n.º 10
0
def test_bn():
    np.random.seed(0)
    sj.current_graph().reset()
    BATCH_SIZE = 5
    DIM = 2
    input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input")
    deterministic = T.Placeholder((), "bool", name="deterministic")

    bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic)

    update = sj.function(input,
                         deterministic,
                         outputs=bn,
                         updates=sj.get_updates())
    avmean = symjax.get_variables(trainable=None)[-3]
    print(avmean)
    get_stats = sj.function(input, outputs=avmean[0])

    data = np.random.randn(50, DIM) * 4 + 2

    true_means = [np.zeros(DIM)]
    actual_means = [np.zeros(DIM)]

    for i in range(10):
        batch = data[BATCH_SIZE * i:BATCH_SIZE * (i + 1)]

        output = update(batch, 0)
        assert np.allclose(
            output,
            (batch - batch.mean(0)) / np.sqrt(0.001 + batch.var(0)),
            1e-4,
        )

        actual_means.append(get_stats(batch))
        true_means.append(0.99 * true_means[-1] + 0.01 * batch.mean(0))

    true_means = np.array(true_means)
    actual_means = np.array(actual_means).squeeze()
    assert np.allclose(true_means, actual_means)
Exemplo n.º 11
0
import symjax
from symjax import nn
import symjax.tensor as T
import numpy as np

input = T.Placeholder((64, 3, 32, 32), "float32")
label = T.Placeholder((64, ), "int32")
deterministic = T.Placeholder((), "bool")

block = nn.models.ResidualBlockv1(activation=nn.relu, dropout_rate=0.1)
transformed = nn.models.ResNet(
    input,
    widths=[32, 32, 64, 64, 128, 128],
    strides=[1, 1, 2, 1, 2, 1],
    block=block,
    deterministic=deterministic,
)

classifier = nn.layers.Dense(transformed, 10)
loss = nn.losses.sparse_softmax_crossentropy_logits(label, classifier).mean()

nn.optimizers.Adam(loss, 0.001)

train = symjax.function(input,
                        label,
                        deterministic,
                        outputs=loss,
                        updates=symjax.get_updates())
Exemplo n.º 12
0
def classif_sj(train_x, train_y, test_x, test_y, mlp=True):
    symjax.current_graph().reset()
    from symjax import nn

    batch_size = 128

    input = T.Placeholder((batch_size, 3, 32, 32), "float32")
    labels = T.Placeholder((batch_size, ), "int32")
    deterministic = T.Placeholder((), "bool")

    if not mlp:
        out = nn.relu(nn.layers.Conv2D(input, 32, (3, 3)))
        for i in range(3):
            for j in range(3):
                conv = nn.layers.Conv2D(out, 32 * (i + 1), (3, 3), pad="SAME")
                bn = nn.layers.BatchNormalization(conv, [1],
                                                  deterministic=deterministic)
                bn = nn.relu(bn)
                conv = nn.layers.Conv2D(bn, 32 * (i + 1), (3, 3), pad="SAME")
                bn = nn.layers.BatchNormalization(conv, [1],
                                                  deterministic=deterministic)
                out = out + bn

            out = nn.layers.Pool2D(out, (2, 2), pool_type="AVG")
            out = nn.layers.Conv2D(out, 32 * (i + 2), (1, 1))
        # out = out.mean((2, 3))
        out = nn.layers.Pool2D(out, out.shape.get()[-2:], pool_type="AVG")
    else:
        out = input
        for i in range(6):
            out = nn.layers.Dense(out, 4000)
            out = nn.relu(
                nn.layers.BatchNormalization(out, [1],
                                             deterministic=deterministic))

    outputs = nn.layers.Dense(out, 10)

    loss = nn.losses.sparse_softmax_crossentropy_logits(labels, outputs).mean()
    nn.optimizers.Adam(loss, 0.001)

    accu = T.equal(outputs.argmax(1), labels).astype("float32").mean()

    train = symjax.function(
        input,
        labels,
        deterministic,
        outputs=[loss, accu, outputs],
        updates=symjax.get_updates(),
    )
    test = symjax.function(input, labels, deterministic, outputs=accu)

    for epoch in range(5):
        accu = 0
        for x, y in symjax.data.utils.batchify(train_x,
                                               train_y,
                                               batch_size=batch_size,
                                               option="random"):
            accu += train(x, y, 0)[1]

        print("training", accu / (len(train_x) // batch_size))

        accu = 0
        for x, y in symjax.data.utils.batchify(test_x,
                                               test_y,
                                               batch_size=batch_size,
                                               option="continuous"):
            accu += test(x, y, 1)
        print(accu / (len(test_x) // batch_size))
Exemplo n.º 13
0
    def __init__(
        self,
        state_dim,
        action_dim,
        lr,
        gamma,
        K_epochs,
        eps_clip,
        actor,
        critic,
        batch_size,
        continuous=True,
    ):
        self.lr = lr
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.batch_size = batch_size

        state = T.Placeholder((batch_size, ) + state_dim, "float32")

        reward = T.Placeholder((batch_size, ), "float32")
        old_action_logprobs = T.Placeholder((batch_size, ), "float32")

        logits = actor(state)

        if not continuous:
            given_action = T.Placeholder((batch_size, ), "int32")
            dist = Categorical(logits=logits)
        else:
            mean = T.tanh(logits[:, :logits.shape[1] // 2])
            std = T.exp(logits[:, logits.shape[1] // 2:])
            given_action = T.Placeholder((batch_size, action_dim), "float32")
            dist = MultivariateNormal(mean=mean, diag_std=std)

        sample = dist.sample()
        sample_logprobs = dist.log_prob(sample)

        self._act = symjax.function(state, outputs=[sample, sample_logprobs])

        given_action_logprobs = dist.log_prob(given_action)

        # Finding the ratio (pi_theta / pi_theta__old):
        ratios = T.exp(sample_logprobs - old_action_logprobs)
        ratios = T.clip(ratios, None, 1 + self.eps_clip)

        state_value = critic(state)
        advantages = reward - T.stop_gradient(state_value)

        loss = (-T.mean(ratios * advantages) + 0.5 * T.mean(
            (state_value - reward)**2) - 0.0 * dist.entropy().mean())

        print(loss)
        nn.optimizers.Adam(loss, self.lr)

        self.learn = symjax.function(
            state,
            given_action,
            reward,
            old_action_logprobs,
            outputs=T.mean(loss),
            updates=symjax.get_updates(),
        )
Exemplo n.º 14
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        n_episodes,
        episode_length,
        actor,
        critic,
        lr=1e-3,
        gamma=0.99,
        train_v_iters=10,
    ):
        self.actor = actor
        self.critic = critic
        self.gamma = gamma
        self.lr = lr
        self.episode_length = episode_length
        self.n_episodes = n_episodes
        self.batch_size = episode_length * n_episodes
        self.train_v_iters = train_v_iters

        states = T.Placeholder((self.batch_size, ) + state_shape, "float32")
        actions = T.Placeholder((self.batch_size, ) + actions_shape, "float32")
        discounted_rewards = T.Placeholder((self.batch_size, ), "float32")
        advantages = T.Placeholder((self.batch_size, ), "float32")

        self.actor = actor(states, distribution="gaussian")
        self.critic = critic(states)

        logprobs = self.actor.actions.log_prob(actions)
        actor_loss = -(logprobs * advantages).sum() / n_episodes
        critic_loss = 0.5 * (
            (discounted_rewards - self.critic.q_values)**2).mean()

        with symjax.Scope("actor_optimizer"):
            nn.optimizers.Adam(
                actor_loss,
                lr,
                params=self.actor.params(True),
            )
        with symjax.Scope("critic_optimizer"):
            nn.optimizers.Adam(
                critic_loss,
                lr,
                params=self.critic.params(True),
            )

        # create the update function
        self._train_actor = symjax.function(
            states,
            actions,
            advantages,
            outputs=actor_loss,
            updates=symjax.get_updates(scope="*/actor_optimizer"),
        )
        # create the update function
        self._train_critic = symjax.function(
            states,
            discounted_rewards,
            outputs=critic_loss,
            updates=symjax.get_updates(scope="*/critic_optimizer"),
        )
Exemplo n.º 15
0
def test_learn_bn():

    symjax.current_graph().reset()
    import tensorflow as tf
    from tensorflow.keras import layers
    import symjax.nn as nn

    np.random.seed(0)

    batch_size = 128

    W = np.random.randn(5, 5, 3, 2)
    W2 = np.random.randn(2 * 28 * 28, 10)

    inputs = layers.Input(shape=(3, 32, 32))
    out = layers.Permute((2, 3, 1))(inputs)
    out = layers.Conv2D(2,
                        5,
                        activation="linear",
                        kernel_initializer=lambda *args, **kwargs: W)(out)
    out = layers.BatchNormalization(-1)(out)
    out = layers.Activation("relu")(out)
    out = layers.Flatten()(out)
    out = layers.Dense(10,
                       activation="linear",
                       kernel_initializer=lambda *args, **kwargs: W2)(out)
    model = tf.keras.Model(inputs, out)
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)

    input = T.Placeholder((batch_size, 3, 32, 32), "float32")
    label = T.Placeholder((batch_size, ), "int32")
    deterministic = T.Placeholder((), "bool")

    conv = nn.layers.Conv2D(input, 2, (5, 5), W=W.transpose((3, 2, 0, 1)))
    out = nn.layers.BatchNormalization(conv, [1], deterministic=deterministic)
    out = nn.relu(out)
    out = nn.layers.Dense(out.transpose((0, 2, 3, 1)), 10, W=W2.T)
    loss = nn.losses.sparse_softmax_crossentropy_logits(label, out).mean()
    # V2 = T.Variable(W2)
    # loss = input.sum() * V2.sum()
    # out=loss
    nn.optimizers.SGD(loss, 0.001)
    f = symjax.function(input, label, deterministic, outputs=[loss, out])
    g = symjax.function(
        input,
        label,
        deterministic,
        outputs=symjax.gradients(loss, symjax.get_variables(trainable=True)),
    )
    train = symjax.function(
        input,
        label,
        deterministic,
        outputs=loss,
        updates=symjax.get_updates(),
    )

    for epoch in range(10):
        # generate some random inputs and labels
        x = np.random.randn(batch_size, 3, 32, 32)
        y = np.random.randint(0, 10, size=batch_size)

        # test predictions during testing mode
        preds = model(x, training=False)
        nb = np.isclose(preds, f(x, y, 1)[1], atol=1e-3).mean()
        print("preds not training", nb)
        assert nb > 0.8

        # test prediction during training mode
        # now get the gradients
        with tf.GradientTape() as tape:
            preds = model(x, training=True)
            loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(y, preds))

        nb = np.isclose(preds, f(x, y, 0)[1], atol=1e-3).mean()
        print("preds training", nb)
        assert nb > 0.8

        # test loss function during training
        losss = f(x, y, 0)[0]
        print("losses", losss, loss)
        assert np.abs(losss - loss) < 1e-3

        # test the gradients
        grads = tape.gradient(loss, model.trainable_variables)
        sj_grads = g(x, y, 0)

        for tf_g, sj_g in zip(grads, sj_grads):
            if sj_g.ndim == 4:
                sj_g = sj_g.transpose((2, 3, 1, 0))
            else:
                sj_g = sj_g.transpose()

            print(sj_g.shape, tf_g.shape)
            nb = np.isclose(
                np.reshape(sj_g, -1),
                np.reshape(tf_g, -1),
                atol=1e-3,
            ).mean()
            print("grads training", nb)
            assert nb >= 0.5

        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train(x, y, 0)
Exemplo n.º 16
0
import matplotlib.pyplot as plt
import symjax.tensor as T
import symjax
import numpy as np

w = T.Placeholder((3, ), "float32", name="w")
alpha = 0.5
new_value, var = symjax.nn.schedules.ExponentialMovingAverage(w, alpha)

train = symjax.function(w, outputs=new_value, updates=symjax.get_updates())

data = np.stack([np.ones(200), np.random.randn(200), np.zeros(200)], 1)
cost = list()
true_ema = [data[0]]
aa = 0.5
for j, i in enumerate(data):
    cost.append(train(i))
    true_ema.append(aa * true_ema[-1] + (1 - aa) * i)
cost = np.asarray(cost)
true = np.asarray(true_ema)[1:]
print("% close values:", 100 * np.mean(np.isclose(cost, true)))

plt.subplot(311)
plt.plot(data[:, 0])
plt.plot(cost[:, 0])
plt.plot(true[:, 0])

plt.subplot(312)
plt.plot(data[:, 1], label="true")
plt.plot(cost[:, 1], label="symjax")
plt.plot(true[:, 1], label="np")
Exemplo n.º 17
0
    def __init__(
        self,
        env,
        actor,
        critic,
        lr=1e-4,
        batch_size=32,
        train_pi_iters=80,
        train_v_iters=80,
    ):

        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_min = env.action_space.low
            action_max = env.action_space.high
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.continuous = continuous
        self.lr = lr
        self.train_pi_iters = train_pi_iters
        self.train_v_iters = train_v_iters
        self.extras = {}

        state_ph = T.Placeholder((batch_size, num_states), "float32")
        rew_ph = T.Placeholder((batch_size, ), "float32")

        with symjax.Scope("actor"):
            logits = actor(state_ph)
            if not continuous:
                pi = Categorical(logits=logits)
            else:
                logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )
                pi = MultivariateNormal(mean=logits, diag_log_std=logstd)

            actions = pi.sample()  # pi
            actions_log_prob = pi.log_prob(actions)  # logp

        with symjax.Scope("critic"):
            critic_value = critic(state_ph)

        # AC objectives
        diff = rew_ph - critic_value
        actor_loss = -(actions_log_prob * diff).mean()
        critic_loss = nn.losses.squared_differences(rew_ph,
                                                    critic_value).mean()

        with symjax.Scope("update_pi"):
            nn.optimizers.Adam(
                actor_loss,
                self.lr,
                params=symjax.get_variables(scope="/actor/"),
            )
        with symjax.Scope("update_v"):
            nn.optimizers.Adam(
                critic_loss,
                self.lr,
                params=symjax.get_variables(scope="/critic/"),
            )

        self.learn_pi = symjax.function(
            state_ph,
            rew_ph,
            outputs=actor_loss,
            updates=symjax.get_updates(scope="/update_pi/"),
        )
        self.learn_v = symjax.function(
            state_ph,
            rew_ph,
            outputs=critic_loss,
            updates=symjax.get_updates(scope="/update_v/*"),
        )

        single_state = T.Placeholder((1, num_states), "float32")
        single_action = actions.clone({state_ph: single_state})[0]
        single_v = critic_value.clone({state_ph: single_state})

        self._act = symjax.function(
            single_state,
            outputs=[single_action, single_v],
        )
Exemplo n.º 18
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        batch_size,
        actor,
        critic,
        lr=1e-3,
        K_epochs=80,
        eps_clip=0.2,
        gamma=0.99,
        entropy_beta=0.01,
    ):
        self.actor = actor
        self.critic = critic
        self.gamma = gamma
        self.lr = lr
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.batch_size = batch_size

        states = T.Placeholder((batch_size, ) + state_shape,
                               "float32",
                               name="states")
        actions = T.Placeholder((batch_size, ) + actions_shape,
                                "float32",
                                name="states")
        rewards = T.Placeholder((batch_size, ),
                                "float32",
                                name="discounted_rewards")
        advantages = T.Placeholder((batch_size, ),
                                   "float32",
                                   name="advantages")

        self.target_actor = actor(states, distribution="gaussian")
        self.actor = actor(states, distribution="gaussian")
        self.critic = critic(states)

        # Finding the ratio (pi_theta / pi_theta__old) and
        # surrogate Loss https://arxiv.org/pdf/1707.06347.pdf
        with symjax.Scope("policy_loss"):
            ratios = T.exp(
                self.actor.actions.log_prob(actions) -
                self.target_actor.actions.log_prob(actions))
            ratios = T.clip(ratios, 0, 10)
            clipped_ratios = T.clip(ratios, 1 - self.eps_clip,
                                    1 + self.eps_clip)

            surr1 = advantages * ratios
            surr2 = advantages * clipped_ratios

            actor_loss = -(T.minimum(surr1, surr2)).mean()

        with symjax.Scope("monitor"):
            clipfrac = (((ratios > (1 + self.eps_clip)) |
                         (ratios <
                          (1 - self.eps_clip))).astype("float32").mean())
            approx_kl = (self.target_actor.actions.log_prob(actions) -
                         self.actor.actions.log_prob(actions)).mean()

        with symjax.Scope("critic_loss"):
            critic_loss = T.mean((rewards - self.critic.q_values)**2)

        with symjax.Scope("entropy"):
            entropy = self.actor.actions.entropy().mean()

        loss = actor_loss + critic_loss  # - entropy_beta * entropy

        with symjax.Scope("optimizer"):
            nn.optimizers.Adam(
                loss,
                lr,
                params=self.actor.params(True) + self.critic.params(True),
            )

        # create the update function
        self._train = symjax.function(
            states,
            actions,
            rewards,
            advantages,
            outputs=[actor_loss, critic_loss, clipfrac, approx_kl],
            updates=symjax.get_updates(scope="*optimizer"),
        )

        # initialize target as current
        self.update_target(1)
Exemplo n.º 19
0
    def __init__(
        self,
        env,
        actor,
        critic,
        lr=1e-4,
        batch_size=32,
        n=1,
        clip_ratio=0.2,
        entcoeff=0.01,
        target_kl=0.01,
        train_pi_iters=4,
        train_v_iters=4,
    ):

        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_min = env.action_space.low
            action_max = env.action_space.high
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.continuous = continuous
        self.lr = lr
        self.train_pi_iters = train_pi_iters
        self.train_v_iters = train_v_iters
        self.clip_ratio = clip_ratio
        self.target_kl = target_kl
        self.extras = {"logprob": ()}
        self.entcoeff = entcoeff

        state_ph = T.Placeholder((batch_size, num_states), "float32")
        ret_ph = T.Placeholder((batch_size, ), "float32")
        adv_ph = T.Placeholder((batch_size, ), "float32")
        act_ph = T.Placeholder((batch_size, num_actions), "float32")

        with symjax.Scope("actor"):
            logits = actor(state_ph)
            if continuous:
                logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )
        with symjax.Scope("old_actor"):
            old_logits = actor(state_ph)
            if continuous:
                old_logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )

        if not continuous:
            pi = Categorical(logits=logits)
        else:
            pi = MultivariateNormal(mean=logits, diag_log_std=logstd)

        actions = T.clip(pi.sample(), -2, 2)  # pi

        actor_params = actor_params = symjax.get_variables(scope="/actor/")
        old_actor_params = actor_params = symjax.get_variables(
            scope="/old_actor/")

        self.update_target = symjax.function(
            updates={o: a
                     for o, a in zip(old_actor_params, actor_params)})

        # PPO objectives
        # pi(a|s) / pi_old(a|s)
        pi_log_prob = pi.log_prob(act_ph)
        old_pi_log_prob = pi_log_prob.clone({
            logits: old_logits,
            logstd: old_logstd
        })

        ratio = T.exp(pi_log_prob - old_pi_log_prob)
        surr1 = ratio * adv_ph
        surr2 = adv_ph * T.clip(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio)

        pi_loss = -T.minimum(surr1, surr2).mean()
        # ent_loss = pi.entropy().mean() * self.entcoeff

        with symjax.Scope("critic"):
            v = critic(state_ph)
        # critic loss
        v_loss = ((ret_ph - v)**2).mean()

        # Info (useful to watch during learning)

        # a sample estimate for KL
        approx_kl = (old_pi_log_prob - pi_log_prob).mean()
        # a sample estimate for entropy
        # approx_ent = -logprob_given_actions.mean()
        # clipped = T.logical_or(
        #     ratio > (1 + clip_ratio), ratio < (1 - clip_ratio)
        # )
        # clipfrac = clipped.astype("float32").mean()

        with symjax.Scope("update_pi"):
            print(len(actor_params), "actor parameters")
            nn.optimizers.Adam(
                pi_loss,
                self.lr,
                params=actor_params,
            )
        with symjax.Scope("update_v"):
            critic_params = symjax.get_variables(scope="/critic/")
            print(len(critic_params), "critic parameters")

            nn.optimizers.Adam(
                v_loss,
                self.lr,
                params=critic_params,
            )

        self.get_params = symjax.function(outputs=critic_params)

        self.learn_pi = symjax.function(
            state_ph,
            act_ph,
            adv_ph,
            outputs=[pi_loss, approx_kl],
            updates=symjax.get_updates(scope="/update_pi/"),
        )
        self.learn_v = symjax.function(
            state_ph,
            ret_ph,
            outputs=v_loss,
            updates=symjax.get_updates(scope="/update_v/"),
        )

        single_state = T.Placeholder((1, num_states), "float32")
        single_v = v.clone({state_ph: single_state})
        single_sample = actions.clone({state_ph: single_state})

        self._act = symjax.function(single_state, outputs=single_sample)
        self._get_v = symjax.function(single_state, outputs=single_v)

        single_action = T.Placeholder((1, num_actions), "float32")
Exemplo n.º 20
0
    def __init__(
        self,
        env_fn,
        actor,
        critic,
        gamma=0.99,
        tau=0.01,
        lr=1e-3,
        batch_size=32,
        epsilon=0.1,
        epsilon_decay=1 / 1000,
        min_epsilon=0.01,
        reward=None,
    ):

        # comment out this line if you don't want to record a video of the agent
        # if save_folder is not None:
        # test_env = gym.wrappers.Monitor(test_env)

        # get size of state space and action space
        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_max = env.action_space.high[0]
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.gamma = gamma
        self.continuous = continuous
        self.observ_min = np.clip(env.observation_space.low, -20, 20)
        self.observ_max = np.clip(env.observation_space.high, -20, 20)
        self.env = env
        self.reward = reward

        # state
        state = T.Placeholder((batch_size, num_states), "float32")
        gradients = T.Placeholder((batch_size, num_actions), "float32")
        action = T.Placeholder((batch_size, num_actions), "float32")
        target = T.Placeholder((batch_size, 1), "float32")

        with symjax.Scope("actor_critic"):

            scaled_out = action_max * actor(state)
            Q = critic(state, action)

        a_loss = -T.sum(gradients * scaled_out)
        q_loss = T.mean((Q - target)**2)

        nn.optimizers.Adam(a_loss + q_loss, lr)

        self.update = symjax.function(
            state,
            action,
            target,
            gradients,
            outputs=[a_loss, q_loss],
            updates=symjax.get_updates(),
        )
        g = symjax.gradients(T.mean(Q), [action])[0]
        self.get_gradients = symjax.function(state, action, outputs=g)

        # also create the target variants
        with symjax.Scope("actor_critic_target"):
            scaled_out_target = action_max * actor(state)
            Q_target = critic(state, action)

        self.actor_predict = symjax.function(state, outputs=scaled_out)
        self.actor_predict_target = symjax.function(state,
                                                    outputs=scaled_out_target)
        self.critic_predict = symjax.function(state, action, outputs=Q)
        self.critic_predict_target = symjax.function(state,
                                                     action,
                                                     outputs=Q_target)

        t_params = symjax.get_variables(scope="/actor_critic_target/*")
        params = symjax.get_variables(scope="/actor_critic/*")
        replacement = {
            t: tau * e + (1 - tau) * t
            for t, e in zip(t_params, params)
        }
        self.update_target = symjax.function(updates=replacement)

        single_state = T.Placeholder((1, num_states), "float32")
        if not continuous:
            scaled_out = clean_action.argmax(-1)

        self.act = symjax.function(single_state,
                                   outputs=scaled_out.clone(
                                       {state: single_state})[0])
Exemplo n.º 21
0
demonstration on how to compute a gradient and apply a basic gradient update
rule to minimize some loss function

"""

import symjax
import symjax.tensor as T
import matplotlib.pyplot as plt

# GRADIENT DESCENT
z = T.Variable(3.0, dtype="float32")
loss = (z - 1) ** 2
g_z = symjax.gradients(loss, [z])[0]
symjax.current_graph().add_updates({z: z - 0.1 * g_z})

train = symjax.function(outputs=[loss, z], updates=symjax.get_updates())

losses = list()
values = list()
for i in range(200):
    if (i + 1) % 50 == 0:
        symjax.reset_variables("*")
    a, b = train()
    losses.append(a)
    values.append(b)

plt.figure()

plt.subplot(121)
plt.plot(losses, "-x")
plt.ylabel("loss")
Exemplo n.º 22
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        batch_size,
        actor,
        critic,
        lr=1e-3,
        gamma=0.99,
        tau=0.01,
    ):

        self.gamma = gamma
        self.tau = tau
        self.lr = lr
        self.batch_size = batch_size

        states = T.Placeholder((batch_size, ) + state_shape, "float32")
        actions = T.Placeholder((batch_size, ) + actions_shape, "float32")
        self.critic = critic(states, actions)
        self.target_critic = critic(states, actions)

        # create critic loss
        targets = T.Placeholder(self.critic.q_values.shape, "float32")
        critic_loss = ((self.critic.q_values - targets)**2).mean()

        # create optimizer
        with symjax.Scope("critic_optimizer"):
            nn.optimizers.Adam(critic_loss,
                               lr,
                               params=self.critic.params(True))

        # create the update function
        self._train_critic = symjax.function(
            states,
            actions,
            targets,
            outputs=critic_loss,
            updates=symjax.get_updates(scope="*/critic_optimizer"),
        )

        # now create utility function to get the gradients
        grad = symjax.gradients(self.critic.q_values.sum(), actions)
        self._get_critic_gradients = symjax.function(states,
                                                     actions,
                                                     outputs=grad)

        # create actor loss
        self.actor = actor(states)
        self.target_actor = actor(states)
        gradients = T.Placeholder(actions.shape, "float32")
        actor_loss = -(self.actor.actions * gradients).mean()

        # create optimizer
        with symjax.Scope("actor_optimizer"):
            nn.optimizers.Adam(actor_loss, lr, params=self.actor.params(True))

        # create the update function
        self._train_actor = symjax.function(
            states,
            gradients,
            outputs=actor_loss,
            updates=symjax.get_updates(scope="*/actor_optimizer"),
        )

        # initialize both networks as the same
        self.update_target(1)