def eliminate_rows(self, prob_sc, ind, phis): """ eliminate rows of phis and prob_matrix scale """ length = prob_sc.size()[1] mask = (prob_sc[:, :, 0] > 0.85).type(dtype) rang = (Variable(torch.range(0, length - 1).unsqueeze(0) .expand_as(mask)). type(dtype)) ind_sc = torch.sort(rang * (1-mask) + length * mask, 1)[1] # permute prob_sc m = mask.unsqueeze(2).expand_as(prob_sc) mm = m.clone() mm[:, :, 1:] = 0 prob_sc = (torch.gather(prob_sc * (1 - m) + mm, 1, ind_sc.unsqueeze(2).expand_as(prob_sc))) # compose permutations ind = torch.gather(ind, 1, ind_sc) active = torch.gather(1-mask, 1, ind_sc) # permute phis active1 = active.unsqueeze(2).expand_as(phis) ind1 = ind.unsqueeze(2).expand_as(phis) active2 = active.unsqueeze(1).expand_as(phis) ind2 = ind.unsqueeze(1).expand_as(phis) phis_out = torch.gather(phis, 1, ind1) * active1 phis_out = torch.gather(phis_out, 2, ind2) * active2 return prob_sc, ind, phis_out, active
def sample_relax_given_class(logits, samp): cat = Categorical(logits=logits) u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) gumbels = -torch.log(-torch.log(u)) z = logits + gumbels b = samp #torch.argmax(z, dim=1) logprob = cat.log_prob(b).view(B,1) u_b = torch.gather(input=u, dim=1, index=b.view(B,1)) z_tilde_b = -torch.log(-torch.log(u_b)) z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b)) z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b) z = z_tilde u_b = torch.gather(input=u, dim=1, index=b.view(B,1)) z_tilde_b = -torch.log(-torch.log(u_b)) u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b)) z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b) return z, z_tilde, logprob
def hard_example_mining(dist_mat, labels, return_inds=False): """For each anchor, find the hardest positive and negative sample. Args: dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] labels: pytorch LongTensor, with shape [N] return_inds: whether to return the indices. Save time if `False`(?) Returns: dist_ap: pytorch Variable, distance(anchor, positive); shape [N] dist_an: pytorch Variable, distance(anchor, negative); shape [N] p_inds: pytorch LongTensor, with shape [N]; indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 n_inds: pytorch LongTensor, with shape [N]; indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 NOTE: Only consider the case in which all labels have same num of samples, thus we can cope with all anchors in parallel. """ assert len(dist_mat.size()) == 2 assert dist_mat.size(0) == dist_mat.size(1) N = dist_mat.size(0) # shape [N, N] is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) # `dist_ap` means distance(anchor, positive) # both `dist_ap` and `relative_p_inds` with shape [N, 1] dist_ap, relative_p_inds = torch.max( dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) # `dist_an` means distance(anchor, negative) # both `dist_an` and `relative_n_inds` with shape [N, 1] dist_an, relative_n_inds = torch.min( dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) # shape [N] dist_ap = dist_ap.squeeze(1) dist_an = dist_an.squeeze(1) if return_inds: # shape [N, N] ind = (labels.new().resize_as_(labels) .copy_(torch.arange(0, N).long()) .unsqueeze( 0).expand(N, N)) # shape [N, 1] p_inds = torch.gather( ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) n_inds = torch.gather( ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) # shape [N] p_inds = p_inds.squeeze(1) n_inds = n_inds.squeeze(1) return dist_ap, dist_an, p_inds, n_inds return dist_ap, dist_an
def sort_by_embeddings(self, Phis, Inputs_N, e): ind = torch.sort(e, 1)[1].squeeze() for i, phis in enumerate(Phis): # rearange phis phis_out = (torch.gather(Phis[i], 1, ind.unsqueeze(2) .expand_as(phis))) Phis[i] = (torch.gather(phis_out, 2, ind.unsqueeze(1) .expand_as(phis))) # rearange inputs Inputs_N[i] = torch.gather(Inputs_N[i], 1, ind.unsqueeze(2).expand_as(Inputs_N[i])) return Phis, Inputs_N
def proposal_layer(self, rpn_class, rpn_bbox): # handling proposals scores = rpn_class[:, :, 1] # Box deltas [batch, num_rois, 4] deltas_mul = Variable(torch.from_numpy(np.reshape( self.config.RPN_BBOX_STD_DEV, [1, 1, 4]).astype(np.float32))).cuda() deltas = rpn_bbox * deltas_mul pre_nms_limit = min(6000, self.anchors.shape[0]) scores, ix = torch.topk(scores, pre_nms_limit, dim=-1, largest=True, sorted=True) ix = torch.unsqueeze(ix, 2) ix = torch.cat([ix, ix, ix, ix], dim=2) deltas = torch.gather(deltas, 1, ix) _anchors = [] for i in range(self.config.IMAGES_PER_GPU): anchors = Variable(torch.from_numpy( self.anchors.astype(np.float32))).cuda() _anchors.append(anchors) anchors = torch.stack(_anchors, 0) pre_nms_anchors = torch.gather(anchors, 1, ix) refined_anchors = apply_box_deltas_graph(pre_nms_anchors, deltas) # Clip to image boundaries. [batch, N, (y1, x1, y2, x2)] height, width = self.config.IMAGE_SHAPE[:2] window = np.array([0, 0, height, width]).astype(np.float32) window = Variable(torch.from_numpy(window)).cuda() refined_anchors_clipped = clip_boxes_graph(refined_anchors, window) refined_proposals = [] for i in range(self.config.IMAGES_PER_GPU): indices = nms( torch.cat([refined_anchors_clipped.data[i], scores.data[i]], 1), 0.7) indices = indices[:self.proposal_count] indices = torch.stack([indices, indices, indices, indices], dim=1) indices = Variable(indices).cuda() proposals = torch.gather(refined_anchors_clipped[i], 0, indices) padding = self.proposal_count - proposals.size()[0] proposals = torch.cat( [proposals, Variable(torch.zeros([padding, 4])).cuda()], 0) refined_proposals.append(proposals) rpn_rois = torch.stack(refined_proposals, 0) return rpn_rois
def _score_sentence(self, scores, mask, tags): """ input: scores: variable (seq_len, batch, tag_size, tag_size) mask: (batch, seq_len) tags: tensor (batch, seq_len) output: score: sum of score for gold sequences within whole batch """ # Gives the score of a provided tag sequence batch_size = scores.size(1) seq_len = scores.size(0) tag_size = scores.size(2) ## convert tag value into a new format, recorded label bigram information to index new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len)) if self.gpu: new_tags = new_tags.cuda() for idx in range(seq_len): if idx == 0: ## start -> first score new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0] else: new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx] ## transition for label to STOP_TAG end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) ## length for batch, last word position = length - 1 length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() ## index the label id of last word end_ids = torch.gather(tags, 1, length_mask - 1) ## index the transition score for end_id to STOP_TAG end_energy = torch.gather(end_transition, 1, end_ids) ## convert tag as (seq_len, batch_size, 1) new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1) ### need convert tags id to search from 400 positions of scores tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size ## mask transpose to (seq_len, batch_size) tg_energy = tg_energy.masked_select(mask.transpose(1,0)) # ## calculate the score from START_TAG to first label # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) # start_energy = torch.gather(start_transition, 1, tags[0,:]) ## add all score together # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() gold_score = tg_energy.sum() + end_energy.sum() return gold_score
def word_pre_train_forward(self, sentence, position): """ output of forward language model args: sentence (char_seq_len, batch_size): char-level representation of sentence position (word_seq_len, batch_size): position of blank space in char-level representation of sentence """ embeds = self.char_embeds(sentence) d_embeds = self.dropout(embeds) lstm_out, hidden = self.forw_char_lstm(d_embeds) tmpsize = position.size() position = position.unsqueeze(2).expand(tmpsize[0], tmpsize[1], self.char_hidden_dim) select_lstm_out = torch.gather(lstm_out, 0, position) d_lstm_out = self.dropout(select_lstm_out).view(-1, self.char_hidden_dim) if self.if_highway: char_out = self.forw2word(d_lstm_out) d_char_out = self.dropout(char_out) else: d_char_out = d_lstm_out pre_score = self.word_pre_train_out(d_char_out) return pre_score, hidden
def decode_with_crf(crf, word_reps, mask_v, l_map): """ decode with viterbi algorithm and return score """ seq_len = word_reps.size(0) bat_size = word_reps.size(1) decoded_crf = crf.decode(word_reps, mask_v) scores = crf.cal_score(word_reps).data mask_v = mask_v.data decoded_crf = decoded_crf.data decoded_crf_withpad = torch.cat((torch.cuda.LongTensor(1,bat_size).fill_(l_map['<start>']), decoded_crf), 0) decoded_crf_withpad = decoded_crf_withpad.transpose(0,1).cpu().numpy() label_size = len(l_map) bi_crf = [] cur_len = decoded_crf_withpad.shape[1]-1 for i_l in decoded_crf_withpad: bi_crf.append([i_l[ind] * label_size + i_l[ind + 1] for ind in range(0, cur_len)] + [ i_l[cur_len] * label_size + l_map['<pad>']]) bi_crf = torch.cuda.LongTensor(bi_crf).transpose(0,1).unsqueeze(2) tg_energy = torch.gather(scores.view(seq_len, bat_size, -1), 2, bi_crf).view(seq_len, bat_size) # seq_len * bat_size tg_energy = tg_energy.transpose(0,1).masked_select(mask_v.transpose(0,1)) tg_energy = tg_energy.cpu().numpy() masks = mask_v.sum(0) crf_result_scored_by_crf = [] start = 0 for i, mask in enumerate(masks): end = start + mask crf_result_scored_by_crf.append(tg_energy[start:end].sum()) start = end crf_result_scored_by_crf = np.array(crf_result_scored_by_crf) return decoded_crf.cpu().transpose(0,1).numpy(), crf_result_scored_by_crf
def masked_cross_entropy(logits, target, length): length = Variable(torch.LongTensor(length)).cuda() """ Args: logits: A Variable containing a FloatTensor of size (batch, max_len, num_classes) which contains the unnormalized probability for each class. target: A Variable containing a LongTensor of size (batch, max_len) which contains the index of the true class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Returns: loss: An average loss value masked by the length. """ # logits_flat: (batch * max_len, num_classes) logits_flat = logits.view(-1, logits.size(-1)) # log_probs_flat: (batch * max_len, num_classes) log_probs_flat = functional.log_softmax(logits_flat) # target_flat: (batch * max_len, 1) target_flat = target.view(-1, 1) # losses_flat: (batch * max_len, 1) losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) # losses: (batch, max_len) losses = losses_flat.view(*target.size()) # mask: (batch, max_len) mask = sequence_mask(sequence_length=length, max_len=target.size(1)) losses = losses * mask.float() loss = losses.sum() / length.float().sum() return loss
def forward(self, batch): X_data, X_padding_mask, X_lens, X_batch_extend_vocab, X_extra_zeros, context, coverage = self.get_input_from_batch(batch) y_data, y_padding_mask, y_max_len, y_lens_var, target_data = self.get_output_from_batch(batch) encoder_outputs, encoder_hidden, max_encoder_output = self.encoder(X_data, X_lens) s_t_1 = self.reduce_state(encoder_hidden) if config.use_maxpool_init_ctx: context = max_encoder_output step_losses = [] for di in range(min(y_max_len, self.args.max_decoder_steps)): y_t_1 = y_data[:, di] # Teacher forcing final_dist, s_t_1, context, attn_dist, p_gen, coverage = self.decoder(y_t_1, s_t_1, encoder_outputs, X_padding_mask, context, X_extra_zeros, X_batch_extend_vocab, coverage) target = target_data[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + self.args.eps) if self.args.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss step_mask = y_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses / y_lens_var loss = torch.mean(batch_avg_loss) return loss
def forward(self, batch): """Forward method receives target-length ordered batches.""" # Encode image and get initial variables img_ctx, c_t, h_t = self.f_init(batch) # Fetch embeddings -> (seq_len, batch_size, emb_dim) caption = batch[self.tl] # n_tokens token processed in this batch self.n_tokens = caption.numel() # Get embeddings embs = self.emb(caption) # Accumulators loss = 0.0 self.alphas = [] # -1: So that we skip the timestep where input is <eos> for t in range(caption.shape[0] - 1): # NOTE: This is where scheduled sampling will happen # Either fetch from self.emb or from log_p # Current textual input to decoder: y_t = embs[t] log_p, c_t, h_t, _ = self.f_next(img_ctx, embs[t], c_t, h_t) # t + 1: We're predicting next token # Cumulate losses loss += torch.gather( log_p, dim=1, index=caption[t + 1].unsqueeze(1)).sum() # Return normalized loss return loss / self.n_tokens
def forward(self, ctx_dict, y): """Computes the softmax outputs given source annotations `ctxs` and ground-truth target token indices `y`. Arguments: ctxs(Variable): A variable of `S*B*ctx_dim` representing the source annotations in an order compatible with ground-truth targets. y(Variable): A variable of `T*B` containing ground-truth target token indices for the given batch. """ loss = 0.0 # Convert token indices to embeddings -> T*B*E y_emb = self.emb(y) # Get initial hidden state h = self.f_init(*ctx_dict['txt']) # -1: So that we skip the timestep where input is <eos> for t in range(y_emb.shape[0] - 1): log_p, h = self.f_next(ctx_dict, y_emb[t], h) loss += torch.gather( log_p, dim=1, index=y[t + 1].unsqueeze(1)).sum() return loss
def eval_one_batch(self, batch): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ get_input_from_batch(batch, use_cuda) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, use_cuda) encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) if config.use_maxpool_init_ctx: c_t_1 = max_encoder_output step_losses = [] for di in range(min(max_dec_len, config.max_dec_steps)): y_t_1 = dec_batch[:, di] # Teacher forcing final_dist, s_t_1, c_t_1,attn_dist, p_gen, coverage = self.model.decoder(y_t_1, s_t_1, encoder_outputs, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_step_losses / dec_lens_var loss = torch.mean(batch_avg_loss) return loss.data[0]
def _compute_loss(self, batch, output, target): scores = self.generator(self._bottle(output)) gtruth = target.view(-1) if self.confidence < 1: tdata = gtruth.data mask = torch.nonzero(tdata.eq(self.padding_idx)).squeeze() log_likelihood = torch.gather(scores.data, 1, tdata.unsqueeze(1)) tmp_ = self.one_hot.repeat(gtruth.size(0), 1) tmp_.scatter_(1, tdata.unsqueeze(1), self.confidence) if mask.dim() > 0: log_likelihood.index_fill_(0, mask, 0) tmp_.index_fill_(0, mask, 0) gtruth = Variable(tmp_, requires_grad=False) loss = self.criterion(scores, gtruth) if self.confidence < 1: # Default: report smoothed ppl. # loss_data = -log_likelihood.sum(0) loss_data = loss.data.clone() else: loss_data = loss.data.clone() stats = self._stats(loss_data, scores.data, target.view(-1).data) return loss, stats
def forward(self, log_prob, y_true, mask): mask = mask.float() log_P = torch.gather(log_prob.view(-1, log_prob.size(2)), 1, y_true.contiguous().view(-1, 1)) # batch*time x 1 log_P = log_P.view(y_true.size(0), y_true.size(1)) # batch x time log_P = log_P * mask # batch x time sum_log_P = torch.sum(log_P, dim=1) / torch.sum(mask, dim=1) # batch return -sum_log_P
def get_max_q_values_with_target( self, q_values, q_values_target, possible_actions_mask ): """ Used in Q-learning update. :param states: Numpy array with shape (batch_size, state_dim). Each row contains a representation of a state. :param possible_actions_mask: Numpy array with shape (batch_size, action_dim). possible_actions[i][j] = 1 iff the agent can take action j from state i. :param double_q_learning: bool to use double q-learning """ # The parametric DQN can create flattened q values so we reshape here. q_values = q_values.reshape(possible_actions_mask.shape) q_values_target = q_values_target.reshape(possible_actions_mask.shape) if self.double_q_learning: # Set q-values of impossible actions to a very large negative number. inverse_pna = 1 - possible_actions_mask impossible_action_penalty = self.ACTION_NOT_POSSIBLE_VAL * inverse_pna q_values = q_values + impossible_action_penalty # Select max_q action after scoring with online network max_q_values, max_indicies = torch.max(q_values, dim=1, keepdim=True) # Use q_values from target network for max_q action from online q_network # to decouble selection & scoring, preventing overestimation of q-values q_values = torch.gather(q_values_target, 1, max_indicies) return q_values, max_indicies else: return self.get_max_q_values(q_values, possible_actions_mask)
def lm_lstm(self, forw_sentence, forw_position, back_sentence, back_position, word_seq): ''' return word representations with character-language-model args: forw_sentence (char_seq_len, batch_size) : char-level representation of sentence forw_position (word_seq_len, batch_size) : position of blank space in char-level representation of sentence back_sentence (char_seq_len, batch_size) : char-level representation of sentence (inverse order) back_position (word_seq_len, batch_size) : position of blank space in inversed char-level representation of sentence word_seq (word_seq_len, batch_size) : word-level representation of sentence ''' self.set_batch_seq_size(forw_position) forw_emb = self.char_embeds(forw_sentence) back_emb = self.char_embeds(back_sentence) d_f_emb = self.dropout(forw_emb) d_b_emb = self.dropout(back_emb) forw_lstm_out, _ = self.forw_char_lstm(d_f_emb) back_lstm_out, _ = self.back_char_lstm(d_b_emb) forw_position = forw_position.unsqueeze(2).expand(self.word_seq_length, self.batch_size, self.char_hidden_dim) select_forw_lstm_out = torch.gather(forw_lstm_out, 0, forw_position) back_position = back_position.unsqueeze(2).expand(self.word_seq_length, self.batch_size, self.char_hidden_dim) select_back_lstm_out = torch.gather(back_lstm_out, 0, back_position) fb_lstm_out = self.dropout(torch.cat((select_forw_lstm_out, select_back_lstm_out), dim=2)) if self.if_highway: char_out = self.fb2char(fb_lstm_out) d_char_out = self.dropout(char_out) else: d_char_out = fb_lstm_out word_emb = self.word_embeds(word_seq) d_word_emb = self.dropout(word_emb) word_input = torch.cat((d_word_emb, d_char_out), dim=2) lstm_out, _ = self.word_lstm_lm(word_input) d_lstm_out = self.dropout(lstm_out) return d_lstm_out
def NN(epoch, net, lemniscate, trainloader, testloader, recompute_memory=0): net.eval() net_time = AverageMeter() cls_time = AverageMeter() losses = AverageMeter() correct = 0. total = 0 testsize = testloader.dataset.__len__() trainFeatures = lemniscate.memory.t() if hasattr(trainloader.dataset, 'imgs'): trainLabels = torch.LongTensor([y for (p, y) in trainloader.dataset.imgs]).cuda() else: trainLabels = torch.LongTensor(trainloader.dataset.train_labels).cuda() if recompute_memory: transform_bak = trainloader.dataset.transform trainloader.dataset.transform = testloader.dataset.transform temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=100, shuffle=False, num_workers=1) for batch_idx, (inputs, targets, indexes) in enumerate(temploader): inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs, volatile=True), Variable(targets) batchSize = inputs.size(0) features = net(inputs) trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.data.t() trainLabels = torch.LongTensor(temploader.dataset.train_labels).cuda() trainloader.dataset.transform = transform_bak end = time.time() for batch_idx, (inputs, targets, indexes) in enumerate(testloader): inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs, volatile=True), Variable(targets) batchSize = inputs.size(0) features = net(inputs) net_time.update(time.time() - end) end = time.time() dist = torch.mm(features.data, trainFeatures) yd, yi = dist.topk(1, dim=1, largest=True, sorted=True) candidates = trainLabels.view(1,-1).expand(batchSize, -1) retrieval = torch.gather(candidates, 1, yi) retrieval = retrieval.narrow(1, 0, 1).clone().view(-1) yd = yd.narrow(1, 0, 1) total += targets.size(0) correct += retrieval.eq(targets.data).cpu().sum() cls_time.update(time.time() - end) end = time.time() print('Test [{}/{}]\t' 'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t' 'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t' 'Top1: {:.2f}'.format( total, testsize, correct*100./total, net_time=net_time, cls_time=cls_time)) return correct/total
def forward(self, input, target, mask): logprob_select = torch.gather(input, 1, target) out = torch.masked_select(logprob_select, mask) loss = -torch.sum(out) / mask.float().sum() return loss
def forward(self, x_de, x_en, update_baseline=True): bs = x_de.size(0) # x_de is bs,n_de. x_en is bs,n_en emb_de = self.embedding_de(x_de) # bs,n_de,word_dim emb_en = self.embedding_en(x_en) # bs,n_en,word_dim h0_enc = torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda() c0_enc = torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda() h0_dec = torch.zeros(self.n_layers, bs, self.hidden_dim).cuda() c0_dec = torch.zeros(self.n_layers, bs, self.hidden_dim).cuda() # hidden vars have dimension n_layers*n_directions,bs,hiddensz enc_h, _ = self.encoder(emb_de, (Variable(h0_enc), Variable(c0_enc))) # enc_h is bs,n_de,hiddensz*n_directions. ordering is different from last week because batch_first=True dec_h, _ = self.decoder(emb_en, (Variable(h0_dec), Variable(c0_dec))) # dec_h is bs,n_en,hidden_size*n_directions # we've gotten our encoder/decoder hidden states so we are ready to do attention # first let's get all our scores, which we can do easily since we are using dot-prod attention if self.directions == 2: scores = torch.bmm(self.dim_reduce(enc_h), dec_h.transpose(1,2)) # TODO: any easier ways to reduce dimension? else: scores = torch.bmm(enc_h, dec_h.transpose(1,2)) # (bs,n_de,hiddensz*n_directions) * (bs,hiddensz*n_directions,n_en) = (bs,n_de,n_en) reinforce_loss = 0 # we only use this variable for hard attention loss = 0 avg_reward = 0 # we just iterate to dec_h.size(1)-1, since there's </s> at the end of each sentence for t in range(dec_h.size(1)-1): # iterate over english words, with teacher forcing attn_dist = F.softmax(scores[:, :, t],dim=1) # bs,n_de. these are the alphas (attention scores for each german word) if self.attn_type == "hard": cat = torch.distributions.Categorical(attn_dist) attn_samples = cat.sample() # bs. each element is a sample from categorical distribution one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, attn_samples.data.unsqueeze(1), 1).cuda()) # bs,n_de # made a bunch of one-hot vectors context = torch.bmm(one_hot.unsqueeze(1), enc_h).squeeze(1) # now we use the one-hot vectors to select correct hidden vectors from enc_h # (bs,1,n_de) * (bs,n_de,hiddensz*n_directions) = (bs,1,hiddensz*n_directions). squeeze to bs,hiddensz*n_directions else: context = torch.bmm(attn_dist.unsqueeze(1), enc_h).squeeze(1) # same dimensions # (bs,1,n_de) * (bs,n_de,hiddensz*n_directions) = (bs,1,hiddensz*n_directions) # context is bs,hidden_size*n_directions # the rnn output and the context together make the decoder "hidden state", which is bs,2*hidden_size*n_directions pred = self.vocab_layer(torch.cat([dec_h[:,t,:], context], 1)) # bs,len(EN.vocab) y = x_en[:, t+1] # bs. these are our labels no_pad = (y != pad_token) # exclude english padding tokens reward = torch.gather(pred, 1, y.unsqueeze(1)) # bs,1 # reward[i,1] = pred[i,y[i]]. this gets log prob of correct word for each batch. similar to -crossentropy reward = reward.squeeze(1)[no_pad] # less than bs if self.attn_type == "hard": reinforce_loss -= (cat.log_prob(attn_samples[no_pad]) * (reward-self.baseline).detach()).sum() # reinforce rule (just read the formula), with special baseline loss -= reward.sum() # minimizing loss is maximizing reward no_pad_total = (x_en[:,1:] != pad_token).data.sum() # TODO: i think this is right, right? loss /= no_pad_total reinforce_loss /= no_pad_total avg_reward = -loss.data[0] if update_baseline: # update baseline as a moving average self.baseline = Variable(0.95*self.baseline.data + 0.05*avg_reward) return loss, reinforce_loss,avg_reward
def getMAE(): userIdx = testData.user.values itemIdx = testData.item.values rates = testData.rate.values R = torch.mm(U,P.t()) ratesPred = torch.gather(R.view(1,-1)[0],0,Variable(torch.LongTensor (userIdx * len(set(itemIdx)) + itemIdx))) diff_op = ratesPred - Variable(torch.FloatTensor(rates)) MAE = diff_op.abs().mean() return MAE.data.numpy()[0]
def forward(self, article_sents, sent_nums, target): enc_out = self._encode(article_sents, sent_nums) bs, nt = target.size() d = enc_out.size(2) ptr_in = torch.gather( enc_out, dim=1, index=target.unsqueeze(2).expand(bs, nt, d) ) output = self._extractor(enc_out, sent_nums, ptr_in) return output
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor, targets: torch.LongTensor, weights: torch.FloatTensor, batch_average: bool = True) -> torch.FloatTensor: """ Computes the cross entropy loss of a sequence, weighted with respect to some user provided weights. Note that the weighting here is not the same as in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting classes; here we are weighting the loss contribution from particular elements in the sequence. This allows loss computations for models which use padding. Parameters ---------- logits : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes) which contains the unnormalized probability for each class. targets : ``torch.LongTensor``, required. A ``torch.LongTensor`` of size (batch, sequence_length) which contains the index of the true class for each corresponding step. weights : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch, sequence_length) batch_average : bool, optional, (default = True). A bool indicating whether the loss should be averaged across the batch, or returned as a vector of losses per batch element. Returns ------- A torch.FloatTensor representing the cross entropy loss. If ``batch_average == True``, the returned loss is a scalar. If ``batch_average == False``, the returned loss is a vector of shape (batch_size,). """ # shape : (batch * sequence_length, num_classes) logits_flat = logits.view(-1, logits.size(-1)) # shape : (batch * sequence_length, num_classes) log_probs_flat = torch.nn.functional.log_softmax(logits_flat) # shape : (batch * max_len, 1) targets_flat = targets.view(-1, 1).long() # Contribution to the negative log likelihood only comes from the exact indices # of the targets, as the target distributions are one-hot. Here we use torch.gather # to extract the indices of the num_classes dimension which contribute to the loss. # shape : (batch * sequence_length, 1) negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size()) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood * weights.float() # shape : (batch_size,) per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13) if batch_average: num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13) return per_batch_loss.sum() / num_non_empty_sequences return per_batch_loss
def forward(self, input, target): logprob_select = torch.gather(input, 1, target) mask = target.data.gt(0) # generate the mask if isinstance(input, Variable): mask = Variable(mask, volatile=input.volatile) out = torch.masked_select(logprob_select, mask) loss = -torch.sum(out) # get the average loss. return loss
def emiss(self, T, idx, ignore_index=None): assert len(idx.shape) == 1 bs = idx.shape[0] idx = idx.view(-1, 1).expand(bs, self.ns).unsqueeze(-1) emiss = torch.gather(self.emission[T], -1, idx).view(bs, 1, self.ns) if ignore_index is None: return emiss else: idx = idx.view(bs, 1, self.ns) mask = (idx != ignore_index).float() return emiss * mask
def log_sum_exp(vec, m_size): """ calculate log of exp sum args: vec (batch_size, vanishing_dim, hidden_dim) : input tensor m_size : hidden_dim return: batch_size, hidden_dim """ _, idx = torch.max(vec, 1) # B * 1 * M max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) # B * M
def _score_candidates(self, cands, xe, encoder_output, hidden): # score each candidate separately # cands are exs_with_cands x cands_per_ex x words_per_cand # cview is total_cands x words_per_cand cview = cands.view(-1, cands.size(2)) cands_xes = xe.expand(xe.size(0), cview.size(0), xe.size(2)) sz = hidden.size() cands_hn = ( hidden.view(sz[0], sz[1], 1, sz[2]) .expand(sz[0], sz[1], cands.size(1), sz[2]) .contiguous() .view(sz[0], -1, sz[2]) ) sz = encoder_output.size() cands_encoder_output = ( encoder_output.contiguous() .view(sz[0], 1, sz[1], sz[2]) .expand(sz[0], cands.size(1), sz[1], sz[2]) .contiguous() .view(-1, sz[1], sz[2]) ) cand_scores = Variable( self.cand_scores.resize_(cview.size(0)).fill_(0)) cand_lengths = Variable( self.cand_lengths.resize_(cview.size(0)).fill_(0)) for i in range(cview.size(1)): output = self._apply_attention(cands_xes, cands_encoder_output, cands_hn) \ if self.use_attention else cands_xes output, cands_hn = self.decoder(output, cands_hn) preds, scores = self.hidden_to_idx(output, dropout=False) cs = cview.select(1, i) non_nulls = cs.ne(self.NULL_IDX) cand_lengths += non_nulls.long() score_per_cand = torch.gather(scores, 1, cs.unsqueeze(1)) cand_scores += score_per_cand.squeeze() * non_nulls.float() cands_xes = self.lt2dec(self.lt(cs).unsqueeze(0)) # set empty scores to -1, so when divided by 0 they become -inf cand_scores -= cand_lengths.eq(0).float() # average the scores per token cand_scores /= cand_lengths.float() cand_scores = cand_scores.view(cands.size(0), cands.size(1)) srtd_scores, text_cand_inds = cand_scores.sort(1, True) text_cand_inds = text_cand_inds.data return text_cand_inds
def sample_relax_given_class_k(logits, samp, k): cat = Categorical(logits=logits) b = samp #torch.argmax(z, dim=1) logprob = cat.log_prob(b).view(B,1) zs = [] z_tildes = [] for i in range(k): u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) gumbels = -torch.log(-torch.log(u)) z = logits + gumbels u_b = torch.gather(input=u, dim=1, index=b.view(B,1)) z_tilde_b = -torch.log(-torch.log(u_b)) z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b)) z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b) z = z_tilde u_b = torch.gather(input=u, dim=1, index=b.view(B,1)) z_tilde_b = -torch.log(-torch.log(u_b)) u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b)) z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b) zs.append(z) z_tildes.append(z_tilde) zs= torch.stack(zs) z_tildes= torch.stack(z_tildes) z = torch.mean(zs, dim=0) z_tilde = torch.mean(z_tildes, dim=0) return z, z_tilde, logprob
def combine_matrices(self, prob_matrix, prob_matrix_scale, perm, last=False): # prob_matrix shape is bs x length x length + 1. Add extra column. length = prob_matrix_scale.size()[2] first = Variable(torch.zeros([self.batch_size, 1, length])).type(dtype) first[:, 0, 0] = 1.0 prob_matrix_scale = torch.cat((first, prob_matrix_scale), 1) # argmax new_perm = self.discretize(prob_matrix_scale) perm = torch.gather(perm, 1, new_perm) # combine prob_matrix = torch.bmm(prob_matrix_scale, prob_matrix) return prob_matrix, prob_matrix_scale, perm
def masked_cross_entropy(logits, target, length): """ Args: logits: A Variable containing a FloatTensor of size (batch, max_len, num_classes) which contains the unnormalized probability for each class. target: A Variable containing a LongTensor of size (batch, max_len) which contains the index of the true class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Returns: loss: An average loss value masked by the length. The code is same as: weight = torch.ones(tgt_vocab_size) weight[padding_idx] = 0 criterion = nn.CrossEntropyLoss(weight.cuda(), size_average) loss = criterion(logits_flat, losses_flat) """ # logits_flat: (batch * max_len, num_classes) logits_flat = logits.view(-1, logits.size(-1)) ## (3, 16, 50000) => (3*16, 50000) # log_probs_flat: (batch * max_len, num_classes) log_probs_flat = F.log_softmax(logits_flat, dim=1) ## (3*16, 50000) => (3*16, 50000) # target_flat: (batch * max_len, 1) target_flat = target.view(-1, 1) ## (3, 16) => (3*16, 1) # losses_flat: (batch * max_len, 1) ## the -log_prob of target token_id in each batch and time step losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) ## (3*16, 1) # losses: (batch, max_len) losses = losses_flat.view(*target.size()) ## (3*16, 1) => (3, 16) # mask: (batch, max_len) mask = sequence_mask(sequence_length=length, max_len=target.size(1)) ## (3, 16) # Note: mask need to bed casted to float! losses = losses * mask.float() loss = losses.sum() / mask.float().sum() ## 把padding的部分弄掉,才能符合loss的公式(分母為正確的num_words) # # (batch_size * max_tgt_len,) ## the predicted token_id of max log_prob in each batch and time step # pred_flat = log_probs_flat.max(1)[1] ## (3*16, 1) # # (batch_size * max_tgt_len,) => (batch_size, max_tgt_len) => (max_tgt_len, batch_size) # pred_seqs = pred_flat.view(*target.size()).transpose(0,1).contiguous() ## (3*16, 1) => (3, 16) => (16, 3) # # (batch_size, max_len) => (batch_size * max_tgt_len,) # mask_flat = mask.view(-1) ## (3, 16) => (3*16) # # `.float()` IS VERY IMPORTANT !!! # # https://discuss.pytorch.org/t/batch-size-and-validation-accuracy/4066/3 # num_corrects = int(pred_flat.eq(target_flat.squeeze(1)).masked_select(mask_flat).float().data.sum()) ## 正確的字有幾個 # num_words = length.data.sum() ## 這個batch共多少字 # return loss, pred_seqs, num_corrects, num_words return loss
def decode_z_order(self, batch_size, u_input, u_hiddens, u_input_1hot, u_last_hidden, z_input, turn_states, sample_type, decoder_type, qz_samples=None, qz_hiddens=None, m_input=None, m_hiddens=None, m_input_1hot=None, mask_otlg=False): return_gmb = True if 'gumbel' in sample_type else False pv_z_pr = turn_states.get('pv_%s_pr'%decoder_type, None) pv_z_h = turn_states.get('pv_%s_h'%decoder_type, None) pv_z_id = turn_states.get('pv_%s_id'%decoder_type, None) z_prob, z_samples, gmb_samples = [], [], [] log_pz = 0 for si, sn in enumerate(self.reader.otlg.informable_slots): last_hidden = u_last_hidden[:-1] # last_hidden = (u_last_hidden[-1] + u_last_hidden[-2]).unsqueeze(0) z_eos_idx = self.vocab.encode(self.z_eos_map[sn]) emb_zt = self.get_first_z_input(sn, batch_size, self.multi_domain) zero_vec = cuda_(torch.zeros(batch_size, 1, self.hidden_size)) selc_read_u = selc_read_m = selc_read_pv_z = zero_vec if pv_z_pr is not None: b, e = si * self.z_length, (si+1) * self.z_length pv_pr, pv_h, pv_idx = pv_z_pr[:, b:e], pv_z_h[:, b:e], pv_z_id[:, b:e] else: pv_pr, pv_h, pv_idx = None, None, None prev_zt = None for t in range(self.z_length): if decoder_type == 'pz': prob, last_hidden, gru_out, selc_read_u, selc_read_pv_z = \ self.pz_decoder(u_input, u_input_1hot, u_hiddens, pv_z_prob=pv_pr, pv_z_hidden=pv_h, pv_z_idx=pv_idx, emb_zt=emb_zt, last_hidden=last_hidden, selc_read_u=selc_read_u, selc_read_pv_z=selc_read_pv_z) else: prob, last_hidden, gru_out, selc_read_u, selc_read_m, selc_read_pv_z, gmb_samp = \ self.qz_decoder(u_input, u_input_1hot, u_hiddens, m_input, m_input_1hot, m_hiddens, pv_z_prob=pv_pr, pv_z_hidden=pv_h, pv_z_idx=pv_idx, emb_zt=emb_zt, last_hidden=last_hidden, selc_read_u=selc_read_u, selc_read_m=selc_read_m, selc_read_pv_z=selc_read_pv_z, temp=self.gumbel_temp, return_gmb=return_gmb) if mask_otlg: prob = self.mask_probs(prob, tokens_allow=self.reader.slot_value_mask[sn]) if sample_type == 'supervised': zt = z_input[sn][:, t] elif sample_type == 'top1': zt = torch.topk(prob, 1)[1] elif sample_type == 'topk': topk_probs, topk_words = torch.topk(prob.squeeze(1), cfg.topk_num) widx = torch.multinomial(topk_probs, 1, replacement=True) zt = torch.gather(topk_words, 1, widx) #[B] elif sample_type == 'posterior': zt = qz_samples[:, si * self.z_length + t] elif 'gumbel' in sample_type: zt = torch.argmax(gmb_samp, dim=1) #[B] emb_zt = torch.matmul(gmb_samp, self.embedding.weight).unsqueeze(1) # [B, 1, H] zt, prev_zt, gmb_samp = self.mask_samples(zt, prev_zt, batch_size, z_eos_idx, gmb_samp, True) gmb_samples.append(gmb_samp) if 'gumbel' not in sample_type: emb_zt = self.embedding(zt.view(-1, 1)) prob_zt = torch.gather(prob, 1, zt.view(-1, 1)).squeeze(1) #[B, 1] log_prob_zt = torch.log(prob_zt) zt, prev_zt, log_prob_zt = self.mask_samples(zt, prev_zt, batch_size, z_eos_idx, log_prob_zt) log_pz += log_prob_zt z_samples.append(zt.view(-1)) z_prob.append(prob) z_prob = torch.stack(z_prob, dim=1) # [B*K,Tz,V] z_samples= torch.stack(z_samples, dim=1) # [B*K,Tz] if sample_type == 'posterior': z_samples, z_hiddens = qz_samples, qz_hiddens elif 'gumbel' not in sample_type: z_hiddens, z_last_hidden = self.z_encoder(z_samples, input_type='index') else: z_gumbel = torch.stack(gmb_samples, dim=1) # [B,Tz, V] z_gumbel = torch.matmul(z_gumbel, self.embedding.weight) # [B,Tz, E] z_hiddens, z_last_hidden = self.z_encoder(z_gumbel, input_type='embedding') retain = self.prev_z_continuous turn_states['pv_%s_h'%decoder_type] = z_hiddens if retain else z_hiddens.detach() turn_states['pv_%s_pr'%decoder_type] = z_prob if retain else z_prob.detach() turn_states['pv_%s_id'%decoder_type] = z_samples if retain else z_samples.detach() return z_prob, z_samples, z_hiddens, turn_states, log_pz
def _get_parallel_step_context(self, embeddings, state, from_depot=False): """ Returns the context per step, optionally for multiple steps at once (for efficient evaluation of the model) :param embeddings: (batch_size, graph_size, embed_dim) :param prev_a: (batch_size, num_steps) :param first_a: Only used when num_steps = 1, action of first step or None if first step :return: (batch_size, num_steps, context_dim) """ current_node = state.get_current_node() batch_size, num_steps = current_node.size() if self.is_vrp: # Embedding of previous node + remaining capacity if from_depot: # 1st dimension is node idx, but we do not squeeze it since we want to insert step dimension # i.e. we actually want embeddings[:, 0, :][:, None, :] which is equivalent return torch.cat( ( embeddings[:, 0:1, :].expand(batch_size, num_steps, embeddings.size(-1)), # used capacity is 0 after visiting depot self.problem.VEHICLE_CAPACITY - torch.zeros_like(state.used_capacity[:, :, None])), -1) else: return torch.cat((torch.gather( embeddings, 1, current_node.contiguous().view( batch_size, num_steps, 1).expand( batch_size, num_steps, embeddings.size(-1))).view( batch_size, num_steps, embeddings.size(-1)), self.problem.VEHICLE_CAPACITY - state.used_capacity[:, :, None]), -1) elif self.is_orienteering or self.is_pctsp: return torch.cat( (torch.gather( embeddings, 1, current_node.contiguous().view( batch_size, num_steps, 1).expand( batch_size, num_steps, embeddings.size(-1))).view( batch_size, num_steps, embeddings.size(-1)), (state.get_remaining_length()[:, :, None] if self.is_orienteering else state.get_remaining_prize_to_collect()[:, :, None])), -1) else: # TSP if num_steps == 1: # We need to special case if we have only 1 step, may be the first or not if state.i.item() == 0: # First and only step, ignore prev_a (this is a placeholder) return self.W_placeholder[None, None, :].expand( batch_size, 1, self.W_placeholder.size(-1)) else: return embeddings.gather( 1, torch.cat((state.first_a, current_node), 1)[:, :, None].expand(batch_size, 2, embeddings.size(-1))).view( batch_size, 1, -1) # More than one step, assume always starting with first embeddings_per_step = embeddings.gather( 1, current_node[:, 1:, None].expand(batch_size, num_steps - 1, embeddings.size(-1))) return torch.cat( ( # First step placeholder, cat in dim 1 (time steps) self.W_placeholder[None, None, :].expand( batch_size, 1, self.W_placeholder.size(-1)), # Second step, concatenate embedding of first with embedding of current/previous (in dim 2, context dim) torch.cat((embeddings_per_step[:, 0:1, :].expand( batch_size, num_steps - 1, embeddings.size(-1)), embeddings_per_step), 2)), 1)
def forward(self, x, r_ij, neighbors, pairwise_mask, f_ij=None, softmax=None): """Compute convolution block. Args: x (torch.Tensor): input representation/embedding of atomic environments with (N_b, N_a, n_in) shape. r_ij (torch.Tensor): interatomic distances of (N_b, N_a, N_nbh) shape. neighbors (torch.Tensor): indices of neighbors of (N_b, N_a, N_nbh) shape. pairwise_mask (torch.Tensor): mask to filter out non-existing neighbors introduced via padding. f_ij (torch.Tensor, optional): expanded interatomic distances in a basis. If None, r_ij.unsqueeze(-1) is used. Returns: torch.Tensor: block output with (N_b, N_a, n_out) shape. """ if f_ij is None: f_ij = r_ij.unsqueeze( -1 ) #shape [batch, num_atoms, num_neighbors (num_atoms-1),gauusian_exp] #-------------NEW----------------# if self.n_heads_weights > 0: A = self.Attention(x) #attention in weight generation #concatenate multi-headed attention to distances f_ij = torch.cat((f_ij, A), dim=3) #f_ij = self.dropout(f_ij) #--------------------------------# # pass expanded interactomic distances through filter block W = self.filter_network(f_ij) #print(W.shape, 'Wsize') # apply cutoff if self.cutoff_network is not None: C = self.cutoff_network(r_ij) W = W * C.unsqueeze(-1) # pass initial embeddings through Dense layer (to correct size for number of filters) y = self.in2f(x) # reshape y for element-wise multiplication by W nbh_size = neighbors.size() nbh = neighbors.view(-1, nbh_size[1] * nbh_size[2], 1) nbh = nbh.expand(-1, -1, y.size(2)) y = torch.gather(y, 1, nbh) y = y.view(nbh_size[0], nbh_size[1], nbh_size[2], -1) #print(y.shape, 'yshape') # element-wise multiplication, aggregating and Dense layer #-----------NEW:attention in convolution--------------# if self.n_heads_conv > 0: W = self.AttentionConv( x, Weights=W) #single head to match weight size #added softmax if softmax is not None: W = self.softmax(W) #------------------------------------------------------# y = y * W y = self.agg(y, pairwise_mask) y = self.f2out(y) return y
def _run_one_fw(self, pixel_model, pixel_inp, cat_var, target, base_eps, avoid_target=True): batch_size, channels, height, width = pixel_inp.size() pixel_inp_jpeg = self._jpeg_cat(pixel_inp, cat_var, base_eps, batch_size, height, width) s = pixel_model(pixel_inp_jpeg) for it in range(self.nb_its): loss = self.criterion(s, target) loss.backward() if avoid_target: grad = cat_var.grad.data else: grad = -cat_var.grad.data def where_float(cond, if_true, if_false): return cond.float() * if_true + (1-cond.float()) * if_false def where_long(cond, if_true, if_false): return cond.long() * if_true + (1-cond.long()) * if_false abs_grad = torch.abs(grad).view(batch_size, -1) num_pixels = abs_grad.size()[1] sign_grad = torch.sign(grad) bound = where_float(sign_grad > 0, self.l1_max - cat_var, cat_var + self.l1_max).view(batch_size, -1) k_min = torch.zeros((batch_size,1), dtype=torch.long, requires_grad=False, device='cuda') k_max = torch.ones((batch_size,1), dtype=torch.long, requires_grad=False, device='cuda') * num_pixels # cum_bnd[k] is meant to track the L1 norm we end up with if we take # the k indices with the largest gradient magnitude and push them to their boundary values (0 or 255) values, indices = torch.sort(abs_grad, descending=True) bnd = torch.gather(bound, 1, indices) # subtract bnd because we don't want the cumsum to include the final element cum_bnd = torch.cumsum(bnd, 1) - bnd # this is hard-coded as floor(log_2(256 * 256 * 3)) for _ in range(17): k_mid = (k_min + k_max) // 2 l1norms = torch.gather(cum_bnd, 1, k_mid) k_min = where_long(l1norms > base_eps, k_min, k_mid) k_max = where_long(l1norms > base_eps, k_mid, k_max) # next want to set the gradient of indices[0:k_min] to their corresponding bound magnitudes = torch.zeros((batch_size, num_pixels), requires_grad=False, device='cuda') for bi in range(batch_size): magnitudes[bi, indices[bi, :k_min[bi,0]]] = bnd[bi, :k_min[bi,0]] magnitudes[bi, indices[bi, k_min[bi,0]]] = base_eps[bi] - cum_bnd[bi, k_min[bi,0]] delta_it = sign_grad * magnitudes.view(cat_var.size()) # These should always be exactly epsilon # l1_check = torch.norm(delta_it.view(batch_size, -1), 1.0, dim=1) / num_pixels # print('l1_check: %s' % l1_check) cat_var.data = cat_var.data + (delta_it - cat_var.data) / (it + 1.0) if it != self.nb_its - 1: # self.jpeg scales rounding_vars by base_eps, so we divide to rescale # its coordinates to [-1, 1] cat_var_temp = cat_var / base_eps[:, None] pixel_inp_jpeg = self._jpeg_cat(pixel_inp, cat_var_temp, base_eps, batch_size, height, width) s = pixel_model(pixel_inp_jpeg) cat_var.grad.data.zero_() return cat_var
def forward(self, obj_heads, reg_heads, cls_heads, batch_anchors): device = cls_heads[0].device with torch.no_grad(): filter_scores, filter_score_classes, filter_boxes = [], [], [] for per_level_obj_pred, per_level_reg_pred, per_level_cls_pred, per_level_anchors in zip( obj_heads, reg_heads, cls_heads, batch_anchors): per_level_obj_pred = per_level_obj_pred.view( per_level_obj_pred.shape[0], -1, per_level_obj_pred.shape[-1]) per_level_reg_pred = per_level_reg_pred.view( per_level_reg_pred.shape[0], -1, per_level_reg_pred.shape[-1]) per_level_cls_pred = per_level_cls_pred.view( per_level_cls_pred.shape[0], -1, per_level_cls_pred.shape[-1]) per_level_anchors = per_level_anchors.view( per_level_anchors.shape[0], -1, per_level_anchors.shape[-1]) per_level_obj_pred = torch.sigmoid(per_level_obj_pred) per_level_cls_pred = torch.sigmoid(per_level_cls_pred) # snap per_level_reg_pred from tx,ty,tw,th -> x_center,y_center,w,h -> x_min,y_min,x_max,y_max per_level_reg_pred[:, :, 0:2] = ( torch.sigmoid(per_level_reg_pred[:, :, 0:2]) + per_level_anchors[:, :, 0:2]) * per_level_anchors[:, :, 4:5] # pred_bboxes_wh=exp(twh)*anchor_wh/stride per_level_reg_pred[:, :, 2:4] = torch.exp( per_level_reg_pred[:, :, 2:4] ) * per_level_anchors[:, :, 2:4] / per_level_anchors[:, :, 4:5] per_level_reg_pred[:, :, 0: 2] = per_level_reg_pred[:, :, 0: 2] - 0.5 * per_level_reg_pred[:, :, 2: 4] per_level_reg_pred[:, :, 2: 4] = per_level_reg_pred[:, :, 2: 4] + per_level_reg_pred[:, :, 0: 2] per_level_reg_pred = per_level_reg_pred.int() per_level_reg_pred[:, :, 0] = torch.clamp(per_level_reg_pred[:, :, 0], min=0) per_level_reg_pred[:, :, 1] = torch.clamp(per_level_reg_pred[:, :, 1], min=0) per_level_reg_pred[:, :, 2] = torch.clamp(per_level_reg_pred[:, :, 2], max=self.image_w - 1) per_level_reg_pred[:, :, 3] = torch.clamp(per_level_reg_pred[:, :, 3], max=self.image_h - 1) per_level_scores, per_level_score_classes = torch.max( per_level_cls_pred, dim=2) per_level_scores = per_level_scores * per_level_obj_pred.squeeze( -1) if per_level_scores.shape[1] >= self.top_n: per_level_scores, indexes = torch.topk(per_level_scores, self.top_n, dim=1, largest=True, sorted=True) per_level_score_classes = torch.gather( per_level_score_classes, 1, indexes) per_level_reg_pred = torch.gather( per_level_reg_pred, 1, indexes.unsqueeze(-1).repeat(1, 1, 4)) filter_scores.append(per_level_scores) filter_score_classes.append(per_level_score_classes) filter_boxes.append(per_level_reg_pred) filter_scores = torch.cat(filter_scores, axis=1) filter_score_classes = torch.cat(filter_score_classes, axis=1) filter_boxes = torch.cat(filter_boxes, axis=1) batch_scores, batch_classes, batch_pred_bboxes = [], [], [] for scores, score_classes, pred_bboxes in zip( filter_scores, filter_score_classes, filter_boxes): score_classes = score_classes[ scores > self.min_score_threshold].float() pred_bboxes = pred_bboxes[ scores > self.min_score_threshold].float() scores = scores[scores > self.min_score_threshold].float() one_image_scores = (-1) * torch.ones( (self.max_detection_num, ), device=device) one_image_classes = (-1) * torch.ones( (self.max_detection_num, ), device=device) one_image_pred_bboxes = (-1) * torch.ones( (self.max_detection_num, 4), device=device) if scores.shape[0] != 0: # Sort boxes sorted_scores, sorted_indexes = torch.sort(scores, descending=True) sorted_score_classes = score_classes[sorted_indexes] sorted_pred_bboxes = pred_bboxes[sorted_indexes] keep = nms(sorted_pred_bboxes, sorted_scores, self.nms_threshold) keep_scores = sorted_scores[keep] keep_classes = sorted_score_classes[keep] keep_pred_bboxes = sorted_pred_bboxes[keep] final_detection_num = min(self.max_detection_num, keep_scores.shape[0]) one_image_scores[0:final_detection_num] = keep_scores[ 0:final_detection_num] one_image_classes[0:final_detection_num] = keep_classes[ 0:final_detection_num] one_image_pred_bboxes[ 0:final_detection_num, :] = keep_pred_bboxes[ 0:final_detection_num, :] one_image_scores = one_image_scores.unsqueeze(0) one_image_classes = one_image_classes.unsqueeze(0) one_image_pred_bboxes = one_image_pred_bboxes.unsqueeze(0) batch_scores.append(one_image_scores) batch_classes.append(one_image_classes) batch_pred_bboxes.append(one_image_pred_bboxes) batch_scores = torch.cat(batch_scores, axis=0) batch_classes = torch.cat(batch_classes, axis=0) batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0) # batch_scores shape:[batch_size,max_detection_num] # batch_classes shape:[batch_size,max_detection_num] # batch_pred_bboxes shape[batch_size,max_detection_num,4] return batch_scores, batch_classes, batch_pred_bboxes
def maskNLLLoss(self, inp, target, mask): nTotal = mask.sum() crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1))) loss = crossEntropy.masked_select(mask).mean() loss = loss.to(self.device) return loss, nTotal.item()
def forward(self, memory_cells, query, alignments, copy_source): """Compute attention with soft-copy. Args: memory_cells (SequenceBatch): of shape (batch_size, num_cells, memory_dim) query (Variable): of shape (batch_size, query_dim) alignments (SequenceBatch): int-valued, of shape (batch_size, num_cells). If something has no alignment, it will have value 0 in the mask. copy_source (Variable): of shape (batch_size, num_candidates) This behaves like normal attention, except we boost the exponentiated logits: exp_logits[i][j] += copy_source[i][alignments[i][j]] This is inspired by: "Incorporating Copying Mechanism in Sequence-to-Sequence Learning" http://www.aclweb.org/anthology/P16-1154 Returns: AttentionOutput """ if not self._is_subset(alignments.mask, memory_cells.mask): raise ValueError( 'Alignments mask must be a subset of memory cells mask.') base_attn = self._base_attention(memory_cells, query) # AttentionOutput exp_logits = torch.exp(base_attn.logits) # (batch_size, num_cells) # y = torch.gather(x, dim=1, index=index) # y[i][j] = x[i][index[i][j]] boost = torch.gather(copy_source, dim=1, index=alignments.values) # no boost for items with no alignment boost = boost * alignments.mask # boost the exponentiated logits boosted_exp_logits = exp_logits + boost # normalize to compute final weights normalizer = torch.sum(boosted_exp_logits, 1).expand_as(boosted_exp_logits) # (batch_size, num_cells) weights = boosted_exp_logits / normalizer weights = Attention._mask_weights(weights, memory_cells.mask) if not np.isfinite(weights.data.sum()): raise ValueError('Some attention weights are NaN') # TODO(kelvin): need to avoid numerical precision issues # TODO(kelvin): need to avoid division by zero # compute context context = Attention._context_from_weights(weights, memory_cells) logits = torch.log(boosted_exp_logits) return SoftCopyAttentionOutput(weights=weights, context=context, logits=logits, orig_logits=base_attn.logits, boost=boost)
def forward(self, *input): [ token_embeddings, input_mask_variable, conversation_mask, max_num_utterances_batch ] = input conversation_batch_size = int(token_embeddings.shape[0] / max_num_utterances_batch) if self.args.fixed_utterance_encoder: utterance_encodings = token_embeddings else: utterance_encodings = self.dialogue_embedder.utterance_encoder( token_embeddings, input_mask_variable) utterance_encodings = utterance_encodings.view( conversation_batch_size, max_num_utterances_batch, utterance_encodings.shape[1]) utterance_encodings_next = utterance_encodings[:, 1:, :].contiguous() utterance_encodings_prev = utterance_encodings[:, 0:-1, :].contiguous() conversation_encoded = self.dialogue_embedder([ token_embeddings, input_mask_variable, conversation_mask, max_num_utterances_batch ]) conversation_encoded_forward = conversation_encoded[:, 0, :] conversation_encoded_backward = conversation_encoded[:, 1, :] #conversation_encoded_forward = conversation_encoded.view(conversation_encoded.shape[0], 1, -1).squeeze(1) #conversation_encoded_backward = conversation_encoded.view(conversation_encoded.shape[0], 1, -1).squeeze(1) conversation_encoded_forward_reassembled = conversation_encoded_forward.view( conversation_batch_size, max_num_utterances_batch, conversation_encoded_forward.shape[1]) conversation_encoded_backward_reassembled = conversation_encoded_backward.view( conversation_batch_size, max_num_utterances_batch, conversation_encoded_backward.shape[1]) # Shift to prepare next and previous utterence encodings conversation_encoded_current1 = conversation_encoded_forward_reassembled[:, 0: -1, :].contiguous( ) conversation_encoded_next = conversation_encoded_forward_reassembled[:, 1:, :].contiguous( ) conversation_mask_next = conversation_mask[:, 1:].contiguous() conversation_encoded_current2 = conversation_encoded_backward_reassembled[:, 1:, :].contiguous( ) conversation_encoded_previous = conversation_encoded_backward_reassembled[:, 0: -1, :].contiguous( ) # conversation_mask_previous = conversation_mask[:, 0:-1].contiguous() # Gold Labels gold_indices = variable( LongTensor(range(conversation_encoded_current1.shape[1]))).view( -1, 1).repeat(conversation_batch_size, 1) # Linear transformation of both utterance representations transformed_current1 = self.current_dl_trasnformer1( conversation_encoded_current1) transformed_current2 = self.current_dl_trasnformer2( conversation_encoded_current2) transformed_next = self.next_dl_trasnformer(conversation_encoded_next) transformed_prev = self.prev_dl_trasnformer( conversation_encoded_previous) # transformed_next = self.next_dl_trasnformer(utterance_encodings_next) # transformed_prev = self.prev_dl_trasnformer(utterance_encodings_prev) # Output layer: Generate Scores for next and prev utterances next_logits = torch.bmm(transformed_current1, transformed_next.transpose(2, 1)) prev_logits = torch.bmm(transformed_current2, transformed_prev.transpose(2, 1)) # Computing custom masked cross entropy next_log_probs = F.log_softmax(next_logits, dim=2) prev_log_probs = F.log_softmax(prev_logits, dim=2) losses_next = -torch.gather(next_log_probs.view( next_log_probs.shape[0] * next_log_probs.shape[1], -1), dim=1, index=gold_indices) losses_prev = -torch.gather(prev_log_probs.view( prev_log_probs.shape[0] * prev_log_probs.shape[1], -1), dim=1, index=gold_indices) losses_masked = (losses_next.squeeze(1) * conversation_mask_next.view(conversation_mask_next.shape[0]*conversation_mask_next.shape[1]))\ + (losses_prev.squeeze(1) * conversation_mask_next.view(conversation_mask_next.shape[0]*conversation_mask_next.shape[1])) loss = losses_masked.sum() / (2 * conversation_mask_next.float().sum()) return loss
def forward(self, im_data, im_info, gt_boxes, num_boxes): # shape: im_data [1,c,w,h] im_info[1,3] gt_boxes[1,20,5] num_boxes[1] batch_size = im_data.size(0) # im_data 为原始图像blob[1,3,850,600] im_info = im_info.data gt_boxes = gt_boxes.data num_boxes = num_boxes.data # feed image data to base model to obtain base feature map base_feat = self.RCNN_base(im_data) # feed base feature map to RPN to obtain rois rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(base_feat, im_info, gt_boxes, num_boxes) # if it is training phase, then use ground truth bboxes for refining if self.training: roi_data = self.RCNN_proposal_target(rois, gt_boxes, num_boxes) rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data rois_label = Variable(rois_label.view(-1).long()) rois_target = Variable(rois_target.view(-1, rois_target.size(2))) rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2))) rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2))) else: rois_label = None rois_target = None rois_inside_ws = None rois_outside_ws = None rpn_loss_cls = 0 rpn_loss_bbox = 0 rois = Variable(rois) # 测试阶段rois格式为[1,300,5]维度为5,第一列全是0, # 并不表示roi的标签,仅仅是batch的index标识。gt_boxes的维度是(x,5),x是object的数量。 # do roi pooling based on predicted rois # POOLING_MODE = align pooled_feat = self.RCNN_roi_align(base_feat, rois.view(-1, 5)) #feed pooled feature to top model pooled_feat = self.head_to_tail(pooled_feat) # compute bbox offset ,roi池化后提取的roi特征计算边框预测值 bbox_pred = self.RCNN_bbox_pred(pooled_feat) if self.training and not self.class_agnostic: # select the corresponding columns according ti roi labels bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4) bbox_pred_select = torch.gather(bbox_pred_view, 1, rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4)) bbox_pred = bbox_pred_select.squeeze(1) # compute object classification probability cls_score = self.RCNN_cls_score(pooled_feat) cls_prob = F.softmax(cls_score, 1) # 测试阶段cls_score为[300, 21], bbox_pred为[300, 84] RCNN_loss_cls = 0 RCNN_loss_bbox = 0 if self.training: # classification loss RCNN_loss_cls = F.cross_entropy(cls_score, rois_label) # bounding box regression L1 loss RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws) cls_prob = cls_prob.view(batch_size, rois.size(1), -1) # 测试阶段[1, 300, 21] bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1) # 测试阶段[1, 300, 84] return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label
def forward(self, source: List[List[str]], target: List[List[str]]) -> torch.Tensor: """ Take a mini-batch of source and target sentences, compute the log-likelihood of target sentences under the language models learned by the NMT system. @param source (List[List[str]]): list of source sentence tokens @param target (List[List[str]]): list of target sentence tokens, wrapped by `<s>` and `</s>` @returns scores (Tensor): a variable/tensor of shape (b, ) representing the log-likelihood of generating the gold-standard target sentence for each example in the input batch. Here b = batch size. """ # Compute sentence lengths source_lengths = [len(s) for s in source] # Convert list of lists into tensors ## A4 code # source_padded = self.vocab.src.to_input_tensor(source, device=self.device) # Tensor: (src_len, b) # target_padded = self.vocab.tgt.to_input_tensor(target, device=self.device) # Tensor: (tgt_len, b) # enc_hiddens, dec_init_state = self.encode(source_padded, source_lengths) # enc_masks = self.generate_sent_masks(enc_hiddens, source_lengths) # combined_outputs = self.decode(enc_hiddens, enc_masks, dec_init_state, target_padded) ## End A4 code ### YOUR CODE HERE for part 1k ### TODO: ### Modify the code lines above as needed to fetch the character-level tensor ### to feed into encode() and decode(). You should: ### - Keep `target_padded` from A4 code above for predictions ### - Add `source_padded_chars` for character level padded encodings for source ### - Add `target_padded_chars` for character level padded encodings for target ### - Modify calls to encode() and decode() to use the character level encodings target_padded = self.vocab.tgt.to_input_tensor(target, device=self.device) source_padded_chars = self.vocab.tgt.to_input_tensor_char( source, device=self.device) target_padded_chars = self.vocab.tgt.to_input_tensor_char( target, device=self.device) enc_hiddens, dec_init_state = self.encode(source_padded_chars, source_lengths) enc_masks = self.generate_sent_masks(enc_hiddens, source_lengths) combined_outputs = self.decode(enc_hiddens, enc_masks, dec_init_state, target_padded_chars) ### END YOUR CODE P = F.log_softmax(self.target_vocab_projection(combined_outputs), dim=-1) # Zero out, probabilities for which we have nothing in the target text target_masks = (target_padded != self.vocab.tgt['<pad>']).float() # Compute log probability of generating true target words target_gold_words_log_prob = torch.gather( P, index=target_padded[1:].unsqueeze(-1), dim=-1).squeeze(-1) * target_masks[1:] scores = target_gold_words_log_prob.sum( ) # mhahn2 Small modification from A4 code. if self.charDecoder is not None: max_word_len = target_padded_chars.shape[-1] target_words = target_padded[1:].contiguous().view(-1) target_chars = target_padded_chars[1:].view(-1, max_word_len) target_outputs = combined_outputs.view(-1, 256) target_chars_oov = target_chars # torch.index_select(target_chars, dim=0, index=oovIndices) rnn_states_oov = target_outputs # torch.index_select(target_outputs, dim=0, index=oovIndices) oovs_losses = self.charDecoder.train_forward( target_chars_oov.t(), (rnn_states_oov.unsqueeze(0), rnn_states_oov.unsqueeze(0))) scores = scores - oovs_losses return scores
def forward(self, im_data, im_info, gt_boxes, num_boxes): batch_size = im_data.size(0) im_info = im_info.data gt_boxes = gt_boxes.data num_boxes = num_boxes.data # feed image data to base model to obtain base feature map base_feat = self.RCNN_base(im_data) # feed base feature map tp RPN to obtain rois rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn( base_feat, im_info, gt_boxes, num_boxes) # if it is training phrase, then use ground trubut bboxes for refining if self.training: roi_data = self.RCNN_proposal_target(rois, gt_boxes, num_boxes) rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data rois_label = Variable(rois_label.view(-1).long()) rois_target = Variable(rois_target.view(-1, rois_target.size(2))) rois_inside_ws = Variable( rois_inside_ws.view(-1, rois_inside_ws.size(2))) rois_outside_ws = Variable( rois_outside_ws.view(-1, rois_outside_ws.size(2))) else: rois_label = None rois_target = None rois_inside_ws = None rois_outside_ws = None rpn_loss_cls = 0 rpn_loss_bbox = 0 rois = Variable(rois) # do roi pooling based on predicted rois if cfg.POOLING_MODE == 'crop': # pdb.set_trace() # pooled_feat_anchor = _crop_pool_layer(base_feat, rois.view(-1, 5)) grid_xy = _affine_grid_gen(rois.view(-1, 5), base_feat.size()[2:], self.grid_size) grid_yx = torch.stack( [grid_xy.data[:, :, :, 1], grid_xy.data[:, :, :, 0]], 3).contiguous() pooled_feat = self.RCNN_roi_crop(base_feat, Variable(grid_yx).detach()) if cfg.CROP_RESIZE_WITH_MAX_POOL: pooled_feat = F.max_pool2d(pooled_feat, 2, 2) elif cfg.POOLING_MODE == 'align': pooled_feat = self.RCNN_roi_align(base_feat, rois.view(-1, 5)) elif cfg.POOLING_MODE == 'pool': pooled_feat = self.RCNN_roi_pool(base_feat, rois.view(-1, 5)) # feed pooled features to top model pooled_feat = self._head_to_tail(pooled_feat) # compute bbox offset bbox_pred = self.RCNN_bbox_pred(pooled_feat) if self.training and not self.class_agnostic: # select the corresponding columns according to roi labels bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4) bbox_pred_select = torch.gather( bbox_pred_view, 1, rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4)) bbox_pred = bbox_pred_select.squeeze(1) # compute object classification probability cls_score = self.RCNN_cls_score(pooled_feat) cls_prob = F.softmax(cls_score) RCNN_loss_cls = 0 RCNN_loss_bbox = 0 if self.training: # classification loss RCNN_loss_cls = F.cross_entropy(cls_score, rois_label) # bounding box regression L1 loss RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws) cls_prob = cls_prob.view(batch_size, rois.size(1), -1) bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1) return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label
def _viterbi_decode(self, feats, mask=None): """ Args: feats: size=(batch_size, seq_len, self.target_size+2) mask: size=(batch_size, seq_len) Returns: decode_idx: (batch_size, seq_len), viterbi decode结果 path_score: size=(batch_size, 1), 每个句子的得分 """ batch_size = feats.size(0) seq_len = feats.size(1) tag_size = feats.size(-1) #print(batch_size ,seq_len,tag_size) length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long() mask = mask.transpose(1, 0).contiguous() ins_num = seq_len * batch_size feats = feats.transpose(1, 0).contiguous().view( ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) scores = feats + self.transitions.view( 1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) scores = scores.view(seq_len, batch_size, tag_size, tag_size) seq_iter = enumerate(scores) # record the position of the best score back_points = list() partition_history = list() mask = (1 - mask.long()).byte() try: _, inivalues = seq_iter.__next__() except: _, inivalues = seq_iter.next() partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1) partition_history.append(partition) for idx, cur_values in seq_iter: cur_values = cur_values + partition.contiguous().view( batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) partition, cur_bp = torch.max(cur_values, 1) partition_history.append(partition.unsqueeze(-1)) cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) back_points.append(cur_bp) partition_history = torch.cat(partition_history).view( seq_len, batch_size, -1).transpose(1, 0).contiguous() last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1 last_partition = torch.gather( partition_history, 1, last_position).view(batch_size, tag_size, 1) last_values = last_partition.expand(batch_size, tag_size, tag_size) + \ self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size).to(device) _, last_bp = torch.max(last_values, 1) pad_zero = Variable(torch.zeros(batch_size, tag_size)).long() pad_zero = pad_zero.to(device) back_points.append(pad_zero) back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) pointer = last_bp[:, self.END_TAG_IDX] insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size) back_points = back_points.transpose(1, 0).contiguous() back_points.scatter_(1, last_position, insert_last) back_points = back_points.transpose(1, 0).contiguous() decode_idx = Variable(torch.LongTensor(seq_len, batch_size)) decode_idx = decode_idx.to(device) decode_idx[-1] = pointer.data for idx in range(len(back_points)-2, -1, -1): pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) decode_idx[idx] = pointer.view(-1).data path_score = None decode_idx = decode_idx.transpose(1, 0) return path_score, decode_idx
def forward(self, input: torch.Tensor, padding=None): pretrained_indices = torch.gather(self.vocab_to_pretrained, dim=0, index=input) rnn_output = self.model(pretrained_indices) return EmbeddingOutput(all_layers=[rnn_output], last_layer=rnn_output)
def _generate(self, model, sample, prefix_tokens=None, bos_token=None, **kwargs): if not self.retain_dropout: model.eval() # model.forward normally channels prev_output_tokens into the decoder # separately, but SequenceGenerator directly calls model.encoder encoder_input = { k: v for k, v in sample['net_input'].items() if k != 'prev_output_tokens' } src_tokens = encoder_input['src_tokens'] if src_tokens.dim() > 2: src_lengths = encoder_input['src_lengths'] else: src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) input_size = src_tokens.size() # batch dimension goes first followed by source lengths bsz = input_size[0] src_len = input_size[1] beam_size = self.beam_size if self.match_source_len: max_len = src_lengths.max().item() else: max_len = min( int(self.max_len_a * src_len + self.max_len_b), # exclude the EOS marker model.max_decoder_positions() - 1, ) assert self.min_len <= max_len, 'min_len cannot be larger than max_len, please adjust these!' # compute the encoder output for each beam encoder_outs = model.forward_encoder(encoder_input) new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) new_order = new_order.to(src_tokens.device).long() encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) # initialize buffers scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0) scores_buf = scores.clone() tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad) tokens_buf = tokens.clone() tokens[:, 0] = self.eos if bos_token is None else bos_token attn, attn_buf = None, None # The blacklist indicates candidates that should be ignored. # For example, suppose we're sampling and have already finalized 2/5 # samples. Then the blacklist would mark 2 positions as being ignored, # so that we only finalize the remaining 3 samples. blacklist = src_tokens.new_zeros(bsz, beam_size).eq( -1) # forward and backward-compatible False mask # list of completed sentences finalized = [[] for i in range(bsz)] finished = [False for i in range(bsz)] num_remaining_sent = bsz # number of candidate hypos per step cand_size = 2 * beam_size # 2 x beam size in case half are EOS # offset arrays for converting between different indexing schemes bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) cand_offsets = torch.arange(0, cand_size).type_as(tokens) # helper function for allocating buffers on the fly buffers = {} def buffer(name, type_of=tokens): # noqa if name not in buffers: buffers[name] = type_of.new() return buffers[name] def is_finished(sent, step, unfin_idx): """ Check whether we've finished generation for a given sentence, by comparing the worst score among finalized hypotheses to the best possible score among unfinalized hypotheses. """ assert len(finalized[sent]) <= beam_size if len(finalized[sent]) == beam_size or step == max_len: return True return False def finalize_hypos(step, bbsz_idx, eos_scores): """ Finalize the given hypotheses at this step, while keeping the total number of finalized hypotheses per sentence <= beam_size. Note: the input must be in the desired finalization order, so that hypotheses that appear earlier in the input are preferred to those that appear later. Args: step: current time step bbsz_idx: A vector of indices in the range [0, bsz*beam_size), indicating which hypotheses to finalize eos_scores: A vector of the same size as bbsz_idx containing scores for each hypothesis """ assert bbsz_idx.numel() == eos_scores.numel() # clone relevant token and attention tensors tokens_clone = tokens.index_select(0, bbsz_idx) tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS assert not tokens_clone.eq(self.eos).any() tokens_clone[:, step] = self.eos attn_clone = attn.index_select( 0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None # compute scores per token position pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] pos_scores[:, step] = eos_scores # convert from cumulative to per-position scores pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] # normalize sentence-level scores if self.normalize_scores: eos_scores /= (step + 1)**self.len_penalty cum_unfin = [] prev = 0 for f in finished: if f: prev += 1 else: cum_unfin.append(prev) sents_seen = set() for i, (idx, score) in enumerate( zip(bbsz_idx.tolist(), eos_scores.tolist())): unfin_idx = idx // beam_size sent = unfin_idx + cum_unfin[unfin_idx] sents_seen.add((sent, unfin_idx)) if self.match_source_len and step > src_lengths[unfin_idx]: score = -math.inf def get_hypo(): if attn_clone is not None: # remove padding tokens from attn scores hypo_attn = attn_clone[i] else: hypo_attn = None return { 'tokens': tokens_clone[i], 'score': score, 'attention': hypo_attn, # src_len x tgt_len 'alignment': None, 'positional_scores': pos_scores[i], } if len(finalized[sent]) < beam_size: finalized[sent].append(get_hypo()) newly_finished = [] for sent, unfin_idx in sents_seen: # check termination conditions for this sentence if not finished[sent] and is_finished(sent, step, unfin_idx): finished[sent] = True newly_finished.append(unfin_idx) return newly_finished reorder_state = None batch_idxs = None for step in range(max_len + 1): # one extra step for EOS marker # reorder decoder internal states based on the prev choice of beams if reorder_state is not None: if batch_idxs is not None: # update beam indices to take into account removed sentences corr = batch_idxs - torch.arange( batch_idxs.numel()).type_as(batch_idxs) reorder_state.view(-1, beam_size).add_( corr.unsqueeze(-1) * beam_size) model.reorder_incremental_state(reorder_state) encoder_outs = model.reorder_encoder_out( encoder_outs, reorder_state) lprobs, avg_attn_scores = model.forward_decoder( tokens[:, :step + 1], encoder_outs, temperature=self.temperature, ) lprobs[lprobs != lprobs] = -math.inf lprobs[:, self.pad] = -math.inf # never select pad lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty # handle max length constraint if step >= max_len: lprobs[:, :self.eos] = -math.inf lprobs[:, self.eos + 1:] = -math.inf elif self.eos_factor is not None: # only consider EOS if its score is no less than a specified # factor of the best candidate score disallow_eos_mask = lprobs[:, self. eos] < self.eos_factor * lprobs.max( dim=1)[0] lprobs[disallow_eos_mask, self.eos] = -math.inf # handle prefix tokens (possibly with different lengths) if prefix_tokens is not None and step < prefix_tokens.size( 1) and step < max_len: prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat( 1, beam_size).view(-1) prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) prefix_mask = prefix_toks.ne(self.pad) lprobs[prefix_mask] = -math.inf lprobs[prefix_mask] = lprobs[prefix_mask].scatter_( -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]) # if prefix includes eos, then we should make sure tokens and # scores are the same across all beams eos_mask = prefix_toks.eq(self.eos) if eos_mask.any(): # validate that the first beam matches the prefix first_beam = tokens[eos_mask].view( -1, beam_size, tokens.size(-1))[:, 0, 1:step + 1] eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] assert (first_beam == target_prefix).all() def replicate_first_beam(tensor, mask): tensor = tensor.view(-1, beam_size, tensor.size(-1)) tensor[mask] = tensor[mask][:, :1, :] return tensor.view(-1, tensor.size(-1)) # copy tokens, scores and lprobs from the first beam to all beams tokens = replicate_first_beam(tokens, eos_mask_batch_dim) scores = replicate_first_beam(scores, eos_mask_batch_dim) lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim) elif step < self.min_len: # minimum length constraint (does not apply if using prefix_tokens) lprobs[:, self.eos] = -math.inf if self.no_repeat_ngram_size > 0: # for each beam and batch sentence, generate a list of previous ngrams gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)] for bbsz_idx in range(bsz * beam_size): gen_tokens = tokens[bbsz_idx].tolist() for ngram in zip(*[ gen_tokens[i:] for i in range(self.no_repeat_ngram_size) ]): gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \ gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]] # Record attention scores if type(avg_attn_scores) is list: avg_attn_scores = avg_attn_scores[0] if avg_attn_scores is not None: if attn is None: if src_tokens.dim() > 2: attn = scores.new( bsz * beam_size, encoder_outs[0]["encoder_out"][0].size(0), max_len + 2, ) else: attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2) attn_buf = attn.clone() attn[:, :, step + 1].copy_(avg_attn_scores) scores = scores.type_as(lprobs) scores_buf = scores_buf.type_as(lprobs) eos_bbsz_idx = buffer('eos_bbsz_idx') eos_scores = buffer('eos_scores', type_of=scores) self.search.set_src_lengths(src_lengths) if self.no_repeat_ngram_size > 0: def calculate_banned_tokens(bbsz_idx): # before decoding the next token, prevent decoding of ngrams that have already appeared ngram_index = tuple( tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist()) return gen_ngrams[bbsz_idx].get(ngram_index, []) if step + 2 - self.no_repeat_ngram_size >= 0: # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet banned_tokens = [ calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size) ] else: banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)] for bbsz_idx in range(bsz * beam_size): lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf cand_scores, cand_indices, cand_beams = self.search.step( step, lprobs.view(bsz, -1, self.vocab_size), scores.view(bsz, beam_size, -1)[:, :, :step], ) # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] cand_bbsz_idx = cand_beams.add(bbsz_offsets) # finalize hypotheses that end in eos, except for blacklisted ones # or candidates with a score of -inf eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) eos_mask[:, :beam_size][blacklist] = 0 # only consider eos when it's among the top beam_size indices torch.masked_select( cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_bbsz_idx, ) finalized_sents = set() if eos_bbsz_idx.numel() > 0: torch.masked_select( cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_scores, ) finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores) num_remaining_sent -= len(finalized_sents) assert num_remaining_sent >= 0 if num_remaining_sent == 0: break assert step < max_len if len(finalized_sents) > 0: new_bsz = bsz - len(finalized_sents) # construct batch_idxs which holds indices of batches to keep for the next pass batch_mask = cand_indices.new_ones(bsz) batch_mask[cand_indices.new(finalized_sents)] = 0 batch_idxs = batch_mask.nonzero().squeeze(-1) eos_mask = eos_mask[batch_idxs] cand_beams = cand_beams[batch_idxs] bbsz_offsets.resize_(new_bsz, 1) cand_bbsz_idx = cand_beams.add(bbsz_offsets) cand_scores = cand_scores[batch_idxs] cand_indices = cand_indices[batch_idxs] if prefix_tokens is not None: prefix_tokens = prefix_tokens[batch_idxs] src_lengths = src_lengths[batch_idxs] blacklist = blacklist[batch_idxs] scores = scores.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, -1) scores_buf.resize_as_(scores) tokens = tokens.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, -1) tokens_buf.resize_as_(tokens) if attn is not None: attn = attn.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, attn.size(1), -1) attn_buf.resize_as_(attn) bsz = new_bsz else: batch_idxs = None # Set active_mask so that values > cand_size indicate eos or # blacklisted hypos and values < cand_size indicate candidate # active hypos. After this, the min values per row are the top # candidate active hypos. active_mask = buffer('active_mask') eos_mask[:, :beam_size] |= blacklist torch.add( eos_mask.type_as(cand_offsets) * cand_size, cand_offsets[:eos_mask.size(1)], out=active_mask, ) # get the top beam_size active hypotheses, which are just the hypos # with the smallest values in active_mask active_hypos, new_blacklist = buffer('active_hypos'), buffer( 'new_blacklist') torch.topk(active_mask, k=beam_size, dim=1, largest=False, out=(new_blacklist, active_hypos)) # update blacklist to ignore any finalized hypos blacklist = new_blacklist.ge(cand_size)[:, :beam_size] assert (~blacklist).any(dim=1).all() active_bbsz_idx = buffer('active_bbsz_idx') torch.gather( cand_bbsz_idx, dim=1, index=active_hypos, out=active_bbsz_idx, ) active_scores = torch.gather( cand_scores, dim=1, index=active_hypos, out=scores[:, step].view(bsz, beam_size), ) active_bbsz_idx = active_bbsz_idx.view(-1) active_scores = active_scores.view(-1) # copy tokens and scores for active hypotheses torch.index_select( tokens[:, :step + 1], dim=0, index=active_bbsz_idx, out=tokens_buf[:, :step + 1], ) torch.gather( cand_indices, dim=1, index=active_hypos, out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1], ) if step > 0: torch.index_select( scores[:, :step], dim=0, index=active_bbsz_idx, out=scores_buf[:, :step], ) torch.gather( cand_scores, dim=1, index=active_hypos, out=scores_buf.view(bsz, beam_size, -1)[:, :, step], ) # copy attention for active hypotheses if attn is not None: torch.index_select( attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, out=attn_buf[:, :, :step + 2], ) # swap buffers tokens, tokens_buf = tokens_buf, tokens scores, scores_buf = scores_buf, scores if attn is not None: attn, attn_buf = attn_buf, attn # reorder incremental state in decoder reorder_state = active_bbsz_idx # sort by score descending for sent in range(len(finalized)): finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) return finalized
def decode_a(self, batch_size, u_input, u_hiddens, u_input_1hot, u_last_hidden, a_input, db_vec, filling_vec, sample_type, decoder_type, qa_samples=None, qa_hiddens=None, m_input=None, m_hiddens=None, m_input_1hot=None): return_gmb = True if 'gumbel' in sample_type else False a_prob, a_samples, gmb_samples = [], [], [] log_pa = 0 for si, sn in enumerate(self.reader.act_order): last_hidden = u_last_hidden[:-1] a_eos_idx = self.vocab.encode('<eos_%s>'%sn) a_sos_idx = self.vocab.encode('<go_%s>'%sn) at = cuda_(torch.ones(batch_size, 1)*a_sos_idx).long() emb_at = self.embedding(at) selc_read_m = cuda_(torch.zeros(batch_size, 1, self.hidden_size)) vec_input = torch.cat([db_vec, filling_vec], dim=1) if sn == 'av': vec_input = cuda_(torch.zeros(vec_input.size())) prev_at = None for t in range(self.a_length): if decoder_type == 'pa': prob, last_hidden, gru_out = self.pa_decoder[sn]( u_hiddens, emb_at, vec_input, last_hidden) else: prob, last_hidden, gru_out, selc_read_m, gmb_samp = \ self.qa_decoder[sn](u_hiddens, m_input, m_input_1hot, m_hiddens, emb_at, vec_input, last_hidden, selc_read_m=selc_read_m, temp=self.gumbel_temp, return_gmb=return_gmb) if sample_type == 'supervised': at = a_input[sn][:, t] elif sample_type == 'top1': at = torch.topk(prob, 1)[1] elif sample_type == 'topk': topk_probs, topk_words = torch.topk(prob.squeeze(1), cfg.topk_num) widx = torch.multinomial(topk_probs, 1, replacement=True) at = torch.gather(topk_words, 1, widx) #[B] elif sample_type == 'posterior': at = qa_samples[:, si * self.a_length + t] elif 'gumbel' in sample_type: at = torch.argmax(gmb_samp, dim=1) #[B] emb_at = torch.matmul(gmb_samp, self.embedding.weight).unsqueeze(1) # [B, 1, H] at, prev_at, gmb_samp = self.mask_samples(at, prev_at, batch_size, a_eos_idx, gmb_samp, True) gmb_samples.append(gmb_samp) if 'gumbel' not in sample_type: emb_at = self.embedding(at.view(-1, 1)) prob_at = torch.gather(prob, 1, at.view(-1, 1)).squeeze(1) #[B, 1] log_prob_at = torch.log(prob_at) at, prev_at, log_prob_at = self.mask_samples(at, prev_at, batch_size, a_eos_idx, log_prob_at) log_pa += log_prob_at a_samples.append(at.view(-1)) a_prob.append(prob) a_prob = torch.stack(a_prob, dim=1) a_samples = torch.stack(a_samples, dim=1) # [B,Ta] if sample_type == 'posterior': a_samples, a_hiddens = qa_samples, qa_hiddens elif 'gumbel' not in sample_type: a_hiddens, a_last_hidden = self.a_encoder(a_samples, input_type='index') else: a_gumbel = torch.stack(gmb_samples, dim=1) # [B,Ta, V] a_gumbel = torch.matmul(a_gumbel, self.embedding.weight) # [B,Ta, E] a_hiddens, a_last_hidden = self.a_encoder(a_gumbel, input_type='embedding') return a_prob, a_samples, a_hiddens, log_pa
def _translateBatch(self, batch, prefix_tokens = None): # Batch size is in different location depending on data. # prefix_tokens = None beam_size = self.opt.beam_size bsz = batch_size = batch.size max_len = self.opt.max_sent_length gold_scores = batch.get('source').data.new(batch_size).float().zero_() gold_words = 0 allgold_scores = [] if batch.has_target: # Use the first model to decode model_ = self.models[0] gold_words, gold_scores, allgold_scores = model_.decode(batch) # (3) Start decoding # initialize buffers src = batch.get('source') scores = src.new(bsz * beam_size, max_len + 1).float().fill_(0) scores_buf = scores.clone() tokens = src.new(bsz * beam_size, max_len + 2).long().fill_(self.pad) tokens_buf = tokens.clone() tokens[:, 0].fill_(self.bos) # first token is bos attn, attn_buf = None, None nonpad_idxs = None src_tokens = src.transpose(0, 1) # batch x time src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) blacklist = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask # list of completed sentences finalized = [[] for i in range(bsz)] finished = [False for i in range(bsz)] num_remaining_sent = bsz # number of candidate hypos per step cand_size = 2 * beam_size # 2 x beam size in case half are EOS # offset arrays for converting between different indexing schemes bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) cand_offsets = torch.arange(0, cand_size).type_as(tokens) # helper function for allocating buffers on the fly buffers = {} def buffer(name, type_of=tokens): # noqa if name not in buffers: buffers[name] = type_of.new() return buffers[name] def is_finished(sent, step, unfinalized_scores=None): """ Check whether we've finished generation for a given sentence, by comparing the worst score among finalized hypotheses to the best possible score among unfinalized hypotheses. """ assert len(finalized[sent]) <= beam_size if len(finalized[sent]) == beam_size: return True return False def finalize_hypos(step, bbsz_idx, eos_scores): """ Finalize the given hypotheses at this step, while keeping the total number of finalized hypotheses per sentence <= beam_size. Note: the input must be in the desired finalization order, so that hypotheses that appear earlier in the input are preferred to those that appear later. Args: step: current time step bbsz_idx: A vector of indices in the range [0, bsz*beam_size), indicating which hypotheses to finalize eos_scores: A vector of the same size as bbsz_idx containing scores for each hypothesis """ assert bbsz_idx.numel() == eos_scores.numel() # clone relevant token and attention tensors tokens_clone = tokens.index_select(0, bbsz_idx) tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS assert not tokens_clone.eq(self.eos).any() tokens_clone[:, step] = self.eos attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None # compute scores per token position pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] pos_scores[:, step] = eos_scores # convert from cumulative to per-position scores pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] # normalize sentence-level scores if self.normalize_scores: eos_scores /= (step + 1) ** self.len_penalty cum_unfin = [] prev = 0 for f in finished: if f: prev += 1 else: cum_unfin.append(prev) sents_seen = set() for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())): unfin_idx = idx // beam_size sent = unfin_idx + cum_unfin[unfin_idx] sents_seen.add((sent, unfin_idx)) # if self.match_source_len and step > src_lengths[unfin_idx]: # score = -math.inf def get_hypo(): if attn_clone is not None: # remove padding tokens from attn scores hypo_attn = attn_clone[i] else: hypo_attn = None # print(hypo_attn.shape) # print(tokens_clone[i]) return { 'tokens': tokens_clone[i], 'score': score, 'attention': hypo_attn, # src_len x tgt_len 'alignment': None, 'positional_scores': pos_scores[i], } if len(finalized[sent]) < beam_size: finalized[sent].append(get_hypo()) newly_finished = [] for sent, unfin_idx in sents_seen: # check termination conditions for this sentence if not finished[sent] and is_finished(sent, step, unfin_idx): finished[sent] = True newly_finished.append(unfin_idx) return newly_finished reorder_state = None batch_idxs = None # initialize the decoder state, including: # - expanding the context over the batch dimension len_src x (B*beam) x H # - expanding the mask over the batch dimension (B*beam) x len_src decoder_states = dict() for i in range(self.n_models): decoder_states[i] = self.models[i].create_decoder_state(batch, beam_size, type=2, buffering=self.buffering) len_context = decoder_states[i].context.size(0) if self.dynamic_max_len: src_len = src.size(0) max_len = math.ceil(int(src_len) * self.dynamic_max_len_scale) # Start decoding for step in range(max_len + 1): # one extra step for EOS marker # reorder decoder internal states based on the prev choice of beams if reorder_state is not None: if batch_idxs is not None: # update beam indices to take into account removed sentences corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs) reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size) for i, model in enumerate(self.models): decoder_states[i]._reorder_incremental_state(reorder_state) decode_input = tokens[:, :step + 1] lprobs, avg_attn_scores = self._decode(decode_input, decoder_states) # avg_attn_scores = None # lprobs[:, self.pad] = -math.inf # never select pad # handle min and max length constraints if step >= max_len: lprobs[:, :self.eos] = -math.inf lprobs[:, self.eos + 1:] = -math.inf elif step < self.min_len: lprobs[:, self.eos] = -math.inf # handle prefix tokens (possibly with different lengths) # prefix_tokens = torch.tensor([[798, 1354]]).type_as(tokens) # prefix_tokens = [[1000, 1354, 2443, 1475, 1010, 242, 127, 1191, 902, 1808, 1589, 26]] if prefix_tokens is not None: prefix_tokens = torch.tensor(prefix_tokens).type_as(tokens) if step < prefix_tokens.size(1) and step < max_len: prefix_tokens = torch.tensor(prefix_tokens).type_as(tokens) prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) prefix_mask = prefix_toks.ne(self.pad) lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) lprobs[prefix_mask] = lprobs[prefix_mask].scatter( -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] ) # if prefix includes eos, then we should make sure tokens and # scores are the same across all beams eos_mask = prefix_toks.eq(self.eos) if eos_mask.any(): # validate that the first beam matches the prefix first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1] eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] assert (first_beam == target_prefix).all() def replicate_first_beam(tensor, mask): tensor = tensor.view(-1, beam_size, tensor.size(-1)) tensor[mask] = tensor[mask][:, :1, :] return tensor.view(-1, tensor.size(-1)) # copy tokens, scores and lprobs from the first beam to all beams tokens = replicate_first_beam(tokens, eos_mask_batch_dim) scores = replicate_first_beam(scores, eos_mask_batch_dim) lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim) if self.no_repeat_ngram_size > 0: # for each beam and batch sentence, generate a list of previous ngrams gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)] for bbsz_idx in range(bsz * beam_size): gen_tokens = tokens[bbsz_idx].tolist() for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]): gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \ gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]] # Record attention scores if avg_attn_scores is not None: if attn is None: attn = scores.new(bsz * beam_size, len_context , max_len + 2) attn_buf = attn.clone() attn[:, :, step + 1].copy_(avg_attn_scores) scores = scores.type_as(lprobs) scores_buf = scores_buf.type_as(lprobs) eos_bbsz_idx = buffer('eos_bbsz_idx') eos_scores = buffer('eos_scores', type_of=scores) if self.no_repeat_ngram_size > 0: def calculate_banned_tokens(bbsz_idx): # before decoding the next token, prevent decoding of ngrams that have already appeared ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist()) return gen_ngrams[bbsz_idx].get(ngram_index, []) if step + 2 - self.no_repeat_ngram_size >= 0: # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)] else: banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)] for bbsz_idx in range(bsz * beam_size): lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf # print(lprobs.shape) cand_scores, cand_indices, cand_beams = self.search.step( step, lprobs.view(bsz, -1, self.vocab_size), scores.view(bsz, beam_size, -1)[:, :, :step], ) # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] cand_bbsz_idx = cand_beams.add(bbsz_offsets) # finalize hypotheses that end in eos (except for blacklisted ones) eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) eos_mask[:, :beam_size][blacklist] = 0 # only consider eos when it's among the top beam_size indices torch.masked_select( cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_bbsz_idx, ) finalized_sents = set() if eos_bbsz_idx.numel() > 0: torch.masked_select( cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size], out=eos_scores, ) finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores) num_remaining_sent -= len(finalized_sents) assert num_remaining_sent >= 0 if num_remaining_sent == 0: break assert step < max_len if len(finalized_sents) > 0: new_bsz = bsz - len(finalized_sents) # construct batch_idxs which holds indices of batches to keep for the next pass batch_mask = cand_indices.new_ones(bsz) batch_mask[cand_indices.new(finalized_sents)] = 0 batch_idxs = batch_mask.nonzero(as_tuple=False).squeeze(-1) eos_mask = eos_mask[batch_idxs] cand_beams = cand_beams[batch_idxs] bbsz_offsets.resize_(new_bsz, 1) cand_bbsz_idx = cand_beams.add(bbsz_offsets) cand_scores = cand_scores[batch_idxs] cand_indices = cand_indices[batch_idxs] if prefix_tokens is not None: prefix_tokens = prefix_tokens[batch_idxs] src_lengths = src_lengths[batch_idxs] blacklist = blacklist[batch_idxs] scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) scores_buf.resize_as_(scores) tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) tokens_buf.resize_as_(tokens) if attn is not None: attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1) attn_buf.resize_as_(attn) bsz = new_bsz else: batch_idxs = None # Set active_mask so that values > cand_size indicate eos or # blacklisted hypos and values < cand_size indicate candidate # active hypos. After this, the min values per row are the top # candidate active hypos. active_mask = buffer('active_mask') eos_mask[:, :beam_size] |= blacklist torch.add( eos_mask.type_as(cand_offsets) * cand_size, cand_offsets[:eos_mask.size(1)], out=active_mask, ) # get the top beam_size active hypotheses, which are just the hypos # with the smallest values in active_mask active_hypos, new_blacklist = buffer('active_hypos'), buffer('new_blacklist') torch.topk( active_mask, k=beam_size, dim=1, largest=False, out=(new_blacklist, active_hypos) ) # update blacklist to ignore any finalized hypos blacklist = new_blacklist.ge(cand_size)[:, :beam_size] assert (~blacklist).any(dim=1).all() active_bbsz_idx = buffer('active_bbsz_idx') torch.gather( cand_bbsz_idx, dim=1, index=active_hypos, out=active_bbsz_idx, ) active_scores = torch.gather( cand_scores, dim=1, index=active_hypos, out=scores[:, step].view(bsz, beam_size), ) active_bbsz_idx = active_bbsz_idx.view(-1) active_scores = active_scores.view(-1) # copy tokens and scores for active hypotheses torch.index_select( tokens[:, :step + 1], dim=0, index=active_bbsz_idx, out=tokens_buf[:, :step + 1], ) torch.gather( cand_indices, dim=1, index=active_hypos, out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1], ) if step > 0: torch.index_select( scores[:, :step], dim=0, index=active_bbsz_idx, out=scores_buf[:, :step], ) torch.gather( cand_scores, dim=1, index=active_hypos, out=scores_buf.view(bsz, beam_size, -1)[:, :, step], ) # copy attention for active hypotheses if attn is not None: torch.index_select( attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, out=attn_buf[:, :, :step + 2], ) # swap buffers tokens, tokens_buf = tokens_buf, tokens scores, scores_buf = scores_buf, scores if attn is not None: attn, attn_buf = attn_buf, attn # reorder incremental state in decoder reorder_state = active_bbsz_idx # sort by score descending for sent in range(len(finalized)): finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) return finalized, gold_scores, gold_words, allgold_scores
def forward(self, x, mask, secret=None, globalx=None): """ x: agents, episode, episode_step, neighbors, agent_type m: agents, episode, episode_step, neighbors, agent_type s: agents, episode, episode_step e: 1, episode, episode_step """ x_ci_oh = onehot(x[:,:,:,:,0], mask, self.input_size, self.device) if len(self.global_input_idx) != 0: x_ci_oh[:,:,:,:,-len(self.global_input_idx):] = globalx[:,:,:,self.global_input_idx].unsqueeze(-2) if self.add_catch: catch = (x[:,:,:,:,0] == x[:,:,:,:,1]).type(th.float) x_ci_oh[:,:,:,:,-len(self.global_input_idx)-1] = catch x_ci_oh[mask] = 0 if self.control_type == 'binary': control = binary(secret, self.control_size).type(th.float) if len(self.global_control_idx) != 0: control[:,:,:,-len(self.global_control_idx):] = globalx[:,:,:,self.global_control_idx] data = x_ci_oh, mask, control elif len(self.global_control_idx) != 0: control = globalx[:,:,:,self.global_control_idx] data = x_ci_oh, mask, control else: data = x_ci_oh, mask # # random test # random_agent = th.randint(high=self.n_agents, size=(1,)).item() # random_batch = th.randint(high=x.shape[1], size=(1,)).item() # random_step = th.randint(high=x.shape[2], size=(1,)).item() # random_neighbor = th.randint(high=(x[random_agent,random_batch,random_step,:,0] != -1).sum(), size=(1,)).item() # # x_ci_oh test # color = x[random_agent,random_batch,random_step,random_neighbor,0] # vector = x_ci_oh[random_agent,random_batch,random_step,random_neighbor,:] # assert vector[color] == 1, f"{color} {vector}" # assert vector[:self.n_actions].sum() == 1 # if len(self.global_input_idx) != 0: # random_g_idx = th.randint(high=len(self.global_input_idx), size=(1,)).item() # offset = self.input_size - len(self.global_input_idx) # global_vec = globalx[random_agent,random_batch,random_step,self.global_input_idx] # assert vector[offset+random_g_idx] == global_vec[random_g_idx] # if self.add_catch: # is_catch = (color == x[random_agent,random_batch,random_step,random_neighbor,1]) # idx = - len(self.global_input_idx) - 1 # assert vector[idx] == is_catch # # end x_ci_oh test # # test control # if (secret is not None) or (len(self.global_control_idx) != 0): # secret_size = self.control_size - len(self.global_control_idx) # control_vec = control[random_agent,random_batch,random_step] # if secret is not None: # binary_secret = np.unpackbits(secret[random_agent, random_batch, random_step].numpy().astype(np.uint8))[-secret_size:][::-1].copy() # binary_secret = th.tensor(binary_secret, dtype=th.float) # assert (control_vec[:secret_size] == binary_secret).all() # if len(self.global_control_idx) != 0: # global_vec = globalx[random_agent,random_batch,random_step,self.global_control_idx] # assert (control_vec[-len(self.global_control_idx):] == global_vec).all() # # end test control if self.multi_type == 'shared_weights': data = ( d.reshape(1, d.shape[0]*d.shape[1], *d.shape[2:]) for d in data ) q = [ model(*d) for model, *d in zip(self.models, *data) ] q = th.stack(q) if self.multi_type == 'shared_weights': q = q.reshape(*x.shape[:3], -1) if self.control_type == 'permute': permutations = self.permuations[secret] q = th.gather(q, -1, permutations) return q
def forward(self, cls_heads, reg_heads, center_heads, batch_positions): with torch.no_grad(): device = cls_heads[0].device filter_scores,filter_score_classes,filter_reg_heads,filter_batch_positions=[],[],[],[] for per_level_cls_head, per_level_reg_head, per_level_center_head, per_level_position in zip( cls_heads, reg_heads, center_heads, batch_positions): per_level_cls_head = torch.sigmoid(per_level_cls_head) per_level_reg_head = torch.exp(per_level_reg_head) per_level_center_head = torch.sigmoid(per_level_center_head) per_level_cls_head = per_level_cls_head.view( per_level_cls_head.shape[0], -1, per_level_cls_head.shape[-1]) per_level_reg_head = per_level_reg_head.view( per_level_reg_head.shape[0], -1, per_level_reg_head.shape[-1]) per_level_center_head = per_level_center_head.view( per_level_center_head.shape[0], -1, per_level_center_head.shape[-1]) per_level_position = per_level_position.view( per_level_position.shape[0], -1, per_level_position.shape[-1]) scores, score_classes = torch.max(per_level_cls_head, dim=2) scores = torch.sqrt(scores * per_level_center_head.squeeze(-1)) if scores.shape[1] >= self.top_n: scores, indexes = torch.topk(scores, self.top_n, dim=1, largest=True, sorted=True) score_classes = torch.gather(score_classes, 1, indexes) per_level_reg_head = torch.gather( per_level_reg_head, 1, indexes.unsqueeze(-1).repeat(1, 1, 4)) per_level_position = torch.gather( per_level_position, 1, indexes.unsqueeze(-1).repeat(1, 1, 2)) filter_scores.append(scores) filter_score_classes.append(score_classes) filter_reg_heads.append(per_level_reg_head) filter_batch_positions.append(per_level_position) filter_scores = torch.cat(filter_scores, axis=1) filter_score_classes = torch.cat(filter_score_classes, axis=1) filter_reg_heads = torch.cat(filter_reg_heads, axis=1) filter_batch_positions = torch.cat(filter_batch_positions, axis=1) batch_scores, batch_classes, batch_pred_bboxes = [], [], [] for scores, score_classes, per_image_reg_preds, per_image_points_position in zip( filter_scores, filter_score_classes, filter_reg_heads, filter_batch_positions): pred_bboxes = self.snap_ltrb_reg_heads_to_x1_y1_x2_y2_bboxes( per_image_reg_preds, per_image_points_position) score_classes = score_classes[ scores > self.min_score_threshold].float() pred_bboxes = pred_bboxes[ scores > self.min_score_threshold].float() scores = scores[scores > self.min_score_threshold].float() one_image_scores = (-1) * torch.ones( (self.max_detection_num, ), device=device) one_image_classes = (-1) * torch.ones( (self.max_detection_num, ), device=device) one_image_pred_bboxes = (-1) * torch.ones( (self.max_detection_num, 4), device=device) if scores.shape[0] != 0: # Sort boxes sorted_scores, sorted_indexes = torch.sort(scores, descending=True) sorted_score_classes = score_classes[sorted_indexes] sorted_pred_bboxes = pred_bboxes[sorted_indexes] keep = nms(sorted_pred_bboxes, sorted_scores, self.nms_threshold) keep_scores = sorted_scores[keep] keep_classes = sorted_score_classes[keep] keep_pred_bboxes = sorted_pred_bboxes[keep] final_detection_num = min(self.max_detection_num, keep_scores.shape[0]) one_image_scores[0:final_detection_num] = keep_scores[ 0:final_detection_num] one_image_classes[0:final_detection_num] = keep_classes[ 0:final_detection_num] one_image_pred_bboxes[ 0:final_detection_num, :] = keep_pred_bboxes[ 0:final_detection_num, :] one_image_scores = one_image_scores.unsqueeze(0) one_image_classes = one_image_classes.unsqueeze(0) one_image_pred_bboxes = one_image_pred_bboxes.unsqueeze(0) batch_scores.append(one_image_scores) batch_classes.append(one_image_classes) batch_pred_bboxes.append(one_image_pred_bboxes) batch_scores = torch.cat(batch_scores, axis=0) batch_classes = torch.cat(batch_classes, axis=0) batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0) # batch_scores shape:[batch_size,max_detection_num] # batch_classes shape:[batch_size,max_detection_num] # batch_pred_bboxes shape[batch_size,max_detection_num,4] return batch_scores, batch_classes, batch_pred_bboxes
def decode_z_parallel(self, batch_size, u_input, u_hiddens, u_input_1hot, u_last_hidden, z_input, turn_states, sample_type, decoder_type, qz_samples=None, qz_hiddens=None, m_input=None, m_hiddens=None, m_input_1hot=None, mask_otlg=False): return_gmb = True if 'gumbel' in sample_type else False slot_num = len(self.reader.otlg.informable_slots) pv_z_pr = turn_states.get('pv_%s_pr'%decoder_type, None) pv_z_h = turn_states.get('pv_%s_h'%decoder_type, None) pv_z_id = turn_states.get('pv_%s_id'%decoder_type, None) z_prob, z_samples, gmb_samples = [], [], [] log_pz = 0 last_hidden = u_last_hidden[:-1].repeat(1, slot_num, 1) # [1, B*|slot|, H] u_input = u_input.repeat(slot_num ,1) u_input_1hot = u_input_1hot.repeat(slot_num, 1, 1) # u_input_1hot = get_one_hot_input(u_input, self.vocab_size) u_hiddens = u_hiddens.repeat(slot_num ,1, 1) if decoder_type != 'pz': m_input = m_input.repeat(slot_num ,1) m_input_1hot = m_input_1hot.repeat(slot_num, 1, 1) m_hiddens = m_hiddens.repeat(slot_num ,1, 1) # m_input_1hot = get_one_hot_input(m_input, self.vocab_size) if pv_z_pr is not None: # pv_z_pr = pv_z_pr.transpose(1,0).reshape(self.z_length, -1, cfg.vocab_size).transpose(1,0) # [B*|slot|, T, V] # pv_z_h = pv_z_h.transpose(1,0).reshape(self.z_length, -1, cfg.hidden_size).transpose(1,0) # [B*|slot|, T, H] # pv_z_id = pv_z_id.transpose(1,0).reshape(self.z_length, -1).transpose(1,0) # [B*|slot|, T] pv_z_prob, pv_z_hid, pv_z_idx = [], [], [] for si, sn in enumerate(self.reader.otlg.informable_slots): pv_z_prob.append(pv_z_pr[:, si*self.z_length : (si+1)*self.z_length]) pv_z_hid.append(pv_z_h[:, si*self.z_length : (si+1)*self.z_length]) pv_z_idx.append(pv_z_id[:, si*self.z_length : (si+1)*self.z_length]) pv_z_pr = torch.cat(pv_z_prob, dim=0) pv_z_h = torch.cat(pv_z_hid, dim=0) pv_z_id = torch.cat(pv_z_idx, dim=0) # print(pv_z_pr.size()) # print(pv_z_id) emb_zt, z_eos_idx = [], [] for si, sn in enumerate(self.reader.otlg.informable_slots): z_eos_idx.append(self.vocab.encode(self.z_eos_map[sn])) emb_zt.append(self.get_first_z_input(sn, batch_size, self.multi_domain)) emb_zt = torch.cat(emb_zt, dim=0) if z_input is not None: z_input_cat = [] for si, sn in enumerate(self.reader.otlg.informable_slots): z_input_cat.append(z_input[sn]) z_input_cat = torch.cat(z_input_cat, dim=0) zero_vec = cuda_(torch.zeros(batch_size * slot_num, 1, self.hidden_size)) selc_read_u = selc_read_m = selc_read_pv_z = zero_vec prev_zt = None for t in range(self.z_length): if decoder_type == 'pz': prob, last_hidden, gru_out, selc_read_u, selc_read_pv_z = \ self.pz_decoder(u_input, u_input_1hot, u_hiddens, pv_z_prob=pv_z_pr, pv_z_hidden=pv_z_h, pv_z_idx=pv_z_id, emb_zt=emb_zt, last_hidden=last_hidden, selc_read_u=selc_read_u, selc_read_pv_z=selc_read_pv_z) else: prob, last_hidden, gru_out, selc_read_u, selc_read_m, selc_read_pv_z, gmb_samp = \ self.qz_decoder(u_input, u_input_1hot, u_hiddens, m_input, m_input_1hot, m_hiddens, pv_z_prob=pv_z_pr, pv_z_hidden=pv_z_h, pv_z_idx=pv_z_id, emb_zt=emb_zt, last_hidden=last_hidden, selc_read_u=selc_read_u, selc_read_m=selc_read_m, selc_read_pv_z=selc_read_pv_z, temp=self.gumbel_temp, return_gmb=return_gmb) if mask_otlg: prob = self.mask_probs(prob, tokens_allow=self.reader.slot_value_mask[sn]) if sample_type == 'supervised': # zt = z_input[sn][:, t] zt = z_input_cat[:, t] elif sample_type == 'top1': zt = torch.topk(prob, 1)[1] elif sample_type == 'topk': topk_probs, topk_words = torch.topk(prob.squeeze(1), cfg.topk_num) widx = torch.multinomial(topk_probs, 1, replacement=True) zt = torch.gather(topk_words, 1, widx) #[B] elif sample_type == 'posterior': zt = qz_samples[:, si * self.z_length + t] elif 'gumbel' in sample_type: zt = torch.argmax(gmb_samp, dim=1) #[B] emb_zt = torch.matmul(gmb_samp, self.embedding.weight).unsqueeze(1) # [B, 1, H] zt, prev_zt, gmb_samp = self.mask_samples(zt, prev_zt, batch_size, z_eos_idx, gmb_samp, True) gmb_samples.append(gmb_samp) if 'gumbel' not in sample_type: emb_zt = self.embedding(zt.view(-1, 1)) prob_zt = torch.gather(prob, 1, zt.view(-1, 1)).squeeze(1) #[B, 1] log_prob_zt = torch.log(prob_zt) zt, prev_zt, log_prob_zt = self.mask_samples(zt, prev_zt, batch_size, z_eos_idx, log_prob_zt) log_pz += log_prob_zt z_samples.append(zt.view(-1)) z_prob.append(prob) z_prob = torch.stack(z_prob, dim=1) # [B*|slot|,Tz,V] z_samples= torch.stack(z_samples, dim=1) # [B*|slot|,Tz] z_prob_col, z_samples_col = [], [] for i in range(slot_num): z_prob_col.append(z_prob[i*batch_size : (i+1)*batch_size]) z_samples_col.append(z_samples[i*batch_size : (i+1)*batch_size]) z_prob = torch.cat(z_prob_col, dim=1) # [B,Tz*|slot|,V] z_samples= torch.cat(z_samples_col, dim=1) # [B,Tz*|slot|] # Tz*|slot|, B if sample_type == 'posterior': z_samples, z_hiddens = qz_samples, qz_hiddens elif 'gumbel' not in sample_type: z_hiddens, z_last_hidden = self.z_encoder(z_samples, input_type='index') else: z_gumbel = torch.stack(gmb_samples, dim=1) # [B, Tz, V] z_gumbel = torch.matmul(z_gumbel, self.embedding.weight) # [B, Tz, E] z_hiddens, z_last_hidden = self.z_encoder(z_gumbel, input_type='embedding') retain = self.prev_z_continuous turn_states['pv_%s_h'%decoder_type] = z_hiddens if retain else z_hiddens.detach() turn_states['pv_%s_pr'%decoder_type] = z_prob if retain else z_prob.detach() turn_states['pv_%s_id'%decoder_type] = z_samples if retain else z_samples.detach() return z_prob, z_samples, z_hiddens, turn_states, log_pz
def forward(self, feats, SISMs): N, C, H, W = feats.shape HW = H * W # Resize SISMs to the same size as the input feats. SISMs = resize(SISMs, [H, W]) # shape=[N, 1, H, W] # NFs: L2-normalized features. NFs = F.normalize(feats, dim=1) # shape=[N, C, H, W] def CFM(SIVs, NFs): # Compute correlation maps [Figure 4] between SIVs and pixel-wise feature vectors in NFs by inner product. # We implement this process by ``F.conv2d()'', which takes SIVs as 1*1 kernels to convolve NFs. correlation_maps = F.conv2d(NFs, weight=SIVs) # shape=[N, N, H, W] # Vectorize and normalize correlation maps. correlation_maps = F.normalize(correlation_maps.reshape(N, N, HW), dim=2) # shape=[N, N, HW] # Compute the weight vectors [Equation 2]. correlation_matrix = torch.matmul(correlation_maps, correlation_maps.permute( 0, 2, 1)) # shape=[N, N, N] weight_vectors = correlation_matrix.sum(dim=2).softmax( dim=1) # shape=[N, N] # Fuse correlation maps with the weight vectors to build co-salient attention (CSA) maps. CSA_maps = torch.sum(correlation_maps * weight_vectors.view(N, N, 1), dim=1) # shape=[N, HW] # Max-min normalize CSA maps. min_value = torch.min(CSA_maps, dim=1, keepdim=True)[0] max_value = torch.max(CSA_maps, dim=1, keepdim=True)[0] CSA_maps = (CSA_maps - min_value) / (max_value - min_value + 1e-12 ) # shape=[N, HW] CSA_maps = CSA_maps.view(N, 1, H, W) # shape=[N, 1, H, W] return CSA_maps def get_SCFs(NFs): NFs = NFs.view(N, C, HW) # shape=[N, C, HW] SCFs = torch.matmul(NFs.permute(0, 2, 1), NFs).view(N, -1, H, W) # shape=[N, HW, H, W] return SCFs # Compute SIVs [Section 3.2, Equation 1]. SIVs = F.normalize((NFs * SISMs).mean(dim=3).mean(dim=2), dim=1).view(N, C, 1, 1) # shape=[N, C, 1, 1] # Compute co-salient attention (CSA) maps [Section 3.3]. CSA_maps = CFM(SIVs, NFs) # shape=[N, 1, H, W] # Compute self-correlation features (SCFs) [Section 3.4]. SCFs = get_SCFs(NFs) # shape=[N, HW, H, W] # Rearrange the channel order of SCFs to obtain RSCFs [Section 3.4]. evidence = CSA_maps.view(N, HW) # shape=[N, HW] indices = torch.argsort(evidence, dim=1, descending=True).view( N, HW, 1, 1).repeat(1, 1, H, W) # shape=[N, HW, H, W] RSCFs = torch.gather(SCFs, dim=1, index=indices) # shape=[N, HW, H, W] cosal_feat = self.conv(RSCFs * CSA_maps) # shape=[N, 128, H, W] return cosal_feat
def beam_decode(self, initial_state, encoder_hidden_states, code_hidden_states, old_nl_hidden_states, masks, max_out_len, batch_data, code_masks, old_nl_masks, device): """Beam search. Generates the top K candidate predictions.""" batch_size = initial_state.shape[0] decoded_batch = [list() for _ in range(batch_size)] decoded_batch_scores = np.zeros([batch_size, BEAM_SIZE]) decoder_input = torch.tensor( [[self.embedding_store.get_nl_id(START)]] * batch_size, device=device) decoder_input = decoder_input.unsqueeze(1) decoder_state = initial_state.unsqueeze(1).expand( -1, decoder_input.shape[1], -1).reshape(-1, initial_state.shape[-1]) beam_scores = torch.ones([batch_size, 1], dtype=torch.float32, device=device) beam_status = torch.zeros([batch_size, 1], dtype=torch.uint8, device=device) beam_predicted_ids = torch.full([batch_size, 1, max_out_len], self.embedding_store.get_end_id(), dtype=torch.int64, device=device) for i in range(max_out_len): beam_size = decoder_input.shape[1] if beam_status[:, 0].sum() == batch_size: break tiled_encoder_states = encoder_hidden_states.unsqueeze(1).expand( -1, beam_size, -1, -1) tiled_masks = masks.unsqueeze(1).expand(-1, beam_size, -1, -1) tiled_code_hidden_states = code_hidden_states.unsqueeze(1).expand( -1, beam_size, -1, -1) tiled_code_masks = code_masks.unsqueeze(1).expand( -1, beam_size, -1, -1) tiled_old_nl_hidden_states = old_nl_hidden_states.unsqueeze( 1).expand(-1, beam_size, -1, -1) tiled_old_nl_masks = old_nl_masks.unsqueeze(1).expand( -1, beam_size, -1, -1) flat_decoder_input = decoder_input.reshape(-1, decoder_input.shape[-1]) flat_encoder_states = tiled_encoder_states.reshape( -1, tiled_encoder_states.shape[-2], tiled_encoder_states.shape[-1]) flat_masks = tiled_masks.reshape(-1, tiled_masks.shape[-2], tiled_masks.shape[-1]) flat_code_hidden_states = tiled_code_hidden_states.reshape( -1, tiled_code_hidden_states.shape[-2], tiled_code_hidden_states.shape[-1]) flat_code_masks = tiled_code_masks.reshape( -1, tiled_code_masks.shape[-2], tiled_code_masks.shape[-1]) flat_old_nl_hidden_states = tiled_old_nl_hidden_states.reshape( -1, tiled_old_nl_hidden_states.shape[-2], tiled_old_nl_hidden_states.shape[-1]) flat_old_nl_masks = tiled_old_nl_masks.reshape( -1, tiled_old_nl_masks.shape[-2], tiled_old_nl_masks.shape[-1]) decoder_input_embeddings = self.embedding_store.get_nl_embeddings( flat_decoder_input) decoder_attention_states, flat_decoder_state, generation_logprobs, copy_logprobs = self.decode( decoder_state, decoder_input_embeddings, flat_encoder_states, flat_code_hidden_states, flat_old_nl_hidden_states, flat_masks, flat_code_masks, flat_old_nl_masks) generation_logprobs = generation_logprobs.squeeze(1) copy_logprobs = copy_logprobs.squeeze(1) generation_logprobs = generation_logprobs.reshape( batch_size, beam_size, generation_logprobs.shape[-1]) copy_logprobs = copy_logprobs.reshape(batch_size, beam_size, copy_logprobs.shape[-1]) prob_scores = torch.zeros([ batch_size, beam_size, generation_logprobs.shape[-1] + copy_logprobs.shape[-1] ], dtype=torch.float32, device=device) prob_scores[:, :, :generation_logprobs.shape[-1]] = torch.exp( generation_logprobs) # Factoring in the copy scores expanded_token_ids = batch_data.input_ids.unsqueeze(1).expand( -1, beam_size, -1) prob_scores += scatter_add(src=torch.exp(copy_logprobs), index=expanded_token_ids, out=torch.zeros_like(prob_scores)) top_scores_per_beam, top_indices_per_beam = torch.topk(prob_scores, k=BEAM_SIZE, dim=-1) updated_scores = torch.einsum('eb,ebm->ebm', beam_scores, top_scores_per_beam) retained_scores = beam_scores.unsqueeze(-1).expand( -1, -1, top_scores_per_beam.shape[-1]) # Trying to keep at most one ray corresponding to completed beams end_mask = (torch.arange(beam_size) == 0).type( torch.float32).to(device) end_scores = torch.einsum('b,ebm->ebm', end_mask, retained_scores) possible_next_scores = torch.where( beam_status.unsqueeze(-1) == 1, end_scores, updated_scores) possible_next_status = torch.where( top_indices_per_beam == self.embedding_store.get_end_id(), torch.ones( [batch_size, beam_size, top_scores_per_beam.shape[-1]], dtype=torch.uint8, device=device), beam_status.unsqueeze(-1).expand( -1, -1, top_scores_per_beam.shape[-1])) possible_beam_predicted_ids = beam_predicted_ids.unsqueeze( 2).expand(-1, -1, top_scores_per_beam.shape[-1], -1) pool_next_scores = possible_next_scores.reshape(batch_size, -1) pool_next_status = possible_next_status.reshape(batch_size, -1) pool_next_ids = top_indices_per_beam.reshape(batch_size, -1) pool_predicted_ids = possible_beam_predicted_ids.reshape( batch_size, -1, beam_predicted_ids.shape[-1]) possible_decoder_state = flat_decoder_state.reshape( batch_size, beam_size, flat_decoder_state.shape[-1]) possible_decoder_state = possible_decoder_state.unsqueeze( 2).expand(-1, -1, top_scores_per_beam.shape[-1], -1) pool_decoder_state = possible_decoder_state.reshape( batch_size, -1, possible_decoder_state.shape[-1]) top_scores, top_indices = torch.topk(pool_next_scores, k=BEAM_SIZE, dim=-1) next_step_ids = torch.gather(pool_next_ids, -1, top_indices) decoder_state = torch.gather( pool_decoder_state, 1, top_indices.unsqueeze(-1).expand(-1, -1, pool_decoder_state.shape[-1])) decoder_state = decoder_state.reshape(-1, decoder_state.shape[-1]) beam_status = torch.gather(pool_next_status, -1, top_indices) beam_scores = torch.gather(pool_next_scores, -1, top_indices) end_tags = torch.full_like(next_step_ids, self.embedding_store.get_end_id()) next_step_ids = torch.where(beam_status == 1, end_tags, next_step_ids) beam_predicted_ids = torch.gather( pool_predicted_ids, 1, top_indices.unsqueeze(-1).expand(-1, -1, pool_predicted_ids.shape[-1])) beam_predicted_ids[:, :, i] = next_step_ids unks = torch.full_like( next_step_ids, self.embedding_store.get_nl_id(Vocabulary.get_unk())) decoder_input = torch.where( next_step_ids < len(self.embedding_store.nl_vocabulary), next_step_ids, unks).unsqueeze(-1) return beam_predicted_ids, beam_scores
def forward(self, im_data, gt_boxes, im_info): batch_size = im_data.size(0) im_info = im_info.data if not gt_boxes is None: gt_boxes = gt_boxes.data # feed image data to base model to obtain base feature map base_feat = self.RCNN_base(im_data) # feed base feature map to RPN to obtain rois rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(base_feat, im_info, gt_boxes) # if it is training phase, then use ground truth bboxes for refining if self.training: roi_data = self.RCNN_proposal_target(rois, gt_boxes) rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data rois_label = Variable(rois_label.view(-1).long()) rois_target = Variable(rois_target.view(-1, rois_target.size(2))) rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2))) rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2))) else: rois_label = None rois_target = None rois_inside_ws = None rois_outside_ws = None rpn_loss_cls = 0 rpn_loss_bbox = 0 rois = Variable(rois) # do roi pooling based on predicted rois if cfg.pooling_mode == 'align': # pooled_feat = self.RCNN_roi_align(feature_map, rois.view(-1, 5)) pooled_feat = roi_align(base_feat, rois.view(-1, 5), (cfg.pool_size, cfg.pool_size), 1.0/16) elif cfg.pooling_mode == 'pool': #pooled_feat = self.RCNN_roi_pool(feature_map, rois.view(-1, 5)) pooled_feat = roi_pool(base_feat, rois.view(-1, 5), (cfg.pool_size, cfg.pool_size), 1.0/16) # feed pooled features to top model pooled_feat = self._head_to_tail(pooled_feat) # compute bbox offset bbox_pred = self.RCNN_bbox_pred(pooled_feat) if self.training and not self.class_agnostic: # select the corresponding columns according to roi labels bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4) bbox_pred_select = torch.gather(bbox_pred_view, 1, rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4)) bbox_pred = bbox_pred_select.squeeze(1) # compute object classification probability cls_score = self.RCNN_cls_score(pooled_feat) cls_prob = F.softmax(cls_score, 1) RCNN_loss_cls = 0 RCNN_loss_bbox = 0 if self.training: # classification loss RCNN_loss_cls = F.cross_entropy(cls_score, rois_label) # bounding box regression L1 loss RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws) cls_prob = cls_prob.view(batch_size, rois.size(1), -1) bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1) return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label
def forward(self, x, loc, src, H, W, N_grid): if self.pre_conv is not None: x_map = token2map(x, loc, [H, W], self.inter_kernel, self.inter_sigma) x_map = self.pre_conv(x_map) x1 = map2token(x_map, loc) x = self.pre_conv2(x) x = x + x1 B, N, C = x.shape x = self.conf_norm(x) # confidence based sampling conf = self.conf(x) # down sample if self.sample_ratio < 1: sample_num = max(math.ceil((N - N_grid) * self.sample_ratio), 0) x_grid, loc_grid = x[:, :N_grid, :], loc[:, :N_grid, :] x_ada, loc_ada = x[:, N_grid:, :], loc[:, N_grid:, :] conf_ada = conf[:, N_grid:, :] index_down = gumble_top_k(conf_ada, sample_num, dim=1, T=1) loc_down = torch.gather(loc_ada, 1, index_down.expand([B, sample_num, 2])) x_down = torch.gather(x_ada, 1, index_down.expand([B, sample_num, C])) x_down = torch.cat([x_grid, x_down], dim=1) loc_down = torch.cat([loc_grid, loc_down], dim=1) else: x_down, loc_down = x, loc # extra points for low-level feature if self.extra_ratio > 0: # high res grid conf_map = token2map(conf, loc, [H, W], self.inter_kernel, self.inter_sigma) loc_extra = get_grid_loc(B, self.HR_res[0], self.HR_res[1], device=x.device) conf_extra = map2token(conf_map, loc_extra) extra_num = int(N * self.extra_ratio) index_extra = gumble_top_k(conf_extra, extra_num, dim=1, T=1) loc_extra = torch.gather(loc_extra, 1, index_extra.expand([B, extra_num, 2])) x_extra = inter_points(x, loc, loc_extra) conf_extra = inter_points(conf, loc, loc_extra) if self.use_local: local = extract_local_feature(src, loc_extra, self.local_kernel) local = local.flatten(2) local = self.local_conv1(local, x_extra) local = self.local_act1(local) local = self.local_conv2(local, x_extra) local = self.local_act2(local) x_extra = x_extra + local x = torch.cat([x, x_extra], dim=1) loc = torch.cat([loc, loc_extra], dim=1) conf = torch.cat([conf, conf_extra], dim=1) # attention block x_down = self.norm1(x_down) x_down = x_down + self.drop_path(self.attn(x_down, x, loc, H, W, conf)) x_down = self.norm2(x_down) kernel_size = self.attn.sr_ratio + 1 if self.sample_ratio <= 0.25: H, W = H // 2, W // 2 x_down = x_down + self.drop_path( self.mlp(x_down, loc_down, H, W, kernel_size, 2)) if vis and self.extra_ratio > 0: import matplotlib.pyplot as plt IMAGENET_DEFAULT_MEAN = torch.tensor([0.485, 0.456, 0.406], device=src.device)[None, :, None, None] IMAGENET_DEFAULT_STD = torch.tensor([0.229, 0.224, 0.225], device=src.device)[None, :, None, None] src = src * IMAGENET_DEFAULT_STD + IMAGENET_DEFAULT_MEAN # for i in range(x.shape[0]): for i in range(1): img = src[i].permute(1, 2, 0).detach().cpu() ax = plt.subplot(1, 3, 1) ax.clear() conf_map = token2map(conf, loc, [H, W], self.inter_kernel, self.inter_sigma) conf_map = F.interpolate(conf_map, self.HR_res, mode='bilinear') # conf_map = token2map(conf, loc, self.HR_res, 1 + (self.inter_kernel-1) * self.HR_res[0] // H, self.inter_sigma) ax.imshow(conf_map[i, 0].detach().cpu()) ax = plt.subplot(1, 3, 2) ax.clear() ax.imshow(img, extent=[0, 1, 0, 1]) loc_show = loc loc_show = (loc_show + 1) * 0.5 loc_grid = loc_show[i, :N_grid].detach().cpu().numpy() ax.scatter(loc_grid[:, 0], 1 - loc_grid[:, 1], c='blue', s=0.5) loc_ada = loc_show[i, N_grid:].detach().cpu().numpy() ax.scatter(loc_ada[:, 0], 1 - loc_ada[:, 1], c='red', s=0.5) ax = plt.subplot(1, 3, 3) ax.clear() ax.imshow(img, extent=[0, 1, 0, 1]) loc_show = loc_down loc_show = (loc_show + 1) * 0.5 loc_grid = loc_show[i, :N_grid].detach().cpu().numpy() ax.scatter(loc_grid[:, 0], 1 - loc_grid[:, 1], c='blue', s=0.5) loc_ada = loc_show[i, N_grid:].detach().cpu().numpy() ax.scatter(loc_ada[:, 0], 1 - loc_ada[:, 1], c='red', s=0.5) return x_down, loc_down
def forward(self, cls_heads, reg_heads, batch_anchors): device = cls_heads[0].device with torch.no_grad(): filter_scores,filter_score_classes,filter_reg_heads,filter_batch_anchors=[],[],[],[] for per_level_cls_head, per_level_reg_head, per_level_anchor in zip( cls_heads, reg_heads, batch_anchors): scores, score_classes = torch.max(per_level_cls_head, dim=2) if scores.shape[1] >= self.top_n: scores, indexes = torch.topk(scores, self.top_n, dim=1, largest=True, sorted=True) score_classes = torch.gather(score_classes, 1, indexes) per_level_reg_head = torch.gather( per_level_reg_head, 1, indexes.unsqueeze(-1).repeat(1, 1, 4)) per_level_anchor = torch.gather( per_level_anchor, 1, indexes.unsqueeze(-1).repeat(1, 1, 4)) filter_scores.append(scores) filter_score_classes.append(score_classes) filter_reg_heads.append(per_level_reg_head) filter_batch_anchors.append(per_level_anchor) filter_scores = torch.cat(filter_scores, axis=1) filter_score_classes = torch.cat(filter_score_classes, axis=1) filter_reg_heads = torch.cat(filter_reg_heads, axis=1) filter_batch_anchors = torch.cat(filter_batch_anchors, axis=1) batch_scores, batch_classes, batch_pred_bboxes = [], [], [] for per_image_scores, per_image_score_classes, per_image_reg_heads, per_image_anchors in zip( filter_scores, filter_score_classes, filter_reg_heads, filter_batch_anchors): pred_bboxes = self.snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes( per_image_reg_heads, per_image_anchors) score_classes = per_image_score_classes[ per_image_scores > self.min_score_threshold].float() pred_bboxes = pred_bboxes[ per_image_scores > self.min_score_threshold].float() scores = per_image_scores[ per_image_scores > self.min_score_threshold].float() one_image_scores = (-1) * torch.ones( (self.max_detection_num, ), device=device) one_image_classes = (-1) * torch.ones( (self.max_detection_num, ), device=device) one_image_pred_bboxes = (-1) * torch.ones( (self.max_detection_num, 4), device=device) if scores.shape[0] != 0: # Sort boxes sorted_scores, sorted_indexes = torch.sort(scores, descending=True) sorted_score_classes = score_classes[sorted_indexes] sorted_pred_bboxes = pred_bboxes[sorted_indexes] keep = nms(sorted_pred_bboxes, sorted_scores, self.nms_threshold) keep_scores = sorted_scores[keep] keep_classes = sorted_score_classes[keep] keep_pred_bboxes = sorted_pred_bboxes[keep] final_detection_num = min(self.max_detection_num, keep_scores.shape[0]) one_image_scores[0:final_detection_num] = keep_scores[ 0:final_detection_num] one_image_classes[0:final_detection_num] = keep_classes[ 0:final_detection_num] one_image_pred_bboxes[ 0:final_detection_num, :] = keep_pred_bboxes[ 0:final_detection_num, :] one_image_scores = one_image_scores.unsqueeze(0) one_image_classes = one_image_classes.unsqueeze(0) one_image_pred_bboxes = one_image_pred_bboxes.unsqueeze(0) batch_scores.append(one_image_scores) batch_classes.append(one_image_classes) batch_pred_bboxes.append(one_image_pred_bboxes) batch_scores = torch.cat(batch_scores, axis=0) batch_classes = torch.cat(batch_classes, axis=0) batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0) # batch_scores shape:[batch_size,max_detection_num] # batch_classes shape:[batch_size,max_detection_num] # batch_pred_bboxes shape[batch_size,max_detection_num,4] return batch_scores, batch_classes, batch_pred_bboxes
def forward( self, scene_inds, txt_feats, txt_masks, hiddens, src_feats, src_masks, # in case we'd like to try using image contexts as input tgt_feats, tgt_masks, sample_mode): """ Args: - **scene_inds** (bsize, ) - **txt_feats** (bsize, nturns, n_feature_dim) - **txt_masks** (bsize, nturns) - **hiddens** (num_layers, bsize, n_feature_dim) - **src_feats** (bsize, nturns, nregions, n_feature_dim) - **src_masks** (bsize, nturns, nregions) - **tgt_feats** (bsize, nturns, nregions, n_feature_dim) - **tgt_masks** (bsize, nturns, nregions) - **sample_mode** 0: top1 1: multinomial sampling 2: circular 3: fixed indices 4: random 5: rollout greedy search Returns - **output_feats** (bsize, nturns, (ninsts), n_feature_dim) - **next_hiddens** (num_layers, bsize, n_feature_dim) - **sample_logits** (bsize, nturns) - **sample_indices** (bsize, nturns) """ input_feats = txt_feats if not self.cfg.use_txt_context: #TODO: do NOT use updater? bsize, nturns, fsize = input_feats.size() output_feats = input_feats # For paragraph model # output_feats = self.updater(input_feats, hiddens.view(bsize, 1, fsize).expand(bsize, nturns, fsize)) return output_feats, None, None, None else: if self.cfg.instance_dim < 2: # one dimensional context self.updater.flatten_parameters() output_feats, next_hiddens = self.updater( input_feats, hiddens.unsqueeze(0)) return output_feats, next_hiddens, None, None else: # two dimensional context bsize, nturns, input_dim = input_feats.size() bsize, ninsts, hidden_dim = hiddens.size() current_hiddens = hiddens output_feats, sample_logits, sample_indices, sample_rewards = [], [], [], [] for i in range(nturns): ####################################################### # search for the instance indices ####################################################### query_feats = input_feats[:, i].unsqueeze(1) if self.cfg.rl_finetune > 0: ####################################################### # Learnable policy ####################################################### if sample_mode < 2: ################################################### ## Inference mode ################################################### if i < self.cfg.instance_dim: instance_inds = ( (i % self.cfg.instance_dim) * query_feats.new_ones(bsize)).long() logits = instance_inds.new_ones( bsize, self.cfg.instance_dim).float() else: instance_inds, logits = self.policy( query_feats.detach(), current_hiddens.detach(), sample_mode) elif sample_mode == 5: ################################################### # rollout greedy search ################################################### instance_inds, rewards = \ self.rollout_search( i, nturns, input_feats[:, i:].view(bsize, nturns-i, input_dim).detach(), txt_masks[:, i:].view(bsize, nturns-i).detach(), scene_inds, current_hiddens.detach(), tgt_feats.detach(), tgt_masks.detach(), sample_mode=1) ######################################################################## # TODO: whether to backprop more ######################################################################## _, logits = self.policy(query_feats.detach(), current_hiddens.detach(), 1) sample_rewards.append(rewards) else: ####################################################### # Fixed policies ####################################################### if sample_mode == 2: instance_inds = ( (i % self.cfg.instance_dim) * query_feats.new_ones(bsize)).long() elif sample_mode == 3: instance_inds = ( query_feats.new_zeros(bsize)).long() elif sample_mode == 4: instance_inds = torch.randint( 0, self.cfg.instance_dim, size=(bsize, )).long() if self.cfg.cuda: instance_inds = instance_inds.cuda() _, logits = self.policy(query_feats.detach(), current_hiddens.detach(), 1) sample_indices.append(instance_inds) sample_logits.append(logits) ####################################################### # update the hidden states using the instance indices ####################################################### instance_inds = instance_inds.view(bsize, 1, 1).expand( bsize, 1, hidden_dim) sample_hiddens = torch.gather(current_hiddens, 1, instance_inds) h = self.updater(query_feats, sample_hiddens) next_hiddens = current_hiddens.clone() next_hiddens.scatter_(dim=1, index=instance_inds, src=h) output_feats.append(next_hiddens) current_hiddens = next_hiddens sample_indices = torch.stack(sample_indices, 1) sample_logits = torch.stack(sample_logits, 1) output_feats = torch.stack(output_feats, 1) return output_feats, next_hiddens, sample_logits, sample_indices
def train(self, batch: EpisodeBatch, t_env: int, episode_num: int, chosen_index=0, return_q_all=False): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values mac_out = [] if self.args.q_net_ensemble: mac_chosen = self.mac[chosen_index] else: mac_chosen = self.mac mac_chosen.init_hidden(batch.batch_size) #self.mac.init_latent(batch.batch_size) index = th.randint(mac_chosen.n_agents, [batch.batch_size]) index = F.one_hot(index, mac_chosen.n_agents).to(th.bool) rp = random.random() < self.args.contrary_grad_p if self.args.random_agent_order: enemy_shape = self.scheme["obs"]["vshape"] - mac_chosen.n_agents * 8 enemy_num = enemy_shape // 8 assert enemy_num * 8 == enemy_shape order_enemy = np.arange(enemy_num) order_ally = np.arange(mac_chosen.n_agents - 1) np.random.shuffle(order_enemy) np.random.shuffle(order_ally) agent_order = [order_ally, order_enemy] else: agent_order = None for t in range(batch.max_seq_length): if self.args.mac == "robust_mac": agent_outs = mac_chosen.forward(batch, t=t, index=index, contrary_grad=rp, agent_order=agent_order) else: agent_outs = mac_chosen.forward( batch, t=t, agent_order=agent_order) #(bs,n,n_actions) mac_out.append(agent_outs) #[t,(bs,n,n_actions)] mac_out = th.stack(mac_out, dim=1) # Concat over time #(bs,t,n,n_actions), Q values of n_actions # Pick the Q-Values for the actions taken by each agent chosen_action_qvals_bm = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim # (bs,t,n) Q value of an action # Calculate the Q-Values necessary for the target target_mac_out = [] if self.args.q_net_ensemble: target_mac_chosen = self.target_mac[chosen_index] else: target_mac_chosen = self.target_mac target_mac_chosen.init_hidden(batch.batch_size) # (bs,n,hidden_size) #self.target_mac.init_latent(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = target_mac_chosen.forward( batch, t=t, agent_order=agent_order) #(bs,n,n_actions) target_mac_out.append(target_agent_outs) #[t,(bs,n,n_actions)] # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack( target_mac_out[1:], dim=1) # Concat across time, dim=1 is time index #(bs,t,n,n_actions) # Mask out unavailable actions target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # Q values # Max over target Q-Values if self.args.double_q: # True for QMix # Get actions that maximise live Q (for double q-learning) mac_out_detach = mac_out.clone().detach( ) #return a new Tensor, detached from the current graph mac_out_detach[avail_actions == 0] = -9999999 # (bs,t,n,n_actions), discard t=0 cur_max_actions = mac_out_detach[:, 1:].max( dim=3, keepdim=True)[1] # indices instead of values # (bs,t,n,1) target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) # (bs,t,n,n_actions) ==> (bs,t,n,1) ==> (bs,t,n) max target-Q else: target_max_qvals = target_mac_out.max(dim=3)[0] # Mix loss_ensemble_w = th.tensor(0.0).to(target_max_qvals.device) loss_ensemble_b = th.tensor(0.0).to(target_max_qvals.device) if self.mixer is not None: if return_q_all: q_all = chosen_action_qvals_bm.detach().cpu().numpy() # chosen_index = random.randint(0,len(self.mixer_list)-1) chosen_mixer = self.mixer_list[chosen_index] chosen_target_mixer = self.target_mixer_list[chosen_index] if self.args.mixer == "aqmix": chosen_action_qvals = chosen_mixer(chosen_action_qvals_bm, batch["state"][:, :-1], actions.detach()) target_max_qvals = chosen_target_mixer( target_max_qvals, batch["state"][:, 1:], cur_max_actions.detach()) else: chosen_action_qvals = chosen_mixer(chosen_action_qvals_bm, batch["state"][:, :-1]) target_max_qvals = chosen_target_mixer(target_max_qvals, batch["state"][:, 1:]) if len(self.mixer_list) > 1: other_w_list = [] other_b_list = [] chosen_action_qvals, w, b = chosen_action_qvals target_max_qvals, _, _ = target_max_qvals for i in range(len(self.mixer_list)): if i != chosen_index: if self.args.mixer == "aqmix": _, other_w, other_b = self.mixer_list[i]( chosen_action_qvals_bm, batch["state"][:, :-1], actions.detach()) else: _, other_w, other_b = self.mixer_list[i]( chosen_action_qvals_bm, batch["state"][:, :-1]) other_w_list.append(other_w.detach()) other_b_list.append(other_b.detach()) for other_w, other_b in zip(other_w_list, other_b_list): norm_delta_w = th.mean( th.abs((w - other_w).squeeze(2).sum(1))) norm_w = th.mean(th.abs(w.detach().squeeze(2).sum(1))) norm_w_other = th.mean(th.abs(other_w.squeeze(2).sum(1))) loss_ensemble_w += norm_delta_w / (norm_w_other + norm_w + 1e-8) norm_delta_b = th.mean(th.abs((b - other_b).view(-1))) norm_b = th.mean(th.abs(b.detach().view(-1))) norm_b_other = th.mean(th.abs(other_b.view(-1))) loss_ensemble_b += norm_delta_b / (norm_b_other + norm_b + 1e-8) loss_ensemble_w /= len(self.mixer_list) - 1 loss_ensemble_b /= len(self.mixer_list) - 1 if return_q_all: mix_q_all = chosen_action_qvals.detach().cpu().numpy() termed = terminated.detach().cpu().numpy() # (bs,t,1) # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals # Td-error td_error = (chosen_action_qvals - targets.detach() ) # no gradient through target net # (bs,t,1) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss_q = (masked_td_error**2).sum() / mask.sum() loss = loss_q - self.args.en_w_alpha * loss_ensemble_w - self.args.en_b_alpha * loss_ensemble_b # Optimise if self.args.q_net_ensemble: current_optim = self.optimiser_list[chosen_index] current_para = self.params_list[chosen_index] else: current_optim = self.optimiser current_para = self.params current_optim.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_( current_para, self.args.grad_norm_clip) # max_norm try: grad_norm = grad_norm.item() except: pass current_optim.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("loss_q", loss_q.item(), t_env) self.logger.log_stat("loss_en_w", loss_ensemble_w.item(), t_env) self.logger.log_stat("loss_en_b", loss_ensemble_b.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.log_stats_t = t_env if return_q_all: return q_all, mix_q_all, termed
def forward(self, context_ids: TextFieldTensors, query_ids: TextFieldTensors, context_lens: torch.Tensor, query_lens: torch.Tensor, mask_label: Optional[torch.Tensor] = None, start_label: Optional[torch.Tensor] = None, end_label: Optional[torch.Tensor] = None, metadata: List[Dict[str, Any]] = None): # concat the context and query to the encoder # get the indexers first indexers = context_ids.keys() dialogue_ids = {} # 获取context和query的长度 context_len = torch.max(context_lens).item() query_len = torch.max(query_lens).item() # [B, _len] context_mask = get_mask_from_sequence_lengths(context_lens, context_len) query_mask = get_mask_from_sequence_lengths(query_lens, query_len) for indexer in indexers: # get the various variables of context and query dialogue_ids[indexer] = {} for key in context_ids[indexer].keys(): context = context_ids[indexer][key] query = query_ids[indexer][key] # concat the context and query in the length dim dialogue = torch.cat([context, query], dim=1) dialogue_ids[indexer][key] = dialogue # get the outputs of the dialogue if isinstance(self._text_field_embedder, TextFieldEmbedder): embedder_outputs = self._text_field_embedder(dialogue_ids) else: embedder_outputs = self._text_field_embedder( **dialogue_ids[self._index_name]) # get the outputs of the query and context # [B, _len, embed_size] context_last_layer = embedder_outputs[:, :context_len].contiguous() query_last_layer = embedder_outputs[:, context_len:].contiguous() # ------- 计算span预测的结果 ------- # 我们想要知道query中的每一个mask位置的token后面需要补充的内容 # 也就是其对应的context中span的start和end的位置 # 同理,将context扩展成 [b, query_len, context_len, embed_size] context_last_layer = context_last_layer.unsqueeze(dim=1).expand( -1, query_len, -1, -1).contiguous() # [b, query_len, context_len] context_expand_mask = context_mask.unsqueeze(dim=1).expand( -1, query_len, -1).contiguous() # 将上面3个部分拼接在一起 # 这里表示query中所有的position span_embed_size = context_last_layer.size(-1) if self.training and self._neg_sample_ratio > 0.0: # 对mask中0的位置进行采样 # [B*query_len, ] sample_mask_label = mask_label.view(-1) # 获取展开之后的长度以及需要采样的负样本的数量 mask_length = sample_mask_label.size(0) mask_sum = int( torch.sum(sample_mask_label).item() * self._neg_sample_ratio) mask_sum = max(10, mask_sum) # 获取需要采样的负样本的索引 neg_indexes = torch.randint(low=0, high=mask_length, size=(mask_sum, )) # 限制在长度范围内 neg_indexes = neg_indexes[:mask_length] # 将负样本对应的位置mask置为1 sample_mask_label[neg_indexes] = 1 # [B, query_len] use_mask_label = sample_mask_label.view( -1, query_len).to(dtype=torch.bool) # 过滤掉query中pad的部分, [B, query_len] use_mask_label = use_mask_label & query_mask span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1) # 选择context部分可以使用的内容 # [B_mask, context_len, span_embed_size] span_context_matrix = context_last_layer.masked_select( span_mask).view(-1, context_len, span_embed_size).contiguous() # 选择query部分可以使用的向量 span_query_vector = query_last_layer.masked_select( span_mask.squeeze(dim=-1)).view(-1, span_embed_size).contiguous() span_context_mask = context_expand_mask.masked_select( span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous() else: use_mask_label = query_mask span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1) # 选择context部分可以使用的内容 # [B_mask, context_len, span_embed_size] span_context_matrix = context_last_layer.masked_select( span_mask).view(-1, context_len, span_embed_size).contiguous() # 选择query部分可以使用的向量 span_query_vector = query_last_layer.masked_select( span_mask.squeeze(dim=-1)).view(-1, span_embed_size).contiguous() span_context_mask = context_expand_mask.masked_select( span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous() # 得到span属于每个位置的logits # [B_mask, context_len] span_start_probs = self.start_attention(span_query_vector, span_context_matrix, span_context_mask) span_end_probs = self.end_attention(span_query_vector, span_context_matrix, span_context_mask) span_start_logits = torch.log(span_start_probs + self._eps) span_end_logits = torch.log(span_end_probs + self._eps) # [B_mask, 2],最后一个维度第一个表示start的位置,第二个表示end的位置 best_spans = get_best_span(span_start_logits, span_end_logits) # 计算得到每个best_span的分数 best_span_scores = ( torch.gather(span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) + torch.gather(span_end_logits, 1, best_spans[:, 1].unsqueeze(1))) # [B_mask, ] best_span_scores = best_span_scores.squeeze(1) # 将重要的信息写入到输出中 output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_spans": best_spans, "best_span_scores": best_span_scores } # 如果存在标签,则使用标签计算loss if start_label is not None: loss = self._calc_loss(span_start_logits, span_end_logits, use_mask_label, start_label, end_label, best_spans) output_dict["loss"] = loss if metadata is not None: predict_rewrite_results = self._get_rewrite_result( use_mask_label, best_spans, query_lens, context_lens, metadata) output_dict['rewrite_results'] = predict_rewrite_results return output_dict
def train(self, batch: EpisodeBatch, t_env: int, episode_num: int, show_demo=False, save_data=None): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = torch.stack(mac_out, dim=1) # Concat over time # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = torch.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim x_mac_out = mac_out.clone().detach() x_mac_out[avail_actions == 0] = -9999999 max_action_qvals, max_action_index = x_mac_out[:, :-1].max(dim=3) max_action_index = max_action_index.detach().unsqueeze(3) is_max_action = (max_action_index == actions).int().float() if show_demo: q_i_data = chosen_action_qvals.detach().cpu().numpy() q_data = (max_action_qvals - chosen_action_qvals).detach().cpu().numpy() # Calculate the Q-Values necessary for the target target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = torch.stack(target_mac_out[1:], dim=1) # Concat across time # Max over target Q-Values if self.args.double_q: # Get actions that maximise live Q (for double q-learning) mac_out_detach = mac_out.clone().detach() mac_out_detach[avail_actions == 0] = -9999999 cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1] target_max_qvals = torch.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] # Mix if self.mixer is not None: chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:]) # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals if show_demo: tot_q_data = chosen_action_qvals.detach().cpu().numpy() tot_target = targets.detach().cpu().numpy() if self.mixer == None: tot_q_data = np.mean(tot_q_data, axis=2) tot_target = np.mean(tot_target, axis=2) print('action_pair_%d_%d' % (save_data[0], save_data[1]), np.squeeze(q_data[:, 0]), np.squeeze(q_i_data[:, 0]), np.squeeze(tot_q_data[:, 0]), np.squeeze(tot_target[:, 0])) self.logger.log_stat( 'action_pair_%d_%d' % (save_data[0], save_data[1]), np.squeeze(tot_q_data[:, 0]), t_env) return # Td-error td_error = (chosen_action_qvals - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask.sum() masked_hit_prob = torch.mean(is_max_action, dim=2) * mask hit_prob = masked_hit_prob.sum() / mask.sum() # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("hit_prob", hit_prob.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.log_stats_t = t_env
def forward(self, x, mask, gt_target=None, soft_threshold=0.8): mask.require_grad = False x.require_grad = False adjusted_weight = mask[:, 0:1, :].clone().detach().unsqueeze( 0) # weights for SC for i in range(self.num_stages - 1): adjusted_weight = torch.cat( (adjusted_weight, mask[:, 0:1, :].clone().detach().unsqueeze(0))) #print(adjusted_weight.size()) confidence = [] feature = [] if gt_target is not None: gt_target = gt_target.unsqueeze(0) # stage 1 out1, feature1 = self.stage1(x, mask) outputs = out1.unsqueeze(0) feature.append(feature1) confidence.append(F.softmax(out1, dim=1) * mask[:, 0:1, :]) confidence[0].require_grad = False if gt_target is None: max_conf, _ = torch.max(confidence[0], dim=1) max_conf = max_conf.unsqueeze(1).clone().detach() max_conf.require_grad = False decrease_flag = (max_conf > soft_threshold).float() increase_flag = mask[:, 0:1, :].clone().detach() - decrease_flag adjusted_weight[1] = max_conf.neg().exp( ) * decrease_flag + max_conf.exp() * increase_flag # for stage 2 else: gt_conf = torch.gather(confidence[0], dim=1, index=gt_target) decrease_flag = (gt_conf > soft_threshold).float() increase_flag = mask[:, 0:1, :].clone().detach() - decrease_flag adjusted_weight[1] = gt_conf.neg().exp( ) * decrease_flag + gt_conf.exp() * increase_flag # stage 2,...,n curr_stage = 0 for s in self.stages: curr_stage = curr_stage + 1 temp = feature[0] for i in range(1, len(feature)): temp = torch.cat((temp, feature[i]), dim=1) * mask[:, 0:1, :] temp = torch.cat((temp, x), dim=1) curr_out, curr_feature = s(temp, mask) outputs = torch.cat((outputs, curr_out.unsqueeze(0)), dim=0) feature.append(curr_feature) confidence.append(F.softmax(curr_out, dim=1) * mask[:, 0:1, :]) confidence[curr_stage].require_grad = False if curr_stage == self.num_stages - 1: # curr_stage starts from 0 break # don't need to compute the next stage's confidence when current stage = last cascade stage if gt_target is None: max_conf, _ = torch.max(confidence[curr_stage], dim=1) max_conf = max_conf.unsqueeze(1).clone().detach() max_conf.require_grad = False decrease_flag = (max_conf > soft_threshold).float() increase_flag = mask[:, 0:1, :].clone().detach() - decrease_flag adjusted_weight[ curr_stage + 1] = max_conf.neg().exp() * decrease_flag + max_conf.exp( ) * increase_flag # output the weight for the next stage else: gt_conf = torch.gather(confidence[curr_stage], dim=1, index=gt_target) decrease_flag = (gt_conf > soft_threshold).float() increase_flag = mask[:, 0:1, :].clone().detach() - decrease_flag adjusted_weight[curr_stage + 1] = gt_conf.neg().exp( ) * decrease_flag + gt_conf.exp() * increase_flag output_weight = adjusted_weight.detach() output_weight.require_grad = False adjusted_weight = adjusted_weight / torch.sum( adjusted_weight, 0) # normalization among stages temp = F.softmax(out1, dim=1) * adjusted_weight[0] for i in range(1, self.num_stages): temp += F.softmax(outputs[i], dim=1) * adjusted_weight[i] confidenceF = temp * mask[:, 0:1, :] # input of fusion stage # Inner LBP for confidenceF barrier, BGM_output = self.fullBarrier(x) if self.use_lbp: confidenceF = self.lbp_in(confidenceF, barrier) # fusion stage: for more consistent output because of the combination of cascade stages may have much fluctuations out, _ = self.stageF(confidenceF, mask) # use mixture of cascade stages # Final LBP for output if self.use_lbp: for i in range(self.num_soft_lbp): out = self.lbp_out(out, barrier) confidence_last = torch.clamp( F.softmax(out, dim=1), min=1e-4, max=1 - 1e-4) * mask[:, 0:1, :] # torch.clamp for training stability outputs = torch.cat((outputs, confidence_last.unsqueeze(0)), dim=0) return outputs, BGM_output, output_weight
def forward(self, heatmap_heads, offset_heads, wh_heads): with torch.no_grad(): device = heatmap_heads.device heatmap_heads = torch.sigmoid(heatmap_heads) batch_scores, batch_classes, batch_pred_bboxes = [], [], [] for per_image_heatmap_heads, per_image_offset_heads, per_image_wh_heads in zip( heatmap_heads, offset_heads, wh_heads): #filter and keep points which value large than the surrounding 8 points per_image_heatmap_heads = self.nms(per_image_heatmap_heads) topk_score, topk_indexes, topk_classes, topk_ys, topk_xs = self.get_topk( per_image_heatmap_heads, K=self.topk) per_image_offset_heads = per_image_offset_heads.permute( 1, 2, 0).contiguous().view(-1, 2) per_image_offset_heads = torch.gather( per_image_offset_heads, 0, topk_indexes.repeat(1, 2)) topk_xs = topk_xs + per_image_offset_heads[:, 0:1] topk_ys = topk_ys + per_image_offset_heads[:, 1:2] per_image_wh_heads = per_image_wh_heads.permute( 1, 2, 0).contiguous().view(-1, 2) per_image_wh_heads = torch.gather(per_image_wh_heads, 0, topk_indexes.repeat(1, 2)) topk_bboxes = torch.cat([ topk_xs - per_image_wh_heads[:, 0:1] / 2, topk_ys - per_image_wh_heads[:, 1:2] / 2, topk_xs + per_image_wh_heads[:, 0:1] / 2, topk_ys + per_image_wh_heads[:, 1:2] / 2 ], dim=1) topk_bboxes = topk_bboxes * self.stride topk_bboxes[:, 0] = torch.clamp(topk_bboxes[:, 0], min=0) topk_bboxes[:, 1] = torch.clamp(topk_bboxes[:, 1], min=0) topk_bboxes[:, 2] = torch.clamp(topk_bboxes[:, 2], max=self.image_w - 1) topk_bboxes[:, 3] = torch.clamp(topk_bboxes[:, 3], max=self.image_h - 1) one_image_scores = (-1) * torch.ones( (self.max_detection_num, ), device=device) one_image_classes = (-1) * torch.ones( (self.max_detection_num, ), device=device) one_image_pred_bboxes = (-1) * torch.ones( (self.max_detection_num, 4), device=device) topk_classes = topk_classes[ topk_score > self.min_score_threshold].float() topk_bboxes = topk_bboxes[ topk_score > self.min_score_threshold].float() topk_score = topk_score[ topk_score > self.min_score_threshold].float() final_detection_num = min(self.max_detection_num, topk_score.shape[0]) one_image_scores[0:final_detection_num] = topk_score[ 0:final_detection_num] one_image_classes[0:final_detection_num] = topk_classes[ 0:final_detection_num] one_image_pred_bboxes[0:final_detection_num, :] = topk_bboxes[ 0:final_detection_num, :] batch_scores.append(one_image_scores.unsqueeze(0)) batch_classes.append(one_image_classes.unsqueeze(0)) batch_pred_bboxes.append(one_image_pred_bboxes.unsqueeze(0)) batch_scores = torch.cat(batch_scores, axis=0) batch_classes = torch.cat(batch_classes, axis=0) batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0) # batch_scores shape:[batch_size,topk] # batch_classes shape:[batch_size,topk] # batch_pred_bboxes shape[batch_size,topk,4] return batch_scores, batch_classes, batch_pred_bboxes