Exemplo n.º 1
0
    def update(self, update_value, fast=True):
        """assign a new value for the variable"""
        if fast:
            symjax.current_graph().nodes[self]["value"] = update_value

        new_value = symjax.current_graph().get(update_value)

        if self.shape != jax.numpy.shape(new_value):
            warnings.warn("Variable and update of {}".format(self) +
                          "are not the same shape (expected {}, got {}".format(
                              self.shape, jax.numpy.shape(new_value)) +
                          "... attempting to cast")
            new_value = jax.numpy.reshape(new_value, self.shape)

        if hasattr(new_value, "dtype"):
            ntype = new_value.dtype
        else:
            ntype = type(new_value)
        if self.dtype != ntype:
            warnings.warn("Variable and update of {}".format(self) +
                          "are not the same dtype (expected {}, got {}".format(
                              self.dtype, ntype) + "... attempting to cast")

            new_value = jax.numpy.asarray(new_value).astype(self.dtype)

        symjax.current_graph().nodes[self]["value"] = new_value
Exemplo n.º 2
0
def test_bn():
    sj.current_graph().reset()
    BATCH_SIZE = 5
    DIM = 2
    input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input")
    deterministic = T.Placeholder((1,), "bool", name="deterministic")

    bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic)

    update = sj.function(input, deterministic, outputs=bn, updates=sj.get_updates())
    get_stats = sj.function(input, outputs=bn.avg_mean)

    data = np.random.randn(50, DIM) * 4 + 2

    true_means = []
    actual_means = []

    for i in range(10):
        batch = data[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
        output = update(batch, 0)
        assert np.allclose(
            output, (batch - batch.mean(0)) / (1e-4 + batch.std(0)), 1e-4
        )
        actual_means.append(get_stats(batch))
        if i == 0:
            true_means.append(batch.mean(0))
        else:
            true_means.append(0.9 * true_means[-1] + 0.1 * batch.mean(0))

    true_means = np.array(true_means)
    actual_means = np.array(actual_means).squeeze()

    assert np.allclose(true_means, actual_means, 1e-4)
Exemplo n.º 3
0
def SJ(x, y, N, lr, model, preallocate=False):
    symjax.current_graph().reset()
    sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])
    sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])

    np.random.seed(0)

    sj_W = T.Variable(np.random.randn(D, 1).astype("float32"))
    sj_b = T.Variable(np.random.randn(1, ).astype("float32"))

    sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output)**2).mean()

    if model == "SGD":
        optimizers.SGD(sj_loss, lr)
    elif model == "Adam":
        optimizers.Adam(sj_loss, lr)
    train = symjax.function(sj_input,
                            sj_output,
                            outputs=sj_loss,
                            updates=symjax.get_updates())

    losses = []
    for i in tqdm(range(N)):
        losses.append(train(x, y))

    return losses
Exemplo n.º 4
0
def test_update():
    sj.current_graph().reset()
    w = symjax.tensor.zeros(10)
    for i in range(10):
        w = symjax.tensor.index_update(w, i, i)
    f = symjax.function(outputs=w)
    assert np.array_equal(f(), np.arange(10))
    w2 = symjax.tensor.zeros(10)
    for i in range(10):
        w2 = symjax.tensor.index_update(w2, (i, ), i)
    f = symjax.function(outputs=w2)
    assert np.array_equal(f(), np.arange(10))

    w3 = symjax.tensor.zeros(10)
    for i in range(10):
        w3 = symjax.tensor.index_update(w3, symjax.tensor.index[i], i)
    f = symjax.function(outputs=w3)
    assert np.array_equal(f(), np.arange(10))

    w4 = symjax.tensor.Variable(symjax.tensor.zeros(10))
    i = symjax.tensor.Variable(0, dtype="int32")
    update = symjax.tensor.index_update(w4, i, i)
    f = symjax.function(updates={w4: update, i: i + 1})
    for i in range(10):
        f()
    assert np.array_equal(w4.value, np.arange(10))
Exemplo n.º 5
0
    def __init__(self,
                 *args,
                 _jax_function,
                 _shapes,
                 _dtypes,
                 name=None,
                 **kwargs):
        if name is None:
            name = _jax_function.__name__

        name, scope = symjax.current_graph()._get_name_scope(name, self)

        Tensor.__init__(
            self,
            *args,
            _attrs={
                "name": name,
                "scope": scope,
                "_dtype": MultiOutputOp,
                "jax_function": _jax_function,
                "root": False,
            },
            **kwargs,
        )

        for i, child in enumerate(self):
            symjax.current_graph().add_edge(
                self,
                child,
                name="parent_index" + str(i),
            )
            symjax.current_graph().nodes[child]["parent"] = self
Exemplo n.º 6
0
    def __init__(
        self,
        *args,
        _jax_function=None,
        _shape=None,
        _dtype=None,
        name=None,
        **kwargs,
    ):

        if self in symjax.current_graph().nodes:
            return

        if name is None:
            name = _jax_function.__name__

        name, scope = symjax.current_graph()._get_name_scope(name, self)

        super().__init__(
            *args,
            _attrs={
                "name": name,
                "scope": scope,
                "_shape": _shape,
                "_dtype": _dtype,
                "jax_function": _jax_function,
                "root": False,
            },
            **kwargs,
        )
Exemplo n.º 7
0
    def __init__(self,
                 *args,
                 _jax_function,
                 _shapes,
                 _dtypes,
                 name=None,
                 **kwargs):

        if name is None:
            name = _jax_function.__name__
        self._name = name
        symjax.current_graph().add(
            self,
            jax_function=_jax_function,
            args=args,
            kwargs=kwargs,
            **kwargs,
        )
        for i, (shape, dtype) in enumerate(zip(_shapes, _dtypes)):
            OpTupleItem(
                shape,
                dtype,
                index=i,
                parent=self,
                name=name + "[{}]".format(i),
            )
Exemplo n.º 8
0
def SJ(x, y, N, preallocate=False):
    symjax.current_graph().reset()
    sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])
    sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])

    np.random.seed(0)

    sj_W = T.Variable(np.random.randn(D, 1).astype("float32"))
    sj_b = T.Variable(
        np.random.randn(
            1,
        ).astype("float32")
    )

    sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean()

    optimizers.Adam(sj_loss, lr)

    train = symjax.function(sj_input, sj_output, updates=symjax.get_updates())

    if preallocate:
        import jax

        x = jax.device_put(x)
        y = jax.device_put(y)

    t = time.time()
    for i in range(N):
        train(x, y)

    return time.time() - t
Exemplo n.º 9
0
def SJ_EMA(X, debias=True):
    symjax.current_graph().reset()
    x = T.Placeholder((), "float32", name="x")
    value = symjax.nn.schedules.ExponentialMovingAverage(x, 0.9,
                                                         debias=debias)[0]
    train = symjax.function(x, outputs=value, updates=symjax.get_updates())
    outputs = []
    for i in range(len(X)):
        outputs.append(train(X[i]))
    return outputs
Exemplo n.º 10
0
def test_map():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32")
    out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)],
                non_sequences=[w, u])
    f = sj.function(u, outputs=out, updates={w: w + 1})
    assert np.array_equal(f(2), np.arange(3))
    assert np.array_equal(f(2), np.zeros(3))
    assert np.array_equal(f(0), -np.arange(3) * 3)
Exemplo n.º 11
0
    def __init__(self, *args, inplace_copy=None, **kwargs):

        if inplace_copy is not None:
            return
        if "_attrs" in kwargs:
            if "_shape" in kwargs["_attrs"]:
                self._shape = kwargs["_attrs"]["_shape"]
            if "_dtype" in kwargs["_attrs"]:
                self._dtype = kwargs["_attrs"]["_dtype"]
        symjax.current_graph()._add(self, *args, **kwargs)
Exemplo n.º 12
0
def test_updating_variables():
    sj.current_graph().reset()
    w1 = symjax.tensor.Variable(1.0, dtype="float32")
    input = symjax.tensor.Placeholder((), "float32")
    update = w1 + input + 1
    f = symjax.function(input, updates={w1: update})

    assert w1.value == 1.0
    f(10)
    assert w1.value == 12.0
Exemplo n.º 13
0
def test_grad_map():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32", name="u")
    out = T.map(lambda a, w, u: w * a * u, (T.range(3), ),
                non_sequences=(w, u))
    g = sj.gradients(out.sum(), w)
    f = sj.function(u, outputs=g)

    assert np.array_equal(f(0), 0)
    assert np.array_equal(f(1), 3)
Exemplo n.º 14
0
def test_seed():
    a = T.random.randn((), seed=10)
    b = T.random.randn(())
    c = T.random.randn((), seed=10)
    f = symjax.function(outputs=[a, b, c])
    result1 = f()
    result2 = f()
    print(result1)
    print(result2)
    assert result1[0] == result1[2]
    assert result1[0] != result1[1]

    assert result2[0] == result2[2]
    assert result2[0] != result1[0]

    a = T.random.randn((), seed=10)
    b = T.random.randn(())
    c = T.random.randn((), seed=10)
    f = symjax.function(outputs=[a, b, c])
    result12 = f()
    result22 = f()
    assert result12[0] == result12[2]
    assert result12[0] != result12[1]
    assert result22[0] == result22[2]
    assert result22[0] != result12[0]

    assert np.isclose(result1[0], result12[0])
    assert np.isclose(result1[2], result12[2])
    assert not np.isclose(result1[1], result12[1])

    assert np.isclose(result2[0], result22[0])
    assert np.isclose(result2[2], result22[2])
    assert not np.isclose(result2[1], result22[1])

    symjax.current_graph().reset()

    a = T.random.randn((), seed=10)
    b = T.random.randn(())
    c = T.random.randn((), seed=10)
    f = symjax.function(outputs=[a, b, c])
    result12 = f()
    result22 = f()
    assert result12[0] == result12[2]
    assert result12[0] != result12[1]
    assert result22[0] == result22[2]
    assert result22[0] != result12[0]

    assert np.isclose(result1[0], result12[0])
    assert np.isclose(result1[2], result12[2])
    assert not np.isclose(result1[1], result12[1])

    assert np.isclose(result2[0], result22[0])
    assert np.isclose(result2[2], result22[2])
    assert not np.isclose(result2[1], result22[1])
Exemplo n.º 15
0
def test_clone_0():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    with sj.Scope("placing"):
        u = T.Placeholder((), "float32", name="u")
    value = 2 * w * u
    c = value.clone({w: u})
    f = sj.function(u, outputs=value)
    g = sj.function(u, outputs=c)

    assert np.array_equal([f(1), g(1), f(2), g(2)], [2, 2, 4, 8])
Exemplo n.º 16
0
def test_sma():
    symjax.current_graph().reset()
    a = symjax.tensor.Placeholder((4, ), "float32")
    sma, var = symjax.nn.schedules.SimpleMovingAverage(a, 3)
    f = symjax.function(a, outputs=[sma, var], updates=symjax.get_updates())

    data = np.random.randn(4, 4)
    current = [data[0], data[:2].mean(0), data[:3].mean(0), data[1:4].mean(0)]

    for i in range(data.shape[0]):
        out = f(data[i])
        assert np.allclose(out[0], current[i])
Exemplo n.º 17
0
def test_vectorize():
    sj.current_graph().reset()
    x = symjax.tensor.Placeholder((0, 2), "float32")
    w = symjax.tensor.Variable(1.0, dtype="float32")
    p = x.sum(1)

    f = symjax.function(x, outputs=p, updates={w: x.sum()})

    assert np.array_equal(f(np.ones((1, 2))), [2.0])
    assert w.value == 2.0
    assert np.array_equal(f(np.ones((2, 2))), [2.0, 2.0])
    assert w.value == 4.0
Exemplo n.º 18
0
    def __init__(self, _shape, _dtype, name=None, **kwargs):

        self._shape = tuple(_shape)
        self._dtype = jax.dtypes.dtype(_dtype)

        if name is not None:
            assert "/" not in name
            self._name = name
        else:
            self._name = "unnamed"

        symjax.current_graph().add(self, **kwargs)
Exemplo n.º 19
0
def test_global_pool():
    np.random.seed(0)
    sj.current_graph().reset()
    BATCH_SIZE = 4096
    DIM = 8
    input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input")

    output = nn.layers.Dense(input, 64)
    output = nn.layers.Dense(output, output.shape[-1] * 2)
    output = nn.layers.Dense(output, output.shape[-1] * 2)
    get = sj.function(input, outputs=output)
    assert get(np.ones((BATCH_SIZE, DIM))).shape == (BATCH_SIZE, 64 * 4)
Exemplo n.º 20
0
def test_while():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    v = T.Placeholder((), "float32")
    out = T.while_loop(
        lambda i, u: i[0] + u < 5,
        lambda i: (i[0] + 1.0, i[0]**2),
        (w, 1.0),
        non_sequences_cond=(v, ),
    )
    f = sj.function(v, outputs=out)
    assert np.array_equal(np.array(f(0)), [5, 16])
    assert np.array_equal(f(2), [3, 4])
Exemplo n.º 21
0
def test_cond2():
    sj.current_graph().reset()
    v = T.ones((10, 10))
    u = T.Placeholder((), "int32")
    out = T.cond(
        u > 0,
        lambda u: 4 * u,
        lambda u: u,
        true_inputs=(v, ),
        false_inputs=(2 * v, ),
    )
    f = sj.function(u, outputs=out)
    assert np.array_equal(f(1), 4 * np.ones((10, 10)))
    assert np.array_equal(f(0), 2 * np.ones((10, 10)))
Exemplo n.º 22
0
def test_accessing_variables():
    sj.current_graph().reset()
    w1 = symjax.tensor.Variable(1.0, trainable=True)
    w2 = symjax.tensor.Variable(1.0, trainable=True)
    w3 = symjax.tensor.Variable(1.0, trainable=False)

    v = symjax.get_variables("*", trainable=True)
    assert w1 in v and w2 in v and w3 not in v

    v = symjax.get_variables("*", trainable=False)
    assert w1 not in v and w2 not in v and w3 in v

    v = symjax.get_variables("*test")
    assert len(v) == 0
Exemplo n.º 23
0
def test_ema():
    symjax.current_graph().reset()
    a = symjax.tensor.Placeholder((), "float32")
    ema, var = symjax.nn.schedules.ExponentialMovingAverage(a,
                                                            0.9,
                                                            debias=False)
    # t = symjax.get_variables("*num_steps*", trainable=False)
    f = symjax.function(a, outputs=[ema, var], updates=symjax.get_updates())
    current = 0

    for i in range(10):
        out = f(1)
        assert np.allclose(out[1], current)
        current = 0.9 * current + 0.1 * 1
        assert np.allclose(out[0], current)
Exemplo n.º 24
0
def test_reset():
    sj.current_graph().reset()
    w = symjax.tensor.Variable(1.0, name="w", dtype="float32")
    x = symjax.tensor.Variable(2.0, name="x", dtype="float32")
    f = symjax.function(outputs=[w, x], updates={w: w + 1, x: x + 1})
    for i in range(10):
        print(i)
        assert np.array_equal(np.array(f()), np.array([1, 2]) + i)

    # reset only the w variable
    symjax.reset_variables("*w")
    assert np.array_equal(np.array(f()), np.array([1, 2 + i + 1]))
    # reset all variables
    symjax.reset_variables("*")
    assert np.array_equal(np.array(f()), np.array([1, 2]))
Exemplo n.º 25
0
    def __new__(
        cls,
        *args,
        _jax_function,
        _shapes,
        _dtypes,
        name=None,
        **kwargs,
    ):
        scope = symjax.current_graph().scope.absolute_name
        items = []

        for i, (shape, dtype) in enumerate(zip(_shapes, _dtypes)):
            items.append(
                OpItem(
                    _attrs={
                        "name": _jax_function.__name__ + "[{}]".format(i),
                        "scope": scope,
                        "jax_function": _jax_function,
                        "root": False,
                        "parent_index": i,
                        "_shape": shape,
                        "_dtype": dtype,
                    }))
        return super(MultiOutputOp, cls).__new__(cls, tuple(items))
Exemplo n.º 26
0
    def __init__(self,
                 *args,
                 _jax_function,
                 _shape,
                 _dtype,
                 _seed,
                 name=None,
                 **kwargs):

        if name is None:
            name = _jax_function.__name__

        name, scope = symjax.current_graph()._get_name_scope(name, self)

        _seed = ord(os.urandom(1)) if _seed is None else _seed
        seed_op = Seed(_seed)

        Tensor.__init__(
            self,
            seed_op,
            *args,
            _attrs={
                "name": name,
                "scope": scope,
                "_shape": _shape,
                "_dtype": _dtype,
                "jax_function": _jax_function,
                "root": False,
            },
            **kwargs,
        )
Exemplo n.º 27
0
    def __init__(self, *args, name=None, **kwargs):

        if name is None:
            name = self.__NAME__
        with symjax.Scope(name):
            self.create_updates(*args, **kwargs)
            self._scope_name = symjax.current_graph().scope_name
Exemplo n.º 28
0
    def __init__(
        self,
        initializer,
        name="unnamed",
        trainable=True,
        shape=None,
        dtype=None,
    ):

        if trainable and dtype == "bool":
            raise RuntimeError("error impossible learning with dtype bool")

        assert not isvar(shape)

        name, scope = symjax.current_graph()._get_name_scope(name, self)
        value = self._reset(initializer, shape, dtype)
        shape = tuple(shape or value.shape)
        dtype = jax.numpy.dtype(dtype) if dtype is not None else value.dtype

        super().__init__(
            _attrs={
                "name": name,
                "scope": scope,
                "_shape": shape,
                "_dtype": dtype,
                "trainable": trainable,
                "initializer": initializer,
                "root": True,
                "value": value,
            })
Exemplo n.º 29
0
def create_variable(
    name,
    tensor_or_func,
    shape,
    trainable,
    inplace=False,
    dtype="float32",
    preprocessor=None,
):
    if tensor_or_func is None:
        return None

    if inplace:
        assert not callable(tensor_or_func)
        return tensor_or_func

    variable = T.Variable(
        tensor_or_func,
        name=name,
        shape=symjax.current_graph().get(shape),
        dtype=dtype,
        trainable=trainable,
    )
    if preprocessor is not None:
        return preprocessor(variable)
    else:
        return variable
Exemplo n.º 30
0
    def update(self, other_seed=None):
        """

        update the seed either by splitting the current one
        effectively generating a new random seed
        or by using a given one

        """
        if other_seed is not None:
            if len(other_seed) != 2:
                raise RuntimeError(
                    "given updated seed {other_seed} is not valid")
            symjax.current_graph().nodes[self]["value"] = other_seed
        else:
            new_key = jax.random.split(self.value, 1)[0]
            symjax.current_graph().nodes[self]["value"] = new_key