示例#1
0
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = einops.rearrange(x, "batch height width -> batch (height width)")
        x = eg.Linear(self.hidden_size)(x)
        x = jax.nn.relu(x)

        mean = eg.Linear(self.latent_size, name="linear_mean")(x)
        log_stddev = eg.Linear(self.latent_size, name="linear_std")(x)
        stddev = jnp.exp(log_stddev)

        self.kl_loss = KLDivergence(weight=2e-1)(mean=mean, std=stddev)

        z = mean + stddev * jax.random.normal(self.next_key(), mean.shape)

        return z
示例#2
0
    def __call__(self, image: jnp.ndarray):
        x = image.astype(jnp.float32) / 255.0
        x = eg.Flatten()(x)
        x = eg.Linear(self.n1)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(self.n2)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(self.n1)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(np.prod(image.shape[-2:]))(x)
        x = jax.nn.sigmoid(x) * 255
        x = x.reshape(image.shape)

        return x
示例#3
0
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # Normalize the input
        x = x.astype(jnp.float32) / 255.0

        # Block 1
        x = eg.Conv(32, [3, 3], strides=[2, 2])(x)
        x = eg.Dropout(0.05)(x)
        x = jax.nn.relu(x)

        # Block 2
        x = eg.Conv(64, [3, 3], strides=[2, 2])(x)
        x = eg.BatchNorm()(x)
        x = eg.Dropout(0.1)(x)
        x = jax.nn.relu(x)

        # Block 3
        x = eg.Conv(128, [3, 3], strides=[2, 2])(x)

        # Global average pooling
        x = x.mean(axis=(1, 2))

        # Classification layer
        x = eg.Linear(10)(x)

        return x
示例#4
0
    def test_evaluate(self):
        class mse(eg.Loss):
            def call(self, target, preds):
                return jnp.mean((target - preds)**2)

        class mae(eg.Metric):
            value: eg.MetricState = eg.MetricState.node(
                default=jnp.array(0.0, jnp.float32))

            def update(self, target, preds):
                return jnp.mean(jnp.abs(target - preds))

            def compute(self) -> tp.Any:
                return self.value

        model = eg.Model(
            module=eg.Linear(1),
            loss=dict(a=mse()),
            metrics=dict(b=mae()),
            optimizer=optax.adamw(1e-3),
            eager=True,
        )

        X = np.random.uniform(size=(5, 2))
        y = np.random.uniform(size=(5, 1))

        logs = model.evaluate(x=X, y=y)

        assert "a/mse_loss" in logs
        assert "b/mae" in logs
        assert "loss" in logs
示例#5
0
    def test_predict(self):

        model = eg.Model(module=eg.Linear(1))

        X = np.random.uniform(size=(5, 2))
        y = np.random.randint(10, size=(5, 1))

        y_pred = model.predict(X)

        assert y_pred.shape == (5, 1)
示例#6
0
    def __call__(self, x: jnp.ndarray):
        # normalize
        x = x.astype(jnp.float32) / 255.0

        # base
        x = ConvBlock()(x, 32, (3, 3))
        x = ConvBlock()(x, 64, (3, 3), stride=2)
        x = ConvBlock()(x, 64, (3, 3), stride=2)
        x = ConvBlock()(x, 128, (3, 3), stride=2)

        # GlobalAveragePooling2D
        x = jnp.mean(x, axis=(1, 2))

        # 1x1 Conv
        x = eg.Linear(10)(x)

        return x
示例#7
0
    def test_distributed_init(self):
        n_devices = jax.device_count()
        batch_size = 5 * n_devices

        x = np.random.uniform(size=(batch_size, 1))
        y = 1.4 * x + 0.1 * np.random.uniform(size=(batch_size, 2))

        model = eg.Model(
            eg.Linear(2),
            loss=[eg.losses.MeanSquaredError()],
        )

        model = model.distributed()

        model.init_on_batch(x)

        assert model.module.kernel.shape == (n_devices, 1, 2)
        assert model.module.bias.shape == (n_devices, 2)
示例#8
0
    def test_saved_model(self):

        with TemporaryDirectory() as model_dir:

            model = eg.Model(module=eg.Linear(4))

            x = np.random.uniform(size=(5, 6))

            model.merge

            model.saved_model(x, model_dir, batch_size=[1, 2, 4, 8])

            output = str(sh.ls(model_dir))

            assert "saved_model.pb" in output
            assert "variables" in output

            saved_model = tf.saved_model.load(model_dir)

            saved_model
示例#9
0
    def test_saved_model_poly(self):

        with TemporaryDirectory() as model_dir:

            model = eg.Model(module=eg.Linear(4))

            x = np.random.uniform(size=(5, 6)).astype(np.float32)

            model.saved_model(x, model_dir, batch_size=None)

            output = str(sh.ls(model_dir))

            assert "saved_model.pb" in output
            assert "variables" in output

            saved_model = tf.saved_model.load(model_dir)

            # change batch
            x = np.random.uniform(size=(3, 6)).astype(np.float32)
            y = saved_model(x)

            assert y.shape == (3, 4)
示例#10
0
    def test_cloudpickle(self):
        model = eg.Model(
            module=eg.Linear(10),
            loss=[
                eg.losses.Crossentropy(),
                eg.regularizers.L2(1e-4),
            ],
            metrics=eg.metrics.Accuracy(),
            optimizer=optax.adamw(1e-3),
            eager=True,
        )

        X = np.random.uniform(size=(5, 2))
        y = np.random.randint(10, size=(5, ))

        y0 = model.predict(X)

        with TemporaryDirectory() as model_dir:
            model.save(model_dir)
            newmodel = eg.load(model_dir)

        y1 = newmodel.predict(X)
        assert np.all(y0 == y1)
示例#11
0
 def __call__(self, x):
     x = eg.Linear(10)(x)
     x = jax.lax.stop_gradient(x)
     return x