Exemplo n.º 1
0
def create_attention_mask_from_input_mask(from_tensor, to_mask):
    """Create 3D attention mask from a 2D tensor mask.

  Args:
    from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
    to_mask: int32 Tensor of shape [batch_size, to_seq_length].

  Returns:
    float Tensor of shape [batch_size, from_seq_length, to_seq_length].
  """
    from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
    batch_size = from_shape[0]
    from_seq_length = from_shape[1]

    to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
    to_seq_length = to_shape[1]

    to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
                      dtype=from_tensor.dtype)

    # We don't assume that `from_tensor` is a mask (although it could be). We
    # don't actually care if we attend *from* padding tokens (only *to* padding)
    # tokens so we create a tensor of all ones.
    #
    # `broadcast_ones` = [batch_size, from_seq_length, 1]
    broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1],
                             dtype=from_tensor.dtype)

    # Here we broadcast along two dimensions to create the mask.
    mask = broadcast_ones * to_mask

    return mask
Exemplo n.º 2
0
    def call(self, inputs):
        vocab_dists = inputs[0]
        attn_dists = inputs[1]
        p_gens = inputs[2]
        input_ids = inputs[3]
        max_oov_size = self.max_oov_size

        vocab_dists = [
            p_gen * dist for (p_gen, dist) in zip(p_gens, vocab_dists)
        ]
        attn_dists = [(1 - p_gen) * dist
                      for (p_gen, dist) in zip(p_gens, attn_dists)]

        batch_size = tf_utils.get_shape_list(vocab_dists[0])[0]

        # Concatenate some zeros to each vocabulary dist, to hold the probabilities for in-article OOV words
        extended_vsize = self.vocab_size + max_oov_size  # the maximum (over the batch) size of the extended vocabulary
        extra_zeros = tf.zeros((batch_size, max_oov_size))
        vocab_dists_extended = [
            tf.concat(axis=1, values=[dist, extra_zeros])
            for dist in vocab_dists
        ]  # list length max_dec_steps of shape (batch_size, extended_vsize)

        # Project the values in the attention distributions onto the appropriate entries in the final distributions
        # This means that if a_i = 0.1 and the ith encoder word is w, and w has index 500 in the vocabulary, then we add 0.1 onto the 500th entry of the final distribution
        # This is done for each decoder timestep.
        # This is fiddly; we use tf.scatter_nd to do the projection
        batch_nums = tf.range(0, limit=batch_size)  # shape (batch_size)
        batch_nums = tf.expand_dims(batch_nums, 1)  # shape (batch_size, 1)
        attn_len = tf_utils.get_shape_list(input_ids)[
            1]  # number of states we attend over
        batch_nums = tf.tile(batch_nums,
                             [1, attn_len])  # shape (batch_size, attn_len)
        indices = tf.stack((batch_nums, input_ids),
                           axis=2)  # shape (batch_size, enc_t, 2)
        shape = [batch_size, extended_vsize]

        attn_dists_projected = [
            tf.scatter_nd(indices, copy_dist, shape)
            for copy_dist in attn_dists
        ]  # list length max_dec_steps (batch_size, extended_vsize)

        # Add the vocab distributions and the copy distributions together to get the final distributions
        # final_dists is a list length max_dec_steps; each entry is a tensor shape (batch_size, extended_vsize) giving the final distribution for that decoder timestep
        # Note that for decoder timesteps and examples corresponding to a [PAD] token, this is junk - ignore.
        final_dists = [
            vocab_dist + copy_dist
            for (vocab_dist,
                 copy_dist) in zip(vocab_dists_extended, attn_dists_projected)
        ]

        return final_dists
Exemplo n.º 3
0
    def call(self, inputs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        word_embeddings = unpacked_inputs[0]
        token_type_ids = unpacked_inputs[1]
        input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
        batch_size = input_shape[0]
        seq_length = input_shape[1]
        width = input_shape[2]

        output = word_embeddings
        if self.use_type_embeddings:
            flat_token_type_ids = tf.reshape(token_type_ids, [-1])
            one_hot_ids = tf.one_hot(flat_token_type_ids,
                                     depth=self.token_type_vocab_size,
                                     dtype=self.dtype)
            token_type_embeddings = tf.matmul(one_hot_ids,
                                              self.type_embeddings)
            token_type_embeddings = tf.reshape(token_type_embeddings,
                                               [batch_size, seq_length, width])
            output += token_type_embeddings

        if self.use_position_embeddings:
            position_embeddings = tf.expand_dims(tf.slice(
                self.position_embeddings, [0, 0], [seq_length, width]),
                                                 axis=0)

            output += position_embeddings

        output = self.output_layer_norm(output)
        output = self.output_dropout(output)

        return output
Exemplo n.º 4
0
 def call(self, inputs):
     """Implements call() for the layer."""
     input_shape = tf_utils.get_shape_list(inputs)
     flat_input = tf.reshape(inputs, [-1])
     output = tf.gather(self.embeddings, flat_input)
     output = tf.reshape(output, input_shape + [self.embedding_size])
     return output
Exemplo n.º 5
0
def gather_indexes(sequence_tensor, positions):
    """Gathers the vectors at the specific positions.

  Args:
      sequence_tensor: Sequence output of `BertModel` layer of shape
        (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
        hidden units of `BertModel` layer.
      positions: Positions ids of tokens in sequence to mask for pretraining of
        with dimension (batch_size, max_predictions_per_seq) where
        `max_predictions_per_seq` is maximum number of tokens to mask out and
        predict per each sequence.

  Returns:
      Masked out sequence tensor of shape (batch_size * max_predictions_per_seq,
      num_hidden).
  """
    sequence_shape = tf_utils.get_shape_list(sequence_tensor,
                                             name='sequence_output_tensor')
    batch_size = sequence_shape[0]
    seq_length = sequence_shape[1]
    width = sequence_shape[2]

    flat_offsets = tf.keras.backend.reshape(
        tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
    flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1])
    flat_sequence_tensor = tf.keras.backend.reshape(
        sequence_tensor, [batch_size * seq_length, width])
    output_tensor = tf.gather(flat_sequence_tensor, flat_positions)

    return output_tensor
Exemplo n.º 6
0
    def call(self, inputs):

        enc_states = inputs[0]
        answer_ids = inputs[1]
        dec_mask = inputs[2]
        batch_size = tf_utils.get_shape_list(answer_ids)[0]
        emb_dec_inputs = [
            self.embedding_lookup(x) for x in tf.unstack(answer_ids, axis=1)
        ]  # list length max_dec_steps containing shape (batch_size, emb_size)

        dec_initial_state = tf.zeros([batch_size, self.config.hidden_size])

        #prev_coverage = self.prev_coverage if self.config.mode == "decode" and self.config.use_coverage else None

        decoder_outputs, dec_out_state, attn_dists, p_gens, coverage = self.decoder(
            emb_dec_inputs, dec_initial_state, enc_states, dec_mask)

        vocab_dists = None  #self.output_projector(decoder_outputs)

        if self.config.use_pointer_gen:
            final_dists = self.final_distribution(vocab_dists, attn_dists,
                                                  p_gens)
        else:  # final distribution is just vocabulary distribution
            final_dists = vocab_dists

        return final_dists, attn_dists
Exemplo n.º 7
0
    def call(self, inputs):
        input_ids, input_mask, answer_ids, answer_mask = inputs
        emb_enc_inputs = self.embedding_lookup(input_ids)
        # tensor with shape (batch_size, max_seq_length, emb_size)

        emb_dec_inputs = [
            self.embedding_lookup(x) for x in tf.unstack(answer_ids, axis=1)
        ]
        # list length max_dec_steps containing shape (batch_size, emb_size)
        if not self.training:
            emb_dec_inputs = emb_dec_inputs[:1]
            #we only have the [START] token

        enc_outputs, enc_state = self.encoder(emb_enc_inputs, input_mask)

        _enc_states = enc_outputs

        _dec_in_state = enc_state

        atten_len = tf_utils.get_shape_list(input_ids)[1]
        batch_size = tf_utils.get_shape_list(input_ids)[0]

        prev_coverage = None  # self.prev_coverage #if self.config.mode == "decode" and self.config.use_coverage  else None

        if self.training:
            decoder_outputs, _dec_out_state, attn_dists, p_gens, coverage = self.decoder(
                emb_dec_inputs,
                _dec_in_state,
                _enc_states,
                input_mask,
                prev_coverage=prev_coverage)
            # if mode == "decoder":
            #     return (decoder_outputs, _dec_out_state, attn_dists, p_gens, coverage)

            vocab_dists = self.output_projector(decoder_outputs)

            if self.config.use_pointer_gen:
                final_dists = self.final_distribution(vocab_dists, attn_dists,
                                                      p_gens, input_ids)
            else:  # final distribution is just vocabulary distribution
                final_dists = vocab_dists

            def _mask_and_avg(values, padding_mask):
                """Applies mask to values then returns overall average (a scalar)

                Args:
                  values: a list length max_dec_steps containing arrays shape (batch_size).
                  padding_mask: tensor shape (batch_size, max_dec_steps) containing 1s and 0s.

                Returns:
                  a scalar
                """
                padding_mask = tf.cast(padding_mask, tf.dtypes.float32)
                dec_lens = (tf.reduce_sum(padding_mask,
                                          axis=1))  # shape batch_size. float32
                values_per_step = [
                    v * padding_mask[:, dec_step]
                    for dec_step, v in enumerate(values)
                ]
                values_per_ex = sum(
                    values_per_step
                ) / dec_lens  # shape (batch_size); normalized value for each batch member
                return tf.reduce_mean(values_per_ex)  # overall average

            def _coverage_loss(attn_dists, padding_mask):
                """Calculates the coverage loss from the attention distributions.

                Args:
                  attn_dists: The attention distributions for each decoder timestep. A list length max_dec_steps containing shape (batch_size, attn_length)
                  padding_mask: shape (batch_size, max_dec_steps).

                Returns:
                  coverage_loss: scalar
                """
                coverage = tf.zeros_like(
                    attn_dists[0]
                )  # shape (batch_size, attn_length). Initial coverage is zero.
                covlosses = [
                ]  # Coverage loss per decoder timestep. Will be list length max_dec_steps containing shape (batch_size).
                for a in attn_dists:
                    covloss = tf.reduce_sum(
                        tf.minimum(a, coverage),
                        [1])  # calculate the coverage loss for this step
                    covlosses.append(covloss)
                    coverage += a  # update the coverage vector
                coverage_loss = _mask_and_avg(covlosses, padding_mask)
                return coverage_loss

            #add the coverage loss
            self.add_loss(_coverage_loss(attn_dists, answer_mask))

            # the main loss will be based on predictions, we will let the training loop handle it.

            return final_dists
        else:

            def sort_hyps(hyps):
                """Return a list of Hypothesis objects, sorted by descending average log probability"""
                return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True)

            # we do a beam search on step by step decoding
            UNKNOWN_TOKEN = 0
            START_TOKEN = 104
            STOP_TOKEN = 105

            hyps = [{
                "tokens": [START_TOKEN],  # answer_ids[0] is the [START]
                "log_probs": [0.0],
                "state": _dec_in_state,
                "attn_dists": [],
                "p_gens": [],
                "coverage": tf.zeros(
                    atten_len)  # zero vector of length attention_length
            }] * batch_size

            results = [
            ]  # this will contain finished hypotheses (those that have emitted the [STOP] token)

            steps = 0
            while steps < self.config.max_dec_steps and len(
                    results) < self.config.beam_size:
                latest_tokens = [h.latest_token for h in hyps
                                 ]  # latest token produced by each hypothesis
                latest_tokens = [
                    t
                    if t in tf.range(self.config.vocab_size) else UNKNOWN_TOKEN
                    for t in latest_tokens
                ]  # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings
                states = [h.state for h in hyps
                          ]  # list of current decoder states of the hypotheses
                prev_coverage = [h.coverage for h in hyps
                                 ]  # list of coverage vectors (or None)

                # Run one step of the decoder to get the new info
                decoder_outputs, new_states, attn_dists, p_gens, new_coverage = self.decoder(
                    latest_tokens,
                    _dec_in_state,
                    states,
                    input_mask,
                    prev_coverage=prev_coverage)

                vocab_dists = self.output_projector(decoder_outputs)

                if self.config.use_pointer_gen:
                    final_dists = self.final_distribution(
                        vocab_dists, attn_dists, p_gens, input_ids)
                else:  # final distribution is just vocabulary distribution
                    final_dists = vocab_dists

                assert len(
                    final_dists
                ) == 1  # final_dists is a singleton list containing shape (batch_size, extended_vsize)
                final_dists = final_dists[0]
                topk_probs, topk_ids = tf.nn.top_k(final_dists,
                                                   self.config.batch_size * 2)
                # take the k largest probs. note batch_size=beam_size in decode mode

                topk_log_probs = tf.log(topk_probs)

                # Extend each hypothesis and collect them all in all_hyps
                all_hyps = []
                num_orig_hyps = 1 if steps == 0 else len(hyps)
                # On the first step, we only had one original hypothesis (the initial hypothesis).
                # On subsequent steps, all original hypotheses are distinct.

                for i in tf.range(num_orig_hyps):
                    h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], \
                                                                     new_coverage[i]
                    # take the ith hypothesis and new decoder state info

                    for j in tf.range(
                            self.config.beam_size *
                            2):  # for each of the top 2*beam_size hyps:
                        # Extend the ith hypothesis with the jth option
                        new_hyp = h.extend(token=topk_ids[i, j],
                                           log_prob=topk_log_probs[i, j],
                                           state=new_state,
                                           attn_dist=attn_dist,
                                           p_gen=p_gen,
                                           coverage=new_coverage_i)
                        all_hyps.append(new_hyp)

                # Filter and collect any hypotheses that have produced the end token.
                hyps = []  # will contain hypotheses for the next step
                for h in self.sort_hyps(all_hyps):  # in order of most likely h
                    if h.latest_token == STOP_TOKEN:  # if stop token is reached...
                        # If this hypothesis is sufficiently long, put in results. Otherwise discard.
                        if steps >= self.config.min_dec_steps:
                            results.append(h)
                    else:  # hasn't reached stop token, so continue to extend this hypothesis
                        hyps.append(h)
                    if len(hyps) == self.config.beam_size or len(
                            results) == self.config.beam_size:
                        # Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop.
                        break

                steps += 1

                # At this point, either we've got beam_size results, or we've reached maximum decoder steps

            if len(
                    results
            ) == 0:  # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
                results = hyps

                # Sort hypotheses by average log probability
            hyps_sorted = self.sort_hyps(results)

            # Return the hypothesis with highest average log prob
            return hyps_sorted[0]
Exemplo n.º 8
0
    def call(self, inputs):
        # unpacked_inputs = tf_utils.unpack_inputs(inputs)

        decoder_inputs = inputs[0]
        initial_state = inputs[1]
        encoder_states = inputs[2]
        enc_padding_mask = inputs[3]
        prev_coverage = inputs[4]

        outputs = []
        attn_dists = []
        p_gens = []

        encoder_states = tf.expand_dims(
            encoder_states,
            axis=2)  # now is shape (batch_size, attn_len, 1, attn_size)

        encoder_features = self.encoder_layer(
            encoder_states
        )  # shape (batch_size,attn_length,1,attention_vec_size)
        state = initial_state  # [initial_state,initial_state]
        # state=[initial_state]*2
        batch_size = tf_utils.get_shape_list(encoder_states)[0]

        coverage = prev_coverage  # initialize coverage to None or whatever was passed in

        context_vector = tf.zeros([batch_size, self.vector_size])
        context_vector.set_shape([
            None, self.vector_size
        ])  # Ensure the second shape of attention vectors is set.
        if self.initial_state_attention:  # true in decode mode
            # Re-calculate the context vector from the previous step so that we can pass it through a linear layer
            # with this step's input to get a modified version of the input
            context_vector, _, coverage = self.attention_layer(
                encoder_features=encoder_features,
                decoder_state=state,
                coverage=coverage,
                input_mask=enc_padding_mask)
            # in decode mode, this is what updates the coverage vector

        for i, inp in enumerate(decoder_inputs):

            # Merge input and previous attentions into one vector x of the same size as inp
            input_size = inp.get_shape().with_rank(2)[1]

            if input_size is None:
                raise ValueError("Could not infer input size from input: %s" %
                                 inp.name)

            x = self.linear([[inp], [context_vector]])

            # Run the decoder RNN cell. cell_output = decoder state
            # print(i, x, state)
            cell_output, state = self.lstm_layer(x, state)

            # Run the attention mechanism.
            if i == 0 and self.initial_state_attention:  # always true in decode mode
                context_vector, attn_dist, _ = self.attention_layer(
                    encoder_features=encoder_features,
                    decoder_state=state,
                    coverage=coverage,
                    input_mask=enc_padding_mask
                )  # don't allow coverage to update
            else:
                context_vector, attn_dist, coverage = self.attention_layer(
                    encoder_features=encoder_features,
                    decoder_state=state,
                    coverage=coverage,
                    input_mask=enc_padding_mask)
            attn_dists.append(attn_dist)

            # Calculate p_gen
            if self.pointer_gen:
                p_gen = self.linear2([[context_vector], [state[0]], [state[1]],
                                      [x]])
                # Tensor shape (batch_size, 1)
                p_gen = tf.sigmoid(p_gen)
                p_gens.append(p_gen)

                # Concatenate the cell_output (= decoder state) and the context vector, and pass them through a linear layer
                # This is V[s_t, h*_t] + b in the paper
                output = self.linear([[cell_output], [context_vector]])
            outputs.append(output)

        # If using coverage, reshape it
        if coverage is not None:
            coverage = tf.reshape(coverage, [batch_size, -1])

        return outputs, state, attn_dists, p_gens, coverage
Exemplo n.º 9
0
    def call(self, inputs):

        input_ids, input_mask, answer_ids, answer_mask = inputs
        emb_enc_inputs = self.embedding_lookup(input_ids)
        # tensor with shape (batch_size, max_seq_length, emb_size)

        emb_dec_inputs = [
            self.embedding_lookup(x) for x in tf.unstack(answer_ids, axis=1)
        ]
        # list length max_dec_steps containing shape (batch_size, emb_size)

        enc_outputs, enc_state = self.encoder(emb_enc_inputs, input_mask)

        _enc_states = enc_outputs

        _dec_in_state = enc_state

        atten_len = tf_utils.get_shape_list(input_ids)[1]
        batch_size = tf_utils.get_shape_list(input_ids)[0]

        UNKNOWN_TOKEN = 0
        START_TOKEN = 104
        STOP_TOKEN = 105

        prev_coverage = self.prev_coverage  #if self.config.mode == "decode" and self.config.use_coverage  else None
        hyps = [
            Hypothesis(
                tokens=[START_TOKEN],  #answer_ids[0] is the [START]
                log_probs=[0.0],
                state=_dec_in_state,
                attn_dists=[],
                p_gens=[],
                coverage=tf.zeros(
                    atten_len)  # zero vector of length attention_length
            ) for _ in tf.range(batch_size)
        ]

        results = [
        ]  # this will contain finished hypotheses (those that have emitted the [STOP] token)

        steps = 0
        while steps < self.config.max_dec_steps and len(
                results) < self.config.beam_size:
            latest_tokens = [h.latest_token for h in hyps
                             ]  # latest token produced by each hypothesis
            latest_tokens = [
                t if t in tf.range(self.config.vocab_size) else UNKNOWN_TOKEN
                for t in latest_tokens
            ]  # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings
            states = [h.state for h in hyps
                      ]  # list of current decoder states of the hypotheses
            prev_coverage = [h.coverage for h in hyps
                             ]  # list of coverage vectors (or None)

            # Run one step of the decoder to get the new info
            decoder_outputs, new_states, attn_dists, p_gens, new_coverage = self.decoder(
                latest_tokens,
                _dec_in_state,
                states,
                input_mask,
                prev_coverage=prev_coverage)

            vocab_dists = self.output_projector(decoder_outputs)

            if self.config.use_pointer_gen:
                final_dists = self.final_distribution(vocab_dists, attn_dists,
                                                      p_gens, input_ids)
            else:  # final distribution is just vocabulary distribution
                final_dists = vocab_dists

            assert len(
                final_dists
            ) == 1  # final_dists is a singleton list containing shape (batch_size, extended_vsize)
            final_dists = final_dists[0]
            topk_probs, topk_ids = tf.nn.top_k(final_dists,
                                               self.config.batch_size * 2)
            # take the k largest probs. note batch_size=beam_size in decode mode

            topk_log_probs = tf.log(topk_probs)

            # Extend each hypothesis and collect them all in all_hyps
            all_hyps = []
            num_orig_hyps = 1 if steps == 0 else len(hyps)
            # On the first step, we only had one original hypothesis (the initial hypothesis).
            # On subsequent steps, all original hypotheses are distinct.

            for i in tf.range(num_orig_hyps):
                h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], \
                                                                 new_coverage[ i]
                # take the ith hypothesis and new decoder state info

                for j in tf.range(self.config.beam_size *
                                  2):  # for each of the top 2*beam_size hyps:
                    # Extend the ith hypothesis with the jth option
                    new_hyp = h.extend(token=topk_ids[i, j],
                                       log_prob=topk_log_probs[i, j],
                                       state=new_state,
                                       attn_dist=attn_dist,
                                       p_gen=p_gen,
                                       coverage=new_coverage_i)
                    all_hyps.append(new_hyp)

            # Filter and collect any hypotheses that have produced the end token.
            hyps = []  # will contain hypotheses for the next step
            for h in self.sort_hyps(all_hyps):  # in order of most likely h
                if h.latest_token == STOP_TOKEN:  # if stop token is reached...
                    # If this hypothesis is sufficiently long, put in results. Otherwise discard.
                    if steps >= self.config.min_dec_steps:
                        results.append(h)
                else:  # hasn't reached stop token, so continue to extend this hypothesis
                    hyps.append(h)
                if len(hyps) == self.config.beam_size or len(
                        results) == self.config.beam_size:
                    # Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop.
                    break

            steps += 1

            # At this point, either we've got beam_size results, or we've reached maximum decoder steps

        if len(
                results
        ) == 0:  # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
            results = hyps

            # Sort hypotheses by average log probability
        hyps_sorted = self.sort_hyps(results)

        # Return the hypothesis with highest average log prob
        return hyps_sorted[0]
Exemplo n.º 10
0
    def call(self, inputs):
        # unpacked_inputs = tf_utils.unpack_inputs(inputs)
        decoder_inputs = inputs[0]
        initial_state = inputs[1]
        encoder_states = inputs[2]
        enc_padding_mask = inputs[3]
        prev_coverage = inputs[4]

        outputs = []
        attn_dists = []
        p_gens = []

        encoder_states = tf.expand_dims(
            encoder_states,
            axis=2)  # now is shape (batch_size, attn_len, 1, attn_size)

        encoder_features = self.encoder_layer(
            encoder_states
        )  # shape (batch_size,attn_length,1,attention_vec_size)
        state = initial_state  # [initial_state,initial_state]
        # state=[initial_state]*2
        batch_size = tf_utils.get_shape_list(encoder_states)[0]

        coverage = prev_coverage  # initialize coverage to None or whatever was passed in

        context_vector = tf.zeros([batch_size, self.vector_size])

        # Merge input and previous attentions into one vector x of the same size as inp
        input_size = decoder_inputs.get_shape().with_rank(2)[1]

        if input_size is None:
            raise ValueError("Could not infer input size from input: %s" %
                             inp.name)

        x = self.linear([[decoder_inputs], [context_vector]])

        # Run the decoder RNN cell. cell_output = decoder state
        # print(i, x, state)
        cell_output, state = self.lstm_layer(x, state)

        # Run the attention mechanism.
        context_vector, attn_dist, coverage = self.attention_layer(
            encoder_features=encoder_features,
            decoder_state=state,
            coverage=coverage,
            input_mask=enc_padding_mask)
        attn_dists.append(attn_dist)

        # Calculate p_gen
        if self.pointer_gen:
            p_gen = self.linear2([[context_vector], [state[0]], [state[1]],
                                  [x]])
            # Tensor shape (batch_size, 1)
            p_gen = tf.sigmoid(p_gen)
            p_gens.append(p_gen)

            # Concatenate the cell_output (= decoder state) and the context vector, and pass them through a linear layer
            # This is V[s_t, h*_t] + b in the paper
        output = self.linear([[cell_output], [context_vector]])

        outputs.append(output)

        return outputs, state, attn_dists, p_gens, coverage
Exemplo n.º 11
0
    def call(self, inputs):

        encoder_features = inputs[0]
        batch_size = tf_utils.get_shape_list(encoder_features)[0]

        decoder_states = inputs[1]
        input_mask = inputs[2]
        coverage = inputs[3]

        decoder_features = self.linear_layer([decoder_states])
        decoder_features = tf.expand_dims(tf.expand_dims(decoder_features, 1),
                                          1)

        # reshape to (batch_size, 1, 1, attention_vec_size)

        def masked_attention(e):
            """Take softmax of e then apply enc_padding_mask and re-normalize"""
            attn_dist = tf.nn.softmax(
                e)  # take softmax. shape (batch_size, attn_length)
            attn_dist *= tf.dtypes.cast(input_mask, tf.float32)  # apply mask
            masked_sums = tf.reduce_sum(attn_dist,
                                        axis=1)  # shape (batch_size)
            return attn_dist / tf.reshape(masked_sums, [-1, 1])  # re-normalize

        if self.use_coverage and coverage is not None:  # non-first step of coverage
            # Multiply coverage vector by w_c to get coverage_features.
            coverage_features = self.coverage_layer(
                coverage
            )  # c has shape (batch_size, attn_length, 1, attention_vec_size)

            # Calculate v^T tanh(W_h h_i + W_s s_t + w_c c_i^t + b_attn)
            e = tf.reduce_sum(self.v *
                              tf.tanh(encoder_features + decoder_features +
                                      coverage_features),
                              [2, 3])  # shape (batch_size,attn_length)

            # Calculate attention distribution
            attn_dist = masked_attention(e)

            # Update coverage vector
            coverage += tf.reshape(attn_dist, [batch_size, -1, 1, 1])
        else:
            # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn)
            e = tf.reduce_sum(self.v *
                              tf.tanh(encoder_features + decoder_features),
                              [2, 3])  # calculate e

            # Calculate attention distribution
            attn_dist = masked_attention(e)

            if self.use_coverage:  # first step of training
                coverage = tf.expand_dims(tf.expand_dims(attn_dist, 2),
                                          2)  # initialize coverage

        # Calculate the context vector from attn_dist and encoder_states
        context_vector = tf.reduce_sum(
            tf.reshape(attn_dist, [batch_size, -1, 1, 1]) * encoder_features,
            [1, 2])  # shape (batch_size, attn_size).

        context_vector = tf.reshape(context_vector, [-1, self.vector_size])

        return context_vector, attn_dist, coverage