Ejemplo n.º 1
0
    def forward(self, input_):
        idxs, inputs, targets = input_
        story, question = inputs
        story = self.__( story, 'story')
        question = self.__(question, 'question')

        batch_size, story_size  = story.size()
        batch_size, question_size = question.size()
        
        story  = self.__( self.embed(story),  'story_emb')
        question = self.__( self.embed(question), 'question_emb')

        story  = story.transpose(1,0)
        story, _  = self.__(  self.encode_story(
            story,
            init_hidden(batch_size, self.encode_story)), 'C'
        )
        
        question  = question.transpose(1,0)
        question, _ = self.__(  self.encode_question(
            question,
            init_hidden(batch_size, self.encode_question)), 'Q'
        )

        story = self.__( story.transpose(0,1), 'story')
        question_state = self.__( question[-1], 'question')
        question_state = self.__( question_state.unsqueeze(1).expand_as(story), 'question')
        merged = self.__(   self.attn_story(story)
                          + self.attn_question(question_state),  'merged' )
        
        return merged.transpose(0,1)
Ejemplo n.º 2
0
    def forward(self, input_):
        idxs, inputs, targets = input_
        story, question = inputs
        story = self.__( story, 'story')
        question = self.__(question, 'question')

        batch_size, story_size  = story.size()
        batch_size, question_size = question.size()
        
        story  = self.__( self.embed(story),  'story_emb')
        question = self.__( self.embed(question), 'question_emb')

        story  = story.transpose(1,0)
        story, _  = self.__(  self.encode_story(
            story,
            init_hidden(batch_size, self.encode_story)), 'C'
        )
        
        question  = question.transpose(1,0)
        question, _ = self.__(  self.encode_question(
            question,
            init_hidden(batch_size, self.encode_question)), 'Q'
        )

        c, m, r = [], [], []
        c.append(Var(np.zeros((batch_size, 2 * self.hidden_size))))
        m.append(Var(np.zeros((batch_size, 2 * self.hidden_size))))
        qi = self.dropout(self.produce_qi(torch.cat([question[-1], m[-1]], dim=-1)))

        qattns, sattns, mattns = [], [], []
        
        for i in range(config.HPCONFIG.reasoning_steps):

            ci, qattn = self.control(c[-1], qi, question, m[-1])
            ci = self.dropout(ci)

            ri, sattn = self.read(m[-1], ci, story)
            ri = self.dropout(ri)

            mi, mattn = self.write( m[-1], ri, ci, c, m )
            mi = self.dropout(mi)

            qi = self.dropout(self.produce_qi(torch.cat([qi, m[-1]], dim=-1)))
            
            c.append(ci)
            r.append(ri)
            m.append(mi)

            qattns.append(qattn)
            sattns.append(sattn)
            mattns.append(mattn)
            
        #projected_output = self.__( F.relu(self.project(torch.cat([qi, mi], dim=-1))), 'projected_output')
        return (self.__( F.log_softmax(self.answer(mi), dim=-1), 'return val'),
                (
                    torch.stack(sattns),
                    torch.stack(qattns),
                    torch.stack(mattns)
                )
        )
Ejemplo n.º 3
0
    def forward(self, input_):
        idxs, inputs, targets = input_
        story, question = inputs
        story = self.__( story, 'story')
        question = self.__(question, 'question')

        batch_size, story_size  = story.size()
        batch_size, question_size = question.size()
        
        story  = self.__( self.embed(story),  'story_emb')
        question = self.__( self.embed(question), 'question_emb')

        story  = story.transpose(1,0)
        story, _  = self.__(  self.encode_story(
            story,
            init_hidden(batch_size, self.encode_story)), 'C'
        )
        
        question  = question.transpose(1,0)
        question, _ = self.__(  self.encode_question(
            question,
            init_hidden(batch_size, self.encode_question)), 'Q'
        )

        story = F.tanh(story)
        question = F.tanh(question)

        attended_question_dist = self.__( self.attend_question(question.view(-1, 2 * self.hidden_size)), 'attended_question_dist')
        attended_question_dist = self.__( attended_question_dist.view(question_size, batch_size, 1), 'attended_question_dist')

        if self.config.HPCONFIG.ACTIVATION == 'softmax':
            attended_question_dist = F.softmax(attended_question_dist, dim=0)
        elif self.config.HPCONFIG.ACTIVATION == 'sigmoid':
            attended_question_dist = F.sigmoid(attended_question_dist)
            
        attended_question_repr = self.__( (attended_question_dist.expand_as(question) * question).sum(dim=0), 'attended_question_repr')
        attended_question_repr = self.__( attended_question_repr.unsqueeze(dim=1), 'attended_question_repr' )

        story_ = self.__( story.transpose(0, 1).transpose(1, 2), 'story_')

        attended_story_dist = self.__( torch.bmm(attended_question_repr, story_), 'attended_story_dist')
        attended_story_dist = self.__( attended_story_dist.transpose(1, 2).transpose(0, 1), 'attended_story_dist')
        
        if self.config.HPCONFIG.ACTIVATION == 'softmax':
            attended_story_dist = F.softmax(attended_story_dist, dim=0)
        elif self.config.HPCONFIG.ACTIVATION == 'sigmoid':
            attended_story_dist = F.sigmoid(attended_story_dist)

        attended_story_repr = (attended_story_dist.expand_as(story) * story).sum(dim=0)
        
        attended_repr = F.tanh(attended_story_repr)
        return (self.__( F.log_softmax(self.answer(attended_repr), dim=-1), 'return val'),
                attended_story_dist.transpose(0,1),
                attended_question_dist.transpose(0,1))
Ejemplo n.º 4
0
    def forward(self, context, question):

        context = LongVar(context)
        question = LongVar(question)

        batch_size, context_size = context.size()
        _, question_size = question.size()

        context = self.__(self.embed(context), 'context_emb')
        question = self.__(self.embed(question), 'question_emb')

        context = context.transpose(1, 0)
        C, _ = self.__(
            self.encode(context, init_hidden(batch_size, self.encode)), 'C')
        C = self.__(C.transpose(1, 0), 'C')
        s = self.__(self.sentinel(batch_size), 's')
        C = self.__(torch.cat([C, s], dim=1), 'C')

        question = question.transpose(1, 0)
        Q, _ = self.__(
            self.encode(question, init_hidden(batch_size, self.encode)), 'Q')
        Q = self.__(Q.transpose(1, 0), 'Q')
        s = self.__(self.sentinel(batch_size), 's')
        Q = self.__(torch.cat([Q, s], dim=1), 'Q')

        squashedQ = self.__(Q.view(batch_size * (question_size + 1), -1),
                            'squashedQ')
        transformedQ = self.__(F.tanh(self.linear(Q)), 'transformedQ')
        Q = self.__(Q.view(batch_size, question_size + 1, -1), 'Q')

        affinity = self.__(torch.bmm(C, Q.transpose(1, 2)), 'affinity')
        affinity = F.softmax(affinity, dim=-1)
        context_attn = self.__(affinity.transpose(1, 2), 'context_attn')
        question_attn = self.__(affinity, 'question_attn')

        context_question = self.__(torch.bmm(C.transpose(1, 2), question_attn),
                                   'context_question')
        context_question = self.__(
            torch.cat([Q, context_question.transpose(1, 2)], -1),
            'context_question')

        attn_cq = self.__(
            torch.bmm(context_question.transpose(1, 2), context_attn),
            'attn_cq')
        attn_cq = self.__(attn_cq.transpose(1, 2).transpose(0, 1), 'attn_cq')
        hidden = self.__(init_hidden(batch_size, self.attend), 'hidden')
        final_repr, _ = self.__(self.attend(attn_cq, hidden), 'final_repr')
        final_repr = self.__(final_repr.transpose(0, 1), 'final_repr')
        return final_repr[:, :-1]  #exclude sentinel
Ejemplo n.º 5
0
    def forward(self, input_):
        idxs, inputs, targets = input_
        context, question, _ = inputs
        context = self.__(context, 'context')
        question = self.__(question, 'question')

        batch_size, context_size = context.size()
        batch_size, question_size = question.size()

        context = self.__(self.embed(context), 'context_emb')
        question = self.__(self.embed(question), 'question_emb')

        context = context.transpose(1, 0)
        C, _ = self.__(
            self.encode_context(context,
                                init_hidden(batch_size, self.encode_context)),
            'C')

        question = question.transpose(1, 0)
        Q, _ = self.__(
            self.encode_question(question,
                                 init_hidden(batch_size,
                                             self.encode_question)), 'Q')
        return F.tanh(C), F.tanh(Q)