Пример #1
0
    def _build_cell(self,
                    mode,
                    batch_size,
                    initial_state=None,
                    memory=None,
                    memory_sequence_length=None,
                    dtype=None,
                    alignment_history=False):
        attention_mechanisms = [
            _build_attention_mechanism(
                attention_mechanism,
                self.num_units,
                memory,
                memory_sequence_length=memory_sequence_length)
            for attention_mechanism in self.attention_mechanism_class
        ]

        cell = build_cell(self.num_layers,
                          self.num_units,
                          mode,
                          dropout=self.dropout,
                          residual_connections=self.residual_connections,
                          cell_class=self.cell_class,
                          attention_layers=self.attention_layers,
                          attention_mechanisms=attention_mechanisms)

        initial_state = cell.zero_state(batch_size, memory.dtype)

        return cell, initial_state
Пример #2
0
    def _build_cell(self,
                    mode,
                    batch_size,
                    initial_state=None,
                    memory=None,
                    memory_sequence_length=None,
                    dtype=None,
                    alignment_history=False):
        _ = memory_sequence_length
        _ = alignment_history

        if memory is None and dtype is None:
            raise ValueError(
                "dtype argument is required when memory is not set")

        cell = build_cell(self.num_layers,
                          self.num_units,
                          mode,
                          dropout=self.dropout,
                          residual_connections=self.residual_connections,
                          cell_class=self.cell_class)

        initial_state = self._init_state(cell.zero_state(
            batch_size, dtype or memory.dtype),
                                         initial_state=initial_state)

        return cell, initial_state
Пример #3
0
    def make_inputs(self, features, training=None):
        inputs = features["char_ids"]
        flat_inputs = tf.reshape(inputs, [-1, tf.shape(inputs)[-1]])
        embeddings = self._embed(flat_inputs, training)
        sequence_length = tf.count_nonzero(flat_inputs, axis=1)

        cell = build_cell(1,
                          self.num_units,
                          tf.estimator.ModeKeys.TRAIN if training else None,
                          dropout=self.dropout,
                          cell_class=self.cell_class)
        rnn_outputs, rnn_state = tf.nn.dynamic_rnn(
            cell,
            embeddings,
            sequence_length=sequence_length,
            dtype=embeddings.dtype)

        if self.encoding == "average":
            encoding = tf.reduce_mean(rnn_outputs, axis=1)
        elif self.encoding == "last":
            encoding = last_encoding_from_state(rnn_state)

        outputs = tf.reshape(encoding,
                             [-1, tf.shape(inputs)[1], self.num_units])
        return outputs
Пример #4
0
 def _build_cell(self, mode):
     return build_cell(self.num_layers,
                       self.num_units,
                       mode,
                       dropout=self.dropout,
                       residual_connections=self.residual_connections,
                       cell_class=self.cell_class)
Пример #5
0
  def _build_cell(self,
                  mode,
                  batch_size,
                  initial_state=None,
                  memory=None,
                  memory_sequence_length=None,
                  dtype=None):
    attention_mechanisms = [
        _build_attention_mechanism(
            attention_mechanism,
            self.num_units,
            memory,
            memory_sequence_length=memory_sequence_length)
        for attention_mechanism in self.attention_mechanism_class]

    cell = build_cell(
        self.num_layers,
        self.num_units,
        mode,
        dropout=self.dropout,
        residual_connections=self.residual_connections,
        cell_class=self.cell_class,
        attention_layers=self.attention_layers,
        attention_mechanisms=attention_mechanisms)

    initial_state = cell.zero_state(batch_size, memory.dtype)

    return cell, initial_state
Пример #6
0
 def _build_cell(self, mode):
   return build_cell(
       self.num_layers,
       self.num_units,
       mode,
       dropout=self.dropout,
       residual_connections=self.residual_connections,
       cell_class=self.cell_class)
Пример #7
0
    def _build_cell(self,
                    mode,
                    batch_size,
                    initial_state=None,
                    memory=None,
                    memory_sequence_length=None):
        _ = memory
        _ = memory_sequence_length

        cell = build_cell(self.num_layers,
                          self.num_units,
                          mode,
                          dropout=self.dropout,
                          residual_connections=self.residual_connections,
                          cell_class=self.cell_class)

        initial_state = self._init_state(cell.zero_state(
            batch_size, tf.float32),
                                         initial_state=initial_state)

        return cell, initial_state
Пример #8
0
    def transform(self, inputs, mode):
        flat_inputs = tf.reshape(inputs, [-1, tf.shape(inputs)[-1]])
        embeddings = self._embed(flat_inputs, mode)
        sequence_length = tf.count_nonzero(flat_inputs, axis=1)

        cell = build_cell(1,
                          self.num_units,
                          mode,
                          dropout=self.dropout,
                          cell_class=self.cell_class)
        rnn_outputs, rnn_state = tf.nn.dynamic_rnn(
            cell,
            embeddings,
            sequence_length=sequence_length,
            dtype=embeddings.dtype)

        if self.encoding == "average":
            encoding = tf.reduce_mean(rnn_outputs, axis=1)
        elif self.encoding == "last":
            encoding = last_encoding_from_state(rnn_state)

        outputs = tf.reshape(encoding,
                             [-1, tf.shape(inputs)[1], self.num_units])
        return outputs
Пример #9
0
  def _build_cell(self,
                  mode,
                  batch_size,
                  initial_state=None,
                  memory=None,
                  memory_sequence_length=None,
                  dtype=None):
    _ = memory_sequence_length

    if memory is None and dtype is None:
      raise ValueError("dtype argument is required when memory is not set")

    cell = build_cell(
        self.num_layers,
        self.num_units,
        mode,
        dropout=self.dropout,
        residual_connections=self.residual_connections,
        cell_class=self.cell_class)

    initial_state = self._init_state(
        cell.zero_state(batch_size, dtype or memory.dtype), initial_state=initial_state)

    return cell, initial_state