def multi_decode(self, src_hids, src_lens):
        # src_hid : (batch_size, x_seq_len, D_hid * n_dir)
        # src_len : (batch_size)
        batch_size = src_hids[0].size(0)
        x_seq_lens = [src_hid.size(1) for src_hid in src_hids]
        src_masks = [ xlen_to_inv_mask(src_len, seq_len=x_seq_len) for (src_len, x_seq_len) \
                     in zip(src_lens, x_seq_lens) ] # (batch_size, x_seq_len)
        y_seq_len = self.max_len_gen
        z_ts = []

        for src_len, src_hid in zip(src_lens, src_hids):
            h_index = (src_len - 1)[:, None,
                                    None].repeat(1, 1, src_hid.size(2))
            z_t = torch.cumsum(src_hid,
                               dim=1)  # (batch_size, x_seq_len, D_hid * n_dir)
            z_t = z_t.gather(dim=1, index=h_index).view(
                batch_size, -1)  # (batch_size, D_hid * n_dir)
            z_t = torch.div(z_t, src_len[:, None].float())
            z_ts.append(z_t)

        y_emb = cuda(torch.full((batch_size, ), self.init_token).long())
        y_emb = self.emb(y_emb)  # (batch_size, D_emb)
        input_feeds = [
            y_emb.data.new(batch_size, self.D_hid).zero_() for ii in range(5)
        ]

        done = cuda(torch.zeros(batch_size).long())
        eos_tensor = cuda(torch.zeros(batch_size).fill_(self.eos_token)).long()

        msg = []
        for idx in range(y_seq_len):
            probs = []
            for which in range(5):
                input = torch.cat([y_emb, input_feeds[which]],
                                  dim=1)  # (batch_size * width, D_emb + D_hid)
                trg_hid = self.layers[0](
                    input, z_ts[which])  # (batch_size * width, D_hid)
                z_ts[which] = trg_hid  # (batch_size * width, D_hid)

                out, _ = self.attention(
                    trg_hid, src_hids[which],
                    src_masks[which])  # (batch_size, D_hid)
                input_feeds[which] = out
                logit = self.out(out)  # (batch_size, voc_sz_trg)
                probs.append(F.softmax(logit, -1))

            prob = torch.stack(probs, dim=-1).mean(-1)
            tokens = prob.max(dim=1)[1]  # (batch_size)
            msg.append(tokens)

            is_next_eos = (tokens == eos_tensor).long()  # (batch_size)
            done = (done + is_next_eos).clamp(min=0, max=1).long()
            if done.sum() == batch_size:
                break

            y_emb = self.emb(tokens)  # (batch_size, D_emb)

        msg = torch.stack(msg, dim=1)  # (batch_size, y_seq_len)
        return msg
    def forward(self, h, h_len, y):
        # h : (batch_size, x_seq_len, D_hid * n_dir)
        # h_len : (batch_size)
        # y : (batch_size, y_seq_len)
        batch_size, x_seq_len = h.size()[:2]
        y_seq_len = y.size()[1]
        xmask = xlen_to_inv_mask(h_len) # (batch_size, x_seq_len)

        # last hidden state of the reverse RNN
        h_index = (h_len - 1)[:,None,None].repeat(1, 1, h.size(2))
        z_t = torch.cumsum(h, dim=1) # (batch_size, x_seq_len, D_hid * n_dir)
        z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir)
        z_t = torch.div( z_t, h_len[:, None].float() )
        z_t = self.hs_ht0( z_t ) # (batch_size, n_layers * D_hid)

        y_emb = self.emb(y) # (batch_size, y_seq_len, D_emb)
        y_emb = F.dropout( y_emb, p=self.drop_ratio, training=self.training )

        ctx_h = h # (batch_size, x_seq_len, D_hid * n_dir)

        logits = []
        for idx in range(y_seq_len):
            h_t_ = self.w_ht( self.w_ht ) # (batch_size, h_t)

            # in (batch_size, 1, D_emb)
            # z_t (n_layers, batch_size, D_hid)
            _, z_t_ = self.rnn( y_emb[:, idx:idx+1, :], z_t )
            # out (batch_size, 1, D_hid)
            # z_t (n_layers, batch_size, D_hid)
            ctx_z_t_ = z_t_.transpose(0,1).contiguous().view(batch_size, -1) \
                    # (batch_size, n_layers * D_hid)
            ctx_s = self.z_to_att( ctx_z_t_ )[:,None,:] # (batch_size, 1, D_att)
            ctx_y = self.y_to_att( y_emb[:, idx:idx+1, :] ) # (batch_size, 1, D_att)
            ctx = F.tanh(ctx_h + ctx_s + ctx_y) # (batch_size, x_seq_len, D_att)

            score = self.att_to_score(ctx).view(batch_size, -1) # (batch_size, x_seq_len)
            score.masked_fill_(xmask, -float('inf'))
            score = F.softmax( score, dim=1 )

            c_t = torch.mul( h, score[:,:,None] ) # (batch_size, x_seq_len, D_hid * n_dir)
            c_t = torch.sum( c_t, 1) # (batch_size, D_hid * n_dir)
            # in (batch_size, 1, D_hid * n_dir)
            # z_t (n_layers, batch_size, D_hid)
            out, z_t = self.rnn2( c_t[:,None,:], z_t_ )
            # out (batch_size, 1, D_hid)
            # z_t (n_layers, batch_size, D_hid)

            #fin_y = self.y_to_out( y_emb[:,idx,:] ) # (batch_size, D_out)
            fin_c = self.c_to_out( c_t ) # (batch_size, D_out)
            fin_s = self.z_to_out( out.view(-1, self.D_hid) ) # (batch_size, D_out)
            fin = F.tanh( fin_c + fin_s )

            logit = self.out( fin ) # (batch_size, voc_sz_trg)
            logits.append( logit )
            # logits : list of (batch_size, voc_sz_trg) vectors

        ans = torch.stack(logits, dim=1) # (batch_size, y_seq_len, voc_sz_trg)
        return ans.view(-1, ans.size(2))
    def forward(self, src_hid, src_len, trg_tok):
        # src_hid : (batch_size, x_seq_len, D_hid * n_dir)
        # src_len : (batch_size)
        # trg_tok : (batch_size, y_seq_len)
        batch_size, x_seq_len = src_hid.size()[:2]
        src_mask = xlen_to_inv_mask(
            src_len, seq_len=x_seq_len)  # (batch_size, x_seq_len)
        y_seq_len = trg_tok.size()[1]

        h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2))
        z_t = torch.cumsum(src_hid,
                           dim=1)  # (batch_size, x_seq_len, D_hid * n_dir)
        z_t = z_t.gather(dim=1,
                         index=h_index).view(batch_size,
                                             -1)  # (batch_size, D_hid * n_dir)
        z_t = torch.div(z_t, src_len[:, None].float())  # (batch_size, D_hid)

        y_emb = self.emb(trg_tok)  # (batch_size, y_seq_len, D_emb)
        y_emb = F.dropout(y_emb, p=self.drop_ratio, training=self.training)

        outs = []
        prev_trg_hids = [z_t for i in range(self.n_layers)]
        input_feed = y_emb.data.new(batch_size, self.D_hid).zero_()

        for idx in range(y_seq_len):
            if self.input_feeding:
                input = torch.cat([y_emb[:, idx, :], input_feed],
                                  dim=1)  # (batch_size, D_emb + D_hid)
            else:
                input = y_emb[:, idx, :]

            for i, rnn in enumerate(self.layers):
                trg_hid = rnn(input, prev_trg_hids[i])  # (batch_sizem D_hid)
                input = F.dropout(trg_hid,
                                  p=self.drop_ratio,
                                  training=self.training)
                prev_trg_hids[i] = trg_hid

            if self.attention is not None:
                out, attn_scores = self.attention(
                    trg_hid, src_hid, src_mask)  # (batch_size, D_hid)
            else:
                out = trg_hid

            input_feed = out
            outs.append(out)

        x = torch.cat(outs, dim=1).view(batch_size, y_seq_len, self.D_hid)
        x = self.out(x).view(-1, self.voc_sz_trg)

        return x
    def forward(self, batch, en_lm=None, all_img=None, ranker=None):
        """ Return all stuff related to reinforce """
        results = {}
        rewards = {}

        # Speak fr en first
        en_msg, en_msg_len = self.fr_en_speak(batch, is_training=True)
        grounding_results, grounding_rewards = self.get_grounding(
            en_msg,
            en_msg_len,
            batch,
            en_lm=en_lm,
            all_img=all_img,
            ranker=ranker)
        results.update(grounding_results)
        rewards.update(grounding_rewards)

        # Speak De and get reward
        (de, de_len) = batch.de
        batch_size = len(batch)
        de_input, de_target = de[:, :-1], de[:, 1:].contiguous().view(-1)
        de_logits, _ = self.en_de(
            en_msg, en_msg_len,
            de_input)  # (batch_size * en_seq_len, vocab_size)
        de_nll = F.cross_entropy(de_logits,
                                 de_target,
                                 ignore_index=0,
                                 reduction='none')
        results['ce_loss'] = de_nll.mean()
        de_nll = de_nll.view(
            batch_size, -1).sum(dim=1) / (de_len - 1).float()  # (batch_size)
        rewards['ce'] = -1 * de_nll.detach(
        )  # NOTE Experiment 1 : Reward = NLL_DE

        # Entropy
        if not (self.fr_en.dec.neg_Hs is []):
            neg_Hs = self.fr_en.dec.neg_Hs  # (batch_size, en_msg_len)
            neg_Hs = neg_Hs.mean()  # (1,)
            results["neg_Hs"] = neg_Hs

        # Reward shaping
        R = rewards['ce']
        if self.train_en_lm:
            R += rewards['lm'] * self.en_lm_nll_co
        if self.train_ranker:
            R += rewards['img_pred'] * self.img_pred_loss_co

        if not self.fix_fr2en:
            R_b = self.fr_en.dec.R_b  # (batch_size, en_msg_len)
            en_mask = xlen_to_inv_mask(en_msg_len, R_b.size(1))
            b_loss = ((R[:, None] - R_b)**2)  # (batch_size, en_msg_len)
            b_loss.masked_fill_(en_mask.bool(), 0)  # (batch_size, en_msg_len)
            b_loss = b_loss.sum(dim=1) / (en_msg_len).float()  # (batch_size)
            b_loss = b_loss.mean()  # (1,)
            pg_loss = -1 * self.fr_en.dec.log_probs  # (batch_size, en_msg_len)
            pg_loss = (R[:, None] -
                       R_b).detach() * pg_loss  # (batch_size, en_msg_len)
            pg_loss.masked_fill_(en_mask.bool(), 0)  # (batch_size, en_msg_len)
            pg_loss = pg_loss.sum(dim=1) / (en_msg_len).float()  # (batch_size)
            pg_loss = pg_loss.mean()  # (1,)
            results.update({"pg_loss": pg_loss, "b_loss": b_loss})
        return results
    def beam_search(self, h, h_len, width): # (batch_size, x_seq_len, D_hid * n_dir)
        voc_size, batch_size, x_seq_len = self.voc_sz_trg, h.size()[0], h.size()[1]
        live = [ [ ( 0.0, [ self.init_token ], 2 ) ] for ii in range(batch_size) ]
        dead = [ [] for ii in range(batch_size) ]
        n_dead = [0 for ii in range(batch_size)]
        xmask = xlen_to_inv_mask(h_len)[:,None,:] # (batch_size, 1, x_seq_len)
        max_len_gen = self.max_len_gen if self.msg_len_ratio < 0.0 else int(x_seq_len * self.msg_len_ratio)
        max_len_gen = np.clip(max_len_gen, 2, self.max_len_gen).item()

        h_index = (h_len - 1)[:,None,None].repeat(1, 1, h.size(2))
        z_t = torch.cumsum(h, dim=1) # (batch_size, x_seq_len, D_hid * n_dir)
        z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir)
        z_t = torch.div( z_t, h_len[:, None].float() )
        z_t = self.ctx_to_z0(z_t) # (batch_size, n_layers * D_hid)
        z_t = z_t.view(batch_size, self.n_layers, self.D_hid).transpose(0,1).contiguous() \

        input = cuda( torch.full((batch_size, 1), self.init_token).long() )
        input = self.emb( input ) # (batch_size, 1, D_emb)

        h_big = h.contiguous().view(-1, self.D_hid * self.n_dir ) \
                # (batch_size * x_seq_len, D_hid * n_dir)
        ctx_h = self.h_to_att( h_big ).view(batch_size, 1, x_seq_len, self.D_att) \
                # NOTE (batch_size, 1, x_seq_len, D_att)

        for tidx in range(max_len_gen):
            cwidth = 1 if tidx == 0 else width
            # input (batch_size * width, 1, D_emb)
            # z_t (n_layers, batch_size * width, D_hid)
            _, z_t_ = self.rnn( input, z_t )
            # out (batch_size * width, 1, D_hid)
            # z_t (n_layers, batch_size * width, D_hid)
            ctx_z_t_ = z_t_.transpose(0,1).contiguous().view(batch_size * cwidth, -1) \
                    # (batch_size * width, n_layers * D_hid)

            ctx_y = self.y_to_att( input.view(-1, self.D_emb) ).view(batch_size, cwidth, 1, self.D_att)
            ctx_s = self.z_to_att( ctx_z_t_ ).view(batch_size, cwidth, 1, self.D_att)
            ctx = F.tanh(ctx_y + ctx_s + ctx_h) # (batch_size, cwidth, x_seq_len, D_att)
            ctx = ctx.view(batch_size * cwidth * x_seq_len, self.D_att)

            score = self.att_to_score(ctx).view(batch_size, -1, x_seq_len) # (batch_size, cwidth, x_seq_len)
            score.masked_fill_(xmask.repeat(1, cwidth, 1), -float('inf'))
            score = F.softmax( score.view(-1, x_seq_len), dim=1 ).view(batch_size, -1, x_seq_len)
            score = score.view(batch_size, cwidth, x_seq_len, 1) # (batch_size, width, x_seq_len, 1)

            c_t = torch.mul( h.view(batch_size, 1, x_seq_len, -1), score ) # (batch_size, width, x_seq_len, D_hid * n_dir)
            c_t = torch.sum( c_t, 2).view(batch_size * cwidth, -1) # (batch_size * width, D_hid * n_dir)
            # c_t (batch_size * width, 1, D_hid * n_dir)
            # z_t (n_layers, batch_size * width, D_hid)
            out, z_t = self.rnn2( c_t[:,None,:], z_t_ )
            # out (batch_size * width, 1, D_hid)
            # z_t (n_layers, batch_size * width, D_hid)

            fin_y = self.y_to_out( input.view(-1, self.D_emb) ) # (batch_size * width, D_out)
            fin_c = self.c_to_out( c_t ) # (batch_size * width, D_out)
            fin_s = self.z_to_out( out.view(-1, self.D_hid) ) # (batch_size * width, D_out)
            fin = F.tanh( fin_y + fin_c + fin_s )

            cur_prob = F.log_softmax( self.out( fin.view(-1, self.D_out) ), dim=1 )\
                    .view(batch_size, cwidth, voc_size).data # (batch_size, width, voc_sz_trg)
            pre_prob = cuda( torch.FloatTensor( [ [ x[0] for x in ee ] for ee in live ] ).view(batch_size, cwidth, 1) ) \
                    # (batch_size, width, 1)
            total_prob = cur_prob + pre_prob # (batch_size, cwidth, voc_size)
            total_prob = total_prob.view(batch_size, -1)

            _, topi_s = total_prob.topk(width, dim=1)
            topv_s = cur_prob.view(batch_size, -1).gather(1, topi_s)
            # (batch_size, width)

            new_live = [ [] for ii in range(batch_size) ]
            for bidx in range(batch_size):
                n_live = width - n_dead[bidx]
                if n_live > 0:
                    tis = topi_s[bidx][:n_live]
                    tvs = topv_s[bidx][:n_live]
                    for eidx, (topi, topv) in enumerate(zip(tis, tvs)): # NOTE max width times
                        if topi % voc_size == self.eoz_token :
                            dead[bidx].append( (  live[bidx][ topi // voc_size ][0] + topv, \
                                                  live[bidx][ topi // voc_size ][1] + [ topi % voc_size ],\
                                                  topi) )
                            n_dead[bidx] += 1
                        else:
                            new_live[bidx].append( (    live[bidx][ topi // voc_size ][0] + topv, \
                                                        live[bidx][ topi // voc_size ][1] + [ topi % voc_size ],\
                                                        topi) )
                while len(new_live[bidx]) < width:
                    new_live[bidx].append( (    -99999999999, \
                                                [0],\
                                                0) )
            live = new_live

            if n_dead == [width for ii in range(batch_size)]:
                break

            in_vocab_idx = [ [ x[2] % voc_size for x in ee ] for ee in live ] # NOTE batch_size first
            input = self.emb( cuda( torch.LongTensor( in_vocab_idx ) ).view(-1) )\
                    .view(-1, 1, self.D_emb) # input (batch_size * width, 1, D_emb)

            in_width_idx = [ [ x[2] // voc_size + bbidx * cwidth for x in ee ] for bbidx, ee in enumerate(live) ] \
                    # live : (batch_size, width)
            z_t = z_t.index_select( 1, cuda( torch.LongTensor( in_width_idx ).view(-1) ) ).\
                    view(self.n_layers, batch_size * width, self.D_hid)
            # h_0 (n_layers, batch_size * width, D_hid)

        for bidx in range(batch_size):
            if n_dead[bidx] < width:
                for didx in range( width - n_dead[bidx] ):
                    (a, b, c) = live[bidx][didx]
                    dead[bidx].append( (a, b, c)  )

        dead_ = [ [ ( a / ( math.pow(5+len(b), self.beam_alpha) / math.pow(5+1, self.beam_alpha) ) , b, c) for (a,b,c) in ee] for ee in dead]
        ans = []
        for dd_ in dead_:
            dd = sorted( dd_, key=operator.itemgetter(0), reverse=True )
            ans.append( dd[0][1] )
        return ans
    def send(self, h, h_len, y_len, send_method):
        # h : (batch_size, x_seq_len, D_hid * n_dir)
        # h_len : (batch_size)
        batch_size, x_seq_len = h.size()[:2]
        xmask = xlen_to_inv_mask(h_len) # (batch_size, x_seq_len)
        max_len_gen = self.max_len_gen if self.msg_len_ratio < 0.0 else int(y_len.max().item() * self.msg_len_ratio)
        max_len_gen = np.clip(max_len_gen, 2, self.max_len_gen).item()

        h_index = (h_len - 1)[:,None,None].repeat(1, 1, h.size(2))
        z_t = torch.cumsum(h, dim=1) # (batch_size, x_seq_len, D_hid * n_dir)
        z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir)
        z_t = torch.div( z_t, h_len[:, None].float() )
        z_t = self.ctx_to_z0(z_t) # (batch_size, n_layers * D_hid)
        z_t = z_t.view(batch_size, self.n_layers, self.D_hid).transpose(0,1).contiguous() \

        y_emb = cuda( torch.full((batch_size, 1), self.init_token).long() )
        y_emb = self.emb( y_emb ) # (batch_size, 1, D_emb)
        y_emb = F.dropout( y_emb, p=self.drop_ratio, training=self.training )

        h_big = h.view(-1, h.size(2) ) \
                # (batch_size * x_seq_len, D_hid * n_dir) 
        ctx_h = self.h_to_att( h_big ).view(batch_size, x_seq_len, self.D_att)

        done = cuda( torch.zeros(batch_size).long() )
        seq_lens = cuda( torch.zeros(batch_size).fill_(max_len_gen).long() ) # (max_len)
        max_seq_lens = cuda( torch.zeros(batch_size).fill_(max_len_gen).long() ) # (max_len)
        eos_tensor = cuda( torch.zeros(batch_size).fill_( self.eoz_token )).long()
        msg = []
        self.log_probs = [] # (batch_size, seq_len)
        batch_logits = 0

        for idx in range(max_len_gen):
            # in (batch_size, 1, D_emb)
            # z_t (n_layers, batch_size, D_hid)
            _, z_t_ = self.rnn( y_emb, z_t )
            # out (batch_size, 1, D_hid)
            # z_t (n_layers, batch_size, D_hid)
            ctx_z_t_ = z_t_.transpose(0,1).contiguous().view(batch_size, self.n_layers * self.D_hid) \
                    # (batch_size, n_layers * D_hid)

            ctx_y = self.y_to_att( y_emb.view(batch_size, self.D_emb) )[:,None,:]
            ctx_s = self.z_to_att( ctx_z_t_ )[:,None,:]
            ctx = F.tanh(ctx_y + ctx_s + ctx_h)
            ctx = ctx.view(batch_size * x_seq_len, self.D_att)

            score = self.att_to_score(ctx).view(batch_size, -1) # (batch_size, x_seq_len)
            score.masked_fill_(xmask, -float('inf'))
            score = F.softmax( score, dim=1 )
            score = score[:,:,None] # (batch_size, x_seq_len, 1)

            c_t = torch.mul( h, score ) # (batch_size, x_seq_len, D_hid * n_dir)
            c_t = torch.sum( c_t, 1) # (batch_size, D_hid * n_dir)
            # in (batch_size, 1, D_hid * n_dir)
            # z_t (n_layers, batch_size, D_hid)
            out, z_t = self.rnn2( c_t[:,None,:], z_t_ )
            # out (batch_size, 1, D_hid)
            # z_t (n_layers, batch_size, D_hid)

            fin_y = self.y_to_out( y_emb.view(batch_size, self.D_emb) ) # (batch_size, D_out)
            fin_c = self.c_to_out( c_t ) # (batch_size, D_out)
            fin_s = self.z_to_out( out.view(batch_size, self.D_hid) ) # (batch_size, D_out)
            fin = F.tanh( fin_y + fin_c + fin_s )

            logit = self.out( fin ) # (batch_size, voc_sz_trg)

            if send_method == "argmax":
                tokens = logit.data.max(dim=1)[1] # (batch_size)
                tokens_idx = tokens

            elif send_method == "gumbel":
                tokens = gumbel_softmax_hard(logit, self.temp, self.st)\
                        .view(batch_size, self.voc_sz_trg) # (batch_size, voc_sz_trg)
                tokens_idx = (tokens * cuda(torch.arange(self.voc_sz_trg))[None,:]).sum(dim=1).long()

            elif send_method == "reinforce":
                cat = Categorical(logits=logit)
                tokens = cat.sample()
                tokens_idx = tokens
                self.log_probs.append( cat.log_prob(tokens) )

                if self.entropy:
                    batch_logits += (logit * (1-done)[:,None].float()).sum(dim=0)

            msg.append(tokens.unsqueeze(dim=1))

            is_next_eos = ( tokens_idx == eos_tensor ).long() # (batch_size)
              # (1 if eos, 0 otherwise)
            done = (done + is_next_eos).clamp(min=0, max=1).long()
            new_seq_lens = max_seq_lens.clone().masked_fill_(mask=is_next_eos.byte(), \
                                                             value=float(idx+1)) # either max or idx if next is eos
            # max_seq_lens : (batch_size)
            seq_lens = torch.min(seq_lens, new_seq_lens)

            if done.sum() == batch_size:
                break

            y_emb = self.emb(tokens)[:,None,:] # (batch_size, 1, D_emb)

        if self.msg_len_ratio > 0.0 :
            seq_lens = torch.clamp( (y_len.float() * self.msg_len_ratio).floor_(), 1, len(msg) ).long()

        msg = torch.cat(msg, dim=1)
        if send_method == "reinforce": # want to sum per-token log prob to yield log prob for the whole message sentence
            self.log_probs = torch.stack(self.log_probs, dim=1) # (batch_size, seq_len)

          # (batch_size, seq_len) if argmax or reinforce
          # (batch_size, seq_len, voc_sz_trg) if gumbel
        results = {"msg":msg, "seq_lens":seq_lens}
        if send_method == "reinforce" and self.entropy:
            results.update( {"batch_logits":batch_logits} )
        return results
    def beam(self, src_hid, src_len):
        # src_hid : (batch_size, x_seq_len, D_hid * n_dir)
        # src_len : (batch_size)
        batch_size, x_seq_len = src_hid.size()[:2]
        src_mask = xlen_to_inv_mask(
            src_len, seq_len=x_seq_len)  # (batch_size, x_seq_len)
        voc_size, width = self.voc_sz_trg, self.beam_width
        y_seq_len = self.max_len_gen

        h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2))
        z_t = torch.cumsum(src_hid,
                           dim=1)  # (batch_size, x_seq_len, D_hid * n_dir)
        z_t = z_t.gather(dim=1,
                         index=h_index).view(batch_size,
                                             -1)  # (batch_size, D_hid * n_dir)
        z_t = torch.div(z_t, src_len[:, None].float())

        y_emb = cuda(torch.full((batch_size, ), self.init_token).long())
        y_emb = self.emb(y_emb)  # (batch_size, D_emb)

        input_feed = y_emb.data.new(batch_size, self.D_hid).zero_()

        live = [[(0.0, [self.init_token], 2)] for ii in range(batch_size)]
        dead = [[] for ii in range(batch_size)]
        n_dead = [0 for ii in range(batch_size)]
        src_hid_ = src_hid
        src_mask_ = src_mask

        for idx in range(y_seq_len):
            cwidth = 1 if idx == 0 else width
            input = torch.cat([y_emb, input_feed],
                              dim=1)  # (batch_size * width, D_emb + D_hid)
            trg_hid = self.layers[0](input, z_t)  # (batch_size * width, D_hid)
            z_t = trg_hid  # (batch_size * width, D_hid)

            out, attn_scores = self.attention(
                trg_hid, src_hid_, src_mask_)  # (batch_size * width, D_hid)
            input_feed = out  # (batch_size * width, D_hid)

            logit = self.out(out)  # (batch_size * width, voc_sz_trg)
            cur_prob = F.log_softmax(logit,
                                     dim=1).view(batch_size, cwidth,
                                                 self.voc_sz_trg)
            pre_prob = cuda( torch.FloatTensor( [ [ x[0] for x in ee ] for ee in live ] )\
                            .view(batch_size, cwidth, 1) ) # (batch_size, width, 1)
            total_prob = cur_prob + pre_prob  # (batch_size, cwidth, voc_sz)
            total_prob = total_prob.view(batch_size,
                                         -1)  # (batch_size, cwidth * voc_sz)

            topi_s = total_prob.topk(width, dim=1)[1]
            topv_s = cur_prob.view(batch_size, -1).gather(1, topi_s)

            new_live = [[] for ii in range(batch_size)]
            for bidx in range(batch_size):
                n_live = width - n_dead[bidx]
                if n_live > 0:
                    tis = topi_s[bidx][:n_live].cpu().numpy().tolist()
                    tvs = topv_s[bidx][:n_live].cpu().numpy().tolist()
                    for eidx, (topi, topv) in enumerate(zip(tis, tvs)):
                        if topi % voc_size == self.eos_token:
                            dead[bidx].append(
                                (live[bidx][topi // voc_size][0] + topv,
                                 live[bidx][topi // voc_size][1] +
                                 [topi % voc_size], topi))
                            n_dead[bidx] += 1
                        else:
                            new_live[bidx].append(
                                (live[bidx][topi // voc_size][0] + topv,
                                 live[bidx][topi // voc_size][1] +
                                 [topi % voc_size], topi))
                while len(new_live[bidx]) < width:
                    new_live[bidx].append((-99999999999, [0], 0))
            live = new_live
            if n_dead == [width for ii in range(batch_size)]:
                break

            in_vocab_idx = np.array([[x[2] % voc_size for x in ee]
                                     for ee in live])  # NOTE batch_size first
            y_emb = self.emb( cuda( torch.LongTensor( in_vocab_idx ) ).view(-1) )\
                    .view(-1, self.D_emb) # input (batch_size * width, 1, D_emb)

            in_width_idx = np.array( [ [ x[2] // voc_size + bbidx * cwidth for x in ee ] \
                            for bbidx, ee in enumerate(live) ] )
            in_width_idx = cuda(torch.LongTensor(in_width_idx).view(-1))
            z_t = z_t.index_select(0,
                                   in_width_idx).view(batch_size * width,
                                                      self.D_hid)
            input_feed = input_feed.index_select(0, in_width_idx).view(
                batch_size * width, self.D_hid)
            src_hid_ = src_hid_.index_select(0, in_width_idx).view(
                batch_size * width, x_seq_len, self.D_hid)
            src_mask_ = src_mask_.index_select(0, in_width_idx).view(
                batch_size * width, x_seq_len)

        for bidx in range(batch_size):
            if n_dead[bidx] < width:
                for didx in range(width - n_dead[bidx]):
                    (a, b, c) = live[bidx][didx]
                    dead[bidx].append((a, b, c))

        dead_ = [ [ ( a / ( math.pow(5+len(b), self.beam_alpha) / math.pow(5+1, self.beam_alpha) ) , b, c)\
                   for (a,b,c) in ee] for ee in dead]
        ans = [[], [], [], [], []]
        for dd_ in dead_:
            dd = sorted(dd_, key=operator.itemgetter(0), reverse=True)
            for idx in range(5):
                ans[idx].append(dd[idx][1])
            #ans.append( dd[0][1] )
        return ans
    def send(self,
             src_hid,
             src_len,
             trg_len,
             send_method,
             value_fn=None,
             gumbel_temp=1,
             dot_token=None):
        # src_hid : (batch_size, x_seq_len, D_hid * n_dir)
        # src_len : (batch_size)
        batch_size, x_seq_len = src_hid.size()[:2]
        src_mask = xlen_to_inv_mask(
            src_len, seq_len=x_seq_len)  # (batch_size, x_seq_len)
        y_seq_len = math.floor(trg_len.max().item() * self.msg_len_ratio) \
            if self.msg_len_ratio > 0 else self.max_len_gen

        h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2))
        z_t = torch.cumsum(src_hid,
                           dim=1)  # (batch_size, x_seq_len, D_hid * n_dir)
        z_t = z_t.gather(dim=1,
                         index=h_index).view(batch_size,
                                             -1)  # (batch_size, D_hid * n_dir)
        z_t = torch.div(z_t, src_len[:, None].float())

        y_emb = cuda(torch.full((batch_size, ), self.init_token).long())
        y_emb = self.emb(y_emb)  # (batch_size, D_emb)
        #y_emb = F.dropout( y_emb, p=self.drop_ratio, training=self.training )

        prev_trg_hids = [z_t for _ in range(self.n_layers)]
        trg_hid = prev_trg_hids[0]
        input_feed = y_emb.data.new(batch_size, self.D_hid).zero_()

        done = cuda(torch.zeros(batch_size).long())
        seq_lens = cuda(
            torch.zeros(batch_size).fill_(y_seq_len).long())  # (max_len)
        max_seq_lens = cuda(
            torch.zeros(batch_size).fill_(y_seq_len).long())  # (max_len)
        eos_tensor = cuda(torch.zeros(batch_size).fill_(self.eos_token)).long()

        self.log_probs, self.R_b, self.neg_Hs, self.gumbel_tokens = [], [], [], []
        msg = []

        for idx in range(y_seq_len):
            if self.input_feeding:
                input = torch.cat([y_emb, input_feed],
                                  dim=1)  # (batch_size, D_emb + D_hid)
            else:
                input = y_emb

            if send_method == "reinforce" and value_fn:
                if self.input_feeding:
                    value_fn_input = [y_emb, input_feed]
                else:
                    value_fn_input = [y_emb, trg_hid]
                value_fn_input = torch.cat(value_fn_input, dim=1).detach()
                self.R_b.append(
                    value_fn(value_fn_input).view(-1))  # (batch_size)

            for i, rnn in enumerate(self.layers):
                trg_hid = rnn(input, prev_trg_hids[i])  # (batch_size, D_hid)
                #input = F.dropout(trg_hid, p=self.drop_ratio, training=self.training)
                input = trg_hid
                prev_trg_hids[i] = trg_hid

            if self.attention is not None:
                out, attn_scores = self.attention(
                    trg_hid, src_hid, src_mask)  # (batch_size, D_hid)
            else:
                out = trg_hid

            input_feed = out
            logit = self.out(out)  # (batch_size, voc_sz_trg)

            if send_method == "argmax":
                tokens = logit.max(dim=1)[1]  # (batch_size)
                tok_dist = Categorical(logits=logit)
                self.neg_Hs.append(-1 * tok_dist.entropy())

            elif send_method == "reinforce":
                tok_dist = Categorical(logits=logit)
                tokens = tok_dist.sample()
                self.log_probs.append(
                    tok_dist.log_prob(tokens))  # (batch_size)
                self.neg_Hs.append(-1 * tok_dist.entropy())

            elif send_method == "gumbel":
                y = gumbel_softmax(logit, gumbel_temp)
                tok_dist = Categorical(probs=y)
                self.neg_Hs.append(-1 * tok_dist.entropy())
                tokens = torch.argmax(y, dim=1)
                tokens_oh = cuda(torch.zeros(y.size())).scatter_(
                    1, tokens.unsqueeze(-1), 1)
                self.gumbel_tokens.append((tokens_oh - y).detach() + y)
            else:
                raise ValueError

            msg.append(tokens)
            is_next_eos = (tokens == eos_tensor).long()  # (batch_size)
            new_seq_lens = max_seq_lens.clone().masked_fill_(
                mask=is_next_eos.bool(),
                value=float(idx +
                            1))  # NOTE idx+1 ensures this is valid length
            seq_lens = torch.min(seq_lens, new_seq_lens)  # (contains lengths)

            done = (done + is_next_eos).clamp(min=0, max=1).long()
            if done.sum() == batch_size:
                break

            y_emb = self.emb(tokens)  # (batch_size, D_emb)

        msg = torch.stack(msg, dim=1)  # (batch_size, y_seq_len)
        self.neg_Hs = torch.stack(self.neg_Hs,
                                  dim=1)  # (batch_size, y_seq_len)
        if send_method == "reinforce":
            self.log_probs = torch.stack(self.log_probs,
                                         dim=1)  # (batch_size, y_seq_len)
            self.R_b = torch.stack(
                self.R_b,
                dim=1) if value_fn else self.R_b  # (batch_size, y_seq_len)

        if send_method == 'gumbel' and len(self.gumbel_tokens) > 0:
            self.gumbel_tokens = torch.stack(self.gumbel_tokens, dim=1)

        result = {"msg": msg.clone(), "new_seq_lens": seq_lens.clone()}

        # Trim sequence length with first dot tensor
        if dot_token is not None:
            seq_lens = first_appear_indice(msg, dot_token) + 1
            pass

        # Jason's trick on trim the sentence length based on ground truth length
        if self.msg_len_ratio > 0:
            en_ref_len = torch.floor(trg_len.float() *
                                     self.msg_len_ratio).long()
            seq_lens = torch.min(seq_lens, en_ref_len)

        # Make length larger than min len and valid
        seq_lens = torch.max(
            seq_lens,
            seq_lens.new(seq_lens.size()).fill_(self.min_len_gen))
        seq_lens = torch.min(seq_lens,
                             seq_lens.new(seq_lens.size()).fill_(msg.shape[1]))

        # Make sure message is valid
        ends_with_eos = (msg.gather(
            dim=1, index=(seq_lens - 1)[:,
                                        None]).view(-1) == eos_tensor).long()
        eos_or_pad = ends_with_eos * self.pad_token + (
            1 - ends_with_eos) * self.eos_token
        msg = torch.cat(
            [msg, msg.new(batch_size, 1).fill_(self.pad_token)], dim=1)
        if send_method == 'gumbel' and len(self.gumbel_tokens) > 0:
            pad_gumbel_tokens = cuda(
                torch.zeros(msg.shape[0], 1, self.gumbel_tokens.shape[2]))
            pad_gumbel_tokens[:, :, self.pad_token] = 1
            self.gumbel_tokens = torch.cat(
                [self.gumbel_tokens, pad_gumbel_tokens], dim=1)

        # (batch_size, y_seq_len + 1)
        msg.scatter_(dim=1, index=seq_lens[:, None], src=eos_or_pad[:, None])
        if send_method == 'gumbel' and len(self.gumbel_tokens) > 0:
            batch_ids = cuda(torch.arange(msg.shape[0]))
            self.gumbel_tokens[batch_ids, seq_lens] = 0.
            self.gumbel_tokens[batch_ids, seq_lens, eos_or_pad] = 1.

        seq_lens = seq_lens + (1 - ends_with_eos)
        msg_mask = xlen_to_inv_mask(
            seq_lens, seq_len=msg.size(1))  # (batch_size, x_seq_len)
        msg.masked_fill_(msg_mask.bool(), self.pad_token)
        result.update({"msg": msg, "new_seq_lens": seq_lens})
        return result