Ejemplo n.º 1
0
def test_map():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32")
    out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)],
                non_sequences=[w, u])
    f = sj.function(u, outputs=out, updates={w: w + 1})
    assert np.array_equal(f(2), np.arange(3))
    assert np.array_equal(f(2), np.zeros(3))
    assert np.array_equal(f(0), -np.arange(3) * 3)
Ejemplo n.º 2
0
def test_grad_map():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32", name="u")
    out = T.map(lambda a, w, u: w * a * u, (T.range(3), ),
                non_sequences=(w, u))
    g = sj.gradients(out.sum(), w)
    f = sj.function(u, outputs=g)

    assert np.array_equal(f(0), 0)
    assert np.array_equal(f(1), 3)
Ejemplo n.º 3
0
    def build_net(self, Q):
        # ------------------ all inputs ------------------------
        state = T.Placeholder([self.batch_size, self.n_states],
                              "float32",
                              name="s")
        next_state = T.Placeholder([self.batch_size, self.n_states],
                                   "float32",
                                   name="s_")
        reward = T.Placeholder(
            [
                self.batch_size,
            ],
            "float32",
            name="r",
        )  # input reward
        action = T.Placeholder(
            [
                self.batch_size,
            ],
            "int32",
            name="a",
        )  # input Action

        with symjax.Scope("eval_net"):
            q_eval = Q(state, self.n_actions)
        with symjax.Scope("test_set"):
            q_next = Q(next_state, self.n_actions)

        q_target = reward + self.reward_decay * q_next.max(1)
        q_target = T.stop_gradient(q_target)

        a_indices = T.stack([T.range(self.batch_size), action], axis=1)
        q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)),
                                         1).squeeze(1)
        loss = T.mean((q_target - q_eval_wrt_a)**2)
        nn.optimizers.Adam(loss, self.lr)

        self.train = symjax.function(state,
                                     action,
                                     reward,
                                     next_state,
                                     updates=symjax.get_updates())
        self.q_eval = symjax.function(state, outputs=q_eval)
Ejemplo n.º 4
0
import os

os.environ["DATASET_PATH"] = "/home/vrael/DATASETS/"

symjax.current_graph().reset()


mnist = symjax.data.mnist()
# 2d image
images = mnist["train_set/images"][mnist["train_set/labels"] == 2][:2, 0]
images /= images.max()

np.random.seed(0)

coordinates = T.meshgrid(T.range(28), T.range(28))
coordinates = T.Variable(
    T.stack([coordinates[1].flatten(), coordinates[0].flatten()]).astype("float32")
)
interp = T.interpolation.map_coordinates(images[0], coordinates, order=1).reshape(
    (28, 28)
)

loss = ((interp - images[1]) ** 2).mean()

lr = symjax.nn.schedules.PiecewiseConstant(0.05, {5000: 0.01, 8000: 0.005})
symjax.nn.optimizers.Adam(loss, lr)

train = symjax.function(outputs=loss, updates=symjax.get_updates())

rec = symjax.function(outputs=interp)
Ejemplo n.º 5
0
import symjax.tensor as T

# map
xx = T.ones(10)
a = T.map(lambda a: a * 2, xx)
g = symjax.gradients(a.sum(), xx)[0]
f = symjax.function(outputs=[a, g])

# scan
xx = T.ones(10) * 2
a = T.scan(lambda c, x: (c * x, c * x), T.ones(1), xx)
g = symjax.gradients(a[1][-1], xx)[0]
f = symjax.function(outputs=[a, g])

# scan with updates
xx = T.range(5)
uu = T.ones((10, 2))
vvar = T.Variable(T.zeros((10, 2)))
vv = T.index_add(vvar, 1, 1)
a = T.scan(lambda c, x, p: (T.index_update(c, x, p[x]), 1), vv, xx, [vv])
#a = T.scan(lambda c, x: (c*x,c*x), T.ones(1), xx)
#a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx)
#g = symjax.gradients(a[1][-1],xx)
f = symjax.function(outputs=a[0], updates={vvar: vvar + 1})
print(f(), f(), f())
asdf

# fori loop
b = T.Placeholder((), 'int32')
xx = T.ones(1)
a = T.fori_loop(0, b, lambda i, x: i * x, xx)
Ejemplo n.º 6
0
def test_grad_map_v2():
    sj.current_graph().reset()
    out = T.map(lambda a, b: a * b, (T.range(3), T.range(3)))
    f = sj.function(outputs=out)

    assert np.array_equal(f(), np.arange(3) * np.arange(3))
Ejemplo n.º 7
0
randn = T.random.randn(SHAPE)
rand = T.random.rand(SHAPE)

out = randn
out2 = out.clone({randn: rand})
out3 = out2.clone({rand: 3})
get_vars = symjax.function(outputs=[out, out2])

for i in range(10):
    print(get_vars())

asdf

# test shuffle
matrix = T.linspace(0, 1, 16).reshape((4, 4))
smatrix = matrix[T.random.permutation(T.range(4))]

get_shuffle = symjax.function(outputs=smatrix)

for i in range(10):
    print(get_shuffle())

asdfsadf

# test the random uniform
SHAPE = (2, 2)
z = T.Variable(np.random.randn(*SHAPE).astype("float32"), name="z")
get_z = symjax.function(outputs=z)

for i in range(10):
    print(get_z())
Ejemplo n.º 8
0
import symjax as sj
import symjax.tensor as T

w = T.Variable(1.0, dtype="float32")
u = T.Placeholder((), "float32")
out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u])
f = sj.function(u, outputs=out, updates={w: w + 1})
print(f(2))
# [0, 1, 2]
print(f(2))
# [0, 0, 0]
print(f(0))
# [0, -3, -6]


w.reset()
out = T.map(lambda a, w, u: w * a * u, [T.range(3)], non_sequences=[w, u])
g = sj.gradients(out.sum(), [w])[0]
f = sj.function(u, outputs=g)

print(f(0))
# 0
print(f(1))
# 3


out = T.map(lambda a, b: a * b, [T.range(3), T.range(3)])
f = sj.function(outputs=out)

print(f())
# [0, 1, 4]