Beispiel #1
0
  def beam_generate(self,h,c,tembs,vembs,gembs,nerd,beamsz,k):
    #h,c,tembs,vembs,gembs,rembs = self.encode_inputs(title,entities,graph)
    #h,c,tembs,vembs,gembs = self.encode_inputs(title,entities,graph)
    embs = [x for x in [(self.t_attn,tembs),(self.g_attn,gembs),(self.e_attn,vembs)] if x[1] is not None]

    outp = torch.LongTensor(vembs[0].size(0),1).fill_(self.starttok).cuda()
    last = h.transpose(0,1)
    outputs = []
    beam = None
    for i in range(self.maxlen):
      outp = self.emb_w_vertex(outp.clone(),nerd)
      enc = self.Embedding(outp)
      decin = torch.cat((enc,last),2)
      decout,(h,c) = self.dlstm(decin,(h,c))
      last, vweight, _ = self.hierattn(decout,embs)
      scalar = torch.sigmoid(self.switch(h))
      outs = torch.cat((decout,last),2)
      decoded = self.outlin(outs.contiguous().view(-1, self.args.hsz*2))
      decoded = decoded.view(outs.size(0), outs.size(1), self.args.ntoks)
      decoded = torch.softmax(decoded,2)
      decoded[:,:,0].fill_(0)
      decoded[:,:,1].fill_(0)
      scalars = scalar.transpose(0,1)
      decoded = torch.mul(decoded,1-scalars.expand_as(decoded))
      vweights = torch.mul(vweight,scalars.expand_as(vweight))
      decoded = torch.cat([decoded,vweights],2)

      zero_vec = 1e-6*torch.ones_like(decoded)
      decoded += zero_vec
      decoded = decoded.log()
      scores, words = decoded.topk(dim=2,k=k)
      #scores = scores.transpose(0,1); words = words.transpose(0,1)
      if not beam:
        beam = Beam(words.squeeze(),scores.squeeze(),[h for i in range(beamsz)],
                  [c for i in range(beamsz)],[last for i in range(beamsz)],beamsz,k)
        beam.endtok = self.endtok
        newembs = []
        for a,x in embs:
          tmp = (x[0].repeat(len(beam.beam),1,1),x[1].repeat(len(beam.beam),1))
          newembs.append((a,tmp))
        embs = newembs
      else:
        if not beam.update(scores,words,h,c,last):
          break
        newembs = []
        for a,x in embs:
          tmp = (x[0][:len(beam.beam),:,:],x[1][:len(beam.beam)])
          newembs.append((a,tmp))
        embs = newembs
      outp = beam.getwords()
      h = beam.geth()
      c = beam.getc()
      last = beam.getlast()

    return beam
Beispiel #2
0
    def beam_generate(self, batch, beam_size, k) :
        batch = batch.input
        encoder_output, context = self.encoder(batch[0], batch[1])
        hidden = []
        for i in range(len(context)) :
            each = context[i]
            hidden.append(torch.cat([each[0:each.size(0):2], each[1:each.size(0):2]], 2))
        hx = hidden[0]
        cx = hidden[1]
        recent_token = torch.LongTensor(1, ).fill_(2).to(self.device)
        beam = None
        for i in range(1000) :
            embedded = self.decoder.embedding(recent_token.type(dtype = torch.long).to(self.device))
            #(beam_size, embedding_size)
            embedded = embedded.unsqueeze(0).permute(1, 0, 2)
            output, (hx, cx) = self.decoder.rnn(embedded, (hx.contiguous(), cx.contiguous()))
            hx = hx.permute(1, 0, -1)
            cx = cx.permute(1, 0, -1)
            output = self.decoder.out(output.contiguous()) #(beam_size, 1, target_vocab_size)
            output = self.softmax(output)
            output[:, :, 0].fill_(0)
            output[:, :, 1].fill_(0)
            output[:, :, 2].fill_(0)
            decoded = output.log().to(self.device)
            scores, words = decoded.topk(dim = -1, k = k) #(beam_size, 1, k) (beam_size, 1, k)
            scores.to(self.device)
            words.to(self.device)

            if not beam :
                beam = Beam(words.squeeze(), scores.squeeze(), [hx] * beam_size, [cx] * beam_size, beam_size, k, self.decoder.output_vocab_size, self.device)
                beam.endtok = 5
                beam.eostok = 3
            else :
                if not beam.update(scores, words, hx, cx) : break
            
            recent_token = beam.getwords().view(-1) #(beam_size, )
            hx = beam.get_h().permute(1, 0, -1)
            cx = beam.get_c().permute(1, 0, -1)
            #context = beam.get_context()
        
        return beam