def __call__(self, inputs: jnp.ndarray, multiplier: FloatLike = None): """Adds bias to `inputs` and optionally multiplies by `multiplier`. Args: inputs: A Tensor of size `[batch_size, input_size1, ...]`. multiplier: A scalar or Tensor which the bias term is multiplied by before adding it to `inputs`. Anything which works in the expression `bias * multiplier` is acceptable here. This may be useful if you want to add a bias in one place and subtract the same bias in another place via `multiplier=-1`. Returns: A Tensor of size `[batch_size, input_size1, ...]`. """ utils.assert_minimum_rank(inputs, 2) input_shape = inputs.shape self.bias_shape = calculate_bias_shape(input_shape, self.bias_dims) input_size = input_shape[1:] if self.output_size is not None and self.output_size != input_size: raise ValueError("Input shape must be {} not {}".format( (-1, ) + self.output_size, input_shape)) self.input_size = input_size b = base.get_parameter("b", self.bias_shape, inputs.dtype, init=self.b_init) b = jnp.broadcast_to(b, inputs.shape) if multiplier is not None: return inputs + (b * multiplier) else: return inputs + b
def __call__( self, inputs: jnp.ndarray, multiplier: Union[float, jnp.ndarray] = None, ) -> jnp.ndarray: """Adds bias to ``inputs`` and optionally multiplies by ``multiplier``. Args: inputs: A Tensor of size ``[batch_size, input_size1, ...]``. multiplier: A scalar or Tensor which the bias term is multiplied by before adding it to ``inputs``. Anything which works in the expression ``bias * multiplier`` is acceptable here. This may be useful if you want to add a bias in one place and subtract the same bias in another place via ``multiplier=-1``. Returns: A Tensor of size ``[batch_size, input_size1, ...]``. """ utils.assert_minimum_rank(inputs, 2) if self.output_size is not None and self.output_size != inputs.shape[ 1:]: raise ValueError( f"Input shape must be {(-1,) + self.output_size} not {inputs.shape}" ) self.bias_shape = calculate_bias_shape(inputs.shape, self.bias_dims) self.input_size = inputs.shape[1:] b = hk.get_parameter("b", self.bias_shape, inputs.dtype, init=self.b_init) b = jnp.broadcast_to(b, inputs.shape) if multiplier is not None: b = b * multiplier return inputs + b