def __call__(self, x: jnp.ndarray) -> jnp.ndarray: x = einops.rearrange(x, "batch height width -> batch (height width)") x = eg.Linear(self.hidden_size)(x) x = jax.nn.relu(x) mean = eg.Linear(self.latent_size, name="linear_mean")(x) log_stddev = eg.Linear(self.latent_size, name="linear_std")(x) stddev = jnp.exp(log_stddev) self.kl_loss = KLDivergence(weight=2e-1)(mean=mean, std=stddev) z = mean + stddev * jax.random.normal(self.next_key(), mean.shape) return z
def __call__(self, image: jnp.ndarray): x = image.astype(jnp.float32) / 255.0 x = eg.Flatten()(x) x = eg.Linear(self.n1)(x) x = jax.nn.relu(x) x = eg.Linear(self.n2)(x) x = jax.nn.relu(x) x = eg.Linear(self.n1)(x) x = jax.nn.relu(x) x = eg.Linear(np.prod(image.shape[-2:]))(x) x = jax.nn.sigmoid(x) * 255 x = x.reshape(image.shape) return x
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # Normalize the input x = x.astype(jnp.float32) / 255.0 # Block 1 x = eg.Conv(32, [3, 3], strides=[2, 2])(x) x = eg.Dropout(0.05)(x) x = jax.nn.relu(x) # Block 2 x = eg.Conv(64, [3, 3], strides=[2, 2])(x) x = eg.BatchNorm()(x) x = eg.Dropout(0.1)(x) x = jax.nn.relu(x) # Block 3 x = eg.Conv(128, [3, 3], strides=[2, 2])(x) # Global average pooling x = x.mean(axis=(1, 2)) # Classification layer x = eg.Linear(10)(x) return x
def test_evaluate(self): class mse(eg.Loss): def call(self, target, preds): return jnp.mean((target - preds)**2) class mae(eg.Metric): value: eg.MetricState = eg.MetricState.node( default=jnp.array(0.0, jnp.float32)) def update(self, target, preds): return jnp.mean(jnp.abs(target - preds)) def compute(self) -> tp.Any: return self.value model = eg.Model( module=eg.Linear(1), loss=dict(a=mse()), metrics=dict(b=mae()), optimizer=optax.adamw(1e-3), eager=True, ) X = np.random.uniform(size=(5, 2)) y = np.random.uniform(size=(5, 1)) logs = model.evaluate(x=X, y=y) assert "a/mse_loss" in logs assert "b/mae" in logs assert "loss" in logs
def test_predict(self): model = eg.Model(module=eg.Linear(1)) X = np.random.uniform(size=(5, 2)) y = np.random.randint(10, size=(5, 1)) y_pred = model.predict(X) assert y_pred.shape == (5, 1)
def __call__(self, x: jnp.ndarray): # normalize x = x.astype(jnp.float32) / 255.0 # base x = ConvBlock()(x, 32, (3, 3)) x = ConvBlock()(x, 64, (3, 3), stride=2) x = ConvBlock()(x, 64, (3, 3), stride=2) x = ConvBlock()(x, 128, (3, 3), stride=2) # GlobalAveragePooling2D x = jnp.mean(x, axis=(1, 2)) # 1x1 Conv x = eg.Linear(10)(x) return x
def test_distributed_init(self): n_devices = jax.device_count() batch_size = 5 * n_devices x = np.random.uniform(size=(batch_size, 1)) y = 1.4 * x + 0.1 * np.random.uniform(size=(batch_size, 2)) model = eg.Model( eg.Linear(2), loss=[eg.losses.MeanSquaredError()], ) model = model.distributed() model.init_on_batch(x) assert model.module.kernel.shape == (n_devices, 1, 2) assert model.module.bias.shape == (n_devices, 2)
def test_saved_model(self): with TemporaryDirectory() as model_dir: model = eg.Model(module=eg.Linear(4)) x = np.random.uniform(size=(5, 6)) model.merge model.saved_model(x, model_dir, batch_size=[1, 2, 4, 8]) output = str(sh.ls(model_dir)) assert "saved_model.pb" in output assert "variables" in output saved_model = tf.saved_model.load(model_dir) saved_model
def test_saved_model_poly(self): with TemporaryDirectory() as model_dir: model = eg.Model(module=eg.Linear(4)) x = np.random.uniform(size=(5, 6)).astype(np.float32) model.saved_model(x, model_dir, batch_size=None) output = str(sh.ls(model_dir)) assert "saved_model.pb" in output assert "variables" in output saved_model = tf.saved_model.load(model_dir) # change batch x = np.random.uniform(size=(3, 6)).astype(np.float32) y = saved_model(x) assert y.shape == (3, 4)
def test_cloudpickle(self): model = eg.Model( module=eg.Linear(10), loss=[ eg.losses.Crossentropy(), eg.regularizers.L2(1e-4), ], metrics=eg.metrics.Accuracy(), optimizer=optax.adamw(1e-3), eager=True, ) X = np.random.uniform(size=(5, 2)) y = np.random.randint(10, size=(5, )) y0 = model.predict(X) with TemporaryDirectory() as model_dir: model.save(model_dir) newmodel = eg.load(model_dir) y1 = newmodel.predict(X) assert np.all(y0 == y1)
def __call__(self, x): x = eg.Linear(10)(x) x = jax.lax.stop_gradient(x) return x