Exemplo n.º 1
0
  def beam_generate(self,h,c,tembs,vembs,gembs,nerd,beamsz,k):
    #h,c,tembs,vembs,gembs,rembs = self.encode_inputs(title,entities,graph)
    #h,c,tembs,vembs,gembs = self.encode_inputs(title,entities,graph)
    embs = [x for x in [(self.t_attn,tembs),(self.g_attn,gembs),(self.e_attn,vembs)] if x[1] is not None]

    outp = torch.LongTensor(vembs[0].size(0),1).fill_(self.starttok).cuda()
    last = h.transpose(0,1)
    outputs = []
    beam = None
    for i in range(self.maxlen):
      outp = self.emb_w_vertex(outp.clone(),nerd)
      enc = self.Embedding(outp)
      decin = torch.cat((enc,last),2)
      decout,(h,c) = self.dlstm(decin,(h,c))
      last, vweight, _ = self.hierattn(decout,embs)
      scalar = torch.sigmoid(self.switch(h))
      outs = torch.cat((decout,last),2)
      decoded = self.outlin(outs.contiguous().view(-1, self.args.hsz*2))
      decoded = decoded.view(outs.size(0), outs.size(1), self.args.ntoks)
      decoded = torch.softmax(decoded,2)
      decoded[:,:,0].fill_(0)
      decoded[:,:,1].fill_(0)
      scalars = scalar.transpose(0,1)
      decoded = torch.mul(decoded,1-scalars.expand_as(decoded))
      vweights = torch.mul(vweight,scalars.expand_as(vweight))
      decoded = torch.cat([decoded,vweights],2)

      zero_vec = 1e-6*torch.ones_like(decoded)
      decoded += zero_vec
      decoded = decoded.log()
      scores, words = decoded.topk(dim=2,k=k)
      #scores = scores.transpose(0,1); words = words.transpose(0,1)
      if not beam:
        beam = Beam(words.squeeze(),scores.squeeze(),[h for i in range(beamsz)],
                  [c for i in range(beamsz)],[last for i in range(beamsz)],beamsz,k)
        beam.endtok = self.endtok
        newembs = []
        for a,x in embs:
          tmp = (x[0].repeat(len(beam.beam),1,1),x[1].repeat(len(beam.beam),1))
          newembs.append((a,tmp))
        embs = newembs
      else:
        if not beam.update(scores,words,h,c,last):
          break
        newembs = []
        for a,x in embs:
          tmp = (x[0][:len(beam.beam),:,:],x[1][:len(beam.beam)])
          newembs.append((a,tmp))
        embs = newembs
      outp = beam.getwords()
      h = beam.geth()
      c = beam.getc()
      last = beam.getlast()

    return beam
Exemplo n.º 2
0
    def beam_generate(self, batch, beam_size, k) :
        batch = batch.input
        encoder_output, context = self.encoder(batch[0], batch[1])
        hidden = []
        for i in range(len(context)) :
            each = context[i]
            hidden.append(torch.cat([each[0:each.size(0):2], each[1:each.size(0):2]], 2))
        hx = hidden[0]
        cx = hidden[1]
        recent_token = torch.LongTensor(1, ).fill_(2).to(self.device)
        beam = None
        for i in range(1000) :
            embedded = self.decoder.embedding(recent_token.type(dtype = torch.long).to(self.device))
            #(beam_size, embedding_size)
            embedded = embedded.unsqueeze(0).permute(1, 0, 2)
            output, (hx, cx) = self.decoder.rnn(embedded, (hx.contiguous(), cx.contiguous()))
            hx = hx.permute(1, 0, -1)
            cx = cx.permute(1, 0, -1)
            output = self.decoder.out(output.contiguous()) #(beam_size, 1, target_vocab_size)
            output = self.softmax(output)
            output[:, :, 0].fill_(0)
            output[:, :, 1].fill_(0)
            output[:, :, 2].fill_(0)
            decoded = output.log().to(self.device)
            scores, words = decoded.topk(dim = -1, k = k) #(beam_size, 1, k) (beam_size, 1, k)
            scores.to(self.device)
            words.to(self.device)

            if not beam :
                beam = Beam(words.squeeze(), scores.squeeze(), [hx] * beam_size, [cx] * beam_size, beam_size, k, self.decoder.output_vocab_size, self.device)
                beam.endtok = 5
                beam.eostok = 3
            else :
                if not beam.update(scores, words, hx, cx) : break
            
            recent_token = beam.getwords().view(-1) #(beam_size, )
            hx = beam.get_h().permute(1, 0, -1)
            cx = beam.get_c().permute(1, 0, -1)
            #context = beam.get_context()
        
        return beam
Exemplo n.º 3
0
def beam_decode(decoder_context,
                decoder_hidden,
                encoder_outputs,
                max_len,
                beam_size=5):
    batch_size = args.beam_size
    vocab_size = output_lang.n_words
    # [1, batch_size x beam_size]
    decoder_input = torch.ones(batch_size * beam_size, dtype=torch.long, device=device) * Language.sos_token

    # [num_layers, batch_size x beam_size, hidden_size]
    decoder_hidden = decoder_hidden.repeat(1, beam_size, 1)
    decoder_context = decoder_context.repeat(1, beam_size, 1)

    encoder_outputs = encoder_outputs.repeat(1, beam_size, 1)

    # [batch_size] [0, beam_size * 1, ..., beam_size * (batch_size - 1)]
    batch_position = torch.arange(0, batch_size, dtype=torch.long, device=device) * beam_size

    score = torch.ones(batch_size * beam_size, device=device) * -float('inf')
    score.index_fill_(0, torch.arange(0, batch_size, dtype=torch.long, device=device) * beam_size, 0.0)

    # Initialize Beam that stores decisions for backtracking
    beam = Beam(
        batch_size,
        beam_size,
        max_len,
        batch_position,
        Language.eos_token
    )

    for i in range(max_len):
        decoder_output, decoder_context, decoder_hidden, _ = decoder(decoder_input,
                                                                    decoder_context,
                                                                    decoder_hidden,
                                                                    encoder_outputs)
        # output: [1, batch_size * beam_size, vocab_size]
        # -> [batch_size * beam_size, vocab_size]
        log_prob = decoder_output

        # score: [batch_size * beam_size, vocab_size]
        score = score.view(-1, 1) + log_prob

        # score [batch_size, beam_size]
        score, top_k_idx = score.view(batch_size, -1).topk(beam_size, dim=1)

        # decoder_input: [batch_size x beam_size]
        decoder_input = (top_k_idx % vocab_size).view(-1)

        # beam_idx: [batch_size, beam_size]
        beam_idx = top_k_idx / vocab_size  # [batch_size, beam_size]

        # top_k_pointer: [batch_size * beam_size]
        top_k_pointer = (beam_idx + batch_position.unsqueeze(1)).view(-1)

        # [num_layers, batch_size * beam_size, hidden_size]
        decoder_hidden = decoder_hidden.index_select(1, top_k_pointer)
        decoder_context = decoder_context.index_select(1, top_k_pointer)

        # Update sequence scores at beam
        beam.update(score.clone(), top_k_pointer, decoder_input)

        # Erase scores for EOS so that they are not expanded
        # [batch_size, beam_size]
        eos_idx = decoder_input.data.eq(Language.eos_token).view(batch_size, beam_size)

        if eos_idx.nonzero().dim() > 0:
            score.data.masked_fill_(eos_idx, -float('inf'))

    prediction, final_score, length = beam.backtrack()
    return prediction, final_score, length