Ejemplo n.º 1
0
    def forward(self, src, src_len, tgt, tgt_len, split='train'):
        
        n_batch=src.size(0)

        #Creating mask
        src_pad_mask = get_pad_mask(src, self.pad).to(self.device) #(B, S)
        tgt_pad_mask = get_pad_mask(tgt, self.pad).to(self.device) #(B, T)
        pad_mask = (src_pad_mask, tgt_pad_mask)

        attn_mask = generate_square_subsequent_mask(self.max_tgt_len)

        src = self.embed(src.to(self.device))
        tgt = self.embed(tgt.to(self.device))

        #permutation
        src = src.permute(1, 0, 2)
        tgt = tgt.permute(1, 0, 2)

        enc_output, attn = self.encoder(src, pad_mask=pad_mask)
        output, attn = self.decoder(tgt, enc_output, pad_mask=pad_mask, cross_mask=attn_mask)

        #permutation
        output = output.permute(1, 0, 2) * (self.scale ** 0.5)

        logits = F.log_softmax(self.final(output), dim=-1)
        generations = torch.argmax(logits, dim=-1)
        
        return logits, generations
Ejemplo n.º 2
0
def generate_lm_batch(batch, pad_idx, in_chat=False):
    if not in_chat:
        batch = torch.tensor(batch).T
    else:
        batch = pad_sequence(list(map(torch.tensor, batch)),
                             padding_value=pad_idx)
    x = batch[:-1]
    y = batch[1:]
    x_mask = utils.generate_square_subsequent_mask(x.shape[0])
    x_pad_mask = (x == pad_idx).T

    return LMFeature(x=x, y=y, x_mask=x_mask, x_pad_mask=x_pad_mask)
Ejemplo n.º 3
0
def generate_lm_batch(batch, vocab, is_mlm, in_chat=False):
    pad_idx = vocab.stoi(utils.PAD)
    mask_idx = vocab.stoi(utils.MASK)

    if not in_chat:
        x_mlm = None
        x_mlm_pad_mask = None
        batch = torch.tensor(batch).T
    elif not is_mlm:
        x_mlm = None
        x_mlm_pad_mask = None
        batch = pad_sequence(list(map(torch.tensor, batch)), padding_value=pad_idx)
    else:
        x_mlm = []
        for v in batch:
            l = len(v)
            if l == 3:
                # bos + w + eos
                x_mlm.append(v)
                continue
            mask = 1
            if l > 5:
                # mask = 2
                # 99/100 is 2, 1/100 is 0
                mask = min(2, random.randint(0, 99))
                if mask == 1:
                    mask = 2
            start = random.randint(1, l-mask-1)
            x_mlm.append(v[0:start] + [mask_idx] + v[start+mask:])

        batch = pad_sequence(list(map(torch.tensor, batch)), padding_value=pad_idx)
        x_mlm = pad_sequence(list(map(torch.tensor, x_mlm)), padding_value=pad_idx)
        x_mlm_pad_mask = (x_mlm == pad_idx).T

    x = batch[:-1]
    y = batch[1:]
    x_mask = utils.generate_square_subsequent_mask(x.shape[0])
    x_pad_mask = (x == pad_idx).T

    return LMFeature(x=x, y=y, x_mlm=x_mlm, 
            x_mask=x_mask, x_pad_mask=x_pad_mask,
            x_mlm_pad_mask=x_mlm_pad_mask)
Ejemplo n.º 4
0
def generate_batch(batch, pad_idx):
    context, segs, personas_no_tag, tags, resp, persona, lm = zip(*batch)

    fn = lambda x: list(map(torch.tensor, x))
    context_pad = pad_sequence(fn(context), padding_value=pad_idx)
    segs_pad = pad_sequence(fn(segs), padding_value=pad_idx)
    tags = itertools.chain(*tags)
    tags_pad = pad_sequence(fn(tags), padding_value=pad_idx)
    tags_pad = tags_pad.view(-1, 2, int(tags_pad.shape[1] / 2)).transpose(1, 0)
    resp_pad = pad_sequence(fn(resp), padding_value=pad_idx)
    src_pad_mask = (context_pad == pad_idx).T
    tgt_pad_mask = (resp_pad == pad_idx).T
    tgt_mask = utils.generate_square_subsequent_mask(resp_pad.shape[0])
    persona_pad = pad_sequence(fn(persona), padding_value=pad_idx)
    persona_pad_mask = (persona_pad == pad_idx).T
    # batch_size X n_persona X 2
    tmp = list(
        map(lambda x: pad_sequence(fn(x), padding_value=pad_idx),
            personas_no_tag))
    # n_persona X batch_size X 2 --> 2 X n_persona X batch_size
    personas_no_tag_pad = pad_sequence(tmp,
                                       padding_value=pad_idx).permute(2, 0, 1)

    lm = generate_lm_batch(lm, pad_idx, in_chat=True)

    return ChatFeature(
        context=context_pad,
        segs=segs_pad,
        personas_no_tag=personas_no_tag_pad,
        tags=tags_pad,
        resp=resp_pad,
        persona=persona_pad,
        context_pad_mask=src_pad_mask,
        resp_mask=tgt_mask,
        resp_pad_mask=tgt_pad_mask,
        persona_pad_mask=persona_pad_mask,
        lm=lm,
    )
Ejemplo n.º 5
0
def generate_batch(batch, vocab, persona_vocab, is_mlm):
    pad_idx = vocab.stoi(utils.PAD)
    persona_pad_idx = pad_idx
    if persona_vocab is not None:
        persona_pad_idx = persona_vocab.stoi(utils.PAD)
    context, segs, personas_no_tag, tags, resp, persona, lm = zip(*batch)

    post = []
    cls_idx = vocab.stoi(utils.CLS)
    sep_idx = vocab.stoi(utils.SEP)
    for v in context:
        post_start = list(reversed(v)).index(sep_idx)
        if _USE_BERT_FEATURE:
            # cls_idx for bert-like sentence rep, xlnet is last token
            post.append([cls_idx] + v[-post_start:])
        else:
            post.append(v[-post_start:-1])

    fn = lambda x: list(map(torch.tensor, x)) 
    context_pad = pad_sequence(fn(context), padding_value=pad_idx)
    post_pad = pad_sequence(fn(post), padding_value=pad_idx)
    segs_pad = pad_sequence(fn(segs), padding_value=pad_idx)
    tags = itertools.chain(*tags)
    tags_pad = pad_sequence(fn(tags), padding_value=persona_pad_idx)
    tags_pad = tags_pad.view(-1, 2, int(tags_pad.shape[1]/2)).transpose(1, 0)
    resp_pad = pad_sequence(fn(resp), padding_value=pad_idx)
    persona_pad = pad_sequence(fn(persona), padding_value=persona_pad_idx)

    context_pad_mask = (context_pad == pad_idx).T
    post_pad_mask = (post_pad == pad_idx).T
    tags_pad_mask = (tags_pad == persona_pad_idx).T
    resp_mask = utils.generate_square_subsequent_mask(resp_pad.shape[0])
    resp_pad_mask = (resp_pad == pad_idx).T
    persona_pad_mask = (persona_pad == persona_pad_idx).T
    # batch_size X n_persona X 2 
    tmp = list(map(lambda x: pad_sequence(fn(x), padding_value=persona_pad_idx),
            personas_no_tag))
    # n_persona X batch_size X 2 --> 2 X n_persona X batch_size
    personas_no_tag_pad = pad_sequence(tmp, padding_value=persona_pad_idx).permute(2, 0, 1) 
    personas_no_tag_pad_mask = (personas_no_tag_pad == persona_pad_idx).T

    lm = generate_lm_batch(lm, vocab, is_mlm, in_chat=True)

    return ChatFeature(
            context=context_pad,
            post=post_pad,
            segs=segs_pad,
            personas_no_tag=personas_no_tag_pad,
            tags=tags_pad,

            resp=resp_pad,
            persona=persona_pad,

            context_pad_mask=context_pad_mask,
            post_pad_mask=post_pad_mask,
            resp_mask=resp_mask,
            resp_pad_mask=resp_pad_mask,
            persona_pad_mask=persona_pad_mask,

            personas_no_tag_pad_mask=personas_no_tag_pad_mask,
            tags_pad_mask=tags_pad_mask,

            lm=lm,
    )
Ejemplo n.º 6
0
    def forward(
        self,
        input: torch.Tensor,
        input_mask: torch.Tensor,
        node_hidden: torch.Tensor,
        prev_action_hidden: torch.Tensor,
        prev_action_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        input: (batch, input_seq_len, hidden_dim)
        input_mask: (batch, input_seq_len)
        node_hidden: (batch, num_node, hidden_dim)
        prev_action_hidden: (batch, prev_action_len, hidden_dim)
        prev_action_mask: (batch, prev_action_len)

        output: (batch, input_seq_len, hidden_dim)
        """
        # calculate attention mask for decoding
        # this is the mask that prevents MultiheadAttention
        # from attending to future values
        input_seq_len = input.size(1)
        attn_mask = generate_square_subsequent_mask(input_seq_len).to(
            input.device)
        # (input_seq_len, input_seq_len)

        # add the positional encodings
        pos_encoded_input = self.pos_encoder(input)

        # self attention layer
        input_residual = pos_encoded_input
        # MultiheadAttention expects batch dim to be 1 for q, k, v
        # but 0 for key_padding_mask, so we need to transpose
        transposed_pos_encoded_input = pos_encoded_input.transpose(0, 1)
        input_attn, _ = self.self_attn(
            transposed_pos_encoded_input,
            transposed_pos_encoded_input,
            transposed_pos_encoded_input,
            key_padding_mask=input_mask == 0,
            attn_mask=attn_mask,
        )
        input_attn = input_attn.transpose(0, 1)
        input_attn *= input_mask.unsqueeze(-1)
        input_attn += input_residual
        # (batch, input_seq_len, hidden_dim)

        # calculate self attention for the nodes and previous action
        # strictly speaking, we should calculate attention masks for these
        # based on input_mask, but due to this bug:
        # https://github.com/pytorch/pytorch/issues/41508
        # it returns nan's if we apply attention masks. So let's just skip it.
        # It's OK, b/c we apply input_mask when we combine these.
        # apply layer norm to the input self attention output to calculate the query
        query = self.self_attn_layer_norm(input_attn).transpose(0, 1)
        # (input_seq_len, batch, hidden_dim)

        # self attention for the nodes
        # no key_padding_mask, since we use all the nodes
        # (batch * num_heads, input_seq_len, num_node)
        transposed_node_hidden = node_hidden.transpose(0, 1)
        node_attn, _ = self.node_attn(query, transposed_node_hidden,
                                      transposed_node_hidden)
        node_attn = node_attn.transpose(0, 1)
        # (batch, input_seq_len, hidden_dim)

        # self attention for the previous action
        # key_padding_mask is from prev_action_mask
        # (batch * num_heads, input_seq_len, prev_action_len)
        transposed_prev_action_hidden = prev_action_hidden.transpose(0, 1)
        prev_action_attn, _ = self.prev_action_attn(
            query,
            transposed_prev_action_hidden,
            transposed_prev_action_hidden,
            key_padding_mask=prev_action_mask == 0,
        )
        prev_action_attn = prev_action_attn.transpose(0, 1)
        # (batch, input_seq_len, hidden_dim)

        # combine self attention for the previous action and nodes with
        # input self attention
        combined_self_attn = self.combine_node_prev_action(
            torch.cat([prev_action_attn, node_attn], dim=-1))
        combined_self_attn *= input_mask.unsqueeze(-1)
        combined_self_attn += input_attn
        # (batch, input_seq_len, hidden_dim)

        # linear layer
        output = self.linear_layer_norm(combined_self_attn)
        output = self.linear_layers(output)
        output += combined_self_attn
        # (batch, input_seq_len, hidden_dim)

        return output
Ejemplo n.º 7
0
def test_generate_subsequent_mask(size):
    mask = generate_square_subsequent_mask(size)
    # assert that the sum of tril and triu is the original mask
    assert mask.equal(torch.tril(mask) + torch.triu(mask, diagonal=1))