Exemple #1
0
  def _create_gates(self, inputs, memory):
    """Create input and forget gates for this step using `inputs` and `memory`.

    Args:
      inputs: Tensor input.
      memory: The current state of memory.

    Returns:
      input_gate: A LSTM-like insert gate.
      forget_gate: A LSTM-like forget gate.
    """
    # We'll create the input and forget gates at once. Hence, calculate double
    # the gate size.
    num_gates = 2 * self._calculate_gate_size()

    memory = tf.tanh(memory)
    inputs = basic.BatchFlatten()(inputs)
    gate_inputs = basic.BatchApply(basic.Linear(num_gates), n_dims=1)(inputs)
    gate_inputs = tf.expand_dims(gate_inputs, axis=1)
    gate_memory = basic.BatchApply(basic.Linear(num_gates))(memory)
    gates = tf.split(gate_memory + gate_inputs, num_or_size_splits=2, axis=2)
    input_gate, forget_gate = gates

    input_gate = tf.sigmoid(input_gate + self._input_bias)
    forget_gate = tf.sigmoid(forget_gate + self._forget_bias)

    return input_gate, forget_gate
    def _build(self, inputs, memory, treat_input_as_matrix=False):
        """Adds relational memory to the TensorFlow graph.
    Args:
      inputs: Tensor input.
      memory: Memory output from the previous time step.
      treat_input_as_matrix: Optional, whether to treat `input` as a sequence
        of matrices. Defaulta to False, in which case the input is flattened
        into a vector.
    Returns:
      output: This time step's output.
      next_memory: The next version of memory to use.
    """
        if treat_input_as_matrix:
            inputs = basic.BatchFlatten(preserve_dims=2)(inputs)
            inputs_reshape = basic.BatchApply(basic.Linear(self._mem_size),
                                              n_dims=2)(inputs)
        else:
            inputs = basic.BatchFlatten()(inputs)
            inputs = basic.Linear(self._mem_size)(inputs)
            inputs_reshape = tf.expand_dims(inputs, 1)

        memory_plus_input = tf.concat([memory, inputs_reshape], axis=1)
        next_memory = self._attend_over_memory(memory_plus_input)

        n = inputs_reshape.get_shape().as_list()[1]
        next_memory = next_memory[:, :-n, :]

        if self._gate_style == 'unit' or self._gate_style == 'memory':
            self._input_gate, self._forget_gate = self._create_gates(
                inputs_reshape, memory)
            next_memory = self._input_gate * tf.tanh(next_memory)
            next_memory += self._forget_gate * memory

        output = basic.BatchFlatten()(next_memory)
        return output, next_memory
Exemple #3
0
    def _build(self, x, prev_state):
        """Connects the core to the graph.

    Args:
      x: Input `Tensor` of shape `(batch_size, input_size)`.
      prev_state: Previous state. This could be a `Tensor`, or a tuple of
          `Tensor`s.

    Returns:
      The tuple `(output, state)` for this core.

    Raises:
      ValueError: if the `Tensor` `x` does not have rank 2.
    """
        x.get_shape().with_rank(2)
        self._batch_size = x.get_shape().as_list()[0]
        self._dtype = x.dtype

        x_zeros = tf.concat(
            [x, tf.zeros(shape=(self._batch_size, 1), dtype=self._dtype)], 1)
        x_ones = tf.concat(
            [x, tf.ones(shape=(self._batch_size, 1), dtype=self._dtype)], 1)
        # Weights for the halting signal
        halting_linear = basic.Linear(name="halting_linear", output_size=1)

        body = functools.partial(self._body,
                                 halting_linear=halting_linear,
                                 x_ones=x_ones)
        cumul_halting_init = tf.zeros(shape=(self._batch_size, 1),
                                      dtype=self._dtype)
        iteration_init = tf.zeros(shape=(self._batch_size, 1),
                                  dtype=self._dtype)
        core_output_size = [x.value for x in self._core.output_size]
        out_init = tf.zeros(shape=(self._batch_size, ) +
                            tuple(core_output_size),
                            dtype=self._dtype)
        cumul_state_init = _nested_zeros_like(prev_state)
        remainder_init = tf.zeros(shape=(self._batch_size, 1),
                                  dtype=self._dtype)
        (unused_final_x, final_out, unused_final_state, final_cumul_state,
         unused_final_halting, final_iteration,
         final_remainder) = tf.while_loop(self._cond, body, [
             x_zeros, out_init, prev_state, cumul_state_init,
             cumul_halting_init, iteration_init, remainder_init
         ])

        act_output = basic.Linear(name="act_output_linear",
                                  output_size=self._output_size)(final_out)

        # 修改,控制器state和读向量使用 pondering 累加权重系数方式,
        # 记忆矩阵不使用,记忆矩阵保持展开时间连续性
        controller_state, access_state, read_vectors = final_cumul_state
        final_cumul_state = (controller_state, unused_final_state[1],
                             read_vectors)

        return (act_output, (final_iteration,
                             final_remainder)), final_cumul_state
Exemple #4
0
    def _instantiate_layers(self):
        """Instantiates all the linear modules used in the network.

    Layers are instantiated in the constructor, as opposed to the build
    function, because MLP implements the Transposable interface, and the
    transpose function can be called before the module is actually connected
    to the graph and build is called.

    Notice that this is safe since layers in the transposed module are
    instantiated using a lambda returning input_size of the mlp layers, and
    this doesn't have to return sensible values until the original module is
    connected to the graph.
    """

        # Here we are entering the module's variable scope to name our submodules
        # correctly (not to create variables). As such it's safe to not check
        # whether we're in the same graph. This is important if we're constructing
        # the module in one graph and connecting it in another (e.g. with `defun`
        # the module is created in some default graph, and connected to a capturing
        # graph in order to turn it into a graph function).
        with self._enter_variable_scope(check_same_graph=False):
            self._layers = [
                basic.Linear(self._output_sizes[i],
                             name="linear_{}".format(i),
                             initializers=self._initializers,
                             partitioners=self._partitioners,
                             regularizers=self._regularizers,
                             use_bias=self.use_bias)
                for i in xrange(self._num_layers)
            ]
Exemple #5
0
    def _instantiate_layers(self):
        """Instantiates all the linear modules used in the network.

        Layers are instantiated in the constructor, as opposed to the build
        function, because MLP implements the Transposable interface, and the
        transpose function can be called before the module is actually connected
        to the graph and build is called.

        Notice that this is safe since layers in the transposed module are
        instantiated using a lambda returning input_size of the mlp layers, and
        this doesn't have to return sensible values until the original module is
        connected to the graph.
        """

        with self._enter_variable_scope():
            self._layers = [
                basic.Linear(
                    self._output_sizes[i],
                    name="linear_{}".format(i),
                    initializers=self._initializers,
                    partitioners=self._partitioners,
                    regularizers=self._regularizers,
                    use_bias=self.use_bias,
                ) for i in xrange(self._num_layers)
            ]
Exemple #6
0
    def _build(self, input_, prev_state):
        """Connects the VanillaRNN module into the graph.

        If this is not the first time the module has been connected to the graph,
        the Tensors provided as input_ and state must have the same final
        dimension, in order for the existing variables to be the correct size for
        their corresponding multiplications. The batch size may differ for each
        connection.

        Args:
          input_: a 2D Tensor of size [batch_size, input_size].
          prev_state: a 2D Tensor of size [batch_size, hidden_size].

        Returns:
          output: a 2D Tensor of size [batch_size, hidden_size].
          next_state: a Tensor of size [batch_size, hidden_size].

        Raises:
          ValueError: if connecting the module into the graph any time after the
            first time, and the inferred size of the inputs does not match previous
            invocations.
        """
        self._in_to_hidden_linear = basic.Linear(
            self._hidden_size,
            name="in_to_hidden",
            initializers=self._initializers.get("in_to_hidden"),
            partitioners=self._partitioners.get("in_to_hidden"),
            regularizers=self._regularizers.get("in_to_hidden"),
        )

        self._hidden_to_hidden_linear = basic.Linear(
            self._hidden_size,
            name="hidden_to_hidden",
            initializers=self._initializers.get("hidden_to_hidden"),
            partitioners=self._partitioners.get("hidden_to_hidden"),
            regularizers=self._regularizers.get("hidden_to_hidden"),
        )

        in_to_hidden = self._in_to_hidden_linear(input_)
        hidden_to_hidden = self._hidden_to_hidden_linear(prev_state)
        output = self._activation(in_to_hidden + hidden_to_hidden)

        # For VanillaRNN, the next state of the RNN is the same as the output
        return output, output
Exemple #7
0
    def testModuleInfo_multiple_modules(self):
        # pylint: disable=not-callable
        tf.reset_default_graph()
        dumb = DumbModule(name="dumb")
        dumb_1 = DumbModule(name="dumb")
        linear = basic.Linear(10, name="linear")
        ph_0 = tf.placeholder(dtype=tf.float32, shape=(
            1,
            10,
        ))
        dumb(ph_0)
        with tf.name_scope("foo"):
            dumb_1(ph_0)
        linear(ph_0)

        def check():
            sonnet_collection = tf.get_default_graph().get_collection(
                base_info.SONNET_COLLECTION_NAME)
            self.assertEqual(len(sonnet_collection), 3)
            # item 0.
            self.assertEqual(sonnet_collection[0].module_name, "dumb")
            self.assertEqual(sonnet_collection[0].class_name,
                             "{}.DumbModule".format(THIS_MODULE))
            self.assertEqual(sonnet_collection[0].scope_name, "dumb")
            self.assertEqual(len(sonnet_collection[0].connected_subgraphs), 1)
            self.assertEqual(
                sonnet_collection[0].connected_subgraphs[0].name_scope, "dumb")
            # item 1.
            self.assertEqual(sonnet_collection[1].module_name, "dumb_1")
            self.assertEqual(sonnet_collection[1].scope_name, "dumb_1")
            self.assertEqual(sonnet_collection[1].class_name,
                             "{}.DumbModule".format(THIS_MODULE))
            self.assertEqual(sonnet_collection[1].scope_name, "dumb_1")
            self.assertEqual(len(sonnet_collection[1].connected_subgraphs), 1)
            self.assertEqual(
                sonnet_collection[1].connected_subgraphs[0].name_scope,
                "foo/dumb_1")
            # item 2.
            self.assertEqual(sonnet_collection[2].module_name, "linear")
            self.assertEqual(sonnet_collection[2].scope_name, "linear")
            self.assertEqual(sonnet_collection[2].class_name,
                             "{}.Linear".format(LINEAR_MODULE))
            self.assertEqual(sonnet_collection[2].scope_name, "linear")
            self.assertEqual(len(sonnet_collection[2].connected_subgraphs), 1)
            self.assertEqual(
                sonnet_collection[2].connected_subgraphs[0].name_scope,
                "linear")

        check()
        _copy_default_graph()
        check()
Exemple #8
0
  def _multihead_attention(self, memory):
    """Perform multi-head attention from 'Attention is All You Need'.

    Implementation of the attention mechanism from
    https://arxiv.org/abs/1706.03762.

    Args:
      memory: Memory tensor to perform attention on.

    Returns:
      new_memory: New memory tensor.
    """
    key_size = self._key_size
    value_size = self._head_size

    qkv_size = 2 * key_size + value_size
    total_size = qkv_size * self._num_heads  # Denote as F.
    qkv = basic.BatchApply(basic.Linear(total_size))(memory)
    qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv)

    mem_slots = memory.get_shape().as_list()[1]  # Denoted as N.

    # [B, N, F] -> [B, N, H, F/H]
    qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads,
                                      qkv_size])(qkv)

    # [B, N, H, F/H] -> [B, H, N, F/H]
    qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3])
    q, k, v = tf.split(qkv_transpose, [key_size, key_size, value_size], -1)

    q *= qkv_size ** -0.5
    dot_product = tf.matmul(q, k, transpose_b=True)  # [B, H, N, N]
    weights = tf.nn.softmax(dot_product)

    output = tf.matmul(weights, v)  # [B, H, N, V]

    # [B, H, N, V] -> [B, N, H, V]
    output_transpose = tf.transpose(output, [0, 2, 1, 3])

    # [B, N, H, V] -> [B, N, H * V]
    new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose)
    return new_memory
Exemple #9
0
    def _build(self,
               inputs,
               keep_prob=None,
               is_training=True,
               test_local_stats=True):
        """Connects the AlexNet module into the graph.

    Args:
      inputs: A Tensor of size [batch_size, input_height, input_width,
        input_channels], representing a batch of input images.
      keep_prob: A scalar Tensor representing the dropout keep probability.
      is_training: Boolean to indicate to `snt.BatchNorm` if we are
        currently training. By default `True`.
      test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch
        normalization should  use local batch statistics at test time.
        By default `True`.

    Returns:
      A Tensor of size [batch_size, output_size], where `output_size` depends
      on the mode the network was constructed in.

    Raises:
      base.IncompatibleShapeError: If any of the input image dimensions
        (input_height, input_width) are too small for the given network mode.
    """

        input_shape = inputs.get_shape().as_list()

        if input_shape[1] < self._min_size or input_shape[2] < self._min_size:
            raise base.IncompatibleShapeError(
                "Image shape too small: ({:d}, {:d}) < {:d}".format(
                    input_shape[1], input_shape[2], self._min_size))

        net = inputs

        for i, params in enumerate(self._conv_layers):
            output_channels, conv_params, max_pooling = params

            kernel_size, stride = conv_params

            conv_mod = conv.Conv2D(name="conv_{}".format(i),
                                   output_channels=output_channels,
                                   kernel_shape=kernel_size,
                                   stride=stride,
                                   padding=conv.VALID,
                                   initializers=self._initializers,
                                   partitioners=self._partitioners,
                                   regularizers=self._regularizers)

            if not self.is_connected:
                self._conv_modules.append(conv_mod)

            net = conv_mod(net)

            if self._use_batch_norm:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if max_pooling is not None:
                pooling_kernel_size, pooling_stride = max_pooling
                net = tf.nn.max_pool(
                    net,
                    ksize=[1, pooling_kernel_size, pooling_kernel_size, 1],
                    strides=[1, pooling_stride, pooling_stride, 1],
                    padding=conv.VALID)

        net = basic.BatchFlatten(name="flatten")(net)

        for i, output_size in enumerate(self._fc_layers):
            linear_mod = basic.Linear(name="fc_{}".format(i),
                                      output_size=output_size,
                                      partitioners=self._partitioners)

            if not self.is_connected:
                self._linear_modules.append(linear_mod)

            net = linear_mod(net)

            if self._use_batch_norm:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if keep_prob is not None:
                net = tf.nn.dropout(net, keep_prob=keep_prob)

        return net
Exemple #10
0
    def _build(self,
               inputs,
               keep_prob=None,
               is_training=None,
               test_local_stats=True):
        """Connects the AlexNet module into the graph.

    The is_training flag only controls the batch norm settings, if `False` it
    does not force no dropout by overriding any input `keep_prob`. To avoid any
    confusion this may cause, if `is_training=False` and `keep_prob` would cause
    dropout to be applied, an error is thrown.

    Args:
      inputs: A Tensor of size [batch_size, input_height, input_width,
        input_channels], representing a batch of input images.
      keep_prob: A scalar Tensor representing the dropout keep probability.
        When `is_training=False` this must be None or 1 to give no dropout.
      is_training: Boolean to indicate if we are currently training. Must be
          specified if batch normalization or dropout is used.
      test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch
        normalization should  use local batch statistics at test time.
        By default `True`.

    Returns:
      A Tensor of size [batch_size, output_size], where `output_size` depends
      on the mode the network was constructed in.

    Raises:
      base.IncompatibleShapeError: If any of the input image dimensions
        (input_height, input_width) are too small for the given network mode.
      ValueError: If `keep_prob` is not None or 1 when `is_training=False`.
      ValueError: If `is_training` is not explicitly specified when using
        batch normalization.
    """
        # Check input shape
        if (self._use_batch_norm
                or keep_prob is not None) and is_training is None:
            raise ValueError(
                "Boolean is_training flag must be explicitly specified "
                "when using batch normalization or dropout.")

        input_shape = inputs.get_shape().as_list()
        if input_shape[1] < self._min_size or input_shape[2] < self._min_size:
            raise base.IncompatibleShapeError(
                "Image shape too small: ({:d}, {:d}) < {:d}".format(
                    input_shape[1], input_shape[2], self._min_size))

        net = inputs

        # Check keep prob
        if keep_prob is not None:
            valid_inputs = tf.logical_or(is_training, tf.equal(keep_prob, 1.))
            keep_prob_check = tf.assert_equal(
                valid_inputs,
                True,
                message=
                "Input `keep_prob` must be None or 1 if `is_training=False`.")
            with tf.control_dependencies([keep_prob_check]):
                net = tf.identity(net)

        for i, params in enumerate(self._conv_layers):
            output_channels, conv_params, max_pooling = params

            kernel_size, stride = conv_params

            conv_mod = conv.Conv2D(name="conv_{}".format(i),
                                   output_channels=output_channels,
                                   kernel_shape=kernel_size,
                                   stride=stride,
                                   padding=conv.VALID,
                                   initializers=self._initializers,
                                   partitioners=self._partitioners,
                                   regularizers=self._regularizers)

            if not self.is_connected:
                self._conv_modules.append(conv_mod)

            net = conv_mod(net)

            if self._use_batch_norm:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if max_pooling is not None:
                pooling_kernel_size, pooling_stride = max_pooling
                net = tf.nn.max_pool(
                    net,
                    ksize=[1, pooling_kernel_size, pooling_kernel_size, 1],
                    strides=[1, pooling_stride, pooling_stride, 1],
                    padding=conv.VALID)

        net = basic.BatchFlatten(name="flatten")(net)

        for i, output_size in enumerate(self._fc_layers):
            linear_mod = basic.Linear(name="fc_{}".format(i),
                                      output_size=output_size,
                                      initializers=self._initializers,
                                      partitioners=self._partitioners)

            if not self.is_connected:
                self._linear_modules.append(linear_mod)

            net = linear_mod(net)

            if self._use_batch_norm and self._bn_on_fc_layers:
                bn = batch_norm.BatchNorm(**self._batch_norm_config)
                net = bn(net, is_training, test_local_stats)

            net = tf.nn.relu(net)

            if keep_prob is not None:
                net = tf.nn.dropout(net, keep_prob=keep_prob)

        return net
    def _build(self,
               inputs,
               state=None,
               condition=None,
               is_training=True,
               final_layer_key_value_inputs=None):
        """Calculates multi-layer self attention and mlp transformation.

    Args:
      inputs: Tensor of shape [batch_size, num_steps, dim_size].
      state: optional list of length num_layers of tensors of shape
        [batch_size, memory_size, dim_size].
      condition: optional tensor to condition on. The shape is shape
        [batch_size, dim_size].
      is_training: If true, dropout is applied.
      final_layer_key_value_inputs: optional Tensor to be used as the key and
        value for the final multi-head attention layer of shape
        [batch_size, num_steps, dim_size]. Useful when the tower is a Seq2Seq
        decoder and it can attend to encoder outputs.

    Returns:
      output: tensor of shape [batch_size, num_steps, output_dim_size].
      state: list of length `num_layers` containing AttentionState tuples.
    """
        # inputs: [B, N, F]
        if final_layer_key_value_inputs is not None and state is not None and len(
                state) == (self._num_layers - 1):
            raise ValueError(
                'When the final_layer_key_value_input is set, exclude'
                'the state of the last layer.')

        if condition is not None:
            condition_tile = tf.tile(tf.expand_dims(condition, 1),
                                     [1, tf.shape(inputs)[1], 1])
            inputs = tf.concat([inputs, condition_tile], -1)

        # Map inputs to be of `embedding_size` dimension.
        if inputs.get_shape().as_list()[-1] != self._embedding_size:
            inputs = default_mlp([self._embedding_size], activate_final=True)(
                inputs,
                is_training=is_training,
                dropout_keep_prob=1 - self._dropout_rate)

        if state is None:
            memory_sizes = [0]
        elif isinstance(state[0], CompressedMemoryState):
            cm_mem_size = max(_memory_size(s.compressed_memory) for s in state)
            em_mem_size = max(_memory_size(s.episodic_memory) for s in state)
            memory_sizes = [cm_mem_size, em_mem_size]
        else:
            memory_sizes = [max([_memory_size(s) for s in state])]
        chunk_size = inputs.get_shape().as_list()[1]
        self._positional_encodings = []
        # Creates positional encodings for different memory types.
        for i, memory_size in enumerate(memory_sizes):
            seq_len = chunk_size + memory_size
            key_positions = get_position_encodings(
                sequence_length=seq_len,
                hidden_size=inputs.get_shape().as_list()[2],
                clamp_value=self._clamp_time_range,
            )
            if is_training:
                key_positions = tf.nn.dropout(key_positions,
                                              rate=self._dropout_rate)
            key_positions = tf.cast(key_positions, dtype=inputs.dtype)
            query_positions = key_positions[:, -chunk_size:, :]
            self._positional_encodings.append((key_positions, query_positions))

        if self._causal:
            self._mask = create_mask(inputs, state,
                                     self._same_attention_length)

        layer_i_inputs = inputs
        attention_states = []
        key_value_inputs = None

        for i in range(self._num_layers):
            with tf.variable_scope('layer_%d' % i, reuse=tf.AUTO_REUSE):
                multihead_attention, object_mlp = self.get_sublayers(
                    is_training)
                # Multihead attention with residuals.
                state_i = None if state is None else state[i]
                if i == (self._num_layers -
                         1) and final_layer_key_value_inputs is not None:
                    # When the final_layer_key_value_inputs is set, the finaly layer
                    # of attention will use it as the key & value, thus no need for state.
                    key_value_inputs = final_layer_key_value_inputs
                    state_i = None

                attention_outputs, attention_state = multihead_attention(
                    layer_i_inputs,
                    state=state_i,
                    is_training=is_training,
                    dropout_keep_prob=1. - self._dropout_rate,
                    key_value_inputs=key_value_inputs)
                attention_states.append(attention_state)
                # Feed-forward with residuals.
                output = object_mlp(attention_outputs,
                                    is_training=is_training,
                                    dropout_keep_prob=1 - self._dropout_rate)
                layer_i_inputs = output

        if self._output_size is not None:
            output = basic.BatchApply(
                basic.Linear(self._output_size, use_bias=False))(output)

        return output, attention_states
    def _build(self,
               inputs,
               query_inputs=None,
               state=None,
               is_training=False,
               dropout_keep_prob=0.5,
               key_value_inputs=None):
        """Calculates multi-layer self attention.

    Args:
      inputs: Tensor of shape [batch_size, num_steps, output_dim_size]. Inputs
        used as the query, key, and value to the attention layer.
      query_inputs: optional Tensor of shape [batch_size, num_steps,
        output_dim_size]. Query inputs to the attention layer. Set when
        query_inputs is different from the inputs argument.
      state: optional CompressedMemoryState or a Tensor of shape [batch_size,
        memory_size, dim_size] concatenated to the inputs. Set when attend to
        the memory from previous steps.
      is_training: if currently training.
      dropout_keep_prob: dropout rate applied to attention weights.
      key_value_inputs: optional Tensor of shape [batch_size, num_steps,
        output_dim_size]. It is used as the key and value of the multihead
        attention. Set when the key and value are different from the inputs
        argument.

    Returns:
      output: the result Tensor of shape
        [batch_size, num_steps, output_dim_size].
      attention_state: named tuple of AttentionState.
    """
        if key_value_inputs is not None and state is not None:
            raise ValueError(
                'Only one of the key_value_input and state is needed.')
        embedding_size = self._value_size * self._num_heads

        q_inputs = inputs if query_inputs is None else query_inputs
        # Denoted by L. If query_inputs is None, L = N.
        _, query_size = q_inputs.get_shape().as_list()[:2]

        if key_value_inputs is not None:
            k_inputs = key_value_inputs
            v_inputs = k_inputs
        elif state is not None:
            if isinstance(state, CompressedMemoryState):
                state_memory_list = [
                    state.compressed_memory, state.episodic_memory
                ]
            else:
                state_memory_list = [state]

            k_inputs = tf.concat(state_memory_list + [inputs], 1)
            v_inputs = k_inputs
        else:
            k_inputs = inputs
            v_inputs = inputs

        # Batch size denoted by B
        batch_size = tf.shape(inputs)[0]
        # Chunk_size denoted by N
        chunk_size = inputs.get_shape().as_list()[1]
        # Denoted by N + M
        att_size = k_inputs.get_shape().as_list()[1]

        if self._positional_encodings and not self._use_relative_positions:
            if len(self._positional_encodings) != 1:
                raise ValueError(
                    'Absolute positional encodings only supported for 1 memory. '
                    'Found %i.' % len(self._positional_encodings))
            key_positions, query_positions = self._positional_encodings[0]
            k_inputs += key_positions
            q_inputs += query_positions

        # [B, H, L, K]
        q = self.multihead_linear(q_inputs, 'query')
        # [B, H, N + M, K]
        k = self.multihead_linear(k_inputs, 'key')
        # [B, H, N + M, V]
        v = self.multihead_linear(v_inputs, 'value')

        # Scaling the dot-product
        if self._scaling:
            q *= self._key_size**-0.5

        # [B, H, L, N + M]
        if self._use_relative_positions:
            r_w_bias = tf.get_variable('r_w_bias',
                                       [1, self._num_heads, 1, self._key_size],
                                       dtype=inputs.dtype)
            content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True)
            all_relative_logits = []
            # Loop over multiple positional encodings, for the case of multiple
            # memory types.
            for i, positional_encodings in enumerate(
                    self._positional_encodings):
                key_positions, query_positions = positional_encodings
                if key_positions.get_shape().as_list()[-1] != att_size:
                    key_positions = key_positions[:,
                                                  -att_size:]  # Crop to layer mem size
                is_final = i == len(self._positional_encodings) - 1
                suffix = '' if is_final else '_%d' % i
                relative_keys = self.multihead_linear(key_positions,
                                                      name='relative_keys' +
                                                      suffix)
                # [B, H, N, D]
                r_r_bias = tf.get_variable(
                    'r_r_bias' + suffix,
                    [1, self._num_heads, 1, self._key_size],
                    dtype=inputs.dtype)
                relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1])
                relative_logits = tf.matmul(q + r_r_bias,
                                            relative_keys,
                                            transpose_b=True)
                relative_logits = rel_shift(relative_logits)
                if not is_final:  # Include relative positions for input sequence.
                    relative_logits = relative_logits[:, :, :, :-chunk_size]
                all_relative_logits.append(relative_logits)
            all_relative_logits = tf.concat(all_relative_logits, 3)
            logits = content_logits + all_relative_logits
        else:
            # [B, H, N, N + M]
            logits = tf.matmul(q, k, transpose_b=True)
            content_logits = logits

        if self._mask is not None:
            if self._mask.get_shape().as_list()[-1] != att_size:
                mask = self._mask[:, :, :, -att_size:]
            else:
                mask = self._mask
            logits += mask

        weights = tf.nn.softmax(logits)
        if is_training:
            weights = tf.nn.dropout(weights, dropout_keep_prob)
        # [B, L, H, V], where V is value_size
        output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v)

        # [B, L, H, V] -> [B, L, HV]
        attended_inputs = basic.BatchReshape([query_size, embedding_size
                                              ])(output_transpose)
        # Apply final mlp to mix information between heads.
        output = basic.BatchApply(
            basic.Linear(embedding_size))(attended_inputs)

        attention_state = AttentionState(queries=q,
                                         keys=k,
                                         values=v,
                                         weights=weights,
                                         logits=content_logits,
                                         embeddings=inputs,
                                         read_words=output)
        return output, attention_state
Exemple #13
0
    def _build(self, inputs, state=None, condition=None, is_training=True):
        """Calculates multi-layer self attention and mlp transformation.

    Args:
      inputs: Tensor of shape [batch_size, num_steps, dim_size].
      state: optional tensor of shape [batch_size, memory_size, dim_size].
      condition: optional tensor to condition on. The shape is shape
        [batch_size, dim_size].
      is_training: If true, dropout is applied.

    Returns:
      output: tensor of shape [batch_size, num_steps, output_dim_size].
      state: list of length `num_layers` containing AttentionState tuples.
    """

        # inputs: [B, N, F]
        if condition is not None:
            condition_tile = tf.tile(tf.expand_dims(condition, 1),
                                     [1, tf.shape(inputs)[1], 1])
            inputs = tf.concat([inputs, condition_tile], -1)

        if state is None:
            memory_sizes = [0]
        elif isinstance(state[0], CompressedMemoryState):
            cm_mem_size = max(_memory_size(s.compressed_memory) for s in state)
            em_mem_size = max(_memory_size(s.episodic_memory) for s in state)
            memory_sizes = [cm_mem_size, em_mem_size]
        else:
            memory_sizes = [max([_memory_size(s) for s in state])]
        chunk_size = inputs.get_shape().as_list()[1]
        self._positional_encodings = []
        # Creates positional encodings for different memory types.
        for i, memory_size in enumerate(memory_sizes):
            seq_len = chunk_size + memory_size
            key_positions = get_position_encodings(
                sequence_length=seq_len,
                hidden_size=inputs.get_shape().as_list()[2],
                clamp_value=self._clamp_time_range,
            )
            if is_training:
                key_positions = tf.nn.dropout(key_positions,
                                              rate=self._dropout_rate)
            key_positions = tf.cast(key_positions, dtype=inputs.dtype)
            query_positions = key_positions[:, -chunk_size:, :]
            self._positional_encodings.append((key_positions, query_positions))

        if self._causal:
            self._mask = create_mask(inputs, state,
                                     self._same_attention_length)

        layer_i_inputs = inputs
        attention_states = []
        for i in range(self._num_layers):
            with tf.variable_scope('layer_%d' % i, reuse=tf.AUTO_REUSE):
                multihead_attention, object_mlp = self.get_sublayers(
                    is_training)
                # Multihead attention with residuals.
                state_i = None if state is None else state[i]
                attention_outputs, attention_state = multihead_attention(
                    layer_i_inputs,
                    state=state_i,
                    is_training=is_training,
                    dropout_keep_prob=1. - self._dropout_rate)
                attention_states.append(attention_state)
                # Feed-forward with residuals.
                output = object_mlp(attention_outputs,
                                    is_training=is_training,
                                    dropout_keep_prob=1 - self._dropout_rate)
                layer_i_inputs = output

        if self._output_size is not None:
            output = basic.BatchApply(
                basic.Linear(self._output_size, use_bias=False))(output)

        return output, attention_states
Exemple #14
0
    def _build(self,
               inputs,
               query_inputs=None,
               state=None,
               is_training=False,
               dropout_keep_prob=0.5):
        embedding_size = self._value_size * self._num_heads

        q_inputs = inputs if query_inputs is None else query_inputs
        # Denoted by L. If query_inputs is None, L = N.
        _, query_size = q_inputs.get_shape().as_list()[:2]

        if state is not None:
            if isinstance(state, CompressedMemoryState):
                state_memory_list = [
                    state.compressed_memory, state.episodic_memory
                ]
            else:
                state_memory_list = [state]

            k_inputs = tf.concat(state_memory_list + [inputs], 1)
            v_inputs = k_inputs
        else:
            k_inputs = inputs
            v_inputs = inputs

        # Batch size denoted by B
        batch_size = tf.shape(inputs)[0]
        # Chunk_size denoted by N
        chunk_size = inputs.get_shape().as_list()[1]
        # Denoted by N + M
        att_size = k_inputs.get_shape().as_list()[1]

        if self._positional_encodings and not self._use_relative_positions:
            key_positions, query_positions = self._positional_encodings
            k_inputs += key_positions
            q_inputs += query_positions

        # [B, H, L, K]
        q = self.multihead_linear(q_inputs, 'query')
        # [B, H, N + M, K]
        k = self.multihead_linear(k_inputs, 'key')
        # [B, H, N + M, V]
        v = self.multihead_linear(v_inputs, 'value')

        # Scaling the dot-product
        if self._scaling:
            q *= self._key_size**-0.5

        # [B, H, L, N + M]
        if self._use_relative_positions:
            r_w_bias = tf.get_variable('r_w_bias',
                                       [1, self._num_heads, 1, self._key_size],
                                       dtype=inputs.dtype)
            content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True)
            all_relative_logits = []
            # Loop over multiple positional encodings, for the case of multiple
            # memory types.
            for i, positional_encodings in enumerate(
                    self._positional_encodings):
                key_positions, query_positions = positional_encodings
                if key_positions.get_shape().as_list()[-1] != att_size:
                    key_positions = key_positions[:,
                                                  -att_size:]  # Crop to layer mem size
                is_final = i == len(self._positional_encodings) - 1
                suffix = '' if is_final else '_%d' % i
                relative_keys = self.multihead_linear(key_positions,
                                                      name='relative_keys' +
                                                      suffix)
                # [B, H, N, D]
                r_r_bias = tf.get_variable(
                    'r_r_bias' + suffix,
                    [1, self._num_heads, 1, self._key_size],
                    dtype=inputs.dtype)
                relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1])
                relative_logits = tf.matmul(q + r_r_bias,
                                            relative_keys,
                                            transpose_b=True)
                relative_logits = rel_shift(relative_logits)
                if not is_final:  # Include relative positions for input sequence.
                    relative_logits = relative_logits[:, :, :, :-chunk_size]
                all_relative_logits.append(relative_logits)
            all_relative_logits = tf.concat(all_relative_logits, 3)
            logits = content_logits + all_relative_logits
        else:
            # [B, H, N, N + M]
            logits = tf.matmul(q, k, transpose_b=True)
            content_logits = logits

        if self._mask is not None:
            if self._mask.get_shape().as_list()[-1] != att_size:
                mask = self._mask[:, :, :, -att_size:]
            else:
                mask = self._mask
            logits += mask

        weights = tf.nn.softmax(logits)
        if is_training:
            weights = tf.nn.dropout(weights, dropout_keep_prob)
        # [B, L, H, V], where V is value_size
        output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v)

        # [B, L, H, V] -> [B, L, HV]
        attended_inputs = basic.BatchReshape([query_size, embedding_size
                                              ])(output_transpose)
        # Apply final mlp to mix information between heads.
        output = basic.BatchApply(
            basic.Linear(embedding_size))(attended_inputs)

        attention_state = AttentionState(queries=q,
                                         keys=k,
                                         values=v,
                                         weights=weights,
                                         logits=content_logits,
                                         embeddings=inputs,
                                         read_words=output)
        return output, attention_state