def test_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32") out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u]) f = sj.function(u, outputs=out, updates={w: w + 1}) assert np.array_equal(f(2), np.arange(3)) assert np.array_equal(f(2), np.zeros(3)) assert np.array_equal(f(0), -np.arange(3) * 3)
def test_grad_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32", name="u") out = T.map(lambda a, w, u: w * a * u, (T.range(3), ), non_sequences=(w, u)) g = sj.gradients(out.sum(), w) f = sj.function(u, outputs=g) assert np.array_equal(f(0), 0) assert np.array_equal(f(1), 3)
def build_net(self, Q): # ------------------ all inputs ------------------------ state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s") next_state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s_") reward = T.Placeholder( [ self.batch_size, ], "float32", name="r", ) # input reward action = T.Placeholder( [ self.batch_size, ], "int32", name="a", ) # input Action with symjax.Scope("eval_net"): q_eval = Q(state, self.n_actions) with symjax.Scope("test_set"): q_next = Q(next_state, self.n_actions) q_target = reward + self.reward_decay * q_next.max(1) q_target = T.stop_gradient(q_target) a_indices = T.stack([T.range(self.batch_size), action], axis=1) q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)), 1).squeeze(1) loss = T.mean((q_target - q_eval_wrt_a)**2) nn.optimizers.Adam(loss, self.lr) self.train = symjax.function(state, action, reward, next_state, updates=symjax.get_updates()) self.q_eval = symjax.function(state, outputs=q_eval)
import os os.environ["DATASET_PATH"] = "/home/vrael/DATASETS/" symjax.current_graph().reset() mnist = symjax.data.mnist() # 2d image images = mnist["train_set/images"][mnist["train_set/labels"] == 2][:2, 0] images /= images.max() np.random.seed(0) coordinates = T.meshgrid(T.range(28), T.range(28)) coordinates = T.Variable( T.stack([coordinates[1].flatten(), coordinates[0].flatten()]).astype("float32") ) interp = T.interpolation.map_coordinates(images[0], coordinates, order=1).reshape( (28, 28) ) loss = ((interp - images[1]) ** 2).mean() lr = symjax.nn.schedules.PiecewiseConstant(0.05, {5000: 0.01, 8000: 0.005}) symjax.nn.optimizers.Adam(loss, lr) train = symjax.function(outputs=loss, updates=symjax.get_updates()) rec = symjax.function(outputs=interp)
import symjax.tensor as T # map xx = T.ones(10) a = T.map(lambda a: a * 2, xx) g = symjax.gradients(a.sum(), xx)[0] f = symjax.function(outputs=[a, g]) # scan xx = T.ones(10) * 2 a = T.scan(lambda c, x: (c * x, c * x), T.ones(1), xx) g = symjax.gradients(a[1][-1], xx)[0] f = symjax.function(outputs=[a, g]) # scan with updates xx = T.range(5) uu = T.ones((10, 2)) vvar = T.Variable(T.zeros((10, 2))) vv = T.index_add(vvar, 1, 1) a = T.scan(lambda c, x, p: (T.index_update(c, x, p[x]), 1), vv, xx, [vv]) #a = T.scan(lambda c, x: (c*x,c*x), T.ones(1), xx) #a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx) #g = symjax.gradients(a[1][-1],xx) f = symjax.function(outputs=a[0], updates={vvar: vvar + 1}) print(f(), f(), f()) asdf # fori loop b = T.Placeholder((), 'int32') xx = T.ones(1) a = T.fori_loop(0, b, lambda i, x: i * x, xx)
def test_grad_map_v2(): sj.current_graph().reset() out = T.map(lambda a, b: a * b, (T.range(3), T.range(3))) f = sj.function(outputs=out) assert np.array_equal(f(), np.arange(3) * np.arange(3))
randn = T.random.randn(SHAPE) rand = T.random.rand(SHAPE) out = randn out2 = out.clone({randn: rand}) out3 = out2.clone({rand: 3}) get_vars = symjax.function(outputs=[out, out2]) for i in range(10): print(get_vars()) asdf # test shuffle matrix = T.linspace(0, 1, 16).reshape((4, 4)) smatrix = matrix[T.random.permutation(T.range(4))] get_shuffle = symjax.function(outputs=smatrix) for i in range(10): print(get_shuffle()) asdfsadf # test the random uniform SHAPE = (2, 2) z = T.Variable(np.random.randn(*SHAPE).astype("float32"), name="z") get_z = symjax.function(outputs=z) for i in range(10): print(get_z())
import symjax as sj import symjax.tensor as T w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32") out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u]) f = sj.function(u, outputs=out, updates={w: w + 1}) print(f(2)) # [0, 1, 2] print(f(2)) # [0, 0, 0] print(f(0)) # [0, -3, -6] w.reset() out = T.map(lambda a, w, u: w * a * u, [T.range(3)], non_sequences=[w, u]) g = sj.gradients(out.sum(), [w])[0] f = sj.function(u, outputs=g) print(f(0)) # 0 print(f(1)) # 3 out = T.map(lambda a, b: a * b, [T.range(3), T.range(3)]) f = sj.function(outputs=out) print(f()) # [0, 1, 4]