예제 #1
0
    def _build(self, inputs):
        """Connects the Add module into the graph, with input Tensor `inputs`.

    Args:
      inputs: A Tensor of size `[batch_size, input_size1, ...]`.

    Returns:
      A Tensor of size `[batch_size, input_size1, ...]`.

    Raises:
      base.IncompatibleShapeError: If the input is not a >= 2D `Tensor`.
      base.IncompatibleShapeError: If connecting the module into the graph
          any time after the first time, and the inferred size of the input does
          not match previous invocations.
      base.IncompatibleShapeError: If the `output_shape` has been specified
          but it does not match the input_shape`.
      base.ParentNotBuiltError: If the module is a transposed and the original
          untransposed module has not been built.
    """
        input_shape = tuple(inputs.get_shape().as_list())
        bias_shape = calculate_bias_shape(input_shape, self._bias_dims)

        # Check always contains minibatched input.
        if len(input_shape) < 2:
            raise base.IncompatibleShapeError(
                "Rank of input shape must be >=2 not: {}.".format(
                    len(input_shape)))

        # Check previous input size is same as new input size.
        if (self._input_shape is not None
                and input_shape[1:] != self._input_shape[1:]):
            raise base.IncompatibleShapeError("Input shape has changed.")

        # If transposed, make sure that the original Module is built.
        if callable(self._output_shape):
            self._output_shape = self._output_shape()
            if self._output_shape is None:
                raise base.ParentNotBuiltError(
                    "Build the original untransposed module before building this one."
                )

        # If output_shape specified, check that it matches input_shape.
        if (self._output_shape is not None
                and self._output_shape[1:] != input_shape[1:]):
            raise base.IncompatibleShapeError(
                "Input shape must be {} not: {}.".format(
                    self._output_shape, input_shape[1]))

        self._input_shape = input_shape

        if "b" not in self._initializers:
            self._initializers["b"] = create_bias_initializer(bias_shape)

        dtype = inputs.dtype
        self._b = tf.get_variable(
            "b",
            shape=bias_shape,
            dtype=dtype,
            initializer=self._initializers["b"],
            partitioner=self._partitioners.get("b", None),
            regularizer=self._regularizers.get("b", None))

        outputs = inputs + self._b
        return outputs
예제 #2
0
  def _build(self, inputs, multiplier=1):
    """Connects the Add module into the graph, with input Tensor `inputs`.

    Args:
      inputs: A Tensor of size `[batch_size, input_size1, ...]`.
      multiplier: A scalar or Tensor which the bias term is multiplied by
        before adding it to `inputs`. Anything which works in the expression
        `bias * multiplier` is acceptable here. This may be useful if you want
        to add a bias in one place and subtract the same bias in another place
        via `multiplier=-1`.

    Returns:
      A Tensor of size `[batch_size, input_size1, ...]`.

    Raises:
      base.IncompatibleShapeError: If the input is not a >= 2D `Tensor`.
      base.IncompatibleShapeError: If connecting the module into the graph
          any time after the first time, and the inferred size of the input does
          not match previous invocations.
      base.IncompatibleShapeError: If the `output_shape` has been specified
          but it does not match the input_shape`.
      base.ParentNotBuiltError: If the module is a transposed and the original
          untransposed module has not been built.
    """
    input_shape = tuple(inputs.get_shape().as_list())
    bias_shape = calculate_bias_shape(input_shape, self._bias_dims)

    # Check always contains minibatched input.
    if len(input_shape) < 2:
      raise base.IncompatibleShapeError(
          "Rank of input shape must be >=2 not: {}.".format(len(input_shape)))

    # Check previous input size is same as new input size.
    if (self._input_shape is not None and
        input_shape[1:] != self._input_shape[1:]):
      raise base.IncompatibleShapeError("Input shape has changed.")

    # If transposed, make sure that the original Module is built.
    if callable(self._output_shape):
      self._output_shape = self._output_shape()
      if self._output_shape is None:
        raise base.ParentNotBuiltError(
            "Build the original untransposed module before building this one.")

    # If output_shape specified, check that it matches input_shape.
    if (self._output_shape is not None and
        self._output_shape[1:] != input_shape[1:]):
      raise base.IncompatibleShapeError(
          "Input shape must be {} not: {}.".format(self._output_shape,
                                                   input_shape[1]))

    self._input_shape = input_shape
    dtype = inputs.dtype

    if "b" not in self._initializers:
      self._initializers["b"] = create_bias_initializer(bias_shape, dtype)

    self._b = tf.get_variable(
        "b",
        shape=bias_shape,
        dtype=dtype,
        initializer=self._initializers["b"],
        partitioner=self._partitioners.get("b", None),
        regularizer=self._regularizers.get("b", None))

    bias = self._b
    if multiplier != 1:
      bias = bias * multiplier  # pylint: disable=g-no-augmented-assignment
    outputs = inputs + bias
    return outputs