示例#1
0
    def beam_generate(self, batch, beam_size, k):
        if self.args.title:
            title_encoding, _ = self.title_encoder(batch.title)  # (1, title_len, 500)
            title_mask = self.create_mask(title_encoding.size(), batch.title[1]).unsqueeze(1)  # (1, 1, title_len)

        ents = batch.ent
        ent_num_list = ents[2]
        ents = self.entity_encoder(ents)
        # ents: (1, entity num, 500) / encoded hidden state of entities in b

        glob, node_embeddings, node_mask = self.graph_encoder(batch.rel[0], batch.rel[1], (ents, ent_num_list))
        hx = glob
        node_mask = node_mask == 0
        node_mask = node_mask.unsqueeze(1)

        cx = hx.clone().detach().requires_grad_(True)  # (beam size, 500)
        context = self.attention_graph(hx.unsqueeze(1), node_embeddings, mask=node_mask).squeeze(1)
        # (1, 500) / c_g
        if self.args.title:
            title_context = self.attention_title(hx.unsqueeze(1), title_encoding, mask=title_mask).squeeze(1)
            # (1, 500) / c_s
            context = torch.cat((context, title_context), 1)  # (beam size, 1000) / cat of c_g, c_s

        recent_token = torch.LongTensor(ents.size(0), 1).fill_(self.starttok).to(self.args.device)
        # recent_token: initially, (1, 1) / start token
        beam = None
        for i in range(self.args.maxlen):
            op = self.trim_entity_index(recent_token.clone(), batch.nerd)
            # previous token을 만들기 위해 tag로부터 index를 없애는 작업
            # 없애는 이유: train할 때도 output vocab (input) => target vocab (output) 으로 train 되었기 때문.
            op = self.embed(op).squeeze(1)  # (beam size, 500)
            prev = torch.cat((context, op), 1)  # (beam size, 1000)
            hx, cx = self.decoder(prev, (hx, cx))
            context = self.attention_graph(hx.unsqueeze(1), node_embeddings, mask=node_mask).squeeze(1)
            if self.args.title:
                title_context = self.attention_title(hx.unsqueeze(1), title_encoding, mask=title_mask).squeeze(1)
                context = torch.cat((context, title_context), 1)  # (beam size, 1000)

            total_context = torch.cat((hx, context), 1).unsqueeze(1)  # (beam size, 1, 1500)
            sampling_prob = torch.sigmoid(self.switch(total_context))  # (beam size, 1, 1)

            out = self.out(total_context)  # (beam size, 1, target vocab size)
            out = torch.softmax(out, 2)
            out = sampling_prob * out

            # compute copy attention
            z = self.mat_attention(total_context, (ents, ent_num_list))
            # z: (1,  max_abstract_len, max_entity_num) / entities attended on each abstract word, then softmaxed

            z = (1 - sampling_prob) * z
            out = torch.cat((out, z), 2)  # (beam size, 1, target vocab size + entity num)
            out[:, :, 0].fill_(0)  # remove probability for special tokens <unk>, <init>
            out[:, :, 1].fill_(0)

            out = out + (1e-6 * torch.ones_like(out))
            decoded = out.log()
            scores, words = decoded.topk(dim=2, k=k)  # (beam size, 1, k), (beam size, 1, k)
            if not beam:
                beam = Beam(words.squeeze(), scores.squeeze(), [hx] * beam_size,
                            [cx] * beam_size, [context] * beam_size,
                            beam_size, k, self.args.output_vocab_size, self.args.device)
                beam.endtok = self.endtok
                beam.eostok = self.eostok
                node_embeddings = node_embeddings.repeat(len(beam.beam), 1, 1)  # (beam size, adjacency matrix len, 500)
                node_mask = node_mask.repeat(len(beam.beam), 1, 1)  # (beam size, 1, adjacency matrix len) => all 0?
                if self.args.title:
                    title_encoding = title_encoding.repeat(len(beam.beam), 1, 1)  # (beam size, title_len, 500)
                    title_mask = title_mask.repeat(len(beam.beam), 1, 1)  # (1, 1, title_len)

                ents = ents.repeat(len(beam.beam), 1, 1)  # (beam size, entity num, 500)
                ent_num_list = ent_num_list.repeat(len(beam.beam))  # (beam size,)
            else:
                # if all beam nodes have ended, stop generating
                if not beam.update(scores, words, hx, cx, context):
                    break
                # if beam size changes (i.e. any of beam ends), change size of weight matrices accordingly
                node_embeddings = node_embeddings[:len(beam.beam)]
                node_mask = node_mask[:len(beam.beam)]
                if self.args.title:
                    title_encoding = title_encoding[:len(beam.beam)]
                    title_mask = title_mask[:len(beam.beam)]
                ents = ents[:len(beam.beam)]
                ent_num_list = ent_num_list[:len(beam.beam)]
            recent_token = beam.getwords()  # (beam size,) / next word for each beam
            hx = beam.get_h()  # (beam size, 500)
            cx = beam.get_c()  # (beam size, 500)
            context = beam.get_context()  # (beam size, 1000)

        return beam
示例#2
0
  def beam_generate(self,b,beamsz,k):
    if self.args.title:
      tencs,_ = self.tenc(b.src)
      tmask = self.maskFromList(tencs.size(),b.src[1]).unsqueeze(1)
    ents = b.ent
    entlens = ents[2]
    ents = self.le(ents)
    if self.graph:
      gents,glob,grels = self.ge(b.rel[0],b.rel[1],(ents,entlens))
      hx = glob
      #hx = ents.max(dim=1)[0]
      keys,mask = grels
      mask = mask==0
    else:
      mask = self.maskFromList(ents.size(),entlens)
      hx = ents.max(dim=1)[0]
      keys =ents
    mask = mask.unsqueeze(1)
    if self.args.plan:
      planlogits = self.splan.plan_decode(hx,keys,mask.clone(),entlens)
      print(planlogits.size())
      sorder = ' '.join([str(x) for x in planlogits.max(1)[1][0].tolist()])
      print(sorder)
      sorder = [x.strip() for x in sorder.split("-1")]
      sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder]
      mask.fill_(0)
      planplace = torch.zeros(hx.size(0)).long()
      for i,m in enumerate(sorder):
        mask[i][0][m[0]]=1
    else:
      planlogits = None

    cx = torch.tensor(hx)
    a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1)
    if self.args.title:
      a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1)
      a = torch.cat((a,a2),1)
    outputs = []
    outp = torch.LongTensor(ents.size(0),1).fill_(self.starttok).cuda()
    beam = None
    for i in range(self.maxlen):
      op = self.emb_w_vertex(outp.clone(),b.nerd)
      if self.args.plan:
        schange = op==self.args.dottok
        if schange.nonzero().size(0)>0:
          print(schange, planplace, sorder)
          planplace[schange.nonzero().squeeze()]+=1
          for j in schange.nonzero().squeeze(1):
            if planplace[j]<len(sorder[j]):
              mask[j] = 0
              m = sorder[j][planplace[j]]
              mask[j][0][sorder[j][planplace[j]]]=1
      op = self.emb(op).squeeze(1)
      prev = torch.cat((a,op),1)
      hx,cx = self.lstm(prev,(hx,cx))
      a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1)
      if self.args.title:
        a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1)
        #a =  a + (self.mix(hx)*a2)
        a = torch.cat((a,a2),1)
      l = torch.cat((hx,a),1).unsqueeze(1)
      s = torch.sigmoid(self.switch(l))
      o = self.out(l)
      o = torch.softmax(o,2)
      o = s*o
      #compute copy attn
      _, z = self.mattn(l,(ents,entlens))
      #z = torch.softmax(z,2)
      z = (1-s)*z
      o = torch.cat((o,z),2)
      o[:,:,0].fill_(0)
      o[:,:,1].fill_(0)
      '''
      if beam:
        for p,q in enumerate(beam.getPrevEnt()):
          o[p,:,q].fill_(0)
        for p,q in beam.getIsStart():
          for r in q:
            o[p,:,r].fill_(0)
      '''

      o = o+(1e-6*torch.ones_like(o))
      decoded = o.log()
      scores, words = decoded.topk(dim=2,k=k)
      if not beam:
        beam = Beam(words.squeeze(),scores.squeeze(),[hx for i in range(beamsz)],
                  [cx for i in range(beamsz)],[a for i in range(beamsz)],beamsz,k,self.args.ntoks)
        beam.endtok = self.endtok
        beam.eostok = self.eostok
        keys = keys.repeat(len(beam.beam),1,1)
        mask = mask.repeat(len(beam.beam),1,1)
        if self.args.title:
          tencs = tencs.repeat(len(beam.beam),1,1)
          tmask = tmask.repeat(len(beam.beam),1,1)
        if self.args.plan:
          planplace= planplace.unsqueeze(0).repeat(len(beam.beam),1)
          sorder = sorder*len(beam.beam)

          
        ents = ents.repeat(len(beam.beam),1,1)
        entlens = entlens.repeat(len(beam.beam))
      else:
        if not beam.update(scores,words,hx,cx,a):
          break
        keys = keys[:len(beam.beam)]
        mask = mask[:len(beam.beam)]
        if self.args.title:
          tencs = tencs[:len(beam.beam)]
          tmask = tmask[:len(beam.beam)]
        if self.args.plan:
          planplace= planplace[:len(beam.beam)]
          sorder = sorder[0]*len(beam.beam)
        ents = ents[:len(beam.beam)]
        entlens = entlens[:len(beam.beam)]
      outp = beam.getwords()
      hx = beam.geth()
      cx = beam.getc()
      a = beam.getlast()

    return beam
示例#3
0
    def beam_generate(self, b, beamsz, k):
        if self.args.title:
            tencs, _ = self.tenc(b.src)  # (1, title_len, 500)
            tmask = self.maskFromList(tencs.size(), b.src[1]).unsqueeze(
                1)  # (1, 1, title_len)
        ents = b.ent  # tuple of (ent, phlens, elens)
        entlens = ents[2]
        ents = self.le(ents)
        # ents: (1, entity num, 500) / encoded hidden state of entities in b

        if self.graph:
            gents, glob, grels = self.ge(b.rel[0], b.rel[1], (ents, entlens))
            hx = glob
            # hx = ents.max(dim=1)[0]
            keys, mask = grels
            mask = mask == 0
        else:
            mask = self.maskFromList(ents.size(), entlens)
            hx = ents.max(dim=1)[0]
            keys = ents
        mask = mask.unsqueeze(1)
        if self.args.plan:
            planlogits = self.splan.plan_decode(hx, keys, mask.clone(),
                                                entlens)
            print(planlogits.size())
            sorder = ' '.join(
                [str(x) for x in planlogits.max(1)[1][0].tolist()])
            print(sorder)
            sorder = [x.strip() for x in sorder.split("-1")]
            sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder]
            mask.fill_(0)
            planplace = torch.zeros(hx.size(0)).long()
            for i, m in enumerate(sorder):
                mask[i][0][m[0]] = 1
        else:
            planlogits = None

        cx = torch.tensor(hx)  # (beam size, 500)
        a = self.attn(hx.unsqueeze(1), keys,
                      mask=mask).squeeze(1)  # (1, 500) / c_g
        if self.args.title:
            a2 = self.attn2(hx.unsqueeze(1), tencs,
                            mask=tmask).squeeze(1)  # (1, 500) / c_s
            a = torch.cat((a, a2), 1)  # (beam size, 1000) / c_t

        outp = torch.LongTensor(ents.size(0), 1).fill_(
            self.starttok).cuda()  # initially, (1, 1) / start token
        beam = None
        for i in range(self.maxlen):
            op = self.emb_w_vertex(outp.clone(), b.nerd)
            # tag로부터 index를 없애는 작업 => train할 때도 not indexed => indexed 로 train 되었기 때문.
            if self.args.plan:
                schange = op == self.args.dottok
                if schange.nonzero().size(0) > 0:
                    print(schange, planplace, sorder)
                    planplace[schange.nonzero().squeeze()] += 1
                    for j in schange.nonzero().squeeze(1):
                        if planplace[j] < len(sorder[j]):
                            mask[j] = 0
                            m = sorder[j][planplace[j]]
                            mask[j][0][sorder[j][planplace[j]]] = 1
            op = self.emb(op).squeeze(1)  # (beam size, 500)
            prev = torch.cat((a, op), 1)  # (beam size, 1000)
            hx, cx = self.lstm(prev, (hx, cx))
            a = self.attn(hx.unsqueeze(1), keys, mask=mask).squeeze(1)
            if self.args.title:
                a2 = self.attn2(hx.unsqueeze(1), tencs, mask=tmask).squeeze(1)
                a = torch.cat((a, a2), 1)  # (beam size, 1000)
            l = torch.cat((hx, a), 1).unsqueeze(1)  # (beam size, 1, 1500)
            s = torch.sigmoid(self.switch(l))  # (beam size, 1, 1)
            o = self.out(l)  # (beam size, 1, target vocab size)
            o = torch.softmax(o, 2)
            o = s * o
            # compute copy attn
            _, z = self.mattn(l, (ents, entlens))
            # z = torch.softmax(z,2)
            z = (1 - s) * z
            o = torch.cat((o, z),
                          2)  # (beam size, 1, target vocab size + entity num)
            o[:, :, 0].fill_(
                0)  # remove probability for special tokens <unk>, <init>
            o[:, :, 1].fill_(0)
            '''
      if beam:
        for p,q in enumerate(beam.getPrevEnt()):
          o[p,:,q].fill_(0)
        for p,q in beam.getIsStart():
          for r in q:
            o[p,:,r].fill_(0)
      '''

            o = o + (1e-6 * torch.ones_like(o))
            decoded = o.log()
            scores, words = decoded.topk(
                dim=2, k=k)  # (beam size, 1, k), (beam size, 1, k)
            if not beam:
                beam = Beam(words.squeeze(), scores.squeeze(),
                            [hx for i in range(beamsz)],
                            [cx for i in range(beamsz)],
                            [a for i in range(beamsz)], beamsz, k,
                            self.args.ntoks)
                beam.endtok = self.endtok
                beam.eostok = self.eostok
                keys = keys.repeat(len(beam.beam), 1,
                                   1)  # (beam size, adjacency matrix len, 500)
                mask = mask.repeat(
                    len(beam.beam), 1,
                    1)  # (beam size, 1, adjacency matrix len) => all 0?
                if self.args.title:
                    tencs = tencs.repeat(len(beam.beam), 1,
                                         1)  # (beam size, title_len, 500)
                    tmask = tmask.repeat(len(beam.beam), 1,
                                         1)  # (1, 1, title_len)
                if self.args.plan:
                    planplace = planplace.unsqueeze(0).repeat(
                        len(beam.beam), 1)
                    sorder = sorder * len(beam.beam)

                ents = ents.repeat(len(beam.beam), 1,
                                   1)  # (beam size, entity num, 500)
                entlens = entlens.repeat(len(beam.beam))  # (beam size,)
            else:
                if not beam.update(scores, words, hx, cx, a):
                    break
                # if beam size changes (i.e. any of beam ends), change size of weight matrices accordingly
                keys = keys[:len(beam.beam)]
                mask = mask[:len(beam.beam)]
                if self.args.title:
                    tencs = tencs[:len(beam.beam)]
                    tmask = tmask[:len(beam.beam)]
                if self.args.plan:
                    planplace = planplace[:len(beam.beam)]
                    sorder = sorder[0] * len(beam.beam)
                ents = ents[:len(beam.beam)]
                entlens = entlens[:len(beam.beam)]
            outp = beam.getwords()  # (beam size,) / next word for each beam
            hx = beam.geth()  # (beam size, 500)
            cx = beam.getc()  # (beam size, 500)
            a = beam.getlast()  # (beam size, 1000)

        return beam