def multi_decode(self, src_hids, src_lens): # src_hid : (batch_size, x_seq_len, D_hid * n_dir) # src_len : (batch_size) batch_size = src_hids[0].size(0) x_seq_lens = [src_hid.size(1) for src_hid in src_hids] src_masks = [ xlen_to_inv_mask(src_len, seq_len=x_seq_len) for (src_len, x_seq_len) \ in zip(src_lens, x_seq_lens) ] # (batch_size, x_seq_len) y_seq_len = self.max_len_gen z_ts = [] for src_len, src_hid in zip(src_lens, src_hids): h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2)) z_t = torch.cumsum(src_hid, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view( batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div(z_t, src_len[:, None].float()) z_ts.append(z_t) y_emb = cuda(torch.full((batch_size, ), self.init_token).long()) y_emb = self.emb(y_emb) # (batch_size, D_emb) input_feeds = [ y_emb.data.new(batch_size, self.D_hid).zero_() for ii in range(5) ] done = cuda(torch.zeros(batch_size).long()) eos_tensor = cuda(torch.zeros(batch_size).fill_(self.eos_token)).long() msg = [] for idx in range(y_seq_len): probs = [] for which in range(5): input = torch.cat([y_emb, input_feeds[which]], dim=1) # (batch_size * width, D_emb + D_hid) trg_hid = self.layers[0]( input, z_ts[which]) # (batch_size * width, D_hid) z_ts[which] = trg_hid # (batch_size * width, D_hid) out, _ = self.attention( trg_hid, src_hids[which], src_masks[which]) # (batch_size, D_hid) input_feeds[which] = out logit = self.out(out) # (batch_size, voc_sz_trg) probs.append(F.softmax(logit, -1)) prob = torch.stack(probs, dim=-1).mean(-1) tokens = prob.max(dim=1)[1] # (batch_size) msg.append(tokens) is_next_eos = (tokens == eos_tensor).long() # (batch_size) done = (done + is_next_eos).clamp(min=0, max=1).long() if done.sum() == batch_size: break y_emb = self.emb(tokens) # (batch_size, D_emb) msg = torch.stack(msg, dim=1) # (batch_size, y_seq_len) return msg
def forward(self, h, h_len, y): # h : (batch_size, x_seq_len, D_hid * n_dir) # h_len : (batch_size) # y : (batch_size, y_seq_len) batch_size, x_seq_len = h.size()[:2] y_seq_len = y.size()[1] xmask = xlen_to_inv_mask(h_len) # (batch_size, x_seq_len) # last hidden state of the reverse RNN h_index = (h_len - 1)[:,None,None].repeat(1, 1, h.size(2)) z_t = torch.cumsum(h, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div( z_t, h_len[:, None].float() ) z_t = self.hs_ht0( z_t ) # (batch_size, n_layers * D_hid) y_emb = self.emb(y) # (batch_size, y_seq_len, D_emb) y_emb = F.dropout( y_emb, p=self.drop_ratio, training=self.training ) ctx_h = h # (batch_size, x_seq_len, D_hid * n_dir) logits = [] for idx in range(y_seq_len): h_t_ = self.w_ht( self.w_ht ) # (batch_size, h_t) # in (batch_size, 1, D_emb) # z_t (n_layers, batch_size, D_hid) _, z_t_ = self.rnn( y_emb[:, idx:idx+1, :], z_t ) # out (batch_size, 1, D_hid) # z_t (n_layers, batch_size, D_hid) ctx_z_t_ = z_t_.transpose(0,1).contiguous().view(batch_size, -1) \ # (batch_size, n_layers * D_hid) ctx_s = self.z_to_att( ctx_z_t_ )[:,None,:] # (batch_size, 1, D_att) ctx_y = self.y_to_att( y_emb[:, idx:idx+1, :] ) # (batch_size, 1, D_att) ctx = F.tanh(ctx_h + ctx_s + ctx_y) # (batch_size, x_seq_len, D_att) score = self.att_to_score(ctx).view(batch_size, -1) # (batch_size, x_seq_len) score.masked_fill_(xmask, -float('inf')) score = F.softmax( score, dim=1 ) c_t = torch.mul( h, score[:,:,None] ) # (batch_size, x_seq_len, D_hid * n_dir) c_t = torch.sum( c_t, 1) # (batch_size, D_hid * n_dir) # in (batch_size, 1, D_hid * n_dir) # z_t (n_layers, batch_size, D_hid) out, z_t = self.rnn2( c_t[:,None,:], z_t_ ) # out (batch_size, 1, D_hid) # z_t (n_layers, batch_size, D_hid) #fin_y = self.y_to_out( y_emb[:,idx,:] ) # (batch_size, D_out) fin_c = self.c_to_out( c_t ) # (batch_size, D_out) fin_s = self.z_to_out( out.view(-1, self.D_hid) ) # (batch_size, D_out) fin = F.tanh( fin_c + fin_s ) logit = self.out( fin ) # (batch_size, voc_sz_trg) logits.append( logit ) # logits : list of (batch_size, voc_sz_trg) vectors ans = torch.stack(logits, dim=1) # (batch_size, y_seq_len, voc_sz_trg) return ans.view(-1, ans.size(2))
def forward(self, src_hid, src_len, trg_tok): # src_hid : (batch_size, x_seq_len, D_hid * n_dir) # src_len : (batch_size) # trg_tok : (batch_size, y_seq_len) batch_size, x_seq_len = src_hid.size()[:2] src_mask = xlen_to_inv_mask( src_len, seq_len=x_seq_len) # (batch_size, x_seq_len) y_seq_len = trg_tok.size()[1] h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2)) z_t = torch.cumsum(src_hid, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div(z_t, src_len[:, None].float()) # (batch_size, D_hid) y_emb = self.emb(trg_tok) # (batch_size, y_seq_len, D_emb) y_emb = F.dropout(y_emb, p=self.drop_ratio, training=self.training) outs = [] prev_trg_hids = [z_t for i in range(self.n_layers)] input_feed = y_emb.data.new(batch_size, self.D_hid).zero_() for idx in range(y_seq_len): if self.input_feeding: input = torch.cat([y_emb[:, idx, :], input_feed], dim=1) # (batch_size, D_emb + D_hid) else: input = y_emb[:, idx, :] for i, rnn in enumerate(self.layers): trg_hid = rnn(input, prev_trg_hids[i]) # (batch_sizem D_hid) input = F.dropout(trg_hid, p=self.drop_ratio, training=self.training) prev_trg_hids[i] = trg_hid if self.attention is not None: out, attn_scores = self.attention( trg_hid, src_hid, src_mask) # (batch_size, D_hid) else: out = trg_hid input_feed = out outs.append(out) x = torch.cat(outs, dim=1).view(batch_size, y_seq_len, self.D_hid) x = self.out(x).view(-1, self.voc_sz_trg) return x
def forward(self, batch, en_lm=None, all_img=None, ranker=None): """ Return all stuff related to reinforce """ results = {} rewards = {} # Speak fr en first en_msg, en_msg_len = self.fr_en_speak(batch, is_training=True) grounding_results, grounding_rewards = self.get_grounding( en_msg, en_msg_len, batch, en_lm=en_lm, all_img=all_img, ranker=ranker) results.update(grounding_results) rewards.update(grounding_rewards) # Speak De and get reward (de, de_len) = batch.de batch_size = len(batch) de_input, de_target = de[:, :-1], de[:, 1:].contiguous().view(-1) de_logits, _ = self.en_de( en_msg, en_msg_len, de_input) # (batch_size * en_seq_len, vocab_size) de_nll = F.cross_entropy(de_logits, de_target, ignore_index=0, reduction='none') results['ce_loss'] = de_nll.mean() de_nll = de_nll.view( batch_size, -1).sum(dim=1) / (de_len - 1).float() # (batch_size) rewards['ce'] = -1 * de_nll.detach( ) # NOTE Experiment 1 : Reward = NLL_DE # Entropy if not (self.fr_en.dec.neg_Hs is []): neg_Hs = self.fr_en.dec.neg_Hs # (batch_size, en_msg_len) neg_Hs = neg_Hs.mean() # (1,) results["neg_Hs"] = neg_Hs # Reward shaping R = rewards['ce'] if self.train_en_lm: R += rewards['lm'] * self.en_lm_nll_co if self.train_ranker: R += rewards['img_pred'] * self.img_pred_loss_co if not self.fix_fr2en: R_b = self.fr_en.dec.R_b # (batch_size, en_msg_len) en_mask = xlen_to_inv_mask(en_msg_len, R_b.size(1)) b_loss = ((R[:, None] - R_b)**2) # (batch_size, en_msg_len) b_loss.masked_fill_(en_mask.bool(), 0) # (batch_size, en_msg_len) b_loss = b_loss.sum(dim=1) / (en_msg_len).float() # (batch_size) b_loss = b_loss.mean() # (1,) pg_loss = -1 * self.fr_en.dec.log_probs # (batch_size, en_msg_len) pg_loss = (R[:, None] - R_b).detach() * pg_loss # (batch_size, en_msg_len) pg_loss.masked_fill_(en_mask.bool(), 0) # (batch_size, en_msg_len) pg_loss = pg_loss.sum(dim=1) / (en_msg_len).float() # (batch_size) pg_loss = pg_loss.mean() # (1,) results.update({"pg_loss": pg_loss, "b_loss": b_loss}) return results
def beam_search(self, h, h_len, width): # (batch_size, x_seq_len, D_hid * n_dir) voc_size, batch_size, x_seq_len = self.voc_sz_trg, h.size()[0], h.size()[1] live = [ [ ( 0.0, [ self.init_token ], 2 ) ] for ii in range(batch_size) ] dead = [ [] for ii in range(batch_size) ] n_dead = [0 for ii in range(batch_size)] xmask = xlen_to_inv_mask(h_len)[:,None,:] # (batch_size, 1, x_seq_len) max_len_gen = self.max_len_gen if self.msg_len_ratio < 0.0 else int(x_seq_len * self.msg_len_ratio) max_len_gen = np.clip(max_len_gen, 2, self.max_len_gen).item() h_index = (h_len - 1)[:,None,None].repeat(1, 1, h.size(2)) z_t = torch.cumsum(h, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div( z_t, h_len[:, None].float() ) z_t = self.ctx_to_z0(z_t) # (batch_size, n_layers * D_hid) z_t = z_t.view(batch_size, self.n_layers, self.D_hid).transpose(0,1).contiguous() \ input = cuda( torch.full((batch_size, 1), self.init_token).long() ) input = self.emb( input ) # (batch_size, 1, D_emb) h_big = h.contiguous().view(-1, self.D_hid * self.n_dir ) \ # (batch_size * x_seq_len, D_hid * n_dir) ctx_h = self.h_to_att( h_big ).view(batch_size, 1, x_seq_len, self.D_att) \ # NOTE (batch_size, 1, x_seq_len, D_att) for tidx in range(max_len_gen): cwidth = 1 if tidx == 0 else width # input (batch_size * width, 1, D_emb) # z_t (n_layers, batch_size * width, D_hid) _, z_t_ = self.rnn( input, z_t ) # out (batch_size * width, 1, D_hid) # z_t (n_layers, batch_size * width, D_hid) ctx_z_t_ = z_t_.transpose(0,1).contiguous().view(batch_size * cwidth, -1) \ # (batch_size * width, n_layers * D_hid) ctx_y = self.y_to_att( input.view(-1, self.D_emb) ).view(batch_size, cwidth, 1, self.D_att) ctx_s = self.z_to_att( ctx_z_t_ ).view(batch_size, cwidth, 1, self.D_att) ctx = F.tanh(ctx_y + ctx_s + ctx_h) # (batch_size, cwidth, x_seq_len, D_att) ctx = ctx.view(batch_size * cwidth * x_seq_len, self.D_att) score = self.att_to_score(ctx).view(batch_size, -1, x_seq_len) # (batch_size, cwidth, x_seq_len) score.masked_fill_(xmask.repeat(1, cwidth, 1), -float('inf')) score = F.softmax( score.view(-1, x_seq_len), dim=1 ).view(batch_size, -1, x_seq_len) score = score.view(batch_size, cwidth, x_seq_len, 1) # (batch_size, width, x_seq_len, 1) c_t = torch.mul( h.view(batch_size, 1, x_seq_len, -1), score ) # (batch_size, width, x_seq_len, D_hid * n_dir) c_t = torch.sum( c_t, 2).view(batch_size * cwidth, -1) # (batch_size * width, D_hid * n_dir) # c_t (batch_size * width, 1, D_hid * n_dir) # z_t (n_layers, batch_size * width, D_hid) out, z_t = self.rnn2( c_t[:,None,:], z_t_ ) # out (batch_size * width, 1, D_hid) # z_t (n_layers, batch_size * width, D_hid) fin_y = self.y_to_out( input.view(-1, self.D_emb) ) # (batch_size * width, D_out) fin_c = self.c_to_out( c_t ) # (batch_size * width, D_out) fin_s = self.z_to_out( out.view(-1, self.D_hid) ) # (batch_size * width, D_out) fin = F.tanh( fin_y + fin_c + fin_s ) cur_prob = F.log_softmax( self.out( fin.view(-1, self.D_out) ), dim=1 )\ .view(batch_size, cwidth, voc_size).data # (batch_size, width, voc_sz_trg) pre_prob = cuda( torch.FloatTensor( [ [ x[0] for x in ee ] for ee in live ] ).view(batch_size, cwidth, 1) ) \ # (batch_size, width, 1) total_prob = cur_prob + pre_prob # (batch_size, cwidth, voc_size) total_prob = total_prob.view(batch_size, -1) _, topi_s = total_prob.topk(width, dim=1) topv_s = cur_prob.view(batch_size, -1).gather(1, topi_s) # (batch_size, width) new_live = [ [] for ii in range(batch_size) ] for bidx in range(batch_size): n_live = width - n_dead[bidx] if n_live > 0: tis = topi_s[bidx][:n_live] tvs = topv_s[bidx][:n_live] for eidx, (topi, topv) in enumerate(zip(tis, tvs)): # NOTE max width times if topi % voc_size == self.eoz_token : dead[bidx].append( ( live[bidx][ topi // voc_size ][0] + topv, \ live[bidx][ topi // voc_size ][1] + [ topi % voc_size ],\ topi) ) n_dead[bidx] += 1 else: new_live[bidx].append( ( live[bidx][ topi // voc_size ][0] + topv, \ live[bidx][ topi // voc_size ][1] + [ topi % voc_size ],\ topi) ) while len(new_live[bidx]) < width: new_live[bidx].append( ( -99999999999, \ [0],\ 0) ) live = new_live if n_dead == [width for ii in range(batch_size)]: break in_vocab_idx = [ [ x[2] % voc_size for x in ee ] for ee in live ] # NOTE batch_size first input = self.emb( cuda( torch.LongTensor( in_vocab_idx ) ).view(-1) )\ .view(-1, 1, self.D_emb) # input (batch_size * width, 1, D_emb) in_width_idx = [ [ x[2] // voc_size + bbidx * cwidth for x in ee ] for bbidx, ee in enumerate(live) ] \ # live : (batch_size, width) z_t = z_t.index_select( 1, cuda( torch.LongTensor( in_width_idx ).view(-1) ) ).\ view(self.n_layers, batch_size * width, self.D_hid) # h_0 (n_layers, batch_size * width, D_hid) for bidx in range(batch_size): if n_dead[bidx] < width: for didx in range( width - n_dead[bidx] ): (a, b, c) = live[bidx][didx] dead[bidx].append( (a, b, c) ) dead_ = [ [ ( a / ( math.pow(5+len(b), self.beam_alpha) / math.pow(5+1, self.beam_alpha) ) , b, c) for (a,b,c) in ee] for ee in dead] ans = [] for dd_ in dead_: dd = sorted( dd_, key=operator.itemgetter(0), reverse=True ) ans.append( dd[0][1] ) return ans
def send(self, h, h_len, y_len, send_method): # h : (batch_size, x_seq_len, D_hid * n_dir) # h_len : (batch_size) batch_size, x_seq_len = h.size()[:2] xmask = xlen_to_inv_mask(h_len) # (batch_size, x_seq_len) max_len_gen = self.max_len_gen if self.msg_len_ratio < 0.0 else int(y_len.max().item() * self.msg_len_ratio) max_len_gen = np.clip(max_len_gen, 2, self.max_len_gen).item() h_index = (h_len - 1)[:,None,None].repeat(1, 1, h.size(2)) z_t = torch.cumsum(h, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div( z_t, h_len[:, None].float() ) z_t = self.ctx_to_z0(z_t) # (batch_size, n_layers * D_hid) z_t = z_t.view(batch_size, self.n_layers, self.D_hid).transpose(0,1).contiguous() \ y_emb = cuda( torch.full((batch_size, 1), self.init_token).long() ) y_emb = self.emb( y_emb ) # (batch_size, 1, D_emb) y_emb = F.dropout( y_emb, p=self.drop_ratio, training=self.training ) h_big = h.view(-1, h.size(2) ) \ # (batch_size * x_seq_len, D_hid * n_dir) ctx_h = self.h_to_att( h_big ).view(batch_size, x_seq_len, self.D_att) done = cuda( torch.zeros(batch_size).long() ) seq_lens = cuda( torch.zeros(batch_size).fill_(max_len_gen).long() ) # (max_len) max_seq_lens = cuda( torch.zeros(batch_size).fill_(max_len_gen).long() ) # (max_len) eos_tensor = cuda( torch.zeros(batch_size).fill_( self.eoz_token )).long() msg = [] self.log_probs = [] # (batch_size, seq_len) batch_logits = 0 for idx in range(max_len_gen): # in (batch_size, 1, D_emb) # z_t (n_layers, batch_size, D_hid) _, z_t_ = self.rnn( y_emb, z_t ) # out (batch_size, 1, D_hid) # z_t (n_layers, batch_size, D_hid) ctx_z_t_ = z_t_.transpose(0,1).contiguous().view(batch_size, self.n_layers * self.D_hid) \ # (batch_size, n_layers * D_hid) ctx_y = self.y_to_att( y_emb.view(batch_size, self.D_emb) )[:,None,:] ctx_s = self.z_to_att( ctx_z_t_ )[:,None,:] ctx = F.tanh(ctx_y + ctx_s + ctx_h) ctx = ctx.view(batch_size * x_seq_len, self.D_att) score = self.att_to_score(ctx).view(batch_size, -1) # (batch_size, x_seq_len) score.masked_fill_(xmask, -float('inf')) score = F.softmax( score, dim=1 ) score = score[:,:,None] # (batch_size, x_seq_len, 1) c_t = torch.mul( h, score ) # (batch_size, x_seq_len, D_hid * n_dir) c_t = torch.sum( c_t, 1) # (batch_size, D_hid * n_dir) # in (batch_size, 1, D_hid * n_dir) # z_t (n_layers, batch_size, D_hid) out, z_t = self.rnn2( c_t[:,None,:], z_t_ ) # out (batch_size, 1, D_hid) # z_t (n_layers, batch_size, D_hid) fin_y = self.y_to_out( y_emb.view(batch_size, self.D_emb) ) # (batch_size, D_out) fin_c = self.c_to_out( c_t ) # (batch_size, D_out) fin_s = self.z_to_out( out.view(batch_size, self.D_hid) ) # (batch_size, D_out) fin = F.tanh( fin_y + fin_c + fin_s ) logit = self.out( fin ) # (batch_size, voc_sz_trg) if send_method == "argmax": tokens = logit.data.max(dim=1)[1] # (batch_size) tokens_idx = tokens elif send_method == "gumbel": tokens = gumbel_softmax_hard(logit, self.temp, self.st)\ .view(batch_size, self.voc_sz_trg) # (batch_size, voc_sz_trg) tokens_idx = (tokens * cuda(torch.arange(self.voc_sz_trg))[None,:]).sum(dim=1).long() elif send_method == "reinforce": cat = Categorical(logits=logit) tokens = cat.sample() tokens_idx = tokens self.log_probs.append( cat.log_prob(tokens) ) if self.entropy: batch_logits += (logit * (1-done)[:,None].float()).sum(dim=0) msg.append(tokens.unsqueeze(dim=1)) is_next_eos = ( tokens_idx == eos_tensor ).long() # (batch_size) # (1 if eos, 0 otherwise) done = (done + is_next_eos).clamp(min=0, max=1).long() new_seq_lens = max_seq_lens.clone().masked_fill_(mask=is_next_eos.byte(), \ value=float(idx+1)) # either max or idx if next is eos # max_seq_lens : (batch_size) seq_lens = torch.min(seq_lens, new_seq_lens) if done.sum() == batch_size: break y_emb = self.emb(tokens)[:,None,:] # (batch_size, 1, D_emb) if self.msg_len_ratio > 0.0 : seq_lens = torch.clamp( (y_len.float() * self.msg_len_ratio).floor_(), 1, len(msg) ).long() msg = torch.cat(msg, dim=1) if send_method == "reinforce": # want to sum per-token log prob to yield log prob for the whole message sentence self.log_probs = torch.stack(self.log_probs, dim=1) # (batch_size, seq_len) # (batch_size, seq_len) if argmax or reinforce # (batch_size, seq_len, voc_sz_trg) if gumbel results = {"msg":msg, "seq_lens":seq_lens} if send_method == "reinforce" and self.entropy: results.update( {"batch_logits":batch_logits} ) return results
def beam(self, src_hid, src_len): # src_hid : (batch_size, x_seq_len, D_hid * n_dir) # src_len : (batch_size) batch_size, x_seq_len = src_hid.size()[:2] src_mask = xlen_to_inv_mask( src_len, seq_len=x_seq_len) # (batch_size, x_seq_len) voc_size, width = self.voc_sz_trg, self.beam_width y_seq_len = self.max_len_gen h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2)) z_t = torch.cumsum(src_hid, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div(z_t, src_len[:, None].float()) y_emb = cuda(torch.full((batch_size, ), self.init_token).long()) y_emb = self.emb(y_emb) # (batch_size, D_emb) input_feed = y_emb.data.new(batch_size, self.D_hid).zero_() live = [[(0.0, [self.init_token], 2)] for ii in range(batch_size)] dead = [[] for ii in range(batch_size)] n_dead = [0 for ii in range(batch_size)] src_hid_ = src_hid src_mask_ = src_mask for idx in range(y_seq_len): cwidth = 1 if idx == 0 else width input = torch.cat([y_emb, input_feed], dim=1) # (batch_size * width, D_emb + D_hid) trg_hid = self.layers[0](input, z_t) # (batch_size * width, D_hid) z_t = trg_hid # (batch_size * width, D_hid) out, attn_scores = self.attention( trg_hid, src_hid_, src_mask_) # (batch_size * width, D_hid) input_feed = out # (batch_size * width, D_hid) logit = self.out(out) # (batch_size * width, voc_sz_trg) cur_prob = F.log_softmax(logit, dim=1).view(batch_size, cwidth, self.voc_sz_trg) pre_prob = cuda( torch.FloatTensor( [ [ x[0] for x in ee ] for ee in live ] )\ .view(batch_size, cwidth, 1) ) # (batch_size, width, 1) total_prob = cur_prob + pre_prob # (batch_size, cwidth, voc_sz) total_prob = total_prob.view(batch_size, -1) # (batch_size, cwidth * voc_sz) topi_s = total_prob.topk(width, dim=1)[1] topv_s = cur_prob.view(batch_size, -1).gather(1, topi_s) new_live = [[] for ii in range(batch_size)] for bidx in range(batch_size): n_live = width - n_dead[bidx] if n_live > 0: tis = topi_s[bidx][:n_live].cpu().numpy().tolist() tvs = topv_s[bidx][:n_live].cpu().numpy().tolist() for eidx, (topi, topv) in enumerate(zip(tis, tvs)): if topi % voc_size == self.eos_token: dead[bidx].append( (live[bidx][topi // voc_size][0] + topv, live[bidx][topi // voc_size][1] + [topi % voc_size], topi)) n_dead[bidx] += 1 else: new_live[bidx].append( (live[bidx][topi // voc_size][0] + topv, live[bidx][topi // voc_size][1] + [topi % voc_size], topi)) while len(new_live[bidx]) < width: new_live[bidx].append((-99999999999, [0], 0)) live = new_live if n_dead == [width for ii in range(batch_size)]: break in_vocab_idx = np.array([[x[2] % voc_size for x in ee] for ee in live]) # NOTE batch_size first y_emb = self.emb( cuda( torch.LongTensor( in_vocab_idx ) ).view(-1) )\ .view(-1, self.D_emb) # input (batch_size * width, 1, D_emb) in_width_idx = np.array( [ [ x[2] // voc_size + bbidx * cwidth for x in ee ] \ for bbidx, ee in enumerate(live) ] ) in_width_idx = cuda(torch.LongTensor(in_width_idx).view(-1)) z_t = z_t.index_select(0, in_width_idx).view(batch_size * width, self.D_hid) input_feed = input_feed.index_select(0, in_width_idx).view( batch_size * width, self.D_hid) src_hid_ = src_hid_.index_select(0, in_width_idx).view( batch_size * width, x_seq_len, self.D_hid) src_mask_ = src_mask_.index_select(0, in_width_idx).view( batch_size * width, x_seq_len) for bidx in range(batch_size): if n_dead[bidx] < width: for didx in range(width - n_dead[bidx]): (a, b, c) = live[bidx][didx] dead[bidx].append((a, b, c)) dead_ = [ [ ( a / ( math.pow(5+len(b), self.beam_alpha) / math.pow(5+1, self.beam_alpha) ) , b, c)\ for (a,b,c) in ee] for ee in dead] ans = [[], [], [], [], []] for dd_ in dead_: dd = sorted(dd_, key=operator.itemgetter(0), reverse=True) for idx in range(5): ans[idx].append(dd[idx][1]) #ans.append( dd[0][1] ) return ans
def send(self, src_hid, src_len, trg_len, send_method, value_fn=None, gumbel_temp=1, dot_token=None): # src_hid : (batch_size, x_seq_len, D_hid * n_dir) # src_len : (batch_size) batch_size, x_seq_len = src_hid.size()[:2] src_mask = xlen_to_inv_mask( src_len, seq_len=x_seq_len) # (batch_size, x_seq_len) y_seq_len = math.floor(trg_len.max().item() * self.msg_len_ratio) \ if self.msg_len_ratio > 0 else self.max_len_gen h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2)) z_t = torch.cumsum(src_hid, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div(z_t, src_len[:, None].float()) y_emb = cuda(torch.full((batch_size, ), self.init_token).long()) y_emb = self.emb(y_emb) # (batch_size, D_emb) #y_emb = F.dropout( y_emb, p=self.drop_ratio, training=self.training ) prev_trg_hids = [z_t for _ in range(self.n_layers)] trg_hid = prev_trg_hids[0] input_feed = y_emb.data.new(batch_size, self.D_hid).zero_() done = cuda(torch.zeros(batch_size).long()) seq_lens = cuda( torch.zeros(batch_size).fill_(y_seq_len).long()) # (max_len) max_seq_lens = cuda( torch.zeros(batch_size).fill_(y_seq_len).long()) # (max_len) eos_tensor = cuda(torch.zeros(batch_size).fill_(self.eos_token)).long() self.log_probs, self.R_b, self.neg_Hs, self.gumbel_tokens = [], [], [], [] msg = [] for idx in range(y_seq_len): if self.input_feeding: input = torch.cat([y_emb, input_feed], dim=1) # (batch_size, D_emb + D_hid) else: input = y_emb if send_method == "reinforce" and value_fn: if self.input_feeding: value_fn_input = [y_emb, input_feed] else: value_fn_input = [y_emb, trg_hid] value_fn_input = torch.cat(value_fn_input, dim=1).detach() self.R_b.append( value_fn(value_fn_input).view(-1)) # (batch_size) for i, rnn in enumerate(self.layers): trg_hid = rnn(input, prev_trg_hids[i]) # (batch_size, D_hid) #input = F.dropout(trg_hid, p=self.drop_ratio, training=self.training) input = trg_hid prev_trg_hids[i] = trg_hid if self.attention is not None: out, attn_scores = self.attention( trg_hid, src_hid, src_mask) # (batch_size, D_hid) else: out = trg_hid input_feed = out logit = self.out(out) # (batch_size, voc_sz_trg) if send_method == "argmax": tokens = logit.max(dim=1)[1] # (batch_size) tok_dist = Categorical(logits=logit) self.neg_Hs.append(-1 * tok_dist.entropy()) elif send_method == "reinforce": tok_dist = Categorical(logits=logit) tokens = tok_dist.sample() self.log_probs.append( tok_dist.log_prob(tokens)) # (batch_size) self.neg_Hs.append(-1 * tok_dist.entropy()) elif send_method == "gumbel": y = gumbel_softmax(logit, gumbel_temp) tok_dist = Categorical(probs=y) self.neg_Hs.append(-1 * tok_dist.entropy()) tokens = torch.argmax(y, dim=1) tokens_oh = cuda(torch.zeros(y.size())).scatter_( 1, tokens.unsqueeze(-1), 1) self.gumbel_tokens.append((tokens_oh - y).detach() + y) else: raise ValueError msg.append(tokens) is_next_eos = (tokens == eos_tensor).long() # (batch_size) new_seq_lens = max_seq_lens.clone().masked_fill_( mask=is_next_eos.bool(), value=float(idx + 1)) # NOTE idx+1 ensures this is valid length seq_lens = torch.min(seq_lens, new_seq_lens) # (contains lengths) done = (done + is_next_eos).clamp(min=0, max=1).long() if done.sum() == batch_size: break y_emb = self.emb(tokens) # (batch_size, D_emb) msg = torch.stack(msg, dim=1) # (batch_size, y_seq_len) self.neg_Hs = torch.stack(self.neg_Hs, dim=1) # (batch_size, y_seq_len) if send_method == "reinforce": self.log_probs = torch.stack(self.log_probs, dim=1) # (batch_size, y_seq_len) self.R_b = torch.stack( self.R_b, dim=1) if value_fn else self.R_b # (batch_size, y_seq_len) if send_method == 'gumbel' and len(self.gumbel_tokens) > 0: self.gumbel_tokens = torch.stack(self.gumbel_tokens, dim=1) result = {"msg": msg.clone(), "new_seq_lens": seq_lens.clone()} # Trim sequence length with first dot tensor if dot_token is not None: seq_lens = first_appear_indice(msg, dot_token) + 1 pass # Jason's trick on trim the sentence length based on ground truth length if self.msg_len_ratio > 0: en_ref_len = torch.floor(trg_len.float() * self.msg_len_ratio).long() seq_lens = torch.min(seq_lens, en_ref_len) # Make length larger than min len and valid seq_lens = torch.max( seq_lens, seq_lens.new(seq_lens.size()).fill_(self.min_len_gen)) seq_lens = torch.min(seq_lens, seq_lens.new(seq_lens.size()).fill_(msg.shape[1])) # Make sure message is valid ends_with_eos = (msg.gather( dim=1, index=(seq_lens - 1)[:, None]).view(-1) == eos_tensor).long() eos_or_pad = ends_with_eos * self.pad_token + ( 1 - ends_with_eos) * self.eos_token msg = torch.cat( [msg, msg.new(batch_size, 1).fill_(self.pad_token)], dim=1) if send_method == 'gumbel' and len(self.gumbel_tokens) > 0: pad_gumbel_tokens = cuda( torch.zeros(msg.shape[0], 1, self.gumbel_tokens.shape[2])) pad_gumbel_tokens[:, :, self.pad_token] = 1 self.gumbel_tokens = torch.cat( [self.gumbel_tokens, pad_gumbel_tokens], dim=1) # (batch_size, y_seq_len + 1) msg.scatter_(dim=1, index=seq_lens[:, None], src=eos_or_pad[:, None]) if send_method == 'gumbel' and len(self.gumbel_tokens) > 0: batch_ids = cuda(torch.arange(msg.shape[0])) self.gumbel_tokens[batch_ids, seq_lens] = 0. self.gumbel_tokens[batch_ids, seq_lens, eos_or_pad] = 1. seq_lens = seq_lens + (1 - ends_with_eos) msg_mask = xlen_to_inv_mask( seq_lens, seq_len=msg.size(1)) # (batch_size, x_seq_len) msg.masked_fill_(msg_mask.bool(), self.pad_token) result.update({"msg": msg, "new_seq_lens": seq_lens}) return result