예제 #1
0
    def update_memory(self, mem, z, write_vec):

        M, mask = mem
        M0 = Variable(M.clone().data.zero_()).detach()
        M1 = M0 + 1

        erase = M1.sub(z.unsqueeze(2).expand(*M.size()))
        add = write_vec.unsqueeze(1).expand(*M.size())
        write = M0.addcmul(erase, add)

        M = M0.addcmul(M, erase) + write

        return M, mask
예제 #2
0
    def forward(self, emb_utts, hidden, mem, M_que):

        if self.net_data is not None:
            Z = []

        M, mask = mem

        #M.requires_grad = False
        #(seq_sz, batch_sz, word_vec_sz) = emb_utts.size()
        outputs = []
        ((hr, cr), (hw, cw)) = hidden
        out = hr.clone()
        out.data.zero_()
        for w in emb_utts.split(1):
            w = w.squeeze(0)
            if self.input_feed:
                w = torch.cat((w, out), 1)
            hr, cr = self.read_lstm(w, (hr, cr))

            hr = self.dropout(hr)
            #sim = hr.unsqueeze(1).bmm(M.transpose(1, 2)).squeeze(1)
            sim = cos(hr, M)
            z = self.softmax(sim.masked_fill_(mask, float('-inf')))

            if self.net_data is not None:
                Z += [z.data.squeeze()]

            m = z.unsqueeze(1).bmm(M)
            cattet = torch.cat([hr, m.squeeze(1)], 1)
            comp = self.compose(cattet)
            hw, cw = self.write_lstm(comp, (hw, cw))
            hw = self.dropout(hw)

            M0 = Variable(M.clone().data.zero_()).detach()
            M1 = M0 + 1

            erase = M1.sub(z.unsqueeze(2).expand(*M.size()))
            add = hw.unsqueeze(1).expand(*M.size())
            write = M0.addcmul(erase, add)

            M = M0.addcmul(M, erase) + write

            outputs += [hw]

        if self.net_data is not None:
            self.net_data['z'] += [torch.stack(Z)]

        return torch.stack(outputs), ((hr, cr), (hw, cw)), M