Exemple #1
0
    def __call__(self, inputs: jnp.ndarray, multiplier: FloatLike = None):
        """Adds bias to `inputs` and optionally multiplies by `multiplier`.

    Args:
      inputs: A Tensor of size `[batch_size, input_size1, ...]`.
      multiplier: A scalar or Tensor which the bias term is multiplied by before
        adding it to `inputs`. Anything which works in the expression `bias *
        multiplier` is acceptable here. This may be useful if you want to add a
        bias in one place and subtract the same bias in another place via
        `multiplier=-1`.

    Returns:
      A Tensor of size `[batch_size, input_size1, ...]`.
    """
        utils.assert_minimum_rank(inputs, 2)

        input_shape = inputs.shape
        self.bias_shape = calculate_bias_shape(input_shape, self.bias_dims)

        input_size = input_shape[1:]
        if self.output_size is not None and self.output_size != input_size:
            raise ValueError("Input shape must be {} not {}".format(
                (-1, ) + self.output_size, input_shape))

        self.input_size = input_size
        b = base.get_parameter("b",
                               self.bias_shape,
                               inputs.dtype,
                               init=self.b_init)
        b = jnp.broadcast_to(b, inputs.shape)

        if multiplier is not None:
            return inputs + (b * multiplier)
        else:
            return inputs + b
Exemple #2
0
    def __call__(
        self,
        inputs: jnp.ndarray,
        multiplier: Union[float, jnp.ndarray] = None,
    ) -> jnp.ndarray:
        """Adds bias to ``inputs`` and optionally multiplies by ``multiplier``.

    Args:
      inputs: A Tensor of size ``[batch_size, input_size1, ...]``.
      multiplier: A scalar or Tensor which the bias term is multiplied by before
        adding it to ``inputs``. Anything which works in the expression ``bias *
        multiplier`` is acceptable here. This may be useful if you want to add a
        bias in one place and subtract the same bias in another place via
        ``multiplier=-1``.

    Returns:
      A Tensor of size ``[batch_size, input_size1, ...]``.
    """
        utils.assert_minimum_rank(inputs, 2)
        if self.output_size is not None and self.output_size != inputs.shape[
                1:]:
            raise ValueError(
                f"Input shape must be {(-1,) + self.output_size} not {inputs.shape}"
            )

        self.bias_shape = calculate_bias_shape(inputs.shape, self.bias_dims)
        self.input_size = inputs.shape[1:]

        b = hk.get_parameter("b",
                             self.bias_shape,
                             inputs.dtype,
                             init=self.b_init)
        b = jnp.broadcast_to(b, inputs.shape)

        if multiplier is not None:
            b = b * multiplier

        return inputs + b