Ejemplo n.º 1
0
    def forward(self, context, question):

        context = LongVar(context)
        question = LongVar(question)

        batch_size, context_size = context.size()
        _, question_size = question.size()

        context = self.__(self.embed(context), 'context_emb')
        question = self.__(self.embed(question), 'question_emb')

        context = context.transpose(1, 0)
        C, _ = self.__(
            self.encode(context, init_hidden(batch_size, self.encode)), 'C')
        C = self.__(C.transpose(1, 0), 'C')
        s = self.__(self.sentinel(batch_size), 's')
        C = self.__(torch.cat([C, s], dim=1), 'C')

        question = question.transpose(1, 0)
        Q, _ = self.__(
            self.encode(question, init_hidden(batch_size, self.encode)), 'Q')
        Q = self.__(Q.transpose(1, 0), 'Q')
        s = self.__(self.sentinel(batch_size), 's')
        Q = self.__(torch.cat([Q, s], dim=1), 'Q')

        squashedQ = self.__(Q.view(batch_size * (question_size + 1), -1),
                            'squashedQ')
        transformedQ = self.__(F.tanh(self.linear(Q)), 'transformedQ')
        Q = self.__(Q.view(batch_size, question_size + 1, -1), 'Q')

        affinity = self.__(torch.bmm(C, Q.transpose(1, 2)), 'affinity')
        affinity = F.softmax(affinity, dim=-1)
        context_attn = self.__(affinity.transpose(1, 2), 'context_attn')
        question_attn = self.__(affinity, 'question_attn')

        context_question = self.__(torch.bmm(C.transpose(1, 2), question_attn),
                                   'context_question')
        context_question = self.__(
            torch.cat([Q, context_question.transpose(1, 2)], -1),
            'context_question')

        attn_cq = self.__(
            torch.bmm(context_question.transpose(1, 2), context_attn),
            'attn_cq')
        attn_cq = self.__(attn_cq.transpose(1, 2).transpose(0, 1), 'attn_cq')
        hidden = self.__(init_hidden(batch_size, self.attend), 'hidden')
        final_repr, _ = self.__(self.attend(attn_cq, hidden), 'final_repr')
        final_repr = self.__(final_repr.transpose(0, 1), 'final_repr')
        return final_repr[:, :-1]  #exclude sentinel
Ejemplo n.º 2
0
def batchop(datapoints, VOCAB, config, *args, **kwargs):
    indices = [d.id for d in datapoints]
    sequence = []

    for d in datapoints:
        sequence.append([VOCAB[w] for w in d.sequence])

    sequence = LongVar(config, pad_seq(sequence))
    sequence = sequence.transpose(1, 0)
    
    batch = indices, (sequence[:-1]), (sequence[1:])
    return batch