Beispiel #1
0
  def beam_generate(self,h,c,tembs,vembs,gembs,nerd,beamsz,k):
    #h,c,tembs,vembs,gembs,rembs = self.encode_inputs(title,entities,graph)
    #h,c,tembs,vembs,gembs = self.encode_inputs(title,entities,graph)
    embs = [x for x in [(self.t_attn,tembs),(self.g_attn,gembs),(self.e_attn,vembs)] if x[1] is not None]

    outp = torch.LongTensor(vembs[0].size(0),1).fill_(self.starttok).cuda()
    last = h.transpose(0,1)
    outputs = []
    beam = None
    for i in range(self.maxlen):
      outp = self.emb_w_vertex(outp.clone(),nerd)
      enc = self.Embedding(outp)
      decin = torch.cat((enc,last),2)
      decout,(h,c) = self.dlstm(decin,(h,c))
      last, vweight, _ = self.hierattn(decout,embs)
      scalar = torch.sigmoid(self.switch(h))
      outs = torch.cat((decout,last),2)
      decoded = self.outlin(outs.contiguous().view(-1, self.args.hsz*2))
      decoded = decoded.view(outs.size(0), outs.size(1), self.args.ntoks)
      decoded = torch.softmax(decoded,2)
      decoded[:,:,0].fill_(0)
      decoded[:,:,1].fill_(0)
      scalars = scalar.transpose(0,1)
      decoded = torch.mul(decoded,1-scalars.expand_as(decoded))
      vweights = torch.mul(vweight,scalars.expand_as(vweight))
      decoded = torch.cat([decoded,vweights],2)

      zero_vec = 1e-6*torch.ones_like(decoded)
      decoded += zero_vec
      decoded = decoded.log()
      scores, words = decoded.topk(dim=2,k=k)
      #scores = scores.transpose(0,1); words = words.transpose(0,1)
      if not beam:
        beam = Beam(words.squeeze(),scores.squeeze(),[h for i in range(beamsz)],
                  [c for i in range(beamsz)],[last for i in range(beamsz)],beamsz,k)
        beam.endtok = self.endtok
        newembs = []
        for a,x in embs:
          tmp = (x[0].repeat(len(beam.beam),1,1),x[1].repeat(len(beam.beam),1))
          newembs.append((a,tmp))
        embs = newembs
      else:
        if not beam.update(scores,words,h,c,last):
          break
        newembs = []
        for a,x in embs:
          tmp = (x[0][:len(beam.beam),:,:],x[1][:len(beam.beam)])
          newembs.append((a,tmp))
        embs = newembs
      outp = beam.getwords()
      h = beam.geth()
      c = beam.getc()
      last = beam.getlast()

    return beam