コード例 #1
0
    def _block_ngrams(
        self, ngram_size: int, logprobs: torch.Tensor, source: torch.LongTensor = None
    ):
        """
        Hard block ngrams from the logprobs, based on the source.

        :param ngram_size:
            The length of ngrams to block. Must be > 0.
        :param logprobs:
            Float or HalfTensor, representing the log-probabilities. This is
            modified in place.
        :param source:
            Source text to grab ngrams from. If None, it uses the current
            hypothesis (i.e. self-blocking).
        """
        for beam_id, hyp in enumerate(self.partial_hyps):
            if len(hyp) < ngram_size - 1:
                continue
            source_ = hyp if source is None else source
            ngrams = self._find_ngrams(source_, ngram_size)
            prefix = hyp[-(ngram_size - 1) :]
            for ngram in ngrams:
                if ngram_size == 1 or prefix == list(ngram[:-1]):
                    logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype)
        return logprobs
コード例 #2
0
    def get_extra_output_from_mask(
        self,
        input: torch.LongTensor,
        encoder_output: torch.Tensor,
        encoder_mask: torch.Tensor,
    ) -> ExtraOutput:
        """
        Use a trainable mask layer to determine which elements of the input to re-attend
        to.

        :param input:
            vectorized input tokens
        :param encoder_out:
            output encodings of input tokens
        :param encoder_mask:
            mask for input

        :return (enc_out, enc_mask):
            return the extra output to which we will be attending (for all layers).
        """
        weights = self.softmax(
            self.mask_dropout(self.mask_linear(encoder_output)).masked_fill_(
                (encoder_mask == 0).view(*encoder_mask.size(),
                                         1).expand(*encoder_output.size()),
                neginf(encoder_output.dtype),
            ),
            dim=1,
        )
        topk = get_topk(self.opt, input.size(-1))
        topk_inds = weights.sum(-1).topk(topk, dim=-1, sorted=False).indices
        new_input = torch.gather(input, dim=-1, index=topk_inds)
        out2 = super().forward(new_input)

        assert isinstance(out2, tuple)
        return (*out2, weights)  # type: ignore
コード例 #3
0
    def forward(self, query_embs, in_mem_embs, out_mem_embs, pad_mask):
        """
        Compute MemNN Hop step.

        :param query_embs:
            (bsz x esz) embedding of queries

        :param in_mem_embs:
            bsz list of (num_mems x esz) embedding of memories for activation

        :param out_mem_embs:
            bsz list of (num_mems x esz) embedding of memories for outputs

        :param pad_mask
            (bsz x num_mems) optional mask indicating which tokens correspond to
            padding

        :returns:
            (bsz x esz) output state
        """
        # rotate query embeddings
        attn = torch.bmm(query_embs.unsqueeze(1), in_mem_embs).squeeze(1)
        if pad_mask is not None:
            attn[pad_mask] = neginf(attn.dtype)
        probs = self.softmax(attn)
        memory_output = torch.bmm(probs.unsqueeze(1), out_mem_embs).squeeze(1)
        output = memory_output + self.rotate(query_embs)
        return output
コード例 #4
0
    def forward(self, token_ids, segment_ids, attention_mask):
        """
        Forward pass.
        """
        output_bert, output_pooler = self.bert_model(
            token_ids, segment_ids, attention_mask
        )
        # output_bert is a list of 12 (for bert base) layers.
        layer_of_interest = output_bert[self.layer_pulled]
        dtype = next(self.parameters()).dtype
        if self.add_transformer_layer:
            # Follow up by yet another transformer layer
            extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            extended_attention_mask = (~extended_attention_mask).to(dtype) * neginf(
                dtype
            )
            embedding_layer = self.additional_transformer_layer(
                layer_of_interest, extended_attention_mask
            )
        else:
            embedding_layer = layer_of_interest

        if self.aggregation == "mean":
            #  consider the average of all the output except CLS.
            # obviously ignores masked elements
            outputs_of_interest = embedding_layer[:, 1:, :]
            mask = attention_mask[:, 1:].type_as(embedding_layer).unsqueeze(2)
            sumed_embeddings = torch.sum(outputs_of_interest * mask, dim=1)
            nb_elems = torch.sum(attention_mask[:, 1:].type(dtype), dim=1).unsqueeze(1)
            embeddings = sumed_embeddings / nb_elems
        elif self.aggregation == "max":
            #  consider the max of all the output except CLS
            outputs_of_interest = embedding_layer[:, 1:, :]
            mask = (~attention_mask[:, 1:]).type(dtype).unsqueeze(2) * neginf(dtype)
            embeddings, _ = torch.max(outputs_of_interest + mask, dim=1)
        else:
            # easiest, we consider the output of "CLS" as the embedding
            embeddings = embedding_layer[:, 0, :]

        # We need this in case of dimensionality reduction
        result = self.additional_linear_layer(embeddings)

        # Sort of hack to make it work with distributed: this way the pooler layer
        # is used for grad computation, even though it does not change anything...
        # in practice, it just adds a very (768*768) x (768*batchsize) matmul
        result += 0 * torch.sum(output_pooler)
        return result
コード例 #5
0
ファイル: modules.py プロジェクト: dilenadan/ParlAI
 def output(self, tensor):
     """
     Compute output logits.
     """
     # project back to vocabulary
     output = F.linear(tensor, self.embeddings.weight)
     # compatibility with fairseq: fairseq sometimes reuses BOS tokens and
     # we need to force their probability of generation to be 0.
     output[:, :, self.start_idx] = neginf(output.dtype)
     return output
コード例 #6
0
    def output_choose_knowledge(self, out_tokens):
        #outputと知識をsoftmaxして正解知識を選べるか
        # encode the context, pretty basic
        #N:バッチサイズ, K:知識数, T:時間, D:埋め込みサイズ, Tk:
        context_encoded, context_mask = self.transformer(out_tokens)

        # make all the knowledge into a 2D matrix to encode
        N, K, Tk = self.know_tokens.size()
        know_encoded, know_mask = self.transformer(
            self.know_tokens.reshape(-1, Tk))

        # compute our sentence embeddings for context and knowledge
        context_use = universal_sentence_embedding(context_encoded,
                                                   context_mask)
        know_use = universal_sentence_embedding(know_encoded, know_mask)

        # remash it back into the shape we need
        know_use = know_use.reshape(N, self.know_tokens.size(1),
                                    self.embed_dim) / np.sqrt(self.embed_dim)
        context_use /= np.sqrt(self.embed_dim)

        ck_attn = th.bmm(know_use, context_use.unsqueeze(-1)).squeeze(-1)
        # fill with near -inf
        #~はInvert-2^(N-1) to 2^(N-1)-1
        ck_attn.masked_fill_(~self.ck_mask, neginf(context_encoded.dtype))

        # pick the true chosen sentence. remember that TransformerEncoder outputs
        #   (batch, time, embed)
        # but because know_encoded is a flattened, it's really
        #   (N * K, T, D)
        # We need to compute the offsets of the chosen_sentences
        cs_encoded = None
        softmax_cs_weight = th.nn.functional.softmax(
            (ck_attn * self.knowledge_lamda), dim=1)
        """
        #cs_idは0 softmax_cs_weightは(B,knowledge)
        true_ids_weight = th.zeros(softmax_cs_weight.shape, device=softmax_cs_weight.device, dtype=softmax_cs_weight.dtype)
        for temp in true_ids_weight:
            temp[0] = 1

        loss = softmax_cs_weight - true_ids_weight
        loss = loss * loss 
        loss[loss == 0] = 0.000001
        loss = th.sqrt(loss)
        loss = th.sum(loss) / N
        #print(loss)

        self.know_tokens = None
        self.ck_mask = None
        self.cs_ids = None
        self.use_cs_ids = None
        # also return the knowledge selection mask for the loss
        """
        return softmax_cs_weight
コード例 #7
0
  def forward(self, x, mask):
    x = self.linear(x)
    x = self.act(x)

    attn = self.attn_wei(x).squeeze(-1)
    attn.masked_fill_(~mask, neginf(x.dtype))
    attn = self.softmax(attn)

    x = th.einsum('btd,bt->bd', x, attn)

    x = self.final(x)
    return x
コード例 #8
0
    def modify_logprobs(self, logprobs: torch.Tensor) -> torch.Tensor:
        """
        Modify logprobs in PACER.

        The way it works:

        1. With frequency r, select a token x_i+1 to re-rank.
        2. Generate word probabilities for token x_i+1.
        3. Examine top k words {x_j | score(x_j) in top_k(P(x_i+1 | x_0,...,x_i))}; use classifier to predict P(a|x1, ..., x_i, x_j)
        4. Rescore top k words via multiplication, re-normalize, and advance the generation.

        :param logprobs:
            initial token probabilities

        :return modified:
            return the modified log probabilities according to PACER
        """
        if random.random() > self.frequency:
            return logprobs
        vals, inds = logprobs.topk(self.n_toks, dim=-1, sorted=False)
        new_probs = logprobs.clone().fill_(neginf(logprobs.dtype))
        # Construct partial hypotheses for each beam for each top K tokens
        batch_hyps = [
            h
            for i in range(len(self.partial_hyps))
            for h in [
                self.agent._v2t(self.partial_hyps[i][1:] + [ind]) for ind in inds[i]
            ]
        ]
        # Classify all beam outputs
        predictor_outputs = self.classifier.batch_classify(
            [self.context_str] * self.n_toks * logprobs.size(0), batch_hyps
        )
        # Extract RPA scores
        log_predictor_scores = (
            torch.stack(
                [
                    F.log_softmax(pred['sorted_scores'].float(), dim=0)[
                        int(pred['text'] == self.character) - 1
                    ]
                    for pred in predictor_outputs
                ]
            )
            .to(vals.device)
            .view(vals.size())
        )
        # "Multiply" Probabilities (in log space...)
        scores = vals + log_predictor_scores
        for i in range(new_probs.size(0)):
            new_probs[i, inds[i]] = scores[i]
        return F.log_softmax(new_probs, dim=-1, dtype=torch.float32)  # type: ignore
コード例 #9
0
  def forward(self, x, mask):
    # import ipdb; ipdb.set_trace()
    # x: N x T x D
    N, T, D = x.shape
    x = self.linear(x).view(N, T, self.out, D)
    x = self.act(x)

    attn = self.attn_wei(x).squeeze(-1)
    attn.masked_fill_(~mask[:, :, None], neginf(x.dtype))
    attn = self.softmax(attn)

    x = th.einsum('btod,bto->bod', x, attn)
    x = self.proj(x)

    x = th.einsum('bod,vd->bov', x, self.embeddings)
    return x
コード例 #10
0
    def forward(
        self,
        xs: torch.Tensor,
        ys: torch.Tensor,
        mask_ys: Optional[torch.Tensor] = None,
        values: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Compute attention.

        Attend over ys with query xs to obtain weights, then apply weights to
        values (ys if yalues is None)

        Args:
            xs: B x query_len x dim (queries)
            ys: B x key_len x dim (keys)
            mask_ys: B x key_len (mask)
            values: B x value_len x dim (values); if None, default to ys
        """
        bsz = xs.size(0)
        y_len = ys.size(1)
        x_len = xs.size(1)
        if self.attn == 'cosine':
            l1 = self.cosine(xs, ys).unsqueeze(self.dim - 1)
        else:
            l1 = torch.bmm(xs, ys.transpose(1, 2))
            if self.attn == 'sqrt':
                d_k = ys.size(-1)
                l1 = l1 / math.sqrt(d_k)
        if mask_ys is not None:
            attn_mask = (mask_ys == 0).view(bsz, 1, y_len)
            attn_mask = attn_mask.repeat(1, x_len, 1)
            l1.masked_fill_(attn_mask, neginf(l1.dtype))
        l2 = F.softmax(l1, dim=self.dim, dtype=torch.float).type_as(l1)
        if values is None:
            values = ys
        lhs_emb = torch.bmm(l2, values)

        # # add back the query
        if self.residual:
            lhs_emb = lhs_emb.add(xs)

        if self.get_weights:
            return lhs_emb.squeeze(self.dim - 1), l2
        else:
            return lhs_emb.squeeze(self.dim - 1)
コード例 #11
0
    def forward(self, src_tokens, know_tokens, ck_mask, cs_ids, use_cs_ids):
        # encode the context, pretty basic
        context_encoded, context_mask = self.transformer(src_tokens)

        # make all the knowledge into a 2D matrix to encode
        N, K, Tk = know_tokens.size()
        know_flat = know_tokens.reshape(-1, Tk)
        know_encoded, know_mask = self.transformer(know_flat)

        # compute our sentence embeddings for context and knowledge
        context_use = universal_sentence_embedding(context_encoded,
                                                   context_mask)
        know_use = universal_sentence_embedding(know_encoded, know_mask)

        # remash it back into the shape we need
        know_use = know_use.reshape(N, know_tokens.size(1), self.embed_dim)
        context_use /= np.sqrt(self.embed_dim)
        know_use /= np.sqrt(self.embed_dim)

        ck_attn = th.bmm(know_use, context_use.unsqueeze(-1)).squeeze(-1)
        # fill with near -inf
        ck_attn.masked_fill_(~ck_mask, neginf(context_encoded.dtype))

        if not use_cs_ids:
            # if we're not given the true chosen_sentence (test time), pick our
            # best guess
            _, cs_ids = ck_attn.max(1)

        # pick the true chosen sentence. remember that TransformerEncoder outputs
        #   (batch, time, embed)
        # but because know_encoded is a flattened, it's really
        #   (N * K, T, D)

        # We need to compute the offsets of the chosen_sentences
        cs_offsets = th.arange(N, device=cs_ids.device) * K + cs_ids
        cs_encoded = know_encoded[cs_offsets]
        # but padding is (N * K, T)
        cs_mask = know_mask[cs_offsets]

        # finally, concatenate it all
        full_enc = th.cat([cs_encoded, context_encoded], dim=1)
        full_mask = th.cat([cs_mask, context_mask], dim=1)

        # also return the knowledge selection mask for the loss
        return full_enc, full_mask, ck_attn
コード例 #12
0
    def forward(self, input):
        """
        Compute scores from inputs.

        :param input: (bsz x seq_len x num_directions * hiddensize) tensor of
                       states, e.g. the output states of an RNN

        :returns: (bsz x seqlen x num_cands) scores for each candidate
        """
        # next compute scores over dictionary
        if self.numsoftmax > 1:
            bsz = input.size(0)
            seqlen = input.size(1) if input.dim() > 1 else 1

            # first compute different softmax scores based on input vec
            # hsz => numsoftmax * esz
            latent = self.latent(input)
            active = self.dropout(self.activation(latent))
            # esz => num_features
            logit = F.linear(active.view(-1, self.esz), self.weight, self.bias)

            # calculate priors: distribution over which softmax scores to use
            # hsz => numsoftmax
            prior_logit = self.prior(input).view(-1, self.numsoftmax)
            # softmax over numsoftmax's
            prior = self.softmax(prior_logit)

            # now combine priors with logits
            prob = self.softmax(logit).view(bsz * seqlen, self.numsoftmax, -1)
            probs = (prob * prior.unsqueeze(2)).sum(1).view(bsz, seqlen, -1)
            scores = probs.log()
        else:
            # hsz => esz, good time for dropout
            e = self.dropout(self.o2e(input))
            # esz => num_features
            scores = F.linear(e, self.weight, self.bias)

        if self.padding_idx >= 0:
            scores[:, :, self.padding_idx] = neginf(scores.dtype)

        return scores
コード例 #13
0
  def forward(self, src_tokens, know_tokens, ck_mask, res_tokens=None):
    # encode the context, pretty basic
    context_encoded, context_mask = self.transformer(src_tokens)

    # make all the knowledge into a 2D matrix to encode
    # knowledge is intent for customer and tickets for agent
    N, K, Tk = know_tokens.size()
    know_flat = know_tokens.reshape(-1, Tk)
    know_encoded, know_mask = self.knowledge_transformer(know_flat)

    if self.agenttype == 'customer':
      ck_attn = None
      intent_out = None
      name_out = None
      cs_encoded = know_encoded
      cs_mask = know_mask
    elif self.agenttype == 'agent':
      # import ipdb; ipdb.set_trace()

      # compute our sentence embeddings for context and knowledge
      context_use = universal_sentence_embedding(context_encoded, context_mask)
      know_use = universal_sentence_embedding(know_encoded, know_mask)

      # remash it back into the shape we need
      know_use = know_use.reshape(N, K, self.embed_dim)
      # project before calculate attn
      know_use_proj = self.know_use_project(know_use)
      ck_attn = th.bmm(know_use_proj, context_use.unsqueeze(-1)).squeeze(-1)
      ck_attn /= np.sqrt(self.embed_dim)
      # fill with near -inf
      ck_attn.masked_fill_(~ck_mask, neginf(context_encoded.dtype))

      # Compute context knowledge attn prob
      ck_prob = nn.functional.softmax(ck_attn, dim=-1)

      _, cs_ids = ck_attn.max(1)

      # pick the true chosen sentence. remember that TransformerEncoder outputs
      #   (batch, time, embed)
      # but because know_encoded is a flattened, it's really
      #   (N * K, T, D)
      # We need to compute the offsets of the chosen_sentences
      cs_offsets = th.arange(N, device=cs_ids.device) * K + cs_ids
      cs_encoded = know_encoded[cs_offsets]
      # but padding is (N * K, T)
      cs_mask = know_mask[cs_offsets]

      # compute reservation embeddings
      res_encoded, res_mask = self.reservation_transformer(res_tokens)

      # finally, concatenate it all
      cs_encoded = th.cat([know_use, cs_encoded, res_encoded], dim=1)
      cs_mask = th.cat([ck_mask, cs_mask, res_mask], dim=1)

      # intent prediction
      intent_out = self.intent_head(context_encoded, context_mask)
      name_out = self.name_head(context_encoded, context_mask)

    # finally, concatenate it all
    full_enc = th.cat([cs_encoded, context_encoded], dim=1)
    full_mask = th.cat([cs_mask, context_mask], dim=1)

    # also return the knowledge selection mask for the loss
    return full_enc, full_mask, ck_attn, intent_out, name_out
コード例 #14
0
ファイル: modules.py プロジェクト: dilenadan/ParlAI
    def forward(  # type: ignore
        # TODO: remove type ignore with pytorch 1.5:
        # https://github.com/pytorch/pytorch/pull/31057
        self,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        value: Optional[torch.Tensor] = None,
        mask: torch.Tensor = None,
        incr_state: Optional[Dict[str, torch.Tensor]] = None,
        static_kv: bool = False,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """
        Forward pass.

        :param query: attention query
        :param key: attention key
        :param value: attention value
        :param mask: tensor in which True means that we are allowing attention and False
          means we are blocking it. Mask is:
          - [B, key_len] (encoder self-attn and decoder enc/dec attn)
          - [B, query_len, key_len] (decoder self-attn)
          - [B, 1, key_len] (decoder self-attn with incr_state caching)
        :param incr_state: dictionary with values representing the previous states of
          the key, value, and mask
        :param static_kv: True if the key and value are held constant during decoding
          (as in encoder/decoder attention)
        :return: (
          final attended tensor,
          new incremental state,
          key/value-multiplied tensor before softmax,
        )
        """

        batch_size, query_len, dim = query.size()
        assert (
            dim == self.dim
        ), 'Dimensions do not match: {} query vs {} configured'.format(dim, self.dim)
        assert mask is not None, 'Mask is None, please specify a mask'
        n_heads = self.n_heads
        dim_per_head = dim // n_heads
        scale = math.sqrt(dim_per_head)

        def prepare_head(tensor):
            # input is [batch_size, seq_len, n_heads * dim_per_head]
            # output is [batch_size * n_heads, seq_len, dim_per_head]
            bsz, seq_len, _ = tensor.size()
            tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head)
            tensor = (
                tensor.transpose(1, 2)
                .contiguous()
                .view(batch_size * n_heads, seq_len, dim_per_head)
            )
            return tensor

        # q, k, v are the transformed values
        if key is None and value is None:
            # self attention
            key = value = query
            _, _key_len, dim = query.size()
        elif value is None:
            # key and value are the same, but query differs
            # self attention
            value = key

        assert key is not None  # let mypy know we sorted this
        _, _key_len, dim = key.size()

        q = prepare_head(self.q_lin(query))
        k = prepare_head(self.k_lin(key))
        v = prepare_head(self.v_lin(value))

        # Prepend incremental states. For each of the key, value, and mask, see if
        # a previous incremental state exists, and if so, reshape it to match the shape
        # of the new state. Concatenate the previous and new states to match what the
        # full state would have been if we had not cached. (If we are using static_kv,
        # these three states are unchanging, so just re-use the cached states.)
        if incr_state is None:
            incr_state = {}
        if 'prev_key' in incr_state:
            prev_key = incr_state['prev_key'].view(
                batch_size * n_heads, -1, dim_per_head
            )
            if static_kv:
                k = prev_key
            else:
                k = torch.cat([prev_key, k], dim=1)
        if 'prev_value' in incr_state:
            prev_value = incr_state['prev_value'].view(
                batch_size * n_heads, -1, dim_per_head
            )
            if static_kv:
                v = prev_value
            else:
                v = torch.cat([prev_value, v], dim=1)
        if 'prev_mask' in incr_state:
            if static_kv:
                mask = incr_state['prev_mask']
            else:
                mask = torch.cat([incr_state['prev_mask'], mask], dim=2)
                # Prepend along the key_len dimension (analogous to
                # incr_state['prev_key'])

        # Save new incremental states. We reshape to allow for reordering along batch
        # dimension.
        new_incr_state = {
            'prev_key': k.view(batch_size, n_heads, -1, dim_per_head),
            'prev_value': v.view(batch_size, n_heads, -1, dim_per_head),
            'prev_mask': mask,
        }

        full_key_len = k.size(1)
        dot_prod = q.div_(scale).bmm(k.transpose(1, 2))
        # [B * n_heads, query_len, key_len]
        attn_mask = (
            (mask == 0)
            .view(batch_size, 1, -1, full_key_len)
            .repeat(1, n_heads, 1, 1)
            .expand(batch_size, n_heads, query_len, full_key_len)
            .view(batch_size * n_heads, query_len, full_key_len)
        )
        assert attn_mask.shape == dot_prod.shape
        dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype))

        attn_weights = F.softmax(
            dot_prod, dim=-1, dtype=torch.float  # type: ignore
        ).type_as(query)
        attn_weights = self.attn_dropout(attn_weights)  # --attention-dropout

        attentioned = attn_weights.bmm(v)
        attentioned = (
            attentioned.type_as(query)
            .view(batch_size, n_heads, query_len, dim_per_head)
            .transpose(1, 2)
            .contiguous()
            .view(batch_size, query_len, dim)
        )

        out = self.out_lin(attentioned)

        return out, new_incr_state, dot_prod
コード例 #15
0
ファイル: wdgenerator.py プロジェクト: convobox/ParlAI
    def _generate(
        self,
        batch: Batch,
        beam_size: int,
        max_ts: int,
        prefix_tokens: tp.Optional[torch.LongTensor] = None,
    ):
        """
        Generate an output with beam search.

        Depending on the options, this may perform greedy/topk/nucleus generation.

        :param Batch batch:
            Batch structure with input and labels
        :param int beam_size:
            Size of each beam during the search
        :param int max_ts:
            the maximum length of the decoded sequence
        :param prefix_tokens:
            if given, a tensor of tokens that must begin the decoded sequence.

        :return:
            tuple (beam_pred_scores, beams)

            - beam_preds_scores: list of (prediction, score) pairs for each sample in
              Batch
            - beams :list of Beam instances defined in Beam class, can be used for any
              following postprocessing, e.g. dot logging.
        """
        model = self.model
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model = self.model.module
        encoder_states = model.encoder(*self._encoder_input(batch))
        if batch.text_vec is not None:
            dev = batch.text_vec.device
        else:
            assert batch.label_vec is not None, "need label_vec for _generate"
            dev = batch.label_vec.device

        bsz = (
            len(batch.text_lengths) if batch.text_lengths is not None else len(
                batch.image)  # type: ignore
        )
        if batch.text_vec is not None:
            batchsize = batch.text_vec.size(0)
            beams = [
                self._treesearch_factory(dev).set_context(
                    self._get_context(batch, batch_idx)).set_block_list(
                        self.beam_block_list) for batch_idx in range(batchsize)
            ]
        else:
            beams = [self._treesearch_factory(dev) for _ in range(bsz)]

        # repeat encoder outputs and decoder inputs
        decoder_input = self._get_initial_decoder_input(bsz, beam_size, dev)

        inds = torch.arange(bsz).to(dev).unsqueeze(1).repeat(
            1, beam_size).view(-1)
        encoder_states = model.reorder_encoder_states(encoder_states, inds)
        incr_state = None

        for _ts in range(max_ts):
            if all((b.is_done() for b in beams)):
                # exit early if possible
                break

            score, incr_state = model.decoder(decoder_input, encoder_states,
                                              incr_state)
            # only need the final hidden state to make the word prediction
            score = score[:, -1:, :]
            score = model.output(score)
            # score contains softmax scores for bsz * beam_size samples
            score = score.view(bsz, beam_size, -1)
            if self.temperature != 1.0:
                score.div_(self.temperature)
            # force to fp32 to avoid overflow issues during search calculations
            score = F.log_softmax(score, dim=-1,
                                  dtype=torch.float32)  # type: ignore
            if prefix_tokens is not None and _ts < prefix_tokens.size(1):
                # generate prefix_tokens for every timestep that they exist
                # achieve by setting score of all other tokens to be -inf
                prefix_toks = prefix_tokens[:, _ts].unsqueeze(-1).repeat(
                    1, beam_size)
                prefix_score = score.gather(-1, prefix_toks.unsqueeze(-1))
                prefix_mask = prefix_toks.ne(self.NULL_IDX)
                score[prefix_mask] = neginf(score.dtype)
                score[prefix_mask] = score[prefix_mask].scatter_(
                    -1,
                    prefix_toks[prefix_mask].unsqueeze(-1),
                    prefix_score[prefix_mask],
                )
            for i, b in enumerate(beams):
                if not b.is_done():
                    score_in = score[i]
                    score_in += self._nidf_feats.to(dev)
                    b.advance(score_in)
            incr_state_inds = torch.cat([
                beam_size * i + b.get_backtrack_from_current_step()
                for i, b in enumerate(beams)
            ])
            incr_state = model.reorder_decoder_incremental_state(
                incr_state, incr_state_inds)
            selection = torch.cat([
                b.get_output_from_current_step() for b in beams
            ]).unsqueeze(-1)
            decoder_input = self._get_next_decoder_input(
                decoder_input, selection, incr_state_inds)

        # get all finalized candidates for each sample (and validate them)
        n_best_beam_preds_scores = [b.get_rescored_finished() for b in beams]
        if hasattr(self, '_rerank_beams'):
            n_best_beam_preds_scores = self._rerank_beams(  # type: ignore
                batch, n_best_beam_preds_scores)

        # get the top prediction for each beam (i.e. minibatch sample)
        beam_preds_scores = [
            n_best_list[0] for n_best_list in n_best_beam_preds_scores
        ]
        return beam_preds_scores, beams
コード例 #16
0
    def advance(self, logprobs):
        """
        Advance the beam one step.
        """
        current_length = len(self.all_scores) - 1
        if current_length < self.min_length:
            # penalize all eos probs to make it decode longer
            for hyp_id in range(logprobs.size(0)):
                logprobs[hyp_id][self.eos] = neginf(logprobs.dtype)

        if self.scores is None:
            self.scores = torch.zeros(1).type_as(logprobs).to(logprobs.device)

        # penalize hypotheses ending in EOS on the prior scores (self.scores) level
        # this is related to search which uses prior scores (self.scores) (e.g. beam)
        for hyp_id, token in enumerate(self.outputs[-1]):
            if token == self.eos:
                self.scores[hyp_id] = neginf(self.scores.dtype)

        # beam blocking
        if self.block_ngram > 0:
            logprobs = self._block_ngrams(self.block_ngram, logprobs, None)

        if self.context_block_ngram > 0:
            if self.context is None:
                raise ValueError(
                    "Must use TreeSearch.set_context to use context blocking."
                )
            logprobs = self._block_ngrams(
                self.context_block_ngram, logprobs, self.context
            )

        hyp_ids, tok_ids, self.scores = self.select_paths(logprobs, self.scores)
        # use clone() here to ensure that self.all_scores will not be changed
        # later due to any penalties to self.scores
        self.all_scores.append(self.scores.clone())

        self.outputs.append(tok_ids)
        self.bookkeep.append(hyp_ids)
        self.partial_hyps = [
            self.partial_hyps[hyp_ids[i]] + [tok_ids[i].item()]
            for i in range(self.beam_size)
        ]

        #  check new hypos for eos label, if we have some, add to finished
        for hypid in range(self.beam_size):
            if self.outputs[-1][hypid] == self.eos:
                if self.scores[hypid] == neginf(self.scores.dtype):
                    continue
                #  this is finished hypo, adding to finished
                eostail = _HypothesisTail(
                    timestep=len(self.outputs) - 1,
                    hypid=hypid,
                    score=self.all_scores[-1][hypid],
                    tokenid=self.eos,
                )
                self.finished.append(eostail)
                self.n_best_counter += 1

        if self.outputs[-1][0] == self.eos:
            self.eos_top = True
            if self.eos_top_ts is None:
                self.eos_top_ts = len(self.outputs) - 1
コード例 #17
0
    def forward(self, xes, hidden, attn_params):
        """
        Compute attention over attn_params given input and hidden states.

        :param xes:         input state. will be combined with applied
                            attention.
        :param hidden:      hidden state from model. will be used to select
                            states to attend to in from the attn_params.
        :param attn_params: tuple of encoder output states and a mask showing
                            which input indices are nonzero.

        :returns: output, attn_weights
                  output is a new state of same size as input state `xes`.
                  attn_weights are the weights given to each state in the
                  encoder outputs.
        """
        if self.attention == 'none':
            # do nothing, no attention
            return xes, None

        if type(hidden) == tuple:
            # for lstms use the "hidden" state not the cell state
            hidden = hidden[0]
        last_hidden = hidden[-1]  # select hidden state from last RNN layer

        enc_out, attn_mask = attn_params
        bsz, seqlen, hszXnumdir = enc_out.size()
        numlayersXnumdir = last_hidden.size(1)

        if self.attention == 'local':
            # local attention weights aren't based on encoder states
            h_merged = torch.cat((xes.squeeze(1), last_hidden), 1)
            attn_weights = F.softmax(self.attn(h_merged), dim=1)

            # adjust state sizes to the fixed window size
            if seqlen > self.max_length:
                offset = seqlen - self.max_length
                enc_out = enc_out.narrow(1, offset, self.max_length)
                seqlen = self.max_length
            if attn_weights.size(1) > seqlen:
                attn_weights = attn_weights.narrow(1, 0, seqlen)
        else:
            hid = last_hidden.unsqueeze(1)
            if self.attention == 'concat':
                # concat hidden state and encoder outputs
                hid = hid.expand(bsz, seqlen, numlayersXnumdir)
                h_merged = torch.cat((enc_out, hid), 2)
                # then do linear combination of them with activation
                active = F.tanh(self.attn(h_merged))
                attn_w_premask = self.attn_v(active).squeeze(2)
            elif self.attention == 'dot':
                # dot product between hidden and encoder outputs
                if numlayersXnumdir != hszXnumdir:
                    # enc_out has two directions, so double hid
                    hid = torch.cat([hid, hid], 2)
                enc_t = enc_out.transpose(1, 2)
                attn_w_premask = torch.bmm(hid, enc_t).squeeze(1)
            elif self.attention == 'general':
                # before doing dot product, transform hidden state with linear
                # same as dot if linear is identity
                hid = self.attn(hid)
                enc_t = enc_out.transpose(1, 2)
                attn_w_premask = torch.bmm(hid, enc_t).squeeze(1)

            # calculate activation scores, apply mask if needed
            if attn_mask is not None:
                # remove activation from NULL symbols
                attn_w_premask.masked_fill_((~attn_mask),
                                            neginf(attn_w_premask.dtype))
            attn_weights = F.softmax(attn_w_premask, dim=1)

        # apply the attention weights to the encoder states
        attn_applied = torch.bmm(attn_weights.unsqueeze(1), enc_out)
        # concatenate the input and encoder states
        merged = torch.cat((xes.squeeze(1), attn_applied.squeeze(1)), 1)
        # combine them with a linear layer and tanh activation
        output = torch.tanh(self.attn_combine(merged).unsqueeze(1))

        return output, attn_weights
コード例 #18
0
    def forward(self, src_tokens, know_tokens, ck_mask, cs_ids, use_cs_ids):
        # encode the context, pretty basic
        #N:バッチサイズ, K:知識数, T:時間, D:埋め込みサイズ, Tk:
        #src_tokens torch.Size([B, T])
        #cs_ids tensor([0, 0, 0, 0], device='cuda:0')
        #use_cs_ids trainならTrue

        self.know_tokens = know_tokens
        self.ck_mask = ck_mask
        self.cs_ids = cs_ids
        self.use_cs_ids = use_cs_ids

        context_encoded, context_mask = self.transformer(src_tokens)

        # make all the knowledge into a 2D matrix to encode
        N, K, Tk = know_tokens.size()
        know_encoded, know_mask = self.transformer(know_tokens.reshape(-1, Tk))

        # compute our sentence embeddings for context and knowledge
        context_use = universal_sentence_embedding(context_encoded,
                                                   context_mask)
        know_use = universal_sentence_embedding(know_encoded, know_mask)

        # remash it back into the shape we need
        know_use = know_use.reshape(N, know_tokens.size(1),
                                    self.embed_dim) / np.sqrt(self.embed_dim)
        context_use /= np.sqrt(self.embed_dim)

        ck_attn = th.bmm(know_use, context_use.unsqueeze(-1)).squeeze(-1)
        # fill with near -inf
        ck_attn.masked_fill_(~ck_mask, neginf(context_encoded.dtype))

        if self.soft_attention:
            # pick the true chosen sentence. remember that TransformerEncoder outputs
            #   (batch, time, embed)
            # but because know_encoded is a flattened, it's really
            #   (N * K, T, D)
            # We need to compute the offsets of the chosen_sentences
            cs_encoded = None
            softmax_cs_weight = th.nn.functional.softmax(
                (ck_attn * self.knowledge_lamda), dim=1)
            #add
            true_ids_weight = th.zeros(softmax_cs_weight.shape,
                                       device=softmax_cs_weight.device,
                                       dtype=softmax_cs_weight.dtype)
            for temp in true_ids_weight:
                temp[0] = 1

            weight_abs = th.abs(softmax_cs_weight - true_ids_weight)
            weight_abs *= weight_abs
            _, T, D = know_encoded.size()
            # finally, concatenate it all
            full_enc = th.cat([(know_encoded.reshape(
                (N * K, -1)) * th.nn.functional.softmax(
                    (ck_attn * self.knowledge_lamda), dim=1).reshape(
                        -1, 1).expand(N * K, T * D)).reshape(
                            (N, K, T, D)).sum(dim=1), context_encoded],
                              dim=1)
            full_mask = th.cat([
                know_mask[th.arange(N, device=cs_ids.device) * K], context_mask
            ],
                               dim=1)

            # also return the knowledge selection mask for the loss
            return full_enc, full_mask, ck_attn

        else:
            if not use_cs_ids:
                # if we're not given the true chosen_sentence (test time), pick our
                # best guess
                # cs_idsが使われるやつ
                _, cs_ids = ck_attn.max(1)
                #_, cs_ids = self.second_max(ck_attn, 1)

            # pick the true chosen sentence. remember that TransformerEncoder outputs
            #   (batch, time, embed)
            # but because know_encoded is a flattened, it's really
            #   (N * K, T, D)
            # We need to compute the offsets of the chosen_sentences
            cs_offsets = th.arange(N, device=cs_ids.device) * K + cs_ids
            cs_encoded = know_encoded[cs_offsets]
            # but padding is (N * K, T)
            cs_mask = know_mask[cs_offsets]

            # finally, concatenate it all
            full_enc = th.cat([cs_encoded, context_encoded], dim=1)
            full_mask = th.cat([cs_mask, context_mask], dim=1)

            # also return the knowledge selection mask for the loss
            return full_enc, full_mask, ck_attn
コード例 #19
0
 def test_neginf(self):
     assert neginf(torch.float32) < -1e15
     assert neginf(torch.float16) > -1e15
     assert neginf(torch.float16) < -1e4
コード例 #20
0
ファイル: modules.py プロジェクト: kevinlim/ParlAI
    def forward(self, query, key=None, value=None, mask=None):
        """
        Forward pass.
        """
        # TODO: there are a lot of parameters to document here.

        # Input is [B, query_len, dim]
        # Mask is [B, key_len] (selfattn) or [B, key_len, key_len] (enc attn)
        batch_size, query_len, dim = query.size()
        assert (dim == self.dim
                ), 'Dimensions do not match: {} query vs {} configured'.format(
                    dim, self.dim)
        assert mask is not None, 'Mask is None, please specify a mask'
        n_heads = self.n_heads
        dim_per_head = dim // n_heads
        scale = math.sqrt(dim_per_head)

        def prepare_head(tensor):
            # input is [batch_size, seq_len, n_heads * dim_per_head]
            # output is [batch_size * n_heads, seq_len, dim_per_head]
            bsz, seq_len, _ = tensor.size()
            tensor = tensor.view(batch_size, tensor.size(1), n_heads,
                                 dim_per_head)
            tensor = (tensor.transpose(1, 2).contiguous().view(
                batch_size * n_heads, seq_len, dim_per_head))
            return tensor

        # q, k, v are the transformed values
        if key is None and value is None:
            # self attention
            key = value = query
        elif value is None:
            # key and value are the same, but query differs
            # self attention
            value = key
        _, key_len, dim = key.size()

        q = prepare_head(self.q_lin(query))
        k = prepare_head(self.k_lin(key))
        v = prepare_head(self.v_lin(value))

        dot_prod = q.div_(scale).bmm(k.transpose(1, 2))
        # [B * n_heads, query_len, key_len]
        attn_mask = ((mask == 0).view(batch_size, 1, -1, key_len).repeat(
            1, n_heads, 1, 1).expand(batch_size, n_heads, query_len,
                                     key_len).view(batch_size * n_heads,
                                                   query_len, key_len))
        assert attn_mask.shape == dot_prod.shape
        dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype))

        attn_weights = F.softmax(dot_prod, dim=-1).type_as(query)
        attn_weights = self.attn_dropout(attn_weights)  # --attention-dropout

        attentioned = attn_weights.bmm(v)
        attentioned = (attentioned.type_as(query).view(
            batch_size, n_heads, query_len, dim_per_head).transpose(
                1, 2).contiguous().view(batch_size, query_len, dim))

        out = self.out_lin(attentioned)

        return out