예제 #1
0
파일: tiled_linear.py 프로젝트: yyht/lamb
 def _merge_input_sizes(self, inputs):
     inferred_input_sizes = self._inferred_input_sizes(inputs)
     if self._merged_input_sizes is None:
         declared_input_sizes = self._declared_input_sizes()
         # This is the first call to build(). Remember the input sizes
         # (only the last dimension matters for matmul).
         if not declared_input_sizes.is_compatible_with(
                 inferred_input_sizes):
             raise snt_base.IncompatibleShapeError(
                 '{}: Declared input sizes {} are incompatible '
                 'with inferred ones {}.'.format(
                     self.scope_name, declared_input_sizes.as_list(),
                     inferred_input_sizes.as_list()))
         self._merged_input_sizes = declared_input_sizes.merge_with(
             inferred_input_sizes)
         if not self._merged_input_sizes.is_fully_defined():
             raise snt_base.IncompatibleShapeError(
                 '{}: Last input dimensions must be known at module build time.'
                 ' Got {}.'.format(self.name,
                                   self._merged_input_sizes.as_list()))
     else:
         # At subsequent calls check that input sizes are compatible.
         if not self._merged_input_sizes.is_compatible_with(
                 inferred_input_sizes):
             raise snt_base.IncompatibleShapeError(
                 '{}: Current input sizes {} are different '
                 'from first build {}'.format(
                     self.name, inferred_input_sizes.as_list(),
                     self._merged_input_sizes.as_list()))
예제 #2
0
    def _build(self, input_batch, is_training, test_local_stats=False):
        """Connects the BatchNormV2 module into the graph.

    Args:
      input_batch: A Tensor of the same dimension as `len(data_format)`.
      is_training: A boolean to indicate if the module should be connected in
        training mode, meaning the moving averages are updated. Can be a Tensor.
      test_local_stats: A boolean to indicate if local batch statistics should
        be used when `is_training=False`. If not, moving averages are used.
        By default `False`. Can be a Tensor.

    Returns:
      A tensor with the same shape as `input_batch`.

    Raises:
      base.IncompatibleShapeError: If `data_format` is not valid for the
        input shape.
      base.NotSupportedError: If `input_batch` has data type of `tf.bfloat16`.
    """
        input_shape = input_batch.get_shape()

        if not self._data_format:
            if len(input_shape) == 2:
                self._data_format = "NC"
            elif len(input_shape) == 3:
                self._data_format = "NWC"
            elif len(input_shape) == 4:
                self._data_format = "NHWC"
            elif len(input_shape) == 5:
                self._data_format = "NDHWC"
            else:
                raise base.IncompatibleShapeError(
                    "Input shape {} has too many or too few dimensions.".
                    format(input_shape))

        self._channel_index = self._data_format.index("C")
        # Use list to turn range into iterator in python3.
        self._axis = list(range(len(self._data_format)))
        del self._axis[self._channel_index]

        if len(self._data_format) != len(input_shape):
            raise base.IncompatibleShapeError(
                "Incorrect data format {} for input shape {}.".format(
                    self._data_format, input_shape))

        dtype = input_batch.dtype.base_dtype
        if self._fused and dtype == tf.bfloat16:
            raise base.NotSupportedError(
                "Fused batch norm does not support tf.bfloat16.")
        # Maintain moving averages at a minimum precision of tf.float32.
        stat_dtype = tf.float32 if dtype in [tf.float16, tf.bfloat16
                                             ] else dtype

        self._num_channels = int(input_shape[self._channel_index])
        if self._channel_index == 1:
            self._image_shape = [int(x) for x in input_shape[2:]]
        else:
            self._image_shape = [int(x) for x in input_shape[1:-1]]

        self._expanded_mean_shape = [1] * len(input_shape)
        self._expanded_mean_shape[self._channel_index] = self._num_channels

        use_batch_stats = is_training | test_local_stats

        mean, variance = self._build_statistics(input_batch, use_batch_stats,
                                                stat_dtype)

        # Sets up optional gamma and beta parameters
        self._build_scale_offset(dtype)
        # Sets up the batch normalization op.
        out, mean, variance = self._batch_norm_op(input_batch, mean, variance,
                                                  use_batch_stats, stat_dtype)
        # Sets up the update op.
        update_ops = self._build_update_ops(mean, variance, is_training)

        # Put update ops in the update ops collection if given, otherwise add as
        # control dependencies of the output.
        if update_ops:
            if self._update_ops_collection:
                for update_op in update_ops:
                    tf.add_to_collection(self._update_ops_collection,
                                         update_op)
            else:
                with tf.control_dependencies(update_ops):
                    out = tf.identity(out)

        return out
예제 #3
0
파일: dilation.py 프로젝트: zofuthan/sonnet
    def _build(self, images):
        """Build dilation module.

    Args:
      images: Tensor of shape [batch_size, height, width, depth]
        and dtype float32. Represents a set of images with an arbitrary depth.
        Note that when using the default initializer, depth must equal
        num_output_classes.

    Returns:
      Tensor of shape [batch_size, height, width, num_output_classes] and dtype
        float32. Represents, for each image and pixel, logits for per-class
        predictions.

    Raises:
      IncompatibleShapeError: If images is not rank 4.
      ValueError: If model_size is not one of 'basic' or 'large'.
    """
        num_classes = self._num_output_classes

        if len(images.get_shape()) != 4:
            raise base.IncompatibleShapeError(
                "'images' must have shape [batch_size, height, width, depth].")

        if self.WEIGHTS not in self._initializers:
            if self._model_size == self.BASIC:
                self._initializers[self.WEIGHTS] = identity_kernel_initializer
            elif self._model_size == self.LARGE:
                self._initializers[
                    self.WEIGHTS] = noisy_identity_kernel_initializer(
                        num_classes)
            else:
                raise ValueError("Unrecognized model_size: %s" %
                                 self._model_size)

        if self.BIASES not in self._initializers:
            self._initializers[self.BIASES] = tf.zeros_initializer()

        if self._model_size == self.BASIC:
            self._conv_modules = [
                self._dilated_conv_layer(num_classes, 1, True, "conv1"),
                self._dilated_conv_layer(num_classes, 1, True, "conv2"),
                self._dilated_conv_layer(num_classes, 2, True, "conv3"),
                self._dilated_conv_layer(num_classes, 4, True, "conv4"),
                self._dilated_conv_layer(num_classes, 8, True, "conv5"),
                self._dilated_conv_layer(num_classes, 16, True, "conv6"),
                self._dilated_conv_layer(num_classes, 1, True, "conv7"),
                self._dilated_conv_layer(num_classes, 1, False, "conv8"),
            ]
        elif self._model_size == self.LARGE:
            self._conv_modules = [
                self._dilated_conv_layer(2 * num_classes, 1, True, "conv1"),
                self._dilated_conv_layer(2 * num_classes, 1, True, "conv2"),
                self._dilated_conv_layer(4 * num_classes, 2, True, "conv3"),
                self._dilated_conv_layer(8 * num_classes, 4, True, "conv4"),
                self._dilated_conv_layer(16 * num_classes, 8, True, "conv5"),
                self._dilated_conv_layer(32 * num_classes, 16, True, "conv6"),
                self._dilated_conv_layer(32 * num_classes, 1, True, "conv7"),
                self._dilated_conv_layer(num_classes, 1, False, "conv8"),
            ]
        else:
            raise ValueError("Unrecognized model_size: %s" % self._model_size)

        dilation_mod = sequential.Sequential(self._conv_modules,
                                             name="dilation")
        return dilation_mod(images)
  def _build(self, input_batch, is_training, test_local_stats=True):
    """Connects the BatchNorm module into the graph.

    Args:
      input_batch: A Tensor of arbitrary dimension. By default, the final
        dimension is not reduced over when computing the minibatch statistics.
      is_training: A boolean to indicate if the module should be connected in
        training mode, meaning the moving averages are updated. Can be a Tensor.
      test_local_stats: A boolean to indicate if local batch statistics should
        be used when `is_training=False`. If not, moving averages are used.
        By default `True`. Can be a Tensor.

    Returns:
      A tensor with the same shape as `input_batch`.

    Raises:
      base.IncompatibleShapeError: If `axis` is not valid for the
        input shape or has negative entries.
      base.NotSupportedError: If `input_batch` has data type of `tf.float16`.
    """
    input_shape = input_batch.get_shape()

    if self._axis is not None:
      if len(self._axis) > len(input_shape):
        raise base.IncompatibleShapeError(
            "Too many indices specified in axis: len({}) > len({}).".format(
                self._axis, input_shape))

      if max(self._axis) >= len(input_shape):
        raise base.IncompatibleShapeError(
            "One or more index in axis is too large for "
            "input shape: {} >= {:d}.".format(self._axis, len(input_shape)))

      if min(self._axis) < 0:
        raise base.IncompatibleShapeError(
            "Indices in axis must be non-negative: {} < 0.".format(
                self._axis))

      axis = self._axis
    else:
      # Reduce over all dimensions except the last.
      axis = tuple(range(len(input_shape))[:-1])

    # See following for important note on accuracy for dtype=tf.float16
    # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_impl.py#L63
    dtype = input_batch.dtype
    if dtype == tf.float16:
      raise base.NotSupportedError(
          "BatchNorm does not support `tf.float16`, insufficient "
          "precision for calculating sufficient statistics.")

    self._mean_shape = input_batch.get_shape().as_list()
    for index in axis:
      self._mean_shape[index] = 1

    use_batch_stats = is_training | test_local_stats

    mean, variance = self._build_statistics(input_batch, axis,
                                            use_batch_stats, dtype)

    # Sets up optional gamma and beta parameters
    self._build_scale_offset(dtype)
    # Sets up the batch normalization op.
    out, mean, variance = self._batch_norm_op(input_batch, mean, variance,
                                              use_batch_stats)
    # Sets up the update op.
    update_ops = self._build_update_ops(mean, variance, is_training)

    # Put update ops in the update ops collection if given, otherwise add as
    # control dependencies of the output.
    if update_ops:
      if self._update_ops_collection:
        for update_op in update_ops:
          tf.add_to_collection(self._update_ops_collection, update_op)
      else:
        with tf.control_dependencies(update_ops):
          out = tf.identity(out)

    return out
예제 #5
0
    def _build(self, inputs):
        """Connects the Add module into the graph, with input Tensor `inputs`.

    Args:
      inputs: A Tensor of size `[batch_size, input_size1, ...]`.

    Returns:
      A Tensor of size `[batch_size, input_size1, ...]`.

    Raises:
      base.IncompatibleShapeError: If the input is not a >= 2D `Tensor`.
      base.IncompatibleShapeError: If connecting the module into the graph
          any time after the first time, and the inferred size of the input does
          not match previous invocations.
      base.IncompatibleShapeError: If the `output_shape` has been specified
          but it does not match the input_shape`.
      base.ParentNotBuiltError: If the module is a transposed and the original
          untransposed module has not been built.
    """
        input_shape = tuple(inputs.get_shape().as_list())
        bias_shape = calculate_bias_shape(input_shape, self._bias_dims)

        # Check always contains minibatched input.
        if len(input_shape) < 2:
            raise base.IncompatibleShapeError(
                "Rank of input shape must be >=2 not: {}.".format(
                    len(input_shape)))

        # Check previous input size is same as new input size.
        if (self._input_shape is not None
                and input_shape[1:] != self._input_shape[1:]):
            raise base.IncompatibleShapeError("Input shape has changed.")

        # If transposed, make sure that the original Module is built.
        if callable(self._output_shape):
            self._output_shape = self._output_shape()
            if self._output_shape is None:
                raise base.ParentNotBuiltError(
                    "Build the original untransposed module before building this one."
                )

        # If output_shape specified, check that it matches input_shape.
        if (self._output_shape is not None
                and self._output_shape[1:] != input_shape[1:]):
            raise base.IncompatibleShapeError(
                "Input shape must be {} not: {}.".format(
                    self._output_shape, input_shape[1]))

        self._input_shape = input_shape

        if "b" not in self._initializers:
            self._initializers["b"] = create_bias_initializer(bias_shape)

        dtype = inputs.dtype
        self._b = tf.get_variable(
            "b",
            shape=bias_shape,
            dtype=dtype,
            initializer=self._initializers["b"],
            partitioner=self._partitioners.get("b", None),
            regularizer=self._regularizers.get("b", None))

        outputs = inputs + self._b
        return outputs
예제 #6
0
    def _build(self, inputs):
        """Connects the Linear module into the graph, with input Tensor `inputs`.

    If this is not the first time the module has been connected to the graph,
    the Tensor provided here must have the same final dimension, in order for
    the existing variables to be the correct size for the multiplication. The
    batch size may differ for each connection.

    Args:
      inputs: A 2D Tensor of size [batch_size, input_size].

    Returns:
      A 2D Tensor of size [batch_size, output_size].

    Raises:
      base.IncompatibleShapeError: If the input is not a 2-D `Tensor` with
          the size of the second dimension specified.
      base.IncompatibleShapeError: If reconnecting an already connected module
          into the graph, and the shape of the input is not compatible with
          previous inputs.
    """
        input_shape = tuple(inputs.get_shape().as_list())

        if len(input_shape) != 2:
            raise base.IncompatibleShapeError(
                "{}: rank of shape must be 2 not: {}".format(
                    self.scope_name, len(input_shape)))

        if input_shape[1] is None:
            raise base.IncompatibleShapeError(
                "{}: Input size must be specified at module build time".format(
                    self.scope_name))

        if self._input_shape is not None and input_shape[
                1] != self._input_shape[1]:
            raise base.IncompatibleShapeError(
                "{}: Input shape must be [batch_size, {}] not: [batch_size, {}]"
                .format(self.scope_name, self._input_shape[1], input_shape[1]))

        self._input_shape = input_shape

        if "w" not in self._initializers:
            self._initializers["w"] = create_linear_initializer(
                self._input_shape[1])

        if "b" not in self._initializers and self._use_bias:
            self._initializers["b"] = create_bias_initializer(
                self._input_shape[1])

        weight_shape = (self._input_shape[1], self.output_size)
        dtype = inputs.dtype
        self._w = tf.get_variable(
            "w",
            shape=weight_shape,
            dtype=dtype,
            initializer=self._initializers["w"],
            partitioner=self._partitioners.get("w", None),
            regularizer=self._regularizers.get("w", None))
        outputs = tf.matmul(inputs, self._w)

        if self._use_bias:
            bias_shape = (self.output_size, )
            self._b = tf.get_variable(
                "b",
                shape=bias_shape,
                dtype=dtype,
                initializer=self._initializers["b"],
                partitioner=self._partitioners.get("b", None),
                regularizer=self._regularizers.get("b", None))
            outputs += self._b

        return outputs
예제 #7
0
  def _build(self, inputs, multiplier=1):
    """Connects the Add module into the graph, with input Tensor `inputs`.

    Args:
      inputs: A Tensor of size `[batch_size, input_size1, ...]`.
      multiplier: A scalar or Tensor which the bias term is multiplied by
        before adding it to `inputs`. Anything which works in the expression
        `bias * multiplier` is acceptable here. This may be useful if you want
        to add a bias in one place and subtract the same bias in another place
        via `multiplier=-1`.

    Returns:
      A Tensor of size `[batch_size, input_size1, ...]`.

    Raises:
      base.IncompatibleShapeError: If the input is not a >= 2D `Tensor`.
      base.IncompatibleShapeError: If connecting the module into the graph
          any time after the first time, and the inferred size of the input does
          not match previous invocations.
      base.IncompatibleShapeError: If the `output_shape` has been specified
          but it does not match the input_shape`.
      base.ParentNotBuiltError: If the module is a transposed and the original
          untransposed module has not been built.
    """
    input_shape = tuple(inputs.get_shape().as_list())
    bias_shape = calculate_bias_shape(input_shape, self._bias_dims)

    # Check always contains minibatched input.
    if len(input_shape) < 2:
      raise base.IncompatibleShapeError(
          "Rank of input shape must be >=2 not: {}.".format(len(input_shape)))

    # Check previous input size is same as new input size.
    if (self._input_shape is not None and
        input_shape[1:] != self._input_shape[1:]):
      raise base.IncompatibleShapeError("Input shape has changed.")

    # If transposed, make sure that the original Module is built.
    if callable(self._output_shape):
      self._output_shape = self._output_shape()
      if self._output_shape is None:
        raise base.ParentNotBuiltError(
            "Build the original untransposed module before building this one.")

    # If output_shape specified, check that it matches input_shape.
    if (self._output_shape is not None and
        self._output_shape[1:] != input_shape[1:]):
      raise base.IncompatibleShapeError(
          "Input shape must be {} not: {}.".format(self._output_shape,
                                                   input_shape[1]))

    self._input_shape = input_shape
    dtype = inputs.dtype

    if "b" not in self._initializers:
      self._initializers["b"] = create_bias_initializer(bias_shape, dtype)

    self._b = tf.get_variable(
        "b",
        shape=bias_shape,
        dtype=dtype,
        initializer=self._initializers["b"],
        partitioner=self._partitioners.get("b", None),
        regularizer=self._regularizers.get("b", None))

    bias = self._b
    if multiplier != 1:
      bias = bias * multiplier  # pylint: disable=g-no-augmented-assignment
    outputs = inputs + bias
    return outputs
예제 #8
0
파일: alexnet.py 프로젝트: zxshinxz/sonnet
    def _build(self,
               inputs,
               keep_prob=None,
               is_training=True,
               test_local_stats=True):
        """Connects the AlexNet module into the graph.

    Args:
      inputs: A Tensor of size [batch_size, input_height, input_width,
        input_channels], representing a batch of input images.
      keep_prob: A scalar Tensor representing the dropout keep probability.
      is_training: Boolean to indicate to `snt.BatchNorm` if we are
        currently training. By default `True`.
      test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch
        normalization should  use local batch statistics at test time.
        By default `True`.

    Returns:
      A Tensor of size [batch_size, output_size], where `output_size` depends
      on the mode the network was constructed in.

    Raises:
      base.IncompatibleShapeError: If any of the input image dimensions
        (input_height, input_width) are too small for the given network mode.
    """

        input_shape = inputs.get_shape().as_list()

        if input_shape[1] < self._min_size or input_shape[2] < self._min_size:
            raise base.IncompatibleShapeError(
                "Image shape too small: ({:d}, {:d}) < {:d}".format(
                    input_shape[1], input_shape[2], self._min_size))

        net = inputs

        for i, params in enumerate(self._conv_layers):
            output_channels, conv_params, max_pooling = params

            kernel_size, stride = conv_params

            conv_mod = conv.Conv2D(name="conv_{}".format(i),
                                   output_channels=output_channels,
                                   kernel_shape=kernel_size,
                                   stride=stride,
                                   padding=conv.VALID,
                                   initializers=self._initializers,
                                   partitioners=self._partitioners,
                                   regularizers=self._regularizers)

            if not self.is_connected:
                self._conv_modules.append(conv_mod)

            net = conv_mod(net)

            if self._use_batch_norm:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if max_pooling is not None:
                pooling_kernel_size, pooling_stride = max_pooling
                net = tf.nn.max_pool(
                    net,
                    ksize=[1, pooling_kernel_size, pooling_kernel_size, 1],
                    strides=[1, pooling_stride, pooling_stride, 1],
                    padding=conv.VALID)

        net = basic.BatchFlatten(name="flatten")(net)

        for i, output_size in enumerate(self._fc_layers):
            linear_mod = basic.Linear(name="fc_{}".format(i),
                                      output_size=output_size,
                                      partitioners=self._partitioners)

            if not self.is_connected:
                self._linear_modules.append(linear_mod)

            net = linear_mod(net)

            if self._use_batch_norm:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if keep_prob is not None:
                net = tf.nn.dropout(net, keep_prob=keep_prob)

        return net
예제 #9
0
파일: alexnet.py 프로젝트: zwcdp/sonnet
    def _build(self,
               inputs,
               keep_prob=None,
               is_training=None,
               test_local_stats=True):
        """Connects the AlexNet module into the graph.

    The is_training flag only controls the batch norm settings, if `False` it
    does not force no dropout by overriding any input `keep_prob`. To avoid any
    confusion this may cause, if `is_training=False` and `keep_prob` would cause
    dropout to be applied, an error is thrown.

    Args:
      inputs: A Tensor of size [batch_size, input_height, input_width,
        input_channels], representing a batch of input images.
      keep_prob: A scalar Tensor representing the dropout keep probability.
        When `is_training=False` this must be None or 1 to give no dropout.
      is_training: Boolean to indicate if we are currently training. Must be
          specified if batch normalization or dropout is used.
      test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch
        normalization should  use local batch statistics at test time.
        By default `True`.

    Returns:
      A Tensor of size [batch_size, output_size], where `output_size` depends
      on the mode the network was constructed in.

    Raises:
      base.IncompatibleShapeError: If any of the input image dimensions
        (input_height, input_width) are too small for the given network mode.
      ValueError: If `keep_prob` is not None or 1 when `is_training=False`.
      ValueError: If `is_training` is not explicitly specified when using
        batch normalization.
    """
        # Check input shape
        if (self._use_batch_norm
                or keep_prob is not None) and is_training is None:
            raise ValueError(
                "Boolean is_training flag must be explicitly specified "
                "when using batch normalization or dropout.")

        input_shape = inputs.get_shape().as_list()
        if input_shape[1] < self._min_size or input_shape[2] < self._min_size:
            raise base.IncompatibleShapeError(
                "Image shape too small: ({:d}, {:d}) < {:d}".format(
                    input_shape[1], input_shape[2], self._min_size))

        net = inputs

        # Check keep prob
        if keep_prob is not None:
            valid_inputs = tf.logical_or(is_training, tf.equal(keep_prob, 1.))
            keep_prob_check = tf.assert_equal(
                valid_inputs,
                True,
                message=
                "Input `keep_prob` must be None or 1 if `is_training=False`.")
            with tf.control_dependencies([keep_prob_check]):
                net = tf.identity(net)

        for i, params in enumerate(self._conv_layers):
            output_channels, conv_params, max_pooling = params

            kernel_size, stride = conv_params

            conv_mod = conv.Conv2D(name="conv_{}".format(i),
                                   output_channels=output_channels,
                                   kernel_shape=kernel_size,
                                   stride=stride,
                                   padding=conv.VALID,
                                   initializers=self._initializers,
                                   partitioners=self._partitioners,
                                   regularizers=self._regularizers)

            if not self.is_connected:
                self._conv_modules.append(conv_mod)

            net = conv_mod(net)

            if self._use_batch_norm:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if max_pooling is not None:
                pooling_kernel_size, pooling_stride = max_pooling
                net = tf.nn.max_pool(
                    net,
                    ksize=[1, pooling_kernel_size, pooling_kernel_size, 1],
                    strides=[1, pooling_stride, pooling_stride, 1],
                    padding=conv.VALID)

        net = basic.BatchFlatten(name="flatten")(net)

        for i, output_size in enumerate(self._fc_layers):
            linear_mod = basic.Linear(name="fc_{}".format(i),
                                      output_size=output_size,
                                      initializers=self._initializers,
                                      partitioners=self._partitioners)

            if not self.is_connected:
                self._linear_modules.append(linear_mod)

            net = linear_mod(net)

            if self._use_batch_norm and self._bn_on_fc_layers:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if keep_prob is not None:
                net = tf.nn.dropout(net, keep_prob=keep_prob)

        return net
예제 #10
0
    def _build(self, memory, query, memory_mask=None):
        """Perform a differentiable read.

    Args:
      memory: [batch_size, memory_size, memory_word_size]-shaped Tensor of
        dtype float32. This represents, for each example and memory slot, a
        single embedding to attend over.
      query: [batch_size, query_word_size]-shaped Tensor of dtype float32.
        Represents, for each example, a single embedding representing a query.
      memory_mask: None or [batch_size, memory_size]-shaped Tensor of dtype
        bool. An entry of False indicates that a memory slot should not enter
        the resulting weighted sum. If None, all memory is used.

    Returns:
      An AttentionOutput instance containing:
        read: [batch_size, memory_word_size]-shaped Tensor of dtype float32.
          This represents, for each example, a weighted sum of the contents of
          the memory.
        weights: [batch_size, memory_size]-shaped Tensor of dtype float32. This
          represents, for each example and memory slot, the attention weights
          used to compute the read.
        weight_logits: [batch_size, memory_size]-shaped Tensor of dtype float32.
          This represents, for each example and memory slot, the logits of the
          attention weights, that is, `weights` is calculated by taking the
          softmax of the weight logits.

    Raises:
      UnderspecifiedError: if memory_word_size or query_word_size can not be
        inferred.
      IncompatibleShapeError: if memory, query, memory_mask, or output of
        attention_logit_mod do not match expected shapes.
    """
        if len(memory.get_shape()) != 3:
            raise base.IncompatibleShapeError(
                "memory must have shape [batch_size, memory_size, memory_word_size]."
            )

        if len(query.get_shape()) != 2:
            raise base.IncompatibleShapeError(
                "query must have shape [batch_size, query_word_size].")

        if memory_mask is not None and len(memory_mask.get_shape()) != 2:
            raise base.IncompatibleShapeError(
                "memory_mask must have shape [batch_size, memory_size].")

        # Ensure final dimensions are defined, else the attention logit module will
        # be unable to infer input size when constructing variables.
        inferred_memory_word_size = memory.get_shape()[2].value
        inferred_query_word_size = query.get_shape()[1].value
        if inferred_memory_word_size is None or inferred_query_word_size is None:
            raise base.UnderspecifiedError(
                "memory_word_size and query_word_size must be known at graph "
                "construction time.")

        memory_shape = tf.shape(memory)
        batch_size = memory_shape[0]
        memory_size = memory_shape[1]

        query_shape = tf.shape(query)
        query_batch_size = query_shape[0]

        # Transform query to have same number of words as memory.
        #
        # expanded_query: [batch_size, memory_size, query_word_size].
        expanded_query = tf.tile(tf.expand_dims(query, dim=1),
                                 [1, memory_size, 1])

        # Compute attention weights for each memory slot.
        #
        # attention_weight_logits: [batch_size, memory_size]
        with tf.control_dependencies(
            [tf.assert_equal(batch_size, query_batch_size)]):
            concatenated_embeddings = tf.concat(
                values=[memory, expanded_query], axis=2)

        batch_apply_attention_logit = basic.BatchApply(
            self._attention_logit_mod,
            n_dims=2,
            name="batch_apply_attention_logit")
        attention_weight_logits = batch_apply_attention_logit(
            concatenated_embeddings)

        # Note: basic.BatchApply() will automatically reshape the [batch_size *
        # memory_size, 1]-shaped result of self._attention_logit_mod(...) into a
        # [batch_size, memory_size, 1]-shaped Tensor. If
        # self._attention_logit_mod(...) returns something with more dimensions,
        # then attention_weight_logits will have extra dimensions, too.
        if len(attention_weight_logits.get_shape()) != 3:
            raise base.IncompatibleShapeError(
                "attention_weight_logits must be a rank-3 Tensor. Are you sure that "
                "attention_logit_mod() returned [batch_size * memory_size, 1]-shaped"
                " Tensor?")

        # Remove final length-1 dimension.
        attention_weight_logits = tf.squeeze(attention_weight_logits, [2])

        # Mask out ignored memory slots by assigning them very small logits. Ensures
        # that every example has at least one valid memory slot, else we'd end up
        # averaging all memory slots equally.
        if memory_mask is not None:
            num_remaining_memory_slots = tf.reduce_sum(tf.cast(memory_mask,
                                                               dtype=tf.int32),
                                                       axis=[1])
            with tf.control_dependencies(
                [tf.assert_positive(num_remaining_memory_slots)]):
                finfo = np.finfo(np.float32)
                kept_indices = tf.cast(memory_mask, dtype=tf.float32)
                ignored_indices = tf.cast(tf.logical_not(memory_mask),
                                          dtype=tf.float32)
                lower_bound = finfo.max * kept_indices + finfo.min * ignored_indices
                attention_weight_logits = tf.minimum(attention_weight_logits,
                                                     lower_bound)

        # attended_memory: [batch_size, memory_word_size].
        attention_weight = tf.reshape(tf.nn.softmax(attention_weight_logits),
                                      shape=[batch_size, memory_size, 1])
        # The multiplication is elementwise and relies on broadcasting the weights
        # across memory_word_size. Then we sum across the memory slots.
        attended_memory = tf.reduce_sum(memory * attention_weight, axis=[1])

        # Infer shape of result as much as possible.
        inferred_batch_size, _, inferred_memory_word_size = (
            memory.get_shape().as_list())
        attended_memory.set_shape(
            [inferred_batch_size, inferred_memory_word_size])

        return AttentionOutput(read=attended_memory,
                               weights=tf.squeeze(attention_weight, [2]),
                               weight_logits=attention_weight_logits)