示例#1
0
    def embed_body(self, dcmt, target_embeded):

        dcmts, ndocs, nsents, dcmt_nwords = dcmt

        doc_mask = sequence_mask(ndocs, device=self.config.gpu)
        doc_sent_mask = sequence_mask(nsents, device=self.config.gpu)
        doc_word_mask = sequence_mask(dcmt_nwords, device=self.config.gpu)

        batch_size, max_num_doc, max_num_sent, max_seqlen = dcmts.size()
        dcmts_reshaped = dcmts.view(-1, max_seqlen)

        #(batch_size * max_num_doc * max_num_sent, max_seqlen, embed_dim)
        dcmts_embeded = self.wembed(dcmts_reshaped)

        # _, sent_reprs = self.w2s(dcmts_embeded, \
        #   mask = doc_word_mask.view(-1, doc_word_mask.size(-1), 1))

        #(batch_size * max_num_doc * max_num_sent, max_seqlen, hidden_dim * 2)

        hiddens, _ = self.w2s(dcmts_embeded)
        # ipdb.set_trace()
        target_embeded_expand = target_embeded.repeat(max_num_sent, 1, 1)
        sent_reprs, _ = attention(target_embeded_expand, hiddens, hiddens,\
         mask = doc_word_mask.view(-1, 1, doc_word_mask.size(-1)), dropout = self.dropout, scale = True)

        #(batch_size * max_num_doc, max_num_sent, hidden_dim * 2)
        sent_reprs = sent_reprs.view(-1, max_num_sent, sent_reprs.size(-1))
        doc_sent_mask = doc_sent_mask.view(-1, doc_sent_mask.size(-1), 1)
        return sent_reprs, doc_sent_mask, doc_mask
示例#2
0
    def embed_doc(self, target, lead, dcmt):

        tgt_hiddens, tgt_output, tgt_nword, tgt_mask =\
         self.embed_sent(target, self.t2v)

        #(batch_size, max_num_sent, max_num_word)
        text, ndoc, nword = dcmt

        # target,
        max_num_word = text.size(-1)
        #(batch_size, max_num_sent, 1)
        doc_mask = sequence_mask(ndoc, device=self.config.gpu).unsqueeze(-1)
        #(batch_size * max_num_sent, max_num_word, 1)
        word_mask = sequence_mask(nword, device=self.config.gpu).view(
            -1, max_num_word, 1)

        #(batch_size * max_num_sent, max_num_words, embed_dim)
        text_embed = self.embed(text.view(-1, max_num_word))

        #(batch_size * max_num_sent, output_dim)
        _, sent_repr = self.w2s(text_embed, mask=word_mask, init=tgt_output[1])

        sent_repr = self.dropout(sent_repr)
        sent_repr = sent_repr.view(*text.size()[:2], -1)

        _, doc_repr = self.s2d(sent_repr, mask=doc_mask, init=tgt_output[1])
        doc_repr = self.dropout(doc_repr)

        return doc_repr, _
示例#3
0
    def embed_body(self, target, dcmt):

        dcmts, ndocs, nsents, dcmt_nwords = dcmt

        doc_mask = sequence_mask(ndocs, device=self.config.gpu)
        doc_sent_mask = sequence_mask(nsents, device=self.config.gpu)
        doc_word_mask = sequence_mask(dcmt_nwords, device=self.config.gpu)

        batch_size, max_num_doc, max_num_sent, max_seqlen = dcmts.size()
        dcmts_reshaped = dcmts.view(-1, max_seqlen)

        #(batch_size * max_num_doc, embed_dim)
        target_embed = self.embed_target(target)

        #(batch_size * max_num_doc * max_num_sent, max_seqlen, embed_dim)
        dcmts_embeded = self.wembed(dcmts_reshaped)

        #(batch_size * max_num_doc * max_num_sent, hidden_dim * 2)
        _, sent_reprs = self.w2s(dcmts_embeded, init = target_embed,\
         mask = doc_word_mask.view(-1, doc_word_mask.size(-1), 1))
        sent_reprs = self.dropout(sent_reprs)

        #(batch_size * max_num_doc, max_num_sent, hidden_dim * 2)
        sent_reprs = sent_reprs.view(-1, max_num_sent, sent_reprs.size(-1))
        doc_sent_mask = doc_sent_mask.view(-1, doc_sent_mask.size(-1), 1)

        #(batch_size * max_num_doc, max_num_sent, hidden_dim * 2)
        _, doc_reprs = self.s2d(sent_reprs, mask=doc_sent_mask)

        doc_reprs = doc_reprs.view(batch_size, max_num_doc, -1)
        doc_reprs = self.dropout(doc_reprs)

        return doc_reprs, doc_mask
示例#4
0
    def embed_body(self, tgt_output, dcmt):

        #(batch_size, max_num_sent, max_num_word)
        text, ndoc, nword = dcmt

        # target,
        max_num_word = text.size(-1)
        #(batch_size, max_num_sent, 1)
        sent_mask = sequence_mask(ndoc, device=self.config.gpu).unsqueeze(-1)
        #(batch_size * max_num_sent, max_num_word, 1)
        word_mask = sequence_mask(nword, device=self.config.gpu).view(
            -1, max_num_word, 1)

        #(batch_size * max_num_sent, max_num_words, embed_dim)
        text_embed = self.embed(text.view(-1, max_num_word))

        #(batch_size * max_num_sent, output_dim)
        _, sent_repr = self.w2s(text_embed, mask=word_mask, init=tgt_output[1])

        sent_repr = self.dropout(sent_repr)
        sent_repr = sent_repr.view(*text.size()[:2], -1)
        ctx_sent_repr, _ = self.s2d(sent_repr, mask=sent_mask)
        ctx_sent_repr = self.dropout(ctx_sent_repr)

        return ctx_sent_repr, sent_mask
示例#5
0
    def forward(self, x, length):
        """Maps input to last hidden state, to pooler_output, to prediction
        Args:
            x (torch.LongTensor): input of shape (batch_size, seq_length)
            length (torch.LongTensor): input of shape (batch_size, )

        Returns:
            x (torch.FloatTensor): logits of shape (batch_size, NUM_CLASS)

        """
        x_mask = sequence_mask(length, pad=0, dtype=torch.float)  # (batch_size, max_length)
        # TODO: clean this hack
        try:  # bert, roberta
            _, cls, hidden_states = self.model(x, attention_mask=x_mask)
            # hidden_states : length 13 tuple of tensors (batch_size, max_length, hidden_size)
            if len(self.layer) == 1:
                x = self.time_pooling(cls, hidden_states[self.layer[0]], length)
            else:
                x = self.layer_pooling([self.time_pooling(cls, hidden_states[layer], length)
                                        for layer in self.layer])
        except:  # xlm, xlnet
            x = self.model(x, attention_mask=x_mask)  # (batch_size, seq_length, hidden_size)
            x = x[0]
        x = self.out(x)  # (batch_size, NUM_CLASS)
        return x
示例#6
0
 def embed_sent(self, sent, encoder, h0=None):
     snt, stn_nword = sent
     word_mask = sequence_mask(stn_nword,
                               device=self.config.gpu).unsqueeze(-1)
     snt_embed = self.embed(snt)
     hiddens, output = encoder(snt_embed, init=h0, mask=word_mask)
     return hiddens, output, stn_nword, word_mask
示例#7
0
 def forward(self, x, target, length):
     """
     Args:
         x: A Variable containing a FloatTensor of size
             (batch, max_len, dim) which contains the
             unnormalized probability for each class.
         target: A Variable containing a LongTensor of size
             (batch, max_len, dim) which contains the index of the true
             class for each corresponding step.
         length: A Variable containing a LongTensor of size (batch,)
             which contains the length of each data in a batch.
     Returns:
         loss: An average loss value in range [0, 1] masked by the length.
     """
     # mask: (batch, max_len, 1)
     target.requires_grad = False
     mask = sequence_mask(sequence_length=length,
                          max_len=target.size(1)).unsqueeze(2).float()
     if self.seq_len_norm:
         norm_w = mask / mask.sum(dim=1, keepdim=True)
         out_weights = norm_w.div(target.shape[0] * target.shape[2])
         mask = mask.expand_as(x)
         # loss = functional.mse_loss(
         #     x * mask, target * mask, reduction='none')
         loss = nn.MSELoss(reduction='none')(x * mask, target * mask)
         loss = loss.mul(out_weights.to(loss.device)).sum()
     else:
         mask = mask.expand_as(x)
         # loss = functional.mse_loss(
         #     x * mask, target * mask, reduction='sum')
         loss = nn.MSELoss(reduction='sum')(x * mask, target * mask)
         loss = loss / mask.sum()
     return loss
示例#8
0
 def select_m_actions(self, probs, lengths, sql_labels):
     batch_size = probs.size(0)
     max_len = probs.size(1)
     probs = probs.transpose(0, 1).contiguous()
     if self.args.model == 'gate':
         probs = F.softmax(probs, dim=-1)
     m_log_probs, m_rewards = torch.FloatTensor(
         batch_size,
         self.args.m).to(device=self.args.device), torch.FloatTensor(
             batch_size, self.args.m).to(device=self.args.device)
     for index in range(self.args.m):
         # notice the max_len
         actions, log_probs = torch.LongTensor(
             batch_size,
             max_len).to(device=self.args.device), torch.FloatTensor(
                 batch_size, max_len).to(device=self.args.device)
         for ti in range(max_len):
             actions[:,
                     ti], log_probs[:,
                                    ti] = self.select_action_step(probs[ti])
         # mask
         mask = sequence_mask(lengths, max_len).to(device=self.args.device)
         actions.data.masked_fill_(1 - mask, -100)
         log_probs.data.masked_fill_(1 - mask, 0.0)
         # compute rewards
         rewards = self.compute_rewards(actions, sql_labels, mode='rewards')
         m_log_probs[:, index], m_rewards[:,
                                          index] = torch.sum(log_probs,
                                                             dim=1), rewards
     m_rewards -= m_rewards.mean(dim=-1).view(batch_size, 1)
     return m_log_probs, m_rewards
示例#9
0
    def _feature(self, inputs, lengths, tags_one_hot=None):
        """"""
        w_lengths, word_sort_ind = lengths.sort(dim=0, descending=True)
        # should catch from  proper index
        inputs = inputs[word_sort_ind].to(device)
        if tags_one_hot is not None:
            tags_one_hot = tags_one_hot[word_sort_ind].byte().to(device)

        # compute features
        inputs_emb = self.embeddings(inputs)
        w = self.dropout(inputs_emb)

        # Pack padded sequence
        w = torch.nn.utils.rnn.pack_padded_sequence(
            w, list(w_lengths), batch_first=True
        )  # packed sequence of word_emb_dim + 2 * char_rnn_dim, with real sequence lengths

        # LSTM
        w, _ = self.BiLSTM(
            w)  # packed sequence of word_rnn_dim, with real sequence lengths
        # Unpack packed sequence

        w, _ = torch.nn.utils.rnn.pad_packed_sequence(
            w, batch_first=True
        )  # (batch_size, max_word_len_in_batch, word_rnn_dim)

        w = self.dropout(w)

        mask = sequence_mask(w_lengths).float()

        crf_scores = self.crf_layer(w)
        return crf_scores, tags_one_hot, mask, w_lengths, word_sort_ind
示例#10
0
 def _qe_masking(qe):
     mask = utils.sequence_mask(
         torch.arange(qe.size()[-1] - 1,
                      qe.size()[-1] - qe.size()[-2] - 1, -1).to(qe.device),
         qe.size()[-1])
     mask = ~mask.to(mask.device)
     return mask.to(qe.dtype) * qe
示例#11
0
 def _test_step(self, batch):
     question_inds = batch["question_inds"]
     seq_length = batch["seq_length"]
     image_feat = batch["image_feat"]
     question_mask = sequence_mask(seq_length)
     outputs = self.offline_model(question_inds, question_mask, image_feat)
     return outputs
示例#12
0
def train_encoder(mdl, crit, optim, sch, stat):
    """Train REL or EXT model"""
    logger.info(f'*** Epoch {stat.epoch} ***')
    mdl.train()
    it = DataLoader(load_dataset(args.dir_data,
                                 'train'), args.model_type, args.batch_size,
                    args.max_ntokens_src, spt_ids_B, spt_ids_C, eos_mapping)
    for batch in it:
        _, logits = mdl(batch)
        mask_inp = utils.sequence_mask(batch.src_lens, batch.inp.size(1))
        loss = crit(logits, batch.tgt, mask_inp)
        loss.backward()
        stat.update(loss,
                    'train',
                    args.model_type,
                    logits=logits,
                    labels=batch.tgt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        optim.step()
        if stat.steps == 0:
            continue
        if stat.steps % args.log_interval == 0:
            stat.lr = optim.param_groups[0]['lr']
            stat.report()
            sch.step(stat.avg_train_loss)
        if stat.steps % args.valid_interval == 0:
            valid_ret(mdl, crit, optim, stat)
示例#13
0
    def embed_body(self, dcmt):

        text, ndoc, nword = dcmt
        # target,
        max_num_word = text.size(-1)
        #(batch_size, max_num_sent, 1)
        sent_mask = sequence_mask(ndoc, device=self.config.gpu).unsqueeze(-1)
        #(batch_size * max_num_sent, max_num_word, 1)
        word_mask = sequence_mask(nword, device=self.config.gpu).view(
            -1, max_num_word, 1)
        #(batch_size * max_num_sent, max_num_words, embed_dim)
        text_embed = self.embed(text.view(-1, max_num_word))

        #(batch_size * max_num_sent, max_num_word, output_dim)
        sent_hiddens, _ = self.w2s(text_embed, mask=word_mask)

        return sent_hiddens, sent_mask, word_mask
示例#14
0
    def embed_lead(self, leads):

        leads, nleads, lead_nwords = leads
        max_num_lead = lead_nwords.size(-1)

        # (batch_size, max_num_doc)
        lead_mask = sequence_mask(nleads, device=self.config.gpu)
        # (batch_size, max_num_doc, max_seqlen)
        lead_word_mask = sequence_mask(lead_nwords, device=self.config.gpu)
        lead_word_mask = lead_word_mask.view(-1, lead_word_mask.size(-1), 1)

        leads = leads.view(-1, leads.size(-1))
        #(batch_size * max_num_doc, max_seqlen, embed_dim)
        leads_embeded = self.wembed(leads)

        #(batch_size * max_num_doc, max_seqlen, hidden_dim)
        lead_hiddens, _ = self.w2s(leads_embeded, mask=lead_word_mask)
        return lead_hiddens, lead_word_mask, max_num_lead
示例#15
0
    def train_emb(self,
                  images,
                  captions,
                  lengths,
                  ids=None,
                  target_align=None,
                  lengths_whole=None,
                  epoch=None,
                  *args):
        """ one training step given images and captions """
        self.Eiters += 1
        self.logger.update('Eit', self.Eiters)
        self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
        lengths = torch.Tensor(lengths).long()
        lengths_whole = torch.Tensor(lengths_whole).long()
        if torch.cuda.is_available():
            lengths = lengths.cuda()
            lengths_whole = lengths_whole.cuda()
        lengths = lengths_whole
        # compute the embeddings
        img_emb, cap_span_features, left_span_features, right_span_features, word_embs, tree_indices, probs, \
            span_bounds = self.forward_emb(images, captions, lengths, target_align, lengths_whole)

        # measure accuracy and record loss
        cum_reward, matching_loss = self.forward_reward(
            img_emb, cap_span_features, left_span_features,
            right_span_features, word_embs, lengths, span_bounds,
            lengths_whole)
        probs = torch.cat(probs,
                          dim=0).reshape(-1, lengths.size(0)).transpose(0, 1)
        masks = sequence_mask(lengths - 1, lengths.max(0)[0] - 1).float()
        # import ipdb; ipdb.set_trace()
        rl_loss = torch.sum(-masks * torch.log(probs) * cum_reward.detach())

        loss = rl_loss + matching_loss * self.vse_loss_alpha
        loss = loss / cum_reward.shape[0]
        self.logger.update('Loss', float(loss), img_emb.size(0))
        self.logger.update('MatchLoss',
                           float(matching_loss / cum_reward.shape[0]),
                           img_emb.size(0))
        self.logger.update('RL-Loss', float(rl_loss / cum_reward.shape[0]),
                           img_emb.size(0))

        # compute gradient and do SGD step
        self.optimizer.zero_grad()
        loss.backward()
        if self.grad_clip > 0:
            clip_grad_norm_(self.params, self.grad_clip)
        self.optimizer.step()

        # clean up
        if epoch > 0:
            del cum_reward
            del tree_indices
            del probs
            del cap_span_features
            del span_bounds
示例#16
0
 def select_max_action(self, probs, lengths, sql_labels):
     batch_size, max_len = probs.size(0), probs.size(1)
     actions = torch.max(probs, 2)[1]
     # mask
     mask = sequence_mask(lengths, max_len).to(device=self.args.device)
     actions.data.masked_fill_(1 - mask, -100)
     # compute acc
     b_error_1, b_error_2, b_error_3, b_error_4, rewards = self.compute_rewards(
         actions, sql_labels, mode='acc')
     return actions, rewards, b_error_1, b_error_2, b_error_3, b_error_4
示例#17
0
    def forward(
        self,
        decoder_hidden: torch.Tensor,
        encoder_hidden: torch.Tensor,
        encoder_lengths: torch.Tensor,
    ):
        """
        Args:
            decoder_hidden (torch.Tensor): Query vector
                ``(batch, hidden_dim)``.
            encoder_hidden (torch.Tensor): Sequence of sources
                ``(batch, src_len, hidden_dim)``.
            encoder_lengths (torch.Tensor): The source sequence length
                ``(batch,)``.

        Returns:
            attn_h (torch.Tensor): The attentional hidden state
                ```(batch, src_len)```


        """
        tgt_batch, tgt_dim = decoder_hidden.shape
        src_batch, src_len, src_dim = encoder_hidden.shape

        assert src_batch == tgt_batch
        assert src_dim == tgt_dim

        # align_scores: (batch, src_len)
        align_scores = self.score(encoder_hidden, decoder_hidden)

        if encoder_lengths is not None:
            mask = sequence_mask(
                encoder_lengths, max_len=align_scores.shape[1]
            )
            align_scores.masked_fill_(1 - mask, -float("inf"))

        # align_vector:  (batch, src_len)
        align_vector = F.softmax(align_scores, dim=1)

        # (batch, 1, src_len) x (batch, src_len, hidden_dim)
        #  --> (batch, 1, hidden_dim)
        # context_vector: (batch, hidden_dim)
        context_vector = torch.bmm(
            align_vector.unsqueeze(1), encoder_hidden
        ).squeeze(1)

        # concat_c_h: (batch, 2 * hidden_dim)
        concat_c_h = torch.cat([context_vector, decoder_hidden], dim=1)

        # attentional hidden state: (batch, hidden_dim)
        attn_h = torch.tanh(self.w_c(concat_c_h))
        attn_h = self.dropout(attn_h)

        return attn_h
示例#18
0
文件: model.py 项目: hwijeen/CPTG
 def _tighten(self, hy, y):
     """
     pad tokens after EOS and mask hiddens after EOS
     hy: (B, MAXLEN+1, 700
     y: (B, MAXLEN+1)
     """
     lengths = get_actual_lengths(y)
     mask = sequence_mask(lengths)
     y = y[:, :mask.size(1)] # truncate unnecessarily generated part
     hy = hy[:, :mask.size(1)]
     y.masked_fill_((mask!=1), PAD_IDX) # this does not backprop
     hy = hy * (mask.unsqueeze(-1)).float()
     hy, y, lengths = sort_by_length(hy, y, lengths)
     return hy, y, lengths
示例#19
0
def masked_cross_entropy(logits,
                         target,
                         length,
                         per_example=False,
                         decode=False):
    """
    Args:
        logits (Variable, FloatTensor): [batch, max_len, num_classes]
            - unnormalized probability for each class
        target (Variable, LongTensor): [batch, max_len]
            - index of true class for each corresponding step
        length (Variable, LongTensor): [batch]
            - length of each data in a batch
    Returns:
        loss (Variable): []
            - An average loss value masked by the length
    """
    batch_size, max_len, num_classes = logits.size()

    # [batch_size * max_len, num_classes]
    logits_flat = logits.view(-1, num_classes)

    # [batch_size * max_len, num_classes]
    log_probs_flat = F.log_softmax(logits_flat, dim=1)

    # [batch_size * max_len, 1]
    target_flat = target.view(-1, 1)

    # Negative Log-likelihood: -sum {  1* log P(target)  + 0 log P(non-target)} = -sum( log P(target) )
    # [batch_size * max_len, 1]
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)

    # [batch_size, max_len]
    losses = losses_flat.view(batch_size, max_len)

    # [batch_size, max_len]
    mask = sequence_mask(sequence_length=length, max_len=max_len)

    # Apply masking on loss
    losses = losses * mask.float()

    # word-wise cross entropy
    # loss = losses.sum() / length.float().sum()

    if per_example:
        # loss: [batch_size]
        return losses.sum(1)
    else:
        loss = losses.sum()
        return loss, length.float().sum()
示例#20
0
    def embed_target(self, target):
        #target(batch_size, seqlen)
        #target_nwords(batch_size)
        target, target_nwords = target
        #(batch_size, seqlen)
        target_word_mask = sequence_mask(target_nwords, device=self.config.gpu)

        #(batch_size, seqlen, embed_dim)
        # target_embed = self.wembed(target)
        target_embed = self.tembed(target)
        target_word_mask = target_word_mask.unsqueeze(-1)
        target_embed = torch.sum(target_embed * target_word_mask, \
          dim = 1) / torch.sum(target_word_mask, dim = 1)

        return target_embed
示例#21
0
    def forward(self, padded_input, input_lengths):
        """
        Args:
            padded_input: N x T x D
            input_lengths: N

        Returns:
            enc_output: N x T x H
        """
        x, input_lengths = self.conv(padded_input, input_lengths)
        x = self.dropout(x)
        alphas = self.linear(x).squeeze(-1)
        alphas = torch.sigmoid(alphas)
        pad_mask = sequence_mask(input_lengths)

        return alphas * pad_mask
示例#22
0
    def initialize(self, memory_bank, src_lengths, field_signals, device=None):
        """Initialize search state for each batch input"""

        # Repeat state in beam_size
        def fn_map_state(state, dim):
            return tile(state, self.beam_size, dim=dim)

        src_max_len = memory_bank.size(1)
        memory_bank = tile(memory_bank, self.beam_size)
        memory_pad_mask = tile(~sequence_mask(src_lengths, src_max_len),
                               self.beam_size)
        self.memory_lengths = tile(src_lengths, self.beam_size)
        mb_device = memory_bank.device
        if device is None:
            self.device = mb_device
        self.field_signals = field_signals
        self.alive_seq = field_signals.repeat_interleave(self.beam_size)\
            .unsqueeze(-1).to(self.device)
        self.is_finished = torch.zeros([self.batch_size, self.beam_size],
                                       dtype=torch.uint8,
                                       device=self.device)
        self.best_scores = torch.full([self.batch_size],
                                      -1e10,
                                      dtype=torch.float,
                                      device=self.device)
        self._beam_offset = torch.arange(0,
                                         self.batch_size * self.beam_size,
                                         step=self.beam_size,
                                         dtype=torch.long,
                                         device=self.device)
        # Give full probability to the first beam on the first step; with no
        # prior information, choose any (the first beam)
        self.topk_log_probs = torch.tensor(
            [0.0] + [float("-inf")] * (self.beam_size - 1),
            device=self.device).repeat(self.batch_size)
        # buffers for the topk scores and 'backpointer'
        self.topk_scores = torch.empty((self.batch_size, self.beam_size),
                                       dtype=torch.float,
                                       device=self.device)
        self.topk_ids = torch.empty((self.batch_size, self.beam_size),
                                    dtype=torch.long,
                                    device=self.device)
        self._batch_index = torch.empty([self.batch_size, self.beam_size],
                                        dtype=torch.long,
                                        device=self.device)
        return fn_map_state, memory_bank, memory_pad_mask
示例#23
0
 def training_step(self, batch, batch_idx, use_sharpen=True):
     question_inds = batch["question_inds"]
     seq_length = batch["seq_length"]
     image_feat = batch["image_feat"]
     answer_idx = batch.get("answer_idx", None)
     gt_layout = batch.get("layout_inds", None)
     bbox_ind = batch.get("bbox_ind", None)
     bbox_gt = batch.get("bbox_batch", None)
     bbox_offset = batch.get("bbox_offset", None)
     question_mask = sequence_mask(seq_length)
     outputs = self.online_model(question_inds, question_mask, image_feat)
     loss = torch.tensor(0.0, device=self.device, dtype=torch.float)
     # we support training on vqa only, loc only, or both, depending on these flags.
     if self.cfg.MODEL.BUILD_VQA and answer_idx is not None:
         loss += self.vqa_loss(outputs["logits"], answer_idx)
         self.train_acc(F.softmax(outputs["logits"], dim=1), answer_idx)
         self.log("train/vqa_acc", self.train_acc)
     if self.cfg.MODEL.BUILD_LOC and bbox_ind is not None:
         loss += self.loc_loss(
             outputs["loc_scores"], outputs["bbox_offset_fcn"], bbox_ind, bbox_offset
         )
         feat_h, feat_w, _, _, stride_h, stride_w = self.img_sizes
         bbox_pred = batch_feat_grid2bbox(
             torch.argmax(outputs["loc_scores"], 1),
             outputs["bbox_offset"],
             stride_h,
             stride_w,
             feat_h,
             feat_w,
         )
         accuracy = torch.mean(
             (
                 batch_bbox_iou(bbox_pred, bbox_gt) >= self.cfg.TRAIN.BBOX_IOU_THRESH
             ).float()
         )
         self.log("train/loc_acc", accuracy)
     if self.cfg.TRAIN.USE_SHARPEN_LOSS and use_sharpen:
         loss += self.sharpen_loss(outputs["module_logits"])
     if self.cfg.TRAIN.USE_GT_LAYOUT:
         loss += self.gt_loss(outputs["module_logits"], gt_layout)
     self.log("train/loss", loss)
     # technically this means the offline model is behind, but its fine.
     accumulate(self.offline_model, self.online_model)
     return loss
示例#24
0
    def __init__(self, batch, model_type, device='cuda'):
        self.batch_size = len(batch)

        pad_ = partial(pad_sequence, batch_first=True)
        self.inp = pad_([torch.tensor(x[0]) for x in batch]).to(device)
        lens = [
            next((i for i, v in enumerate(s) if v == 0), len(s))
            for s in self.inp
        ]
        self.src_lens = torch.LongTensor(lens).to(device)
        self.mask_inp = sequence_mask(self.src_lens, self.inp.size(1))
        self.segs = pad_([torch.tensor(x[1]) for x in batch]).to(device)
        if model_type == 'rel':
            self.tgt = torch.tensor([x[2] for x in batch]).to(device)
        elif model_type in ['ext', 'abs']:
            self.tgt = pad_([torch.tensor(x[2]) for x in batch]).to(device)

        self.qid = [x[3] for x in batch]
        self.did = [x[4] for x in batch]
示例#25
0
 def select_action(self, probs, lengths, sql_labels):
     batch_size = probs.size(0)
     # notice the max_len
     max_len = probs.size(1)
     actions, log_probs = torch.LongTensor(
         batch_size,
         max_len).to(device=self.args.device), torch.FloatTensor(
             batch_size, max_len).to(device=self.args.device)
     probs = probs.transpose(0, 1).contiguous()
     for ti in range(max_len):
         actions[:, ti], log_probs[:,
                                   ti] = self.select_action_step(probs[ti])
     # mask
     mask = sequence_mask(lengths, max_len).to(device=self.args.device)
     actions.data.masked_fill_(1 - mask, -100)
     log_probs.data.masked_fill_(1 - mask, 0.0)
     # compute rewards; (batch_size)
     rewards = self.compute_rewards(actions, sql_labels, mode='rewards')
     return torch.sum(log_probs, dim=1), rewards
示例#26
0
 def forward(self, x: Tensor, x_lens: Tensor = None):
     """
     Args:
         x : input of shape `(batch_sz, seq_len, n_features)`
         x_lens : lengths of x of shape `(batch_sz)`
     """
     x_proj = torch.tanh(self.proj(x))
     x_queries_sim = self.queries(x_proj)
     if x_lens is not None:
         masks = sequence_mask(x_lens).unsqueeze(-1)
         # attn_w: (batch_sz, seq_len, n_head)
         attn_w = softmax_with_mask(x_queries_sim,
                                    masks.expand_as(x_queries_sim),
                                    dim=1)
     else:
         attn_w = F.softmax(x_queries_sim, dim=1)
     # x_attended: (batch_sz, n_head, n_features)
     x_attended = attn_w.transpose(2, 1) @ x
     self.attn_w = attn_w
     return self.pool(x_attended), attn_w
示例#27
0
def loglik_ordinal(batch_data, list_type, theta, normalization_params):
    output = dict()
    epsilon = 1e-6

    # Data outputs
    data, missing_mask = batch_data
    missing_mask = missing_mask.float()
    batch_size = data.size()[0]

    # We need to force that the outputs of the network increase with the categories
    partition_param, mean_param = theta
    mean_value = torch.reshape(mean_param, [-1, 1])
    theta_values = torch.cumsum(
        torch.clamp(nn.Softplus()(partition_param), epsilon, 1e20), 1)
    sigmoid_est_mean = nn.Sigmoid()(theta_values - mean_value)
    mean_probs = torch.cat(
        [sigmoid_est_mean,
         torch.ones([batch_size, 1]).float()], 1) - torch.cat(
             [torch.zeros([batch_size, 1]).float(), sigmoid_est_mean], 1)

    mean_probs = torch.clamp(mean_probs, epsilon, 1.0)

    # Code needed to compute samples from an ordinal distribution
    true_values = one_hot(torch.sum(data.int(), 1) - 1, int(list_type['dim']))

    # Compute loglik
    # log_p_x = -nn.softmax_cross_entropy_with_logits_v2(logits=torch.log(mean_probs),
    #                                                       labels=tf.stop_gradient(true_values))
    log_p_x = -torch.nn.CrossEntropyLoss()(mean_probs,
                                           true_values)  # .detach() ???

    output['log_p_x'] = torch.mul(log_p_x, missing_mask)
    output['log_p_x_missing'] = torch.mul(log_p_x, 1.0 - missing_mask)
    output['params'] = mean_probs
    output['samples'] = sequence_mask(1 + td.Categorical(
        logits=torch.log(torch.clamp(mean_probs, epsilon, 1e20))).sample(),
                                      int(list_type['dim']),
                                      dtype=torch.float32)

    return output
示例#28
0
def eval_data(dataset, process = 0):
    all_result = []
    all_loss = []
    process = min(process, len(dataset))
    for number, [length, traj, index] in enumerate(dataset):
        traj = traj.transpose(0, 1)
        fake_input = cuda(torch.zeros((traj.shape[0], traj.shape[1], 0)).float())
        model.eval()
        result = model(traj, length, fake_input)
        raw_output = model.get_result(traj, length).cpu().detach()
        output = torch.tensor(raw_output)
        for num in range(len(raw_output)):
            output[index[num]] = raw_output[num]
        all_result.append(output)
        mask = sequence_mask(length, args.max_length).transpose(0, 1)
        eval_loss = loss(result, traj, dim = 2) * mask
        eval_loss = eval_loss.sum(dim=0) / length.float()
        all_loss.append(eval_loss.cpu().detach())
        if process > 0 and number % (len(dataset) // process) == 0:
            print('encoding %d / %d' % (number, len(dataset)))
    all_result = torch.cat(all_result)
    all_loss = torch.cat(all_loss)
    return all_result, all_loss.mean().item()
示例#29
0
 def forward(self, x, target, length):
     """
     Args:
         x: A Variable containing a FloatTensor of size
             (batch, max_len) which contains the
             unnormalized probability for each class.
         target: A Variable containing a LongTensor of size
             (batch, max_len) which contains the index of the true
             class for each corresponding step.
         length: A Variable containing a LongTensor of size (batch,)
             which contains the length of each data in a batch.
     Returns:
         loss: An average loss value in range [0, 1] masked by the length.
     """
     # mask: (batch, max_len, 1)
     target.requires_grad = False
     mask = sequence_mask(sequence_length=length,
                          max_len=target.size(1)).float()
     # loss = functional.binary_cross_entropy_with_logits(
     #     x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum')
     loss = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight,
                                 reduction='sum')(x * mask, target * mask)
     loss = loss / mask.sum()
     return loss
示例#30
0
    def forward(self, batch):
        if self.general_config.embedding_model.find('elmo') >= 0:
            batch_size, passage_max_len, other = list(
                batch['passage_ids'].size())
        else:
            batch_size, passage_max_len = list(batch['passage_ids'].size())
        assert passage_max_len % 10 == 0

        if self.general_config.embedding_model.find('elmo') >= 0:
            passage_ids = batch['passage_ids'].view(
                batch_size * 10, passage_max_len // 10,
                other)  # [batch*10, passage/10, other]
        else:
            passage_ids = batch['passage_ids'].view(
                batch_size * 10,
                passage_max_len // 10)  # [batch*10, passage/10]

        passage_repre = self.get_repre(
            passage_ids)  # [batch*10, passage/10, elmo_emb]
        passage_repre, _ = self.passage_encoder(
            passage_repre)  # [batch*10, passage/10, lstm_emb]
        emb_size = utils.shape(passage_repre, 2)
        passage_repre = passage_repre.contiguous().view(
            batch_size, passage_max_len, emb_size)

        question_repre = self.get_repre(
            batch['question_ids'])  # [batch, question, elmo_emb]
        question_repre, _ = self.question_encoder(
            question_repre)  # [batch, question, lstm_emb]

        # modeling question
        batch_size = len(batch['ids'])
        question_starts = torch.zeros(batch_size, 1,
                                      dtype=torch.long).cuda()  # [batch, 1]
        question_ends = batch['question_lens'].view(batch_size,
                                                    1) - 1  # [batch, 1]
        question_types = torch.zeros(batch_size, 1,
                                     dtype=torch.long).cuda()  # [batch, 1]
        question_mask_float = torch.ones(
            batch_size, 1, dtype=torch.float).cuda()  # [batch, 1]
        question_emb = self.get_mention_embedding(
            question_repre, question_starts, question_ends, question_types,
            question_mask_float).squeeze(dim=1)  # [batch, emb]

        # modeling mentions
        mention_starts = batch['mention_starts']
        mention_ends = batch['mention_ends']
        mention_types = batch['mention_types']
        mention_nums = batch['mention_nums']

        mention_max_num = utils.shape(mention_starts, 1)
        mention_mask = utils.sequence_mask(mention_nums, mention_max_num)
        mention_emb = self.get_mention_embedding(passage_repre, mention_starts,
                                                 mention_ends, mention_types,
                                                 mention_mask.float())

        if self.general_config.mention_compress_size > 0:
            question_emb = self.mention_compressor(question_emb)
            mention_emb = self.mention_compressor(mention_emb)

        matching_results = []
        rst_seq = self.perform_matching(mention_emb, question_emb)
        matching_results.append(rst_seq)

        # graph encoding
        if self.general_config.graph_encoding in ('GCN', 'GRN'):
            if self.general_config.graph_encoding in ("GRN", "GCN"):
                edges = batch['edges']  # [batch, mention, edge]
                edge_nums = batch['edge_nums']  # [batch, mention]
                edge_max_num = utils.shape(edges, 2)
                edge_mask = utils.sequence_mask(
                    edge_nums.view(batch_size * mention_max_num),
                    edge_max_num).view(batch_size, mention_max_num,
                                       edge_max_num)  # [batch, mention, edge]
                assert not (edge_mask &
                            (~mention_mask.unsqueeze(dim=2))).any().item()

            for i in range(self.general_config.graph_encoding_steps):
                mention_emb_new = self.graph_encoder(mention_emb,
                                                     mention_mask.float(),
                                                     edges, edge_mask.float())
                mention_emb = mention_emb_new + mention_emb if self.general_config.graph_residual else mention_emb_new
                rst_graph = self.perform_matching(mention_emb, question_emb)
                matching_results.append(rst_graph)

        if len(matching_results) > 1:
            assert len(matching_results
                       ) == self.general_config.graph_encoding_steps + 1
            matching_results = torch.stack(
                matching_results, dim=2)  # [batch, mention, graph_step+1]
            logits = self.matching_integrater(matching_results).squeeze(
                dim=2)  # [batch, mention]
        else:
            assert len(matching_results) == 1
            logits = matching_results[0]  # [batch, mention]

        candidates, candidate_num, candidate_appear_num = \
                batch['candidates'], batch['candidate_num'], batch['candidate_appear_num']
        _, cand_max_num, cand_pos_max_num = list(candidates.size())

        candidate_mask = utils.sequence_mask(candidate_num,
                                             cand_max_num)  # [batch, cand]
        candidate_appear_mask = utils.sequence_mask(
            candidate_appear_num.view(batch_size * cand_max_num),
            cand_pos_max_num).view(batch_size, cand_max_num,
                                   cand_pos_max_num)  # [batch, cand, pos]
        assert not (candidate_appear_mask &
                    (~candidate_mask.unsqueeze(dim=2))).any().item()

        # ideas to get 'candidate_appear_dist'

        ## idea 1
        #candidate_appear_logits = (utils.batch_gather(logits, candidates) + \
        #        candidate_appear_mask.float().log()).view(batch_size, cand_max_num * cand_pos_max_num) # [batch, cand * pos]
        #candidate_appear_logits = torch.clamp(candidate_appear_logits, -1e1, 1e1) # [batch, cand * pos]
        #candidate_appear_dist = F.softmax(candidate_appear_logits, dim=1).view(batch_size,
        #        cand_max_num, cand_pos_max_num) # [batch, cand, pos]

        ## idea 2
        #candidate_appear_dist = torch.clamp(utils.batch_gather(logits, candidates).exp() * \
        #        candidate_appear_mask.float(), 1e-6, 1e6).view(batch_size, cand_max_num * cand_pos_max_num) # [batch, cand * pos]
        #candidate_appear_dist = candidate_appear_dist / candidate_appear_dist.sum(dim=1, keepdim=True)
        #candidate_appear_dist = candidate_appear_dist.view(batch_size, cand_max_num, cand_pos_max_num)

        ## idea 3
        #candidate_appear_dist = F.softmax(utils.batch_gather(logits, candidates).view(batch_size,
        #        cand_max_num * cand_pos_max_num), dim=1) # [batch, cand * pos]
        #candidate_appear_dist = torch.clamp(candidate_appear_dist * candidate_appear_mask.view(batch_size,
        #        cand_max_num * cand_pos_max_num).float(), 1e-8, 1.0) # [batch, cand * pos]
        #candidate_appear_dist = (candidate_appear_dist / candidate_appear_dist.sum(dim=1, keepdim=True)).view(batch_size,
        #        cand_max_num, cand_pos_max_num) # [batch, cand, pos]

        ## get 'candidate_dist', which is common for idea 1, 2 and 3
        #if not (candidate_appear_dist > 0).all().item():
        #    print(candidate_appear_dist)
        #    assert False
        #candidate_dist = candidate_appear_dist.sum(dim=2) # [batch, cand]

        # original impl
        mention_dist = F.softmax(logits, dim=1)
        if utils.contain_nan(mention_dist):
            print(logits)
            print(mention_dist)
            assert False
        candidate_appear_dist = utils.batch_gather(
            mention_dist, candidates) * candidate_appear_mask.float()
        candidate_dist = candidate_appear_dist.sum(
            dim=2) * candidate_mask.float()
        candidate_dist = utils.clip_and_normalize(candidate_dist, 1e-6)
        assert utils.contain_nan(candidate_dist) == False
        # end of original impl

        candidate_logits = candidate_dist.log()  # [batch, cand]
        predictions = candidate_logits.argmax(dim=1)  # [batch]
        if not (predictions < candidate_num).all().item():
            print(candidate_dist)
            print(candidate_num)
            assert False

        if 'refs' not in batch or batch['refs'] is None:
            return {'predictions': predictions}

        refs = batch['refs']
        loss = nn.CrossEntropyLoss()(candidate_logits, refs)
        right_count = (predictions == refs).sum()
        return {
            'predictions': predictions,
            'loss': loss,
            'right_count': right_count
        }