Esempio n. 1
0
    def test_fn(
        self,
        x: tp.Any = (),
        y: tp.Any = None,
        sample_weight: tp.Optional[np.ndarray] = None,
        class_weight: tp.Optional[np.ndarray] = None,
        get_gradients: bool = False,
    ) -> tp.Tuple[np.ndarray, tp.Dict, tp.Optional[tp.Dict]]:

        if get_gradients:
            (loss, y_pred, total_loss_logs), grads = module.value_and_grad(
                self.loss_fn, modules=self.module
            )(x, y, sample_weight, class_weight)
        else:
            grads = None
            loss, y_pred, total_loss_logs = self.loss_fn(
                x, y, sample_weight, class_weight
            )

        logs = self.metrics(
            total_loss_logs,
            x=x,
            y_true=y,
            y_pred=y_pred,
            sample_weight=sample_weight,
            class_weight=class_weight,
            training=module.is_training(),
            parameters=self.module.get_parameters(trainable=True),
            states=self.module.get_parameters(trainable=False),
        )

        return loss, logs, grads
Esempio n. 2
0
    def call(
        self,
        x: np.ndarray,
        training: tp.Optional[bool] = None,
        rng: tp.Optional[np.ndarray] = None,
    ) -> jnp.ndarray:
        """
        Arguments:
            x: The value to be dropped out.
            training: Whether training is currently happening.
            rng: Optional RNGKey.
        Returns:
            x but dropped out and scaled by `1 / (1 - rate)`.
        """
        if training is None:
            training = module.is_training()

        return hk.dropout(
            rng=rng if rng is not None else module.next_rng_key(),
            rate=self.rate if training else 0.0,
            x=x,
        )
Esempio n. 3
0
    def loss_fn(
        self,
        x: tp.Any = (),
        y: tp.Any = None,
        sample_weight: tp.Optional[np.ndarray] = None,
        class_weight: tp.Optional[np.ndarray] = None,
    ):
        y_pred = self.predict_fn(x)

        if self.loss is not None:
            loss_logs = self.loss(
                x=x,
                y_true=y,
                y_pred=y_pred,
                sample_weight=sample_weight,
                class_weight=class_weight,
                training=module.is_training(),
                parameters=self.module.get_parameters(trainable=True),
                states=self.module.get_parameters(trainable=False),
            )
        else:
            loss_logs = {}

        hooks_losses_logs = module.get_losses()

        if hooks_losses_logs is None:
            hooks_losses_logs = {}

        loss = sum(loss_logs.values()) + sum(hooks_losses_logs.values())

        total_loss_logs = {}
        total_loss_logs.update(hooks_losses_logs)
        total_loss_logs.update(loss_logs)
        total_loss_logs["loss"] = loss

        return loss, y_pred, total_loss_logs
Esempio n. 4
0
    def predict_fn(self, x: tp.Any = ()):

        x_args, x_kwargs = utils.get_input_args(x, training=module.is_training())
        y_pred = utils.inject_dependencies(self)(*x_args, **x_kwargs)

        return y_pred
Esempio n. 5
0
    def call(
        self,
        inputs: jnp.ndarray,
        training: tp.Optional[bool] = None,
        test_local_stats: bool = False,
        scale: Optional[jnp.ndarray] = None,
        offset: Optional[jnp.ndarray] = None,
    ) -> jnp.ndarray:
        """Computes the normalized version of the input.

        Args:
            inputs: An array, where the data format is ``[..., C]``.
            training: Whether training is currently happening.
            test_local_stats: Whether local stats are used when training=False.
            scale: An array up to n-D. The shape of this tensor must be broadcastable
                to the shape of ``inputs``. This is the scale applied to the normalized
                inputs. This cannot be passed in if the module was constructed with
                ``create_scale=True``.
            offset: An array up to n-D. The shape of this tensor must be broadcastable
                to the shape of ``inputs``. This is the offset applied to the normalized
                inputs. This cannot be passed in if the module was constructed with
                ``create_offset=True``.

        Returns:
            The array, normalized across all but the last dimension.
        """
        if training is None:
            training = module.is_training()

        if self.create_scale and scale is not None:
            raise ValueError("Cannot pass `scale` at call time if `create_scale=True`.")
        if self.create_offset and offset is not None:
            raise ValueError(
                "Cannot pass `offset` at call time if `create_offset=True`."
            )

        channel_index = self.channel_index
        if channel_index < 0:
            channel_index += inputs.ndim

        if self.axis is not None:
            axis = self.axis
        else:
            axis = [i for i in range(inputs.ndim) if i != channel_index]

        if training or test_local_stats:
            cross_replica_axis = self.cross_replica_axis
            if self.cross_replica_axis:
                mean = jnp.mean(inputs, axis, keepdims=True)
                mean = jax.lax.pmean(mean, cross_replica_axis)
                mean_of_squares = jnp.mean(inputs ** 2, axis, keepdims=True)
                mean_of_squares = jax.lax.pmean(mean_of_squares, cross_replica_axis)
                var = mean_of_squares - mean ** 2
            else:
                mean = jnp.mean(inputs, axis, keepdims=True)
                # This uses E[(X - E[X])^2].
                # TODO(tycai): Consider the faster, but possibly less stable
                # E[X^2] - E[X]^2 method.
                var = jnp.var(inputs, axis, keepdims=True)
        else:
            mean = self.mean_ema.average
            var = self.var_ema.average

        if training:
            self.mean_ema(mean)
            self.var_ema(var)

        w_shape = [1 if i in axis else inputs.shape[i] for i in range(inputs.ndim)]
        w_dtype = inputs.dtype

        if self.create_scale:
            scale = self.add_parameter("scale", w_shape, w_dtype, self.scale_init)
        elif scale is None:
            scale = np.ones([], dtype=w_dtype)

        if self.create_offset:
            offset = self.add_parameter("offset", w_shape, w_dtype, self.offset_init)
        elif offset is None:
            offset = np.zeros([], dtype=w_dtype)

        inv = scale * jax.lax.rsqrt(var + self.eps)
        return (inputs - mean) * inv + offset