def __call__( self, input_ids: jnp.ndarray, input_mask: jnp.ndarray, type_ids: jnp.ndarray, *, deterministic: bool = False, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Applies BERT model on the inputs.""" word_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(input_ids) type_embeddings = self.type_embeddings(type_ids) embeddings = word_embeddings + position_embeddings + type_embeddings embeddings = self.embeddings_layer_norm(embeddings) embeddings = self.embeddings_dropout(embeddings, deterministic=deterministic) hidden_states = embeddings mask = input_mask.astype(jnp.int32) for transformer_block in self.encoder_layers: hidden_states = transformer_block(hidden_states, mask, deterministic=deterministic) pooled_output = self.pooler(hidden_states[:, 0]) pooled_output = jnp.tanh(pooled_output) return hidden_states, pooled_output
def mean_absolute_percentage_error(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray: """ Computes the mean absolute percentage error (MAPE) between labels and predictions. After computing the absolute distance between the true value and the prediction value and divide by the true value, the mean value over the last dimension is returned. Usage: ```python rng = jax.random.PRNGKey(42) y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2) y_pred = jax.random.uniform(rng, shape=(2, 3)) loss = elegy.losses.mean_absolute_percentage_error(y_true, y_pred) assert loss.shape == (2,) assert jnp.array_equal(loss, 100. * jnp.mean(jnp.abs((y_pred - y_true) / jnp.clip(y_true, types.EPSILON, None)))) ``` Arguments: y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. Returns: Mean absolute percentage error values. shape = `[batch_size, d0, .. dN-1]`. """ y_true = y_true.astype(y_pred.dtype) diff = jnp.abs( (y_pred - y_true) / jnp.maximum(jnp.abs(y_true), types.EPSILON)) return 100.0 * jnp.mean(diff, axis=-1)
def __call__(self, x: jnp.ndarray, training: bool) -> jnp.ndarray: # Normalize the input x = x.astype(jnp.float32) / 255.0 # Block 1 x = linen.Conv(32, [3, 3], strides=[2, 2])(x) x = linen.Dropout(0.05, deterministic=not training)(x) x = jax.nn.relu(x) # Block 2 x = linen.Conv(64, [3, 3], strides=[2, 2])(x) x = linen.BatchNorm(use_running_average=not training)(x) x = linen.Dropout(0.1, deterministic=not training)(x) x = jax.nn.relu(x) # Block 3 x = linen.Conv(128, [3, 3], strides=[2, 2])(x) # Global average pooling x = x.mean(axis=(1, 2)) # Classification layer x = linen.Dense(10)(x) return x
def call(self, image: jnp.ndarray, training: bool): @elegy.to_module def ConvBlock(x, units, kernel, stride=1): x = elegy.nn.Conv2D(units, kernel, stride=stride, padding="same")(x) x = elegy.nn.BatchNormalization()(x, training) x = elegy.nn.Dropout(0.2)(x, training) return jax.nn.relu(x) x: np.ndarray = image.astype(jnp.float32) / 255.0 # base x = ConvBlock()(x, 32, [3, 3]) x = ConvBlock()(x, 64, [3, 3], stride=2) x = ConvBlock()(x, 64, [3, 3], stride=2) x = ConvBlock()(x, 128, [3, 3], stride=2) # GlobalAveragePooling2D x = jnp.mean(x, axis=[1, 2]) # 1x1 Conv x = elegy.nn.Linear(10)(x) return x
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # Normalize the input x = x.astype(jnp.float32) / 255.0 # Block 1 x = eg.Conv(32, [3, 3], strides=[2, 2])(x) x = eg.Dropout(0.05)(x) x = jax.nn.relu(x) # Block 2 x = eg.Conv(64, [3, 3], strides=[2, 2])(x) x = eg.BatchNorm()(x) x = eg.Dropout(0.1)(x) x = jax.nn.relu(x) # Block 3 x = eg.Conv(128, [3, 3], strides=[2, 2])(x) # Global average pooling x = x.mean(axis=(1, 2)) # Classification layer x = eg.Linear(10)(x) return x
def precision( y_true: jnp.ndarray, y_pred: jnp.ndarray, threshold: jnp.ndarray, class_id: jnp.ndarray, sample_weight: jnp.ndarray, true_positives: ReduceConfusionMatrix, false_positives: ReduceConfusionMatrix, ) -> jnp.ndarray: # TODO: class_id behavior y_pred = (y_pred > threshold).astype(jnp.float32) if y_true.dtype != y_pred.dtype: y_pred = y_pred.astype(y_true.dtype) true_positives = true_positives(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) false_positives = false_positives(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) return jnp.nan_to_num( jnp.divide(true_positives, true_positives + false_positives))
def quantize(x: jnp.ndarray, to_type: str): assert to_type in ["float16", "float32", "uint16", "uint8"] if "int" in to_type: max_int = 2 ** 8 - 1 if to_type == "uint8" else 2 ** 16 - 1 return to_type, int_quantize_jit(x, max_int, to_type) else: return to_type, x.astype(to_type)
def fit( self, df: pd.DataFrame, sampler: str = "NUTS", rng_key: np.ndarray = None, sampler_kwargs: typing.Dict[str, typing.Any] = None, **mcmc_kwargs, ): """Fit the model to a DataFrame. Parameters ---------- df : pd.DataFrame Source dataframe. sampler : str Numpyro sampler name. Default NUTS rng_key : two-element ndarray. Optional rng key, will be randomly splitted if not provided. sampler_kwargs : Passed to the numpyro sampler selected. **mcmc_kwargs : Passed to numpyro.infer.MCMC Returns ------- Model The fitted model. """ if self.fitted: raise exceptions.AlreadyFittedError(self) if sampler.upper() not in ("NUTS", "HMC"): raise ValueError("Invalid sampler, try NUTS or HMC.") sampler = getattr(infer, sampler.upper()) # store fit df self.df = df # set up mcmc _mcmc_kwargs = dict(num_warmup=500, num_samples=1000) _mcmc_kwargs.update(mcmc_kwargs) _sampler_kwargs = dict(model=self.model) _sampler_kwargs.update(sampler_kwargs or {}) mcmc = infer.MCMC(sampler(**_sampler_kwargs), **_mcmc_kwargs) # do it rng_key_ = (self.split_rand_key() if rng_key is None else rng_key.astype("uint32")) mcmc.run(rng_key_, df=df) # store results self.samples = mcmc.get_samples(group_by_chain=True) self.fitted = True return self
def __call__(self, x: jnp.ndarray): x = x.astype(jnp.float32) / 255.0 x = einops.rearrange(x, "batch ... -> batch (...)") x = eg.nn.Linear(self.n1)(x) x = jax.nn.relu(x) x = eg.nn.Linear(self.n2)(x) x = jax.nn.relu(x) x = eg.nn.Linear(10)(x) return x
def sample_posterior_predictive( self, df: pd.DataFrame, hdpi: bool = False, hdpi_interval: float = 0.9, rng_key: np.ndarray = None, ) -> typing.Union[pd.Series, pd.DataFrame]: """Obtain samples from the posterior predictive. Parameters ---------- df : pd.DataFrame Source dataframe. hdpi : bool Option to include lower/upper bound of the highest posterior density interval. Returns a dataframe if true, a series if false. Default False. hdpi_interval : float HDPI width. Default 0.9. rng_key : two-element ndarray. Optional rng key, will be randomly splitted if not provided. Returns ------- pd.Series or pd.DataFrame Forecasts. Will be a series with the name of the dv if no HDPI. Will be a dataframe if HDPI is included. """ # get rng key rng_key_ = (self.split_rand_key() if rng_key is None else rng_key.astype("uint32")) # check for nulls null_cols = columns_with_null_data(self.transform(df)) if null_cols: raise exceptions.NullDataFound(*null_cols) # do it predictions = infer.Predictive(self.model, self.samples_flat)(rng_key_, df=df)[self.dv] if not hdpi: return pd.Series(predictions.mean(axis=0), index=df.index, name=self.dv) hdpi = diagnostics.hpdi(predictions, hdpi_interval) return pd.DataFrame( { self.dv: predictions.mean(axis=0), "hdpi_lower": hdpi[0, :], "hdpi_upper": hdpi[1, :], }, index=df.index, )
def accuracy(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray: # [y_pred, y_true], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values( # [y_pred, y_true] # ) # y_pred.shape.assert_is_compatible_with(y_true.shape) if y_true.dtype != y_pred.dtype: y_pred = y_pred.astype(y_true.dtype) return (y_true == y_pred).astype(jnp.float32)
def __call__(self, x: jnp.ndarray) -> VAEOutput: x = x.astype(jnp.float32) mean, stddev = Encoder(self._hidden_size, self._latent_size)(x) z = mean + stddev * jax.random.normal(hk.next_rng_key(), mean.shape) logits = Decoder(self._hidden_size, self._output_shape)(z) p = jax.nn.sigmoid(logits) image = jax.random.bernoulli(hk.next_rng_key(), p) return VAEOutput(image, mean, stddev, logits)
def __call__(self, x: jnp.ndarray) -> dict: next_key = eg.KeySeq() x = x.astype(jnp.float32) z = Encoder(self.hidden_size, self.latent_size)(x) logits = Decoder(self.hidden_size, self.output_shape)(z) p = jax.nn.sigmoid(logits) image = jax.random.bernoulli(next_key(), p) return dict(image=image, logits=logits, det_image=p)
def __call__(self, x: jnp.ndarray): x = x.astype(jnp.float32) / 255.0 x = eg.Flatten()(x) x = eg.Linear(self.n1)(x) x = jax.nn.relu(x) x = eg.Linear(self.n2)(x) x = jax.nn.relu(x) x = eg.Linear(10)(x) return x
def call(self, image: jnp.ndarray, training: bool): x = image.astype(jnp.float32) / 255.0 x = jnp.reshape(x, [x.shape[0], -1]) x = elegy.nn.Linear(self.n1)(x) x = elegy.nn.BatchNormalization()(x) x = jax.nn.relu(x) x = elegy.nn.Linear(self.n2)(x) x = jax.nn.relu(x) x = elegy.nn.Linear(10)(x) return x
def call(self, image: jnp.ndarray): image = image.astype(jnp.float32) / 255.0 mlp = hk.Sequential([ hk.Flatten(), hk.Linear(self.n1), jax.nn.relu, hk.Linear(self.n2), jax.nn.relu, hk.Linear(10), ]) return dict(outputs=mlp(image))
def call(self, image: jnp.ndarray): image = image.astype(jnp.float32) / 255.0 mlp = elegy.nn.sequential( elegy.nn.Flatten(), elegy.nn.Linear(self.n1), jax.nn.relu, elegy.nn.Linear(self.n2), jax.nn.relu, elegy.nn.Linear(10), ) return mlp(image)
def __call__(self, image: jnp.ndarray): x = image.astype(jnp.float32) / 255.0 x = eg.Flatten()(x) x = eg.Linear(self.n1)(x) x = jax.nn.relu(x) x = eg.Linear(self.n2)(x) x = jax.nn.relu(x) x = eg.Linear(self.n1)(x) x = jax.nn.relu(x) x = eg.Linear(np.prod(image.shape[-2:]))(x) x = jax.nn.sigmoid(x) * 255 x = x.reshape(image.shape) return x
def call(self, image: jnp.ndarray): image = image.astype(jnp.float32) / 255.0 x = elegy.nn.Flatten()(image) x = elegy.nn.sequential( elegy.nn.Linear(self.n1), jax.nn.relu, elegy.nn.Linear(self.n2), jax.nn.relu, elegy.nn.Linear(self.n1), jax.nn.relu, elegy.nn.Linear(x.shape[-1]), jax.nn.sigmoid, )(x) return x.reshape(image.shape) * 255
def call(self, x: jnp.ndarray): batch_size = x.shape[0] # normalize data x = x.astype(jnp.float32) / 255.0 # make patch embeddings x = einops.rearrange(x, "batch (h1 h2) (w1 w2) -> batch (h1 w1) (h2 w2)", h2=7, w2=7) x = elegy.nn.Linear(self.size)(x) # add predict token predict_token = jnp.zeros(shape=[batch_size, 1, self.size]) x = jnp.concatenate([predict_token, x], axis=1) # create positional embeddings positional_embeddings = self.add_parameter( "positional_embeddings", lambda: elegy.initializers.TruncatedNormal() (x.shape[-2:], jnp.float32), ) positional_embeddings = einops.repeat( positional_embeddings, "... -> batch ...", batch=batch_size, ) # add positional embeddings x = x + positional_embeddings # apply N transformers encoder layers x = elegy.nn.transformers.TransformerEncoder( lambda: elegy.nn.transformers.TransformerEncoderLayer( head_size=self.size, num_heads=self.num_heads, dropout=self.dropout, ), num_layers=self.num_layers, )(x) # get predict output token x = x[:, 0] # apply predict head logits = elegy.nn.Linear(10)(x) return logits
def __call__(self, state: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: kernel_initializer = jax.nn.initializers.glorot_uniform() # Preprocess inputs a = action.reshape(-1) # flatten x = state.astype(jnp.float32) x = x.reshape(-1) # flatten x = jnp.concatenate((x, a)) for _ in range(self.num_layers): x = nn.Dense(features=self.hidden_units, kernel_init=kernel_initializer)(x) x = nn.relu(x) return nn.Dense(features=1, kernel_init=kernel_initializer)(x)
def _samplewise_log_loss(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray: """Based on: https://github.com/scikit-learn/scikit-learn/blob/ffbb1b4a0bbb58fdca34a30856c6f7faace87c67/sklearn /metrics/_classification.py#L2123""" if y_true.ndim == 0: # If no dimension binary classification problem y_true = y_true.reshape(1)[:, jnp.newaxis] y_pred = y_pred.reshape(1)[:, jnp.newaxis] if y_true.shape[0] == 1: # Reshuffle data to compute log loss correctly y_true = jnp.append(1 - y_true, y_true) y_pred = jnp.append(1 - y_pred, y_pred) # Clipping eps = 1e-15 y_pred = y_pred.astype(jnp.float32).clip(eps, 1 - eps) loss = (y_true * -jnp.log(y_pred)).sum() return loss
def actor(self, state: jnp.ndarray, key: jnp.ndarray) -> SacActorOutput: """Calls the SAC actor network. This can be called using network_def.apply(..., method=network_def.actor). Args: state: An input state. key: A PRNGKey to use to sample an action from the actor's output distribution. Returns: A named tuple containing a sampled action, the mean action, and the likelihood of the sampled action. """ # Preprocess inputs x = state.astype(jnp.float32) x = x.reshape(-1) # flatten for layer in self._actor_layers: x = layer(x) x = nn.relu(x) # Note we are only producing a diagonal covariance matrix, not a full # covariance matrix as it is difficult to ensure that it would be PSD. loc_and_scale_diag = self._actor_final_layer(x) loc, scale_diag = jnp.split(loc_and_scale_diag, 2) # Exponentiate to only get positive terms. scale_diag = jnp.exp(scale_diag) dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag) if self.action_limits is None: mode = dist.mode() else: lower_action_limit = jnp.asarray(self.action_limits[0], dtype=jnp.float32) upper_action_limit = jnp.asarray(self.action_limits[1], dtype=jnp.float32) mean = (lower_action_limit + upper_action_limit) / 2.0 magnitude = (upper_action_limit - lower_action_limit) / 2.0 mode = magnitude * jnp.tanh(dist.mode()) + mean dist = _transform_distribution(dist, mean, magnitude) sampled_action = dist.sample(seed=key) action_probability = dist.log_prob(sampled_action) mode = jnp.reshape(mode, self.action_shape) sampled_action = jnp.reshape(sampled_action, self.action_shape) return SacActorOutput(mode, sampled_action, action_probability)
def __call__(self, x: jnp.ndarray): # normalize x = x.astype(jnp.float32) / 255.0 # base x = ConvBlock()(x, 32, (3, 3)) x = ConvBlock()(x, 64, (3, 3), stride=2) x = ConvBlock()(x, 64, (3, 3), stride=2) x = ConvBlock()(x, 128, (3, 3), stride=2) # GlobalAveragePooling2D x = jnp.mean(x, axis=(1, 2)) # 1x1 Conv x = eg.Linear(10)(x) return x
def learning_schedule(global_step: jnp.ndarray, batch_size: int, base_learning_rate: float, total_steps: int, warmup_steps: int) -> float: """Cosine learning rate scheduler.""" # Compute LR & Scaled LR scaled_lr = base_learning_rate * batch_size / 256. learning_rate = ( global_step.astype(jnp.float32) / int(warmup_steps) * scaled_lr if warmup_steps > 0 else scaled_lr) # Cosine schedule after warmup. return jnp.where( global_step < warmup_steps, learning_rate, _cosine_decay(global_step - warmup_steps, total_steps - warmup_steps, scaled_lr))
def _encode_bow(self, bow: jnp.ndarray) -> jnp.ndarray: """Encode the bag-of-words into tensors that can be used by the transormer. Args: bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector. Returns: embeddings: [batch_size, bow_n_tokens, bow_embedding_dim] tensor. """ batch_size = bow.shape[0] bow = bow.astype(jnp.float32) # [B, D * n] embeddings = hk.Linear(self._bow_embedding_dim * self._bow_n_tokens)(bow) embeddings = transformer_block.layer_norm(jax.nn.gelu(embeddings)) return jnp.reshape( embeddings, [batch_size, self._bow_n_tokens, self._bow_embedding_dim])
def __call__(self, x: jnp.ndarray) -> VAEOutput: x = x.astype(jnp.float32) # q(z|x) = N(mean(x), covariance(x)) mean, stddev = Encoder(self._hidden_size, self._latent_size)(x) variational_distrib = distrax.MultivariateNormalDiag(loc=mean, scale_diag=stddev) z = variational_distrib.sample(seed=hk.next_rng_key()) # p(x|z) = \Prod Bernoulli(logits(z)) logits = Decoder(self._hidden_size, self._output_shape)(z) likelihood_distrib = distrax.Independent( distrax.Bernoulli(logits=logits), reinterpreted_batch_ndims=len( self._output_shape)) # 3 non-batch dims # Generate images from the likelihood image = likelihood_distrib.sample(seed=hk.next_rng_key()) return VAEOutput(variational_distrib, likelihood_distrib, image)
def learning_schedule( global_step: jnp.ndarray, base_learning_rate: float, total_steps: int, warmup_steps: int, use_schedule: bool, ) -> float: """Cosine learning rate scheduler.""" # Compute LR & Scaled LR if not use_schedule: return base_learning_rate warmup_learning_rate = (global_step.astype(jnp.float32) / int(warmup_steps) * base_learning_rate if warmup_steps > 0 else base_learning_rate) # Cosine schedule after warmup. decay_learning_rate = _cosine_decay(global_step - warmup_steps, total_steps - warmup_steps, base_learning_rate) return jnp.where(global_step < warmup_steps, warmup_learning_rate, decay_learning_rate)
def mean_squared_logarithmic_error(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray: """ Computes the mean squared logarithmic error between labels and predictions. ```python loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1) ``` Usage: ```python rng = jax.random.PRNGKey(42) y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2) y_pred = jax.random.uniform(rng, shape=(2, 3)) loss = elegy.losses.mean_squared_logarithmic_error(y_true, y_pred) assert loss.shape == (2,) first_log = jnp.log(jnp.maximum(y_true, types.EPSILON) + 1.0) second_log = jnp.log(jnp.maximum(y_pred, types.EPSILON) + 1.0) assert jnp.array_equal(loss, jnp.mean(jnp.square(first_log - second_log), axis=-1)) ``` Arguments: y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. Returns: Mean squared logarithmic error values. shape = `[batch_size, d0, .. dN-1]`. """ y_true = y_true.astype(y_pred.dtype) first_log = jnp.log(jnp.maximum(y_true, types.EPSILON) + 1.0) second_log = jnp.log(jnp.maximum(y_pred, types.EPSILON) + 1.0) return jnp.mean(jnp.square(first_log - second_log), axis=-1)
def mean_squared_error(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray: """ Computes the mean squared error between labels and predictions. After computing the squared distance between the inputs, the mean value over the last dimension is returned. ```python loss = mean(square(y_true - y_pred), axis=-1) ``` Usage: ```python rng = jax.random.PRNGKey(42) y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2) y_pred = jax.random.uniform(rng, shape=(2, 3)) loss = elegy.losses.mean_squared_error(y_true, y_pred) assert loss.shape == (2,) assert jnp.array_equal(loss, jnp.mean(jnp.square(y_true - y_pred), axis=-1)) ``` Arguments: y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. Returns: Mean squared error values. shape = `[batch_size, d0, .. dN-1]`. """ y_true = y_true.astype(y_pred.dtype) return jnp.mean(jnp.square(y_pred - y_true), axis=-1)