def getProb(input_event, output_event, category='xNeed'): batch = set_atomic_inputs(input_event, output_event, category, data_loader, text_encoder) start_idx = data_loader.max_event + data.atomic_data.num_delimiter_tokens["category"] XMB = batch["sequences"][:, :start_idx] MMB = batch["attention_mask"][:, :start_idx] XMB = model_utils.prepare_position_embeddings(opt, data_loader.vocab_encoder, XMB.unsqueeze(-1)) beam_ll = 0 for w in output_event.split(): lm_probs = F.log_softmax(model( XMB.unsqueeze(1), sequence_mask=MMB), dim=-1) dist = lm_probs[:, -1, :].squeeze() word = w + '</w>' # import ipdb; ipdb.set_trace() if word not in data_loader.vocab_encoder: return -1000 else: tok_ll = dist[data_loader.vocab_encoder[w + '</w>']] next_tok = torch.tensor([[data_loader.vocab_encoder[w + '</w>']]], dtype=torch.long, device=MMB.device) beam_ll += tok_ll next_pos = XMB[:, -1:, 1] + 1 next_x = torch.cat((next_tok, next_pos), -1).unsqueeze(1) XMB = torch.cat((XMB, next_x), 1) MMB = torch.cat([MMB, torch.ones(XMB.size(0), 1, device=MMB.device)], 1) return beam_ll
def generate_sequence(self, batch, model, data_loader, start_idx, end_len): XMB = batch["sequences"][:, :start_idx] MMB = batch["attention_mask"][:, :start_idx] XMB = model_utils.prepare_position_embeddings( self.opt, data_loader.vocab_encoder, XMB.unsqueeze(-1)) lm_probs = F.log_softmax(model( XMB.unsqueeze(1), sequence_mask=MMB), dim=-1) values, indices = lm_probs[:, -1, :].max(dim=-1) seqs = indices.clone().unsqueeze(1) loss = values counts = 1 next_pos = XMB[:, -1:, 1] + 1 next_x = torch.cat((indices.view(-1, 1), next_pos), -1).unsqueeze(1) XMB = torch.cat((XMB, next_x), 1) MMB = torch.cat([MMB, torch.ones(XMB.size(0), 1, device=MMB.device)], 1) # Sample from top k for _ in range(self.opt.eval.smax): lm_probs = F.log_softmax(model( XMB.unsqueeze(1), sequence_mask=MMB), dim=-1) # Sample from top k values, next_idx = lm_probs[:, -1, :].max(dim=-1) loss += values counts += 1 next_idx = next_idx.unsqueeze(1) seqs = torch.cat([seqs, next_idx], 1) if (next_idx.item() == self.end_token) or (_ == end_len - 1): break XMB, MMB = self.append_batch(XMB, next_idx, MMB) beams = [] for beam in seqs: beams.append(" ".join("".join( [data_loader.vocab_decoder[tok.item()].replace( '</w>', ' ').replace('\n', '') for tok in beam if tok != self.end_token]).split())) sampling_result = { "sequence": beams[0], "beams": beams, "beam_losses": [loss.item()], "loss": loss.item(), "beam_lengths": [counts], "length": counts } return sampling_result
def batch_conceptnet_generate(opt, nums, losses, batch_variables, eval_mode=False, tracking_mode=False): data_loader = batch_variables["data"] model = batch_variables["model"] split = batch_variables["split"] category = batch_variables["category"] batch, reset = data_loader.sample_batch(split, bs=opt.train.dynamic.bs, cat=category) input_ = model_utils.prepare_position_embeddings( opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1)) attention_mask = batch["attention_mask"] loss_mask = batch["loss_mask"] targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1) loss, dist = mle_steps(opt.net.model, model, input_[:, :-1, :], targets, attention_mask[:, :-1], loss_reduction="none") # Set loss name if not eval_mode or batch_variables["category"] == "positive": micro_name = "total_micro" macro_name = "total_macro" else: micro_name = "negative_micro" macro_name = "negative_macro" length = loss_mask.sum(1) bs = input_.size(0) final_loss = (loss * loss_mask).sum(1) update_generation_losses(losses, nums, micro_name, macro_name, bs, length, (loss * loss_mask).sum(1), split) final_loss = final_loss / length outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset} if tracking_mode: outputs["tracking"] = final_loss.squeeze().tolist() return outputs
def batch_atomic_generate(opt, nums, losses, batch_variables, eval_mode=False, comet_loss=False): data_loader = batch_variables["data"] model = batch_variables["model"] split = batch_variables["split"] batch, reset = data_loader.sample_batch(split, bs=opt.train.dynamic.bs) input_ = model_utils.prepare_position_embeddings( opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1)) attention_mask = batch["attention_mask"] loss_mask = batch["loss_mask"] #ipdb.set_trace() if comet_loss: len_before_lastobj = batch["len_before_lastobj"] for ii in range(loss_mask.size(0)): loss_mask[ii][:len_before_lastobj[ii] - 1] = 0 targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1) #ipdb.set_trace() loss, dist = mle_steps(opt.net.model, model, input_[:, :-1, :], targets, attention_mask[:, :-1], loss_reduction="none") # Set loss name micro_name = "total_micro" macro_name = "total_macro" length = loss_mask.sum(1) bs = input_.size(0) final_loss = (loss * loss_mask).sum(1) update_generation_losses(losses, nums, micro_name, macro_name, bs, length, (loss * loss_mask).sum(1), split) final_loss = final_loss / length outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset} return outputs
def generate_sequence(self, batch, model, data_loader, start_idx, end_len): # start_idx = context_size_event + 1 # start_idx = max_e1 + max_r # end_idx = context_size_effect - 1 # end_idx = max_e2 XMB = batch["sequences"][:, :start_idx] MMB = batch["attention_mask"][:, :start_idx] XMB = model_utils.prepare_position_embeddings( self.opt, data_loader.vocab_encoder, XMB.unsqueeze(-1)) tokens = [] beam_losses = [] # Beam Search beam_lls, beam_toks, beam_seqs = None, None, None lm_probs = F.log_softmax(model(XMB.unsqueeze(1), sequence_mask=MMB), dim=-1) dist = lm_probs[:, -1, :].squeeze() beam_lls, beam_toks = dist.topk(self.opt.eval.bs) beam_losses.append(beam_lls) ended = (beam_toks == self.end_token).float() counts = (2 - ended) beam_toks = beam_toks.unsqueeze(1) beam_seqs = beam_toks.clone() XMB = XMB.repeat(self.opt.eval.bs, 1, 1) MMB = MMB.repeat(self.opt.eval.bs, 1) next_pos = XMB[:, -1:, 1] + 1 next_x = torch.cat((beam_toks, next_pos), -1).unsqueeze(1) XMB = torch.cat((XMB, next_x), 1) MMB = torch.cat( [MMB, torch.ones(XMB.size(0), 1, device=MMB.device)], 1) for _ in range(end_len): # Compute distribution for current beam lm_probs = F.log_softmax(model(XMB.unsqueeze(1), attention_mask=MMB), dim=-1) dist = lm_probs[:, -1, :].squeeze() # get hypothesis tokens for distribution hyp_beam_lls, hyp_beam_toks = dist.topk(self.opt.eval.bs) # Compute masks and expand beam expanded_ended = ended.unsqueeze(1).repeat(1, self.opt.eval.bs) hypothesis_mask = expanded_ended * self.kill_mask + ( 1 - expanded_ended) paper_results = False if paper_results: # Results from paper with slightly buggy beam search current_beam_lls = beam_lls.unsqueeze(1).repeat( 1, self.opt.eval.bs).view(self.opt.eval.bs**2) else: # Current beam search implementation current_beam_lls = beam_losses[-1].unsqueeze(1).repeat( 1, self.opt.eval.bs).view(self.opt.eval.bs**2) # Compute losses of hypotheses, masking those that have ended hyp_beam_lls = (hyp_beam_lls.view(self.opt.eval.bs**2) * hypothesis_mask.view(-1)) + current_beam_lls # Get normalizer for sequences temp_counts = counts.unsqueeze(1).repeat(1, self.opt.eval.bs).view( self.opt.eval.bs**2) # Select best beams with lowest aggregate loss beam_lls, top_beam_idxs = (hyp_beam_lls / temp_counts).topk( self.opt.eval.bs) # Update placements in beam based on selecetion beam_losses = [ i.index_select(0, top_beam_idxs // self.opt.eval.bs) for i in beam_losses ] ended = ended.index_select(0, top_beam_idxs // self.opt.eval.bs) counts = temp_counts.index_select(0, top_beam_idxs) # Save beam losses beam_losses.append(beam_lls * counts) # Update beam tokens ended_mask = (1 - ended).long() end_replacement = (self.end_token * ended).long() next_toks = hyp_beam_toks.view(-1)[top_beam_idxs] beam_toks = next_toks * ended_mask + end_replacement # Update ended and counts ended = ended + (beam_toks == self.end_token).float() * (1 - ended) counts = counts + (1 - ended) # Update beam sequences beam_seqs = beam_seqs.t().repeat(self.opt.eval.bs, 1).t().contiguous().view( self.opt.eval.bs**2, -1)[top_beam_idxs] beam_seqs = torch.cat((beam_seqs, beam_toks.unsqueeze(1)), dim=1) # I have no idea what's going on but Ari's on point with it XMB = XMB.transpose(0, 1).transpose(1, 2).repeat( self.opt.eval.bs, 1, 1).transpose(2, 1).transpose( 1, 0).contiguous().view(self.opt.eval.bs**2, XMB.size(1), XMB.size(2))[top_beam_idxs] XMB, MMB = self.append_batch(XMB, beam_toks, MMB) if (beam_toks == self.end_token).sum().item() == self.opt.eval.bs: break beams = [] for beam in beam_seqs: beams.append(" ".join("".join([ data_loader.vocab_decoder[tok.item()].replace( '</w>', ' ').replace('\n', '') for tok in beam if tok != self.end_token ]).split())) sampling_result = { "sequence": beams[0], "beams": beams, "beam_losses": beam_lls.tolist(), "loss": beam_lls[0].item(), "beam_lengths": counts.tolist(), "length": counts[0].item() } return sampling_result
def generate_sequence(self, batch, model, data_loader, start_idx, end_len): # start_idx = context_size_event + 1 # start_idx = max_e1 + max_r # end_idx = context_size_effect - 1 # end_idx = max_e2 XMB = batch["sequences"][:, :start_idx] MMB = batch["attention_mask"][:, :start_idx] XMB = model_utils.prepare_position_embeddings( self.opt, data_loader.vocab_encoder, XMB.unsqueeze(-1)) lm_probs = F.log_softmax(model(XMB.unsqueeze(1), attention_mask=MMB), dim=-1) values, indices = lm_probs[:, -1, :].topk(self.opt.eval.k) seqs = indices.t().clone() losses = -values.view(-1, 1) ended = (seqs == self.end_token).float() counts = (1 - ended) XMB = XMB.repeat(self.opt.eval.k, 1, 1) MMB = MMB.repeat(self.opt.eval.k, 1) next_pos = XMB[:, -1:, 1] + 1 next_x = torch.cat((indices.view(self.opt.eval.k, -1), next_pos), -1).unsqueeze(1) XMB = torch.cat((XMB, next_x), 1) MMB = torch.cat( [MMB, torch.ones(XMB.size(0), 1, device=MMB.device)], 1) # Sample from top k for _ in range(end_len): lm_probs = F.log_softmax(model(XMB.unsqueeze(1), sequence_mask=MMB), dim=-1) # Sample from top k values, indices = lm_probs[:, -1, :].topk(self.opt.eval.k) choice = torch.multinomial(values.exp(), 1) next_idx = indices.gather(-1, choice) ended = ended + (next_idx == self.end_token).float() * (1 - ended) next_idx = next_idx * ( 1 - ended).long() + ended.long() * self.end_token counts += (1 - ended) seqs = torch.cat([seqs, next_idx], 1) if ended.sum().item() == self.opt.eval.k: break losses -= values.gather(-1, choice) * (1 - ended) XMB, MMB = self.append_batch(XMB, next_idx, MMB) beams = [] for beam in seqs: beams.append(" ".join("".join([ data_loader.vocab_decoder[tok.item()].replace( '</w>', ' ').replace('\n', '') for tok in beam if tok != self.end_token ]).split())) sampling_result = { "sequence": beams[0], "beams": beams, "beam_losses": losses.squeeze().tolist(), "loss": losses[0].item(), "beam_lengths": counts.long().squeeze().tolist(), "length": counts[0].long().item() } return sampling_result
batch, reset = data_loader.sample_batch(split=split, bs=1, idxs=[idx]) XMB = batch["sequences"][:, :data_loader.max_e1 + data_loader.max_r] Ref = batch["sequences"][:, data_loader.max_e1 + data_loader.max_r:] MMB = batch["attention_mask"][:, :data_loader.max_e1 + data_loader.max_r] init = "".join([text_encoder.decoder[i].replace('</w>', ' ').replace( "<blank>", "___ ") for i in XMB[:, :data_loader.max_e1].squeeze().tolist() if i]) start = data_loader.max_e1 end = data_loader.max_e1 + data_loader.max_r attr = "".join([text_encoder.decoder[i].replace( '</w>', ' ') for i in XMB[:, start:end].squeeze(0).tolist() if i]).strip() XMB = model_utils.prepare_position_embeddings( opt, text_encoder.encoder, XMB.unsqueeze(-1)) sequence_all["e1"] = init sequence_all["r"] = attr sequence_all["key"] = batch["key"] tokens = [] beam_losses = [] # Beam Search beam_lls, beam_toks, beam_seqs = None, None, None lm_probs = F.log_softmax(lm_model( XMB.unsqueeze(1), sequence_mask=MMB), dim=-1) dist = lm_probs[:, -1, :].squeeze() beam_lls, beam_toks = dist.topk(args.beam) beam_losses.append(beam_lls)