Ejemplo n.º 1
0
    def forward(self,
                inputs,
                listener_outputs,
                function=F.log_softmax,
                teacher_forcing_ratio=0.90,
                use_beam_search=False):
        batch_size = inputs.size(0)
        max_length = inputs.size(1) - 1  # minus the start of sequence symbol

        decode_results = list()
        use_teacher_forcing = True if random.random(
        ) < teacher_forcing_ratio else False

        hidden = self._init_state(batch_size)

        if use_beam_search:  # TopK Decoding
            input_ = inputs[:, 0].unsqueeze(1)
            beam = Beam(k=self.k,
                        decoder=self,
                        batch_size=batch_size,
                        max_length=max_length,
                        function=function,
                        device=self.device)
            logits = None
            y_hats = beam.search(input_, listener_outputs)

        else:
            if use_teacher_forcing:  # if teacher_forcing, Infer all at once
                speller_inputs = inputs[inputs != self.eos_id].view(
                    batch_size, -1)
                predicted_softmax, hidden = self.forward_step(
                    input_=speller_inputs,
                    hidden=hidden,
                    listener_outputs=listener_outputs,
                    function=function)

                for di in range(predicted_softmax.size(1)):
                    step_output = predicted_softmax[:, di, :]
                    decode_results.append(step_output)

            else:
                speller_input = inputs[:, 0].unsqueeze(1)

                for di in range(max_length):
                    predicted_softmax, hidden = self.forward_step(
                        input_=speller_input,
                        hidden=hidden,
                        listener_outputs=listener_outputs,
                        function=function)
                    step_output = predicted_softmax.squeeze(1)
                    decode_results.append(step_output)
                    speller_input = decode_results[-1].topk(1)[1]

            logits = torch.stack(decode_results, dim=1).to(self.device)
            y_hats = logits.max(-1)[1]

        return y_hats, logits
Ejemplo n.º 2
0
    def forward(self,
                inputs,
                encoder_outputs,
                function=F.log_softmax,
                teacher_forcing_ratio=0.90,
                use_beam_search=False):
        y_hats, logits = None, None
        decode_results = []
        batch_size = inputs.size(0)
        max_len = inputs.size(1) - 1  # minus the start of sequence symbol
        decoder_hidden = torch.FloatTensor(self.n_layers, batch_size,
                                           self.hidden_size).uniform_(
                                               -0.1, 0.1).to(self.device)
        use_teacher_forcing = True if random.random(
        ) < teacher_forcing_ratio else False

        if use_beam_search:
            """ Beam-Search Decoding """
            inputs = inputs[:, 0].unsqueeze(1)
            beam = Beam(k=self.k,
                        decoder_hidden=decoder_hidden,
                        decoder=self,
                        batch_size=batch_size,
                        max_len=max_len,
                        function=function,
                        device=self.device)
            y_hats = beam.search(inputs, encoder_outputs)
        else:
            if use_teacher_forcing:
                """ if teacher_forcing, Infer all at once """
                inputs = inputs[:, :-1]
                predicted_softmax = self._forward_step(
                    input=inputs,
                    decoder_hidden=decoder_hidden,
                    encoder_outputs=encoder_outputs,
                    function=function)
                for di in range(predicted_softmax.size(1)):
                    step_output = predicted_softmax[:, di, :]
                    decode_results.append(step_output)
            else:
                input = inputs[:, 0].unsqueeze(1)
                for di in range(max_len):
                    predicted_softmax = self._forward_step(
                        input=input,
                        decoder_hidden=decoder_hidden,
                        encoder_outputs=encoder_outputs,
                        function=function)
                    step_output = predicted_softmax.squeeze(1)
                    decode_results.append(step_output)
                    input = decode_results[-1].topk(1)[1]

            logits = torch.stack(decode_results, dim=1).to(self.device)
            y_hats = logits.max(-1)[1]
        return y_hats, logits
Ejemplo n.º 3
0
    def beam_search(self, source_tensor, beam_size):
        """
        Use beam search to generate summaries one by one.
        :param source_tensor: (src seq len, batch size), batch size need to be 1.
        :param beam_size: beam search size
        :return: same as forward
        """
        batch_size = source_tensor.size(1)
        assert batch_size == 1
        # run encoder
        encoder_init_hidden = torch.zeros(1,
                                          batch_size,
                                          self.params.hidden_layer_units,
                                          device=device)
        encoder_word_embeddings = self.embedding(source_tensor)
        encoder_outputs, encoder_hidden = self.encoder(encoder_word_embeddings,
                                                       encoder_init_hidden)

        # build batch of beam size and initialize states
        encoder_outputs = encoder_outputs.expand(-1, beam_size,
                                                 -1).contiguous()
        decoder_hidden_cur_step = encoder_hidden.expand(-1, beam_size,
                                                        -1).contiguous()
        be = Beam(beam_size, self.special_tokens)

        step = 0
        while step <= self.params.max_dec_steps:
            decoder_input_cur_step = be.states[-1]
            decoder_cur_word_embedding = self.embedding(decoder_input_cur_step)
            decoder_output_cur_step, decoder_hidden_cur_step = self.decoder(
                decoder_cur_word_embedding, decoder_hidden_cur_step,
                encoder_outputs)
            if be.advance(decoder_output_cur_step):
                break
            step += 1
        result_tokens = be.trace(0)
        output_token_idx = torch.tensor(result_tokens,
                                        device=device,
                                        dtype=torch.long).unsqueeze(1).expand(
                                            -1, batch_size)
        return output_token_idx, torch.tensor(0., device=device)
Ejemplo n.º 4
0
def beamsearch(memory,
               model,
               device,
               beam_size=4,
               candidates=1,
               max_seq_length=128,
               sos_token=1,
               eos_token=2):
    # memory: Tx1xE
    model.eval()

    beam = Beam(beam_size=beam_size,
                min_length=0,
                n_top=candidates,
                ranker=None,
                start_token_id=sos_token,
                end_token_id=eos_token)

    with torch.no_grad():
        #        memory = memory.repeat(1, beam_size, 1) # TxNxE
        memory = model.expand_memory(memory, beam_size)

        for _ in range(max_seq_length):

            tgt_inp = beam.get_current_state().transpose(0,
                                                         1).to(device)  # TxN
            decoder_outputs, memory = model.transformer.forward_decoder(
                tgt_inp, memory)

            log_prob = log_softmax(decoder_outputs[:, -1, :].squeeze(0),
                                   dim=-1)
            beam.advance(log_prob.cpu())

            if beam.done():
                break

        scores, ks = beam.sort_finished(minimum=1)

        hypothesises = []
        for i, (times, k) in enumerate(ks[:candidates]):
            hypothesis = beam.get_hypothesis(times, k)
            hypothesises.append(hypothesis)

    return [1] + [int(i) for i in hypothesises[0][:-1]]
Ejemplo n.º 5
0
  def beam_generate(self,b,beamsz,k):
    if self.args.title:
      tencs,_ = self.tenc(b.src)
      tmask = self.maskFromList(tencs.size(),b.src[1]).unsqueeze(1)
    ents = b.ent
    entlens = ents[2]
    ents = self.le(ents)
    if self.graph:
      gents,glob,grels = self.ge(b.rel[0],b.rel[1],(ents,entlens))
      hx = glob
      #hx = ents.max(dim=1)[0]
      keys,mask = grels
      mask = mask==0
    else:
      mask = self.maskFromList(ents.size(),entlens)
      hx = ents.max(dim=1)[0]
      keys =ents
    mask = mask.unsqueeze(1)
    if self.args.plan:
      planlogits = self.splan.plan_decode(hx,keys,mask.clone(),entlens)
      print(planlogits.size())
      sorder = ' '.join([str(x) for x in planlogits.max(1)[1][0].tolist()])
      print(sorder)
      sorder = [x.strip() for x in sorder.split("-1")]
      sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder]
      mask.fill_(0)
      planplace = torch.zeros(hx.size(0)).long()
      for i,m in enumerate(sorder):
        mask[i][0][m[0]]=1
    else:
      planlogits = None

    cx = torch.tensor(hx)
    a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1)
    if self.args.title:
      a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1)
      a = torch.cat((a,a2),1)
    outputs = []
    outp = torch.LongTensor(ents.size(0),1).fill_(self.starttok).cuda()
    beam = None
    for i in range(self.maxlen):
      op = self.emb_w_vertex(outp.clone(),b.nerd)
      if self.args.plan:
        schange = op==self.args.dottok
        if schange.nonzero().size(0)>0:
          print(schange, planplace, sorder)
          planplace[schange.nonzero().squeeze()]+=1
          for j in schange.nonzero().squeeze(1):
            if planplace[j]<len(sorder[j]):
              mask[j] = 0
              m = sorder[j][planplace[j]]
              mask[j][0][sorder[j][planplace[j]]]=1
      op = self.emb(op).squeeze(1)
      prev = torch.cat((a,op),1)
      hx,cx = self.lstm(prev,(hx,cx))
      a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1)
      if self.args.title:
        a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1)
        #a =  a + (self.mix(hx)*a2)
        a = torch.cat((a,a2),1)
      l = torch.cat((hx,a),1).unsqueeze(1)
      s = torch.sigmoid(self.switch(l))
      o = self.out(l)
      o = torch.softmax(o,2)
      o = s*o
      #compute copy attn
      _, z = self.mattn(l,(ents,entlens))
      #z = torch.softmax(z,2)
      z = (1-s)*z
      o = torch.cat((o,z),2)
      o[:,:,0].fill_(0)
      o[:,:,1].fill_(0)
      '''
      if beam:
        for p,q in enumerate(beam.getPrevEnt()):
          o[p,:,q].fill_(0)
        for p,q in beam.getIsStart():
          for r in q:
            o[p,:,r].fill_(0)
      '''

      o = o+(1e-6*torch.ones_like(o))
      decoded = o.log()
      scores, words = decoded.topk(dim=2,k=k)
      if not beam:
        beam = Beam(words.squeeze(),scores.squeeze(),[hx for i in range(beamsz)],
                  [cx for i in range(beamsz)],[a for i in range(beamsz)],beamsz,k,self.args.ntoks)
        beam.endtok = self.endtok
        beam.eostok = self.eostok
        keys = keys.repeat(len(beam.beam),1,1)
        mask = mask.repeat(len(beam.beam),1,1)
        if self.args.title:
          tencs = tencs.repeat(len(beam.beam),1,1)
          tmask = tmask.repeat(len(beam.beam),1,1)
        if self.args.plan:
          planplace= planplace.unsqueeze(0).repeat(len(beam.beam),1)
          sorder = sorder*len(beam.beam)

          
        ents = ents.repeat(len(beam.beam),1,1)
        entlens = entlens.repeat(len(beam.beam))
      else:
        if not beam.update(scores,words,hx,cx,a):
          break
        keys = keys[:len(beam.beam)]
        mask = mask[:len(beam.beam)]
        if self.args.title:
          tencs = tencs[:len(beam.beam)]
          tmask = tmask[:len(beam.beam)]
        if self.args.plan:
          planplace= planplace[:len(beam.beam)]
          sorder = sorder[0]*len(beam.beam)
        ents = ents[:len(beam.beam)]
        entlens = entlens[:len(beam.beam)]
      outp = beam.getwords()
      hx = beam.geth()
      cx = beam.getc()
      a = beam.getlast()

    return beam
Ejemplo n.º 6
0
    def beam_generate(self, batch, beam_size, k):
        if self.args.title:
            title_encoding, _ = self.title_encoder(batch.title)  # (1, title_len, 500)
            title_mask = self.create_mask(title_encoding.size(), batch.title[1]).unsqueeze(1)  # (1, 1, title_len)

        ents = batch.ent
        ent_num_list = ents[2]
        ents = self.entity_encoder(ents)
        # ents: (1, entity num, 500) / encoded hidden state of entities in b

        glob, node_embeddings, node_mask = self.graph_encoder(batch.rel[0], batch.rel[1], (ents, ent_num_list))
        hx = glob
        node_mask = node_mask == 0
        node_mask = node_mask.unsqueeze(1)

        cx = hx.clone().detach().requires_grad_(True)  # (beam size, 500)
        context = self.attention_graph(hx.unsqueeze(1), node_embeddings, mask=node_mask).squeeze(1)
        # (1, 500) / c_g
        if self.args.title:
            title_context = self.attention_title(hx.unsqueeze(1), title_encoding, mask=title_mask).squeeze(1)
            # (1, 500) / c_s
            context = torch.cat((context, title_context), 1)  # (beam size, 1000) / cat of c_g, c_s

        recent_token = torch.LongTensor(ents.size(0), 1).fill_(self.starttok).to(self.args.device)
        # recent_token: initially, (1, 1) / start token
        beam = None
        for i in range(self.args.maxlen):
            op = self.trim_entity_index(recent_token.clone(), batch.nerd)
            # previous token을 만들기 위해 tag로부터 index를 없애는 작업
            # 없애는 이유: train할 때도 output vocab (input) => target vocab (output) 으로 train 되었기 때문.
            op = self.embed(op).squeeze(1)  # (beam size, 500)
            prev = torch.cat((context, op), 1)  # (beam size, 1000)
            hx, cx = self.decoder(prev, (hx, cx))
            context = self.attention_graph(hx.unsqueeze(1), node_embeddings, mask=node_mask).squeeze(1)
            if self.args.title:
                title_context = self.attention_title(hx.unsqueeze(1), title_encoding, mask=title_mask).squeeze(1)
                context = torch.cat((context, title_context), 1)  # (beam size, 1000)

            total_context = torch.cat((hx, context), 1).unsqueeze(1)  # (beam size, 1, 1500)
            sampling_prob = torch.sigmoid(self.switch(total_context))  # (beam size, 1, 1)

            out = self.out(total_context)  # (beam size, 1, target vocab size)
            out = torch.softmax(out, 2)
            out = sampling_prob * out

            # compute copy attention
            z = self.mat_attention(total_context, (ents, ent_num_list))
            # z: (1,  max_abstract_len, max_entity_num) / entities attended on each abstract word, then softmaxed

            z = (1 - sampling_prob) * z
            out = torch.cat((out, z), 2)  # (beam size, 1, target vocab size + entity num)
            out[:, :, 0].fill_(0)  # remove probability for special tokens <unk>, <init>
            out[:, :, 1].fill_(0)

            out = out + (1e-6 * torch.ones_like(out))
            decoded = out.log()
            scores, words = decoded.topk(dim=2, k=k)  # (beam size, 1, k), (beam size, 1, k)
            if not beam:
                beam = Beam(words.squeeze(), scores.squeeze(), [hx] * beam_size,
                            [cx] * beam_size, [context] * beam_size,
                            beam_size, k, self.args.output_vocab_size, self.args.device)
                beam.endtok = self.endtok
                beam.eostok = self.eostok
                node_embeddings = node_embeddings.repeat(len(beam.beam), 1, 1)  # (beam size, adjacency matrix len, 500)
                node_mask = node_mask.repeat(len(beam.beam), 1, 1)  # (beam size, 1, adjacency matrix len) => all 0?
                if self.args.title:
                    title_encoding = title_encoding.repeat(len(beam.beam), 1, 1)  # (beam size, title_len, 500)
                    title_mask = title_mask.repeat(len(beam.beam), 1, 1)  # (1, 1, title_len)

                ents = ents.repeat(len(beam.beam), 1, 1)  # (beam size, entity num, 500)
                ent_num_list = ent_num_list.repeat(len(beam.beam))  # (beam size,)
            else:
                # if all beam nodes have ended, stop generating
                if not beam.update(scores, words, hx, cx, context):
                    break
                # if beam size changes (i.e. any of beam ends), change size of weight matrices accordingly
                node_embeddings = node_embeddings[:len(beam.beam)]
                node_mask = node_mask[:len(beam.beam)]
                if self.args.title:
                    title_encoding = title_encoding[:len(beam.beam)]
                    title_mask = title_mask[:len(beam.beam)]
                ents = ents[:len(beam.beam)]
                ent_num_list = ent_num_list[:len(beam.beam)]
            recent_token = beam.getwords()  # (beam size,) / next word for each beam
            hx = beam.get_h()  # (beam size, 500)
            cx = beam.get_c()  # (beam size, 500)
            context = beam.get_context()  # (beam size, 1000)

        return beam
Ejemplo n.º 7
0
    def beam_sample(self,
                    src,
                    src_mask,
                    segment_ids,
                    src_len,
                    beam_size=1,
                    eval_=False):

        # (1) Run the encoder on the src.

        lengths, indices = torch.sort(src_len, dim=0, descending=True)
        _, ind = torch.sort(indices)
        src = torch.index_select(src, dim=0, index=indices)
        if self.bert_model is None:
            src = src.t()
            batch_size = src.size(1)
        else:
            batch_size = src.size(0)
        contexts, encState = self.encoder(src, src_mask, segment_ids,
                                          lengths.tolist())

        #  (1b) Initialize for the decoder.
        def var(a):
            return a.clone().detach().requires_grad_(False)

        def rvar(a):
            return var(a.repeat(1, beam_size, 1))

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        if self.config.sgm.cell == 'lstm':
            decState = (rvar(encState[0]), rvar(encState[1]))
        else:
            decState = rvar(encState)

        beam = [
            Beam(beam_size,
                 n_best=1,
                 cuda=self.use_cuda,
                 length_norm=self.config.sgm.length_norm)
            for __ in range(batch_size)
        ]
        # if self.decoder.attention is not None:
        #     self.decoder.attention.init_context(contexts)
        if self.decoder.attention is not None:
            if self.bert_model is None:
                # Repeat everything beam_size times.
                contexts = rvar(contexts)
                contexts = contexts.transpose(0, 1)
            contexts = var(contexts.repeat(beam_size, 1, 1))
            self.decoder.attention.init_context(context=contexts)

        # (2) run the decoder to generate sentences, using beam search.

        for i in range(self.config.sgm.max_time_step):

            if all((b.done() for b in beam)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = var(
                torch.stack([b.getCurrentState()
                             for b in beam]).t().contiguous().view(-1))
            # if self.bert_model is None:
            #     inp = var(torch.stack([b.getCurrentState() for b in beam])
            #               .t().contiguous().view(-1))
            # else:
            #     inp = var(torch.stack([b.getCurrentState() for b in beam]).contiguous().view(-1))
            # Run one step.

            output, decState, attn = self.decoder(inp, decState)
            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.
            output = unbottle(self.log_softmax(output))
            attn = unbottle(attn)
            # beam x tgt_vocab

            # (c) Advance each beam.
            # update state
            for j, b in enumerate(beam):
                b.advance(output[:, j], attn[:, j])
                if self.config.sgm.cell == 'lstm':
                    b.beam_update(decState, j)
                else:
                    b.beam_update_gru(decState, j)

        # (3) Package everything up.
        allHyps, allScores, allAttn = [], [], []
        if eval_:
            allWeight = []

        for j in ind:
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            if eval_:
                weight = []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                attn.append(att.max(1)[1])
                if eval_:
                    weight.append(att)
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])
            if eval_:
                allWeight.append(weight[0])

        if eval_:
            return allHyps, allAttn, allWeight

        return allHyps, allAttn
Ejemplo n.º 8
0
    def beam_generate(self, b, beamsz, k):
        if self.args.title:
            tencs, _ = self.tenc(b.src)  # (1, title_len, 500)
            tmask = self.maskFromList(tencs.size(), b.src[1]).unsqueeze(
                1)  # (1, 1, title_len)
        ents = b.ent  # tuple of (ent, phlens, elens)
        entlens = ents[2]
        ents = self.le(ents)
        # ents: (1, entity num, 500) / encoded hidden state of entities in b

        if self.graph:
            gents, glob, grels = self.ge(b.rel[0], b.rel[1], (ents, entlens))
            hx = glob
            # hx = ents.max(dim=1)[0]
            keys, mask = grels
            mask = mask == 0
        else:
            mask = self.maskFromList(ents.size(), entlens)
            hx = ents.max(dim=1)[0]
            keys = ents
        mask = mask.unsqueeze(1)
        if self.args.plan:
            planlogits = self.splan.plan_decode(hx, keys, mask.clone(),
                                                entlens)
            print(planlogits.size())
            sorder = ' '.join(
                [str(x) for x in planlogits.max(1)[1][0].tolist()])
            print(sorder)
            sorder = [x.strip() for x in sorder.split("-1")]
            sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder]
            mask.fill_(0)
            planplace = torch.zeros(hx.size(0)).long()
            for i, m in enumerate(sorder):
                mask[i][0][m[0]] = 1
        else:
            planlogits = None

        cx = torch.tensor(hx)  # (beam size, 500)
        a = self.attn(hx.unsqueeze(1), keys,
                      mask=mask).squeeze(1)  # (1, 500) / c_g
        if self.args.title:
            a2 = self.attn2(hx.unsqueeze(1), tencs,
                            mask=tmask).squeeze(1)  # (1, 500) / c_s
            a = torch.cat((a, a2), 1)  # (beam size, 1000) / c_t

        outp = torch.LongTensor(ents.size(0), 1).fill_(
            self.starttok).cuda()  # initially, (1, 1) / start token
        beam = None
        for i in range(self.maxlen):
            op = self.emb_w_vertex(outp.clone(), b.nerd)
            # tag로부터 index를 없애는 작업 => train할 때도 not indexed => indexed 로 train 되었기 때문.
            if self.args.plan:
                schange = op == self.args.dottok
                if schange.nonzero().size(0) > 0:
                    print(schange, planplace, sorder)
                    planplace[schange.nonzero().squeeze()] += 1
                    for j in schange.nonzero().squeeze(1):
                        if planplace[j] < len(sorder[j]):
                            mask[j] = 0
                            m = sorder[j][planplace[j]]
                            mask[j][0][sorder[j][planplace[j]]] = 1
            op = self.emb(op).squeeze(1)  # (beam size, 500)
            prev = torch.cat((a, op), 1)  # (beam size, 1000)
            hx, cx = self.lstm(prev, (hx, cx))
            a = self.attn(hx.unsqueeze(1), keys, mask=mask).squeeze(1)
            if self.args.title:
                a2 = self.attn2(hx.unsqueeze(1), tencs, mask=tmask).squeeze(1)
                a = torch.cat((a, a2), 1)  # (beam size, 1000)
            l = torch.cat((hx, a), 1).unsqueeze(1)  # (beam size, 1, 1500)
            s = torch.sigmoid(self.switch(l))  # (beam size, 1, 1)
            o = self.out(l)  # (beam size, 1, target vocab size)
            o = torch.softmax(o, 2)
            o = s * o
            # compute copy attn
            _, z = self.mattn(l, (ents, entlens))
            # z = torch.softmax(z,2)
            z = (1 - s) * z
            o = torch.cat((o, z),
                          2)  # (beam size, 1, target vocab size + entity num)
            o[:, :, 0].fill_(
                0)  # remove probability for special tokens <unk>, <init>
            o[:, :, 1].fill_(0)
            '''
      if beam:
        for p,q in enumerate(beam.getPrevEnt()):
          o[p,:,q].fill_(0)
        for p,q in beam.getIsStart():
          for r in q:
            o[p,:,r].fill_(0)
      '''

            o = o + (1e-6 * torch.ones_like(o))
            decoded = o.log()
            scores, words = decoded.topk(
                dim=2, k=k)  # (beam size, 1, k), (beam size, 1, k)
            if not beam:
                beam = Beam(words.squeeze(), scores.squeeze(),
                            [hx for i in range(beamsz)],
                            [cx for i in range(beamsz)],
                            [a for i in range(beamsz)], beamsz, k,
                            self.args.ntoks)
                beam.endtok = self.endtok
                beam.eostok = self.eostok
                keys = keys.repeat(len(beam.beam), 1,
                                   1)  # (beam size, adjacency matrix len, 500)
                mask = mask.repeat(
                    len(beam.beam), 1,
                    1)  # (beam size, 1, adjacency matrix len) => all 0?
                if self.args.title:
                    tencs = tencs.repeat(len(beam.beam), 1,
                                         1)  # (beam size, title_len, 500)
                    tmask = tmask.repeat(len(beam.beam), 1,
                                         1)  # (1, 1, title_len)
                if self.args.plan:
                    planplace = planplace.unsqueeze(0).repeat(
                        len(beam.beam), 1)
                    sorder = sorder * len(beam.beam)

                ents = ents.repeat(len(beam.beam), 1,
                                   1)  # (beam size, entity num, 500)
                entlens = entlens.repeat(len(beam.beam))  # (beam size,)
            else:
                if not beam.update(scores, words, hx, cx, a):
                    break
                # if beam size changes (i.e. any of beam ends), change size of weight matrices accordingly
                keys = keys[:len(beam.beam)]
                mask = mask[:len(beam.beam)]
                if self.args.title:
                    tencs = tencs[:len(beam.beam)]
                    tmask = tmask[:len(beam.beam)]
                if self.args.plan:
                    planplace = planplace[:len(beam.beam)]
                    sorder = sorder[0] * len(beam.beam)
                ents = ents[:len(beam.beam)]
                entlens = entlens[:len(beam.beam)]
            outp = beam.getwords()  # (beam size,) / next word for each beam
            hx = beam.geth()  # (beam size, 500)
            cx = beam.getc()  # (beam size, 500)
            a = beam.getlast()  # (beam size, 1000)

        return beam
    def forward(self,
                inputs=None,
                listener_hidden=None,
                listener_outputs=None,
                function=F.log_softmax,
                teacher_forcing_ratio=0.99):
        y_hats, logit = None, None
        decode_results = []
        # Validate Arguments
        batch_size = inputs.size(0)
        max_length = inputs.size(1) - 1  # minus the start of sequence symbol
        # Initiate Speller Hidden State to zeros  :  LxBxH
        speller_hidden = torch.FloatTensor(self.layer_size, batch_size,
                                           self.hidden_size).uniform_(
                                               -1.0, 1.0)  #.cuda()
        # Decide Use Teacher Forcing or Not
        use_teacher_forcing = True if random.random(
        ) < teacher_forcing_ratio else False

        if self.use_beam_search:
            """Implementation of Beam-Search Decoding"""
            speller_input = inputs[:, 0].unsqueeze(1)
            beam = Beam(k=self.k,
                        speller_hidden=speller_hidden,
                        decoder=self,
                        batch_size=batch_size,
                        max_len=max_length,
                        decode_func=function)
            y_hats = beam.search(speller_input, listener_outputs)
        else:
            # Manual unrolling is used to support random teacher forcing.
            # If teacher_forcing_ratio is True or False instead of a probability, the unrolling can be done in graph
            if use_teacher_forcing:
                speller_input = inputs[:, :-1]  # except </s>
                """ if teacher_forcing, Infer all at once """
                predicted_softmax = self._forward_step(speller_input,
                                                       speller_hidden,
                                                       listener_outputs,
                                                       function=function)
                """Extract Output by Step"""
                for di in range(predicted_softmax.size(1)):
                    step_output = predicted_softmax[:, di, :]
                    decode_results.append(step_output)
            else:
                speller_input = inputs[:, 0].unsqueeze(1)
                for di in range(max_length):
                    predicted_softmax = self._forward_step(speller_input,
                                                           speller_hidden,
                                                           listener_outputs,
                                                           function=function)
                    # (batch_size, classfication_num)
                    step_output = predicted_softmax.squeeze(1)
                    decode_results.append(step_output)
                    speller_input = decode_results[-1].topk(1)[1]

            logit = torch.stack(decode_results, dim=1).to(self.device)
            y_hats = logit.max(-1)[1]
        print("Speller y_hats ====================")
        print(y_hats)

        return y_hats, logit if self.training else y_hats