Exemplo n.º 1
0
  def _feedforward_unit(self, state, arrays, network_states, batch_size,
                        during_training):
    """Constructs a single instance of a feed-forward cell.

    Given an input state and access to the arrays storing activations, this
    function encapsulates creation of a single network unit. This will *not*
    create new variables.

    Args:
      state: MasterState for the state that will be used to extract features.
      arrays: List of TensorArrays corresponding to network outputs from this
        component. These are used for recurrent link features; the arrays from
        other components are used for stack-prop style connections.
      network_states: NetworkState object containing the TensorArrays from
        *all* components.
      batch_size: int Tensor with the current batch size.
      during_training: Whether to build a unit for training (vs inference).

    Returns:
      List of tensors generated by the underlying network implementation.
    """
    with tf.variable_scope(self.name, reuse=True):
      fixed_embeddings = []
      for channel_id, feature_spec in enumerate(self.spec.fixed_feature):
        fixed_embedding = network_units.fixed_feature_lookup(
            self, state, channel_id, batch_size)
        if feature_spec.is_constant:
          fixed_embedding.tensor = tf.stop_gradient(fixed_embedding.tensor)
        fixed_embeddings.append(fixed_embedding)

      linked_embeddings = []
      for channel_id, feature_spec in enumerate(self.spec.linked_feature):
        if feature_spec.source_component == self.name:
          # Recurrent feature: pull from the local arrays.
          index = self.network.get_layer_index(feature_spec.source_layer)
          source_array = arrays[index]
          source_layer_size = self.network.layers[index].dim
          linked_embeddings.append(
              network_units.activation_lookup_recurrent(
                  self, state, channel_id, source_array, source_layer_size,
                  batch_size))
        else:
          # Stackprop style feature: pull from another component's arrays.
          source = self.master.lookup_component[feature_spec.source_component]
          source_tensor = network_states[source.name].activations[
              feature_spec.source_layer]
          source_layer_size = source.network.get_layer_size(
              feature_spec.source_layer)
          linked_embeddings.append(
              network_units.activation_lookup_other(
                  self, state, channel_id, source_tensor.dynamic_tensor,
                  source_layer_size, batch_size))

      context_tensor_arrays = []
      for context_layer in self.network.context_layers:
        index = self.network.get_layer_index(context_layer.name)
        context_tensor_arrays.append(arrays[index])

      return self.network.create(fixed_embeddings, linked_embeddings,
                                 context_tensor_arrays, during_training)
Exemplo n.º 2
0
  def _feedforward_unit(self, state, arrays, network_states, stride,
                        during_training):
    """Constructs a single instance of a feed-forward cell.

    Given an input state and access to the arrays storing activations, this
    function encapsulates creation of a single network unit. This will *not*
    create new variables.

    Args:
      state: MasterState for the state that will be used to extract features.
      arrays: List of TensorArrays corresponding to network outputs from this
        component. These are used for recurrent link features; the arrays from
        other components are used for stack-prop style connections.
      network_states: NetworkState object containing the TensorArrays from
        *all* components.
      stride: int Tensor with the current beam * batch size.
      during_training: Whether to build a unit for training (vs inference).

    Returns:
      List of tensors generated by the underlying network implementation.
    """
    with tf.variable_scope(self.name, reuse=True):
      fixed_embeddings = []
      for channel_id, feature_spec in enumerate(self.spec.fixed_feature):
        fixed_embedding = network_units.fixed_feature_lookup(
            self, state, channel_id, stride)
        if feature_spec.is_constant:
          fixed_embedding.tensor = tf.stop_gradient(fixed_embedding.tensor)
        fixed_embeddings.append(fixed_embedding)

      linked_embeddings = []
      for channel_id, feature_spec in enumerate(self.spec.linked_feature):
        if feature_spec.source_component == self.name:
          # Recurrent feature: pull from the local arrays.
          index = self.network.get_layer_index(feature_spec.source_layer)
          source_array = arrays[index]
          source_layer_size = self.network.layers[index].dim
          linked_embeddings.append(
              network_units.activation_lookup_recurrent(
                  self, state, channel_id, source_array, source_layer_size,
                  stride))
        else:
          # Stackprop style feature: pull from another component's arrays.
          source = self.master.lookup_component[feature_spec.source_component]
          source_tensor = network_states[source.name].activations[
              feature_spec.source_layer]
          source_layer_size = source.network.get_layer_size(
              feature_spec.source_layer)
          linked_embeddings.append(
              network_units.activation_lookup_other(
                  self, state, channel_id, source_tensor.dynamic_tensor,
                  source_layer_size))

      context_tensor_arrays = []
      for context_layer in self.network.context_layers:
        index = self.network.get_layer_index(context_layer.name)
        context_tensor_arrays.append(arrays[index])

      if self.spec.attention_component:
        logging.info('%s component has attention over %s', self.name,
                     self.spec.attention_component)
        source = self.master.lookup_component[self.spec.attention_component]
        network_state = network_states[self.spec.attention_component]
        with tf.control_dependencies(
            [tf.assert_equal(state.current_batch_size, 1)]):
          attention_tensor = tf.identity(
              network_state.activations['layer_0'].bulk_tensor)

      else:
        attention_tensor = None

      return self.network.create(fixed_embeddings, linked_embeddings,
                                 context_tensor_arrays, attention_tensor,
                                 during_training)