コード例 #1
0
ファイル: helpers.py プロジェクト: bkj/ParlAI
    def forward(self, token_ids, segment_ids, attention_mask):
        output_bert, output_pooler = self.bert_model(token_ids, segment_ids, attention_mask)
        
        # output_bert is a list of 12 (for bert base) layers.
        layer_of_interest = output_bert[self.layer_pulled]
        dtype = next(self.parameters()).dtype
        if self.add_transformer_layer:
            # Follow up by yet another transformer layer
            extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            extended_attention_mask = (
                (~extended_attention_mask).to(dtype) * neginf(dtype)
            )
            embedding_layer = self.additional_transformer_layer(
                layer_of_interest, extended_attention_mask)
        else:
            embedding_layer = layer_of_interest
        
        if self.aggregation == "mean":
            #  consider the average of all the output except CLS.
            # obviously ignores masked elements
            outputs_of_interest = embedding_layer[:, 1:, :]
            mask = attention_mask[:, 1:].type_as(embedding_layer).unsqueeze(2)
            sumed_embeddings = torch.sum(outputs_of_interest * mask, dim=1)
            nb_elems = torch.sum(attention_mask[:, 1:].type(dtype), dim=1).unsqueeze(1)
            embeddings = sumed_embeddings / nb_elems
            
        elif self.aggregation == "max":
            #  consider the max of all the output except CLS
            outputs_of_interest = embedding_layer[:, 1:, :]
            mask = (~attention_mask[:, 1:]).type(dtype).unsqueeze(2) * neginf(dtype)
            embeddings, _ = torch.max(outputs_of_interest + mask, dim=1)
            
        else:
            # easiest, we consider the output of "CLS" as the embedding
            embeddings = embedding_layer[:, 0, :]
            
# >> BKJ
        if self.output_linear is not None:
            result = self.output_linear(embeddings)
        else:
            result = embeddings
# <<

        # Sort of hack to make it work with distributed: this way the pooler layer
        # is used for grad computation, even though it does not change anything...
        # in practice, it just adds a very (768*768) x (768*batchsize) matmul
# >> BKJ
        # result += 0 * torch.sum(output_pooler)
# <<
        
        return result
コード例 #2
0
ファイル: modules.py プロジェクト: zhengwu34/ParlAI
    def forward(self, query_embs, in_mem_embs, out_mem_embs, pad_mask):
        """
        Compute MemNN Hop step.

        :param query_embs:
            (bsz x esz) embedding of queries

        :param in_mem_embs:
            bsz list of (num_mems x esz) embedding of memories for activation

        :param out_mem_embs:
            bsz list of (num_mems x esz) embedding of memories for outputs

        :param pad_mask
            (bsz x num_mems) optional mask indicating which tokens correspond to
            padding

        :returns:
            (bsz x esz) output state
        """
        # rotate query embeddings
        attn = torch.bmm(query_embs.unsqueeze(1), in_mem_embs).squeeze(1)
        if pad_mask is not None:
            attn[pad_mask] = neginf(attn.dtype)
        probs = self.softmax(attn)
        memory_output = torch.bmm(probs.unsqueeze(1), out_mem_embs).squeeze(1)
        output = memory_output + self.rotate(query_embs)
        return output
コード例 #3
0
    def forward(self, query, key=None, value=None, mask=None):
        """Forward pass."""
        # TODO: there are a lot of parameters to document here.

        # Input is [B, query_len, dim]
        # Mask is [B, key_len] (selfattn) or [B, key_len, key_len] (enc attn)
        batch_size, query_len, dim = query.size()
        assert (dim == self.dim
                ), 'Dimensions do not match: {} query vs {} configured'.format(
                    dim, self.dim)
        assert mask is not None, 'Mask is None, please specify a mask'
        n_heads = self.n_heads
        dim_per_head = dim // n_heads
        scale = math.sqrt(dim_per_head)

        def prepare_head(tensor):
            # input is [batch_size, seq_len, n_heads * dim_per_head]
            # output is [batch_size * n_heads, seq_len, dim_per_head]
            bsz, seq_len, _ = tensor.size()
            tensor = tensor.view(batch_size, tensor.size(1), n_heads,
                                 dim_per_head)
            tensor = (tensor.transpose(1, 2).contiguous().view(
                batch_size * n_heads, seq_len, dim_per_head))
            return tensor

        # q, k, v are the transformed values
        if key is None and value is None:
            # self attention
            key = value = query
        elif value is None:
            # key and value are the same, but query differs
            # self attention
            value = key
        _, key_len, dim = key.size()

        q = prepare_head(self.q_lin(query))
        k = prepare_head(self.k_lin(key))
        v = prepare_head(self.v_lin(value))

        dot_prod = q.div_(scale).bmm(k.transpose(1, 2))
        # [B * n_heads, query_len, key_len]
        attn_mask = ((mask == 0).view(batch_size, 1, -1, key_len).repeat(
            1, n_heads, 1, 1).expand(batch_size, n_heads, query_len,
                                     key_len).view(batch_size * n_heads,
                                                   query_len, key_len))
        assert attn_mask.shape == dot_prod.shape
        dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype))

        attn_weights = F.softmax(dot_prod, dim=-1).type_as(query)
        attn_weights = self.attn_dropout(attn_weights)  # --attention-dropout

        attentioned = attn_weights.bmm(v)
        attentioned = (attentioned.type_as(query).view(
            batch_size, n_heads, query_len, dim_per_head).transpose(
                1, 2).contiguous().view(batch_size, query_len, dim))

        out = self.out_lin(attentioned)

        return out
コード例 #4
0
ファイル: modules.py プロジェクト: Arran5353/dialog-probing
    def forward(self, src_tokens, know_tokens, ck_mask, cs_ids, use_cs_ids):
        # encode the context, pretty basic
        context_encoded, context_mask = self.transformer(src_tokens)

        # make all the knowledge into a 2D matrix to encode
        N, K, Tk = know_tokens.size()
        know_flat = know_tokens.reshape(-1, Tk)
        know_encoded, know_mask = self.transformer(know_flat)

        # compute our sentence embeddings for context and knowledge
        context_use = universal_sentence_embedding(context_encoded,
                                                   context_mask)
        know_use = universal_sentence_embedding(know_encoded, know_mask)

        # remash it back into the shape we need
        know_use = know_use.reshape(N, know_tokens.size(1), self.embed_dim)
        context_use /= np.sqrt(self.embed_dim)
        know_use /= np.sqrt(self.embed_dim)

        ck_attn = th.bmm(know_use, context_use.unsqueeze(-1)).squeeze(-1)
        # fill with near -inf
        ck_attn.masked_fill_(~ck_mask, neginf(context_encoded.dtype))

        if not use_cs_ids:
            # if we're not given the true chosen_sentence (test time), pick our
            # best guess
            _, cs_ids = ck_attn.max(1)

        # pick the true chosen sentence. remember that TransformerEncoder outputs
        #   (batch, time, embed)
        # but because know_encoded is a flattened, it's really
        #   (N * K, T, D)
        # We need to compute the offsets of the chosen_sentences
        cs_offsets = th.arange(N, device=cs_ids.device) * K + cs_ids
        cs_encoded = know_encoded[cs_offsets]
        # but padding is (N * K, T)
        cs_mask = know_mask[cs_offsets]

        # finally, concatenate it all
        full_enc = th.cat([cs_encoded, context_encoded], dim=1)
        full_mask = th.cat([cs_mask, context_mask], dim=1)

        # also return the knowledge selection mask for the loss
        return full_enc, full_mask, ck_attn
コード例 #5
0
    def advance(self, logprobs):
        """Advance the beam one step."""
        current_length = len(self.all_scores) - 1
        if current_length < self.min_length:
            # penalize all eos probs to make it decode longer
            for hyp_id in range(logprobs.size(0)):
                logprobs[hyp_id][self.eos] = neginf(logprobs.dtype)

        if self.scores is None:
            self.scores = torch.zeros(1).type_as(logprobs).to(logprobs.device)

        hyp_ids, tok_ids, self.scores = self.select_paths(
            logprobs, self.scores)
        self.all_scores.append(self.scores)

        self.outputs.append(tok_ids)
        self.bookkeep.append(hyp_ids)
        self.partial_hyps = [
            self.partial_hyps[hyp_ids[i]] + [tok_ids[i].item()]
            for i in range(self.beam_size)
        ]

        #  check new hypos for eos label, if we have some, add to finished
        for hypid in range(self.beam_size):
            if self.outputs[-1][hypid] == self.eos:
                #  this is finished hypo, adding to finished
                eostail = _HypothesisTail(
                    timestep=len(self.outputs) - 1,
                    hypid=hypid,
                    score=self.scores[hypid],
                    tokenid=self.eos,
                )
                self.finished.append(eostail)
                self.n_best_counter += 1

        if self.outputs[-1][0] == self.eos:
            self.eos_top = True
            if self.eos_top_ts is None:
                self.eos_top_ts = len(self.outputs) - 1
コード例 #6
0
ファイル: torch_generator_agent.py プロジェクト: zwcdp/KBRD
    def advance(self, softmax_probs):
        """Advance the beam one step."""
        voc_size = softmax_probs.size(-1)
        current_length = len(self.all_scores) - 1
        if current_length < self.min_length:
            # penalize all eos probs to make it decode longer
            for hyp_id in range(softmax_probs.size(0)):
                softmax_probs[hyp_id][self.eos] = neginf(softmax_probs.dtype)
        if len(self.bookkeep) == 0:
            # the first step we take only the first hypo into account since all
            # hypos are the same initially
            beam_scores = softmax_probs[0]
        else:
            # we need to sum up hypo scores and curr softmax scores before topk
            # [beam_size, voc_size]
            beam_scores = (softmax_probs +
                           self.scores.unsqueeze(1).expand_as(softmax_probs))
            for i in range(self.outputs[-1].size(0)):
                if self.block_ngram > 0:
                    current_hypo = self.partial_hyps[i][1:]
                    current_ngrams = []
                    for ng in range(self.block_ngram):
                        ngrams = Beam.find_ngrams(current_hypo, ng)
                        if len(ngrams) > 0:
                            current_ngrams.extend(ngrams)
                    counted_ngrams = Counter(current_ngrams)
                    if any(v > 1 for k, v in counted_ngrams.items()):
                        # block this hypothesis hard
                        beam_scores[i] = neginf(softmax_probs.dtype)

                #  if previous output hypo token had eos
                # we penalize those word probs to never be chosen
                if self.outputs[-1][i] == self.eos:
                    # beam_scores[i] is voc_size array for i-th hypo
                    beam_scores[i] = neginf(softmax_probs.dtype)

        flatten_beam_scores = beam_scores.view(-1)  # [beam_size * voc_size]
        with torch.no_grad():
            best_scores, best_idxs = torch.topk(flatten_beam_scores,
                                                self.beam_size,
                                                dim=-1)

        self.scores = best_scores
        self.all_scores.append(self.scores)
        # get the backtracking hypothesis id as a multiple of full voc_sizes
        hyp_ids = best_idxs / voc_size
        # get the actual word id from residual of the same division
        tok_ids = best_idxs % voc_size

        self.outputs.append(tok_ids)
        self.bookkeep.append(hyp_ids)
        self.partial_hyps = [
            self.partial_hyps[hyp_ids[i]] + [tok_ids[i].item()]
            for i in range(self.beam_size)
        ]

        #  check new hypos for eos label, if we have some, add to finished
        for hypid in range(self.beam_size):
            if self.outputs[-1][hypid] == self.eos:
                #  this is finished hypo, adding to finished
                eostail = self.HypothesisTail(timestep=len(self.outputs) - 1,
                                              hypid=hypid,
                                              score=self.scores[hypid],
                                              tokenid=self.eos)
                self.finished.append(eostail)
                self.n_best_counter += 1

        if self.outputs[-1][0] == self.eos:
            self.eos_top = True
            if self.eos_top_ts is None:
                self.eos_top_ts = len(self.outputs) - 1