Пример #1
0
 def eliminate_rows(self, prob_sc, ind, phis):
     """ eliminate rows of phis and prob_matrix scale """
     length = prob_sc.size()[1]
     mask = (prob_sc[:, :, 0] > 0.85).type(dtype)
     rang = (Variable(torch.range(0, length - 1).unsqueeze(0)
             .expand_as(mask)).
             type(dtype))
     ind_sc = torch.sort(rang * (1-mask) + length * mask, 1)[1]
     # permute prob_sc
     m = mask.unsqueeze(2).expand_as(prob_sc)
     mm = m.clone()
     mm[:, :, 1:] = 0
     prob_sc = (torch.gather(prob_sc * (1 - m) + mm, 1,
                ind_sc.unsqueeze(2).expand_as(prob_sc)))
     # compose permutations
     ind = torch.gather(ind, 1, ind_sc)
     active = torch.gather(1-mask, 1, ind_sc)
     # permute phis
     active1 = active.unsqueeze(2).expand_as(phis)
     ind1 = ind.unsqueeze(2).expand_as(phis)
     active2 = active.unsqueeze(1).expand_as(phis)
     ind2 = ind.unsqueeze(1).expand_as(phis)
     phis_out = torch.gather(phis, 1, ind1) * active1
     phis_out = torch.gather(phis_out, 2, ind2) * active2
     return prob_sc, ind, phis_out, active
def sample_relax_given_class(logits, samp):

    cat = Categorical(logits=logits)

    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels

    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)


    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)


    z = z_tilde

    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

    return z, z_tilde, logprob
def hard_example_mining(dist_mat, labels, return_inds=False):
  """For each anchor, find the hardest positive and negative sample.
  Args:
    dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
    labels: pytorch LongTensor, with shape [N]
    return_inds: whether to return the indices. Save time if `False`(?)
  Returns:
    dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
    dist_an: pytorch Variable, distance(anchor, negative); shape [N]
    p_inds: pytorch LongTensor, with shape [N]; 
      indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
    n_inds: pytorch LongTensor, with shape [N];
      indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
  NOTE: Only consider the case in which all labels have same num of samples, 
    thus we can cope with all anchors in parallel.
  """

  assert len(dist_mat.size()) == 2
  assert dist_mat.size(0) == dist_mat.size(1)
  N = dist_mat.size(0)

  # shape [N, N]
  is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
  is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())

  # `dist_ap` means distance(anchor, positive)
  # both `dist_ap` and `relative_p_inds` with shape [N, 1]
  dist_ap, relative_p_inds = torch.max(
    dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
  # `dist_an` means distance(anchor, negative)
  # both `dist_an` and `relative_n_inds` with shape [N, 1]
  dist_an, relative_n_inds = torch.min(
    dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
  # shape [N]
  dist_ap = dist_ap.squeeze(1)
  dist_an = dist_an.squeeze(1)

  if return_inds:
    # shape [N, N]
    ind = (labels.new().resize_as_(labels)
           .copy_(torch.arange(0, N).long())
           .unsqueeze( 0).expand(N, N))
    # shape [N, 1]
    p_inds = torch.gather(
      ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
    n_inds = torch.gather(
      ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
    # shape [N]
    p_inds = p_inds.squeeze(1)
    n_inds = n_inds.squeeze(1)
    return dist_ap, dist_an, p_inds, n_inds

  return dist_ap, dist_an
Пример #4
0
 def sort_by_embeddings(self, Phis, Inputs_N, e):
     ind = torch.sort(e, 1)[1].squeeze()
     for i, phis in enumerate(Phis):
         # rearange phis
         phis_out = (torch.gather(Phis[i], 1, ind.unsqueeze(2)
                     .expand_as(phis)))
         Phis[i] = (torch.gather(phis_out, 2, ind.unsqueeze(1)
                    .expand_as(phis)))
         # rearange inputs
         Inputs_N[i] = torch.gather(Inputs_N[i], 1,
                                    ind.unsqueeze(2).expand_as(Inputs_N[i]))
     return Phis, Inputs_N
Пример #5
0
    def proposal_layer(self, rpn_class, rpn_bbox):
        # handling proposals
        scores = rpn_class[:, :, 1]
        # Box deltas [batch, num_rois, 4]
        deltas_mul = Variable(torch.from_numpy(np.reshape(
            self.config.RPN_BBOX_STD_DEV, [1, 1, 4]).astype(np.float32))).cuda()
        deltas = rpn_bbox * deltas_mul

        pre_nms_limit = min(6000, self.anchors.shape[0])

        scores, ix = torch.topk(scores, pre_nms_limit, dim=-1,
                                largest=True, sorted=True)


        ix = torch.unsqueeze(ix, 2)
        ix = torch.cat([ix, ix, ix, ix], dim=2)
        deltas = torch.gather(deltas, 1, ix)

        _anchors = []
        for i in range(self.config.IMAGES_PER_GPU):
            anchors = Variable(torch.from_numpy(
                self.anchors.astype(np.float32))).cuda()
            _anchors.append(anchors)
        anchors = torch.stack(_anchors, 0) 
    
        pre_nms_anchors = torch.gather(anchors, 1, ix)
        refined_anchors = apply_box_deltas_graph(pre_nms_anchors, deltas)

        # Clip to image boundaries. [batch, N, (y1, x1, y2, x2)]
        height, width = self.config.IMAGE_SHAPE[:2]
        window = np.array([0, 0, height, width]).astype(np.float32)
        window = Variable(torch.from_numpy(window)).cuda()

        refined_anchors_clipped = clip_boxes_graph(refined_anchors, window)

        refined_proposals = []
        for i in range(self.config.IMAGES_PER_GPU):
            indices = nms(
                torch.cat([refined_anchors_clipped.data[i], scores.data[i]], 1), 0.7)
            indices = indices[:self.proposal_count]
            indices = torch.stack([indices, indices, indices, indices], dim=1)
            indices = Variable(indices).cuda()
            proposals = torch.gather(refined_anchors_clipped[i], 0, indices)
            padding = self.proposal_count - proposals.size()[0]
            proposals = torch.cat(
                [proposals, Variable(torch.zeros([padding, 4])).cuda()], 0)
            refined_proposals.append(proposals)

        rpn_rois = torch.stack(refined_proposals, 0)

        return rpn_rois
Пример #6
0
    def _score_sentence(self, scores, mask, tags):
        """
            input:
                scores: variable (seq_len, batch, tag_size, tag_size)
                mask: (batch, seq_len)
                tags: tensor  (batch, seq_len)
            output:
                score: sum of score for gold sequences within whole batch
        """
        # Gives the score of a provided tag sequence
        batch_size = scores.size(1)
        seq_len = scores.size(0)
        tag_size = scores.size(2)
        ## convert tag value into a new format, recorded label bigram information to index  
        new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len))
        if self.gpu:
            new_tags = new_tags.cuda()
        for idx in range(seq_len):
            if idx == 0:
                ## start -> first score
                new_tags[:,0] =  (tag_size - 2)*tag_size + tags[:,0]

            else:
                new_tags[:,idx] =  tags[:,idx-1]*tag_size + tags[:,idx]

        ## transition for label to STOP_TAG
        end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size)
        ## length for batch,  last word position = length - 1
        length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
        ## index the label id of last word
        end_ids = torch.gather(tags, 1, length_mask - 1)

        ## index the transition score for end_id to STOP_TAG
        end_energy = torch.gather(end_transition, 1, end_ids)

        ## convert tag as (seq_len, batch_size, 1)
        new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1)
        ### need convert tags id to search from 400 positions of scores
        tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size)  # seq_len * bat_size
        ## mask transpose to (seq_len, batch_size)
        tg_energy = tg_energy.masked_select(mask.transpose(1,0))
        
        # ## calculate the score from START_TAG to first label
        # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size)
        # start_energy = torch.gather(start_transition, 1, tags[0,:])

        ## add all score together
        # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum()
        gold_score = tg_energy.sum() + end_energy.sum()
        return gold_score
Пример #7
0
    def word_pre_train_forward(self, sentence, position):
        """
        output of forward language model

        args:
            sentence (char_seq_len, batch_size): char-level representation of sentence
            position (word_seq_len, batch_size): position of blank space in char-level representation of sentence

        """

        embeds = self.char_embeds(sentence)
        d_embeds = self.dropout(embeds)
        lstm_out, hidden = self.forw_char_lstm(d_embeds)

        tmpsize = position.size()
        position = position.unsqueeze(2).expand(tmpsize[0], tmpsize[1], self.char_hidden_dim)
        select_lstm_out = torch.gather(lstm_out, 0, position)
        d_lstm_out = self.dropout(select_lstm_out).view(-1, self.char_hidden_dim)

        if self.if_highway:
            char_out = self.forw2word(d_lstm_out)
            d_char_out = self.dropout(char_out)
        else:
            d_char_out = d_lstm_out

        pre_score = self.word_pre_train_out(d_char_out)
        return pre_score, hidden
Пример #8
0
def decode_with_crf(crf, word_reps, mask_v, l_map):
    """
    decode with viterbi algorithm and return score

    """

    seq_len = word_reps.size(0)
    bat_size = word_reps.size(1)
    decoded_crf = crf.decode(word_reps, mask_v)
    scores = crf.cal_score(word_reps).data
    mask_v = mask_v.data
    decoded_crf = decoded_crf.data
    decoded_crf_withpad = torch.cat((torch.cuda.LongTensor(1,bat_size).fill_(l_map['<start>']), decoded_crf), 0)
    decoded_crf_withpad = decoded_crf_withpad.transpose(0,1).cpu().numpy()
    label_size = len(l_map)

    bi_crf = []
    cur_len = decoded_crf_withpad.shape[1]-1
    for i_l in decoded_crf_withpad:
        bi_crf.append([i_l[ind] * label_size + i_l[ind + 1] for ind in range(0, cur_len)] + [
            i_l[cur_len] * label_size + l_map['<pad>']])
    bi_crf = torch.cuda.LongTensor(bi_crf).transpose(0,1).unsqueeze(2)

    tg_energy = torch.gather(scores.view(seq_len, bat_size, -1), 2, bi_crf).view(seq_len, bat_size)  # seq_len * bat_size
    tg_energy = tg_energy.transpose(0,1).masked_select(mask_v.transpose(0,1))
    tg_energy = tg_energy.cpu().numpy()
    masks = mask_v.sum(0)
    crf_result_scored_by_crf = []
    start = 0
    for i, mask in enumerate(masks):
        end = start + mask
        crf_result_scored_by_crf.append(tg_energy[start:end].sum())
        start = end
    crf_result_scored_by_crf = np.array(crf_result_scored_by_crf)
    return decoded_crf.cpu().transpose(0,1).numpy(), crf_result_scored_by_crf
def masked_cross_entropy(logits, target, length):
    length = Variable(torch.LongTensor(length)).cuda()

    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.

    Returns:
        loss: An average loss value masked by the length.
    """

    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = functional.log_softmax(logits_flat)
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())
    # mask: (batch, max_len)
    mask = sequence_mask(sequence_length=length, max_len=target.size(1))
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    return loss
Пример #10
0
    def forward(self, batch):
        X_data, X_padding_mask, X_lens, X_batch_extend_vocab, X_extra_zeros, context, coverage = self.get_input_from_batch(batch)
        y_data, y_padding_mask, y_max_len, y_lens_var, target_data = self.get_output_from_batch(batch)

        encoder_outputs, encoder_hidden, max_encoder_output = self.encoder(X_data, X_lens)
        s_t_1 = self.reduce_state(encoder_hidden)
        if config.use_maxpool_init_ctx:
            context = max_encoder_output

        step_losses = []
        for di in range(min(y_max_len, self.args.max_decoder_steps)):
            y_t_1 = y_data[:, di]  # Teacher forcing
            final_dist, s_t_1, context, attn_dist, p_gen, coverage = self.decoder(y_t_1, s_t_1,
                                                                                        encoder_outputs, X_padding_mask, context,
                                                                                        X_extra_zeros, X_batch_extend_vocab,
                                                                                        coverage)
            target = target_data[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + self.args.eps)
            if self.args.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
            step_mask = y_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / y_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss
Пример #11
0
    def forward(self, batch):
        """Forward method receives target-length ordered batches."""

        # Encode image and get initial variables
        img_ctx, c_t, h_t = self.f_init(batch)

        # Fetch embeddings -> (seq_len, batch_size, emb_dim)
        caption = batch[self.tl]

        # n_tokens token processed in this batch
        self.n_tokens = caption.numel()

        # Get embeddings
        embs = self.emb(caption)

        # Accumulators
        loss = 0.0
        self.alphas = []

        # -1: So that we skip the timestep where input is <eos>
        for t in range(caption.shape[0] - 1):
            # NOTE: This is where scheduled sampling will happen
            # Either fetch from self.emb or from log_p
            # Current textual input to decoder: y_t = embs[t]
            log_p, c_t, h_t, _ = self.f_next(img_ctx, embs[t], c_t, h_t)

            # t + 1: We're predicting next token
            # Cumulate losses
            loss += torch.gather(
                log_p, dim=1, index=caption[t + 1].unsqueeze(1)).sum()

        # Return normalized loss
        return loss / self.n_tokens
Пример #12
0
    def forward(self, ctx_dict, y):
        """Computes the softmax outputs given source annotations `ctxs` and
        ground-truth target token indices `y`.

        Arguments:
            ctxs(Variable): A variable of `S*B*ctx_dim` representing the source
                annotations in an order compatible with ground-truth targets.
            y(Variable): A variable of `T*B` containing ground-truth target
                token indices for the given batch.
        """

        loss = 0.0
        # Convert token indices to embeddings -> T*B*E
        y_emb = self.emb(y)

        # Get initial hidden state
        h = self.f_init(*ctx_dict['txt'])

        # -1: So that we skip the timestep where input is <eos>
        for t in range(y_emb.shape[0] - 1):
            log_p, h = self.f_next(ctx_dict, y_emb[t], h)
            loss += torch.gather(
                log_p, dim=1, index=y[t + 1].unsqueeze(1)).sum()

        return loss
Пример #13
0
    def eval_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_1 = max_encoder_output

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1,attn_dist, p_gen, coverage = self.model.decoder(y_t_1, s_t_1,
                                                                encoder_outputs, enc_padding_mask, c_t_1,
                                                                extra_zeros, enc_batch_extend_vocab, coverage)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss.data[0]
Пример #14
0
    def _compute_loss(self, batch, output, target):
        scores = self.generator(self._bottle(output))

        gtruth = target.view(-1)
        if self.confidence < 1:
            tdata = gtruth.data
            mask = torch.nonzero(tdata.eq(self.padding_idx)).squeeze()
            log_likelihood = torch.gather(scores.data, 1, tdata.unsqueeze(1))
            tmp_ = self.one_hot.repeat(gtruth.size(0), 1)
            tmp_.scatter_(1, tdata.unsqueeze(1), self.confidence)
            if mask.dim() > 0:
                log_likelihood.index_fill_(0, mask, 0)
                tmp_.index_fill_(0, mask, 0)
            gtruth = Variable(tmp_, requires_grad=False)
        loss = self.criterion(scores, gtruth)
        if self.confidence < 1:
            # Default: report smoothed ppl.
            # loss_data = -log_likelihood.sum(0)
            loss_data = loss.data.clone()
        else:
            loss_data = loss.data.clone()

        stats = self._stats(loss_data, scores.data, target.view(-1).data)

        return loss, stats
 def forward(self, log_prob, y_true, mask):
     mask = mask.float()
     log_P = torch.gather(log_prob.view(-1, log_prob.size(2)), 1, y_true.contiguous().view(-1, 1))  # batch*time x 1
     log_P = log_P.view(y_true.size(0), y_true.size(1))  # batch x time
     log_P = log_P * mask  # batch x time
     sum_log_P = torch.sum(log_P, dim=1) / torch.sum(mask, dim=1)  # batch
     return -sum_log_P
Пример #16
0
    def get_max_q_values_with_target(
        self, q_values, q_values_target, possible_actions_mask
    ):
        """
        Used in Q-learning update.
        :param states: Numpy array with shape (batch_size, state_dim). Each row
            contains a representation of a state.
        :param possible_actions_mask: Numpy array with shape (batch_size, action_dim).
            possible_actions[i][j] = 1 iff the agent can take action j from
            state i.
        :param double_q_learning: bool to use double q-learning
        """

        # The parametric DQN can create flattened q values so we reshape here.
        q_values = q_values.reshape(possible_actions_mask.shape)
        q_values_target = q_values_target.reshape(possible_actions_mask.shape)

        if self.double_q_learning:
            # Set q-values of impossible actions to a very large negative number.
            inverse_pna = 1 - possible_actions_mask
            impossible_action_penalty = self.ACTION_NOT_POSSIBLE_VAL * inverse_pna
            q_values = q_values + impossible_action_penalty
            # Select max_q action after scoring with online network
            max_q_values, max_indicies = torch.max(q_values, dim=1, keepdim=True)
            # Use q_values from target network for max_q action from online q_network
            # to decouble selection & scoring, preventing overestimation of q-values
            q_values = torch.gather(q_values_target, 1, max_indicies)
            return q_values, max_indicies
        else:
            return self.get_max_q_values(q_values, possible_actions_mask)
Пример #17
0
    def lm_lstm(self, forw_sentence, forw_position, back_sentence, back_position, word_seq):
        '''
        return word representations with character-language-model

        args:
            forw_sentence (char_seq_len, batch_size) : char-level representation of sentence
            forw_position (word_seq_len, batch_size) : position of blank space in char-level representation of sentence
            back_sentence (char_seq_len, batch_size) : char-level representation of sentence (inverse order)
            back_position (word_seq_len, batch_size) : position of blank space in inversed char-level representation of sentence
            word_seq (word_seq_len, batch_size) : word-level representation of sentence

        '''

        self.set_batch_seq_size(forw_position)

        forw_emb = self.char_embeds(forw_sentence)
        back_emb = self.char_embeds(back_sentence)

        d_f_emb = self.dropout(forw_emb)
        d_b_emb = self.dropout(back_emb)

        forw_lstm_out, _ = self.forw_char_lstm(d_f_emb)

        back_lstm_out, _ = self.back_char_lstm(d_b_emb)

        forw_position = forw_position.unsqueeze(2).expand(self.word_seq_length, self.batch_size, self.char_hidden_dim)
        select_forw_lstm_out = torch.gather(forw_lstm_out, 0, forw_position)

        back_position = back_position.unsqueeze(2).expand(self.word_seq_length, self.batch_size, self.char_hidden_dim)
        select_back_lstm_out = torch.gather(back_lstm_out, 0, back_position)

        fb_lstm_out = self.dropout(torch.cat((select_forw_lstm_out, select_back_lstm_out), dim=2))
        if self.if_highway:
            char_out = self.fb2char(fb_lstm_out)
            d_char_out = self.dropout(char_out)
        else:
            d_char_out = fb_lstm_out

        word_emb = self.word_embeds(word_seq)
        d_word_emb = self.dropout(word_emb)

        word_input = torch.cat((d_word_emb, d_char_out), dim=2)

        lstm_out, _ = self.word_lstm_lm(word_input)
        d_lstm_out = self.dropout(lstm_out)

        return d_lstm_out
Пример #18
0
def NN(epoch, net, lemniscate, trainloader, testloader, recompute_memory=0):
    net.eval()
    net_time = AverageMeter()
    cls_time = AverageMeter()
    losses = AverageMeter()
    correct = 0.
    total = 0
    testsize = testloader.dataset.__len__()

    trainFeatures = lemniscate.memory.t()
    if hasattr(trainloader.dataset, 'imgs'):
        trainLabels = torch.LongTensor([y for (p, y) in trainloader.dataset.imgs]).cuda()
    else:
        trainLabels = torch.LongTensor(trainloader.dataset.train_labels).cuda()

    if recompute_memory:
        transform_bak = trainloader.dataset.transform
        trainloader.dataset.transform = testloader.dataset.transform
        temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=100, shuffle=False, num_workers=1)
        for batch_idx, (inputs, targets, indexes) in enumerate(temploader):
            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = Variable(inputs, volatile=True), Variable(targets)
            batchSize = inputs.size(0)
            features = net(inputs)
            trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.data.t()
        trainLabels = torch.LongTensor(temploader.dataset.train_labels).cuda()
        trainloader.dataset.transform = transform_bak
    
    end = time.time()
    for batch_idx, (inputs, targets, indexes) in enumerate(testloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs, volatile=True), Variable(targets)
        batchSize = inputs.size(0)
        features = net(inputs)
        net_time.update(time.time() - end)
        end = time.time()

        dist = torch.mm(features.data, trainFeatures)

        yd, yi = dist.topk(1, dim=1, largest=True, sorted=True)
        candidates = trainLabels.view(1,-1).expand(batchSize, -1)
        retrieval = torch.gather(candidates, 1, yi)

        retrieval = retrieval.narrow(1, 0, 1).clone().view(-1)
        yd = yd.narrow(1, 0, 1)

        total += targets.size(0)
        correct += retrieval.eq(targets.data).cpu().sum()
        
        cls_time.update(time.time() - end)
        end = time.time()

        print('Test [{}/{}]\t'
              'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t'
              'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t'
              'Top1: {:.2f}'.format(
              total, testsize, correct*100./total, net_time=net_time, cls_time=cls_time))

    return correct/total
Пример #19
0
    def forward(self, input, target, mask):

        logprob_select = torch.gather(input, 1, target)

        out = torch.masked_select(logprob_select, mask)

        loss = -torch.sum(out) / mask.float().sum()
        return loss
Пример #20
0
 def forward(self, x_de, x_en, update_baseline=True):
     bs = x_de.size(0)
     # x_de is bs,n_de. x_en is bs,n_en
     emb_de = self.embedding_de(x_de) # bs,n_de,word_dim
     emb_en = self.embedding_en(x_en) # bs,n_en,word_dim
     h0_enc = torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda()
     c0_enc = torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda()
     h0_dec = torch.zeros(self.n_layers, bs, self.hidden_dim).cuda()
     c0_dec = torch.zeros(self.n_layers, bs, self.hidden_dim).cuda()
     # hidden vars have dimension n_layers*n_directions,bs,hiddensz
     enc_h, _ = self.encoder(emb_de, (Variable(h0_enc), Variable(c0_enc)))
     # enc_h is bs,n_de,hiddensz*n_directions. ordering is different from last week because batch_first=True
     dec_h, _ = self.decoder(emb_en, (Variable(h0_dec), Variable(c0_dec)))
     # dec_h is bs,n_en,hidden_size*n_directions
     # we've gotten our encoder/decoder hidden states so we are ready to do attention
     # first let's get all our scores, which we can do easily since we are using dot-prod attention
     if self.directions == 2:
         scores = torch.bmm(self.dim_reduce(enc_h), dec_h.transpose(1,2))
         # TODO: any easier ways to reduce dimension?
     else:
         scores = torch.bmm(enc_h, dec_h.transpose(1,2))
     # (bs,n_de,hiddensz*n_directions) * (bs,hiddensz*n_directions,n_en) = (bs,n_de,n_en)
     reinforce_loss = 0 # we only use this variable for hard attention
     loss = 0
     avg_reward = 0
     # we just iterate to dec_h.size(1)-1, since there's </s> at the end of each sentence
     for t in range(dec_h.size(1)-1): # iterate over english words, with teacher forcing
         attn_dist = F.softmax(scores[:, :, t],dim=1) # bs,n_de. these are the alphas (attention scores for each german word)
         if self.attn_type == "hard":
             cat = torch.distributions.Categorical(attn_dist) 
             attn_samples = cat.sample() # bs. each element is a sample from categorical distribution
             one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, attn_samples.data.unsqueeze(1), 1).cuda()) # bs,n_de
             # made a bunch of one-hot vectors
             context = torch.bmm(one_hot.unsqueeze(1), enc_h).squeeze(1)
             # now we use the one-hot vectors to select correct hidden vectors from enc_h
             # (bs,1,n_de) * (bs,n_de,hiddensz*n_directions) = (bs,1,hiddensz*n_directions). squeeze to bs,hiddensz*n_directions
         else:
             context = torch.bmm(attn_dist.unsqueeze(1), enc_h).squeeze(1) # same dimensions
             # (bs,1,n_de) * (bs,n_de,hiddensz*n_directions) = (bs,1,hiddensz*n_directions)
         # context is bs,hidden_size*n_directions
         # the rnn output and the context together make the decoder "hidden state", which is bs,2*hidden_size*n_directions
         pred = self.vocab_layer(torch.cat([dec_h[:,t,:], context], 1)) # bs,len(EN.vocab)
         y = x_en[:, t+1] # bs. these are our labels
         no_pad = (y != pad_token) # exclude english padding tokens
         reward = torch.gather(pred, 1, y.unsqueeze(1)) # bs,1
         # reward[i,1] = pred[i,y[i]]. this gets log prob of correct word for each batch. similar to -crossentropy
         reward = reward.squeeze(1)[no_pad] # less than bs
         if self.attn_type == "hard":
             reinforce_loss -= (cat.log_prob(attn_samples[no_pad]) * (reward-self.baseline).detach()).sum() 
             # reinforce rule (just read the formula), with special baseline
         loss -= reward.sum() # minimizing loss is maximizing reward
     no_pad_total = (x_en[:,1:] != pad_token).data.sum() # TODO: i think this is right, right?
     loss /= no_pad_total
     reinforce_loss /= no_pad_total
     avg_reward = -loss.data[0]
     if update_baseline: # update baseline as a moving average
         self.baseline = Variable(0.95*self.baseline.data + 0.05*avg_reward)
     return loss, reinforce_loss,avg_reward
Пример #21
0
def getMAE():
    userIdx = testData.user.values
    itemIdx = testData.item.values
    rates = testData.rate.values
    R = torch.mm(U,P.t())
    ratesPred = torch.gather(R.view(1,-1)[0],0,Variable(torch.LongTensor (userIdx * len(set(itemIdx)) + itemIdx)))
    diff_op = ratesPred - Variable(torch.FloatTensor(rates))
    MAE = diff_op.abs().mean()
    return MAE.data.numpy()[0]
Пример #22
0
 def forward(self, article_sents, sent_nums, target):
     enc_out = self._encode(article_sents, sent_nums)
     bs, nt = target.size()
     d = enc_out.size(2)
     ptr_in = torch.gather(
         enc_out, dim=1, index=target.unsqueeze(2).expand(bs, nt, d)
     )
     output = self._extractor(enc_out, sent_nums, ptr_in)
     return output
Пример #23
0
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
                                       targets: torch.LongTensor,
                                       weights: torch.FloatTensor,
                                       batch_average: bool = True) -> torch.FloatTensor:
    """
    Computes the cross entropy loss of a sequence, weighted with respect to
    some user provided weights. Note that the weighting here is not the same as
    in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
    classes; here we are weighting the loss contribution from particular elements
    in the sequence. This allows loss computations for models which use padding.

    Parameters
    ----------
    logits : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
        which contains the unnormalized probability for each class.
    targets : ``torch.LongTensor``, required.
        A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
        index of the true class for each corresponding step.
    weights : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch, sequence_length)
    batch_average : bool, optional, (default = True).
        A bool indicating whether the loss should be averaged across the batch,
        or returned as a vector of losses per batch element.

    Returns
    -------
    A torch.FloatTensor representing the cross entropy loss.
    If ``batch_average == True``, the returned loss is a scalar.
    If ``batch_average == False``, the returned loss is a vector of shape (batch_size,).

    """
    # shape : (batch * sequence_length, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # shape : (batch * sequence_length, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat)
    # shape : (batch * max_len, 1)
    targets_flat = targets.view(-1, 1).long()

    # Contribution to the negative log likelihood only comes from the exact indices
    # of the targets, as the target distributions are one-hot. Here we use torch.gather
    # to extract the indices of the num_classes dimension which contribute to the loss.
    # shape : (batch * sequence_length, 1)
    negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood * weights.float()
    # shape : (batch_size,)
    per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)

    if batch_average:
        num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
        return per_batch_loss.sum() / num_non_empty_sequences
    return per_batch_loss
Пример #24
0
    def forward(self, input, target):
        logprob_select = torch.gather(input, 1, target)

        mask = target.data.gt(0)  # generate the mask
        if isinstance(input, Variable):
            mask = Variable(mask, volatile=input.volatile)
        
        out = torch.masked_select(logprob_select, mask)

        loss = -torch.sum(out) # get the average loss.
        return loss
Пример #25
0
 def emiss(self, T, idx, ignore_index=None):
     assert len(idx.shape) == 1
     bs = idx.shape[0]
     idx = idx.view(-1, 1).expand(bs, self.ns).unsqueeze(-1)
     emiss = torch.gather(self.emission[T], -1, idx).view(bs, 1, self.ns)
     if ignore_index is None:
         return emiss
     else:
         idx = idx.view(bs, 1, self.ns)
         mask = (idx != ignore_index).float()
         return emiss * mask
Пример #26
0
def log_sum_exp(vec, m_size):
    """
    calculate log of exp sum
    args:
        vec (batch_size, vanishing_dim, hidden_dim) : input tensor
        m_size : hidden_dim
    return:
        batch_size, hidden_dim
    """
    _, idx = torch.max(vec, 1)  # B * 1 * M
    max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size)  # B * M
    return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size)  # B * M
Пример #27
0
    def _score_candidates(self, cands, xe, encoder_output, hidden):
        # score each candidate separately

        # cands are exs_with_cands x cands_per_ex x words_per_cand
        # cview is total_cands x words_per_cand
        cview = cands.view(-1, cands.size(2))
        cands_xes = xe.expand(xe.size(0), cview.size(0), xe.size(2))
        sz = hidden.size()
        cands_hn = (
            hidden.view(sz[0], sz[1], 1, sz[2])
            .expand(sz[0], sz[1], cands.size(1), sz[2])
            .contiguous()
            .view(sz[0], -1, sz[2])
        )

        sz = encoder_output.size()
        cands_encoder_output = (
            encoder_output.contiguous()
            .view(sz[0], 1, sz[1], sz[2])
            .expand(sz[0], cands.size(1), sz[1], sz[2])
            .contiguous()
            .view(-1, sz[1], sz[2])
        )

        cand_scores = Variable(
                    self.cand_scores.resize_(cview.size(0)).fill_(0))
        cand_lengths = Variable(
                    self.cand_lengths.resize_(cview.size(0)).fill_(0))

        for i in range(cview.size(1)):
            output = self._apply_attention(cands_xes, cands_encoder_output, cands_hn) \
                    if self.use_attention else cands_xes

            output, cands_hn = self.decoder(output, cands_hn)
            preds, scores = self.hidden_to_idx(output, dropout=False)
            cs = cview.select(1, i)
            non_nulls = cs.ne(self.NULL_IDX)
            cand_lengths += non_nulls.long()
            score_per_cand = torch.gather(scores, 1, cs.unsqueeze(1))
            cand_scores += score_per_cand.squeeze() * non_nulls.float()
            cands_xes = self.lt2dec(self.lt(cs).unsqueeze(0))

        # set empty scores to -1, so when divided by 0 they become -inf
        cand_scores -= cand_lengths.eq(0).float()
        # average the scores per token
        cand_scores /= cand_lengths.float()

        cand_scores = cand_scores.view(cands.size(0), cands.size(1))
        srtd_scores, text_cand_inds = cand_scores.sort(1, True)
        text_cand_inds = text_cand_inds.data

        return text_cand_inds
def sample_relax_given_class_k(logits, samp, k):

    cat = Categorical(logits=logits)
    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)

    zs = []
    z_tildes = []
    for i in range(k):

        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        z = z_tilde

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        zs.append(z)
        z_tildes.append(z_tilde)

    zs= torch.stack(zs)
    z_tildes= torch.stack(z_tildes)
    
    z = torch.mean(zs, dim=0)
    z_tilde = torch.mean(z_tildes, dim=0)

    return z, z_tilde, logprob
Пример #29
0
 def combine_matrices(self, prob_matrix, prob_matrix_scale,
                      perm, last=False):
     # prob_matrix shape is bs x length x length + 1. Add extra column.
     length = prob_matrix_scale.size()[2]
     first = Variable(torch.zeros([self.batch_size, 1, length])).type(dtype)
     first[:, 0, 0] = 1.0
     prob_matrix_scale = torch.cat((first, prob_matrix_scale), 1)
     # argmax
     new_perm = self.discretize(prob_matrix_scale)
     perm = torch.gather(perm, 1, new_perm)
     # combine
     prob_matrix = torch.bmm(prob_matrix_scale, prob_matrix)
     return prob_matrix, prob_matrix_scale, perm
Пример #30
0
def masked_cross_entropy(logits, target, length):
    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.
    Returns:
        loss: An average loss value masked by the length.
        
    The code is same as:
    
    weight = torch.ones(tgt_vocab_size)
    weight[padding_idx] = 0
    criterion = nn.CrossEntropyLoss(weight.cuda(), size_average)
    loss = criterion(logits_flat, losses_flat)
    """
    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))       ## (3, 16, 50000) => (3*16, 50000)
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = F.log_softmax(logits_flat, dim=1)   ## (3*16, 50000) => (3*16, 50000)
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)                     ## (3, 16) => (3*16, 1)
    # losses_flat: (batch * max_len, 1)                  ## the -log_prob of target token_id in each batch and time step
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)          ## (3*16, 1)
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())                                      ## (3*16, 1) => (3, 16)
    # mask: (batch, max_len)
    mask = sequence_mask(sequence_length=length, max_len=target.size(1))           ## (3, 16)
    # Note: mask need to bed casted to float!
    losses = losses * mask.float()
    loss = losses.sum() / mask.float().sum()             ## 把padding的部分弄掉,才能符合loss的公式(分母為正確的num_words)
    
    # # (batch_size * max_tgt_len,)                        ## the predicted token_id of max log_prob in each batch and time step
    # pred_flat = log_probs_flat.max(1)[1]                 ## (3*16, 1)
    # # (batch_size * max_tgt_len,) => (batch_size, max_tgt_len) => (max_tgt_len, batch_size)
    # pred_seqs = pred_flat.view(*target.size()).transpose(0,1).contiguous()         ## (3*16, 1) => (3, 16) => (16, 3)
    # # (batch_size, max_len) => (batch_size * max_tgt_len,)
    # mask_flat = mask.view(-1)                            ## (3, 16) => (3*16)
    
    # # `.float()` IS VERY IMPORTANT !!!
    # # https://discuss.pytorch.org/t/batch-size-and-validation-accuracy/4066/3
    # num_corrects = int(pred_flat.eq(target_flat.squeeze(1)).masked_select(mask_flat).float().data.sum())  ## 正確的字有幾個
    # num_words = length.data.sum()     ## 這個batch共多少字

    # return loss, pred_seqs, num_corrects, num_words
    return loss
Пример #31
0
    def decode_z_order(self, batch_size, u_input, u_hiddens, u_input_1hot, u_last_hidden, z_input,
                          turn_states, sample_type, decoder_type, qz_samples=None, qz_hiddens=None,
                          m_input=None, m_hiddens=None, m_input_1hot=None, mask_otlg=False):
        return_gmb = True if 'gumbel' in sample_type else False
        pv_z_pr = turn_states.get('pv_%s_pr'%decoder_type, None)
        pv_z_h = turn_states.get('pv_%s_h'%decoder_type, None)
        pv_z_id = turn_states.get('pv_%s_id'%decoder_type, None)
        z_prob, z_samples, gmb_samples = [], [], []
        log_pz = 0
        for si, sn in enumerate(self.reader.otlg.informable_slots):
            last_hidden = u_last_hidden[:-1]
            # last_hidden = (u_last_hidden[-1] + u_last_hidden[-2]).unsqueeze(0)
            z_eos_idx = self.vocab.encode(self.z_eos_map[sn])
            emb_zt = self.get_first_z_input(sn, batch_size, self.multi_domain)
            zero_vec = cuda_(torch.zeros(batch_size, 1, self.hidden_size))
            selc_read_u = selc_read_m = selc_read_pv_z = zero_vec
            if pv_z_pr is not None:
                b, e = si * self.z_length, (si+1) * self.z_length
                pv_pr, pv_h, pv_idx = pv_z_pr[:, b:e], pv_z_h[:, b:e], pv_z_id[:, b:e]
            else:
                pv_pr, pv_h, pv_idx = None, None, None
            prev_zt = None
            for t in range(self.z_length):
                if decoder_type == 'pz':
                    prob, last_hidden, gru_out, selc_read_u, selc_read_pv_z = \
                        self.pz_decoder(u_input, u_input_1hot, u_hiddens,
                                            pv_z_prob=pv_pr, pv_z_hidden=pv_h, pv_z_idx=pv_idx,
                                            emb_zt=emb_zt,  last_hidden=last_hidden,
                                            selc_read_u=selc_read_u, selc_read_pv_z=selc_read_pv_z)
                else:
                    prob, last_hidden, gru_out, selc_read_u, selc_read_m, selc_read_pv_z, gmb_samp = \
                        self.qz_decoder(u_input, u_input_1hot, u_hiddens, m_input, m_input_1hot, m_hiddens,
                                                pv_z_prob=pv_pr, pv_z_hidden=pv_h, pv_z_idx=pv_idx,
                                                emb_zt=emb_zt, last_hidden=last_hidden, selc_read_u=selc_read_u,
                                                selc_read_m=selc_read_m, selc_read_pv_z=selc_read_pv_z,
                                                temp=self.gumbel_temp, return_gmb=return_gmb)
                if mask_otlg:
                    prob = self.mask_probs(prob, tokens_allow=self.reader.slot_value_mask[sn])

                if sample_type == 'supervised':
                    zt = z_input[sn][:, t]
                elif sample_type == 'top1':
                    zt = torch.topk(prob, 1)[1]
                elif sample_type == 'topk':
                    topk_probs, topk_words = torch.topk(prob.squeeze(1), cfg.topk_num)
                    widx = torch.multinomial(topk_probs, 1, replacement=True)
                    zt = torch.gather(topk_words, 1, widx)      #[B]
                elif sample_type == 'posterior':
                    zt = qz_samples[:, si * self.z_length + t]
                elif 'gumbel' in sample_type:
                    zt = torch.argmax(gmb_samp, dim=1)   #[B]
                    emb_zt = torch.matmul(gmb_samp, self.embedding.weight).unsqueeze(1) # [B, 1, H]
                    zt, prev_zt, gmb_samp = self.mask_samples(zt, prev_zt, batch_size, z_eos_idx, gmb_samp, True)
                    gmb_samples.append(gmb_samp)

                if 'gumbel' not in sample_type:
                    emb_zt = self.embedding(zt.view(-1, 1))
                    prob_zt = torch.gather(prob, 1, zt.view(-1, 1)).squeeze(1) #[B, 1]
                    log_prob_zt = torch.log(prob_zt)
                    zt, prev_zt, log_prob_zt = self.mask_samples(zt, prev_zt, batch_size, z_eos_idx, log_prob_zt)
                    log_pz += log_prob_zt
                z_samples.append(zt.view(-1))
                z_prob.append(prob)

        z_prob = torch.stack(z_prob, dim=1)  # [B*K,Tz,V]
        z_samples= torch.stack(z_samples, dim=1)  # [B*K,Tz]
        if sample_type == 'posterior':
            z_samples, z_hiddens = qz_samples, qz_hiddens
        elif 'gumbel' not in sample_type:
            z_hiddens, z_last_hidden = self.z_encoder(z_samples, input_type='index')
        else:
            z_gumbel = torch.stack(gmb_samples, dim=1)   # [B,Tz, V]
            z_gumbel = torch.matmul(z_gumbel, self.embedding.weight)     # [B,Tz, E]
            z_hiddens, z_last_hidden = self.z_encoder(z_gumbel, input_type='embedding')

        retain = self.prev_z_continuous
        turn_states['pv_%s_h'%decoder_type] = z_hiddens if retain else z_hiddens.detach()
        turn_states['pv_%s_pr'%decoder_type] = z_prob if retain else z_prob.detach()
        turn_states['pv_%s_id'%decoder_type] = z_samples if retain else z_samples.detach()

        return z_prob, z_samples, z_hiddens, turn_states, log_pz
Пример #32
0
    def _get_parallel_step_context(self, embeddings, state, from_depot=False):
        """
        Returns the context per step, optionally for multiple steps at once (for efficient evaluation of the model)
        
        :param embeddings: (batch_size, graph_size, embed_dim)
        :param prev_a: (batch_size, num_steps)
        :param first_a: Only used when num_steps = 1, action of first step or None if first step
        :return: (batch_size, num_steps, context_dim)
        """

        current_node = state.get_current_node()
        batch_size, num_steps = current_node.size()

        if self.is_vrp:
            # Embedding of previous node + remaining capacity
            if from_depot:
                # 1st dimension is node idx, but we do not squeeze it since we want to insert step dimension
                # i.e. we actually want embeddings[:, 0, :][:, None, :] which is equivalent
                return torch.cat(
                    (
                        embeddings[:, 0:1, :].expand(batch_size, num_steps,
                                                     embeddings.size(-1)),
                        # used capacity is 0 after visiting depot
                        self.problem.VEHICLE_CAPACITY -
                        torch.zeros_like(state.used_capacity[:, :, None])),
                    -1)
            else:
                return torch.cat((torch.gather(
                    embeddings, 1,
                    current_node.contiguous().view(
                        batch_size, num_steps, 1).expand(
                            batch_size, num_steps, embeddings.size(-1))).view(
                                batch_size, num_steps, embeddings.size(-1)),
                                  self.problem.VEHICLE_CAPACITY -
                                  state.used_capacity[:, :, None]), -1)
        elif self.is_orienteering or self.is_pctsp:
            return torch.cat(
                (torch.gather(
                    embeddings, 1,
                    current_node.contiguous().view(
                        batch_size, num_steps, 1).expand(
                            batch_size, num_steps, embeddings.size(-1))).view(
                                batch_size, num_steps, embeddings.size(-1)),
                 (state.get_remaining_length()[:, :,
                                               None] if self.is_orienteering
                  else state.get_remaining_prize_to_collect()[:, :, None])),
                -1)
        else:  # TSP

            if num_steps == 1:  # We need to special case if we have only 1 step, may be the first or not
                if state.i.item() == 0:
                    # First and only step, ignore prev_a (this is a placeholder)
                    return self.W_placeholder[None, None, :].expand(
                        batch_size, 1, self.W_placeholder.size(-1))
                else:
                    return embeddings.gather(
                        1,
                        torch.cat((state.first_a, current_node),
                                  1)[:, :,
                                     None].expand(batch_size, 2,
                                                  embeddings.size(-1))).view(
                                                      batch_size, 1, -1)
            # More than one step, assume always starting with first
            embeddings_per_step = embeddings.gather(
                1, current_node[:, 1:, None].expand(batch_size, num_steps - 1,
                                                    embeddings.size(-1)))
            return torch.cat(
                (
                    # First step placeholder, cat in dim 1 (time steps)
                    self.W_placeholder[None, None, :].expand(
                        batch_size, 1, self.W_placeholder.size(-1)),
                    # Second step, concatenate embedding of first with embedding of current/previous (in dim 2, context dim)
                    torch.cat((embeddings_per_step[:, 0:1, :].expand(
                        batch_size, num_steps - 1,
                        embeddings.size(-1)), embeddings_per_step), 2)),
                1)
Пример #33
0
    def forward(self,
                x,
                r_ij,
                neighbors,
                pairwise_mask,
                f_ij=None,
                softmax=None):
        """Compute convolution block.

        Args:
            x (torch.Tensor): input representation/embedding of atomic environments
                with (N_b, N_a, n_in) shape.
            r_ij (torch.Tensor): interatomic distances of (N_b, N_a, N_nbh) shape.
            neighbors (torch.Tensor): indices of neighbors of (N_b, N_a, N_nbh) shape.
            pairwise_mask (torch.Tensor): mask to filter out non-existing neighbors
                introduced via padding.
            f_ij (torch.Tensor, optional): expanded interatomic distances in a basis.
                If None, r_ij.unsqueeze(-1) is used.

        Returns:
            torch.Tensor: block output with (N_b, N_a, n_out) shape.

        """
        if f_ij is None:
            f_ij = r_ij.unsqueeze(
                -1
            )  #shape [batch, num_atoms, num_neighbors (num_atoms-1),gauusian_exp]

        #-------------NEW----------------#
        if self.n_heads_weights > 0:
            A = self.Attention(x)  #attention in weight generation
            #concatenate multi-headed attention to distances
            f_ij = torch.cat((f_ij, A), dim=3)
            #f_ij = self.dropout(f_ij)
        #--------------------------------#

        # pass expanded interactomic distances through filter block
        W = self.filter_network(f_ij)
        #print(W.shape, 'Wsize')
        # apply cutoff
        if self.cutoff_network is not None:
            C = self.cutoff_network(r_ij)
            W = W * C.unsqueeze(-1)

        # pass initial embeddings through Dense layer (to correct size for number of filters)
        y = self.in2f(x)
        # reshape y for element-wise multiplication by W
        nbh_size = neighbors.size()
        nbh = neighbors.view(-1, nbh_size[1] * nbh_size[2], 1)
        nbh = nbh.expand(-1, -1, y.size(2))
        y = torch.gather(y, 1, nbh)
        y = y.view(nbh_size[0], nbh_size[1], nbh_size[2], -1)
        #print(y.shape, 'yshape')
        # element-wise multiplication, aggregating and Dense layer
        #-----------NEW:attention in convolution--------------#
        if self.n_heads_conv > 0:
            W = self.AttentionConv(
                x, Weights=W)  #single head to match weight size
        #added softmax
        if softmax is not None:
            W = self.softmax(W)
        #------------------------------------------------------#
        y = y * W
        y = self.agg(y, pairwise_mask)
        y = self.f2out(y)
        return y
Пример #34
0
    def _run_one_fw(self, pixel_model, pixel_inp, cat_var, target, base_eps, avoid_target=True):
        batch_size, channels, height, width = pixel_inp.size()
        pixel_inp_jpeg = self._jpeg_cat(pixel_inp, cat_var, base_eps, batch_size, height, width)
        s = pixel_model(pixel_inp_jpeg)

        for it in range(self.nb_its):
            loss = self.criterion(s, target)
            loss.backward()

            if avoid_target:
                grad = cat_var.grad.data
            else:
                grad = -cat_var.grad.data

            def where_float(cond, if_true, if_false):
                return cond.float() * if_true + (1-cond.float()) * if_false

            def where_long(cond, if_true, if_false):
                return cond.long() * if_true + (1-cond.long()) * if_false

            abs_grad = torch.abs(grad).view(batch_size, -1)
            num_pixels = abs_grad.size()[1]
            sign_grad = torch.sign(grad)

            bound = where_float(sign_grad > 0, self.l1_max - cat_var, cat_var + self.l1_max).view(batch_size, -1)
                
            k_min = torch.zeros((batch_size,1), dtype=torch.long, requires_grad=False, device='cuda')
            k_max = torch.ones((batch_size,1), dtype=torch.long, requires_grad=False, device='cuda') * num_pixels
                
            # cum_bnd[k] is meant to track the L1 norm we end up with if we take 
            # the k indices with the largest gradient magnitude and push them to their boundary values (0 or 255)
            values, indices = torch.sort(abs_grad, descending=True)
            bnd = torch.gather(bound, 1, indices)
            # subtract bnd because we don't want the cumsum to include the final element
            cum_bnd = torch.cumsum(bnd, 1) - bnd
                
            # this is hard-coded as floor(log_2(256 * 256 * 3))
            for _ in range(17):
                k_mid = (k_min + k_max) // 2
                l1norms = torch.gather(cum_bnd, 1, k_mid)
                k_min = where_long(l1norms > base_eps, k_min, k_mid)
                k_max = where_long(l1norms > base_eps, k_mid, k_max)
                
            # next want to set the gradient of indices[0:k_min] to their corresponding bound
            magnitudes = torch.zeros((batch_size, num_pixels), requires_grad=False, device='cuda')
            for bi in range(batch_size):
                magnitudes[bi, indices[bi, :k_min[bi,0]]] = bnd[bi, :k_min[bi,0]]
                magnitudes[bi, indices[bi, k_min[bi,0]]] = base_eps[bi] - cum_bnd[bi, k_min[bi,0]]
                
            delta_it = sign_grad * magnitudes.view(cat_var.size())
            # These should always be exactly epsilon
            # l1_check = torch.norm(delta_it.view(batch_size, -1), 1.0, dim=1) / num_pixels
            # print('l1_check: %s' % l1_check)
            cat_var.data = cat_var.data + (delta_it - cat_var.data) / (it + 1.0)

            if it != self.nb_its - 1:
                # self.jpeg scales rounding_vars by base_eps, so we divide to rescale
                # its coordinates to [-1, 1]
                cat_var_temp = cat_var / base_eps[:, None]
                pixel_inp_jpeg = self._jpeg_cat(pixel_inp, cat_var_temp, base_eps, batch_size, height, width)
                s = pixel_model(pixel_inp_jpeg)
            cat_var.grad.data.zero_()
        return cat_var
Пример #35
0
    def forward(self, obj_heads, reg_heads, cls_heads, batch_anchors):
        device = cls_heads[0].device
        with torch.no_grad():
            filter_scores, filter_score_classes, filter_boxes = [], [], []
            for per_level_obj_pred, per_level_reg_pred, per_level_cls_pred, per_level_anchors in zip(
                    obj_heads, reg_heads, cls_heads, batch_anchors):
                per_level_obj_pred = per_level_obj_pred.view(
                    per_level_obj_pred.shape[0], -1,
                    per_level_obj_pred.shape[-1])
                per_level_reg_pred = per_level_reg_pred.view(
                    per_level_reg_pred.shape[0], -1,
                    per_level_reg_pred.shape[-1])
                per_level_cls_pred = per_level_cls_pred.view(
                    per_level_cls_pred.shape[0], -1,
                    per_level_cls_pred.shape[-1])
                per_level_anchors = per_level_anchors.view(
                    per_level_anchors.shape[0], -1,
                    per_level_anchors.shape[-1])

                per_level_obj_pred = torch.sigmoid(per_level_obj_pred)
                per_level_cls_pred = torch.sigmoid(per_level_cls_pred)

                # snap per_level_reg_pred from tx,ty,tw,th -> x_center,y_center,w,h -> x_min,y_min,x_max,y_max
                per_level_reg_pred[:, :, 0:2] = (
                    torch.sigmoid(per_level_reg_pred[:, :, 0:2]) +
                    per_level_anchors[:, :, 0:2]) * per_level_anchors[:, :,
                                                                      4:5]
                # pred_bboxes_wh=exp(twh)*anchor_wh/stride
                per_level_reg_pred[:, :, 2:4] = torch.exp(
                    per_level_reg_pred[:, :, 2:4]
                ) * per_level_anchors[:, :, 2:4] / per_level_anchors[:, :, 4:5]

                per_level_reg_pred[:, :, 0:
                                   2] = per_level_reg_pred[:, :, 0:
                                                           2] - 0.5 * per_level_reg_pred[:, :,
                                                                                         2:
                                                                                         4]
                per_level_reg_pred[:, :, 2:
                                   4] = per_level_reg_pred[:, :, 2:
                                                           4] + per_level_reg_pred[:, :,
                                                                                   0:
                                                                                   2]
                per_level_reg_pred = per_level_reg_pred.int()
                per_level_reg_pred[:, :,
                                   0] = torch.clamp(per_level_reg_pred[:, :,
                                                                       0],
                                                    min=0)
                per_level_reg_pred[:, :,
                                   1] = torch.clamp(per_level_reg_pred[:, :,
                                                                       1],
                                                    min=0)
                per_level_reg_pred[:, :,
                                   2] = torch.clamp(per_level_reg_pred[:, :,
                                                                       2],
                                                    max=self.image_w - 1)
                per_level_reg_pred[:, :,
                                   3] = torch.clamp(per_level_reg_pred[:, :,
                                                                       3],
                                                    max=self.image_h - 1)

                per_level_scores, per_level_score_classes = torch.max(
                    per_level_cls_pred, dim=2)
                per_level_scores = per_level_scores * per_level_obj_pred.squeeze(
                    -1)
                if per_level_scores.shape[1] >= self.top_n:
                    per_level_scores, indexes = torch.topk(per_level_scores,
                                                           self.top_n,
                                                           dim=1,
                                                           largest=True,
                                                           sorted=True)
                    per_level_score_classes = torch.gather(
                        per_level_score_classes, 1, indexes)
                    per_level_reg_pred = torch.gather(
                        per_level_reg_pred, 1,
                        indexes.unsqueeze(-1).repeat(1, 1, 4))

                filter_scores.append(per_level_scores)
                filter_score_classes.append(per_level_score_classes)
                filter_boxes.append(per_level_reg_pred)

            filter_scores = torch.cat(filter_scores, axis=1)
            filter_score_classes = torch.cat(filter_score_classes, axis=1)
            filter_boxes = torch.cat(filter_boxes, axis=1)

            batch_scores, batch_classes, batch_pred_bboxes = [], [], []
            for scores, score_classes, pred_bboxes in zip(
                    filter_scores, filter_score_classes, filter_boxes):
                score_classes = score_classes[
                    scores > self.min_score_threshold].float()
                pred_bboxes = pred_bboxes[
                    scores > self.min_score_threshold].float()
                scores = scores[scores > self.min_score_threshold].float()

                one_image_scores = (-1) * torch.ones(
                    (self.max_detection_num, ), device=device)
                one_image_classes = (-1) * torch.ones(
                    (self.max_detection_num, ), device=device)
                one_image_pred_bboxes = (-1) * torch.ones(
                    (self.max_detection_num, 4), device=device)

                if scores.shape[0] != 0:
                    # Sort boxes
                    sorted_scores, sorted_indexes = torch.sort(scores,
                                                               descending=True)
                    sorted_score_classes = score_classes[sorted_indexes]
                    sorted_pred_bboxes = pred_bboxes[sorted_indexes]

                    keep = nms(sorted_pred_bboxes, sorted_scores,
                               self.nms_threshold)
                    keep_scores = sorted_scores[keep]
                    keep_classes = sorted_score_classes[keep]
                    keep_pred_bboxes = sorted_pred_bboxes[keep]

                    final_detection_num = min(self.max_detection_num,
                                              keep_scores.shape[0])

                    one_image_scores[0:final_detection_num] = keep_scores[
                        0:final_detection_num]
                    one_image_classes[0:final_detection_num] = keep_classes[
                        0:final_detection_num]
                    one_image_pred_bboxes[
                        0:final_detection_num, :] = keep_pred_bboxes[
                            0:final_detection_num, :]

                one_image_scores = one_image_scores.unsqueeze(0)
                one_image_classes = one_image_classes.unsqueeze(0)
                one_image_pred_bboxes = one_image_pred_bboxes.unsqueeze(0)

                batch_scores.append(one_image_scores)
                batch_classes.append(one_image_classes)
                batch_pred_bboxes.append(one_image_pred_bboxes)

            batch_scores = torch.cat(batch_scores, axis=0)
            batch_classes = torch.cat(batch_classes, axis=0)
            batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)

            # batch_scores shape:[batch_size,max_detection_num]
            # batch_classes shape:[batch_size,max_detection_num]
            # batch_pred_bboxes shape[batch_size,max_detection_num,4]
            return batch_scores, batch_classes, batch_pred_bboxes
Пример #36
0
 def maskNLLLoss(self, inp, target, mask):
     nTotal = mask.sum()
     crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
     loss = crossEntropy.masked_select(mask).mean()
     loss = loss.to(self.device)
     return loss, nTotal.item()
Пример #37
0
    def forward(self, memory_cells, query, alignments, copy_source):
        """Compute attention with soft-copy.

        Args:
            memory_cells (SequenceBatch): of shape (batch_size, num_cells, memory_dim)
            query (Variable): of shape (batch_size, query_dim)
            alignments (SequenceBatch): int-valued, of shape
                (batch_size, num_cells). If something has no alignment, it will
                have value 0 in the mask.
            copy_source (Variable): of shape (batch_size, num_candidates)

        This behaves like normal attention, except we boost the
        exponentiated logits:

        exp_logits[i][j] += copy_source[i][alignments[i][j]]

        This is inspired by:
            "Incorporating Copying Mechanism in Sequence-to-Sequence Learning"
            http://www.aclweb.org/anthology/P16-1154

        Returns:
            AttentionOutput
        """
        if not self._is_subset(alignments.mask, memory_cells.mask):
            raise ValueError(
                'Alignments mask must be a subset of memory cells mask.')

        base_attn = self._base_attention(memory_cells,
                                         query)  # AttentionOutput

        exp_logits = torch.exp(base_attn.logits)  # (batch_size, num_cells)

        # y = torch.gather(x, dim=1, index=index)
        # y[i][j] = x[i][index[i][j]]
        boost = torch.gather(copy_source, dim=1, index=alignments.values)

        # no boost for items with no alignment
        boost = boost * alignments.mask

        # boost the exponentiated logits
        boosted_exp_logits = exp_logits + boost

        # normalize to compute final weights
        normalizer = torch.sum(boosted_exp_logits,
                               1).expand_as(boosted_exp_logits)
        # (batch_size, num_cells)

        weights = boosted_exp_logits / normalizer
        weights = Attention._mask_weights(weights, memory_cells.mask)

        if not np.isfinite(weights.data.sum()):
            raise ValueError('Some attention weights are NaN')
            # TODO(kelvin): need to avoid numerical precision issues
            # TODO(kelvin): need to avoid division by zero

        # compute context
        context = Attention._context_from_weights(weights, memory_cells)

        logits = torch.log(boosted_exp_logits)
        return SoftCopyAttentionOutput(weights=weights,
                                       context=context,
                                       logits=logits,
                                       orig_logits=base_attn.logits,
                                       boost=boost)
Пример #38
0
    def forward(self, *input):
        [
            token_embeddings, input_mask_variable, conversation_mask,
            max_num_utterances_batch
        ] = input
        conversation_batch_size = int(token_embeddings.shape[0] /
                                      max_num_utterances_batch)

        if self.args.fixed_utterance_encoder:
            utterance_encodings = token_embeddings
        else:
            utterance_encodings = self.dialogue_embedder.utterance_encoder(
                token_embeddings, input_mask_variable)
        utterance_encodings = utterance_encodings.view(
            conversation_batch_size, max_num_utterances_batch,
            utterance_encodings.shape[1])
        utterance_encodings_next = utterance_encodings[:, 1:, :].contiguous()
        utterance_encodings_prev = utterance_encodings[:, 0:-1, :].contiguous()

        conversation_encoded = self.dialogue_embedder([
            token_embeddings, input_mask_variable, conversation_mask,
            max_num_utterances_batch
        ])

        conversation_encoded_forward = conversation_encoded[:, 0, :]
        conversation_encoded_backward = conversation_encoded[:, 1, :]
        #conversation_encoded_forward = conversation_encoded.view(conversation_encoded.shape[0], 1, -1).squeeze(1)
        #conversation_encoded_backward = conversation_encoded.view(conversation_encoded.shape[0], 1, -1).squeeze(1)

        conversation_encoded_forward_reassembled = conversation_encoded_forward.view(
            conversation_batch_size, max_num_utterances_batch,
            conversation_encoded_forward.shape[1])
        conversation_encoded_backward_reassembled = conversation_encoded_backward.view(
            conversation_batch_size, max_num_utterances_batch,
            conversation_encoded_backward.shape[1])

        # Shift to prepare next and previous utterence encodings
        conversation_encoded_current1 = conversation_encoded_forward_reassembled[:,
                                                                                 0:
                                                                                 -1, :].contiguous(
                                                                                 )
        conversation_encoded_next = conversation_encoded_forward_reassembled[:,
                                                                             1:, :].contiguous(
                                                                             )
        conversation_mask_next = conversation_mask[:, 1:].contiguous()

        conversation_encoded_current2 = conversation_encoded_backward_reassembled[:,
                                                                                  1:, :].contiguous(
                                                                                  )
        conversation_encoded_previous = conversation_encoded_backward_reassembled[:,
                                                                                  0:
                                                                                  -1, :].contiguous(
                                                                                  )
        # conversation_mask_previous = conversation_mask[:, 0:-1].contiguous()

        # Gold Labels
        gold_indices = variable(
            LongTensor(range(conversation_encoded_current1.shape[1]))).view(
                -1, 1).repeat(conversation_batch_size, 1)

        # Linear transformation of both utterance representations
        transformed_current1 = self.current_dl_trasnformer1(
            conversation_encoded_current1)
        transformed_current2 = self.current_dl_trasnformer2(
            conversation_encoded_current2)

        transformed_next = self.next_dl_trasnformer(conversation_encoded_next)
        transformed_prev = self.prev_dl_trasnformer(
            conversation_encoded_previous)
        # transformed_next = self.next_dl_trasnformer(utterance_encodings_next)
        # transformed_prev = self.prev_dl_trasnformer(utterance_encodings_prev)

        # Output layer: Generate Scores for next and prev utterances
        next_logits = torch.bmm(transformed_current1,
                                transformed_next.transpose(2, 1))
        prev_logits = torch.bmm(transformed_current2,
                                transformed_prev.transpose(2, 1))

        # Computing custom masked cross entropy
        next_log_probs = F.log_softmax(next_logits, dim=2)
        prev_log_probs = F.log_softmax(prev_logits, dim=2)

        losses_next = -torch.gather(next_log_probs.view(
            next_log_probs.shape[0] * next_log_probs.shape[1], -1),
                                    dim=1,
                                    index=gold_indices)
        losses_prev = -torch.gather(prev_log_probs.view(
            prev_log_probs.shape[0] * prev_log_probs.shape[1], -1),
                                    dim=1,
                                    index=gold_indices)

        losses_masked = (losses_next.squeeze(1) * conversation_mask_next.view(conversation_mask_next.shape[0]*conversation_mask_next.shape[1]))\
            + (losses_prev.squeeze(1) * conversation_mask_next.view(conversation_mask_next.shape[0]*conversation_mask_next.shape[1]))

        loss = losses_masked.sum() / (2 * conversation_mask_next.float().sum())

        return loss
Пример #39
0
    def forward(self, im_data, im_info, gt_boxes, num_boxes):
        # shape:  im_data [1,c,w,h]  im_info[1,3]   gt_boxes[1,20,5]  num_boxes[1]
        batch_size = im_data.size(0)
        # im_data 为原始图像blob[1,3,850,600]

        im_info = im_info.data
        gt_boxes = gt_boxes.data
        num_boxes = num_boxes.data

        # feed image data to base model to obtain base feature map
        base_feat = self.RCNN_base(im_data)
        # feed base feature map to RPN to obtain rois
        rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(base_feat, im_info, gt_boxes, num_boxes)

        # if it is training phase, then use ground truth bboxes for refining
        if self.training:
            roi_data = self.RCNN_proposal_target(rois, gt_boxes, num_boxes)
            rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data

            rois_label = Variable(rois_label.view(-1).long())
            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:
            rois_label = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = 0
            rpn_loss_bbox = 0

        rois = Variable(rois)
        # 测试阶段rois格式为[1,300,5]维度为5,第一列全是0,
        # 并不表示roi的标签,仅仅是batch的index标识。gt_boxes的维度是(x,5),x是object的数量。
        # do roi pooling based on predicted rois
        # POOLING_MODE = align
        pooled_feat = self.RCNN_roi_align(base_feat, rois.view(-1, 5))
        #feed pooled feature to top model
        pooled_feat = self.head_to_tail(pooled_feat)

        # compute bbox offset ,roi池化后提取的roi特征计算边框预测值
        bbox_pred = self.RCNN_bbox_pred(pooled_feat)
        if self.training and not self.class_agnostic:
            # select the corresponding columns according ti roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(bbox_pred_view, 1, rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

         # compute object classification probability
        cls_score = self.RCNN_cls_score(pooled_feat)
        cls_prob = F.softmax(cls_score, 1)
        # 测试阶段cls_score为[300, 21], bbox_pred为[300, 84]
        RCNN_loss_cls = 0
        RCNN_loss_bbox = 0

        if self.training:
            # classification loss
            RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)

            # bounding box regression L1 loss
            RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)

        cls_prob = cls_prob.view(batch_size, rois.size(1), -1)  # 测试阶段[1, 300, 21]
        bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1)  # 测试阶段[1, 300, 84]

        return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label
Пример #40
0
    def forward(self, source: List[List[str]],
                target: List[List[str]]) -> torch.Tensor:
        """ Take a mini-batch of source and target sentences, compute the log-likelihood of
        target sentences under the language models learned by the NMT system.

        @param source (List[List[str]]): list of source sentence tokens
        @param target (List[List[str]]): list of target sentence tokens, wrapped by `<s>` and `</s>`

        @returns scores (Tensor): a variable/tensor of shape (b, ) representing the
                                    log-likelihood of generating the gold-standard target sentence for
                                    each example in the input batch. Here b = batch size.
        """
        # Compute sentence lengths
        source_lengths = [len(s) for s in source]

        # Convert list of lists into tensors

        ## A4 code
        # source_padded = self.vocab.src.to_input_tensor(source, device=self.device)   # Tensor: (src_len, b)
        # target_padded = self.vocab.tgt.to_input_tensor(target, device=self.device)   # Tensor: (tgt_len, b)

        # enc_hiddens, dec_init_state = self.encode(source_padded, source_lengths)
        # enc_masks = self.generate_sent_masks(enc_hiddens, source_lengths)
        # combined_outputs = self.decode(enc_hiddens, enc_masks, dec_init_state, target_padded)
        ## End A4 code

        ### YOUR CODE HERE for part 1k
        ### TODO:
        ###     Modify the code lines above as needed to fetch the character-level tensor
        ###     to feed into encode() and decode(). You should:
        ###     - Keep `target_padded` from A4 code above for predictions
        ###     - Add `source_padded_chars` for character level padded encodings for source
        ###     - Add `target_padded_chars` for character level padded encodings for target
        ###     - Modify calls to encode() and decode() to use the character level encodings

        target_padded = self.vocab.tgt.to_input_tensor(target,
                                                       device=self.device)
        source_padded_chars = self.vocab.tgt.to_input_tensor_char(
            source, device=self.device)
        target_padded_chars = self.vocab.tgt.to_input_tensor_char(
            target, device=self.device)
        enc_hiddens, dec_init_state = self.encode(source_padded_chars,
                                                  source_lengths)
        enc_masks = self.generate_sent_masks(enc_hiddens, source_lengths)
        combined_outputs = self.decode(enc_hiddens, enc_masks, dec_init_state,
                                       target_padded_chars)

        ### END YOUR CODE

        P = F.log_softmax(self.target_vocab_projection(combined_outputs),
                          dim=-1)

        # Zero out, probabilities for which we have nothing in the target text
        target_masks = (target_padded != self.vocab.tgt['<pad>']).float()

        # Compute log probability of generating true target words
        target_gold_words_log_prob = torch.gather(
            P, index=target_padded[1:].unsqueeze(-1),
            dim=-1).squeeze(-1) * target_masks[1:]
        scores = target_gold_words_log_prob.sum(
        )  # mhahn2 Small modification from A4 code.

        if self.charDecoder is not None:
            max_word_len = target_padded_chars.shape[-1]

            target_words = target_padded[1:].contiguous().view(-1)
            target_chars = target_padded_chars[1:].view(-1, max_word_len)
            target_outputs = combined_outputs.view(-1, 256)

            target_chars_oov = target_chars  # torch.index_select(target_chars, dim=0, index=oovIndices)
            rnn_states_oov = target_outputs  # torch.index_select(target_outputs, dim=0, index=oovIndices)
            oovs_losses = self.charDecoder.train_forward(
                target_chars_oov.t(),
                (rnn_states_oov.unsqueeze(0), rnn_states_oov.unsqueeze(0)))
            scores = scores - oovs_losses

        return scores
Пример #41
0
    def forward(self, im_data, im_info, gt_boxes, num_boxes):
        batch_size = im_data.size(0)

        im_info = im_info.data
        gt_boxes = gt_boxes.data
        num_boxes = num_boxes.data

        # feed image data to base model to obtain base feature map
        base_feat = self.RCNN_base(im_data)

        # feed base feature map tp RPN to obtain rois
        rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(
            base_feat, im_info, gt_boxes, num_boxes)

        # if it is training phrase, then use ground trubut bboxes for refining
        if self.training:
            roi_data = self.RCNN_proposal_target(rois, gt_boxes, num_boxes)
            rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data

            rois_label = Variable(rois_label.view(-1).long())
            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(
                rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(
                rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:
            rois_label = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = 0
            rpn_loss_bbox = 0

        rois = Variable(rois)
        # do roi pooling based on predicted rois

        if cfg.POOLING_MODE == 'crop':
            # pdb.set_trace()
            # pooled_feat_anchor = _crop_pool_layer(base_feat, rois.view(-1, 5))
            grid_xy = _affine_grid_gen(rois.view(-1, 5),
                                       base_feat.size()[2:], self.grid_size)
            grid_yx = torch.stack(
                [grid_xy.data[:, :, :, 1], grid_xy.data[:, :, :, 0]],
                3).contiguous()
            pooled_feat = self.RCNN_roi_crop(base_feat,
                                             Variable(grid_yx).detach())
            if cfg.CROP_RESIZE_WITH_MAX_POOL:
                pooled_feat = F.max_pool2d(pooled_feat, 2, 2)
        elif cfg.POOLING_MODE == 'align':
            pooled_feat = self.RCNN_roi_align(base_feat, rois.view(-1, 5))
        elif cfg.POOLING_MODE == 'pool':
            pooled_feat = self.RCNN_roi_pool(base_feat, rois.view(-1, 5))

        # feed pooled features to top model
        pooled_feat = self._head_to_tail(pooled_feat)

        # compute bbox offset
        bbox_pred = self.RCNN_bbox_pred(pooled_feat)
        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0),
                                            int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(
                bbox_pred_view, 1,
                rois_label.view(rois_label.size(0), 1,
                                1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        # compute object classification probability
        cls_score = self.RCNN_cls_score(pooled_feat)
        cls_prob = F.softmax(cls_score)

        RCNN_loss_cls = 0
        RCNN_loss_bbox = 0

        if self.training:
            # classification loss
            RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)

            # bounding box regression L1 loss
            RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target,
                                             rois_inside_ws, rois_outside_ws)

        cls_prob = cls_prob.view(batch_size, rois.size(1), -1)
        bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1)

        return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label
Пример #42
0
    def _viterbi_decode(self, feats, mask=None):
        """
        Args:
            feats: size=(batch_size, seq_len, self.target_size+2)
            mask: size=(batch_size, seq_len)
        Returns:
            decode_idx: (batch_size, seq_len), viterbi decode结果
            path_score: size=(batch_size, 1), 每个句子的得分
        """
        batch_size = feats.size(0)
        seq_len = feats.size(1)
        tag_size = feats.size(-1)
        #print(batch_size ,seq_len,tag_size)
        length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()
        mask = mask.transpose(1, 0).contiguous()
        ins_num = seq_len * batch_size
        feats = feats.transpose(1, 0).contiguous().view(
            ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)

        scores = feats + self.transitions.view(
            1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
        scores = scores.view(seq_len, batch_size, tag_size, tag_size)

        seq_iter = enumerate(scores)
        # record the position of the best score
        back_points = list()
        partition_history = list()
        mask = (1 - mask.long()).byte()
        try:
            _, inivalues = seq_iter.__next__()
        except:
            _, inivalues = seq_iter.next()
        partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
        partition_history.append(partition)

        for idx, cur_values in seq_iter:
            cur_values = cur_values + partition.contiguous().view(
                batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
            partition, cur_bp = torch.max(cur_values, 1)
            partition_history.append(partition.unsqueeze(-1))

            cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
            back_points.append(cur_bp)

        partition_history = torch.cat(partition_history).view(
            seq_len, batch_size, -1).transpose(1, 0).contiguous()

        last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1
        last_partition = torch.gather(
            partition_history, 1, last_position).view(batch_size, tag_size, 1)

        last_values = last_partition.expand(batch_size, tag_size, tag_size) + \
            self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size).to(device)
        _, last_bp = torch.max(last_values, 1)
        pad_zero = Variable(torch.zeros(batch_size, tag_size)).long()
        pad_zero = pad_zero.to(device)
        back_points.append(pad_zero)
        back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)

        pointer = last_bp[:, self.END_TAG_IDX]
        insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size)
        back_points = back_points.transpose(1, 0).contiguous()

        back_points.scatter_(1, last_position, insert_last)

        back_points = back_points.transpose(1, 0).contiguous()

        decode_idx = Variable(torch.LongTensor(seq_len, batch_size))
        decode_idx = decode_idx.to(device)
        decode_idx[-1] = pointer.data
        for idx in range(len(back_points)-2, -1, -1):
            pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1))
            decode_idx[idx] = pointer.view(-1).data
        path_score = None
        decode_idx = decode_idx.transpose(1, 0)
        return path_score, decode_idx
Пример #43
0
 def forward(self, input: torch.Tensor, padding=None):
     pretrained_indices = torch.gather(self.vocab_to_pretrained,
                                       dim=0,
                                       index=input)
     rnn_output = self.model(pretrained_indices)
     return EmbeddingOutput(all_layers=[rnn_output], last_layer=rnn_output)
Пример #44
0
    def _generate(self,
                  model,
                  sample,
                  prefix_tokens=None,
                  bos_token=None,
                  **kwargs):
        if not self.retain_dropout:
            model.eval()

        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v
            for k, v in sample['net_input'].items()
            if k != 'prev_output_tokens'
        }

        src_tokens = encoder_input['src_tokens']
        if src_tokens.dim() > 2:
            src_lengths = encoder_input['src_lengths']
        else:
            src_lengths = (src_tokens.ne(self.eos)
                           & src_tokens.ne(self.pad)).long().sum(dim=1)
        input_size = src_tokens.size()
        # batch dimension goes first followed by source lengths
        bsz = input_size[0]
        src_len = input_size[1]
        beam_size = self.beam_size

        if self.match_source_len:
            max_len = src_lengths.max().item()
        else:
            max_len = min(
                int(self.max_len_a * src_len + self.max_len_b),
                # exclude the EOS marker
                model.max_decoder_positions() - 1,
            )
        assert self.min_len <= max_len, 'min_len cannot be larger than max_len, please adjust these!'

        # compute the encoder output for each beam
        encoder_outs = model.forward_encoder(encoder_input)
        new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
        new_order = new_order.to(src_tokens.device).long()
        encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)

        # initialize buffers
        scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = src_tokens.new(bsz * beam_size,
                                max_len + 2).long().fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos if bos_token is None else bos_token
        attn, attn_buf = None, None

        # The blacklist indicates candidates that should be ignored.
        # For example, suppose we're sampling and have already finalized 2/5
        # samples. Then the blacklist would mark 2 positions as being ignored,
        # so that we only finalize the remaining 3 samples.
        blacklist = src_tokens.new_zeros(bsz, beam_size).eq(
            -1)  # forward and backward-compatible False mask

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) *
                        beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfin_idx):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size or step == max_len:
                return True
            return False

        def finalize_hypos(step, bbsz_idx, eos_scores):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step +
                                        2]  # skip the first index, which is EOS
            assert not tokens_clone.eq(self.eos).any()
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(
                0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1)**self.len_penalty

            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1
                else:
                    cum_unfin.append(prev)

            sents_seen = set()
            for i, (idx, score) in enumerate(
                    zip(bbsz_idx.tolist(), eos_scores.tolist())):
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                sents_seen.add((sent, unfin_idx))

                if self.match_source_len and step > src_lengths[unfin_idx]:
                    score = -math.inf

                def get_hypo():

                    if attn_clone is not None:
                        # remove padding tokens from attn scores
                        hypo_attn = attn_clone[i]
                    else:
                        hypo_attn = None

                    return {
                        'tokens': tokens_clone[i],
                        'score': score,
                        'attention': hypo_attn,  # src_len x tgt_len
                        'alignment': None,
                        'positional_scores': pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())

            newly_finished = []
            for sent, unfin_idx in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step, unfin_idx):
                    finished[sent] = True
                    newly_finished.append(unfin_idx)
            return newly_finished

        reorder_state = None
        batch_idxs = None
        for step in range(max_len + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(
                        batch_idxs.numel()).type_as(batch_idxs)
                    reorder_state.view(-1, beam_size).add_(
                        corr.unsqueeze(-1) * beam_size)
                model.reorder_incremental_state(reorder_state)
                encoder_outs = model.reorder_encoder_out(
                    encoder_outs, reorder_state)

            lprobs, avg_attn_scores = model.forward_decoder(
                tokens[:, :step + 1],
                encoder_outs,
                temperature=self.temperature,
            )
            lprobs[lprobs != lprobs] = -math.inf

            lprobs[:, self.pad] = -math.inf  # never select pad
            lprobs[:, self.unk] -= self.unk_penalty  # apply unk penalty

            # handle max length constraint
            if step >= max_len:
                lprobs[:, :self.eos] = -math.inf
                lprobs[:, self.eos + 1:] = -math.inf
            elif self.eos_factor is not None:
                # only consider EOS if its score is no less than a specified
                # factor of the best candidate score
                disallow_eos_mask = lprobs[:, self.
                                           eos] < self.eos_factor * lprobs.max(
                                               dim=1)[0]
                lprobs[disallow_eos_mask, self.eos] = -math.inf

            # handle prefix tokens (possibly with different lengths)
            if prefix_tokens is not None and step < prefix_tokens.size(
                    1) and step < max_len:
                prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(
                    1, beam_size).view(-1)
                prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
                prefix_mask = prefix_toks.ne(self.pad)
                lprobs[prefix_mask] = -math.inf
                lprobs[prefix_mask] = lprobs[prefix_mask].scatter_(
                    -1, prefix_toks[prefix_mask].unsqueeze(-1),
                    prefix_lprobs[prefix_mask])
                # if prefix includes eos, then we should make sure tokens and
                # scores are the same across all beams
                eos_mask = prefix_toks.eq(self.eos)
                if eos_mask.any():
                    # validate that the first beam matches the prefix
                    first_beam = tokens[eos_mask].view(
                        -1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
                    eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
                    target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
                    assert (first_beam == target_prefix).all()

                    def replicate_first_beam(tensor, mask):
                        tensor = tensor.view(-1, beam_size, tensor.size(-1))
                        tensor[mask] = tensor[mask][:, :1, :]
                        return tensor.view(-1, tensor.size(-1))

                    # copy tokens, scores and lprobs from the first beam to all beams
                    tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
                    scores = replicate_first_beam(scores, eos_mask_batch_dim)
                    lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim)
            elif step < self.min_len:
                # minimum length constraint (does not apply if using prefix_tokens)
                lprobs[:, self.eos] = -math.inf

            if self.no_repeat_ngram_size > 0:
                # for each beam and batch sentence, generate a list of previous ngrams
                gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
                for bbsz_idx in range(bsz * beam_size):
                    gen_tokens = tokens[bbsz_idx].tolist()
                    for ngram in zip(*[
                            gen_tokens[i:]
                            for i in range(self.no_repeat_ngram_size)
                    ]):
                        gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
                                gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]

            # Record attention scores
            if type(avg_attn_scores) is list:
                avg_attn_scores = avg_attn_scores[0]
            if avg_attn_scores is not None:
                if attn is None:
                    if src_tokens.dim() > 2:
                        attn = scores.new(
                            bsz * beam_size,
                            encoder_outs[0]["encoder_out"][0].size(0),
                            max_len + 2,
                        )
                    else:
                        attn = scores.new(bsz * beam_size, src_tokens.size(1),
                                          max_len + 2)
                    attn_buf = attn.clone()
                attn[:, :, step + 1].copy_(avg_attn_scores)

            scores = scores.type_as(lprobs)
            scores_buf = scores_buf.type_as(lprobs)
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)

            self.search.set_src_lengths(src_lengths)

            if self.no_repeat_ngram_size > 0:

                def calculate_banned_tokens(bbsz_idx):
                    # before decoding the next token, prevent decoding of ngrams that have already appeared
                    ngram_index = tuple(
                        tokens[bbsz_idx, step + 2 -
                               self.no_repeat_ngram_size:step + 1].tolist())
                    return gen_ngrams[bbsz_idx].get(ngram_index, [])

                if step + 2 - self.no_repeat_ngram_size >= 0:
                    # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
                    banned_tokens = [
                        calculate_banned_tokens(bbsz_idx)
                        for bbsz_idx in range(bsz * beam_size)
                    ]
                else:
                    banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]

                for bbsz_idx in range(bsz * beam_size):
                    lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf

            cand_scores, cand_indices, cand_beams = self.search.step(
                step,
                lprobs.view(bsz, -1, self.vocab_size),
                scores.view(bsz, beam_size, -1)[:, :, :step],
            )

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)

            # finalize hypotheses that end in eos, except for blacklisted ones
            # or candidates with a score of -inf
            eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
            eos_mask[:, :beam_size][blacklist] = 0

            # only consider eos when it's among the top beam_size indices
            torch.masked_select(
                cand_bbsz_idx[:, :beam_size],
                mask=eos_mask[:, :beam_size],
                out=eos_bbsz_idx,
            )

            finalized_sents = set()
            if eos_bbsz_idx.numel() > 0:
                torch.masked_select(
                    cand_scores[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_scores,
                )
                finalized_sents = finalize_hypos(step, eos_bbsz_idx,
                                                 eos_scores)
                num_remaining_sent -= len(finalized_sents)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            assert step < max_len

            if len(finalized_sents) > 0:
                new_bsz = bsz - len(finalized_sents)

                # construct batch_idxs which holds indices of batches to keep for the next pass
                batch_mask = cand_indices.new_ones(bsz)
                batch_mask[cand_indices.new(finalized_sents)] = 0
                batch_idxs = batch_mask.nonzero().squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)
                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]
                if prefix_tokens is not None:
                    prefix_tokens = prefix_tokens[batch_idxs]
                src_lengths = src_lengths[batch_idxs]
                blacklist = blacklist[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                scores_buf.resize_as_(scores)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                tokens_buf.resize_as_(tokens)
                if attn is not None:
                    attn = attn.view(bsz, -1)[batch_idxs].view(
                        new_bsz * beam_size, attn.size(1), -1)
                    attn_buf.resize_as_(attn)
                bsz = new_bsz
            else:
                batch_idxs = None

            # Set active_mask so that values > cand_size indicate eos or
            # blacklisted hypos and values < cand_size indicate candidate
            # active hypos. After this, the min values per row are the top
            # candidate active hypos.
            active_mask = buffer('active_mask')
            eos_mask[:, :beam_size] |= blacklist
            torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)],
                out=active_mask,
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, new_blacklist = buffer('active_hypos'), buffer(
                'new_blacklist')
            torch.topk(active_mask,
                       k=beam_size,
                       dim=1,
                       largest=False,
                       out=(new_blacklist, active_hypos))

            # update blacklist to ignore any finalized hypos
            blacklist = new_blacklist.ge(cand_size)[:, :beam_size]
            assert (~blacklist).any(dim=1).all()

            active_bbsz_idx = buffer('active_bbsz_idx')
            torch.gather(
                cand_bbsz_idx,
                dim=1,
                index=active_hypos,
                out=active_bbsz_idx,
            )
            active_scores = torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step + 1],
                dim=0,
                index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
            )
            torch.gather(
                cand_indices,
                dim=1,
                index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step],
                    dim=0,
                    index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )

            # copy attention for active hypotheses
            if attn is not None:
                torch.index_select(
                    attn[:, :, :step + 2],
                    dim=0,
                    index=active_bbsz_idx,
                    out=attn_buf[:, :, :step + 2],
                )

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            if attn is not None:
                attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(len(finalized)):
            finalized[sent] = sorted(finalized[sent],
                                     key=lambda r: r['score'],
                                     reverse=True)
        return finalized
Пример #45
0
    def decode_a(self, batch_size, u_input, u_hiddens, u_input_1hot, u_last_hidden, a_input,
                          db_vec, filling_vec, sample_type, decoder_type, qa_samples=None,
                           qa_hiddens=None, m_input=None, m_hiddens=None, m_input_1hot=None):
        return_gmb = True if 'gumbel' in sample_type else False
        a_prob, a_samples, gmb_samples = [], [], []
        log_pa = 0
        for si, sn in enumerate(self.reader.act_order):
            last_hidden = u_last_hidden[:-1]
            a_eos_idx = self.vocab.encode('<eos_%s>'%sn)
            a_sos_idx = self.vocab.encode('<go_%s>'%sn)
            at = cuda_(torch.ones(batch_size, 1)*a_sos_idx).long()
            emb_at = self.embedding(at)
            selc_read_m = cuda_(torch.zeros(batch_size, 1, self.hidden_size))
            vec_input = torch.cat([db_vec, filling_vec], dim=1)
            if sn == 'av':
                vec_input = cuda_(torch.zeros(vec_input.size()))
            prev_at = None
            for t in range(self.a_length):
                if decoder_type == 'pa':
                    prob, last_hidden, gru_out = self.pa_decoder[sn](
                                u_hiddens, emb_at, vec_input, last_hidden)
                else:
                    prob, last_hidden, gru_out, selc_read_m, gmb_samp = \
                        self.qa_decoder[sn](u_hiddens,
                                                        m_input, m_input_1hot, m_hiddens,
                                                        emb_at, vec_input, last_hidden,
                                                        selc_read_m=selc_read_m, temp=self.gumbel_temp,
                                                        return_gmb=return_gmb)
                if sample_type == 'supervised':
                    at = a_input[sn][:, t]
                elif sample_type == 'top1':
                    at = torch.topk(prob, 1)[1]
                elif sample_type == 'topk':
                    topk_probs, topk_words = torch.topk(prob.squeeze(1), cfg.topk_num)
                    widx = torch.multinomial(topk_probs, 1, replacement=True)
                    at = torch.gather(topk_words, 1, widx)      #[B]
                elif sample_type == 'posterior':
                    at = qa_samples[:, si * self.a_length + t]
                elif 'gumbel' in sample_type:
                    at = torch.argmax(gmb_samp, dim=1)   #[B]
                    emb_at = torch.matmul(gmb_samp, self.embedding.weight).unsqueeze(1) # [B, 1, H]
                    at, prev_at, gmb_samp = self.mask_samples(at, prev_at, batch_size, a_eos_idx, gmb_samp, True)
                    gmb_samples.append(gmb_samp)

                if 'gumbel' not in sample_type:
                    emb_at = self.embedding(at.view(-1, 1))
                    prob_at = torch.gather(prob, 1, at.view(-1, 1)).squeeze(1) #[B, 1]
                    log_prob_at = torch.log(prob_at)
                    at, prev_at, log_prob_at = self.mask_samples(at, prev_at, batch_size, a_eos_idx, log_prob_at)
                    log_pa += log_prob_at
                a_samples.append(at.view(-1))
                a_prob.append(prob)
        a_prob = torch.stack(a_prob, dim=1)
        a_samples = torch.stack(a_samples, dim=1)  # [B,Ta]

        if sample_type == 'posterior':
            a_samples, a_hiddens = qa_samples, qa_hiddens
        elif 'gumbel' not in sample_type:
            a_hiddens, a_last_hidden = self.a_encoder(a_samples, input_type='index')
        else:
            a_gumbel = torch.stack(gmb_samples, dim=1)   # [B,Ta, V]
            a_gumbel = torch.matmul(a_gumbel, self.embedding.weight)     # [B,Ta, E]
            a_hiddens, a_last_hidden = self.a_encoder(a_gumbel, input_type='embedding')

        return a_prob, a_samples, a_hiddens, log_pa
Пример #46
0
    def _translateBatch(self, batch, prefix_tokens = None):

        # Batch size is in different location depending on data.
        # prefix_tokens = None
        beam_size = self.opt.beam_size
        bsz = batch_size =  batch.size

        max_len = self.opt.max_sent_length

        gold_scores = batch.get('source').data.new(batch_size).float().zero_()
        gold_words = 0
        allgold_scores = []

        if batch.has_target:
            # Use the first model to decode
            model_ = self.models[0]

            gold_words, gold_scores, allgold_scores = model_.decode(batch)

        #  (3) Start decoding

        # initialize buffers
        src = batch.get('source')
        scores = src.new(bsz * beam_size, max_len + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = src.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0].fill_(self.bos)  # first token is bos
        attn, attn_buf = None, None
        nonpad_idxs = None
        src_tokens = src.transpose(0, 1)  # batch x time
        src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
        blacklist = src_tokens.new_zeros(bsz, beam_size).eq(-1)  # forward and backward-compatible False mask

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfinalized_scores=None):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                return True
            return False

        def finalize_hypos(step, bbsz_idx, eos_scores):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.
            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.
            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step + 2]  # skip the first index, which is EOS
            assert not tokens_clone.eq(self.eos).any()
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1) ** self.len_penalty

            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1
                else:
                    cum_unfin.append(prev)

            sents_seen = set()
            for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                sents_seen.add((sent, unfin_idx))

                # if self.match_source_len and step > src_lengths[unfin_idx]:
                #     score = -math.inf

                def get_hypo():

                    if attn_clone is not None:
                        # remove padding tokens from attn scores
                        hypo_attn = attn_clone[i]
                    else:
                        hypo_attn = None
                    # print(hypo_attn.shape)
                    # print(tokens_clone[i])
                    return {
                        'tokens': tokens_clone[i],
                        'score': score,
                        'attention': hypo_attn,  # src_len x tgt_len
                        'alignment': None,
                        'positional_scores': pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())

            newly_finished = []
            for sent, unfin_idx in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step, unfin_idx):
                    finished[sent] = True
                    newly_finished.append(unfin_idx)
            return newly_finished

        reorder_state = None
        batch_idxs = None

        # initialize the decoder state, including:
        # - expanding the context over the batch dimension len_src x (B*beam) x H
        # - expanding the mask over the batch dimension    (B*beam) x len_src
        decoder_states = dict()
        for i in range(self.n_models):
            decoder_states[i] = self.models[i].create_decoder_state(batch, beam_size, type=2, buffering=self.buffering)
            len_context = decoder_states[i].context.size(0)

        if self.dynamic_max_len:
            src_len = src.size(0)
            max_len = math.ceil(int(src_len) * self.dynamic_max_len_scale)

        # Start decoding
        for step in range(max_len + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
                    reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
                for i, model in enumerate(self.models):
                    decoder_states[i]._reorder_incremental_state(reorder_state)

            decode_input = tokens[:, :step + 1]
            lprobs, avg_attn_scores = self._decode(decode_input, decoder_states)
            # avg_attn_scores = None

            # lprobs[:, self.pad] = -math.inf  # never select pad

            # handle min and max length constraints
            if step >= max_len:
                lprobs[:, :self.eos] = -math.inf
                lprobs[:, self.eos + 1:] = -math.inf
            elif step < self.min_len:
                lprobs[:, self.eos] = -math.inf

            # handle prefix tokens (possibly with different lengths)
            # prefix_tokens = torch.tensor([[798, 1354]]).type_as(tokens)
            # prefix_tokens = [[1000, 1354, 2443, 1475, 1010,  242,  127, 1191,  902, 1808, 1589,   26]]
            if prefix_tokens is not None:
                prefix_tokens = torch.tensor(prefix_tokens).type_as(tokens)
                if step < prefix_tokens.size(1) and  step < max_len:
                    prefix_tokens = torch.tensor(prefix_tokens).type_as(tokens)
                    prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
                    prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
                    prefix_mask = prefix_toks.ne(self.pad)
                    lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs)

                    lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
                        -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
                    )

                    # if prefix includes eos, then we should make sure tokens and
                    # scores are the same across all beams
                    eos_mask = prefix_toks.eq(self.eos)
                    if eos_mask.any():
                        # validate that the first beam matches the prefix
                        first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
                        eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
                        target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
                        assert (first_beam == target_prefix).all()

                        def replicate_first_beam(tensor, mask):
                            tensor = tensor.view(-1, beam_size, tensor.size(-1))
                            tensor[mask] = tensor[mask][:, :1, :]
                            return tensor.view(-1, tensor.size(-1))

                        # copy tokens, scores and lprobs from the first beam to all beams
                        tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
                        scores = replicate_first_beam(scores, eos_mask_batch_dim)
                        lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim)

            if self.no_repeat_ngram_size > 0:
                # for each beam and batch sentence, generate a list of previous ngrams
                gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
                for bbsz_idx in range(bsz * beam_size):
                    gen_tokens = tokens[bbsz_idx].tolist()
                    for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
                        gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
                            gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]

            # Record attention scores
            if avg_attn_scores is not None:
                if attn is None:
                    attn = scores.new(bsz * beam_size, len_context , max_len + 2)
                    attn_buf = attn.clone()
                attn[:, :, step + 1].copy_(avg_attn_scores)

            scores = scores.type_as(lprobs)
            scores_buf = scores_buf.type_as(lprobs)
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)

            if self.no_repeat_ngram_size > 0:
                def calculate_banned_tokens(bbsz_idx):
                    # before decoding the next token, prevent decoding of ngrams that have already appeared
                    ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
                    return gen_ngrams[bbsz_idx].get(ngram_index, [])

                if step + 2 - self.no_repeat_ngram_size >= 0:
                    # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
                    banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
                else:
                    banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]

                for bbsz_idx in range(bsz * beam_size):
                    lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
            # print(lprobs.shape)
            cand_scores, cand_indices, cand_beams = self.search.step(
                step,
                lprobs.view(bsz, -1, self.vocab_size),
                scores.view(bsz, beam_size, -1)[:, :, :step],
            )

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)

            # finalize hypotheses that end in eos (except for blacklisted ones)

            eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)

            eos_mask[:, :beam_size][blacklist] = 0

            # only consider eos when it's among the top beam_size indices
            torch.masked_select(
                cand_bbsz_idx[:, :beam_size],
                mask=eos_mask[:, :beam_size],
                out=eos_bbsz_idx,
            )

            finalized_sents = set()
            if eos_bbsz_idx.numel() > 0:

                torch.masked_select(
                    cand_scores[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_scores,
                )
                finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores)
                num_remaining_sent -= len(finalized_sents)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            assert step < max_len

            if len(finalized_sents) > 0:
                new_bsz = bsz - len(finalized_sents)

                # construct batch_idxs which holds indices of batches to keep for the next pass
                batch_mask = cand_indices.new_ones(bsz)
                batch_mask[cand_indices.new(finalized_sents)] = 0
                batch_idxs = batch_mask.nonzero(as_tuple=False).squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)
                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]
                if prefix_tokens is not None:
                    prefix_tokens = prefix_tokens[batch_idxs]
                src_lengths = src_lengths[batch_idxs]
                blacklist = blacklist[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                scores_buf.resize_as_(scores)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                tokens_buf.resize_as_(tokens)
                if attn is not None:
                    attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
                    attn_buf.resize_as_(attn)
                bsz = new_bsz
            else:
                batch_idxs = None

            # Set active_mask so that values > cand_size indicate eos or
            # blacklisted hypos and values < cand_size indicate candidate
            # active hypos. After this, the min values per row are the top
            # candidate active hypos.
            active_mask = buffer('active_mask')
            eos_mask[:, :beam_size] |= blacklist
            torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)],
                out=active_mask,
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, new_blacklist = buffer('active_hypos'), buffer('new_blacklist')
            torch.topk(
                active_mask, k=beam_size, dim=1, largest=False,
                out=(new_blacklist, active_hypos)
            )

            # update blacklist to ignore any finalized hypos
            blacklist = new_blacklist.ge(cand_size)[:, :beam_size]
            assert (~blacklist).any(dim=1).all()

            active_bbsz_idx = buffer('active_bbsz_idx')
            torch.gather(
                cand_bbsz_idx, dim=1, index=active_hypos,
                out=active_bbsz_idx,
            )
            active_scores = torch.gather(
                cand_scores, dim=1, index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
            )
            torch.gather(
                cand_indices, dim=1, index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step], dim=0, index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores, dim=1, index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )

            # copy attention for active hypotheses
            if attn is not None:
                torch.index_select(
                    attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
                    out=attn_buf[:, :, :step + 2],
                )

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            if attn is not None:
                attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(len(finalized)):
            finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)


        return finalized, gold_scores, gold_words, allgold_scores
Пример #47
0
    def forward(self, x, mask, secret=None, globalx=None):
        """
            x: agents, episode, episode_step, neighbors, agent_type
            m: agents, episode, episode_step, neighbors, agent_type
            s: agents, episode, episode_step
            e: 1, episode, episode_step
        """
        x_ci_oh = onehot(x[:,:,:,:,0], mask, self.input_size, self.device) 

        if len(self.global_input_idx) != 0:
            x_ci_oh[:,:,:,:,-len(self.global_input_idx):] = globalx[:,:,:,self.global_input_idx].unsqueeze(-2)

        if self.add_catch:
            catch = (x[:,:,:,:,0] == x[:,:,:,:,1]).type(th.float)
            x_ci_oh[:,:,:,:,-len(self.global_input_idx)-1] = catch
        x_ci_oh[mask] = 0

        if self.control_type == 'binary':
            control = binary(secret, self.control_size).type(th.float)
            if len(self.global_control_idx) != 0:
                control[:,:,:,-len(self.global_control_idx):] = globalx[:,:,:,self.global_control_idx]
            data = x_ci_oh, mask, control

        elif len(self.global_control_idx) != 0:
            control = globalx[:,:,:,self.global_control_idx]
            data = x_ci_oh, mask, control
        else:
            data = x_ci_oh, mask

        # # random test
        # random_agent = th.randint(high=self.n_agents, size=(1,)).item()
        # random_batch = th.randint(high=x.shape[1], size=(1,)).item()
        # random_step = th.randint(high=x.shape[2], size=(1,)).item()
        # random_neighbor = th.randint(high=(x[random_agent,random_batch,random_step,:,0] != -1).sum(), size=(1,)).item()
        
        # # x_ci_oh test
        # color = x[random_agent,random_batch,random_step,random_neighbor,0]
        # vector = x_ci_oh[random_agent,random_batch,random_step,random_neighbor,:]
        # assert vector[color] == 1, f"{color} {vector}"
        # assert vector[:self.n_actions].sum() == 1

        # if len(self.global_input_idx) != 0:
        #     random_g_idx = th.randint(high=len(self.global_input_idx), size=(1,)).item()
        #     offset = self.input_size - len(self.global_input_idx)
        #     global_vec = globalx[random_agent,random_batch,random_step,self.global_input_idx]
        #     assert vector[offset+random_g_idx] == global_vec[random_g_idx]

        # if self.add_catch:
        #     is_catch = (color == x[random_agent,random_batch,random_step,random_neighbor,1])
        #     idx = - len(self.global_input_idx) - 1
        #     assert vector[idx] == is_catch
        # # end x_ci_oh test

        # # test control
        # if (secret is not None) or (len(self.global_control_idx) != 0):
        #     secret_size = self.control_size - len(self.global_control_idx)
        #     control_vec = control[random_agent,random_batch,random_step]
        # if secret is not None:
        #     binary_secret = np.unpackbits(secret[random_agent, random_batch, random_step].numpy().astype(np.uint8))[-secret_size:][::-1].copy()
        #     binary_secret = th.tensor(binary_secret, dtype=th.float)
        #     assert (control_vec[:secret_size] == binary_secret).all()
        # if len(self.global_control_idx) != 0:
        #     global_vec = globalx[random_agent,random_batch,random_step,self.global_control_idx]
        #     assert (control_vec[-len(self.global_control_idx):] == global_vec).all()
        # # end test control

        if self.multi_type == 'shared_weights':
            data = (
                d.reshape(1, d.shape[0]*d.shape[1], *d.shape[2:])
                for d in data
            )

        q = [
            model(*d)
            for model, *d in zip(self.models, *data)
        ]
        q = th.stack(q)

        if self.multi_type == 'shared_weights':
            q = q.reshape(*x.shape[:3], -1)

        if self.control_type == 'permute':
            permutations = self.permuations[secret]
            q = th.gather(q, -1, permutations)

        return q
Пример #48
0
    def forward(self, cls_heads, reg_heads, center_heads, batch_positions):
        with torch.no_grad():
            device = cls_heads[0].device

            filter_scores,filter_score_classes,filter_reg_heads,filter_batch_positions=[],[],[],[]
            for per_level_cls_head, per_level_reg_head, per_level_center_head, per_level_position in zip(
                    cls_heads, reg_heads, center_heads, batch_positions):
                per_level_cls_head = torch.sigmoid(per_level_cls_head)
                per_level_reg_head = torch.exp(per_level_reg_head)
                per_level_center_head = torch.sigmoid(per_level_center_head)

                per_level_cls_head = per_level_cls_head.view(
                    per_level_cls_head.shape[0], -1,
                    per_level_cls_head.shape[-1])
                per_level_reg_head = per_level_reg_head.view(
                    per_level_reg_head.shape[0], -1,
                    per_level_reg_head.shape[-1])
                per_level_center_head = per_level_center_head.view(
                    per_level_center_head.shape[0], -1,
                    per_level_center_head.shape[-1])
                per_level_position = per_level_position.view(
                    per_level_position.shape[0], -1,
                    per_level_position.shape[-1])

                scores, score_classes = torch.max(per_level_cls_head, dim=2)
                scores = torch.sqrt(scores * per_level_center_head.squeeze(-1))
                if scores.shape[1] >= self.top_n:
                    scores, indexes = torch.topk(scores,
                                                 self.top_n,
                                                 dim=1,
                                                 largest=True,
                                                 sorted=True)
                    score_classes = torch.gather(score_classes, 1, indexes)
                    per_level_reg_head = torch.gather(
                        per_level_reg_head, 1,
                        indexes.unsqueeze(-1).repeat(1, 1, 4))
                    per_level_position = torch.gather(
                        per_level_position, 1,
                        indexes.unsqueeze(-1).repeat(1, 1, 2))
                filter_scores.append(scores)
                filter_score_classes.append(score_classes)
                filter_reg_heads.append(per_level_reg_head)
                filter_batch_positions.append(per_level_position)

            filter_scores = torch.cat(filter_scores, axis=1)
            filter_score_classes = torch.cat(filter_score_classes, axis=1)
            filter_reg_heads = torch.cat(filter_reg_heads, axis=1)
            filter_batch_positions = torch.cat(filter_batch_positions, axis=1)

            batch_scores, batch_classes, batch_pred_bboxes = [], [], []
            for scores, score_classes, per_image_reg_preds, per_image_points_position in zip(
                    filter_scores, filter_score_classes, filter_reg_heads,
                    filter_batch_positions):
                pred_bboxes = self.snap_ltrb_reg_heads_to_x1_y1_x2_y2_bboxes(
                    per_image_reg_preds, per_image_points_position)

                score_classes = score_classes[
                    scores > self.min_score_threshold].float()
                pred_bboxes = pred_bboxes[
                    scores > self.min_score_threshold].float()
                scores = scores[scores > self.min_score_threshold].float()

                one_image_scores = (-1) * torch.ones(
                    (self.max_detection_num, ), device=device)
                one_image_classes = (-1) * torch.ones(
                    (self.max_detection_num, ), device=device)
                one_image_pred_bboxes = (-1) * torch.ones(
                    (self.max_detection_num, 4), device=device)

                if scores.shape[0] != 0:
                    # Sort boxes
                    sorted_scores, sorted_indexes = torch.sort(scores,
                                                               descending=True)
                    sorted_score_classes = score_classes[sorted_indexes]
                    sorted_pred_bboxes = pred_bboxes[sorted_indexes]

                    keep = nms(sorted_pred_bboxes, sorted_scores,
                               self.nms_threshold)
                    keep_scores = sorted_scores[keep]
                    keep_classes = sorted_score_classes[keep]
                    keep_pred_bboxes = sorted_pred_bboxes[keep]

                    final_detection_num = min(self.max_detection_num,
                                              keep_scores.shape[0])

                    one_image_scores[0:final_detection_num] = keep_scores[
                        0:final_detection_num]
                    one_image_classes[0:final_detection_num] = keep_classes[
                        0:final_detection_num]
                    one_image_pred_bboxes[
                        0:final_detection_num, :] = keep_pred_bboxes[
                            0:final_detection_num, :]

                one_image_scores = one_image_scores.unsqueeze(0)
                one_image_classes = one_image_classes.unsqueeze(0)
                one_image_pred_bboxes = one_image_pred_bboxes.unsqueeze(0)

                batch_scores.append(one_image_scores)
                batch_classes.append(one_image_classes)
                batch_pred_bboxes.append(one_image_pred_bboxes)

            batch_scores = torch.cat(batch_scores, axis=0)
            batch_classes = torch.cat(batch_classes, axis=0)
            batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)

            # batch_scores shape:[batch_size,max_detection_num]
            # batch_classes shape:[batch_size,max_detection_num]
            # batch_pred_bboxes shape[batch_size,max_detection_num,4]
            return batch_scores, batch_classes, batch_pred_bboxes
Пример #49
0
    def decode_z_parallel(self, batch_size, u_input, u_hiddens, u_input_1hot, u_last_hidden, z_input,
                          turn_states, sample_type, decoder_type, qz_samples=None, qz_hiddens=None,
                          m_input=None, m_hiddens=None, m_input_1hot=None, mask_otlg=False):
        return_gmb = True if 'gumbel' in sample_type else False

        slot_num = len(self.reader.otlg.informable_slots)
        pv_z_pr = turn_states.get('pv_%s_pr'%decoder_type, None)
        pv_z_h = turn_states.get('pv_%s_h'%decoder_type, None)
        pv_z_id = turn_states.get('pv_%s_id'%decoder_type, None)
        z_prob, z_samples, gmb_samples = [], [], []
        log_pz = 0

        last_hidden = u_last_hidden[:-1].repeat(1, slot_num, 1)   # [1, B*|slot|, H]
        u_input = u_input.repeat(slot_num ,1)
        u_input_1hot = u_input_1hot.repeat(slot_num, 1, 1)
        # u_input_1hot = get_one_hot_input(u_input, self.vocab_size)

        u_hiddens = u_hiddens.repeat(slot_num ,1, 1)
        if decoder_type != 'pz':
            m_input = m_input.repeat(slot_num ,1)
            m_input_1hot = m_input_1hot.repeat(slot_num, 1, 1)
            m_hiddens = m_hiddens.repeat(slot_num ,1, 1)
            # m_input_1hot = get_one_hot_input(m_input, self.vocab_size)
        if pv_z_pr is not None:
            # pv_z_pr = pv_z_pr.transpose(1,0).reshape(self.z_length, -1, cfg.vocab_size).transpose(1,0)    # [B*|slot|, T, V]
            # pv_z_h = pv_z_h.transpose(1,0).reshape(self.z_length, -1, cfg.hidden_size).transpose(1,0)    # [B*|slot|, T, H]
            # pv_z_id = pv_z_id.transpose(1,0).reshape(self.z_length, -1).transpose(1,0)    # [B*|slot|, T]
            pv_z_prob, pv_z_hid, pv_z_idx = [], [], []
            for si, sn in enumerate(self.reader.otlg.informable_slots):
                pv_z_prob.append(pv_z_pr[:, si*self.z_length : (si+1)*self.z_length])
                pv_z_hid.append(pv_z_h[:, si*self.z_length : (si+1)*self.z_length])
                pv_z_idx.append(pv_z_id[:, si*self.z_length : (si+1)*self.z_length])
            pv_z_pr = torch.cat(pv_z_prob, dim=0)
            pv_z_h = torch.cat(pv_z_hid, dim=0)
            pv_z_id = torch.cat(pv_z_idx, dim=0)
            # print(pv_z_pr.size())
            # print(pv_z_id)

        emb_zt, z_eos_idx = [], []
        for si, sn in enumerate(self.reader.otlg.informable_slots):
            z_eos_idx.append(self.vocab.encode(self.z_eos_map[sn]))
            emb_zt.append(self.get_first_z_input(sn, batch_size, self.multi_domain))
        emb_zt = torch.cat(emb_zt, dim=0)

        if z_input is not None:
            z_input_cat = []
            for si, sn in enumerate(self.reader.otlg.informable_slots):
                z_input_cat.append(z_input[sn])
            z_input_cat = torch.cat(z_input_cat, dim=0)

        zero_vec = cuda_(torch.zeros(batch_size * slot_num, 1, self.hidden_size))
        selc_read_u = selc_read_m = selc_read_pv_z = zero_vec

        prev_zt = None
        for t in range(self.z_length):
            if decoder_type == 'pz':
                prob, last_hidden, gru_out, selc_read_u, selc_read_pv_z = \
                    self.pz_decoder(u_input, u_input_1hot, u_hiddens,
                                        pv_z_prob=pv_z_pr, pv_z_hidden=pv_z_h, pv_z_idx=pv_z_id,
                                        emb_zt=emb_zt,  last_hidden=last_hidden,
                                        selc_read_u=selc_read_u, selc_read_pv_z=selc_read_pv_z)
            else:

                prob, last_hidden, gru_out, selc_read_u, selc_read_m, selc_read_pv_z, gmb_samp = \
                    self.qz_decoder(u_input, u_input_1hot, u_hiddens, m_input, m_input_1hot, m_hiddens,
                                            pv_z_prob=pv_z_pr, pv_z_hidden=pv_z_h, pv_z_idx=pv_z_id,
                                            emb_zt=emb_zt, last_hidden=last_hidden, selc_read_u=selc_read_u,
                                            selc_read_m=selc_read_m, selc_read_pv_z=selc_read_pv_z,
                                            temp=self.gumbel_temp, return_gmb=return_gmb)
            if mask_otlg:
                prob = self.mask_probs(prob, tokens_allow=self.reader.slot_value_mask[sn])

            if sample_type == 'supervised':
                # zt = z_input[sn][:, t]
                zt = z_input_cat[:, t]
            elif sample_type == 'top1':
                zt = torch.topk(prob, 1)[1]
            elif sample_type == 'topk':
                topk_probs, topk_words = torch.topk(prob.squeeze(1), cfg.topk_num)
                widx = torch.multinomial(topk_probs, 1, replacement=True)
                zt = torch.gather(topk_words, 1, widx)      #[B]
            elif sample_type == 'posterior':
                zt = qz_samples[:, si * self.z_length + t]
            elif 'gumbel' in sample_type:
                zt = torch.argmax(gmb_samp, dim=1)   #[B]
                emb_zt = torch.matmul(gmb_samp, self.embedding.weight).unsqueeze(1) # [B, 1, H]
                zt, prev_zt, gmb_samp = self.mask_samples(zt, prev_zt, batch_size, z_eos_idx, gmb_samp, True)
                gmb_samples.append(gmb_samp)

            if 'gumbel' not in sample_type:
                emb_zt = self.embedding(zt.view(-1, 1))
                prob_zt = torch.gather(prob, 1, zt.view(-1, 1)).squeeze(1) #[B, 1]
                log_prob_zt = torch.log(prob_zt)
                zt, prev_zt, log_prob_zt = self.mask_samples(zt, prev_zt, batch_size, z_eos_idx, log_prob_zt)
                log_pz += log_prob_zt
                z_samples.append(zt.view(-1))
                z_prob.append(prob)

        z_prob = torch.stack(z_prob, dim=1)  # [B*|slot|,Tz,V]
        z_samples= torch.stack(z_samples, dim=1)  # [B*|slot|,Tz]
        z_prob_col, z_samples_col = [], []
        for i in range(slot_num):
            z_prob_col.append(z_prob[i*batch_size : (i+1)*batch_size])
            z_samples_col.append(z_samples[i*batch_size : (i+1)*batch_size])
        z_prob = torch.cat(z_prob_col, dim=1)  # [B,Tz*|slot|,V]
        z_samples= torch.cat(z_samples_col, dim=1)  # [B,Tz*|slot|]

        #    Tz*|slot|, B

        if sample_type == 'posterior':
            z_samples, z_hiddens = qz_samples, qz_hiddens
        elif 'gumbel' not in sample_type:
            z_hiddens, z_last_hidden = self.z_encoder(z_samples, input_type='index')
        else:
            z_gumbel = torch.stack(gmb_samples, dim=1)   # [B, Tz, V]
            z_gumbel = torch.matmul(z_gumbel, self.embedding.weight)     # [B, Tz, E]
            z_hiddens, z_last_hidden = self.z_encoder(z_gumbel, input_type='embedding')

        retain = self.prev_z_continuous
        turn_states['pv_%s_h'%decoder_type] = z_hiddens if retain else z_hiddens.detach()
        turn_states['pv_%s_pr'%decoder_type] = z_prob if retain else z_prob.detach()
        turn_states['pv_%s_id'%decoder_type] = z_samples if retain else z_samples.detach()

        return z_prob, z_samples, z_hiddens, turn_states, log_pz
Пример #50
0
    def forward(self, feats, SISMs):
        N, C, H, W = feats.shape
        HW = H * W

        # Resize SISMs to the same size as the input feats.
        SISMs = resize(SISMs, [H, W])  # shape=[N, 1, H, W]

        # NFs: L2-normalized features.
        NFs = F.normalize(feats, dim=1)  # shape=[N, C, H, W]

        def CFM(SIVs, NFs):
            # Compute correlation maps [Figure 4] between SIVs and pixel-wise feature vectors in NFs by inner product.
            # We implement this process by ``F.conv2d()'', which takes SIVs as 1*1 kernels to convolve NFs.
            correlation_maps = F.conv2d(NFs, weight=SIVs)  # shape=[N, N, H, W]

            # Vectorize and normalize correlation maps.
            correlation_maps = F.normalize(correlation_maps.reshape(N, N, HW),
                                           dim=2)  # shape=[N, N, HW]

            # Compute the weight vectors [Equation 2].
            correlation_matrix = torch.matmul(correlation_maps,
                                              correlation_maps.permute(
                                                  0, 2, 1))  # shape=[N, N, N]
            weight_vectors = correlation_matrix.sum(dim=2).softmax(
                dim=1)  # shape=[N, N]

            # Fuse correlation maps with the weight vectors to build co-salient attention (CSA) maps.
            CSA_maps = torch.sum(correlation_maps *
                                 weight_vectors.view(N, N, 1),
                                 dim=1)  # shape=[N, HW]

            # Max-min normalize CSA maps.
            min_value = torch.min(CSA_maps, dim=1, keepdim=True)[0]
            max_value = torch.max(CSA_maps, dim=1, keepdim=True)[0]
            CSA_maps = (CSA_maps - min_value) / (max_value - min_value + 1e-12
                                                 )  # shape=[N, HW]
            CSA_maps = CSA_maps.view(N, 1, H, W)  # shape=[N, 1, H, W]
            return CSA_maps

        def get_SCFs(NFs):
            NFs = NFs.view(N, C, HW)  # shape=[N, C, HW]
            SCFs = torch.matmul(NFs.permute(0, 2, 1),
                                NFs).view(N, -1, H, W)  # shape=[N, HW, H, W]
            return SCFs

        # Compute SIVs [Section 3.2, Equation 1].
        SIVs = F.normalize((NFs * SISMs).mean(dim=3).mean(dim=2),
                           dim=1).view(N, C, 1, 1)  # shape=[N, C, 1, 1]

        # Compute co-salient attention (CSA) maps [Section 3.3].
        CSA_maps = CFM(SIVs, NFs)  # shape=[N, 1, H, W]

        # Compute self-correlation features (SCFs) [Section 3.4].
        SCFs = get_SCFs(NFs)  # shape=[N, HW, H, W]

        # Rearrange the channel order of SCFs to obtain RSCFs [Section 3.4].
        evidence = CSA_maps.view(N, HW)  # shape=[N, HW]
        indices = torch.argsort(evidence, dim=1, descending=True).view(
            N, HW, 1, 1).repeat(1, 1, H, W)  # shape=[N, HW, H, W]
        RSCFs = torch.gather(SCFs, dim=1, index=indices)  # shape=[N, HW, H, W]
        cosal_feat = self.conv(RSCFs * CSA_maps)  # shape=[N, 128, H, W]
        return cosal_feat
    def beam_decode(self, initial_state, encoder_hidden_states,
                    code_hidden_states, old_nl_hidden_states, masks,
                    max_out_len, batch_data, code_masks, old_nl_masks, device):
        """Beam search. Generates the top K candidate predictions."""
        batch_size = initial_state.shape[0]
        decoded_batch = [list() for _ in range(batch_size)]
        decoded_batch_scores = np.zeros([batch_size, BEAM_SIZE])

        decoder_input = torch.tensor(
            [[self.embedding_store.get_nl_id(START)]] * batch_size,
            device=device)
        decoder_input = decoder_input.unsqueeze(1)
        decoder_state = initial_state.unsqueeze(1).expand(
            -1, decoder_input.shape[1], -1).reshape(-1,
                                                    initial_state.shape[-1])

        beam_scores = torch.ones([batch_size, 1],
                                 dtype=torch.float32,
                                 device=device)
        beam_status = torch.zeros([batch_size, 1],
                                  dtype=torch.uint8,
                                  device=device)
        beam_predicted_ids = torch.full([batch_size, 1, max_out_len],
                                        self.embedding_store.get_end_id(),
                                        dtype=torch.int64,
                                        device=device)

        for i in range(max_out_len):
            beam_size = decoder_input.shape[1]
            if beam_status[:, 0].sum() == batch_size:
                break

            tiled_encoder_states = encoder_hidden_states.unsqueeze(1).expand(
                -1, beam_size, -1, -1)
            tiled_masks = masks.unsqueeze(1).expand(-1, beam_size, -1, -1)
            tiled_code_hidden_states = code_hidden_states.unsqueeze(1).expand(
                -1, beam_size, -1, -1)
            tiled_code_masks = code_masks.unsqueeze(1).expand(
                -1, beam_size, -1, -1)
            tiled_old_nl_hidden_states = old_nl_hidden_states.unsqueeze(
                1).expand(-1, beam_size, -1, -1)
            tiled_old_nl_masks = old_nl_masks.unsqueeze(1).expand(
                -1, beam_size, -1, -1)

            flat_decoder_input = decoder_input.reshape(-1,
                                                       decoder_input.shape[-1])
            flat_encoder_states = tiled_encoder_states.reshape(
                -1, tiled_encoder_states.shape[-2],
                tiled_encoder_states.shape[-1])
            flat_masks = tiled_masks.reshape(-1, tiled_masks.shape[-2],
                                             tiled_masks.shape[-1])
            flat_code_hidden_states = tiled_code_hidden_states.reshape(
                -1, tiled_code_hidden_states.shape[-2],
                tiled_code_hidden_states.shape[-1])
            flat_code_masks = tiled_code_masks.reshape(
                -1, tiled_code_masks.shape[-2], tiled_code_masks.shape[-1])
            flat_old_nl_hidden_states = tiled_old_nl_hidden_states.reshape(
                -1, tiled_old_nl_hidden_states.shape[-2],
                tiled_old_nl_hidden_states.shape[-1])
            flat_old_nl_masks = tiled_old_nl_masks.reshape(
                -1, tiled_old_nl_masks.shape[-2], tiled_old_nl_masks.shape[-1])

            decoder_input_embeddings = self.embedding_store.get_nl_embeddings(
                flat_decoder_input)
            decoder_attention_states, flat_decoder_state, generation_logprobs, copy_logprobs = self.decode(
                decoder_state, decoder_input_embeddings, flat_encoder_states,
                flat_code_hidden_states, flat_old_nl_hidden_states, flat_masks,
                flat_code_masks, flat_old_nl_masks)

            generation_logprobs = generation_logprobs.squeeze(1)
            copy_logprobs = copy_logprobs.squeeze(1)

            generation_logprobs = generation_logprobs.reshape(
                batch_size, beam_size, generation_logprobs.shape[-1])
            copy_logprobs = copy_logprobs.reshape(batch_size, beam_size,
                                                  copy_logprobs.shape[-1])

            prob_scores = torch.zeros([
                batch_size, beam_size,
                generation_logprobs.shape[-1] + copy_logprobs.shape[-1]
            ],
                                      dtype=torch.float32,
                                      device=device)
            prob_scores[:, :, :generation_logprobs.shape[-1]] = torch.exp(
                generation_logprobs)

            # Factoring in the copy scores
            expanded_token_ids = batch_data.input_ids.unsqueeze(1).expand(
                -1, beam_size, -1)
            prob_scores += scatter_add(src=torch.exp(copy_logprobs),
                                       index=expanded_token_ids,
                                       out=torch.zeros_like(prob_scores))

            top_scores_per_beam, top_indices_per_beam = torch.topk(prob_scores,
                                                                   k=BEAM_SIZE,
                                                                   dim=-1)

            updated_scores = torch.einsum('eb,ebm->ebm', beam_scores,
                                          top_scores_per_beam)
            retained_scores = beam_scores.unsqueeze(-1).expand(
                -1, -1, top_scores_per_beam.shape[-1])

            # Trying to keep at most one ray corresponding to completed beams
            end_mask = (torch.arange(beam_size) == 0).type(
                torch.float32).to(device)
            end_scores = torch.einsum('b,ebm->ebm', end_mask, retained_scores)

            possible_next_scores = torch.where(
                beam_status.unsqueeze(-1) == 1, end_scores, updated_scores)
            possible_next_status = torch.where(
                top_indices_per_beam == self.embedding_store.get_end_id(),
                torch.ones(
                    [batch_size, beam_size, top_scores_per_beam.shape[-1]],
                    dtype=torch.uint8,
                    device=device),
                beam_status.unsqueeze(-1).expand(
                    -1, -1, top_scores_per_beam.shape[-1]))

            possible_beam_predicted_ids = beam_predicted_ids.unsqueeze(
                2).expand(-1, -1, top_scores_per_beam.shape[-1], -1)
            pool_next_scores = possible_next_scores.reshape(batch_size, -1)
            pool_next_status = possible_next_status.reshape(batch_size, -1)
            pool_next_ids = top_indices_per_beam.reshape(batch_size, -1)
            pool_predicted_ids = possible_beam_predicted_ids.reshape(
                batch_size, -1, beam_predicted_ids.shape[-1])

            possible_decoder_state = flat_decoder_state.reshape(
                batch_size, beam_size, flat_decoder_state.shape[-1])
            possible_decoder_state = possible_decoder_state.unsqueeze(
                2).expand(-1, -1, top_scores_per_beam.shape[-1], -1)
            pool_decoder_state = possible_decoder_state.reshape(
                batch_size, -1, possible_decoder_state.shape[-1])

            top_scores, top_indices = torch.topk(pool_next_scores,
                                                 k=BEAM_SIZE,
                                                 dim=-1)
            next_step_ids = torch.gather(pool_next_ids, -1, top_indices)

            decoder_state = torch.gather(
                pool_decoder_state, 1,
                top_indices.unsqueeze(-1).expand(-1, -1,
                                                 pool_decoder_state.shape[-1]))
            decoder_state = decoder_state.reshape(-1, decoder_state.shape[-1])
            beam_status = torch.gather(pool_next_status, -1, top_indices)
            beam_scores = torch.gather(pool_next_scores, -1, top_indices)

            end_tags = torch.full_like(next_step_ids,
                                       self.embedding_store.get_end_id())
            next_step_ids = torch.where(beam_status == 1, end_tags,
                                        next_step_ids)

            beam_predicted_ids = torch.gather(
                pool_predicted_ids, 1,
                top_indices.unsqueeze(-1).expand(-1, -1,
                                                 pool_predicted_ids.shape[-1]))
            beam_predicted_ids[:, :, i] = next_step_ids

            unks = torch.full_like(
                next_step_ids,
                self.embedding_store.get_nl_id(Vocabulary.get_unk()))
            decoder_input = torch.where(
                next_step_ids < len(self.embedding_store.nl_vocabulary),
                next_step_ids, unks).unsqueeze(-1)

        return beam_predicted_ids, beam_scores
Пример #52
0
    def forward(self, im_data, gt_boxes, im_info):
        batch_size = im_data.size(0)
        im_info = im_info.data


        if not gt_boxes is None:
            gt_boxes = gt_boxes.data

        # feed image data to base model to obtain base feature map
        base_feat = self.RCNN_base(im_data)

        # feed base feature map to RPN to obtain rois
        rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(base_feat, im_info, gt_boxes)

        # if it is training phase, then use ground truth bboxes for refining
        if self.training:
            roi_data = self.RCNN_proposal_target(rois, gt_boxes)
            rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data

            rois_label = Variable(rois_label.view(-1).long())
            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:
            rois_label = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = 0
            rpn_loss_bbox = 0

        rois = Variable(rois)
        # do roi pooling based on predicted rois

        if cfg.pooling_mode == 'align':
            # pooled_feat = self.RCNN_roi_align(feature_map, rois.view(-1, 5))
            pooled_feat = roi_align(base_feat, rois.view(-1, 5), (cfg.pool_size, cfg.pool_size), 1.0/16)
        elif cfg.pooling_mode == 'pool':
            #pooled_feat = self.RCNN_roi_pool(feature_map, rois.view(-1, 5))
            pooled_feat = roi_pool(base_feat, rois.view(-1, 5), (cfg.pool_size, cfg.pool_size), 1.0/16)

        # feed pooled features to top model
        pooled_feat = self._head_to_tail(pooled_feat)

        # compute bbox offset
        bbox_pred = self.RCNN_bbox_pred(pooled_feat)
        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(bbox_pred_view, 1, rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        # compute object classification probability
        cls_score = self.RCNN_cls_score(pooled_feat)
        cls_prob = F.softmax(cls_score, 1)

        RCNN_loss_cls = 0
        RCNN_loss_bbox = 0

        if self.training:
            # classification loss
            RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)

            # bounding box regression L1 loss
            RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)


        cls_prob = cls_prob.view(batch_size, rois.size(1), -1)
        bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1)

        return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label
Пример #53
0
    def forward(self, x, loc, src, H, W, N_grid):
        if self.pre_conv is not None:
            x_map = token2map(x, loc, [H, W], self.inter_kernel,
                              self.inter_sigma)
            x_map = self.pre_conv(x_map)
            x1 = map2token(x_map, loc)
            x = self.pre_conv2(x)
            x = x + x1

        B, N, C = x.shape
        x = self.conf_norm(x)

        # confidence based sampling
        conf = self.conf(x)

        # down sample
        if self.sample_ratio < 1:
            sample_num = max(math.ceil((N - N_grid) * self.sample_ratio), 0)
            x_grid, loc_grid = x[:, :N_grid, :], loc[:, :N_grid, :]
            x_ada, loc_ada = x[:, N_grid:, :], loc[:, N_grid:, :]
            conf_ada = conf[:, N_grid:, :]

            index_down = gumble_top_k(conf_ada, sample_num, dim=1, T=1)
            loc_down = torch.gather(loc_ada, 1,
                                    index_down.expand([B, sample_num, 2]))
            x_down = torch.gather(x_ada, 1,
                                  index_down.expand([B, sample_num, C]))
            x_down = torch.cat([x_grid, x_down], dim=1)
            loc_down = torch.cat([loc_grid, loc_down], dim=1)
        else:
            x_down, loc_down = x, loc

        # extra points for low-level feature
        if self.extra_ratio > 0:
            # high res grid
            conf_map = token2map(conf, loc, [H, W], self.inter_kernel,
                                 self.inter_sigma)
            loc_extra = get_grid_loc(B,
                                     self.HR_res[0],
                                     self.HR_res[1],
                                     device=x.device)
            conf_extra = map2token(conf_map, loc_extra)

            extra_num = int(N * self.extra_ratio)
            index_extra = gumble_top_k(conf_extra, extra_num, dim=1, T=1)
            loc_extra = torch.gather(loc_extra, 1,
                                     index_extra.expand([B, extra_num, 2]))
            x_extra = inter_points(x, loc, loc_extra)
            conf_extra = inter_points(conf, loc, loc_extra)

            if self.use_local:
                local = extract_local_feature(src, loc_extra,
                                              self.local_kernel)
                local = local.flatten(2)
                local = self.local_conv1(local, x_extra)
                local = self.local_act1(local)
                local = self.local_conv2(local, x_extra)
                local = self.local_act2(local)
                x_extra = x_extra + local

            x = torch.cat([x, x_extra], dim=1)
            loc = torch.cat([loc, loc_extra], dim=1)
            conf = torch.cat([conf, conf_extra], dim=1)

        # attention block
        x_down = self.norm1(x_down)
        x_down = x_down + self.drop_path(self.attn(x_down, x, loc, H, W, conf))

        x_down = self.norm2(x_down)
        kernel_size = self.attn.sr_ratio + 1
        if self.sample_ratio <= 0.25:
            H, W = H // 2, W // 2
        x_down = x_down + self.drop_path(
            self.mlp(x_down, loc_down, H, W, kernel_size, 2))

        if vis and self.extra_ratio > 0:
            import matplotlib.pyplot as plt
            IMAGENET_DEFAULT_MEAN = torch.tensor([0.485, 0.456, 0.406],
                                                 device=src.device)[None, :,
                                                                    None, None]
            IMAGENET_DEFAULT_STD = torch.tensor([0.229, 0.224, 0.225],
                                                device=src.device)[None, :,
                                                                   None, None]
            src = src * IMAGENET_DEFAULT_STD + IMAGENET_DEFAULT_MEAN
            # for i in range(x.shape[0]):
            for i in range(1):
                img = src[i].permute(1, 2, 0).detach().cpu()

                ax = plt.subplot(1, 3, 1)
                ax.clear()
                conf_map = token2map(conf, loc, [H, W], self.inter_kernel,
                                     self.inter_sigma)
                conf_map = F.interpolate(conf_map,
                                         self.HR_res,
                                         mode='bilinear')
                # conf_map = token2map(conf, loc, self.HR_res, 1 + (self.inter_kernel-1) * self.HR_res[0] // H, self.inter_sigma)

                ax.imshow(conf_map[i, 0].detach().cpu())

                ax = plt.subplot(1, 3, 2)
                ax.clear()
                ax.imshow(img, extent=[0, 1, 0, 1])
                loc_show = loc
                loc_show = (loc_show + 1) * 0.5
                loc_grid = loc_show[i, :N_grid].detach().cpu().numpy()
                ax.scatter(loc_grid[:, 0], 1 - loc_grid[:, 1], c='blue', s=0.5)
                loc_ada = loc_show[i, N_grid:].detach().cpu().numpy()
                ax.scatter(loc_ada[:, 0], 1 - loc_ada[:, 1], c='red', s=0.5)

                ax = plt.subplot(1, 3, 3)
                ax.clear()
                ax.imshow(img, extent=[0, 1, 0, 1])
                loc_show = loc_down
                loc_show = (loc_show + 1) * 0.5
                loc_grid = loc_show[i, :N_grid].detach().cpu().numpy()
                ax.scatter(loc_grid[:, 0], 1 - loc_grid[:, 1], c='blue', s=0.5)
                loc_ada = loc_show[i, N_grid:].detach().cpu().numpy()
                ax.scatter(loc_ada[:, 0], 1 - loc_ada[:, 1], c='red', s=0.5)

        return x_down, loc_down
Пример #54
0
    def forward(self, cls_heads, reg_heads, batch_anchors):
        device = cls_heads[0].device
        with torch.no_grad():
            filter_scores,filter_score_classes,filter_reg_heads,filter_batch_anchors=[],[],[],[]
            for per_level_cls_head, per_level_reg_head, per_level_anchor in zip(
                    cls_heads, reg_heads, batch_anchors):
                scores, score_classes = torch.max(per_level_cls_head, dim=2)
                if scores.shape[1] >= self.top_n:
                    scores, indexes = torch.topk(scores,
                                                 self.top_n,
                                                 dim=1,
                                                 largest=True,
                                                 sorted=True)
                    score_classes = torch.gather(score_classes, 1, indexes)
                    per_level_reg_head = torch.gather(
                        per_level_reg_head, 1,
                        indexes.unsqueeze(-1).repeat(1, 1, 4))
                    per_level_anchor = torch.gather(
                        per_level_anchor, 1,
                        indexes.unsqueeze(-1).repeat(1, 1, 4))

                filter_scores.append(scores)
                filter_score_classes.append(score_classes)
                filter_reg_heads.append(per_level_reg_head)
                filter_batch_anchors.append(per_level_anchor)

            filter_scores = torch.cat(filter_scores, axis=1)
            filter_score_classes = torch.cat(filter_score_classes, axis=1)
            filter_reg_heads = torch.cat(filter_reg_heads, axis=1)
            filter_batch_anchors = torch.cat(filter_batch_anchors, axis=1)

            batch_scores, batch_classes, batch_pred_bboxes = [], [], []
            for per_image_scores, per_image_score_classes, per_image_reg_heads, per_image_anchors in zip(
                    filter_scores, filter_score_classes, filter_reg_heads,
                    filter_batch_anchors):
                pred_bboxes = self.snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
                    per_image_reg_heads, per_image_anchors)
                score_classes = per_image_score_classes[
                    per_image_scores > self.min_score_threshold].float()
                pred_bboxes = pred_bboxes[
                    per_image_scores > self.min_score_threshold].float()
                scores = per_image_scores[
                    per_image_scores > self.min_score_threshold].float()

                one_image_scores = (-1) * torch.ones(
                    (self.max_detection_num, ), device=device)
                one_image_classes = (-1) * torch.ones(
                    (self.max_detection_num, ), device=device)
                one_image_pred_bboxes = (-1) * torch.ones(
                    (self.max_detection_num, 4), device=device)

                if scores.shape[0] != 0:
                    # Sort boxes
                    sorted_scores, sorted_indexes = torch.sort(scores,
                                                               descending=True)
                    sorted_score_classes = score_classes[sorted_indexes]
                    sorted_pred_bboxes = pred_bboxes[sorted_indexes]

                    keep = nms(sorted_pred_bboxes, sorted_scores,
                               self.nms_threshold)
                    keep_scores = sorted_scores[keep]
                    keep_classes = sorted_score_classes[keep]
                    keep_pred_bboxes = sorted_pred_bboxes[keep]

                    final_detection_num = min(self.max_detection_num,
                                              keep_scores.shape[0])

                    one_image_scores[0:final_detection_num] = keep_scores[
                        0:final_detection_num]
                    one_image_classes[0:final_detection_num] = keep_classes[
                        0:final_detection_num]
                    one_image_pred_bboxes[
                        0:final_detection_num, :] = keep_pred_bboxes[
                            0:final_detection_num, :]

                one_image_scores = one_image_scores.unsqueeze(0)
                one_image_classes = one_image_classes.unsqueeze(0)
                one_image_pred_bboxes = one_image_pred_bboxes.unsqueeze(0)

                batch_scores.append(one_image_scores)
                batch_classes.append(one_image_classes)
                batch_pred_bboxes.append(one_image_pred_bboxes)

            batch_scores = torch.cat(batch_scores, axis=0)
            batch_classes = torch.cat(batch_classes, axis=0)
            batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)

            # batch_scores shape:[batch_size,max_detection_num]
            # batch_classes shape:[batch_size,max_detection_num]
            # batch_pred_bboxes shape[batch_size,max_detection_num,4]
            return batch_scores, batch_classes, batch_pred_bboxes
Пример #55
0
    def forward(
            self,
            scene_inds,
            txt_feats,
            txt_masks,
            hiddens,
            src_feats,
            src_masks,  # in case we'd like to try using image contexts as input 
            tgt_feats,
            tgt_masks,
            sample_mode):
        """
        Args:
            - **scene_inds** (bsize, )
            - **txt_feats**  (bsize, nturns, n_feature_dim)
            - **txt_masks**  (bsize, nturns)
            - **hiddens**    (num_layers, bsize, n_feature_dim)
            - **src_feats**  (bsize, nturns, nregions, n_feature_dim)
            - **src_masks**  (bsize, nturns, nregions)
            - **tgt_feats**  (bsize, nturns, nregions, n_feature_dim)
            - **tgt_masks**  (bsize, nturns, nregions)
            - **sample_mode**
                0: top1
                1: multinomial sampling
                2: circular
                3: fixed indices
                4: random
                5: rollout greedy search
        Returns
            - **output_feats** (bsize, nturns, (ninsts), n_feature_dim)
            - **next_hiddens** (num_layers, bsize, n_feature_dim)
            - **sample_logits** (bsize, nturns)
            - **sample_indices** (bsize, nturns)
        """
        input_feats = txt_feats

        if not self.cfg.use_txt_context:
            #TODO: do NOT use updater?
            bsize, nturns, fsize = input_feats.size()
            output_feats = input_feats  # For paragraph model
            # output_feats = self.updater(input_feats, hiddens.view(bsize, 1, fsize).expand(bsize, nturns, fsize))
            return output_feats, None, None, None
        else:
            if self.cfg.instance_dim < 2:
                # one dimensional context
                self.updater.flatten_parameters()
                output_feats, next_hiddens = self.updater(
                    input_feats, hiddens.unsqueeze(0))
                return output_feats, next_hiddens, None, None
            else:
                # two dimensional context
                bsize, nturns, input_dim = input_feats.size()
                bsize, ninsts, hidden_dim = hiddens.size()
                current_hiddens = hiddens
                output_feats, sample_logits, sample_indices, sample_rewards = [], [], [], []
                for i in range(nturns):
                    #######################################################
                    # search for the instance indices
                    #######################################################
                    query_feats = input_feats[:, i].unsqueeze(1)
                    if self.cfg.rl_finetune > 0:
                        #######################################################
                        # Learnable policy
                        #######################################################
                        if sample_mode < 2:
                            ###################################################
                            ## Inference mode
                            ###################################################
                            if i < self.cfg.instance_dim:
                                instance_inds = (
                                    (i % self.cfg.instance_dim) *
                                    query_feats.new_ones(bsize)).long()
                                logits = instance_inds.new_ones(
                                    bsize, self.cfg.instance_dim).float()
                            else:
                                instance_inds, logits = self.policy(
                                    query_feats.detach(),
                                    current_hiddens.detach(), sample_mode)
                        elif sample_mode == 5:
                            ###################################################
                            # rollout greedy search
                            ###################################################
                            instance_inds, rewards = \
                                self.rollout_search(
                                    i, nturns,
                                    input_feats[:, i:].view(bsize, nturns-i, input_dim).detach(),
                                    txt_masks[:, i:].view(bsize, nturns-i).detach(),
                                    scene_inds,
                                    current_hiddens.detach(),
                                    tgt_feats.detach(), tgt_masks.detach(),
                                    sample_mode=1)
                            ########################################################################
                            # TODO: whether to backprop more
                            ########################################################################
                            _, logits = self.policy(query_feats.detach(),
                                                    current_hiddens.detach(),
                                                    1)
                            sample_rewards.append(rewards)
                    else:
                        #######################################################
                        # Fixed policies
                        #######################################################
                        if sample_mode == 2:
                            instance_inds = (
                                (i % self.cfg.instance_dim) *
                                query_feats.new_ones(bsize)).long()
                        elif sample_mode == 3:
                            instance_inds = (
                                query_feats.new_zeros(bsize)).long()
                        elif sample_mode == 4:
                            instance_inds = torch.randint(
                                0, self.cfg.instance_dim,
                                size=(bsize, )).long()
                            if self.cfg.cuda:
                                instance_inds = instance_inds.cuda()
                        _, logits = self.policy(query_feats.detach(),
                                                current_hiddens.detach(), 1)

                    sample_indices.append(instance_inds)
                    sample_logits.append(logits)
                    #######################################################
                    # update the hidden states using the instance indices
                    #######################################################
                    instance_inds = instance_inds.view(bsize, 1, 1).expand(
                        bsize, 1, hidden_dim)
                    sample_hiddens = torch.gather(current_hiddens, 1,
                                                  instance_inds)
                    h = self.updater(query_feats, sample_hiddens)
                    next_hiddens = current_hiddens.clone()
                    next_hiddens.scatter_(dim=1, index=instance_inds, src=h)
                    output_feats.append(next_hiddens)
                    current_hiddens = next_hiddens
                sample_indices = torch.stack(sample_indices, 1)
                sample_logits = torch.stack(sample_logits, 1)
                output_feats = torch.stack(output_feats, 1)
                return output_feats, next_hiddens, sample_logits, sample_indices
Пример #56
0
    def train(self,
              batch: EpisodeBatch,
              t_env: int,
              episode_num: int,
              chosen_index=0,
              return_q_all=False):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        # Calculate estimated Q-Values
        mac_out = []
        if self.args.q_net_ensemble:
            mac_chosen = self.mac[chosen_index]
        else:
            mac_chosen = self.mac
        mac_chosen.init_hidden(batch.batch_size)
        #self.mac.init_latent(batch.batch_size)
        index = th.randint(mac_chosen.n_agents, [batch.batch_size])
        index = F.one_hot(index, mac_chosen.n_agents).to(th.bool)
        rp = random.random() < self.args.contrary_grad_p
        if self.args.random_agent_order:
            enemy_shape = self.scheme["obs"]["vshape"] - mac_chosen.n_agents * 8
            enemy_num = enemy_shape // 8
            assert enemy_num * 8 == enemy_shape
            order_enemy = np.arange(enemy_num)
            order_ally = np.arange(mac_chosen.n_agents - 1)
            np.random.shuffle(order_enemy)
            np.random.shuffle(order_ally)
            agent_order = [order_ally, order_enemy]
        else:
            agent_order = None
        for t in range(batch.max_seq_length):
            if self.args.mac == "robust_mac":
                agent_outs = mac_chosen.forward(batch,
                                                t=t,
                                                index=index,
                                                contrary_grad=rp,
                                                agent_order=agent_order)
            else:
                agent_outs = mac_chosen.forward(
                    batch, t=t, agent_order=agent_order)  #(bs,n,n_actions)
            mac_out.append(agent_outs)  #[t,(bs,n,n_actions)]
        mac_out = th.stack(mac_out, dim=1)  # Concat over time
        #(bs,t,n,n_actions), Q values of n_actions

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals_bm = th.gather(mac_out[:, :-1],
                                           dim=3,
                                           index=actions).squeeze(
                                               3)  # Remove the last dim
        # (bs,t,n) Q value of an action

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        if self.args.q_net_ensemble:
            target_mac_chosen = self.target_mac[chosen_index]
        else:
            target_mac_chosen = self.target_mac
        target_mac_chosen.init_hidden(batch.batch_size)  # (bs,n,hidden_size)
        #self.target_mac.init_latent(batch.batch_size)

        for t in range(batch.max_seq_length):
            target_agent_outs = target_mac_chosen.forward(
                batch, t=t, agent_order=agent_order)  #(bs,n,n_actions)
            target_mac_out.append(target_agent_outs)  #[t,(bs,n,n_actions)]

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(
            target_mac_out[1:],
            dim=1)  # Concat across time, dim=1 is time index
        #(bs,t,n,n_actions)

        # Mask out unavailable actions
        target_mac_out[avail_actions[:, 1:] == 0] = -9999999  # Q values

        # Max over target Q-Values
        if self.args.double_q:  # True for QMix
            # Get actions that maximise live Q (for double q-learning)
            mac_out_detach = mac_out.clone().detach(
            )  #return a new Tensor, detached from the current graph
            mac_out_detach[avail_actions == 0] = -9999999
            # (bs,t,n,n_actions), discard t=0
            cur_max_actions = mac_out_detach[:, 1:].max(
                dim=3, keepdim=True)[1]  # indices instead of values
            # (bs,t,n,1)
            target_max_qvals = th.gather(target_mac_out, 3,
                                         cur_max_actions).squeeze(3)
            # (bs,t,n,n_actions) ==> (bs,t,n,1) ==> (bs,t,n) max target-Q
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]

        # Mix
        loss_ensemble_w = th.tensor(0.0).to(target_max_qvals.device)
        loss_ensemble_b = th.tensor(0.0).to(target_max_qvals.device)
        if self.mixer is not None:
            if return_q_all:
                q_all = chosen_action_qvals_bm.detach().cpu().numpy()
            # chosen_index = random.randint(0,len(self.mixer_list)-1)
            chosen_mixer = self.mixer_list[chosen_index]
            chosen_target_mixer = self.target_mixer_list[chosen_index]
            if self.args.mixer == "aqmix":
                chosen_action_qvals = chosen_mixer(chosen_action_qvals_bm,
                                                   batch["state"][:, :-1],
                                                   actions.detach())
                target_max_qvals = chosen_target_mixer(
                    target_max_qvals, batch["state"][:, 1:],
                    cur_max_actions.detach())
            else:
                chosen_action_qvals = chosen_mixer(chosen_action_qvals_bm,
                                                   batch["state"][:, :-1])
                target_max_qvals = chosen_target_mixer(target_max_qvals,
                                                       batch["state"][:, 1:])
            if len(self.mixer_list) > 1:
                other_w_list = []
                other_b_list = []
                chosen_action_qvals, w, b = chosen_action_qvals
                target_max_qvals, _, _ = target_max_qvals
                for i in range(len(self.mixer_list)):
                    if i != chosen_index:
                        if self.args.mixer == "aqmix":
                            _, other_w, other_b = self.mixer_list[i](
                                chosen_action_qvals_bm, batch["state"][:, :-1],
                                actions.detach())
                        else:
                            _, other_w, other_b = self.mixer_list[i](
                                chosen_action_qvals_bm, batch["state"][:, :-1])
                        other_w_list.append(other_w.detach())
                        other_b_list.append(other_b.detach())
                for other_w, other_b in zip(other_w_list, other_b_list):
                    norm_delta_w = th.mean(
                        th.abs((w - other_w).squeeze(2).sum(1)))
                    norm_w = th.mean(th.abs(w.detach().squeeze(2).sum(1)))
                    norm_w_other = th.mean(th.abs(other_w.squeeze(2).sum(1)))
                    loss_ensemble_w += norm_delta_w / (norm_w_other + norm_w +
                                                       1e-8)
                    norm_delta_b = th.mean(th.abs((b - other_b).view(-1)))
                    norm_b = th.mean(th.abs(b.detach().view(-1)))
                    norm_b_other = th.mean(th.abs(other_b.view(-1)))
                    loss_ensemble_b += norm_delta_b / (norm_b_other + norm_b +
                                                       1e-8)
                loss_ensemble_w /= len(self.mixer_list) - 1
                loss_ensemble_b /= len(self.mixer_list) - 1
            if return_q_all:
                mix_q_all = chosen_action_qvals.detach().cpu().numpy()
                termed = terminated.detach().cpu().numpy()
            # (bs,t,1)

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 -
                                               terminated) * target_max_qvals

        # Td-error
        td_error = (chosen_action_qvals - targets.detach()
                    )  # no gradient through target net
        # (bs,t,1)

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Normal L2 loss, take mean over actual data
        loss_q = (masked_td_error**2).sum() / mask.sum()
        loss = loss_q - self.args.en_w_alpha * loss_ensemble_w - self.args.en_b_alpha * loss_ensemble_b
        # Optimise
        if self.args.q_net_ensemble:
            current_optim = self.optimiser_list[chosen_index]
            current_para = self.params_list[chosen_index]
        else:
            current_optim = self.optimiser
            current_para = self.params
        current_optim.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(
            current_para, self.args.grad_norm_clip)  # max_norm
        try:
            grad_norm = grad_norm.item()
        except:
            pass
        current_optim.step()

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("loss_q", loss_q.item(), t_env)
            self.logger.log_stat("loss_en_w", loss_ensemble_w.item(), t_env)
            self.logger.log_stat("loss_en_b", loss_ensemble_b.item(), t_env)

            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs",
                (masked_td_error.abs().sum().item() / mask_elems), t_env)
            self.logger.log_stat("q_taken_mean",
                                 (chosen_action_qvals * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env
        if return_q_all:
            return q_all, mix_q_all, termed
Пример #57
0
    def forward(self,
                context_ids: TextFieldTensors,
                query_ids: TextFieldTensors,
                context_lens: torch.Tensor,
                query_lens: torch.Tensor,
                mask_label: Optional[torch.Tensor] = None,
                start_label: Optional[torch.Tensor] = None,
                end_label: Optional[torch.Tensor] = None,
                metadata: List[Dict[str, Any]] = None):
        # concat the context and query to the encoder
        # get the indexers first
        indexers = context_ids.keys()
        dialogue_ids = {}

        # 获取context和query的长度
        context_len = torch.max(context_lens).item()
        query_len = torch.max(query_lens).item()

        # [B, _len]
        context_mask = get_mask_from_sequence_lengths(context_lens,
                                                      context_len)
        query_mask = get_mask_from_sequence_lengths(query_lens, query_len)
        for indexer in indexers:
            # get the various variables of context and query
            dialogue_ids[indexer] = {}
            for key in context_ids[indexer].keys():
                context = context_ids[indexer][key]
                query = query_ids[indexer][key]
                # concat the context and query in the length dim
                dialogue = torch.cat([context, query], dim=1)
                dialogue_ids[indexer][key] = dialogue

        # get the outputs of the dialogue
        if isinstance(self._text_field_embedder, TextFieldEmbedder):
            embedder_outputs = self._text_field_embedder(dialogue_ids)
        else:
            embedder_outputs = self._text_field_embedder(
                **dialogue_ids[self._index_name])

        # get the outputs of the query and context
        # [B, _len, embed_size]
        context_last_layer = embedder_outputs[:, :context_len].contiguous()
        query_last_layer = embedder_outputs[:, context_len:].contiguous()

        # ------- 计算span预测的结果 -------
        # 我们想要知道query中的每一个mask位置的token后面需要补充的内容
        # 也就是其对应的context中span的start和end的位置
        # 同理,将context扩展成 [b, query_len, context_len, embed_size]
        context_last_layer = context_last_layer.unsqueeze(dim=1).expand(
            -1, query_len, -1, -1).contiguous()
        # [b, query_len, context_len]
        context_expand_mask = context_mask.unsqueeze(dim=1).expand(
            -1, query_len, -1).contiguous()

        # 将上面3个部分拼接在一起
        # 这里表示query中所有的position
        span_embed_size = context_last_layer.size(-1)

        if self.training and self._neg_sample_ratio > 0.0:
            # 对mask中0的位置进行采样
            # [B*query_len, ]
            sample_mask_label = mask_label.view(-1)
            # 获取展开之后的长度以及需要采样的负样本的数量
            mask_length = sample_mask_label.size(0)
            mask_sum = int(
                torch.sum(sample_mask_label).item() * self._neg_sample_ratio)
            mask_sum = max(10, mask_sum)
            # 获取需要采样的负样本的索引
            neg_indexes = torch.randint(low=0,
                                        high=mask_length,
                                        size=(mask_sum, ))
            # 限制在长度范围内
            neg_indexes = neg_indexes[:mask_length]
            # 将负样本对应的位置mask置为1
            sample_mask_label[neg_indexes] = 1
            # [B, query_len]
            use_mask_label = sample_mask_label.view(
                -1, query_len).to(dtype=torch.bool)
            # 过滤掉query中pad的部分, [B, query_len]
            use_mask_label = use_mask_label & query_mask
            span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1)
            # 选择context部分可以使用的内容
            # [B_mask, context_len, span_embed_size]
            span_context_matrix = context_last_layer.masked_select(
                span_mask).view(-1, context_len, span_embed_size).contiguous()
            # 选择query部分可以使用的向量
            span_query_vector = query_last_layer.masked_select(
                span_mask.squeeze(dim=-1)).view(-1,
                                                span_embed_size).contiguous()
            span_context_mask = context_expand_mask.masked_select(
                span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous()
        else:
            use_mask_label = query_mask
            span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1)
            # 选择context部分可以使用的内容
            # [B_mask, context_len, span_embed_size]
            span_context_matrix = context_last_layer.masked_select(
                span_mask).view(-1, context_len, span_embed_size).contiguous()
            # 选择query部分可以使用的向量
            span_query_vector = query_last_layer.masked_select(
                span_mask.squeeze(dim=-1)).view(-1,
                                                span_embed_size).contiguous()
            span_context_mask = context_expand_mask.masked_select(
                span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous()

        # 得到span属于每个位置的logits
        # [B_mask, context_len]
        span_start_probs = self.start_attention(span_query_vector,
                                                span_context_matrix,
                                                span_context_mask)
        span_end_probs = self.end_attention(span_query_vector,
                                            span_context_matrix,
                                            span_context_mask)

        span_start_logits = torch.log(span_start_probs + self._eps)
        span_end_logits = torch.log(span_end_probs + self._eps)

        # [B_mask, 2],最后一个维度第一个表示start的位置,第二个表示end的位置
        best_spans = get_best_span(span_start_logits, span_end_logits)
        # 计算得到每个best_span的分数
        best_span_scores = (
            torch.gather(span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) +
            torch.gather(span_end_logits, 1, best_spans[:, 1].unsqueeze(1)))
        # [B_mask, ]
        best_span_scores = best_span_scores.squeeze(1)

        # 将重要的信息写入到输出中
        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_spans": best_spans,
            "best_span_scores": best_span_scores
        }

        # 如果存在标签,则使用标签计算loss
        if start_label is not None:
            loss = self._calc_loss(span_start_logits, span_end_logits,
                                   use_mask_label, start_label, end_label,
                                   best_spans)
            output_dict["loss"] = loss
        if metadata is not None:
            predict_rewrite_results = self._get_rewrite_result(
                use_mask_label, best_spans, query_lens, context_lens, metadata)
            output_dict['rewrite_results'] = predict_rewrite_results
        return output_dict
Пример #58
0
    def train(self,
              batch: EpisodeBatch,
              t_env: int,
              episode_num: int,
              show_demo=False,
              save_data=None):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        # Calculate estimated Q-Values
        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = torch.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = torch.gather(mac_out[:, :-1],
                                           dim=3,
                                           index=actions).squeeze(
                                               3)  # Remove the last dim

        x_mac_out = mac_out.clone().detach()
        x_mac_out[avail_actions == 0] = -9999999
        max_action_qvals, max_action_index = x_mac_out[:, :-1].max(dim=3)

        max_action_index = max_action_index.detach().unsqueeze(3)
        is_max_action = (max_action_index == actions).int().float()

        if show_demo:
            q_i_data = chosen_action_qvals.detach().cpu().numpy()
            q_data = (max_action_qvals -
                      chosen_action_qvals).detach().cpu().numpy()

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = torch.stack(target_mac_out[1:],
                                     dim=1)  # Concat across time

        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions == 0] = -9999999
            cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]
            target_max_qvals = torch.gather(target_mac_out, 3,
                                            cur_max_actions).squeeze(3)
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]

        # Mix
        if self.mixer is not None:
            chosen_action_qvals = self.mixer(chosen_action_qvals,
                                             batch["state"][:, :-1])
            target_max_qvals = self.target_mixer(target_max_qvals,
                                                 batch["state"][:, 1:])

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 -
                                               terminated) * target_max_qvals

        if show_demo:
            tot_q_data = chosen_action_qvals.detach().cpu().numpy()
            tot_target = targets.detach().cpu().numpy()
            if self.mixer == None:
                tot_q_data = np.mean(tot_q_data, axis=2)
                tot_target = np.mean(tot_target, axis=2)

            print('action_pair_%d_%d' % (save_data[0], save_data[1]),
                  np.squeeze(q_data[:, 0]), np.squeeze(q_i_data[:, 0]),
                  np.squeeze(tot_q_data[:, 0]), np.squeeze(tot_target[:, 0]))
            self.logger.log_stat(
                'action_pair_%d_%d' % (save_data[0], save_data[1]),
                np.squeeze(tot_q_data[:, 0]), t_env)
            return

        # Td-error
        td_error = (chosen_action_qvals - targets.detach())

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error**2).sum() / mask.sum()

        masked_hit_prob = torch.mean(is_max_action, dim=2) * mask
        hit_prob = masked_hit_prob.sum() / mask.sum()

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.params,
                                                   self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("hit_prob", hit_prob.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs",
                (masked_td_error.abs().sum().item() / mask_elems), t_env)
            self.logger.log_stat("q_taken_mean",
                                 (chosen_action_qvals * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env
Пример #59
0
    def forward(self, x, mask, gt_target=None, soft_threshold=0.8):
        mask.require_grad = False
        x.require_grad = False
        adjusted_weight = mask[:, 0:1, :].clone().detach().unsqueeze(
            0)  # weights for SC
        for i in range(self.num_stages - 1):
            adjusted_weight = torch.cat(
                (adjusted_weight, mask[:,
                                       0:1, :].clone().detach().unsqueeze(0)))
        #print(adjusted_weight.size())
        confidence = []
        feature = []
        if gt_target is not None:
            gt_target = gt_target.unsqueeze(0)

        # stage 1
        out1, feature1 = self.stage1(x, mask)
        outputs = out1.unsqueeze(0)
        feature.append(feature1)
        confidence.append(F.softmax(out1, dim=1) * mask[:, 0:1, :])
        confidence[0].require_grad = False

        if gt_target is None:
            max_conf, _ = torch.max(confidence[0], dim=1)
            max_conf = max_conf.unsqueeze(1).clone().detach()
            max_conf.require_grad = False
            decrease_flag = (max_conf > soft_threshold).float()
            increase_flag = mask[:, 0:1, :].clone().detach() - decrease_flag
            adjusted_weight[1] = max_conf.neg().exp(
            ) * decrease_flag + max_conf.exp() * increase_flag  # for stage 2
        else:
            gt_conf = torch.gather(confidence[0], dim=1, index=gt_target)
            decrease_flag = (gt_conf > soft_threshold).float()
            increase_flag = mask[:, 0:1, :].clone().detach() - decrease_flag
            adjusted_weight[1] = gt_conf.neg().exp(
            ) * decrease_flag + gt_conf.exp() * increase_flag

        # stage 2,...,n
        curr_stage = 0
        for s in self.stages:
            curr_stage = curr_stage + 1
            temp = feature[0]
            for i in range(1, len(feature)):
                temp = torch.cat((temp, feature[i]), dim=1) * mask[:, 0:1, :]
            temp = torch.cat((temp, x), dim=1)
            curr_out, curr_feature = s(temp, mask)
            outputs = torch.cat((outputs, curr_out.unsqueeze(0)), dim=0)
            feature.append(curr_feature)
            confidence.append(F.softmax(curr_out, dim=1) * mask[:, 0:1, :])
            confidence[curr_stage].require_grad = False
            if curr_stage == self.num_stages - 1:  # curr_stage starts from 0
                break  # don't need to compute the next stage's confidence when current stage = last cascade stage

            if gt_target is None:
                max_conf, _ = torch.max(confidence[curr_stage], dim=1)
                max_conf = max_conf.unsqueeze(1).clone().detach()
                max_conf.require_grad = False
                decrease_flag = (max_conf > soft_threshold).float()
                increase_flag = mask[:,
                                     0:1, :].clone().detach() - decrease_flag
                adjusted_weight[
                    curr_stage +
                    1] = max_conf.neg().exp() * decrease_flag + max_conf.exp(
                    ) * increase_flag  # output the weight for the next stage
            else:
                gt_conf = torch.gather(confidence[curr_stage],
                                       dim=1,
                                       index=gt_target)
                decrease_flag = (gt_conf > soft_threshold).float()
                increase_flag = mask[:,
                                     0:1, :].clone().detach() - decrease_flag
                adjusted_weight[curr_stage + 1] = gt_conf.neg().exp(
                ) * decrease_flag + gt_conf.exp() * increase_flag

        output_weight = adjusted_weight.detach()
        output_weight.require_grad = False
        adjusted_weight = adjusted_weight / torch.sum(
            adjusted_weight, 0)  # normalization among stages
        temp = F.softmax(out1, dim=1) * adjusted_weight[0]
        for i in range(1, self.num_stages):
            temp += F.softmax(outputs[i], dim=1) * adjusted_weight[i]
        confidenceF = temp * mask[:, 0:1, :]  # input of fusion stage

        #  Inner LBP for confidenceF
        barrier, BGM_output = self.fullBarrier(x)
        if self.use_lbp:
            confidenceF = self.lbp_in(confidenceF, barrier)

        #  fusion stage: for more consistent output because of the combination of cascade stages may have much fluctuations
        out, _ = self.stageF(confidenceF,
                             mask)  # use mixture of cascade stages

        #  Final LBP for output
        if self.use_lbp:
            for i in range(self.num_soft_lbp):
                out = self.lbp_out(out, barrier)

        confidence_last = torch.clamp(
            F.softmax(out, dim=1), min=1e-4, max=1 -
            1e-4) * mask[:, 0:1, :]  # torch.clamp for training stability
        outputs = torch.cat((outputs, confidence_last.unsqueeze(0)), dim=0)
        return outputs, BGM_output, output_weight
Пример #60
0
    def forward(self, heatmap_heads, offset_heads, wh_heads):
        with torch.no_grad():
            device = heatmap_heads.device
            heatmap_heads = torch.sigmoid(heatmap_heads)

            batch_scores, batch_classes, batch_pred_bboxes = [], [], []
            for per_image_heatmap_heads, per_image_offset_heads, per_image_wh_heads in zip(
                    heatmap_heads, offset_heads, wh_heads):
                #filter and keep points which value large than the surrounding 8 points
                per_image_heatmap_heads = self.nms(per_image_heatmap_heads)
                topk_score, topk_indexes, topk_classes, topk_ys, topk_xs = self.get_topk(
                    per_image_heatmap_heads, K=self.topk)

                per_image_offset_heads = per_image_offset_heads.permute(
                    1, 2, 0).contiguous().view(-1, 2)
                per_image_offset_heads = torch.gather(
                    per_image_offset_heads, 0, topk_indexes.repeat(1, 2))
                topk_xs = topk_xs + per_image_offset_heads[:, 0:1]
                topk_ys = topk_ys + per_image_offset_heads[:, 1:2]

                per_image_wh_heads = per_image_wh_heads.permute(
                    1, 2, 0).contiguous().view(-1, 2)
                per_image_wh_heads = torch.gather(per_image_wh_heads, 0,
                                                  topk_indexes.repeat(1, 2))

                topk_bboxes = torch.cat([
                    topk_xs - per_image_wh_heads[:, 0:1] / 2,
                    topk_ys - per_image_wh_heads[:, 1:2] / 2,
                    topk_xs + per_image_wh_heads[:, 0:1] / 2,
                    topk_ys + per_image_wh_heads[:, 1:2] / 2
                ],
                                        dim=1)

                topk_bboxes = topk_bboxes * self.stride

                topk_bboxes[:, 0] = torch.clamp(topk_bboxes[:, 0], min=0)
                topk_bboxes[:, 1] = torch.clamp(topk_bboxes[:, 1], min=0)
                topk_bboxes[:, 2] = torch.clamp(topk_bboxes[:, 2],
                                                max=self.image_w - 1)
                topk_bboxes[:, 3] = torch.clamp(topk_bboxes[:, 3],
                                                max=self.image_h - 1)

                one_image_scores = (-1) * torch.ones(
                    (self.max_detection_num, ), device=device)
                one_image_classes = (-1) * torch.ones(
                    (self.max_detection_num, ), device=device)
                one_image_pred_bboxes = (-1) * torch.ones(
                    (self.max_detection_num, 4), device=device)

                topk_classes = topk_classes[
                    topk_score > self.min_score_threshold].float()
                topk_bboxes = topk_bboxes[
                    topk_score > self.min_score_threshold].float()
                topk_score = topk_score[
                    topk_score > self.min_score_threshold].float()

                final_detection_num = min(self.max_detection_num,
                                          topk_score.shape[0])

                one_image_scores[0:final_detection_num] = topk_score[
                    0:final_detection_num]
                one_image_classes[0:final_detection_num] = topk_classes[
                    0:final_detection_num]
                one_image_pred_bboxes[0:final_detection_num, :] = topk_bboxes[
                    0:final_detection_num, :]

                batch_scores.append(one_image_scores.unsqueeze(0))
                batch_classes.append(one_image_classes.unsqueeze(0))
                batch_pred_bboxes.append(one_image_pred_bboxes.unsqueeze(0))

            batch_scores = torch.cat(batch_scores, axis=0)
            batch_classes = torch.cat(batch_classes, axis=0)
            batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)

            # batch_scores shape:[batch_size,topk]
            # batch_classes shape:[batch_size,topk]
            # batch_pred_bboxes shape[batch_size,topk,4]
            return batch_scores, batch_classes, batch_pred_bboxes