def _symbols_to_logits_fn(self,
                              embedding,
                              vocab_size,
                              mode,
                              output_layer=None,
                              dtype=None):
        embedding_fn = get_embedding_fn(embedding)
        if output_layer is None:
            output_layer = build_output_layer(self.num_units,
                                              vocab_size,
                                              dtype=dtype)

        def _impl(ids, step, cache):
            inputs = embedding_fn(ids[:, -1:])
            inputs *= self.num_units**0.5
            inputs = self.position_encoder.apply_one(inputs, step + 1)
            outputs = self._self_attention_stack(inputs,
                                                 mode=mode,
                                                 cache=cache,
                                                 memory=cache["memory"],
                                                 memory_sequence_length=None)
            outputs = outputs[:, -1:, :]
            logits = output_layer(outputs)
            return logits, cache

        return _impl
Esempio n. 2
0
  def _symbols_to_logits_fn(self, embedding, vocab_size, mode, output_layer=None, dtype=None):
    embedding_fn = get_embedding_fn(embedding)
    if self.share_embedding:
      w_embs = reuse_variable("w_embs")
      output_layer = build_linear_shared_weights(
          vocab_size, w_embs, scope="proj_to_vocab_size")
    elif output_layer is None:
      output_layer = build_linear_weight_norm(self.out_embedding_dim, vocab_size,
                                              dropout=self.dropout,
                                              dtype=dtype,
                                              scope="proj_to_vocab_size")

    def _impl(ids, step, cache):
      inputs = embedding_fn(ids[:, -1:])
      if self.position_encoder is not None:
        inputs = self.position_encoder.apply_one(inputs, step + 1)
      outputs = self._cnn_stack(
          inputs,
          memory=cache["memory"],
          mode=mode,
          cache=cache)
      outputs = outputs[:, -1:, :]
      logits = output_layer(outputs)
      return logits, cache

    return _impl
Esempio n. 3
0
    def dynamic_decode_and_search(self,
                                  embedding,
                                  start_tokens,
                                  end_token,
                                  vocab_size,
                                  initial_state=None,
                                  beam_width=5,
                                  length_penalty=0.0,
                                  maximum_iterations=250,
                                  mode=tf.estimator.ModeKeys.PREDICT,
                                  memory=None,
                                  memory_sequence_length=None):
        if initial_state is not None:
            initial_state = tf.contrib.seq2seq.tile_batch(
                initial_state, multiplier=beam_width)
        if memory is not None:
            memory = tf.contrib.seq2seq.tile_batch(memory,
                                                   multiplier=beam_width)
        if memory_sequence_length is not None:
            memory_sequence_length = tf.contrib.seq2seq.tile_batch(
                memory_sequence_length, multiplier=beam_width)

        embedding_fn = get_embedding_fn(embedding)

        def _symbols_to_logits_fn(symbols):
            batch_size = tf.shape(symbols)[0]
            step = tf.shape(symbols)[1]
            sequence_length = tf.fill([batch_size], step)
            outputs = self._self_attention_stack(
                embedding_fn(symbols),
                sequence_length,
                mode=mode,
                memory=memory,
                memory_sequence_length=memory_sequence_length)

            # Only sample the last timestep.
            last_output = tf.slice(outputs, [0, step - 1, 0], [-1, 1, -1])
            logits = tf.layers.dense(last_output, vocab_size)
            return logits

        outputs, log_probs = beam_search(_symbols_to_logits_fn,
                                         start_tokens,
                                         beam_width,
                                         maximum_iterations,
                                         vocab_size,
                                         length_penalty,
                                         eos_id=end_token)
        outputs = tf.slice(outputs, [0, 0, 1], [-1, -1, -1])  # Ignore <s>.

        lengths = tf.not_equal(outputs, 0)
        lengths = tf.cast(lengths, tf.int32)
        lengths = tf.reduce_sum(lengths, axis=-1)

        return (outputs, None, lengths, log_probs)
Esempio n. 4
0
    def _symbols_to_logits_fn(self, embedding, vocab_size, mode):
        embedding_fn = get_embedding_fn(embedding)

        def _impl(ids, step, cache):
            inputs = embedding_fn(ids[:, -1:])
            inputs = self.position_encoder.apply_one(inputs, step + 1)
            outputs = self._self_attention_stack(
                inputs,
                mode=mode,
                cache=cache,
                memory=cache["memory"],
                memory_sequence_length=cache["memory_sequence_length"])
            outputs = outputs[:, -1:, :]
            logits = tf.layers.dense(outputs, vocab_size)
            return logits, cache

        return _impl
Esempio n. 5
0
  def _symbols_to_logits_fn(self, embedding, vocab_size, mode):
    embedding_fn = get_embedding_fn(embedding)

    def _impl(ids, step, cache):
      inputs = embedding_fn(ids[:, -1:])
      inputs *= self.num_units**0.5
      inputs = self.position_encoder.apply_one(inputs, step + 1)
      outputs = self._self_attention_stack(
          inputs,
          mode=mode,
          cache=cache,
          memory=cache["memory"],
          memory_sequence_length=None)
      outputs = outputs[:, -1:, :]
      logits = tf.layers.dense(outputs, vocab_size)
      return logits, cache

    return _impl
Esempio n. 6
0
    def dynamic_decode(self,
                       embedding,
                       start_tokens,
                       end_token,
                       vocab_size,
                       initial_state=None,
                       maximum_iterations=250,
                       mode=tf.estimator.ModeKeys.PREDICT,
                       memory=None,
                       memory_sequence_length=None):
        batch_size = tf.shape(start_tokens)[0]
        finished = tf.tile([False], [batch_size])
        step = tf.constant(0)
        inputs = tf.expand_dims(start_tokens, 1)
        lengths = tf.zeros([batch_size], dtype=tf.int32)
        log_probs = tf.zeros([batch_size])

        embedding_fn = get_embedding_fn(embedding)

        def _condition(unused_step, finished, unused_inputs, unused_lengths,
                       unused_log_probs):
            return tf.logical_not(tf.reduce_all(finished))

        def _body(step, finished, inputs, lengths, log_probs):
            inputs_lengths = tf.add(lengths, 1 - tf.cast(finished, tf.int32))

            # Decode inputs.
            outputs = self._self_attention_stack(
                embedding_fn(inputs),
                inputs_lengths,
                mode=mode,
                memory=memory,
                memory_sequence_length=memory_sequence_length)

            # Only sample the last timestep.
            last_output = tf.slice(outputs, [0, step, 0], [-1, 1, -1])
            logits = tf.layers.dense(last_output, vocab_size)
            probs = tf.nn.log_softmax(logits)
            sample_ids = tf.argmax(probs, axis=-1)

            # Accumulate log probabilities.
            sample_probs = tf.reduce_max(probs, axis=-1)
            masked_probs = tf.squeeze(
                sample_probs, -1) * (1.0 - tf.cast(finished, tf.float32))
            log_probs = tf.add(log_probs, masked_probs)

            next_inputs = tf.concat(
                [inputs, tf.cast(sample_ids, tf.int32)], -1)
            next_lengths = inputs_lengths
            next_finished = tf.logical_or(
                finished, tf.equal(tf.squeeze(sample_ids, axis=[1]),
                                   end_token))
            step = step + 1

            if maximum_iterations is not None:
                next_finished = tf.logical_or(next_finished,
                                              step >= maximum_iterations)

            return step, next_finished, next_inputs, next_lengths, log_probs

        step, _, outputs, lengths, log_probs = tf.while_loop(
            _condition,
            _body,
            loop_vars=(step, finished, inputs, lengths, log_probs),
            shape_invariants=(tf.TensorShape([]), finished.get_shape(),
                              tf.TensorShape([None,
                                              None]), lengths.get_shape(),
                              log_probs.get_shape()),
            parallel_iterations=1)

        outputs = tf.slice(outputs, [0, 1], [-1, -1])  # Ignore <s>.

        # Make shape consistent with beam search.
        outputs = tf.expand_dims(outputs, 1)
        lengths = tf.expand_dims(lengths, 1)
        log_probs = tf.expand_dims(log_probs, 1)

        return (outputs, None, lengths, log_probs)