def embed_body(self, dcmt, target_embeded): dcmts, ndocs, nsents, dcmt_nwords = dcmt doc_mask = sequence_mask(ndocs, device=self.config.gpu) doc_sent_mask = sequence_mask(nsents, device=self.config.gpu) doc_word_mask = sequence_mask(dcmt_nwords, device=self.config.gpu) batch_size, max_num_doc, max_num_sent, max_seqlen = dcmts.size() dcmts_reshaped = dcmts.view(-1, max_seqlen) #(batch_size * max_num_doc * max_num_sent, max_seqlen, embed_dim) dcmts_embeded = self.wembed(dcmts_reshaped) # _, sent_reprs = self.w2s(dcmts_embeded, \ # mask = doc_word_mask.view(-1, doc_word_mask.size(-1), 1)) #(batch_size * max_num_doc * max_num_sent, max_seqlen, hidden_dim * 2) hiddens, _ = self.w2s(dcmts_embeded) # ipdb.set_trace() target_embeded_expand = target_embeded.repeat(max_num_sent, 1, 1) sent_reprs, _ = attention(target_embeded_expand, hiddens, hiddens,\ mask = doc_word_mask.view(-1, 1, doc_word_mask.size(-1)), dropout = self.dropout, scale = True) #(batch_size * max_num_doc, max_num_sent, hidden_dim * 2) sent_reprs = sent_reprs.view(-1, max_num_sent, sent_reprs.size(-1)) doc_sent_mask = doc_sent_mask.view(-1, doc_sent_mask.size(-1), 1) return sent_reprs, doc_sent_mask, doc_mask
def embed_doc(self, target, lead, dcmt): tgt_hiddens, tgt_output, tgt_nword, tgt_mask =\ self.embed_sent(target, self.t2v) #(batch_size, max_num_sent, max_num_word) text, ndoc, nword = dcmt # target, max_num_word = text.size(-1) #(batch_size, max_num_sent, 1) doc_mask = sequence_mask(ndoc, device=self.config.gpu).unsqueeze(-1) #(batch_size * max_num_sent, max_num_word, 1) word_mask = sequence_mask(nword, device=self.config.gpu).view( -1, max_num_word, 1) #(batch_size * max_num_sent, max_num_words, embed_dim) text_embed = self.embed(text.view(-1, max_num_word)) #(batch_size * max_num_sent, output_dim) _, sent_repr = self.w2s(text_embed, mask=word_mask, init=tgt_output[1]) sent_repr = self.dropout(sent_repr) sent_repr = sent_repr.view(*text.size()[:2], -1) _, doc_repr = self.s2d(sent_repr, mask=doc_mask, init=tgt_output[1]) doc_repr = self.dropout(doc_repr) return doc_repr, _
def embed_body(self, target, dcmt): dcmts, ndocs, nsents, dcmt_nwords = dcmt doc_mask = sequence_mask(ndocs, device=self.config.gpu) doc_sent_mask = sequence_mask(nsents, device=self.config.gpu) doc_word_mask = sequence_mask(dcmt_nwords, device=self.config.gpu) batch_size, max_num_doc, max_num_sent, max_seqlen = dcmts.size() dcmts_reshaped = dcmts.view(-1, max_seqlen) #(batch_size * max_num_doc, embed_dim) target_embed = self.embed_target(target) #(batch_size * max_num_doc * max_num_sent, max_seqlen, embed_dim) dcmts_embeded = self.wembed(dcmts_reshaped) #(batch_size * max_num_doc * max_num_sent, hidden_dim * 2) _, sent_reprs = self.w2s(dcmts_embeded, init = target_embed,\ mask = doc_word_mask.view(-1, doc_word_mask.size(-1), 1)) sent_reprs = self.dropout(sent_reprs) #(batch_size * max_num_doc, max_num_sent, hidden_dim * 2) sent_reprs = sent_reprs.view(-1, max_num_sent, sent_reprs.size(-1)) doc_sent_mask = doc_sent_mask.view(-1, doc_sent_mask.size(-1), 1) #(batch_size * max_num_doc, max_num_sent, hidden_dim * 2) _, doc_reprs = self.s2d(sent_reprs, mask=doc_sent_mask) doc_reprs = doc_reprs.view(batch_size, max_num_doc, -1) doc_reprs = self.dropout(doc_reprs) return doc_reprs, doc_mask
def embed_body(self, tgt_output, dcmt): #(batch_size, max_num_sent, max_num_word) text, ndoc, nword = dcmt # target, max_num_word = text.size(-1) #(batch_size, max_num_sent, 1) sent_mask = sequence_mask(ndoc, device=self.config.gpu).unsqueeze(-1) #(batch_size * max_num_sent, max_num_word, 1) word_mask = sequence_mask(nword, device=self.config.gpu).view( -1, max_num_word, 1) #(batch_size * max_num_sent, max_num_words, embed_dim) text_embed = self.embed(text.view(-1, max_num_word)) #(batch_size * max_num_sent, output_dim) _, sent_repr = self.w2s(text_embed, mask=word_mask, init=tgt_output[1]) sent_repr = self.dropout(sent_repr) sent_repr = sent_repr.view(*text.size()[:2], -1) ctx_sent_repr, _ = self.s2d(sent_repr, mask=sent_mask) ctx_sent_repr = self.dropout(ctx_sent_repr) return ctx_sent_repr, sent_mask
def forward(self, x, length): """Maps input to last hidden state, to pooler_output, to prediction Args: x (torch.LongTensor): input of shape (batch_size, seq_length) length (torch.LongTensor): input of shape (batch_size, ) Returns: x (torch.FloatTensor): logits of shape (batch_size, NUM_CLASS) """ x_mask = sequence_mask(length, pad=0, dtype=torch.float) # (batch_size, max_length) # TODO: clean this hack try: # bert, roberta _, cls, hidden_states = self.model(x, attention_mask=x_mask) # hidden_states : length 13 tuple of tensors (batch_size, max_length, hidden_size) if len(self.layer) == 1: x = self.time_pooling(cls, hidden_states[self.layer[0]], length) else: x = self.layer_pooling([self.time_pooling(cls, hidden_states[layer], length) for layer in self.layer]) except: # xlm, xlnet x = self.model(x, attention_mask=x_mask) # (batch_size, seq_length, hidden_size) x = x[0] x = self.out(x) # (batch_size, NUM_CLASS) return x
def embed_sent(self, sent, encoder, h0=None): snt, stn_nword = sent word_mask = sequence_mask(stn_nword, device=self.config.gpu).unsqueeze(-1) snt_embed = self.embed(snt) hiddens, output = encoder(snt_embed, init=h0, mask=word_mask) return hiddens, output, stn_nword, word_mask
def forward(self, x, target, length): """ Args: x: A Variable containing a FloatTensor of size (batch, max_len, dim) which contains the unnormalized probability for each class. target: A Variable containing a LongTensor of size (batch, max_len, dim) which contains the index of the true class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Returns: loss: An average loss value in range [0, 1] masked by the length. """ # mask: (batch, max_len, 1) target.requires_grad = False mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() if self.seq_len_norm: norm_w = mask / mask.sum(dim=1, keepdim=True) out_weights = norm_w.div(target.shape[0] * target.shape[2]) mask = mask.expand_as(x) # loss = functional.mse_loss( # x * mask, target * mask, reduction='none') loss = nn.MSELoss(reduction='none')(x * mask, target * mask) loss = loss.mul(out_weights.to(loss.device)).sum() else: mask = mask.expand_as(x) # loss = functional.mse_loss( # x * mask, target * mask, reduction='sum') loss = nn.MSELoss(reduction='sum')(x * mask, target * mask) loss = loss / mask.sum() return loss
def select_m_actions(self, probs, lengths, sql_labels): batch_size = probs.size(0) max_len = probs.size(1) probs = probs.transpose(0, 1).contiguous() if self.args.model == 'gate': probs = F.softmax(probs, dim=-1) m_log_probs, m_rewards = torch.FloatTensor( batch_size, self.args.m).to(device=self.args.device), torch.FloatTensor( batch_size, self.args.m).to(device=self.args.device) for index in range(self.args.m): # notice the max_len actions, log_probs = torch.LongTensor( batch_size, max_len).to(device=self.args.device), torch.FloatTensor( batch_size, max_len).to(device=self.args.device) for ti in range(max_len): actions[:, ti], log_probs[:, ti] = self.select_action_step(probs[ti]) # mask mask = sequence_mask(lengths, max_len).to(device=self.args.device) actions.data.masked_fill_(1 - mask, -100) log_probs.data.masked_fill_(1 - mask, 0.0) # compute rewards rewards = self.compute_rewards(actions, sql_labels, mode='rewards') m_log_probs[:, index], m_rewards[:, index] = torch.sum(log_probs, dim=1), rewards m_rewards -= m_rewards.mean(dim=-1).view(batch_size, 1) return m_log_probs, m_rewards
def _feature(self, inputs, lengths, tags_one_hot=None): """""" w_lengths, word_sort_ind = lengths.sort(dim=0, descending=True) # should catch from proper index inputs = inputs[word_sort_ind].to(device) if tags_one_hot is not None: tags_one_hot = tags_one_hot[word_sort_ind].byte().to(device) # compute features inputs_emb = self.embeddings(inputs) w = self.dropout(inputs_emb) # Pack padded sequence w = torch.nn.utils.rnn.pack_padded_sequence( w, list(w_lengths), batch_first=True ) # packed sequence of word_emb_dim + 2 * char_rnn_dim, with real sequence lengths # LSTM w, _ = self.BiLSTM( w) # packed sequence of word_rnn_dim, with real sequence lengths # Unpack packed sequence w, _ = torch.nn.utils.rnn.pad_packed_sequence( w, batch_first=True ) # (batch_size, max_word_len_in_batch, word_rnn_dim) w = self.dropout(w) mask = sequence_mask(w_lengths).float() crf_scores = self.crf_layer(w) return crf_scores, tags_one_hot, mask, w_lengths, word_sort_ind
def _qe_masking(qe): mask = utils.sequence_mask( torch.arange(qe.size()[-1] - 1, qe.size()[-1] - qe.size()[-2] - 1, -1).to(qe.device), qe.size()[-1]) mask = ~mask.to(mask.device) return mask.to(qe.dtype) * qe
def _test_step(self, batch): question_inds = batch["question_inds"] seq_length = batch["seq_length"] image_feat = batch["image_feat"] question_mask = sequence_mask(seq_length) outputs = self.offline_model(question_inds, question_mask, image_feat) return outputs
def train_encoder(mdl, crit, optim, sch, stat): """Train REL or EXT model""" logger.info(f'*** Epoch {stat.epoch} ***') mdl.train() it = DataLoader(load_dataset(args.dir_data, 'train'), args.model_type, args.batch_size, args.max_ntokens_src, spt_ids_B, spt_ids_C, eos_mapping) for batch in it: _, logits = mdl(batch) mask_inp = utils.sequence_mask(batch.src_lens, batch.inp.size(1)) loss = crit(logits, batch.tgt, mask_inp) loss.backward() stat.update(loss, 'train', args.model_type, logits=logits, labels=batch.tgt) torch.nn.utils.clip_grad_norm_(model.parameters(), 5) optim.step() if stat.steps == 0: continue if stat.steps % args.log_interval == 0: stat.lr = optim.param_groups[0]['lr'] stat.report() sch.step(stat.avg_train_loss) if stat.steps % args.valid_interval == 0: valid_ret(mdl, crit, optim, stat)
def embed_body(self, dcmt): text, ndoc, nword = dcmt # target, max_num_word = text.size(-1) #(batch_size, max_num_sent, 1) sent_mask = sequence_mask(ndoc, device=self.config.gpu).unsqueeze(-1) #(batch_size * max_num_sent, max_num_word, 1) word_mask = sequence_mask(nword, device=self.config.gpu).view( -1, max_num_word, 1) #(batch_size * max_num_sent, max_num_words, embed_dim) text_embed = self.embed(text.view(-1, max_num_word)) #(batch_size * max_num_sent, max_num_word, output_dim) sent_hiddens, _ = self.w2s(text_embed, mask=word_mask) return sent_hiddens, sent_mask, word_mask
def embed_lead(self, leads): leads, nleads, lead_nwords = leads max_num_lead = lead_nwords.size(-1) # (batch_size, max_num_doc) lead_mask = sequence_mask(nleads, device=self.config.gpu) # (batch_size, max_num_doc, max_seqlen) lead_word_mask = sequence_mask(lead_nwords, device=self.config.gpu) lead_word_mask = lead_word_mask.view(-1, lead_word_mask.size(-1), 1) leads = leads.view(-1, leads.size(-1)) #(batch_size * max_num_doc, max_seqlen, embed_dim) leads_embeded = self.wembed(leads) #(batch_size * max_num_doc, max_seqlen, hidden_dim) lead_hiddens, _ = self.w2s(leads_embeded, mask=lead_word_mask) return lead_hiddens, lead_word_mask, max_num_lead
def train_emb(self, images, captions, lengths, ids=None, target_align=None, lengths_whole=None, epoch=None, *args): """ one training step given images and captions """ self.Eiters += 1 self.logger.update('Eit', self.Eiters) self.logger.update('lr', self.optimizer.param_groups[0]['lr']) lengths = torch.Tensor(lengths).long() lengths_whole = torch.Tensor(lengths_whole).long() if torch.cuda.is_available(): lengths = lengths.cuda() lengths_whole = lengths_whole.cuda() lengths = lengths_whole # compute the embeddings img_emb, cap_span_features, left_span_features, right_span_features, word_embs, tree_indices, probs, \ span_bounds = self.forward_emb(images, captions, lengths, target_align, lengths_whole) # measure accuracy and record loss cum_reward, matching_loss = self.forward_reward( img_emb, cap_span_features, left_span_features, right_span_features, word_embs, lengths, span_bounds, lengths_whole) probs = torch.cat(probs, dim=0).reshape(-1, lengths.size(0)).transpose(0, 1) masks = sequence_mask(lengths - 1, lengths.max(0)[0] - 1).float() # import ipdb; ipdb.set_trace() rl_loss = torch.sum(-masks * torch.log(probs) * cum_reward.detach()) loss = rl_loss + matching_loss * self.vse_loss_alpha loss = loss / cum_reward.shape[0] self.logger.update('Loss', float(loss), img_emb.size(0)) self.logger.update('MatchLoss', float(matching_loss / cum_reward.shape[0]), img_emb.size(0)) self.logger.update('RL-Loss', float(rl_loss / cum_reward.shape[0]), img_emb.size(0)) # compute gradient and do SGD step self.optimizer.zero_grad() loss.backward() if self.grad_clip > 0: clip_grad_norm_(self.params, self.grad_clip) self.optimizer.step() # clean up if epoch > 0: del cum_reward del tree_indices del probs del cap_span_features del span_bounds
def select_max_action(self, probs, lengths, sql_labels): batch_size, max_len = probs.size(0), probs.size(1) actions = torch.max(probs, 2)[1] # mask mask = sequence_mask(lengths, max_len).to(device=self.args.device) actions.data.masked_fill_(1 - mask, -100) # compute acc b_error_1, b_error_2, b_error_3, b_error_4, rewards = self.compute_rewards( actions, sql_labels, mode='acc') return actions, rewards, b_error_1, b_error_2, b_error_3, b_error_4
def forward( self, decoder_hidden: torch.Tensor, encoder_hidden: torch.Tensor, encoder_lengths: torch.Tensor, ): """ Args: decoder_hidden (torch.Tensor): Query vector ``(batch, hidden_dim)``. encoder_hidden (torch.Tensor): Sequence of sources ``(batch, src_len, hidden_dim)``. encoder_lengths (torch.Tensor): The source sequence length ``(batch,)``. Returns: attn_h (torch.Tensor): The attentional hidden state ```(batch, src_len)``` """ tgt_batch, tgt_dim = decoder_hidden.shape src_batch, src_len, src_dim = encoder_hidden.shape assert src_batch == tgt_batch assert src_dim == tgt_dim # align_scores: (batch, src_len) align_scores = self.score(encoder_hidden, decoder_hidden) if encoder_lengths is not None: mask = sequence_mask( encoder_lengths, max_len=align_scores.shape[1] ) align_scores.masked_fill_(1 - mask, -float("inf")) # align_vector: (batch, src_len) align_vector = F.softmax(align_scores, dim=1) # (batch, 1, src_len) x (batch, src_len, hidden_dim) # --> (batch, 1, hidden_dim) # context_vector: (batch, hidden_dim) context_vector = torch.bmm( align_vector.unsqueeze(1), encoder_hidden ).squeeze(1) # concat_c_h: (batch, 2 * hidden_dim) concat_c_h = torch.cat([context_vector, decoder_hidden], dim=1) # attentional hidden state: (batch, hidden_dim) attn_h = torch.tanh(self.w_c(concat_c_h)) attn_h = self.dropout(attn_h) return attn_h
def _tighten(self, hy, y): """ pad tokens after EOS and mask hiddens after EOS hy: (B, MAXLEN+1, 700 y: (B, MAXLEN+1) """ lengths = get_actual_lengths(y) mask = sequence_mask(lengths) y = y[:, :mask.size(1)] # truncate unnecessarily generated part hy = hy[:, :mask.size(1)] y.masked_fill_((mask!=1), PAD_IDX) # this does not backprop hy = hy * (mask.unsqueeze(-1)).float() hy, y, lengths = sort_by_length(hy, y, lengths) return hy, y, lengths
def masked_cross_entropy(logits, target, length, per_example=False, decode=False): """ Args: logits (Variable, FloatTensor): [batch, max_len, num_classes] - unnormalized probability for each class target (Variable, LongTensor): [batch, max_len] - index of true class for each corresponding step length (Variable, LongTensor): [batch] - length of each data in a batch Returns: loss (Variable): [] - An average loss value masked by the length """ batch_size, max_len, num_classes = logits.size() # [batch_size * max_len, num_classes] logits_flat = logits.view(-1, num_classes) # [batch_size * max_len, num_classes] log_probs_flat = F.log_softmax(logits_flat, dim=1) # [batch_size * max_len, 1] target_flat = target.view(-1, 1) # Negative Log-likelihood: -sum { 1* log P(target) + 0 log P(non-target)} = -sum( log P(target) ) # [batch_size * max_len, 1] losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) # [batch_size, max_len] losses = losses_flat.view(batch_size, max_len) # [batch_size, max_len] mask = sequence_mask(sequence_length=length, max_len=max_len) # Apply masking on loss losses = losses * mask.float() # word-wise cross entropy # loss = losses.sum() / length.float().sum() if per_example: # loss: [batch_size] return losses.sum(1) else: loss = losses.sum() return loss, length.float().sum()
def embed_target(self, target): #target(batch_size, seqlen) #target_nwords(batch_size) target, target_nwords = target #(batch_size, seqlen) target_word_mask = sequence_mask(target_nwords, device=self.config.gpu) #(batch_size, seqlen, embed_dim) # target_embed = self.wembed(target) target_embed = self.tembed(target) target_word_mask = target_word_mask.unsqueeze(-1) target_embed = torch.sum(target_embed * target_word_mask, \ dim = 1) / torch.sum(target_word_mask, dim = 1) return target_embed
def forward(self, padded_input, input_lengths): """ Args: padded_input: N x T x D input_lengths: N Returns: enc_output: N x T x H """ x, input_lengths = self.conv(padded_input, input_lengths) x = self.dropout(x) alphas = self.linear(x).squeeze(-1) alphas = torch.sigmoid(alphas) pad_mask = sequence_mask(input_lengths) return alphas * pad_mask
def initialize(self, memory_bank, src_lengths, field_signals, device=None): """Initialize search state for each batch input""" # Repeat state in beam_size def fn_map_state(state, dim): return tile(state, self.beam_size, dim=dim) src_max_len = memory_bank.size(1) memory_bank = tile(memory_bank, self.beam_size) memory_pad_mask = tile(~sequence_mask(src_lengths, src_max_len), self.beam_size) self.memory_lengths = tile(src_lengths, self.beam_size) mb_device = memory_bank.device if device is None: self.device = mb_device self.field_signals = field_signals self.alive_seq = field_signals.repeat_interleave(self.beam_size)\ .unsqueeze(-1).to(self.device) self.is_finished = torch.zeros([self.batch_size, self.beam_size], dtype=torch.uint8, device=self.device) self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=self.device) self._beam_offset = torch.arange(0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=self.device) # Give full probability to the first beam on the first step; with no # prior information, choose any (the first beam) self.topk_log_probs = torch.tensor( [0.0] + [float("-inf")] * (self.beam_size - 1), device=self.device).repeat(self.batch_size) # buffers for the topk scores and 'backpointer' self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=self.device) self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=self.device) self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=self.device) return fn_map_state, memory_bank, memory_pad_mask
def training_step(self, batch, batch_idx, use_sharpen=True): question_inds = batch["question_inds"] seq_length = batch["seq_length"] image_feat = batch["image_feat"] answer_idx = batch.get("answer_idx", None) gt_layout = batch.get("layout_inds", None) bbox_ind = batch.get("bbox_ind", None) bbox_gt = batch.get("bbox_batch", None) bbox_offset = batch.get("bbox_offset", None) question_mask = sequence_mask(seq_length) outputs = self.online_model(question_inds, question_mask, image_feat) loss = torch.tensor(0.0, device=self.device, dtype=torch.float) # we support training on vqa only, loc only, or both, depending on these flags. if self.cfg.MODEL.BUILD_VQA and answer_idx is not None: loss += self.vqa_loss(outputs["logits"], answer_idx) self.train_acc(F.softmax(outputs["logits"], dim=1), answer_idx) self.log("train/vqa_acc", self.train_acc) if self.cfg.MODEL.BUILD_LOC and bbox_ind is not None: loss += self.loc_loss( outputs["loc_scores"], outputs["bbox_offset_fcn"], bbox_ind, bbox_offset ) feat_h, feat_w, _, _, stride_h, stride_w = self.img_sizes bbox_pred = batch_feat_grid2bbox( torch.argmax(outputs["loc_scores"], 1), outputs["bbox_offset"], stride_h, stride_w, feat_h, feat_w, ) accuracy = torch.mean( ( batch_bbox_iou(bbox_pred, bbox_gt) >= self.cfg.TRAIN.BBOX_IOU_THRESH ).float() ) self.log("train/loc_acc", accuracy) if self.cfg.TRAIN.USE_SHARPEN_LOSS and use_sharpen: loss += self.sharpen_loss(outputs["module_logits"]) if self.cfg.TRAIN.USE_GT_LAYOUT: loss += self.gt_loss(outputs["module_logits"], gt_layout) self.log("train/loss", loss) # technically this means the offline model is behind, but its fine. accumulate(self.offline_model, self.online_model) return loss
def __init__(self, batch, model_type, device='cuda'): self.batch_size = len(batch) pad_ = partial(pad_sequence, batch_first=True) self.inp = pad_([torch.tensor(x[0]) for x in batch]).to(device) lens = [ next((i for i, v in enumerate(s) if v == 0), len(s)) for s in self.inp ] self.src_lens = torch.LongTensor(lens).to(device) self.mask_inp = sequence_mask(self.src_lens, self.inp.size(1)) self.segs = pad_([torch.tensor(x[1]) for x in batch]).to(device) if model_type == 'rel': self.tgt = torch.tensor([x[2] for x in batch]).to(device) elif model_type in ['ext', 'abs']: self.tgt = pad_([torch.tensor(x[2]) for x in batch]).to(device) self.qid = [x[3] for x in batch] self.did = [x[4] for x in batch]
def select_action(self, probs, lengths, sql_labels): batch_size = probs.size(0) # notice the max_len max_len = probs.size(1) actions, log_probs = torch.LongTensor( batch_size, max_len).to(device=self.args.device), torch.FloatTensor( batch_size, max_len).to(device=self.args.device) probs = probs.transpose(0, 1).contiguous() for ti in range(max_len): actions[:, ti], log_probs[:, ti] = self.select_action_step(probs[ti]) # mask mask = sequence_mask(lengths, max_len).to(device=self.args.device) actions.data.masked_fill_(1 - mask, -100) log_probs.data.masked_fill_(1 - mask, 0.0) # compute rewards; (batch_size) rewards = self.compute_rewards(actions, sql_labels, mode='rewards') return torch.sum(log_probs, dim=1), rewards
def forward(self, x: Tensor, x_lens: Tensor = None): """ Args: x : input of shape `(batch_sz, seq_len, n_features)` x_lens : lengths of x of shape `(batch_sz)` """ x_proj = torch.tanh(self.proj(x)) x_queries_sim = self.queries(x_proj) if x_lens is not None: masks = sequence_mask(x_lens).unsqueeze(-1) # attn_w: (batch_sz, seq_len, n_head) attn_w = softmax_with_mask(x_queries_sim, masks.expand_as(x_queries_sim), dim=1) else: attn_w = F.softmax(x_queries_sim, dim=1) # x_attended: (batch_sz, n_head, n_features) x_attended = attn_w.transpose(2, 1) @ x self.attn_w = attn_w return self.pool(x_attended), attn_w
def loglik_ordinal(batch_data, list_type, theta, normalization_params): output = dict() epsilon = 1e-6 # Data outputs data, missing_mask = batch_data missing_mask = missing_mask.float() batch_size = data.size()[0] # We need to force that the outputs of the network increase with the categories partition_param, mean_param = theta mean_value = torch.reshape(mean_param, [-1, 1]) theta_values = torch.cumsum( torch.clamp(nn.Softplus()(partition_param), epsilon, 1e20), 1) sigmoid_est_mean = nn.Sigmoid()(theta_values - mean_value) mean_probs = torch.cat( [sigmoid_est_mean, torch.ones([batch_size, 1]).float()], 1) - torch.cat( [torch.zeros([batch_size, 1]).float(), sigmoid_est_mean], 1) mean_probs = torch.clamp(mean_probs, epsilon, 1.0) # Code needed to compute samples from an ordinal distribution true_values = one_hot(torch.sum(data.int(), 1) - 1, int(list_type['dim'])) # Compute loglik # log_p_x = -nn.softmax_cross_entropy_with_logits_v2(logits=torch.log(mean_probs), # labels=tf.stop_gradient(true_values)) log_p_x = -torch.nn.CrossEntropyLoss()(mean_probs, true_values) # .detach() ??? output['log_p_x'] = torch.mul(log_p_x, missing_mask) output['log_p_x_missing'] = torch.mul(log_p_x, 1.0 - missing_mask) output['params'] = mean_probs output['samples'] = sequence_mask(1 + td.Categorical( logits=torch.log(torch.clamp(mean_probs, epsilon, 1e20))).sample(), int(list_type['dim']), dtype=torch.float32) return output
def eval_data(dataset, process = 0): all_result = [] all_loss = [] process = min(process, len(dataset)) for number, [length, traj, index] in enumerate(dataset): traj = traj.transpose(0, 1) fake_input = cuda(torch.zeros((traj.shape[0], traj.shape[1], 0)).float()) model.eval() result = model(traj, length, fake_input) raw_output = model.get_result(traj, length).cpu().detach() output = torch.tensor(raw_output) for num in range(len(raw_output)): output[index[num]] = raw_output[num] all_result.append(output) mask = sequence_mask(length, args.max_length).transpose(0, 1) eval_loss = loss(result, traj, dim = 2) * mask eval_loss = eval_loss.sum(dim=0) / length.float() all_loss.append(eval_loss.cpu().detach()) if process > 0 and number % (len(dataset) // process) == 0: print('encoding %d / %d' % (number, len(dataset))) all_result = torch.cat(all_result) all_loss = torch.cat(all_loss) return all_result, all_loss.mean().item()
def forward(self, x, target, length): """ Args: x: A Variable containing a FloatTensor of size (batch, max_len) which contains the unnormalized probability for each class. target: A Variable containing a LongTensor of size (batch, max_len) which contains the index of the true class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Returns: loss: An average loss value in range [0, 1] masked by the length. """ # mask: (batch, max_len, 1) target.requires_grad = False mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() # loss = functional.binary_cross_entropy_with_logits( # x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum') loss = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight, reduction='sum')(x * mask, target * mask) loss = loss / mask.sum() return loss
def forward(self, batch): if self.general_config.embedding_model.find('elmo') >= 0: batch_size, passage_max_len, other = list( batch['passage_ids'].size()) else: batch_size, passage_max_len = list(batch['passage_ids'].size()) assert passage_max_len % 10 == 0 if self.general_config.embedding_model.find('elmo') >= 0: passage_ids = batch['passage_ids'].view( batch_size * 10, passage_max_len // 10, other) # [batch*10, passage/10, other] else: passage_ids = batch['passage_ids'].view( batch_size * 10, passage_max_len // 10) # [batch*10, passage/10] passage_repre = self.get_repre( passage_ids) # [batch*10, passage/10, elmo_emb] passage_repre, _ = self.passage_encoder( passage_repre) # [batch*10, passage/10, lstm_emb] emb_size = utils.shape(passage_repre, 2) passage_repre = passage_repre.contiguous().view( batch_size, passage_max_len, emb_size) question_repre = self.get_repre( batch['question_ids']) # [batch, question, elmo_emb] question_repre, _ = self.question_encoder( question_repre) # [batch, question, lstm_emb] # modeling question batch_size = len(batch['ids']) question_starts = torch.zeros(batch_size, 1, dtype=torch.long).cuda() # [batch, 1] question_ends = batch['question_lens'].view(batch_size, 1) - 1 # [batch, 1] question_types = torch.zeros(batch_size, 1, dtype=torch.long).cuda() # [batch, 1] question_mask_float = torch.ones( batch_size, 1, dtype=torch.float).cuda() # [batch, 1] question_emb = self.get_mention_embedding( question_repre, question_starts, question_ends, question_types, question_mask_float).squeeze(dim=1) # [batch, emb] # modeling mentions mention_starts = batch['mention_starts'] mention_ends = batch['mention_ends'] mention_types = batch['mention_types'] mention_nums = batch['mention_nums'] mention_max_num = utils.shape(mention_starts, 1) mention_mask = utils.sequence_mask(mention_nums, mention_max_num) mention_emb = self.get_mention_embedding(passage_repre, mention_starts, mention_ends, mention_types, mention_mask.float()) if self.general_config.mention_compress_size > 0: question_emb = self.mention_compressor(question_emb) mention_emb = self.mention_compressor(mention_emb) matching_results = [] rst_seq = self.perform_matching(mention_emb, question_emb) matching_results.append(rst_seq) # graph encoding if self.general_config.graph_encoding in ('GCN', 'GRN'): if self.general_config.graph_encoding in ("GRN", "GCN"): edges = batch['edges'] # [batch, mention, edge] edge_nums = batch['edge_nums'] # [batch, mention] edge_max_num = utils.shape(edges, 2) edge_mask = utils.sequence_mask( edge_nums.view(batch_size * mention_max_num), edge_max_num).view(batch_size, mention_max_num, edge_max_num) # [batch, mention, edge] assert not (edge_mask & (~mention_mask.unsqueeze(dim=2))).any().item() for i in range(self.general_config.graph_encoding_steps): mention_emb_new = self.graph_encoder(mention_emb, mention_mask.float(), edges, edge_mask.float()) mention_emb = mention_emb_new + mention_emb if self.general_config.graph_residual else mention_emb_new rst_graph = self.perform_matching(mention_emb, question_emb) matching_results.append(rst_graph) if len(matching_results) > 1: assert len(matching_results ) == self.general_config.graph_encoding_steps + 1 matching_results = torch.stack( matching_results, dim=2) # [batch, mention, graph_step+1] logits = self.matching_integrater(matching_results).squeeze( dim=2) # [batch, mention] else: assert len(matching_results) == 1 logits = matching_results[0] # [batch, mention] candidates, candidate_num, candidate_appear_num = \ batch['candidates'], batch['candidate_num'], batch['candidate_appear_num'] _, cand_max_num, cand_pos_max_num = list(candidates.size()) candidate_mask = utils.sequence_mask(candidate_num, cand_max_num) # [batch, cand] candidate_appear_mask = utils.sequence_mask( candidate_appear_num.view(batch_size * cand_max_num), cand_pos_max_num).view(batch_size, cand_max_num, cand_pos_max_num) # [batch, cand, pos] assert not (candidate_appear_mask & (~candidate_mask.unsqueeze(dim=2))).any().item() # ideas to get 'candidate_appear_dist' ## idea 1 #candidate_appear_logits = (utils.batch_gather(logits, candidates) + \ # candidate_appear_mask.float().log()).view(batch_size, cand_max_num * cand_pos_max_num) # [batch, cand * pos] #candidate_appear_logits = torch.clamp(candidate_appear_logits, -1e1, 1e1) # [batch, cand * pos] #candidate_appear_dist = F.softmax(candidate_appear_logits, dim=1).view(batch_size, # cand_max_num, cand_pos_max_num) # [batch, cand, pos] ## idea 2 #candidate_appear_dist = torch.clamp(utils.batch_gather(logits, candidates).exp() * \ # candidate_appear_mask.float(), 1e-6, 1e6).view(batch_size, cand_max_num * cand_pos_max_num) # [batch, cand * pos] #candidate_appear_dist = candidate_appear_dist / candidate_appear_dist.sum(dim=1, keepdim=True) #candidate_appear_dist = candidate_appear_dist.view(batch_size, cand_max_num, cand_pos_max_num) ## idea 3 #candidate_appear_dist = F.softmax(utils.batch_gather(logits, candidates).view(batch_size, # cand_max_num * cand_pos_max_num), dim=1) # [batch, cand * pos] #candidate_appear_dist = torch.clamp(candidate_appear_dist * candidate_appear_mask.view(batch_size, # cand_max_num * cand_pos_max_num).float(), 1e-8, 1.0) # [batch, cand * pos] #candidate_appear_dist = (candidate_appear_dist / candidate_appear_dist.sum(dim=1, keepdim=True)).view(batch_size, # cand_max_num, cand_pos_max_num) # [batch, cand, pos] ## get 'candidate_dist', which is common for idea 1, 2 and 3 #if not (candidate_appear_dist > 0).all().item(): # print(candidate_appear_dist) # assert False #candidate_dist = candidate_appear_dist.sum(dim=2) # [batch, cand] # original impl mention_dist = F.softmax(logits, dim=1) if utils.contain_nan(mention_dist): print(logits) print(mention_dist) assert False candidate_appear_dist = utils.batch_gather( mention_dist, candidates) * candidate_appear_mask.float() candidate_dist = candidate_appear_dist.sum( dim=2) * candidate_mask.float() candidate_dist = utils.clip_and_normalize(candidate_dist, 1e-6) assert utils.contain_nan(candidate_dist) == False # end of original impl candidate_logits = candidate_dist.log() # [batch, cand] predictions = candidate_logits.argmax(dim=1) # [batch] if not (predictions < candidate_num).all().item(): print(candidate_dist) print(candidate_num) assert False if 'refs' not in batch or batch['refs'] is None: return {'predictions': predictions} refs = batch['refs'] loss = nn.CrossEntropyLoss()(candidate_logits, refs) right_count = (predictions == refs).sum() return { 'predictions': predictions, 'loss': loss, 'right_count': right_count }