def update(self, update_value, fast=True): """assign a new value for the variable""" if fast: symjax.current_graph().nodes[self]["value"] = update_value new_value = symjax.current_graph().get(update_value) if self.shape != jax.numpy.shape(new_value): warnings.warn("Variable and update of {}".format(self) + "are not the same shape (expected {}, got {}".format( self.shape, jax.numpy.shape(new_value)) + "... attempting to cast") new_value = jax.numpy.reshape(new_value, self.shape) if hasattr(new_value, "dtype"): ntype = new_value.dtype else: ntype = type(new_value) if self.dtype != ntype: warnings.warn("Variable and update of {}".format(self) + "are not the same dtype (expected {}, got {}".format( self.dtype, ntype) + "... attempting to cast") new_value = jax.numpy.asarray(new_value).astype(self.dtype) symjax.current_graph().nodes[self]["value"] = new_value
def test_bn(): sj.current_graph().reset() BATCH_SIZE = 5 DIM = 2 input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input") deterministic = T.Placeholder((1,), "bool", name="deterministic") bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic) update = sj.function(input, deterministic, outputs=bn, updates=sj.get_updates()) get_stats = sj.function(input, outputs=bn.avg_mean) data = np.random.randn(50, DIM) * 4 + 2 true_means = [] actual_means = [] for i in range(10): batch = data[BATCH_SIZE * i : BATCH_SIZE * (i + 1)] output = update(batch, 0) assert np.allclose( output, (batch - batch.mean(0)) / (1e-4 + batch.std(0)), 1e-4 ) actual_means.append(get_stats(batch)) if i == 0: true_means.append(batch.mean(0)) else: true_means.append(0.9 * true_means[-1] + 0.1 * batch.mean(0)) true_means = np.array(true_means) actual_means = np.array(actual_means).squeeze() assert np.allclose(true_means, actual_means, 1e-4)
def SJ(x, y, N, lr, model, preallocate=False): symjax.current_graph().reset() sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D]) sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1]) np.random.seed(0) sj_W = T.Variable(np.random.randn(D, 1).astype("float32")) sj_b = T.Variable(np.random.randn(1, ).astype("float32")) sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output)**2).mean() if model == "SGD": optimizers.SGD(sj_loss, lr) elif model == "Adam": optimizers.Adam(sj_loss, lr) train = symjax.function(sj_input, sj_output, outputs=sj_loss, updates=symjax.get_updates()) losses = [] for i in tqdm(range(N)): losses.append(train(x, y)) return losses
def test_update(): sj.current_graph().reset() w = symjax.tensor.zeros(10) for i in range(10): w = symjax.tensor.index_update(w, i, i) f = symjax.function(outputs=w) assert np.array_equal(f(), np.arange(10)) w2 = symjax.tensor.zeros(10) for i in range(10): w2 = symjax.tensor.index_update(w2, (i, ), i) f = symjax.function(outputs=w2) assert np.array_equal(f(), np.arange(10)) w3 = symjax.tensor.zeros(10) for i in range(10): w3 = symjax.tensor.index_update(w3, symjax.tensor.index[i], i) f = symjax.function(outputs=w3) assert np.array_equal(f(), np.arange(10)) w4 = symjax.tensor.Variable(symjax.tensor.zeros(10)) i = symjax.tensor.Variable(0, dtype="int32") update = symjax.tensor.index_update(w4, i, i) f = symjax.function(updates={w4: update, i: i + 1}) for i in range(10): f() assert np.array_equal(w4.value, np.arange(10))
def __init__(self, *args, _jax_function, _shapes, _dtypes, name=None, **kwargs): if name is None: name = _jax_function.__name__ name, scope = symjax.current_graph()._get_name_scope(name, self) Tensor.__init__( self, *args, _attrs={ "name": name, "scope": scope, "_dtype": MultiOutputOp, "jax_function": _jax_function, "root": False, }, **kwargs, ) for i, child in enumerate(self): symjax.current_graph().add_edge( self, child, name="parent_index" + str(i), ) symjax.current_graph().nodes[child]["parent"] = self
def __init__( self, *args, _jax_function=None, _shape=None, _dtype=None, name=None, **kwargs, ): if self in symjax.current_graph().nodes: return if name is None: name = _jax_function.__name__ name, scope = symjax.current_graph()._get_name_scope(name, self) super().__init__( *args, _attrs={ "name": name, "scope": scope, "_shape": _shape, "_dtype": _dtype, "jax_function": _jax_function, "root": False, }, **kwargs, )
def __init__(self, *args, _jax_function, _shapes, _dtypes, name=None, **kwargs): if name is None: name = _jax_function.__name__ self._name = name symjax.current_graph().add( self, jax_function=_jax_function, args=args, kwargs=kwargs, **kwargs, ) for i, (shape, dtype) in enumerate(zip(_shapes, _dtypes)): OpTupleItem( shape, dtype, index=i, parent=self, name=name + "[{}]".format(i), )
def SJ(x, y, N, preallocate=False): symjax.current_graph().reset() sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D]) sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1]) np.random.seed(0) sj_W = T.Variable(np.random.randn(D, 1).astype("float32")) sj_b = T.Variable( np.random.randn( 1, ).astype("float32") ) sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean() optimizers.Adam(sj_loss, lr) train = symjax.function(sj_input, sj_output, updates=symjax.get_updates()) if preallocate: import jax x = jax.device_put(x) y = jax.device_put(y) t = time.time() for i in range(N): train(x, y) return time.time() - t
def SJ_EMA(X, debias=True): symjax.current_graph().reset() x = T.Placeholder((), "float32", name="x") value = symjax.nn.schedules.ExponentialMovingAverage(x, 0.9, debias=debias)[0] train = symjax.function(x, outputs=value, updates=symjax.get_updates()) outputs = [] for i in range(len(X)): outputs.append(train(X[i])) return outputs
def test_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32") out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u]) f = sj.function(u, outputs=out, updates={w: w + 1}) assert np.array_equal(f(2), np.arange(3)) assert np.array_equal(f(2), np.zeros(3)) assert np.array_equal(f(0), -np.arange(3) * 3)
def __init__(self, *args, inplace_copy=None, **kwargs): if inplace_copy is not None: return if "_attrs" in kwargs: if "_shape" in kwargs["_attrs"]: self._shape = kwargs["_attrs"]["_shape"] if "_dtype" in kwargs["_attrs"]: self._dtype = kwargs["_attrs"]["_dtype"] symjax.current_graph()._add(self, *args, **kwargs)
def test_updating_variables(): sj.current_graph().reset() w1 = symjax.tensor.Variable(1.0, dtype="float32") input = symjax.tensor.Placeholder((), "float32") update = w1 + input + 1 f = symjax.function(input, updates={w1: update}) assert w1.value == 1.0 f(10) assert w1.value == 12.0
def test_grad_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32", name="u") out = T.map(lambda a, w, u: w * a * u, (T.range(3), ), non_sequences=(w, u)) g = sj.gradients(out.sum(), w) f = sj.function(u, outputs=g) assert np.array_equal(f(0), 0) assert np.array_equal(f(1), 3)
def test_seed(): a = T.random.randn((), seed=10) b = T.random.randn(()) c = T.random.randn((), seed=10) f = symjax.function(outputs=[a, b, c]) result1 = f() result2 = f() print(result1) print(result2) assert result1[0] == result1[2] assert result1[0] != result1[1] assert result2[0] == result2[2] assert result2[0] != result1[0] a = T.random.randn((), seed=10) b = T.random.randn(()) c = T.random.randn((), seed=10) f = symjax.function(outputs=[a, b, c]) result12 = f() result22 = f() assert result12[0] == result12[2] assert result12[0] != result12[1] assert result22[0] == result22[2] assert result22[0] != result12[0] assert np.isclose(result1[0], result12[0]) assert np.isclose(result1[2], result12[2]) assert not np.isclose(result1[1], result12[1]) assert np.isclose(result2[0], result22[0]) assert np.isclose(result2[2], result22[2]) assert not np.isclose(result2[1], result22[1]) symjax.current_graph().reset() a = T.random.randn((), seed=10) b = T.random.randn(()) c = T.random.randn((), seed=10) f = symjax.function(outputs=[a, b, c]) result12 = f() result22 = f() assert result12[0] == result12[2] assert result12[0] != result12[1] assert result22[0] == result22[2] assert result22[0] != result12[0] assert np.isclose(result1[0], result12[0]) assert np.isclose(result1[2], result12[2]) assert not np.isclose(result1[1], result12[1]) assert np.isclose(result2[0], result22[0]) assert np.isclose(result2[2], result22[2]) assert not np.isclose(result2[1], result22[1])
def test_clone_0(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") with sj.Scope("placing"): u = T.Placeholder((), "float32", name="u") value = 2 * w * u c = value.clone({w: u}) f = sj.function(u, outputs=value) g = sj.function(u, outputs=c) assert np.array_equal([f(1), g(1), f(2), g(2)], [2, 2, 4, 8])
def test_sma(): symjax.current_graph().reset() a = symjax.tensor.Placeholder((4, ), "float32") sma, var = symjax.nn.schedules.SimpleMovingAverage(a, 3) f = symjax.function(a, outputs=[sma, var], updates=symjax.get_updates()) data = np.random.randn(4, 4) current = [data[0], data[:2].mean(0), data[:3].mean(0), data[1:4].mean(0)] for i in range(data.shape[0]): out = f(data[i]) assert np.allclose(out[0], current[i])
def test_vectorize(): sj.current_graph().reset() x = symjax.tensor.Placeholder((0, 2), "float32") w = symjax.tensor.Variable(1.0, dtype="float32") p = x.sum(1) f = symjax.function(x, outputs=p, updates={w: x.sum()}) assert np.array_equal(f(np.ones((1, 2))), [2.0]) assert w.value == 2.0 assert np.array_equal(f(np.ones((2, 2))), [2.0, 2.0]) assert w.value == 4.0
def __init__(self, _shape, _dtype, name=None, **kwargs): self._shape = tuple(_shape) self._dtype = jax.dtypes.dtype(_dtype) if name is not None: assert "/" not in name self._name = name else: self._name = "unnamed" symjax.current_graph().add(self, **kwargs)
def test_global_pool(): np.random.seed(0) sj.current_graph().reset() BATCH_SIZE = 4096 DIM = 8 input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input") output = nn.layers.Dense(input, 64) output = nn.layers.Dense(output, output.shape[-1] * 2) output = nn.layers.Dense(output, output.shape[-1] * 2) get = sj.function(input, outputs=output) assert get(np.ones((BATCH_SIZE, DIM))).shape == (BATCH_SIZE, 64 * 4)
def test_while(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") v = T.Placeholder((), "float32") out = T.while_loop( lambda i, u: i[0] + u < 5, lambda i: (i[0] + 1.0, i[0]**2), (w, 1.0), non_sequences_cond=(v, ), ) f = sj.function(v, outputs=out) assert np.array_equal(np.array(f(0)), [5, 16]) assert np.array_equal(f(2), [3, 4])
def test_cond2(): sj.current_graph().reset() v = T.ones((10, 10)) u = T.Placeholder((), "int32") out = T.cond( u > 0, lambda u: 4 * u, lambda u: u, true_inputs=(v, ), false_inputs=(2 * v, ), ) f = sj.function(u, outputs=out) assert np.array_equal(f(1), 4 * np.ones((10, 10))) assert np.array_equal(f(0), 2 * np.ones((10, 10)))
def test_accessing_variables(): sj.current_graph().reset() w1 = symjax.tensor.Variable(1.0, trainable=True) w2 = symjax.tensor.Variable(1.0, trainable=True) w3 = symjax.tensor.Variable(1.0, trainable=False) v = symjax.get_variables("*", trainable=True) assert w1 in v and w2 in v and w3 not in v v = symjax.get_variables("*", trainable=False) assert w1 not in v and w2 not in v and w3 in v v = symjax.get_variables("*test") assert len(v) == 0
def test_ema(): symjax.current_graph().reset() a = symjax.tensor.Placeholder((), "float32") ema, var = symjax.nn.schedules.ExponentialMovingAverage(a, 0.9, debias=False) # t = symjax.get_variables("*num_steps*", trainable=False) f = symjax.function(a, outputs=[ema, var], updates=symjax.get_updates()) current = 0 for i in range(10): out = f(1) assert np.allclose(out[1], current) current = 0.9 * current + 0.1 * 1 assert np.allclose(out[0], current)
def test_reset(): sj.current_graph().reset() w = symjax.tensor.Variable(1.0, name="w", dtype="float32") x = symjax.tensor.Variable(2.0, name="x", dtype="float32") f = symjax.function(outputs=[w, x], updates={w: w + 1, x: x + 1}) for i in range(10): print(i) assert np.array_equal(np.array(f()), np.array([1, 2]) + i) # reset only the w variable symjax.reset_variables("*w") assert np.array_equal(np.array(f()), np.array([1, 2 + i + 1])) # reset all variables symjax.reset_variables("*") assert np.array_equal(np.array(f()), np.array([1, 2]))
def __new__( cls, *args, _jax_function, _shapes, _dtypes, name=None, **kwargs, ): scope = symjax.current_graph().scope.absolute_name items = [] for i, (shape, dtype) in enumerate(zip(_shapes, _dtypes)): items.append( OpItem( _attrs={ "name": _jax_function.__name__ + "[{}]".format(i), "scope": scope, "jax_function": _jax_function, "root": False, "parent_index": i, "_shape": shape, "_dtype": dtype, })) return super(MultiOutputOp, cls).__new__(cls, tuple(items))
def __init__(self, *args, _jax_function, _shape, _dtype, _seed, name=None, **kwargs): if name is None: name = _jax_function.__name__ name, scope = symjax.current_graph()._get_name_scope(name, self) _seed = ord(os.urandom(1)) if _seed is None else _seed seed_op = Seed(_seed) Tensor.__init__( self, seed_op, *args, _attrs={ "name": name, "scope": scope, "_shape": _shape, "_dtype": _dtype, "jax_function": _jax_function, "root": False, }, **kwargs, )
def __init__(self, *args, name=None, **kwargs): if name is None: name = self.__NAME__ with symjax.Scope(name): self.create_updates(*args, **kwargs) self._scope_name = symjax.current_graph().scope_name
def __init__( self, initializer, name="unnamed", trainable=True, shape=None, dtype=None, ): if trainable and dtype == "bool": raise RuntimeError("error impossible learning with dtype bool") assert not isvar(shape) name, scope = symjax.current_graph()._get_name_scope(name, self) value = self._reset(initializer, shape, dtype) shape = tuple(shape or value.shape) dtype = jax.numpy.dtype(dtype) if dtype is not None else value.dtype super().__init__( _attrs={ "name": name, "scope": scope, "_shape": shape, "_dtype": dtype, "trainable": trainable, "initializer": initializer, "root": True, "value": value, })
def create_variable( name, tensor_or_func, shape, trainable, inplace=False, dtype="float32", preprocessor=None, ): if tensor_or_func is None: return None if inplace: assert not callable(tensor_or_func) return tensor_or_func variable = T.Variable( tensor_or_func, name=name, shape=symjax.current_graph().get(shape), dtype=dtype, trainable=trainable, ) if preprocessor is not None: return preprocessor(variable) else: return variable
def update(self, other_seed=None): """ update the seed either by splitting the current one effectively generating a new random seed or by using a given one """ if other_seed is not None: if len(other_seed) != 2: raise RuntimeError( "given updated seed {other_seed} is not valid") symjax.current_graph().nodes[self]["value"] = other_seed else: new_key = jax.random.split(self.value, 1)[0] symjax.current_graph().nodes[self]["value"] = new_key