Exemplo n.º 1
0
    def _build(self, inputs, memory, treat_input_as_matrix=False):
        """Adds relational memory to the TensorFlow graph.
    Args:
      inputs: Tensor input.
      memory: Memory output from the previous time step.
      treat_input_as_matrix: Optional, whether to treat `input` as a sequence
        of matrices. Defaulta to False, in which case the input is flattened
        into a vector.
    Returns:
      output: This time step's output.
      next_memory: The next version of memory to use.
    """
        if treat_input_as_matrix:
            inputs = basic.BatchFlatten(preserve_dims=2)(inputs)
            inputs_reshape = basic.BatchApply(basic.Linear(self._mem_size),
                                              n_dims=2)(inputs)
        else:
            inputs = basic.BatchFlatten()(inputs)
            inputs = basic.Linear(self._mem_size)(inputs)
            inputs_reshape = tf.expand_dims(inputs, 1)

        memory_plus_input = tf.concat([memory, inputs_reshape], axis=1)
        next_memory = self._attend_over_memory(memory_plus_input)

        n = inputs_reshape.get_shape().as_list()[1]
        next_memory = next_memory[:, :-n, :]

        if self._gate_style == 'unit' or self._gate_style == 'memory':
            self._input_gate, self._forget_gate = self._create_gates(
                inputs_reshape, memory)
            next_memory = self._input_gate * tf.tanh(next_memory)
            next_memory += self._forget_gate * memory

        output = basic.BatchFlatten()(next_memory)
        return output, next_memory
Exemplo n.º 2
0
  def _create_gates(self, inputs, memory):
    """Create input and forget gates for this step using `inputs` and `memory`.

    Args:
      inputs: Tensor input.
      memory: The current state of memory.

    Returns:
      input_gate: A LSTM-like insert gate.
      forget_gate: A LSTM-like forget gate.
    """
    # We'll create the input and forget gates at once. Hence, calculate double
    # the gate size.
    num_gates = 2 * self._calculate_gate_size()

    memory = tf.tanh(memory)
    inputs = basic.BatchFlatten()(inputs)
    gate_inputs = basic.BatchApply(basic.Linear(num_gates), n_dims=1)(inputs)
    gate_inputs = tf.expand_dims(gate_inputs, axis=1)
    gate_memory = basic.BatchApply(basic.Linear(num_gates))(memory)
    gates = tf.split(gate_memory + gate_inputs, num_or_size_splits=2, axis=2)
    input_gate, forget_gate = gates

    input_gate = tf.sigmoid(input_gate + self._input_bias)
    forget_gate = tf.sigmoid(forget_gate + self._forget_bias)

    return input_gate, forget_gate
Exemplo n.º 3
0
  def _multihead_attention(self, memory):
    """Perform multi-head attention from 'Attention is All You Need'.

    Implementation of the attention mechanism from
    https://arxiv.org/abs/1706.03762.

    Args:
      memory: Memory tensor to perform attention on.

    Returns:
      new_memory: New memory tensor.
    """
    key_size = self._key_size
    value_size = self._head_size

    qkv_size = 2 * key_size + value_size
    total_size = qkv_size * self._num_heads  # Denote as F.
    qkv = basic.BatchApply(basic.Linear(total_size))(memory)
    qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv)

    mem_slots = memory.get_shape().as_list()[1]  # Denoted as N.

    # [B, N, F] -> [B, N, H, F/H]
    qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads,
                                      qkv_size])(qkv)

    # [B, N, H, F/H] -> [B, H, N, F/H]
    qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3])
    q, k, v = tf.split(qkv_transpose, [key_size, key_size, value_size], -1)

    q *= qkv_size ** -0.5
    dot_product = tf.matmul(q, k, transpose_b=True)  # [B, H, N, N]
    weights = tf.nn.softmax(dot_product)

    output = tf.matmul(weights, v)  # [B, H, N, V]

    # [B, H, N, V] -> [B, N, H, V]
    output_transpose = tf.transpose(output, [0, 2, 1, 3])

    # [B, N, H, V] -> [B, N, H * V]
    new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose)
    return new_memory
Exemplo n.º 4
0
    def _affine_grid_warper_inverse(inputs):
      """Assembles network to compute inverse affine transformation.

      Each `inputs` row potentailly contains [a, b, tx, c, d, ty]
      corresponding to an affine matrix:

        A = [a, b, tx],
            [c, d, ty]

      We want to generate a tensor containing the coefficients of the
      corresponding inverse affine transformation in a constraints-aware
      fashion.
      Calling M:

        M = [a, b]
            [c, d]

      the affine matrix for the inverse transform is:

         A_in = [M^(-1), M^-1 * [-tx, -tx]^T]

      where

        M^(-1) = (ad - bc)^(-1) * [ d, -b]
                                  [-c,  a]

      Args:
        inputs: Tensor containing a batch of transformation parameters.

      Returns:
        A tensorflow graph performing the inverse affine transformation
        parametrized by the input coefficients.
      """
      batch_size = tf.expand_dims(tf.shape(inputs)[0], 0)
      constant_shape = tf.concat([batch_size, tf.convert_to_tensor((1,))], 0)

      index = iter(range(6))
      def get_variable(constraint):
        if constraint is None:
          i = index.next()
          return inputs[:, i:i+1]
        else:
          return tf.fill(constant_shape, tf.constant(constraint,
                                                     dtype=inputs.dtype))

      constraints = chain.from_iterable(self.constraints)
      a, b, tx, c, d, ty = (get_variable(constr) for constr in constraints)

      det = a * d - b * c
      a_inv = d / det
      b_inv = -b / det
      c_inv = -c / det
      d_inv = a / det

      m_inv = basic.BatchReshape(
          [2, 2])(tf.concat([a_inv, b_inv, c_inv, d_inv], 1))

      txy = tf.expand_dims(tf.concat([tx, ty], 1), 2)

      txy_inv = basic.BatchFlatten()(tf.matmul(m_inv, txy))
      tx_inv = txy_inv[:, 0:1]
      ty_inv = txy_inv[:, 1:2]

      inverse_gw_inputs = tf.concat(
          [a_inv, b_inv, -tx_inv, c_inv, d_inv, -ty_inv], 1)

      agw = AffineGridWarper(self.output_shape,
                             self.source_shape)


      return agw(inverse_gw_inputs)  # pylint: disable=not-callable
Exemplo n.º 5
0
    def _build(self,
               inputs,
               keep_prob=None,
               is_training=True,
               test_local_stats=True):
        """Connects the AlexNet module into the graph.

    Args:
      inputs: A Tensor of size [batch_size, input_height, input_width,
        input_channels], representing a batch of input images.
      keep_prob: A scalar Tensor representing the dropout keep probability.
      is_training: Boolean to indicate to `snt.BatchNorm` if we are
        currently training. By default `True`.
      test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch
        normalization should  use local batch statistics at test time.
        By default `True`.

    Returns:
      A Tensor of size [batch_size, output_size], where `output_size` depends
      on the mode the network was constructed in.

    Raises:
      base.IncompatibleShapeError: If any of the input image dimensions
        (input_height, input_width) are too small for the given network mode.
    """

        input_shape = inputs.get_shape().as_list()

        if input_shape[1] < self._min_size or input_shape[2] < self._min_size:
            raise base.IncompatibleShapeError(
                "Image shape too small: ({:d}, {:d}) < {:d}".format(
                    input_shape[1], input_shape[2], self._min_size))

        net = inputs

        for i, params in enumerate(self._conv_layers):
            output_channels, conv_params, max_pooling = params

            kernel_size, stride = conv_params

            conv_mod = conv.Conv2D(name="conv_{}".format(i),
                                   output_channels=output_channels,
                                   kernel_shape=kernel_size,
                                   stride=stride,
                                   padding=conv.VALID,
                                   initializers=self._initializers,
                                   partitioners=self._partitioners,
                                   regularizers=self._regularizers)

            if not self.is_connected:
                self._conv_modules.append(conv_mod)

            net = conv_mod(net)

            if self._use_batch_norm:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if max_pooling is not None:
                pooling_kernel_size, pooling_stride = max_pooling
                net = tf.nn.max_pool(
                    net,
                    ksize=[1, pooling_kernel_size, pooling_kernel_size, 1],
                    strides=[1, pooling_stride, pooling_stride, 1],
                    padding=conv.VALID)

        net = basic.BatchFlatten(name="flatten")(net)

        for i, output_size in enumerate(self._fc_layers):
            linear_mod = basic.Linear(name="fc_{}".format(i),
                                      output_size=output_size,
                                      partitioners=self._partitioners)

            if not self.is_connected:
                self._linear_modules.append(linear_mod)

            net = linear_mod(net)

            if self._use_batch_norm:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if keep_prob is not None:
                net = tf.nn.dropout(net, keep_prob=keep_prob)

        return net
Exemplo n.º 6
0
    def _build(self,
               inputs,
               keep_prob=None,
               is_training=None,
               test_local_stats=True):
        """Connects the AlexNet module into the graph.

    The is_training flag only controls the batch norm settings, if `False` it
    does not force no dropout by overriding any input `keep_prob`. To avoid any
    confusion this may cause, if `is_training=False` and `keep_prob` would cause
    dropout to be applied, an error is thrown.

    Args:
      inputs: A Tensor of size [batch_size, input_height, input_width,
        input_channels], representing a batch of input images.
      keep_prob: A scalar Tensor representing the dropout keep probability.
        When `is_training=False` this must be None or 1 to give no dropout.
      is_training: Boolean to indicate if we are currently training. Must be
          specified if batch normalization or dropout is used.
      test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch
        normalization should  use local batch statistics at test time.
        By default `True`.

    Returns:
      A Tensor of size [batch_size, output_size], where `output_size` depends
      on the mode the network was constructed in.

    Raises:
      base.IncompatibleShapeError: If any of the input image dimensions
        (input_height, input_width) are too small for the given network mode.
      ValueError: If `keep_prob` is not None or 1 when `is_training=False`.
      ValueError: If `is_training` is not explicitly specified when using
        batch normalization.
    """
        # Check input shape
        if (self._use_batch_norm
                or keep_prob is not None) and is_training is None:
            raise ValueError(
                "Boolean is_training flag must be explicitly specified "
                "when using batch normalization or dropout.")

        input_shape = inputs.get_shape().as_list()
        if input_shape[1] < self._min_size or input_shape[2] < self._min_size:
            raise base.IncompatibleShapeError(
                "Image shape too small: ({:d}, {:d}) < {:d}".format(
                    input_shape[1], input_shape[2], self._min_size))

        net = inputs

        # Check keep prob
        if keep_prob is not None:
            valid_inputs = tf.logical_or(is_training, tf.equal(keep_prob, 1.))
            keep_prob_check = tf.assert_equal(
                valid_inputs,
                True,
                message=
                "Input `keep_prob` must be None or 1 if `is_training=False`.")
            with tf.control_dependencies([keep_prob_check]):
                net = tf.identity(net)

        for i, params in enumerate(self._conv_layers):
            output_channels, conv_params, max_pooling = params

            kernel_size, stride = conv_params

            conv_mod = conv.Conv2D(name="conv_{}".format(i),
                                   output_channels=output_channels,
                                   kernel_shape=kernel_size,
                                   stride=stride,
                                   padding=conv.VALID,
                                   initializers=self._initializers,
                                   partitioners=self._partitioners,
                                   regularizers=self._regularizers)

            if not self.is_connected:
                self._conv_modules.append(conv_mod)

            net = conv_mod(net)

            if self._use_batch_norm:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if max_pooling is not None:
                pooling_kernel_size, pooling_stride = max_pooling
                net = tf.nn.max_pool(
                    net,
                    ksize=[1, pooling_kernel_size, pooling_kernel_size, 1],
                    strides=[1, pooling_stride, pooling_stride, 1],
                    padding=conv.VALID)

        net = basic.BatchFlatten(name="flatten")(net)

        for i, output_size in enumerate(self._fc_layers):
            linear_mod = basic.Linear(name="fc_{}".format(i),
                                      output_size=output_size,
                                      initializers=self._initializers,
                                      partitioners=self._partitioners)

            if not self.is_connected:
                self._linear_modules.append(linear_mod)

            net = linear_mod(net)

            if self._use_batch_norm and self._bn_on_fc_layers:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if keep_prob is not None:
                net = tf.nn.dropout(net, keep_prob=keep_prob)

        return net