예제 #1
0
    def test_connects(self):
        transformer_model = elegy.nn.Transformer(
            head_size=64,
            num_heads=4,
            num_encoder_layers=2,
            num_decoder_layers=2,
        )

        src = np.random.uniform(size=(5, 32, 64))
        tgt = np.random.uniform(size=(5, 32, 64))

        _, params, collections = transformer_model.init(rng=elegy.RNGSeq(42))(
            src, tgt)
        out, params, collections = transformer_model.apply(
            params, collections, rng=elegy.RNGSeq(420))(src, tgt)
예제 #2
0
    def test_basic(self):
        class M(linen.Module):
            @linen.compact
            def __call__(self, x):

                initialized = self.has_variable("batch_stats", "n")

                vn = self.variable("batch_stats", "n", lambda: 0)

                w = self.param("w", lambda key: 2.0)

                if initialized:
                    vn.value += 1

                return x * w

        gm = generalize(M())
        rng = elegy.RNGSeq(42)

        y_true, params, states = gm.init(rng)(x=3.0, y=1)

        assert y_true == 6
        assert params["w"] == 2
        assert states["batch_stats"]["n"] == 0

        params = params.copy(dict(w=10.0))
        y_true, params, states = gm.apply(params,
                                          states,
                                          training=True,
                                          rng=rng)(x=3.0, y=1)

        assert y_true == 30
        assert params["w"] == 10
        assert states["batch_stats"]["n"] == 1
예제 #3
0
    def test_basic(self):
        class M(elegy.Module):
            def call(self, x):
                n = self.add_parameter("n", lambda: 0, trainable=False)
                w = self.add_parameter("w", lambda: 2.0)

                self.update_parameter("n", n + 1)

                key = self.next_key()

                return x * w

        gm = generalize(M())
        rng = elegy.RNGSeq(42)

        y_true, params, states = gm.init(rng)(x=3.0, y=1)

        assert y_true == 6
        assert params["w"] == 2
        assert states["states"]["n"] == 0

        params["w"] = 10.0
        y_true, params, states = gm.apply(params, states, training=True, rng=rng)(
            x=3.0, y=1
        )

        assert y_true == 30
        assert params["w"] == 10
        assert states["states"]["n"] == 1
예제 #4
0
    def test_basic(self):
        class M(haiku.Module):
            def __call__(self, x):

                n = haiku.get_state(
                    "n", shape=[], dtype=jnp.int32, init=lambda *args: np.array(0)
                )
                w = haiku.get_parameter("w", [], init=lambda *args: np.array(2.0))

                haiku.set_state("n", n + 1)

                return x * w

        def f(x, initializing, rng):
            return M()(x)

        gm = elegy.HaikuModule(f)
        rng = elegy.RNGSeq(42)

        y_true, params, states = gm.init(rng)(x=3.0, y=1, rng=None, initializing=True)

        assert y_true == 6
        assert params["m"]["w"] == 2
        assert states["m"]["n"] == 0

        params = haiku.data_structures.to_mutable_dict(params)
        params["m"]["w"] = np.array(10.0)
        y_true, params, states = gm.apply(params, states, training=True, rng=rng)(
            x=3.0, y=1, rng=None, initializing=True
        )

        assert y_true == 30
        assert params["m"]["w"] == 10
        assert states["m"]["n"] == 1
예제 #5
0
    def test_losses(self):
        def loss_fn():
            return 3.0

        losses = elegy.model.model.Losses(
            dict(a=dict(b=[loss_fn, loss_fn], c=loss_fn)))

        rng = elegy.RNGSeq(42)
        hooks_losses = dict(x=0.3, y=4.5)

        with elegy.hooks.context(losses=True):
            elegy.hooks.add_loss("d", 1.0)
            aux_losses = elegy.hooks.get_losses()
            logs, logs, states = losses.init(aux_losses, rng)()

        with elegy.hooks.context(losses=True):
            elegy.hooks.add_loss("d", 1.0)
            aux_losses = elegy.hooks.get_losses()
            loss, logs, states = losses.apply(aux_losses, states)()

        assert loss == 10

        assert len(losses.losses) == 3
        assert "a/b/loss_fn" in losses.losses
        assert "a/b/loss_fn_1" in losses.losses
        assert "a/c/loss_fn" in losses.losses

        assert len(logs) == 5
        assert "loss" in logs
        assert "a/b/loss_fn" in logs
        assert "a/b/loss_fn_1" in logs
        assert "a/c/loss_fn" in logs
        assert "d_loss" in logs
예제 #6
0
    def test_optimizer_epoch(self):
        optax_op = optax.adam(1e-3)
        lr_schedule = lambda step, epoch: epoch

        optimizer = elegy.Optimizer(optax_op,
                                    lr_schedule=lr_schedule,
                                    steps_per_epoch=2)

        params = np.random.uniform((3, 4))
        grads = np.random.uniform((3, 4))
        rng = elegy.RNGSeq(42)

        optimizer_states = optimizer.init(
            rng=rng,
            net_params=params,
        )

        assert jnp.allclose(optimizer.current_lr(optimizer_states), 0)
        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)

        assert jnp.allclose(optimizer.current_lr(optimizer_states), 0)
        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)

        assert jnp.allclose(optimizer.current_lr(optimizer_states), 1)
        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)

        assert jnp.allclose(optimizer.current_lr(optimizer_states), 1)
        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)
예제 #7
0
    def test_connects(self):

        x = np.random.uniform(-1, 1, size=(4, 3))
        linear = elegy.nn.Linear(5)

        y_pred = linear.call_with_defaults(rng=elegy.RNGSeq(42))(x)

        assert y_pred.shape == (4, 5)
예제 #8
0
    def test_module_system_docs(self):
        class Linear(elegy.Module):
            def __init__(self, n_out):
                super().__init__()
                self.n_out = n_out

            def call(self, x):
                w = self.add_parameter(
                    "w",
                    lambda: elegy.initializers.RandomUniform()(
                        shape=[x.shape[-1], self.n_out],
                        dtype=jnp.float32,
                    ),
                )
                b = self.add_parameter("b",
                                       lambda: jnp.zeros(shape=[self.n_out]))

                return jnp.dot(x, w) + b

        class MLP(elegy.Module):
            def call(self, x):
                x = Linear(64)(x)
                x = jax.nn.relu(x)
                x = Linear(32)(x)
                x = jax.nn.relu(x)
                x = Linear(1)(x)
                return x

        def loss_fn(parameters, x, y):
            y_pred, _ = mlp.apply(dict(parameters=parameters))(x)
            return jnp.mean(jnp.square(y - y_pred))

        def update(parameters, x, y):
            loss, gradients = jax.value_and_grad(loss_fn)(parameters, x, y)
            parameters = jax.tree_multimap(lambda p, g: p - 0.01 * g,
                                           parameters, gradients)

            return loss, parameters

        x = np.random.uniform(size=(15, 3))
        y = np.random.uniform(size=(15, 1))
        mlp = MLP()

        y_pred, collections = mlp.init(rng=elegy.RNGSeq(42))(x)

        parameters = collections["parameters"]

        update_jit = jax.jit(update)

        for step in range(1):
            loss, parameters = update_jit(parameters, x, y)

        mlp.set_default_parameters(dict(parameters=parameters))
예제 #9
0
def test_basic():

    w = 2.0
    grads = 1.5
    lr = 1.0
    rng = elegy.RNGSeq(42)

    go = generalize_optimizer(optax.sgd(lr))

    states = go.init(rng, w)
    w, states = go.apply(w, grads, states, rng)

    assert w == 0.5
예제 #10
0
    def test_connects(self):

        y = elegy.nn.Sequential(
            lambda: [
                elegy.nn.Flatten(),
                elegy.nn.Linear(5),
                jax.nn.relu,
                elegy.nn.Linear(2),
            ]
        ).call_with_defaults(rng=elegy.RNGSeq(42))(jnp.ones([10, 3]))

        assert y.shape == (10, 2)

        y = elegy.nn.Sequential(
            lambda: [
                elegy.nn.Flatten(),
                elegy.nn.Linear(5),
                jax.nn.relu,
                elegy.nn.Linear(2),
            ]
        ).call_with_defaults(rng=elegy.RNGSeq(42), training=False)(jnp.ones([10, 3]))

        assert y.shape == (10, 2)
예제 #11
0
파일: model.py 프로젝트: Dave0995/elegy
    def init_step(self, x):
        rng = elegy.RNGSeq(0)
        gx, g_params, g_states = self.generator.init(rng=rng)(x)
        dx, d_params, d_states = self.discriminator.init(rng=rng)(gx)

        g_optimizer_states = self.g_optimizer.init(g_params)
        d_optimizer_states = self.d_optimizer.init(d_params)

        return elegy.States(
            g_states=g_states,
            d_states=d_states,
            g_params=g_params,
            d_params=d_params,
            g_opt_states=g_optimizer_states,
            d_opt_states=d_optimizer_states,
            rng=rng,
            step=0,
        )
예제 #12
0
    def test_di(self):

        m = elegy.nn.Sequential(
            lambda: [
                elegy.nn.Flatten(),
                elegy.nn.Linear(2),
            ]
        )

        y = elegy.inject_dependencies(
            m.call_with_defaults(rng=elegy.RNGSeq(42), training=False), signature_f=m
        )(
            jnp.ones([5, 3]),
            a=1,
            b=2,
        )

        assert y.shape == (5, 2)
예제 #13
0
    def test_optimizer_chain(self):

        optimizer = elegy.Optimizer(
            optax.sgd(0.1),
            optax.clip(0.5),
        )

        params = np.zeros(shape=(3, 4))
        grads = np.ones(shape=(3, 4)) * 100_000
        rng = elegy.RNGSeq(42)

        optimizer_states = optimizer.init(
            rng=rng,
            net_params=params,
        )

        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)

        assert np.all(-0.5 <= params) and np.all(params <= 0.5)
예제 #14
0
    def test_metrics(self):
        class M(elegy.Module):
            def call(self, x):
                n = self.add_parameter("n", lambda: 0, trainable=False)
                self.update_parameter("n", n + 1)
                return x

        metrics = elegy.model.model.Metrics(dict(a=dict(b=[M(), M()], c=M())))

        rng = elegy.RNGSeq(42)
        x = np.random.uniform(size=(5, 7, 7))

        with elegy.hooks.context(metrics=True):
            elegy.hooks.add_metric("d", 10)
            aux_metrics = elegy.hooks.get_metrics()
            logs, states = metrics.init(aux_metrics, rng)(x, training=True)

        with elegy.hooks.context(metrics=True):
            elegy.hooks.add_metric("d", 10)
            aux_metrics = elegy.hooks.get_metrics()
            logs, states = metrics.apply(aux_metrics, rng,
                                         states)(x, training=True)

        assert len(metrics.metrics) == 3
        assert "a/b/m" in metrics.metrics
        assert "a/b/m_1" in metrics.metrics
        assert "a/c/m" in metrics.metrics

        assert len(logs) == 4
        assert "a/b/m" in logs
        assert "a/b/m_1" in logs
        assert "a/c/m" in logs
        assert "d" in logs

        assert len(states) == 3
        assert "a/b/m" in states
        assert "a/b/m_1" in states
        assert "a/c/m" in states
예제 #15
0
    def test_optimizer(self):
        optax_op = optax.adam(1e-3)
        lr_schedule = lambda step, epoch: step / 3

        optimizer = elegy.Optimizer(optax_op, lr_schedule=lr_schedule)

        params = np.random.uniform((3, 4))
        grads = np.random.uniform((3, 4))
        rng = elegy.RNGSeq(42)

        optimizer_states = optimizer.init(rng, params)
        assert jnp.allclose(optimizer.current_lr(optimizer_states), 0 / 3)

        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)
        assert jnp.allclose(optimizer.current_lr(optimizer_states), 1 / 3)

        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)
        assert jnp.allclose(optimizer.current_lr(optimizer_states), 2 / 3)

        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)
        assert jnp.allclose(optimizer.current_lr(optimizer_states), 3 / 3)
예제 #16
0
 def test_dropout_connects(self):
     elegy.nn.Dropout(0.25).call_with_defaults(rng=elegy.RNGSeq(42))(
         jnp.ones([3, 3]), training=True)