def multi_decode(self, batch): """ A helper for beam """ (fr, fr_len) = batch.fr (en, en_len) = batch.en (de, de_len) = batch.de batch_size = len(batch) fr_hid = self.fr_en.enc(fr[:, 1:], fr_len - 1) en_msgs_ = self.fr_en.dec.beam(fr_hid, fr_len - 1) # (5, batch_size, en_len) en_msgs, en_lens = [], [] for en in en_msgs_: en = [ee[1:] for ee in en] en_len = [len(x) for x in en] max_len = max(en_len) en_len = cuda(torch.LongTensor(en_len)) en_lens.append(en_len) en = [ np.lib.pad(xx, (0, max_len - len(xx)), 'constant', constant_values=(0, 0)) for xx in en ] en = cuda(torch.LongTensor(np.array(en))) en_msgs.append(en) en_hids = [ self.en_de.enc(en_msg, en_len) for (en_msg, en_len) in zip(en_msgs, en_lens) ] de_msg = self.en_de.dec.multi_decode(en_hids, en_lens) return en_msgs, de_msg
def valid_model(args, model, dev_it, dev_metrics, iters, loss_names, monitor_names, extra_input): with torch.no_grad(): model.eval() for j, dev_batch in enumerate(dev_it): img_input = None if args.no_img else cuda( extra_input["img"]['multi30k'][1].index_select( dim=0, index=dev_batch.idx.cpu())) en, en_len = dev_batch.en decoded = model(en, img_input) R = {} R["nll"] = F.cross_entropy(decoded, en[:, 1:].contiguous().view(-1), ignore_index=0) #R["nll_cur"] = F.cross_entropy( decoded, en[:,:-1].contiguous().view(-1), ignore_index=0 ) if not (img_input is None): idx = cuda(torch.randperm(img_input.size(0))) img_input = img_input.index_select(0, idx) decoded = model(en, img_input) R["nll_rnd"] = F.cross_entropy(decoded, en[:, 1:].contiguous().view(-1), ignore_index=0) dev_metrics.accumulate( len(dev_batch), *[R[name].item() for name in loss_names + monitor_names]) args.logger.info(dev_metrics)
def multi_decode(self, src_hids, src_lens): # src_hid : (batch_size, x_seq_len, D_hid * n_dir) # src_len : (batch_size) batch_size = src_hids[0].size(0) x_seq_lens = [src_hid.size(1) for src_hid in src_hids] src_masks = [ xlen_to_inv_mask(src_len, seq_len=x_seq_len) for (src_len, x_seq_len) \ in zip(src_lens, x_seq_lens) ] # (batch_size, x_seq_len) y_seq_len = self.max_len_gen z_ts = [] for src_len, src_hid in zip(src_lens, src_hids): h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2)) z_t = torch.cumsum(src_hid, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view( batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div(z_t, src_len[:, None].float()) z_ts.append(z_t) y_emb = cuda(torch.full((batch_size, ), self.init_token).long()) y_emb = self.emb(y_emb) # (batch_size, D_emb) input_feeds = [ y_emb.data.new(batch_size, self.D_hid).zero_() for ii in range(5) ] done = cuda(torch.zeros(batch_size).long()) eos_tensor = cuda(torch.zeros(batch_size).fill_(self.eos_token)).long() msg = [] for idx in range(y_seq_len): probs = [] for which in range(5): input = torch.cat([y_emb, input_feeds[which]], dim=1) # (batch_size * width, D_emb + D_hid) trg_hid = self.layers[0]( input, z_ts[which]) # (batch_size * width, D_hid) z_ts[which] = trg_hid # (batch_size * width, D_hid) out, _ = self.attention( trg_hid, src_hids[which], src_masks[which]) # (batch_size, D_hid) input_feeds[which] = out logit = self.out(out) # (batch_size, voc_sz_trg) probs.append(F.softmax(logit, -1)) prob = torch.stack(probs, dim=-1).mean(-1) tokens = prob.max(dim=1)[1] # (batch_size) msg.append(tokens) is_next_eos = (tokens == eos_tensor).long() # (batch_size) done = (done + is_next_eos).clamp(min=0, max=1).long() if done.sum() == batch_size: break y_emb = self.emb(tokens) # (batch_size, D_emb) msg = torch.stack(msg, dim=1) # (batch_size, y_seq_len) return msg
def forward(self, x, x_len, h_0=None): # NOTE x_len.max() == x.size(1) # x : (batch_size, x_seq_len) or (batcg_size, x_seq_len, vocab_size) # x_len : (batch_size) batch_size, x_seq_len = x.size()[:2] # NOTE dim==3 for gumbel softmax if h_0 is None: h_0 = cuda( torch.FloatTensor(self.n_layers * self.n_dir, batch_size, self.D_hid).zero_()) if len(x.shape) == 2: input = self.emb(x) elif len(x.shape) == 3: # Gumbel input = torch.matmul(x, self.emb.weight) else: raise ValueError input = F.dropout(input, p=self.drop_ratio, training=self.training) # input (batch_size, x_seq_len, D_emb) # h_0 (n_layers * n_dir, batch_size, D_hid) output, _ = self.rnn(input, h_0) # output (batch_size, x_seq_len, n_dir * D_hid) # h_n (n_layers * n_dir, batch_size, D_hid) """ if self.model == "RNN" : # RNN out = take_last(output, x_len, self.n_dir, self.D_hid) return out # (batch_size, n_dir * D_hid) else: # RNNAttn #inv_mask = xlen_to_inv_mask(x_len)[:,:,None] # (batch_size, x_seq_len, 1) #output.masked_fill_(inv_mask, 0) """ return output.contiguous() # (batch_size, x_seq_len, n_dir * D_hid)
def batch_cap_rep(self, sents, sent_lens): """ :param sents: [NB_SENT, len] :param sent_lens: [NB_SENT] :return: [NB_X, D_hid] """ batch_size = 64 start = 0 result = [] while start < sents.shape[0]: end = start + batch_size batch_sent = cuda(sents[start: end]) batch_len = cuda(sent_lens[start: end]) sent_enc = self.get_cap_rep(batch_sent, batch_len) result.append(sent_enc) start = end result = torch.cat(result, dim=0) return normf(result)
def _make_sure_message_valid(msg, msg_len, init_token): # Add BOS msg = torch.cat( [cuda(torch.full((msg.shape[0], 1), init_token)).long(), msg], dim=1) msg_len += 1 # Make sure padding are all zeros #inv_mask = xlen_to_inv_mask(msg_len, seq_len=msg.shape[1]) #msg.masked_fill_(mask=inv_mask.bool(), value=0) return msg, msg_len
def get_cap_rep(self, x, x_len): batch_size, seq_len = x.shape hidden = cuda(torch.zeros(self.n_layers, batch_size, self.D_hid)) emb = F.dropout( self.emb( x ), p=self.drop_ratio, training=self.training ) output, _ = self.rnn(emb, hidden) # (batch, seq_len, 2 * D_hid) f_out = output.view(batch_size, seq_len, self.D_hid) f_idx = (x_len - 1)[:,None,None].repeat(1, 1, self.D_hid) f_out = f_out.gather(dim=1, index=f_idx).view(batch_size, -1) # (batch_size, D_hid) return f_out
def valid_model(args, model, valid_img_feats, valid_caps, valid_lens): model.eval() batch_size = 32 start = 0 val_metrics = Metrics('val_loss', 'loss', data_type="avg") with torch.no_grad(): while start <= valid_img_feats.shape[0]: cap_id = random.randint(0, 4) end = start + batch_size batch_img_feat = cuda(valid_img_feats[start: end]) batch_ens = cuda(valid_caps[cap_id][start: end]) batch_lens = cuda(valid_lens[cap_id][start: end]) R = model(batch_ens[:, 1:], batch_lens - 1, batch_img_feat) if args.img_pred_loss == "vse": R['loss'] = R['loss'].sum() elif args.img_pred_loss == "mse": R['loss'] = R['loss'].mean() else: raise ValueError val_metrics.accumulate(batch_size, R['loss']) start = end return val_metrics
def valid_model(args, model, dev_it, extra_input): with torch.no_grad(): model.eval() img_features, cap_features = [], [] for j, dev_batch in enumerate(dev_it): img = cuda(extra_input["img"]["multi30k"][1].index_select( dim=0, index=dev_batch.idx.cpu())) # (batch_size, D_img) en, en_len = dev_batch.en cap_rep = model.get_cap_rep(en[:, 1:], en_len - 1) cap_features.append(normf(cap_rep)) img = model.img_enc(img) # (batch_size, D_hid) img_features.append(normf(img)) img_features = torch.cat(img_features, dim=0) # (5000, D_img) cap_features = torch.cat(cap_features, dim=0) # (5000, D_img) scores = torch.mm(img_features, cap_features.t()) # (5000, 5000) _, cap_idx = torch.sort(scores, dim=1, descending=True) # (5000, 5000) _, img_idx = torch.sort(scores, dim=0, descending=True) # (5000, 5000) query = cuda(torch.arange(5000)) cap_r1, cap_r5, cap_r10 = [retrieval(idx, query, 1) for idx in \ [ cap_idx[:,:1], cap_idx[:,:5], cap_idx[:,:10] ] ] img_r1, img_r5, img_r10 = [retrieval(idx, query, 0) for idx in \ [ img_idx[:1,:], img_idx[:5,:], img_idx[:10,:] ] ] return (cap_r1, cap_r5, cap_r10, img_r1, img_r5, img_r10)
def batch_enc_img(self, img_feat): """ :param img_feat: [NB_img, D_img] :return: [NB_img, D_hid] """ batch_size = 64 result = [] start = 0 while start < img_feat.shape[0]: end = start + batch_size batch_img_feat = cuda(img_feat[start: end]) batch_img_feat = F.dropout(batch_img_feat, p=self.drop_ratio, training=self.training ) batch_img_feat = self.img_enc(batch_img_feat) result.append(batch_img_feat) start = end result = torch.cat(result, dim=0) return normf(result)
def get_retrieve_result(args, model, valid_img_feats, valid_caps, valid_lens): with torch.no_grad(): model.eval() query = cuda(torch.arange(valid_img_feats.shape[0])) img_features = model.batch_enc_img(valid_img_feats) cap_features = [] for valid_cap, valid_len in zip(valid_caps, valid_lens): cap_features.append(model.batch_cap_rep(valid_cap[:, 1:], valid_len - 1)) scores = [torch.mm(img_features, cap.t()) for cap in cap_features] # [ (img 1000, cap_i 1000) x 5 ] img_idxs = [torch.sort( sc, dim=0, descending=True )[1] for sc in scores] # [ (img 1000, cap_i 1000) x 5 ] img_r1 = np.mean([retrieval(img_idx[:1, :], query, 0) for img_idx in img_idxs]) img_r5 = np.mean([retrieval(img_idx[:5, :], query, 0) for img_idx in img_idxs]) img_r10 = np.mean([retrieval(img_idx[:10, :], query, 0) for img_idx in img_idxs]) scores = torch.cat(scores, dim=1 ) # (img 1000, cap 5000) cap_idx = torch.sort(scores, dim=1, descending=True )[1] # (img 1000, cap 5000) cap_idx = torch.remainder(cap_idx, 1000 ) # (img 1000, cap 5000) cap_r1, cap_r5, cap_r10 = [retrieval(idx, query, 1) for idx in \ [ cap_idx[:,:1], cap_idx[:,:5], cap_idx[:,:10] ] ] return cap_r1, cap_r5, cap_r10, img_r1, img_r5, img_r10
def forward(self, x, x_len, img): # x : (batch_size, seq_len) # x_len : (batch_size) # img : (batch_size, D_img) batch_size, seq_len = x_len.size(0), x_len.max().item() x_enc = self.get_cap_rep(x, x_len) R = {} if self.img_pred_loss == "mse": x_enc = F.dropout( x_enc, p=self.drop_ratio, training=self.training ) # (batch_size, D_img) x_enc = self.img_enc( x_enc ) #loss = F.mse_loss(x_enc, img, reduction='none') loss = ( (x_enc - img) ** 2 ).mean(1) elif self.img_pred_loss == "vse": img = F.dropout( img, p=self.drop_ratio, training=self.training ) # (batch_size, D_img) img = self.img_enc( img ) # (batch_size, D_hid) x_enc = normf(x_enc) img = normf(img) scores = torch.mm( img, x_enc.t() ) # (batch_size, batch_size) diagonal = scores.diag().view(batch_size, -1) # (batch_size, 1) pos_cap_scores = diagonal.expand_as(scores) # (batch_size, batch_size) pos_img_scores = diagonal.t().expand_as(scores) # (batch_size, batch_size) cost_cap = (self.margin + scores - pos_cap_scores).clamp(min=0) cost_img = (self.margin + scores - pos_img_scores).clamp(min=0) mask = cuda( torch.eye( batch_size ) > .5 ) # remove diagonal cost_cap = cost_cap.masked_fill(mask, 0.0) # (batch_size, batch_size) cost_img = cost_img.masked_fill(mask, 0.0) # (batch_size, batch_size) if self.training: loss = cost_cap.max(1)[0] + cost_img.max(0)[0] # (batch_size) else: loss = cost_cap.mean(dim=1) + cost_img.mean(dim=0) # (batch_size) R['loss'] = loss return R
def forward(self, msg, img): # msg : (batch_size, seq_len) or (batch_size, seq_len, vocab_size) # img : (batch_size, D_img) batch_size = msg.size(0) if self.no_img: assert (img is None) hidden = cuda( torch.zeros(self.n_layers, batch_size, self.D_hid) ) else: assert not (img is None) img = F.dropout( img, p=self.drop_ratio, training=self.training ) hidden = self.img_enc( img )[None,:,:] # (1, batch_size, D_hid) hidden = hidden.repeat(self.n_layers, 1, 1) input, target = msg[:,:-1], msg[:,1:] if len(input.shape) == 2: emb = F.dropout(self.encoder( input ), p=self.drop_ratio, training=self.training) elif len(input.shape) == 3: emb = torch.matmul(input, self.encoder.weight) emb = F.dropout(emb, p=self.drop_ratio, training=self.training) else: raise ValueError output, _ = self.rnn(emb, (hidden, hidden)) output = F.dropout( output, p=self.drop_ratio, training=self.training ) decoded = self.decoder(output).view(-1, self.voc_sz) return decoded
def send(self, src_hid, src_len, trg_len, send_method, value_fn=None, gumbel_temp=1, dot_token=None): # src_hid : (batch_size, x_seq_len, D_hid * n_dir) # src_len : (batch_size) batch_size, x_seq_len = src_hid.size()[:2] src_mask = xlen_to_inv_mask( src_len, seq_len=x_seq_len) # (batch_size, x_seq_len) y_seq_len = math.floor(trg_len.max().item() * self.msg_len_ratio) \ if self.msg_len_ratio > 0 else self.max_len_gen h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2)) z_t = torch.cumsum(src_hid, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div(z_t, src_len[:, None].float()) y_emb = cuda(torch.full((batch_size, ), self.init_token).long()) y_emb = self.emb(y_emb) # (batch_size, D_emb) #y_emb = F.dropout( y_emb, p=self.drop_ratio, training=self.training ) prev_trg_hids = [z_t for _ in range(self.n_layers)] trg_hid = prev_trg_hids[0] input_feed = y_emb.data.new(batch_size, self.D_hid).zero_() done = cuda(torch.zeros(batch_size).long()) seq_lens = cuda( torch.zeros(batch_size).fill_(y_seq_len).long()) # (max_len) max_seq_lens = cuda( torch.zeros(batch_size).fill_(y_seq_len).long()) # (max_len) eos_tensor = cuda(torch.zeros(batch_size).fill_(self.eos_token)).long() self.log_probs, self.R_b, self.neg_Hs, self.gumbel_tokens = [], [], [], [] msg = [] for idx in range(y_seq_len): if self.input_feeding: input = torch.cat([y_emb, input_feed], dim=1) # (batch_size, D_emb + D_hid) else: input = y_emb if send_method == "reinforce" and value_fn: if self.input_feeding: value_fn_input = [y_emb, input_feed] else: value_fn_input = [y_emb, trg_hid] value_fn_input = torch.cat(value_fn_input, dim=1).detach() self.R_b.append( value_fn(value_fn_input).view(-1)) # (batch_size) for i, rnn in enumerate(self.layers): trg_hid = rnn(input, prev_trg_hids[i]) # (batch_size, D_hid) #input = F.dropout(trg_hid, p=self.drop_ratio, training=self.training) input = trg_hid prev_trg_hids[i] = trg_hid if self.attention is not None: out, attn_scores = self.attention( trg_hid, src_hid, src_mask) # (batch_size, D_hid) else: out = trg_hid input_feed = out logit = self.out(out) # (batch_size, voc_sz_trg) if send_method == "argmax": tokens = logit.max(dim=1)[1] # (batch_size) tok_dist = Categorical(logits=logit) self.neg_Hs.append(-1 * tok_dist.entropy()) elif send_method == "reinforce": tok_dist = Categorical(logits=logit) tokens = tok_dist.sample() self.log_probs.append( tok_dist.log_prob(tokens)) # (batch_size) self.neg_Hs.append(-1 * tok_dist.entropy()) elif send_method == "gumbel": y = gumbel_softmax(logit, gumbel_temp) tok_dist = Categorical(probs=y) self.neg_Hs.append(-1 * tok_dist.entropy()) tokens = torch.argmax(y, dim=1) tokens_oh = cuda(torch.zeros(y.size())).scatter_( 1, tokens.unsqueeze(-1), 1) self.gumbel_tokens.append((tokens_oh - y).detach() + y) else: raise ValueError msg.append(tokens) is_next_eos = (tokens == eos_tensor).long() # (batch_size) new_seq_lens = max_seq_lens.clone().masked_fill_( mask=is_next_eos.bool(), value=float(idx + 1)) # NOTE idx+1 ensures this is valid length seq_lens = torch.min(seq_lens, new_seq_lens) # (contains lengths) done = (done + is_next_eos).clamp(min=0, max=1).long() if done.sum() == batch_size: break y_emb = self.emb(tokens) # (batch_size, D_emb) msg = torch.stack(msg, dim=1) # (batch_size, y_seq_len) self.neg_Hs = torch.stack(self.neg_Hs, dim=1) # (batch_size, y_seq_len) if send_method == "reinforce": self.log_probs = torch.stack(self.log_probs, dim=1) # (batch_size, y_seq_len) self.R_b = torch.stack( self.R_b, dim=1) if value_fn else self.R_b # (batch_size, y_seq_len) if send_method == 'gumbel' and len(self.gumbel_tokens) > 0: self.gumbel_tokens = torch.stack(self.gumbel_tokens, dim=1) result = {"msg": msg.clone(), "new_seq_lens": seq_lens.clone()} # Trim sequence length with first dot tensor if dot_token is not None: seq_lens = first_appear_indice(msg, dot_token) + 1 pass # Jason's trick on trim the sentence length based on ground truth length if self.msg_len_ratio > 0: en_ref_len = torch.floor(trg_len.float() * self.msg_len_ratio).long() seq_lens = torch.min(seq_lens, en_ref_len) # Make length larger than min len and valid seq_lens = torch.max( seq_lens, seq_lens.new(seq_lens.size()).fill_(self.min_len_gen)) seq_lens = torch.min(seq_lens, seq_lens.new(seq_lens.size()).fill_(msg.shape[1])) # Make sure message is valid ends_with_eos = (msg.gather( dim=1, index=(seq_lens - 1)[:, None]).view(-1) == eos_tensor).long() eos_or_pad = ends_with_eos * self.pad_token + ( 1 - ends_with_eos) * self.eos_token msg = torch.cat( [msg, msg.new(batch_size, 1).fill_(self.pad_token)], dim=1) if send_method == 'gumbel' and len(self.gumbel_tokens) > 0: pad_gumbel_tokens = cuda( torch.zeros(msg.shape[0], 1, self.gumbel_tokens.shape[2])) pad_gumbel_tokens[:, :, self.pad_token] = 1 self.gumbel_tokens = torch.cat( [self.gumbel_tokens, pad_gumbel_tokens], dim=1) # (batch_size, y_seq_len + 1) msg.scatter_(dim=1, index=seq_lens[:, None], src=eos_or_pad[:, None]) if send_method == 'gumbel' and len(self.gumbel_tokens) > 0: batch_ids = cuda(torch.arange(msg.shape[0])) self.gumbel_tokens[batch_ids, seq_lens] = 0. self.gumbel_tokens[batch_ids, seq_lens, eos_or_pad] = 1. seq_lens = seq_lens + (1 - ends_with_eos) msg_mask = xlen_to_inv_mask( seq_lens, seq_len=msg.size(1)) # (batch_size, x_seq_len) msg.masked_fill_(msg_mask.bool(), self.pad_token) result.update({"msg": msg, "new_seq_lens": seq_lens}) return result
def train_model(args, model): resnet = torchvision.models.resnet152(pretrained=True) resnet = nn.Sequential(*list(resnet.children())[:-1]) resnet = nn.DataParallel(resnet).cuda() resnet.eval() if not args.debug: from tensorboardX import SummaryWriter writer = SummaryWriter( args.event_path + args.id_str) params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == 'Adam': opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr) else: raise NotImplementedError loss_names, loss_cos = ["loss"], {"loss":1.0} monitor_names = "cap_r1 cap_r5 cap_r10 img_r1 img_r5 img_r10".split() train_metrics = Metrics('train_loss', *loss_names, data_type="avg") best = Best(max, 'r1', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \ gpu=args.gpu, debug=args.debug) # Train dataset args.logger.info("Loading train imgs...") train_dataset = ImageFolderWithPaths(os.path.join(args.data_dir, 'flickr30k'), preprocess_rc) train_imgs = open(os.path.join(args.data_dir, 'flickr30k/train.txt'), 'r').readlines() train_imgs = [x.strip() for x in train_imgs if x.strip() != ""] train_dataset.samples = [x for x in train_dataset.samples if x[0].split("/")[-1] in train_imgs] train_dataset.imgs = [x for x in train_dataset.imgs if x[0].split("/")[-1] in train_imgs] train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=False) args.logger.info("Train loader built!") en_vocab = TextVocab(counter=torch.load(os.path.join(args.data_dir, 'bpe/vocab.en.pth'))) word2idx = en_vocab.stoi train_en = [open(os.path.join(args.data_dir, 'flickr30k/caps', 'train.{}.bpe'.format(idx+1))).readlines() for idx in range(5)] train_en = [[["<bos>"] + sentence.strip().split() + ["<eos>"] for sentence in doc if sentence.strip() != "" ] for doc in train_en] train_en = [[[word2idx[word] for word in sentence] for sentence in doc] for doc in train_en] args.logger.info("Train corpus built!") # Valid dataset valid_img_feats = torch.tensor(torch.load(os.path.join(args.data_dir, 'flickr30k/val_feat.pth'))) valid_ens = [] valid_en_lens = [] for idx in range(5): valid_en = [] with open(os.path.join(args.data_dir, 'flickr30k/caps', 'val.{}.bpe'.format(idx+1))) as f: for line in f: line = line.strip() if line == "": continue words = ["<bos>"] + line.split() + ["<eos>"] words = [word2idx[word] for word in words] valid_en.append(words) # Pad valid_en_len = [len(sent) for sent in valid_en] valid_en = [np.lib.pad(xx, (0, max(valid_en_len) - len(xx)), 'constant', constant_values=(0, 0)) for xx in valid_en] valid_ens.append(torch.tensor(valid_en).long()) valid_en_lens.append(torch.tensor(valid_en_len).long()) args.logger.info("Valid corpus built!") iters = -1 should_stop = False for epoch in range(999999999): if should_stop: break for idx, (train_img, lab, path) in enumerate(train_loader): iters += 1 if iters > args.max_training_steps: should_stop = True break if iters % args.eval_every == 0: res = get_retrieve_result(args, model, valid_caps=valid_ens, valid_lens=valid_en_lens, valid_img_feats=valid_img_feats) val_metrics = valid_model(args, model, valid_img_feats=valid_img_feats, valid_caps=valid_ens, valid_lens=valid_en_lens) args.logger.info("[VALID] update {} : {}".format(iters, str(val_metrics))) if not args.debug: write_tb(writer, monitor_names, res, iters, prefix="dev/") write_tb(writer, loss_names, [val_metrics.__getattr__(name) for name in loss_names], iters, prefix="dev/") best.accumulate((res[0]+res[3])/2, iters) args.logger.info('model:' + args.prefix + args.hp_str) args.logger.info('epoch {} iters {}'.format(epoch, iters)) args.logger.info(best) if args.early_stop and (iters - best.iters) // args.eval_every > args.patience: args.logger.info("Early stopping.") return model.train() def get_lr_anneal(iters): lr_end = args.lr_min return max( 0, (args.lr - lr_end) * (args.linear_anneal_steps - iters) / args.linear_anneal_steps ) + lr_end if args.lr_anneal == "linear": opt.param_groups[0]['lr'] = get_lr_anneal(iters) opt.zero_grad() batch_size = len(path) path = [p.split("/")[-1] for p in path] sentence_idx = [train_imgs.index(p) for p in path] en = [train_en[random.randint(0, 4)][sentence_i] for sentence_i in sentence_idx] en_len = [len(x) for x in en] en = [ np.lib.pad( xx, (0, max(en_len) - len(xx)), 'constant', constant_values=(0,0) ) for xx in en ] en = cuda( torch.LongTensor( np.array(en).tolist() ) ) en_len = cuda( torch.LongTensor( en_len ) ) with torch.no_grad(): train_img = resnet(train_img).view(batch_size, -1) R = model(en[:,1:], en_len-1, train_img) if args.img_pred_loss == "vse": R['loss'] = R['loss'].sum() elif args.img_pred_loss == "mse": R['loss'] = R['loss'].mean() else: raise Exception() total_loss = 0 for loss_name in loss_names: total_loss += R[loss_name] * loss_cos[loss_name] train_metrics.accumulate(batch_size, *[R[name].item() for name in loss_names]) total_loss.backward() if args.plot_grad: plot_grad(writer, model, iters) if args.grad_clip > 0: nn.utils.clip_grad_norm_(params, args.grad_clip) opt.step() if iters % args.eval_every == 0: args.logger.info("update {} : {}".format(iters, str(train_metrics))) if iters % args.eval_every == 0 and not args.debug: write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \ iters, prefix="train/") write_tb(writer, ['lr'], [opt.param_groups[0]['lr']], iters, prefix="train/") train_metrics.reset()
def train_model(args, model, iterators, extra_input): (train_its, dev_its) = iterators if not args.debug: from tensorboardX import SummaryWriter writer = SummaryWriter(args.event_path + args.id_str) params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == 'Adam': opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr) else: raise NotImplementedError loss_names, loss_cos = ["nll"], {"nll": 1.0} monitor_names = ["nll_rnd"] """ if args.rep_pen_co > 0.0: loss_names.append("nll_cur") loss_cos["nll_cur"] = -1 * args.rep_pen_co else: monitor_names.append("nll_cur") """ train_metrics = Metrics('train_loss', *loss_names, data_type="avg") dev_metrics = Metrics('dev_loss', *loss_names, *monitor_names, data_type="avg") best = Best(min, 'loss', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \ gpu=args.gpu, debug=args.debug) iters = 0 should_stop = False for epoch in range(999999999): if should_stop: break for dataset in args.dataset.split("_"): if should_stop: break train_it = train_its[dataset] for _, train_batch in enumerate(train_it): if iters >= args.max_training_steps: args.logger.info( 'stopping training after {} training steps'.format( args.max_training_steps)) should_stop = True break if iters % args.eval_every == 0: dev_metrics.reset() valid_model(args, model, dev_its['multi30k'], dev_metrics, iters, loss_names, monitor_names, extra_input) if not args.debug: write_tb(writer, loss_names, [dev_metrics.__getattr__(name) for name in loss_names], \ iters, prefix="dev/") write_tb(writer, monitor_names, [dev_metrics.__getattr__(name) for name in monitor_names], \ iters, prefix="dev/") best.accumulate(dev_metrics.nll, iters) args.logger.info('model:' + args.prefix + args.hp_str) args.logger.info('epoch {} dataset {} iters {}'.format( epoch, dataset, iters)) args.logger.info(best) if args.early_stop and ( iters - best.iters) // args.eval_every > args.patience: args.logger.info("Early stopping.") return model.train() def get_lr_anneal(iters): lr_end = args.lr_min return max(0, (args.lr - lr_end) * (args.linear_anneal_steps - iters) / args.linear_anneal_steps) + lr_end if args.lr_anneal == "linear": opt.param_groups[0]['lr'] = get_lr_anneal(iters) opt.zero_grad() batch_size = len(train_batch) img_input = None if args.no_img else cuda( extra_input["img"][dataset][0].index_select( dim=0, index=train_batch.idx.cpu())) if dataset == "coco": en, en_len = train_batch.__dict__[ "_" + str(random.randint(1, 5))] elif dataset == "multi30k": en, en_len = train_batch.en decoded = model(en, img_input) R = {} R["nll"] = F.cross_entropy(decoded, en[:, 1:].contiguous().view(-1), ignore_index=0) #R["nll_cur"] = F.cross_entropy( decoded, en[:,:-1].contiguous().view(-1), ignore_index=0 ) total_loss = 0 for loss_name in loss_names: total_loss += R[loss_name] * loss_cos[loss_name] train_metrics.accumulate( batch_size, *[R[name].item() for name in loss_names]) total_loss.backward() if args.plot_grad: plot_grad(writer, model, iters) if args.grad_clip > 0: nn.utils.clip_grad_norm_(params, args.grad_clip) opt.step() iters += 1 if iters % args.eval_every == 0: args.logger.info("update {} : {}".format( iters, str(train_metrics))) if iters % args.eval_every == 0 and not args.debug: write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \ iters, prefix="train/") write_tb(writer, ['lr'], [opt.param_groups[0]['lr']], iters, prefix="train/") train_metrics.reset()
def get_grounding(self, en_msg, en_msg_len, batch, en_lm=None, all_img=None, ranker=None, use_gumbel_tokens=False): """ Forward speaker with English sentence to get grounding loss and rewards """ results = {} rewards = {} batch_size = en_msg.shape[0] # NOTE add <BOS> to beginning en_msg_ = torch.cat([ cuda(torch.full((batch_size, 1), self.init_token)).long(), en_msg ], dim=1) gumbel_tokens = None if use_gumbel_tokens: gumbel_tokens = self.fr_en.dec.gumbel_tokens init_tokens = torch.zeros( [gumbel_tokens.shape[0], 1, gumbel_tokens.shape[2]]) init_tokens = init_tokens.to(device=gumbel_tokens.device) init_tokens[:, :, self.init_token] = 1 gumbel_tokens = torch.cat([init_tokens, gumbel_tokens], dim=1) if self.use_en_lm: # monitor EN LM NLL if "wiki" in self.en_lm_dataset: if use_gumbel_tokens: raise NotImplementedError en_nll_lm = en_lm.get_nll(en_msg_) # (batch_size, en_msg_len) if self.train_en_lm: en_nll_lm = sum_reward(en_nll_lm, en_msg_len + 1) # (batch_size) rewards['lm'] = -1 * en_nll_lm.detach() # R = R + -1 * en_nll_lm.detach() * self.en_lm_nll_co # (batch_size) results.update({"en_nll_lm": en_nll_lm.mean()}) elif self.en_lm_dataset in ["coco", "multi30k"]: if use_gumbel_tokens: en_lm.train() en_nll_lm = en_lm.get_loss_oh(gumbel_tokens, None) en_lm.eval() else: en_nll_lm = en_lm.get_loss( en_msg_, None) # (batch_size, en_msg_len) if self.train_en_lm: en_nll_lm = sum_reward(en_nll_lm, en_msg_len + 1) # (batch_size) rewards['lm'] = -1 * en_nll_lm.detach() results.update({"en_nll_lm": en_nll_lm.mean()}) else: raise Exception() if self.use_ranker: # NOTE Experiment 3 : Reward = NLL_DE + NLL_EN_LM + NLL_IMG_PRED if use_gumbel_tokens and self.train_ranker: raise NotImplementedError ranker.eval() img = cuda(all_img.index_select( dim=0, index=batch.idx.cpu())) # (batch_size, D_img) if self.img_pred_loss == "nll": img_pred_loss = ranker.get_loss( en_msg_, img) # (batch_size, en_msg_len) img_pred_loss = sum_reward(img_pred_loss, en_msg_len + 1) # (batch_size) else: with torch.no_grad(): img_pred_loss = ranker(en_msg, en_msg_len, img)["loss"] if self.train_ranker: rewards['img_pred'] = -1 * img_pred_loss.detach() results.update({ "img_pred_loss_{}".format(self.img_pred_loss): img_pred_loss.mean() }) # Get ranker retrieval result with torch.no_grad(): K = 19 # Randomly select K distractor image random_idx = torch.randint(all_img.shape[0], size=[batch_size, K]) wrong_img = cuda( all_img.index_select(dim=0, index=random_idx.view(-1))) wrong_img_feat = ranker.batch_enc_img(wrong_img).view( batch_size, K, -1) right_img_feat = ranker.batch_enc_img(img) # [bsz, K+1, hid_size] all_feat = torch.cat( [right_img_feat.unsqueeze(1), wrong_img_feat], dim=1) # [bsz, hid_size] cap_feats = ranker.batch_cap_rep(en_msg, en_msg_len) scores = (cap_feats.unsqueeze(1) * all_feat).sum(-1) r1_acc = (torch.argmax(scores, -1) == 0).float().mean() results['r1_acc'] = r1_acc return results, rewards
def train_model(args, model, iterators, extra_input): (train_its, dev_its) = iterators if not args.debug: from tensorboardX import SummaryWriter writer = SummaryWriter(args.event_path + args.id_str) params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == 'Adam': opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr) else: raise NotImplementedError loss_names, loss_cos = ["loss"], {"loss": 1.0} monitor_names = "cap_r1 cap_r5 cap_r10 img_r1 img_r5 img_r10".split() train_metrics = Metrics('train_loss', *loss_names, data_type="avg") best = Best(max, 'r1', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \ gpu=args.gpu, debug=args.debug) iters = 0 for epoch in range(999999999): for dataset in args.dataset.split("_"): train_it = train_its[dataset] for _, train_batch in enumerate(train_it): iters += 1 if iters % args.eval_every == 0: R = valid_model(args, model, dev_its['multi30k'], extra_input) if not args.debug: write_tb(writer, monitor_names, R, iters, prefix="dev/") best.accumulate((R[0] + R[3]) / 2, iters) args.logger.info('model:' + args.prefix + args.hp_str) args.logger.info('epoch {} dataset {} iters {}'.format( epoch, dataset, iters)) args.logger.info(best) if args.early_stop and ( iters - best.iters) // args.eval_every > args.patience: args.logger.info("Early stopping.") return model.train() def get_lr_anneal(iters): lr_end = args.lr_min return max(0, (args.lr - lr_end) * (args.linear_anneal_steps - iters) / args.linear_anneal_steps) + lr_end if args.lr_anneal == "linear": opt.param_groups[0]['lr'] = get_lr_anneal(iters) opt.zero_grad() batch_size = len(train_batch) img = extra_input["img"][dataset][0].index_select( dim=0, index=train_batch.idx.cpu()) # (batch_size, D_img) en, en_len = train_batch.__dict__["_" + str(random.randint(1, 5))] R = model(en[:, 1:], en_len - 1, cuda(img)) R['loss'] = R['loss'].mean() total_loss = 0 for loss_name in loss_names: total_loss += R[loss_name] * loss_cos[loss_name] train_metrics.accumulate( batch_size, *[R[name].item() for name in loss_names]) total_loss.backward() if args.plot_grad: plot_grad(writer, model, iters) if args.grad_clip > 0: nn.utils.clip_grad_norm_(params, args.grad_clip) opt.step() if iters % args.eval_every == 0: args.logger.info("update {} : {}".format( iters, str(train_metrics))) if iters % args.eval_every == 0 and not args.debug: write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \ iters, prefix="train/") write_tb(writer, ['lr'], [opt.param_groups[0]['lr']], iters, prefix="train/") train_metrics.reset()
def beam_search(self, h, h_len, width): # (batch_size, x_seq_len, D_hid * n_dir) voc_size, batch_size, x_seq_len = self.voc_sz_trg, h.size()[0], h.size()[1] live = [ [ ( 0.0, [ self.init_token ], 2 ) ] for ii in range(batch_size) ] dead = [ [] for ii in range(batch_size) ] n_dead = [0 for ii in range(batch_size)] xmask = xlen_to_inv_mask(h_len)[:,None,:] # (batch_size, 1, x_seq_len) max_len_gen = self.max_len_gen if self.msg_len_ratio < 0.0 else int(x_seq_len * self.msg_len_ratio) max_len_gen = np.clip(max_len_gen, 2, self.max_len_gen).item() h_index = (h_len - 1)[:,None,None].repeat(1, 1, h.size(2)) z_t = torch.cumsum(h, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div( z_t, h_len[:, None].float() ) z_t = self.ctx_to_z0(z_t) # (batch_size, n_layers * D_hid) z_t = z_t.view(batch_size, self.n_layers, self.D_hid).transpose(0,1).contiguous() \ input = cuda( torch.full((batch_size, 1), self.init_token).long() ) input = self.emb( input ) # (batch_size, 1, D_emb) h_big = h.contiguous().view(-1, self.D_hid * self.n_dir ) \ # (batch_size * x_seq_len, D_hid * n_dir) ctx_h = self.h_to_att( h_big ).view(batch_size, 1, x_seq_len, self.D_att) \ # NOTE (batch_size, 1, x_seq_len, D_att) for tidx in range(max_len_gen): cwidth = 1 if tidx == 0 else width # input (batch_size * width, 1, D_emb) # z_t (n_layers, batch_size * width, D_hid) _, z_t_ = self.rnn( input, z_t ) # out (batch_size * width, 1, D_hid) # z_t (n_layers, batch_size * width, D_hid) ctx_z_t_ = z_t_.transpose(0,1).contiguous().view(batch_size * cwidth, -1) \ # (batch_size * width, n_layers * D_hid) ctx_y = self.y_to_att( input.view(-1, self.D_emb) ).view(batch_size, cwidth, 1, self.D_att) ctx_s = self.z_to_att( ctx_z_t_ ).view(batch_size, cwidth, 1, self.D_att) ctx = F.tanh(ctx_y + ctx_s + ctx_h) # (batch_size, cwidth, x_seq_len, D_att) ctx = ctx.view(batch_size * cwidth * x_seq_len, self.D_att) score = self.att_to_score(ctx).view(batch_size, -1, x_seq_len) # (batch_size, cwidth, x_seq_len) score.masked_fill_(xmask.repeat(1, cwidth, 1), -float('inf')) score = F.softmax( score.view(-1, x_seq_len), dim=1 ).view(batch_size, -1, x_seq_len) score = score.view(batch_size, cwidth, x_seq_len, 1) # (batch_size, width, x_seq_len, 1) c_t = torch.mul( h.view(batch_size, 1, x_seq_len, -1), score ) # (batch_size, width, x_seq_len, D_hid * n_dir) c_t = torch.sum( c_t, 2).view(batch_size * cwidth, -1) # (batch_size * width, D_hid * n_dir) # c_t (batch_size * width, 1, D_hid * n_dir) # z_t (n_layers, batch_size * width, D_hid) out, z_t = self.rnn2( c_t[:,None,:], z_t_ ) # out (batch_size * width, 1, D_hid) # z_t (n_layers, batch_size * width, D_hid) fin_y = self.y_to_out( input.view(-1, self.D_emb) ) # (batch_size * width, D_out) fin_c = self.c_to_out( c_t ) # (batch_size * width, D_out) fin_s = self.z_to_out( out.view(-1, self.D_hid) ) # (batch_size * width, D_out) fin = F.tanh( fin_y + fin_c + fin_s ) cur_prob = F.log_softmax( self.out( fin.view(-1, self.D_out) ), dim=1 )\ .view(batch_size, cwidth, voc_size).data # (batch_size, width, voc_sz_trg) pre_prob = cuda( torch.FloatTensor( [ [ x[0] for x in ee ] for ee in live ] ).view(batch_size, cwidth, 1) ) \ # (batch_size, width, 1) total_prob = cur_prob + pre_prob # (batch_size, cwidth, voc_size) total_prob = total_prob.view(batch_size, -1) _, topi_s = total_prob.topk(width, dim=1) topv_s = cur_prob.view(batch_size, -1).gather(1, topi_s) # (batch_size, width) new_live = [ [] for ii in range(batch_size) ] for bidx in range(batch_size): n_live = width - n_dead[bidx] if n_live > 0: tis = topi_s[bidx][:n_live] tvs = topv_s[bidx][:n_live] for eidx, (topi, topv) in enumerate(zip(tis, tvs)): # NOTE max width times if topi % voc_size == self.eoz_token : dead[bidx].append( ( live[bidx][ topi // voc_size ][0] + topv, \ live[bidx][ topi // voc_size ][1] + [ topi % voc_size ],\ topi) ) n_dead[bidx] += 1 else: new_live[bidx].append( ( live[bidx][ topi // voc_size ][0] + topv, \ live[bidx][ topi // voc_size ][1] + [ topi % voc_size ],\ topi) ) while len(new_live[bidx]) < width: new_live[bidx].append( ( -99999999999, \ [0],\ 0) ) live = new_live if n_dead == [width for ii in range(batch_size)]: break in_vocab_idx = [ [ x[2] % voc_size for x in ee ] for ee in live ] # NOTE batch_size first input = self.emb( cuda( torch.LongTensor( in_vocab_idx ) ).view(-1) )\ .view(-1, 1, self.D_emb) # input (batch_size * width, 1, D_emb) in_width_idx = [ [ x[2] // voc_size + bbidx * cwidth for x in ee ] for bbidx, ee in enumerate(live) ] \ # live : (batch_size, width) z_t = z_t.index_select( 1, cuda( torch.LongTensor( in_width_idx ).view(-1) ) ).\ view(self.n_layers, batch_size * width, self.D_hid) # h_0 (n_layers, batch_size * width, D_hid) for bidx in range(batch_size): if n_dead[bidx] < width: for didx in range( width - n_dead[bidx] ): (a, b, c) = live[bidx][didx] dead[bidx].append( (a, b, c) ) dead_ = [ [ ( a / ( math.pow(5+len(b), self.beam_alpha) / math.pow(5+1, self.beam_alpha) ) , b, c) for (a,b,c) in ee] for ee in dead] ans = [] for dd_ in dead_: dd = sorted( dd_, key=operator.itemgetter(0), reverse=True ) ans.append( dd[0][1] ) return ans
def send(self, h, h_len, y_len, send_method): # h : (batch_size, x_seq_len, D_hid * n_dir) # h_len : (batch_size) batch_size, x_seq_len = h.size()[:2] xmask = xlen_to_inv_mask(h_len) # (batch_size, x_seq_len) max_len_gen = self.max_len_gen if self.msg_len_ratio < 0.0 else int(y_len.max().item() * self.msg_len_ratio) max_len_gen = np.clip(max_len_gen, 2, self.max_len_gen).item() h_index = (h_len - 1)[:,None,None].repeat(1, 1, h.size(2)) z_t = torch.cumsum(h, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div( z_t, h_len[:, None].float() ) z_t = self.ctx_to_z0(z_t) # (batch_size, n_layers * D_hid) z_t = z_t.view(batch_size, self.n_layers, self.D_hid).transpose(0,1).contiguous() \ y_emb = cuda( torch.full((batch_size, 1), self.init_token).long() ) y_emb = self.emb( y_emb ) # (batch_size, 1, D_emb) y_emb = F.dropout( y_emb, p=self.drop_ratio, training=self.training ) h_big = h.view(-1, h.size(2) ) \ # (batch_size * x_seq_len, D_hid * n_dir) ctx_h = self.h_to_att( h_big ).view(batch_size, x_seq_len, self.D_att) done = cuda( torch.zeros(batch_size).long() ) seq_lens = cuda( torch.zeros(batch_size).fill_(max_len_gen).long() ) # (max_len) max_seq_lens = cuda( torch.zeros(batch_size).fill_(max_len_gen).long() ) # (max_len) eos_tensor = cuda( torch.zeros(batch_size).fill_( self.eoz_token )).long() msg = [] self.log_probs = [] # (batch_size, seq_len) batch_logits = 0 for idx in range(max_len_gen): # in (batch_size, 1, D_emb) # z_t (n_layers, batch_size, D_hid) _, z_t_ = self.rnn( y_emb, z_t ) # out (batch_size, 1, D_hid) # z_t (n_layers, batch_size, D_hid) ctx_z_t_ = z_t_.transpose(0,1).contiguous().view(batch_size, self.n_layers * self.D_hid) \ # (batch_size, n_layers * D_hid) ctx_y = self.y_to_att( y_emb.view(batch_size, self.D_emb) )[:,None,:] ctx_s = self.z_to_att( ctx_z_t_ )[:,None,:] ctx = F.tanh(ctx_y + ctx_s + ctx_h) ctx = ctx.view(batch_size * x_seq_len, self.D_att) score = self.att_to_score(ctx).view(batch_size, -1) # (batch_size, x_seq_len) score.masked_fill_(xmask, -float('inf')) score = F.softmax( score, dim=1 ) score = score[:,:,None] # (batch_size, x_seq_len, 1) c_t = torch.mul( h, score ) # (batch_size, x_seq_len, D_hid * n_dir) c_t = torch.sum( c_t, 1) # (batch_size, D_hid * n_dir) # in (batch_size, 1, D_hid * n_dir) # z_t (n_layers, batch_size, D_hid) out, z_t = self.rnn2( c_t[:,None,:], z_t_ ) # out (batch_size, 1, D_hid) # z_t (n_layers, batch_size, D_hid) fin_y = self.y_to_out( y_emb.view(batch_size, self.D_emb) ) # (batch_size, D_out) fin_c = self.c_to_out( c_t ) # (batch_size, D_out) fin_s = self.z_to_out( out.view(batch_size, self.D_hid) ) # (batch_size, D_out) fin = F.tanh( fin_y + fin_c + fin_s ) logit = self.out( fin ) # (batch_size, voc_sz_trg) if send_method == "argmax": tokens = logit.data.max(dim=1)[1] # (batch_size) tokens_idx = tokens elif send_method == "gumbel": tokens = gumbel_softmax_hard(logit, self.temp, self.st)\ .view(batch_size, self.voc_sz_trg) # (batch_size, voc_sz_trg) tokens_idx = (tokens * cuda(torch.arange(self.voc_sz_trg))[None,:]).sum(dim=1).long() elif send_method == "reinforce": cat = Categorical(logits=logit) tokens = cat.sample() tokens_idx = tokens self.log_probs.append( cat.log_prob(tokens) ) if self.entropy: batch_logits += (logit * (1-done)[:,None].float()).sum(dim=0) msg.append(tokens.unsqueeze(dim=1)) is_next_eos = ( tokens_idx == eos_tensor ).long() # (batch_size) # (1 if eos, 0 otherwise) done = (done + is_next_eos).clamp(min=0, max=1).long() new_seq_lens = max_seq_lens.clone().masked_fill_(mask=is_next_eos.byte(), \ value=float(idx+1)) # either max or idx if next is eos # max_seq_lens : (batch_size) seq_lens = torch.min(seq_lens, new_seq_lens) if done.sum() == batch_size: break y_emb = self.emb(tokens)[:,None,:] # (batch_size, 1, D_emb) if self.msg_len_ratio > 0.0 : seq_lens = torch.clamp( (y_len.float() * self.msg_len_ratio).floor_(), 1, len(msg) ).long() msg = torch.cat(msg, dim=1) if send_method == "reinforce": # want to sum per-token log prob to yield log prob for the whole message sentence self.log_probs = torch.stack(self.log_probs, dim=1) # (batch_size, seq_len) # (batch_size, seq_len) if argmax or reinforce # (batch_size, seq_len, voc_sz_trg) if gumbel results = {"msg":msg, "seq_lens":seq_lens} if send_method == "reinforce" and self.entropy: results.update( {"batch_logits":batch_logits} ) return results
def beam(self, src_hid, src_len): # src_hid : (batch_size, x_seq_len, D_hid * n_dir) # src_len : (batch_size) batch_size, x_seq_len = src_hid.size()[:2] src_mask = xlen_to_inv_mask( src_len, seq_len=x_seq_len) # (batch_size, x_seq_len) voc_size, width = self.voc_sz_trg, self.beam_width y_seq_len = self.max_len_gen h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2)) z_t = torch.cumsum(src_hid, dim=1) # (batch_size, x_seq_len, D_hid * n_dir) z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir) z_t = torch.div(z_t, src_len[:, None].float()) y_emb = cuda(torch.full((batch_size, ), self.init_token).long()) y_emb = self.emb(y_emb) # (batch_size, D_emb) input_feed = y_emb.data.new(batch_size, self.D_hid).zero_() live = [[(0.0, [self.init_token], 2)] for ii in range(batch_size)] dead = [[] for ii in range(batch_size)] n_dead = [0 for ii in range(batch_size)] src_hid_ = src_hid src_mask_ = src_mask for idx in range(y_seq_len): cwidth = 1 if idx == 0 else width input = torch.cat([y_emb, input_feed], dim=1) # (batch_size * width, D_emb + D_hid) trg_hid = self.layers[0](input, z_t) # (batch_size * width, D_hid) z_t = trg_hid # (batch_size * width, D_hid) out, attn_scores = self.attention( trg_hid, src_hid_, src_mask_) # (batch_size * width, D_hid) input_feed = out # (batch_size * width, D_hid) logit = self.out(out) # (batch_size * width, voc_sz_trg) cur_prob = F.log_softmax(logit, dim=1).view(batch_size, cwidth, self.voc_sz_trg) pre_prob = cuda( torch.FloatTensor( [ [ x[0] for x in ee ] for ee in live ] )\ .view(batch_size, cwidth, 1) ) # (batch_size, width, 1) total_prob = cur_prob + pre_prob # (batch_size, cwidth, voc_sz) total_prob = total_prob.view(batch_size, -1) # (batch_size, cwidth * voc_sz) topi_s = total_prob.topk(width, dim=1)[1] topv_s = cur_prob.view(batch_size, -1).gather(1, topi_s) new_live = [[] for ii in range(batch_size)] for bidx in range(batch_size): n_live = width - n_dead[bidx] if n_live > 0: tis = topi_s[bidx][:n_live].cpu().numpy().tolist() tvs = topv_s[bidx][:n_live].cpu().numpy().tolist() for eidx, (topi, topv) in enumerate(zip(tis, tvs)): if topi % voc_size == self.eos_token: dead[bidx].append( (live[bidx][topi // voc_size][0] + topv, live[bidx][topi // voc_size][1] + [topi % voc_size], topi)) n_dead[bidx] += 1 else: new_live[bidx].append( (live[bidx][topi // voc_size][0] + topv, live[bidx][topi // voc_size][1] + [topi % voc_size], topi)) while len(new_live[bidx]) < width: new_live[bidx].append((-99999999999, [0], 0)) live = new_live if n_dead == [width for ii in range(batch_size)]: break in_vocab_idx = np.array([[x[2] % voc_size for x in ee] for ee in live]) # NOTE batch_size first y_emb = self.emb( cuda( torch.LongTensor( in_vocab_idx ) ).view(-1) )\ .view(-1, self.D_emb) # input (batch_size * width, 1, D_emb) in_width_idx = np.array( [ [ x[2] // voc_size + bbidx * cwidth for x in ee ] \ for bbidx, ee in enumerate(live) ] ) in_width_idx = cuda(torch.LongTensor(in_width_idx).view(-1)) z_t = z_t.index_select(0, in_width_idx).view(batch_size * width, self.D_hid) input_feed = input_feed.index_select(0, in_width_idx).view( batch_size * width, self.D_hid) src_hid_ = src_hid_.index_select(0, in_width_idx).view( batch_size * width, x_seq_len, self.D_hid) src_mask_ = src_mask_.index_select(0, in_width_idx).view( batch_size * width, x_seq_len) for bidx in range(batch_size): if n_dead[bidx] < width: for didx in range(width - n_dead[bidx]): (a, b, c) = live[bidx][didx] dead[bidx].append((a, b, c)) dead_ = [ [ ( a / ( math.pow(5+len(b), self.beam_alpha) / math.pow(5+1, self.beam_alpha) ) , b, c)\ for (a,b,c) in ee] for ee in dead] ans = [[], [], [], [], []] for dd_ in dead_: dd = sorted(dd_, key=operator.itemgetter(0), reverse=True) for idx in range(5): ans[idx].append(dd[idx][1]) #ans.append( dd[0][1] ) return ans