示例#1
0
def test_stop():
    a = symjax.tensor.ones(())
    b = a + a**2
    g = symjax.gradients(b, [a])[0]
    f = symjax.function(outputs=g)
    assert f() == 3
    b = a + symjax.tensor.stop_gradient(a**2)
    g = symjax.gradients(b, [a])[0]
    f = symjax.function(outputs=g)
    assert f() == 1
示例#2
0
def test_g():
    a = symjax.tensor.ones(())
    b = symjax.tensor.Variable(1.0)
    l = a * b
    g = symjax.gradients(l, [a])[0]
    f = symjax.function(outputs=g, updates={b: b + 1.0})
    assert f() == 1
    assert f() == 2
    assert f() == 3
示例#3
0
def test_grad():
    w = tt.Placeholder((), "float32")
    v = tt.Variable(1.0, dtype="float32")
    x = w * v + 2
    #    symjax.nn.optimizers.Adam(x, 0.001)
    g = symjax.gradients(x.sum(), [v])[0]
    f = symjax.function(w, outputs=g)
    assert f(1) == 1
    assert f(10) == 10
示例#4
0
def test_vectorize_sgd():
    sj.current_graph().reset()
    x = symjax.tensor.Placeholder((0, 2), "float32")
    y = symjax.tensor.Placeholder((0, ), "float32")

    w = symjax.tensor.Variable((1, 1), dtype="float32")
    loss = ((x.dot(w) - y)**2).mean()

    g = symjax.gradients(loss, [w])[0]

    other_g = symjax.gradients(x.dot(w).sum(), [w])[0]

    f = symjax.function(x, y, outputs=loss, updates={w: w - 0.1 * g})
    other_f = symjax.function(x, outputs=other_g)

    L = [10]
    for i in range(10):
        L.append(f(np.ones((i + 1, 2)), -1 * np.ones(i + 1)))
        assert L[-1] < L[-2]
        assert np.array_equal(other_f(np.ones((i + 1, 2))), [i + 1.0, i + 1.0])
示例#5
0
def test_grad_map():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32", name="u")
    out = T.map(lambda a, w, u: w * a * u, (T.range(3), ),
                non_sequences=(w, u))
    g = sj.gradients(out.sum(), w)
    f = sj.function(u, outputs=g)

    assert np.array_equal(f(0), 0)
    assert np.array_equal(f(1), 3)
示例#6
0
def test_clone_base():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    w2 = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32", name="u")
    uu = T.Placeholder((), "float32", name="uu")

    aa = T.Placeholder((), "float32")
    bb = T.Placeholder((), "float32")

    l = 2 * w * u * w2
    g = sj.gradients(l, w)
    guu = T.clone(l, {u: uu})
    guuu = T.clone(l, {w: uu})

    f = sj.function(u, outputs=g, updates={w2: w2 + 1})
    fuu = sj.function(uu, outputs=guu, updates={w2: w2 + 1})
    fuuu = sj.function(u, uu, outputs=guuu, updates={w2: w2 + 1})

    #    print(f(2))
    assert np.array_equal(f(2), 4.0)
    assert np.array_equal(fuu(1), 4)
    assert np.array_equal(fuuu(0, 0), 0)
示例#7
0
    def __init__(
        self,
        env_fn,
        actor,
        critic,
        gamma=0.99,
        tau=0.01,
        lr=1e-3,
        batch_size=32,
        epsilon=0.1,
        epsilon_decay=1 / 1000,
        min_epsilon=0.01,
        reward=None,
    ):

        # comment out this line if you don't want to record a video of the agent
        # if save_folder is not None:
        # test_env = gym.wrappers.Monitor(test_env)

        # get size of state space and action space
        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_max = env.action_space.high[0]
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.gamma = gamma
        self.continuous = continuous
        self.observ_min = np.clip(env.observation_space.low, -20, 20)
        self.observ_max = np.clip(env.observation_space.high, -20, 20)
        self.env = env
        self.reward = reward

        # state
        state = T.Placeholder((batch_size, num_states), "float32")
        gradients = T.Placeholder((batch_size, num_actions), "float32")
        action = T.Placeholder((batch_size, num_actions), "float32")
        target = T.Placeholder((batch_size, 1), "float32")

        with symjax.Scope("actor_critic"):

            scaled_out = action_max * actor(state)
            Q = critic(state, action)

        a_loss = -T.sum(gradients * scaled_out)
        q_loss = T.mean((Q - target)**2)

        nn.optimizers.Adam(a_loss + q_loss, lr)

        self.update = symjax.function(
            state,
            action,
            target,
            gradients,
            outputs=[a_loss, q_loss],
            updates=symjax.get_updates(),
        )
        g = symjax.gradients(T.mean(Q), [action])[0]
        self.get_gradients = symjax.function(state, action, outputs=g)

        # also create the target variants
        with symjax.Scope("actor_critic_target"):
            scaled_out_target = action_max * actor(state)
            Q_target = critic(state, action)

        self.actor_predict = symjax.function(state, outputs=scaled_out)
        self.actor_predict_target = symjax.function(state,
                                                    outputs=scaled_out_target)
        self.critic_predict = symjax.function(state, action, outputs=Q)
        self.critic_predict_target = symjax.function(state,
                                                     action,
                                                     outputs=Q_target)

        t_params = symjax.get_variables(scope="/actor_critic_target/*")
        params = symjax.get_variables(scope="/actor_critic/*")
        replacement = {
            t: tau * e + (1 - tau) * t
            for t, e in zip(t_params, params)
        }
        self.update_target = symjax.function(updates=replacement)

        single_state = T.Placeholder((1, num_states), "float32")
        if not continuous:
            scaled_out = clean_action.argmax(-1)

        self.act = symjax.function(single_state,
                                   outputs=scaled_out.clone(
                                       {state: single_state})[0])
示例#8
0
z = T.Variable(np.random.randn(*SHAPE).astype("float32"), name="z")
get_z = symjax.function(outputs=z)

for i in range(10):
    print(get_z())

# print(f_shuffle()[0])
asdasd

w = T.Placeholder(SHAPE, "float32", name="w")
noise = T.random.uniform(SHAPE, dtype="float32")
y = T.cos(symjax.nn.activations.leaky_relu(z, 0.3) + w + noise)
cost = T.pool(y, (2, 2))
cost = T.sum(cost)

grads = symjax.gradients(cost, [w, z], [1])

print(cost.get({w: np.random.randn(*SHAPE)}))
noise.seed = 20
print(cost.get({w: np.random.randn(*SHAPE)}))
noise.seed = 40
print(cost.get({w: np.random.randn(*SHAPE)}))

updates = {z: z - 0.01 * grads[0]}
fn1 = symjax.function(w, outputs=[cost])
fn2 = symjax.function(w, outputs=[cost], updates=updates)
print(fn1(np.random.randn(*SHAPE)))
print(fn1(np.random.randn(*SHAPE)))

cost = list()
for i in range(1000):
示例#9
0
def test_learn_bn():

    symjax.current_graph().reset()
    import tensorflow as tf
    from tensorflow.keras import layers
    import symjax.nn as nn

    np.random.seed(0)

    batch_size = 128

    W = np.random.randn(5, 5, 3, 2)
    W2 = np.random.randn(2 * 28 * 28, 10)

    inputs = layers.Input(shape=(3, 32, 32))
    out = layers.Permute((2, 3, 1))(inputs)
    out = layers.Conv2D(2,
                        5,
                        activation="linear",
                        kernel_initializer=lambda *args, **kwargs: W)(out)
    out = layers.BatchNormalization(-1)(out)
    out = layers.Activation("relu")(out)
    out = layers.Flatten()(out)
    out = layers.Dense(10,
                       activation="linear",
                       kernel_initializer=lambda *args, **kwargs: W2)(out)
    model = tf.keras.Model(inputs, out)
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)

    input = T.Placeholder((batch_size, 3, 32, 32), "float32")
    label = T.Placeholder((batch_size, ), "int32")
    deterministic = T.Placeholder((), "bool")

    conv = nn.layers.Conv2D(input, 2, (5, 5), W=W.transpose((3, 2, 0, 1)))
    out = nn.layers.BatchNormalization(conv, [1], deterministic=deterministic)
    out = nn.relu(out)
    out = nn.layers.Dense(out.transpose((0, 2, 3, 1)), 10, W=W2.T)
    loss = nn.losses.sparse_softmax_crossentropy_logits(label, out).mean()
    # V2 = T.Variable(W2)
    # loss = input.sum() * V2.sum()
    # out=loss
    nn.optimizers.SGD(loss, 0.001)
    f = symjax.function(input, label, deterministic, outputs=[loss, out])
    g = symjax.function(
        input,
        label,
        deterministic,
        outputs=symjax.gradients(loss, symjax.get_variables(trainable=True)),
    )
    train = symjax.function(
        input,
        label,
        deterministic,
        outputs=loss,
        updates=symjax.get_updates(),
    )

    for epoch in range(10):
        # generate some random inputs and labels
        x = np.random.randn(batch_size, 3, 32, 32)
        y = np.random.randint(0, 10, size=batch_size)

        # test predictions during testing mode
        preds = model(x, training=False)
        nb = np.isclose(preds, f(x, y, 1)[1], atol=1e-3).mean()
        print("preds not training", nb)
        assert nb > 0.8

        # test prediction during training mode
        # now get the gradients
        with tf.GradientTape() as tape:
            preds = model(x, training=True)
            loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(y, preds))

        nb = np.isclose(preds, f(x, y, 0)[1], atol=1e-3).mean()
        print("preds training", nb)
        assert nb > 0.8

        # test loss function during training
        losss = f(x, y, 0)[0]
        print("losses", losss, loss)
        assert np.abs(losss - loss) < 1e-3

        # test the gradients
        grads = tape.gradient(loss, model.trainable_variables)
        sj_grads = g(x, y, 0)

        for tf_g, sj_g in zip(grads, sj_grads):
            if sj_g.ndim == 4:
                sj_g = sj_g.transpose((2, 3, 1, 0))
            else:
                sj_g = sj_g.transpose()

            print(sj_g.shape, tf_g.shape)
            nb = np.isclose(
                np.reshape(sj_g, -1),
                np.reshape(tf_g, -1),
                atol=1e-3,
            ).mean()
            print("grads training", nb)
            assert nb >= 0.5

        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train(x, y, 0)
示例#10
0
print(f(np.ones((1, 2))))
print(w.value)
print(f(np.ones((2, 2))))
print(w.value)
# [2.]
# 2.0
# [2. 2.]
# 4.0

x = symjax.tensor.Placeholder((0, 2), "float32")
y = symjax.tensor.Placeholder((0,), "float32")
w = symjax.tensor.Variable((1, 1), dtype="float32")

loss = ((x.dot(w) - y) ** 2).mean()

g = symjax.gradients(loss, [w])[0]

other_g = symjax.gradients(x.dot(w).sum(), [w])[0]
f = symjax.function(x, y, outputs=loss, updates={w: w - 0.1 * g})
other_f = symjax.function(x, outputs=other_g)
for i in range(10):
    print(f(np.ones((i + 1, 2)), -1 * np.ones(i + 1)))
    print(other_f(np.ones((i + 1, 2))))

# 9.0
# [1. 1.]
# 3.2399998
# [2. 2.]
# 1.1663998
# [3. 3.]
# 0.419904
示例#11
0
import symjax
import symjax.tensor as T
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

matplotlib.use("Agg")

###### DERIVATIVE OF GAUSSIAN EXAMPLE

t = T.Placeholder((1000, ), "float32")
print(t)
f = T.meshgrid(t, t)
f = T.exp(-(t**2))
u = f.sum()
g = symjax.gradients(u, [t])
g2 = symjax.gradients(g[0].sum(), [t])
g3 = symjax.gradients(g2[0].sum(), [t])

dog = symjax.function(t, outputs=[g[0], g2[0], g3[0]])

plt.plot(np.array(dog(np.linspace(-10, 10, 1000))).T)

###### GRADIENT DESCENT
z = T.Variable(3.0)
loss = z**2
g_z = symjax.gradients(loss, [z])
print(loss, z)
train = symjax.function(outputs=[loss, z], updates={z: z - 0.1 * g_z[0]})

losses = list()
示例#12
0
w = T.Variable(1.0, dtype="float32")
u = T.Placeholder((), "float32")
out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u])
f = sj.function(u, outputs=out, updates={w: w + 1})
print(f(2))
# [0, 1, 2]
print(f(2))
# [0, 0, 0]
print(f(0))
# [0, -3, -6]


w.reset()
out = T.map(lambda a, w, u: w * a * u, [T.range(3)], non_sequences=[w, u])
g = sj.gradients(out.sum(), [w])[0]
f = sj.function(u, outputs=g)

print(f(0))
# 0
print(f(1))
# 3


out = T.map(lambda a, b: a * b, [T.range(3), T.range(3)])
f = sj.function(outputs=out)

print(f())
# [0, 1, 4]

示例#13
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        batch_size,
        actor,
        critic,
        lr=1e-3,
        gamma=0.99,
        tau=0.01,
    ):

        self.gamma = gamma
        self.tau = tau
        self.lr = lr
        self.batch_size = batch_size

        states = T.Placeholder((batch_size, ) + state_shape, "float32")
        actions = T.Placeholder((batch_size, ) + actions_shape, "float32")
        self.critic = critic(states, actions)
        self.target_critic = critic(states, actions)

        # create critic loss
        targets = T.Placeholder(self.critic.q_values.shape, "float32")
        critic_loss = ((self.critic.q_values - targets)**2).mean()

        # create optimizer
        with symjax.Scope("critic_optimizer"):
            nn.optimizers.Adam(critic_loss,
                               lr,
                               params=self.critic.params(True))

        # create the update function
        self._train_critic = symjax.function(
            states,
            actions,
            targets,
            outputs=critic_loss,
            updates=symjax.get_updates(scope="*/critic_optimizer"),
        )

        # now create utility function to get the gradients
        grad = symjax.gradients(self.critic.q_values.sum(), actions)
        self._get_critic_gradients = symjax.function(states,
                                                     actions,
                                                     outputs=grad)

        # create actor loss
        self.actor = actor(states)
        self.target_actor = actor(states)
        gradients = T.Placeholder(actions.shape, "float32")
        actor_loss = -(self.actor.actions * gradients).mean()

        # create optimizer
        with symjax.Scope("actor_optimizer"):
            nn.optimizers.Adam(actor_loss, lr, params=self.actor.params(True))

        # create the update function
        self._train_actor = symjax.function(
            states,
            gradients,
            outputs=actor_loss,
            updates=symjax.get_updates(scope="*/actor_optimizer"),
        )

        # initialize both networks as the same
        self.update_target(1)
示例#14
0
out1 = randn * value
out2 = out1.clone({randn: rand})

f = symjax.function(rand, outputs=out2, updates={value: 2 + value})

for i in range(3):
    print(f(i))
# 0.
# 3.
# 10.

# we create a simple computational graph
var = T.Variable(T.random.randn((16, 8), seed=10))
loss = ((var - T.ones_like(var))**2).sum()
g = symjax.gradients(loss, [var])
opt = symjax.optimizers.SGD(loss, 0.01, params=var)

f = symjax.function(outputs=loss, updates=opt.updates)

for i in range(10):
    print(f())
# 240.96829
# 231.42595
# 222.26149
# 213.45993
# 205.00691
# 196.88864
# 189.09186
# 181.60382
# 174.41231
示例#15
0
def test_pymc():
    class RandomVariable(symjax.tensor.Variable):
        def __init__(self, name, shape, observed):
            if observed is None:
                super().__init__(np.zeros(shape), name=name)
            else:
                super().__init__(observed, name=name, trainable=False)

        def logp(self, value):
            raise NotImplementedError()

        def random(self, sample_shape):
            raise NotImplementedError()

        @property
        def logpt(self):
            return self.logp(self)

    class Normal(RandomVariable):
        def __init__(self, name, mu, sigma, shape=None, observed=None):
            self.mu = mu
            self.sigma = sigma
            super().__init__(name, shape, observed)

        def logp(self, value):
            tau = self.sigma**-2.0
            return (-tau *
                    (value - self.mu)**2 + tt.log(tau / np.pi / 2.0)) / 2.0

        def random(self, sample_shape):
            return np.random.randn(sample_shape) * self.sigma + self.mu

    x = Normal("x", 0, 10.0)
    s = Normal("s", 0.0, 5.0)
    y = Normal("y", x, tt.exp(s))

    assert symjax.current_graph().get(y) == 0.0

    #################
    model_logpt = x.logpt + s.logpt + y.logpt

    f = symjax.function(x, s, y, outputs=model_logpt)

    normal_loglike = jsp.stats.norm.logpdf

    def f_(x, s, y):
        return (normal_loglike(x, 0.0, 10.0) + normal_loglike(s, 0.0, 5.0) +
                normal_loglike(y, x, jnp.exp(s)))

    for i in range(10):
        x_val = np.random.randn() * 10.0
        s_val = np.random.randn() * 5.0
        y_val = np.random.randn() * 0.1 + x_val

        np.testing.assert_allclose(f(x_val, s_val, y_val),
                                   f_(x_val, s_val, y_val),
                                   rtol=1e-06)

    model_dlogpt = symjax.gradients(model_logpt, [x, s, y])

    f_with_grad = symjax.function(x, s, y, outputs=[model_logpt, model_dlogpt])

    f_with_grad(x_val, s_val, y_val)

    grad_fn = jax.grad(f_, argnums=[0, 1, 2])
    f_(x_val, s_val, y_val), grad_fn(x_val, s_val, y_val)
示例#16
0
import jax
import numpy as np
import sys
sys.path.insert(0, "../")

import symjax
import symjax.tensor as T

# map
xx = T.ones(10)
a = T.map(lambda a: a * 2, xx)
g = symjax.gradients(a.sum(), xx)[0]
f = symjax.function(outputs=[a, g])

# scan
xx = T.ones(10) * 2
a = T.scan(lambda c, x: (c * x, c * x), T.ones(1), xx)
g = symjax.gradients(a[1][-1], xx)[0]
f = symjax.function(outputs=[a, g])

# scan with updates
xx = T.range(5)
uu = T.ones((10, 2))
vvar = T.Variable(T.zeros((10, 2)))
vv = T.index_add(vvar, 1, 1)
a = T.scan(lambda c, x, p: (T.index_update(c, x, p[x]), 1), vv, xx, [vv])
#a = T.scan(lambda c, x: (c*x,c*x), T.ones(1), xx)
#a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx)
#g = symjax.gradients(a[1][-1],xx)
f = symjax.function(outputs=a[0], updates={vvar: vvar + 1})
print(f(), f(), f())
示例#17
0
import sys
sys.path.insert(0, "../")
import symjax
import symjax.tensor as T

# create our variable to be optimized
mu = T.Variable(T.random.normal((), seed=1))

# create our cost
cost = T.exp(-(mu-1)**2)

# get the gradient, notice that it is itself a tensor that can then
# be manipulated as well
g = symjax.gradients(cost, mu)
print(g)

# (Tensor: shape=(), dtype=float32)

# create the compield function that will compute the cost and apply
# the update onto the variable
f = symjax.function(outputs=cost, updates={mu:mu-0.2*g})

for i in range(10):
    print(f())

# 0.008471076
# 0.008201109
# 0.007946267
# 0.007705368
# 0.0074773384
# 0.007261208
示例#18
0
文件: plot_sgd.py 项目: SymJAX/SymJAX
Basic gradient descent (and reset)
==================================

demonstration on how to compute a gradient and apply a basic gradient update
rule to minimize some loss function

"""

import symjax
import symjax.tensor as T
import matplotlib.pyplot as plt

# GRADIENT DESCENT
z = T.Variable(3.0, dtype="float32")
loss = (z - 1) ** 2
g_z = symjax.gradients(loss, [z])[0]
symjax.current_graph().add_updates({z: z - 0.1 * g_z})

train = symjax.function(outputs=[loss, z], updates=symjax.get_updates())

losses = list()
values = list()
for i in range(200):
    if (i + 1) % 50 == 0:
        symjax.reset_variables("*")
    a, b = train()
    losses.append(a)
    values.append(b)

plt.figure()
示例#19
0
文件: scan.py 项目: ml-lab/SymJAX
    return carry + 1, 0


output, _ = T.scan(func, T.zeros(1), T.ones(10), length=10)

f = sj.function(outputs=output)
print(f())
# [10.]

# example of simple RNN

w = T.Placeholder((3, 10), 'float32')
h = T.random.randn((3, 3))
b = T.random.randn((3, ))
t_steps = 100
X = T.random.randn((t_steps, 10))


def rnn_cell(carry, x, w):
    output = T.sigmoid(T.matmul(w, x) + T.matmul(carry, h) + b)
    return output, output


last, hidden = T.scan(rnn_cell, T.zeros(3), X, constants=(w, ))

g = sj.gradients(hidden.sum(), w)
print(g.get({w: np.ones((3, 10))}))
f = sj.function(w, outputs=hidden)
print(f(np.ones((3, 10))))
#