def test_fn( self, x: tp.Any = (), y: tp.Any = None, sample_weight: tp.Optional[np.ndarray] = None, class_weight: tp.Optional[np.ndarray] = None, get_gradients: bool = False, ) -> tp.Tuple[np.ndarray, tp.Dict, tp.Optional[tp.Dict]]: if get_gradients: (loss, y_pred, total_loss_logs), grads = module.value_and_grad( self.loss_fn, modules=self.module )(x, y, sample_weight, class_weight) else: grads = None loss, y_pred, total_loss_logs = self.loss_fn( x, y, sample_weight, class_weight ) logs = self.metrics( total_loss_logs, x=x, y_true=y, y_pred=y_pred, sample_weight=sample_weight, class_weight=class_weight, training=module.is_training(), parameters=self.module.get_parameters(trainable=True), states=self.module.get_parameters(trainable=False), ) return loss, logs, grads
def call( self, x: np.ndarray, training: tp.Optional[bool] = None, rng: tp.Optional[np.ndarray] = None, ) -> jnp.ndarray: """ Arguments: x: The value to be dropped out. training: Whether training is currently happening. rng: Optional RNGKey. Returns: x but dropped out and scaled by `1 / (1 - rate)`. """ if training is None: training = module.is_training() return hk.dropout( rng=rng if rng is not None else module.next_rng_key(), rate=self.rate if training else 0.0, x=x, )
def loss_fn( self, x: tp.Any = (), y: tp.Any = None, sample_weight: tp.Optional[np.ndarray] = None, class_weight: tp.Optional[np.ndarray] = None, ): y_pred = self.predict_fn(x) if self.loss is not None: loss_logs = self.loss( x=x, y_true=y, y_pred=y_pred, sample_weight=sample_weight, class_weight=class_weight, training=module.is_training(), parameters=self.module.get_parameters(trainable=True), states=self.module.get_parameters(trainable=False), ) else: loss_logs = {} hooks_losses_logs = module.get_losses() if hooks_losses_logs is None: hooks_losses_logs = {} loss = sum(loss_logs.values()) + sum(hooks_losses_logs.values()) total_loss_logs = {} total_loss_logs.update(hooks_losses_logs) total_loss_logs.update(loss_logs) total_loss_logs["loss"] = loss return loss, y_pred, total_loss_logs
def predict_fn(self, x: tp.Any = ()): x_args, x_kwargs = utils.get_input_args(x, training=module.is_training()) y_pred = utils.inject_dependencies(self)(*x_args, **x_kwargs) return y_pred
def call( self, inputs: jnp.ndarray, training: tp.Optional[bool] = None, test_local_stats: bool = False, scale: Optional[jnp.ndarray] = None, offset: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Computes the normalized version of the input. Args: inputs: An array, where the data format is ``[..., C]``. training: Whether training is currently happening. test_local_stats: Whether local stats are used when training=False. scale: An array up to n-D. The shape of this tensor must be broadcastable to the shape of ``inputs``. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed with ``create_scale=True``. offset: An array up to n-D. The shape of this tensor must be broadcastable to the shape of ``inputs``. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed with ``create_offset=True``. Returns: The array, normalized across all but the last dimension. """ if training is None: training = module.is_training() if self.create_scale and scale is not None: raise ValueError("Cannot pass `scale` at call time if `create_scale=True`.") if self.create_offset and offset is not None: raise ValueError( "Cannot pass `offset` at call time if `create_offset=True`." ) channel_index = self.channel_index if channel_index < 0: channel_index += inputs.ndim if self.axis is not None: axis = self.axis else: axis = [i for i in range(inputs.ndim) if i != channel_index] if training or test_local_stats: cross_replica_axis = self.cross_replica_axis if self.cross_replica_axis: mean = jnp.mean(inputs, axis, keepdims=True) mean = jax.lax.pmean(mean, cross_replica_axis) mean_of_squares = jnp.mean(inputs ** 2, axis, keepdims=True) mean_of_squares = jax.lax.pmean(mean_of_squares, cross_replica_axis) var = mean_of_squares - mean ** 2 else: mean = jnp.mean(inputs, axis, keepdims=True) # This uses E[(X - E[X])^2]. # TODO(tycai): Consider the faster, but possibly less stable # E[X^2] - E[X]^2 method. var = jnp.var(inputs, axis, keepdims=True) else: mean = self.mean_ema.average var = self.var_ema.average if training: self.mean_ema(mean) self.var_ema(var) w_shape = [1 if i in axis else inputs.shape[i] for i in range(inputs.ndim)] w_dtype = inputs.dtype if self.create_scale: scale = self.add_parameter("scale", w_shape, w_dtype, self.scale_init) elif scale is None: scale = np.ones([], dtype=w_dtype) if self.create_offset: offset = self.add_parameter("offset", w_shape, w_dtype, self.offset_init) elif offset is None: offset = np.zeros([], dtype=w_dtype) inv = scale * jax.lax.rsqrt(var + self.eps) return (inputs - mean) * inv + offset