def iou_match(yx_min, yx_max, data): batch_size, cells, num_anchors, _ = yx_min.size() iou_matrix = utils.iou.torch.batch_iou_matrix(yx_min.view(batch_size, -1, 2), yx_max.view(batch_size, -1, 2), data['yx_min'], data['yx_max']) iou_matrix = iou_matrix.view(batch_size, cells, num_anchors, -1) iou, index = iou_matrix.max(-1) _index = torch.unbind(index.view(batch_size, -1)) _data = {} for key in 'yx_min, yx_max, cls'.split(', '): t = data[key] if len(t.size()) == 2: t = torch.stack([d[i] for d, i in zip(torch.unbind(t, 0), _index)]).view(batch_size, cells, num_anchors) elif len(t.size()) == 3: t = torch.stack([d[i] for d, i in zip(torch.unbind(t, 0), _index)]).view(batch_size, cells, num_anchors, -1) _data[key] = t return iou_matrix, iou, index, _data
def func(module, x): """ The actual wrapped operation. """ return torch.stack( tuple(Layer.resolve(layer)(module, X) for X in torch.unbind(x, 0)), 0 )
def output_batch(self, ner_model, documents, fout): """ decode the whole corpus in the specific format by calling apply_model to fit specific models args: ner_model: sequence labeling model feature (list): list of words list fout: output file """ ner_model.eval() d_len = len(documents) for d_ind in tqdm( range(0, d_len), mininterval=1, desc=' - Process', leave=False, file=sys.stdout): fout.write('-DOCSTART- -DOCSTART- -DOCSTART-\n\n') features = documents[d_ind] f_len = len(features) for ind in range(0, f_len, self.batch_size): eind = min(f_len, ind + self.batch_size) labels = self.apply_model(ner_model, features[ind: eind]) labels = torch.unbind(labels, 1) for ind2 in range(ind, eind): f = features[ind2] l = labels[ind2 - ind][0: len(f) ] fout.write(self.decode_str(features[ind2], l) + '\n\n')
def fit_positive(rows, cols, yx_min, yx_max, anchors): device_id = anchors.get_device() if torch.cuda.is_available() else None batch_size, num, _ = yx_min.size() num_anchors, _ = anchors.size() valid = torch.prod(yx_min < yx_max, -1) center = (yx_min + yx_max) / 2 ij = torch.floor(center) i, j = torch.unbind(ij.long(), -1) index = i * cols + j anchors2 = anchors / 2 iou_matrix = utils.iou.torch.iou_matrix((yx_min - center).view(-1, 2), (yx_max - center).view(-1, 2), -anchors2, anchors2).view(batch_size, -1, num_anchors) iou, index_anchor = iou_matrix.max(-1) _positive = [] cells = rows * cols for valid, index, index_anchor in zip(torch.unbind(valid), torch.unbind(index), torch.unbind(index_anchor)): index, index_anchor = (t[valid] for t in (index, index_anchor)) t = utils.ensure_device(torch.ByteTensor(cells, num_anchors).zero_(), device_id) t[index, index_anchor] = 1 _positive.append(t) return torch.stack(_positive)
def calc_acc_batch(self, decoded_data, target_data): """ update statics for accuracy args: decoded_data (batch_size, seq_len): prediction sequence target_data (batch_size, seq_len): ground-truth """ batch_decoded = torch.unbind(decoded_data, 1) batch_targets = torch.unbind(target_data, 0) for decoded, target in zip(batch_decoded, batch_targets): gold = self.packer.convert_for_eval(target) # remove padding length = utils.find_length_from_labels(gold, self.l_map) gold = gold[:length].numpy() best_path = decoded[:length].numpy() self.total_labels += length self.correct_labels += np.sum(np.equal(best_path, gold))
def calc_f1_batch(self, decoded_data, target_data): """ update statics for f1 score args: decoded_data (batch_size, seq_len): prediction sequence target_data (batch_size, seq_len): ground-truth """ batch_decoded = torch.unbind(decoded_data, 1) batch_targets = torch.unbind(target_data, 0) for decoded, target in zip(batch_decoded, batch_targets): gold = self.packer.convert_for_eval(target) # remove padding length = utils.find_length_from_labels(gold, self.l_map) gold = gold[:length] best_path = decoded[:length] correct_labels_i, total_labels_i, gold_count_i, guess_count_i, overlap_count_i = self.eval_instance(best_path.numpy(), gold.numpy()) self.correct_labels += correct_labels_i self.total_labels += total_labels_i self.gold_count += gold_count_i self.guess_count += guess_count_i self.overlap_count += overlap_count_i
def plot_samples(): try: z_samples = model.z.data.numpy().T except AttributeError: z_samples = model.sample() plt.clf() plot_generative() plt.scatter(z_samples[0], z_samples[1], alpha=0.6) plt.xlabel('mu') plt.ylabel('logvar') plt.title('qz') # plt.xlim(-8,8) # plt.ylim(-4,4) mu, logvar = torch.unbind(model.qz.mu, 1) mu = mu.data.numpy() logvar = logvar.data.numpy() plt.scatter(mu, logvar, color='black', alpha=0.3) plt.pause(0.01)
def forward(self, x): B, T = x.shape # 获取掩码 mask = x.gt(0) # 获取按长度有序的字序列索引 lens, indices = torch.sort(mask.sum(dim=1), descending=True) # 获取逆序索引 _, inverse_indices = indices.sort() # 获取单词最大长度 max_len = lens[0] # 序列按长度由大到小排列 x = x[indices, :max_len] # 获取字嵌入向量 x = self.embed(x) # 打包数据 x = pack_padded_sequence(x, lens, True) x, (hidden, _) = self.lstm(x) # 获取词的字符表示 reprs = torch.cat(torch.unbind(hidden), dim=1) # 恢复原有的顺序 reprs = reprs[inverse_indices] return reprs
def beam_search(self, init_state, init_logprobs, *args, **kwargs): # function computes the similarity score to be augmented def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): local_time = t - divm unaug_logprobsf = logprobsf.clone() for prev_choice in range(divm): prev_decisions = beam_seq_table[prev_choice][local_time] for sub_beam in range(bdash): for prev_labels in range(bdash): logprobsf[sub_beam][ prev_decisions[prev_labels]] = logprobsf[sub_beam][ prev_decisions[prev_labels]] - diversity_lambda return unaug_logprobsf # does one step of classical beam search def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): #INPUTS: #logprobsf: probabilities augmented after diversity #beam_size: obvious #t : time instant #beam_seq : tensor contanining the beams #beam_seq_logprobs: tensor contanining the beam logprobs #beam_logprobs_sum: tensor contanining joint logprobs #OUPUTS: #beam_seq : tensor containing the word indices of the decoded captions #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq #beam_logprobs_sum : joint log-probability of each beam ys, ix = torch.sort(logprobsf, 1, True) candidates = [] cols = min(beam_size, ys.size(1)) rows = beam_size if t == 0: rows = 1 for c in range(cols): # for each column (word, essentially) for q in range(rows): # for each beam expansion #compute logprob of expanding beam q with word in (sorted) position c local_logprob = ys[q, c].item() candidate_logprob = beam_logprobs_sum[q] + local_logprob local_unaug_logprob = unaug_logprobsf[q, ix[q, c]] candidates.append({ 'c': ix[q, c], 'q': q, 'p': candidate_logprob, 'r': local_unaug_logprob }) candidates = sorted(candidates, key=lambda x: -x['p']) new_state = [_.clone() for _ in state] #beam_seq_prev, beam_seq_logprobs_prev if t >= 1: #we''ll need these as reference when we fork beams around beam_seq_prev = beam_seq[:t].clone() beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() for vix in range(beam_size): v = candidates[vix] #fork beam index q into index vix if t >= 1: beam_seq[:t, vix] = beam_seq_prev[:, v['q']] beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] #rearrange recurrent states for state_ix in range(len(new_state)): # copy over state in previous beam q to new beam at vix new_state[state_ix][:, vix] = state[state_ix][:, v[ 'q']] # dimension one is time step #append new end terminal at the end of this beam beam_seq[t, vix] = v['c'] # c'th word is the continuation beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here beam_logprobs_sum[vix] = v[ 'p'] # the new (sum) logprob along this beam state = new_state return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates # Start diverse_beam_search opt = kwargs['opt'] beam_size = opt.get('beam_size', 10) group_size = opt.get('group_size', 1) diversity_lambda = opt.get('diversity_lambda', 0.5) decoding_constraint = opt.get('decoding_constraint', 0) max_ppl = opt.get('max_ppl', 0) length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) bdash = beam_size // group_size # beam per group # INITIALIZATIONS beam_seq_table = [ torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size) ] beam_seq_logprobs_table = [ torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size) ] beam_logprobs_sum_table = [ torch.zeros(bdash) for _ in range(group_size) ] # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) done_beams_table = [[] for _ in range(group_size)] state_table = [ list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2) ] logprobs_table = list(init_logprobs.chunk(group_size, 0)) # END INIT # Chunk elements in the args args = list(args) args = [ _.chunk(group_size) if _ is not None else [None] * group_size for _ in args ] args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] for t in range(self.seq_length + group_size - 1): for divm in range(group_size): if t >= divm and t <= self.seq_length + divm - 1: # add diversity logprobsf = logprobs_table[divm].data.float() # suppress previous word if decoding_constraint and t - divm > 0: logprobsf.scatter_( 1, beam_seq_table[divm][t - divm - 1].unsqueeze(1).cuda(), float('-inf')) # suppress UNK tokens in the decoding logprobsf[:, logprobsf.size(1) - 1] = logprobsf[:, logprobsf.size(1) - 1] - 1000 # diversity is added here # the function directly modifies the logprobsf values and hence, we need to return # the unaugmented ones for sorting the candidates in the end. # for historical # reasons :-) unaug_logprobsf = add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash) # infer new beams beam_seq_table[divm],\ beam_seq_logprobs_table[divm],\ beam_logprobs_sum_table[divm],\ state_table[divm],\ candidates_divm = beam_step(logprobsf, unaug_logprobsf, bdash, t-divm, beam_seq_table[divm], beam_seq_logprobs_table[divm], beam_logprobs_sum_table[divm], state_table[divm]) # if time's up... or if end token is reached then copy beams for vix in range(bdash): if beam_seq_table[divm][ t - divm, vix] == 0 or t == self.seq_length + divm - 1: final_beam = { 'seq': beam_seq_table[divm][:, vix].clone(), 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), 'unaug_p': beam_seq_logprobs_table[divm] [:, vix].sum().item(), 'p': beam_logprobs_sum_table[divm][vix].item() } final_beam['p'] = length_penalty( t - divm + 1, final_beam['p']) # if max_ppl: # final_beam['p'] = final_beam['p'] / (t-divm+1) done_beams_table[divm].append(final_beam) # don't continue beams from finished sequences beam_logprobs_sum_table[divm][vix] = -1000 # move the current group one step forward in time it = beam_seq_table[divm][t - divm] logprobs_table[divm], state_table[ divm] = self.get_logprobs_state( it.cuda(), *(args[divm] + [state_table[divm]])) # all beams are sorted by their log-probabilities done_beams_table = [ sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size) ] done_beams = reduce(lambda a, b: a + b, done_beams_table) return done_beams
def run_SAM(self, df_data, skeleton=None, **kwargs): """Execute the SAM model. :param df_data: """ gpu = kwargs.get('gpu', False) gpu_no = kwargs.get('gpu_no', 0) categorical_variables = kwargs.get('categorical_variables', None) verbose = kwargs.get('verbose', True) plot = kwargs.get("plot", False) plot_generated_pair = kwargs.get("plot_generated_pair", False) d_str = "Epoch: {} -- Disc: {} -- Gen: {} -- L1: {}" if categorical_variables is None: warnings.warn("Dataset considered as numerical") categorical_variables = [ False for i in range(len(df_data.columns)) ] # list_nodes = list(df_data.columns) onehotdata = [] for i, var_is_categorical in enumerate(categorical_variables): if var_is_categorical: onehotdata.append( pd.get_dummies(df_data.iloc[:, i]).as_matrix()) else: onehotdata.append(df_data.iloc[:, [i]].as_matrix()) cat_size = [i.shape[1] for i in onehotdata] # cat_size.append(1) # Noise df_data = np.concatenate(onehotdata, 1) data = df_data.astype('float32') self.data = df_data data = th.from_numpy(data) if self.batchsize == -1: self.batchsize = data.shape[0] rows, cols = data.size() # CAT data: cols override cols = len(cat_size) # Get the list of indexes to ignore if skeleton is not None: zero_components = [[] for i in range(cols)] for i, j in zip(*((1 - skeleton).nonzero())): zero_components[j].append(i) else: zero_components = [[i] for i in range(cols)] self.sam = SAM_generators((rows, cols), cat_size, zero_components, batch_norm=True, nh=self.nh, batch_size=self.batchsize, **kwargs) # Begin UGLY activation_function = kwargs.get('activation_function', th.nn.Tanh) try: del kwargs["activation_function"] except KeyError: pass self.discriminator_sam = SAM_discriminator( [sum(cat_size), self.dnh, self.dnh, 1], batch_norm=True, activation_function=th.nn.LeakyReLU, activation_argument=0.2, **kwargs) kwargs["activation_function"] = activation_function # End of UGLY if gpu: self.sam = self.sam.cuda(gpu_no) self.discriminator_sam = self.discriminator_sam.cuda(gpu_no) data = data.cuda(gpu_no) # Select parameters to optimize : ignore the non connected nodes criterion = th.nn.BCEWithLogitsLoss() g_optimizer = th.optim.Adam(self.sam.parameters(), lr=self.lr) d_optimizer = th.optim.Adam(self.discriminator_sam.parameters(), lr=self.dlr) true_variable = Variable(th.ones(self.batchsize, 1), requires_grad=False) false_variable = Variable(th.zeros(self.batchsize, 1), requires_grad=False) causal_filters = th.zeros(cols, cols) if gpu: true_variable = true_variable.cuda(gpu_no) false_variable = false_variable.cuda(gpu_no) causal_filters = causal_filters.cuda(gpu_no) data_iterator = DataLoader(data, batch_size=self.batchsize, shuffle=True, drop_last=True) # TRAIN for epoch in range(self.train + self.test): for i_batch, batch in enumerate(data_iterator): batch = Variable(batch) # print(batch.size()) unbind_vectors = th.unbind(batch, 1) # print(cat_size) batch_vectors = [ th.stack( unbind_vectors[sum(cat_size[:idx] ):sum(cat_size[:idx]) + i], 1) if i > 1 else unbind_vectors[sum(cat_size[:idx])].unsqueeze(1) for idx, i in enumerate(cat_size) ] g_optimizer.zero_grad() d_optimizer.zero_grad() # Train the discriminator generated_variables = self.sam(batch) # for i in generated_variables: # print(i.size()) # print(batch.size()) disc_losses = [] gen_losses = [] # print([j.size() for j in batch_vectors]) # print([j.size() for j in generated_variables]) for i in range(cols): generator_output = th.cat([ v for c in [ batch_vectors[:i], [generated_variables[i]], batch_vectors[i + 1:] ] for v in c ], 1) # 1. Train discriminator on fake # print(i, generator_output.size()) disc_output_detached = self.discriminator_sam( generator_output.detach()) disc_output = self.discriminator_sam(generator_output) disc_losses.append( criterion(disc_output_detached, false_variable)) # 2. Train the generator : gen_losses.append(criterion(disc_output, true_variable)) true_output = self.discriminator_sam(batch) adv_loss = sum(disc_losses)/cols + \ criterion(true_output, true_variable) gen_loss = sum(gen_losses) adv_loss.backward() d_optimizer.step() # 3. Compute filter regularization filters = th.stack( [i.filter._filter[0, :-1].abs() for i in self.sam.blocks], 1) l1_reg = self.l1 * filters.sum() loss = gen_loss + l1_reg if verbose and not epoch % 20: print( str(i) + " " + d_str.format(epoch, adv_loss.cpu().item(), gen_loss.cpu().item() / cols, l1_reg.cpu().item())) loss.backward() # STORE ASSYMETRY values for output if epoch >= self.train: causal_filters.add_(filters.data) g_optimizer.step() if plot and i_batch == 0: try: ax.clear() ax.plot(range(len(adv_plt)), adv_plt, "r-", linewidth=1.5, markersize=4, label="Discriminator") ax.plot(range(len(adv_plt)), gen_plt, "g-", linewidth=1.5, markersize=4, label="Generators") ax.plot(range(len(adv_plt)), l1_plt, "b-", linewidth=1.5, markersize=4, label="L1-Regularization") ax.plot(range(len(adv_plt)), asym_plt, "c-", linewidth=1.5, markersize=4, label="Assym penalization") plt.legend() adv_plt.append(adv_loss.cpu().data[0]) gen_plt.append(gen_loss.cpu().data[0] / cols) l1_plt.append(l1_reg.cpu().data[0]) asym_plt.append(asymmetry_reg.cpu().data[0]) plt.pause(0.0001) except NameError: plt.ion() plt.figure() plt.xlabel("Epoch") plt.ylabel("Losses") plt.pause(0.0001) adv_plt = [adv_loss.cpu().data[0]] gen_plt = [gen_loss.cpu().data[0] / cols] l1_plt = [l1_reg.cpu().data[0]] elif plot: adv_plt.append(adv_loss.cpu().data[0]) gen_plt.append(gen_loss.cpu().data[0] / cols) l1_plt.append(l1_reg.cpu().data[0]) if plot_generated_pair and i_batch == 0: if epoch == 0: plt.ion() to_print = [[0, 1]] # , [1, 0]] # [2, 3]] # , [11, 17]] plt.clf() for (i, j) in to_print: plt.scatter(generated_variables[i].data.cpu().numpy(), batch.data.cpu().numpy()[:, j], label="Y -> X") plt.scatter(batch.data.cpu().numpy()[:, i], generated_variables[j].data.cpu().numpy(), label="X -> Y") plt.scatter(batch.data.cpu().numpy()[:, i], batch.data.cpu().numpy()[:, j], label="original data") plt.legend() plt.pause(0.01) return causal_filters.div_(self.test).cpu().numpy()
def apply(func, apply_dimension): ''' Applies a function along a given dimension ''' output_list = [func(m) for m in torch.unbind(apply_dimension, dim=0)] return torch.stack(output_list, dim=0)
def _take_step(self, indices, context1, context1_): num_tasks = len(indices) # data is (task, batch, feat) obs, actions, rewards, next_obs, terms = self.sample_data(indices) t, b, _ = obs.size() # run inference in networks policy_outputs, task_z = self.agent(obs, context1) new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4] policy_mean = policy_mean.view(t, b, -1) policy_log_std = policy_log_std.view(t, b, -1) #positive samples policy_outputs_1, task_z_1 = self.agent(obs, context1_) new_actions_1, policy_mean_1, policy_log_std_1, log_pi_1 = policy_outputs_1[: 4] policy_mean_1 = policy_mean_1.view(t, b, -1) #(task, batch, feat) policy_log_std_1 = policy_log_std_1.view(t, b, -1) #(task, batch, feat) #computer policy kl divergence policy1 = [[ torch.distributions.Normal(mu, torch.exp(log_std)) for mu, log_std in zip(torch.unbind(policy_mean[i]), torch.unbind(policy_log_std[i])) ] for i in range(t)] policy2 = [[ torch.distributions.Normal(mu, torch.exp(log_std)) for mu, log_std in zip(torch.unbind(policy_mean_1[i]), torch.unbind(policy_log_std_1[i])) ] for i in range(t)] kl_divs = [[ torch.distributions.kl.kl_divergence(policy1[i][j], policy2[i][j]) for j in range(b) ] for i in range(t)] kl_div_sum = torch.mean(torch.stack(kl_divs)) kl_divs_final = [] for i in range(len(kl_divs)): kl_divs_final.append(torch.mean(torch.stack(kl_divs[i]))) kl_policy_final = torch.mean(torch.stack(kl_divs_final)) kl_policy_loss = 1 * kl_policy_final self.encoder_optimizer.zero_grad() self.policy_optimizer.zero_grad() kl_policy_loss.backward(retain_graph=True) # flattens out the task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) next_obs = next_obs.view(t * b, -1) # Q and V networks # encoder will only get gradients from Q nets q1_pred = self.qf1(obs, actions, task_z) q2_pred = self.qf2(obs, actions, task_z) v_pred = self.vf(obs, task_z.detach()) # get targets for use in V and Q updates with torch.no_grad(): target_v_values = self.target_vf(next_obs, task_z) # KL constraint on z if probabilistic if self.use_information_bottleneck: kl_div = self.agent.compute_kl_div() kl_loss = self.kl_lambda * kl_div kl_loss.backward(retain_graph=True) # qf and encoder update (note encoder does not get grads from policy or vf) self.qf1_optimizer.zero_grad() self.qf2_optimizer.zero_grad() rewards_flat = rewards.view(self.batch_size * num_tasks, -1) # scale rewards for Bellman update rewards_flat = rewards_flat * self.reward_scale terms_flat = terms.view(self.batch_size * num_tasks, -1) q_target = rewards_flat + ( 1. - terms_flat) * self.discount * target_v_values qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean( (q2_pred - q_target)**2) qf_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.step() self.encoder_optimizer.step() # compute min Q on the new actions min_q_new_actions = self._min_q(obs, new_actions, task_z) # vf update v_target = min_q_new_actions - log_pi vf_loss = self.vf_criterion(v_pred, v_target.detach()) self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() self._update_target_network() # policy update # n.b. policy update includes dQ/da log_policy_target = min_q_new_actions policy_loss = (log_pi - log_policy_target).mean() mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean() std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value**2).sum(dim=1).mean()) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss #self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() # save some statistics for eval if self.eval_statistics is None: # eval should set this to None. # this way, these statistics are only computed for one batch. self.eval_statistics = OrderedDict() if self.use_information_bottleneck: z_mean = np.mean(np.abs(ptu.get_numpy(self.agent.z_means[0]))) z_sig = np.mean(ptu.get_numpy(self.agent.z_vars[0])) self.eval_statistics['Z mean train'] = z_mean self.eval_statistics['Z variance train'] = z_sig #self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div) #self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss) self.eval_statistics['Contrastive Loss'] = np.mean( ptu.get_numpy(loss)) self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'V Predictions', ptu.get_numpy(v_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), ))
def gradient3DForBboxFace(self, emb3D_scenes, bbox, scores): # emb3D_scenes should be B x C x D x H x W dz_batch, dy_batch, dx_batch = utils_basic.gradient3D(emb3D_scenes, absolute=False, square=False) bbox = torch.clamp(bbox, min=0) sizes_val = [hyp.Z2 - 1, hyp.Y2 - 1, hyp.X2 - 1] gs_loss_list = [] # gradient smoothness loss for index_batch, emb_scene in enumerate(emb3D_scenes): gsloss = 0 dz, dy, dx = dz_batch[index_batch:index_batch + 1], dy_batch[ index_batch:index_batch + 1], dx_batch[index_batch:index_batch + 1] for index_box, box in enumerate(bbox[index_batch]): if scores[index_batch][index_box] > 0: lower, upper = torch.unbind(box) lower = [torch.floor(i).to(torch.int32) for i in lower] upper = [torch.ceil(i).to(torch.int32) for i in upper] xmin, ymin, zmin = [max(i, 0) for i in lower] xmax, ymax, zmax = [ min(i, sizes_val[index]) for index, i in enumerate(upper) ] #zmin face gsloss += self.get_gradient_loss_on_bbox_surface( dz, zmin, zmin + 1, ymin, ymax, xmin, xmax) if zmin < sizes_val[0]: gsloss += self.get_gradient_loss_on_bbox_surface( dz, zmin + 1, zmin + 2, ymin, ymax, xmin, xmax) #zmax face gsloss += self.get_gradient_loss_on_bbox_surface( dz, zmax, zmax + 1, ymin, ymax, xmin, xmax) if zmax < sizes_val[0]: gsloss += self.get_gradient_loss_on_bbox_surface( dz, zmax + 1, zmax + 2, ymin, ymax, xmin, xmax) #ymin face gsloss += self.get_gradient_loss_on_bbox_surface( dy, zmin, zmax, ymin, ymin + 1, xmin, xmax) if ymin < sizes_val[1]: gsloss += self.get_gradient_loss_on_bbox_surface( dy, zmin, zmax, ymin + 1, ymin + 2, xmin, xmax) #ymax face gsloss += self.get_gradient_loss_on_bbox_surface( dy, zmin, zmax, ymax, ymax + 1, xmin, xmax) if ymax < sizes_val[1]: gsloss += self.get_gradient_loss_on_bbox_surface( dy, zmin, zmax, ymax + 1, ymax + 2, xmin, xmax) #xmin face gsloss += self.get_gradient_loss_on_bbox_surface( dx, zmin, zmax, ymin, ymax, xmin, xmin + 1) if xmin < sizes_val[2]: gsloss += self.get_gradient_loss_on_bbox_surface( dx, zmin, zmax, ymin, ymax, xmin + 1, xmin + 2) #xmax face gsloss += self.get_gradient_loss_on_bbox_surface( dx, zmin, zmax, ymin, ymax, xmax, xmax + 1) if xmax < sizes_val[2]: gsloss += self.get_gradient_loss_on_bbox_surface( dx, zmin, zmax, ymin, ymax, xmax + 1, xmax + 2) gs_loss_list.append(gsloss) gsloss = torch.mean(torch.tensor(gs_loss_list)) return gsloss
def train_point(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=-1) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] d = { k: v for k, v in state['state_dict'].items() if 'map' not in k } model.load_state_dict(d) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: num_class = 20 total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): # Get training data data = data_iter.next() points, target, sample_weight = data if config.pure_point: sinput = points.transpose(1, 2).cuda().float() # DEBUG: use the discrete coord for point-based ''' feats = torch.unbind(points[:,:,:], dim=0) voxel_size = config.voxel_size coords = torch.unbind(points[:,:,:3]/voxel_size, dim=0) # 0.05 is the voxel-size coords, feats= ME.utils.sparse_collate(coords, feats) # assert feats.reshape([16, 4096, -1]) == points[:,:,3:] points_ = ME.TensorField(features=feats.float(), coordinates=coords, device=device) tmp_voxel = points_.sparse() sinput_ = tmp_voxel.slice(points_) sinput = torch.cat([sinput_.C[:,1:]*config.voxel_size, sinput_.F[:,3:]],dim=1).reshape([config.batch_size, config.num_points, 6]) # sinput = sinput_.F.reshape([config.batch_size, config.num_points, 6]) sinput = sinput.transpose(1,2).cuda().float() # sinput = torch.cat([coords[:,1:], feats],dim=1).reshape([config.batch_size, config.num_points, 6]) # sinput = sinput.transpose(1,2).cuda().float() ''' # For some networks, making the network invariant to even, odd coords is important # coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input # if config.normalize_color: # feats = feats / 255. - 0.5 # torch.save(points[:,:,:3], './sandbox/tensorfield-c.pth') # torch.save(points_.C, './sandbox/points-c.pth') else: # feats = torch.unbind(points[:,:,3:], dim=0) # WRONG: should also feed in xyz as inupt feature voxel_size = config.voxel_size coords = torch.unbind(points[:, :, :3] / voxel_size, dim=0) # 0.05 is the voxel-size # Normalize the xyz in feature # points[:,:,:3] = points[:,:,:3] / points[:,:,:3].mean() feats = torch.unbind(points[:, :, :], dim=0) coords, feats = ME.utils.sparse_collate(coords, feats) # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input # if config.normalize_color: # feats = feats / 255. - 0.5 # they are the same points_ = ME.TensorField(features=feats.float(), coordinates=coords, device=device) # points_1 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE) # points_2 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE) sinput = points_.sparse() data_time += data_timer.toc(False) B, npoint = target.shape # model.initialize_coords(*init_args) soutput = model(sinput) if config.pure_point: soutput = soutput.reshape([B * npoint, -1]) else: soutput = soutput.slice(points_).F # s1 = soutput.slice(points_) # print(soutput.quantization_mode) # soutput.quantization_mode = ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE # s2 = soutput.slice(points_) # The output of the network is not sorted target = (target - 1).view(-1).long().to(device) # catch NAN if torch.isnan(soutput).sum() > 0: import ipdb ipdb.set_trace() loss = criterion(soutput, target) if torch.isnan(loss).sum() > 0: import ipdb ipdb.set_trace() loss = (loss * sample_weight.to(device)).mean() # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # print(model.input_mlp[0].weight.max()) # print(model.input_mlp[0].weight.grad.max()) # Update number of steps optimizer.step() scheduler.step() # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput, target) score = precision_at_one(pred, target, ignore_label=-1) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # Calc the iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target >= 0)) | (target == l)).sum() if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation: # for point-based should use alternate dataloader for eval # if curr_iter % config.val_freq == 0: # val_miou = test_points(model, val_data_loader, None, curr_iter, config, transform_data_fn) # if val_miou > best_val_miou: # best_val_miou = val_miou # best_val_iter = curr_iter # checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, # "best_val") # logging.info("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter)) # # Recover back # model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) test_points(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def __call__(self, boxes, scores, indices, stage): """Perform NMS-based selection of detections Parameters ---------- boxes : sequence of torch.Tensor Sequence of N tensors of class-specific bounding boxes with shapes M_i x C x 4, entries can be None scores : sequence of torch.Tensor Sequence of N tensors of class probabilities with shapes M_i x (C + 1), entries can be None Returns ------- bbx_pred : PackedSequence A sequence of N tensors of bounding boxes with shapes S_i x 4, entries are None for images in which no detection can be kept according to the selection parameters cls_pred : PackedSequence A sequence of N tensors of thing class predictions with shapes S_i, entries are None for images in which no detection can be kept according to the selection parameters obj_pred : PackedSequence A sequence of N tensors of detection confidences with shapes S_i, entries are None for images in which no detection can be kept according to the selection parameters """ bbx_pred, cls_pred, obj_pred, indices_pred = [], [], [], [] for bbx_i, obj_i, index_i in zip(boxes, scores, indices): try: if bbx_i is None or obj_i is None: raise Empty # Do NMS separately for each class bbx_pred_i, cls_pred_i, obj_pred_i, indices_pred_i = [], [], [], [] for cls_id, (bbx_cls_i, obj_cls_i, index_cls_i) in enumerate( zip(torch.unbind(bbx_i, dim=1), torch.unbind(obj_i, dim=1)[1:], torch.unbind(index_i, dim=1))): # Filter out low-scoring predictions idx = obj_cls_i > self.score_threshold if not idx.any().item(): continue bbx_cls_i = bbx_cls_i[idx] obj_cls_i = obj_cls_i[idx] index_cls_i = index_cls_i[idx] # Filter out empty predictions idx = (bbx_cls_i[:, 2] > bbx_cls_i[:, 0]) & ( bbx_cls_i[:, 3] > bbx_cls_i[:, 1]) if not idx.any().item(): continue bbx_cls_i = bbx_cls_i[idx] obj_cls_i = obj_cls_i[idx] index_cls_i = index_cls_i[idx] # Do NMS idx = nms(bbx_cls_i.contiguous(), obj_cls_i.contiguous(), threshold=self.nms_threshold[stage], n_max=-1) if idx.numel() == 0: continue bbx_cls_i = bbx_cls_i[idx] obj_cls_i = obj_cls_i[idx] index_cls_i = index_cls_i[idx] # Save remaining outputs bbx_pred_i.append(bbx_cls_i) cls_pred_i.append( bbx_cls_i.new_full((bbx_cls_i.size(0), ), cls_id, dtype=torch.long)) obj_pred_i.append(obj_cls_i) indices_pred_i.append(index_cls_i) # Compact predictions from the classes if len(bbx_pred_i) == 0: raise Empty bbx_pred_i = torch.cat(bbx_pred_i, dim=0) cls_pred_i = torch.cat(cls_pred_i, dim=0) obj_pred_i = torch.cat(obj_pred_i, dim=0) indices_pred_i = torch.cat(indices_pred_i, dim=0) # Do post-NMS selection (if needed) if bbx_pred_i.size(0) > self.max_predictions: _, idx = obj_pred_i.topk(self.max_predictions) bbx_pred_i = bbx_pred_i[idx] cls_pred_i = cls_pred_i[idx] obj_pred_i = obj_pred_i[idx] indices_pred_i = indices_pred_i[idx] # Save results bbx_pred.append(bbx_pred_i) cls_pred.append(cls_pred_i) obj_pred.append(obj_pred_i) indices_pred.append(indices_pred_i) except Empty: bbx_pred.append(None) cls_pred.append(None) obj_pred.append(None) indices_pred.append(None) return PackedSequence(bbx_pred), PackedSequence( cls_pred), PackedSequence(obj_pred), PackedSequence(indices_pred)
def unbind(self, dim=0): results = torch.unbind(self._tensor, dim) results = tuple(CUDALongTensor(t) for t in results) return results
def apply_along_axis(func, M, dim): tList = [func(m) for m in torch.unbind(M, dim)] res = torch.stack(tList, dim).to(device=M.device) return res
import numpy as np import cv2 MVSDataset = find_dataset_def("dtu_yao") dataset = MVSDataset("/data1/Dataset/mvs_training/dtu/", '../lists/dtu/train.txt', 'train', 3, 256) dataloader = DataLoader(dataset, batch_size=2) item = next(iter(dataloader)) imgs = item["imgs"][:, :, :, ::4, ::4].cuda() proj_matrices = item["proj_matrices"].cuda() mask = item["mask"].cuda() depth = item["depth"].cuda() depth_values = item["depth_values"].cuda() imgs = torch.unbind(imgs, 1) proj_matrices = torch.unbind(proj_matrices, 1) ref_img, src_imgs = imgs[0], imgs[1:] ref_proj, src_projs = proj_matrices[0], proj_matrices[1:] warped_imgs = homo_warping2(src_imgs[0], src_projs[0], ref_proj, depth_values) cv2.imwrite( '../tmp/ref.png', ref_img.permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255) cv2.imwrite( '../tmp/src.png', src_imgs[0].permute( [0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255)
basename = 'test_split_size', fn = lambda x: bf.split(x, 3, 0), torch_fn = lambda x: torch.split(x, 3, 0), inputs = utils.random_inputs([ [(9)], [(9,2)] ]) ) utils.add_tests(ShapeTestCase, basename = 'test_split_indices', fn = lambda x: bf.split(x, [2, 7], 0), torch_fn = lambda x: torch.split(x, [2, 5, 2], 0), inputs = utils.random_inputs([ [(9)], [(9,2)] ]) ) utils.add_tests(ShapeTestCase, basename = 'test_unstack', fn = lambda x: bf.unstack(x, 0), torch_fn = lambda x: torch.unbind(x, 0), inputs = utils.random_inputs([ [(5,3)], [(5,2,3)] ]) )
def forward(self, x, h0, x_mask): start_scores_list = [] end_scores_list = [] for turn in range(self.num_turn): st_scores = self.attn_b(x, h0, x_mask) start_scores_list.append(st_scores) if self.answer_opt == 3: ptr_net_b = torch.bmm(F.softmax(st_scores, 1).unsqueeze(1), x).squeeze(1) ptr_net_b = self.dropout(ptr_net_b) xb = ptr_net_b if self.proj is None else self.proj(ptr_net_b) end_scores = self.attn_e(x, h0 + xb, x_mask) ptr_net_e = torch.bmm( F.softmax(end_scores, 1).unsqueeze(1), x).squeeze(1) ptr_net_in = (ptr_net_b + ptr_net_e) / 2.0 h0 = self.dropout(h0) h0 = self.rnn(ptr_net_in, h0) elif self.answer_opt == 2: ptr_net_b = torch.bmm(F.softmax(st_scores, 1).unsqueeze(1), x).squeeze(1) ptr_net_b = self.dropout(ptr_net_b) xb = ptr_net_b if self.proj is None else self.proj(ptr_net_b) end_scores = self.attn_e(x, xb, x_mask) ptr_net_e = torch.bmm( F.softmax(end_scores, 1).unsqueeze(1), x).squeeze(1) ptr_net_in = ptr_net_e h0 = self.dropout(h0) h0 = self.rnn(ptr_net_in, h0) elif self.answer_opt == 1: ptr_net_b = torch.bmm(F.softmax(st_scores, 1).unsqueeze(1), x).squeeze(1) ptr_net_b = self.dropout(ptr_net_b) h0 = self.dropout(h0) ptr_net_in = ptr_net_b h1 = self.rnn(ptr_net_in, h0) end_scores = self.attn_e(x, h1, x_mask) h0 = h1 else: end_scores = self.attn_e(x, h0, x_mask) ptr_net_e = torch.bmm( F.softmax(end_scores, 1).unsqueeze(1), x).squeeze(1) ptr_net_in = ptr_net_e h0 = self.dropout(h0) h0 = self.rnn(ptr_net_in, h0) end_scores_list.append(end_scores) if self.mem_type == 1: mask = generate_mask(self.alpha.data.new(x.size(0), self.num_turn), self.mem_random_drop, self.training) mask = [m.contiguous() for m in torch.unbind(mask, 1)] start_scores_list = [ mask[idx].view(x.size(0), 1).expand_as(inp) * F.softmax(inp, 1) for idx, inp in enumerate(start_scores_list) ] end_scores_list = [ mask[idx].view(x.size(0), 1).expand_as(inp) * F.softmax(inp, 1) for idx, inp in enumerate(end_scores_list) ] start_scores = torch.stack(start_scores_list, 2) end_scores = torch.stack(end_scores_list, 2) start_scores = torch.mean(start_scores, 2) end_scores = torch.mean(end_scores, 2) start_scores.data.masked_fill_(x_mask.data, SMALL_POS_NUM) end_scores.data.masked_fill_(x_mask.data, SMALL_POS_NUM) start_scores = torch.log(start_scores) end_scores = torch.log(end_scores) else: start_scores = start_scores_list[-1] end_scores = end_scores_list[-1] return start_scores, end_scores
def predict(self, loader): """ Predict for an input. Args ---- loader : PyTorch DataLoader. """ self.model.eval() all_preds = [] all_ys = [] all_cs = [] all_ts = [] all_ms = [] all_idx = [] if isinstance(loader.dataset, torch.utils.data.Subset): n_frames = loader.dataset.dataset.n_frames elif isinstance(loader.dataset, torch.utils.data.ConcatDataset): n_frames = loader.dataset.datasets[0].n_frames else: n_frames = loader.dataset.n_frames with torch.no_grad(): for batch_samples in tqdm(loader): # prepare training sample X = batch_samples['X'] if X.dim() == 4: full_track = False # batch_size x in_channels x 1025 x 129 else: bs = X.size(0) ns = X.size(1) full_track = True # batch_size * splits x in_channels x 1025 x 129 X = X.view(bs * ns, self.in_channels, self.n_fft, n_frames) # batch_size x in_channels x 1025 x 129 x 2 X_complex = batch_samples['X_complex'] if X_complex.dim() != 5: # batch_size * splits x in_channels x 1025 x 129 x 2 X_complex = X_complex.view( bs * ns, self.out_channels, self.n_fft, n_frames, 2) # batch_size x nclasses x in_channels x 1025 x time samples x 2 y = batch_samples['y_complex'] # batch_size x nclasses cs = batch_samples['c'] # batch_size x 1 ts = batch_samples['t'] track_idx = batch_samples['track_idx'] if self.USE_CUDA: X = X.cuda() X_complex = X_complex.cuda() y = y.cuda() if X.size(0) > 4: X_list = torch.split(X, 4, dim=0) else: X_list = [X] masks_list = [] pred_list = [] for X in X_list: # detach hidden state self.model.detach_hidden(X.size(0)) # forward pass preds, mask = self.model(X) masks_list += [mask] pred_list += [preds] mask = torch.cat(masks_list, dim=0) preds = torch.cat(pred_list, dim=0) if full_track: # batch size x nclasses x in_channels x 1025 x time samples if self.regression: preds = preds.view( bs, ns, self.n_classes, self.out_channels, self.n_fft, n_frames) preds = torch.unbind(preds, dim=1) preds = torch.cat(preds, dim=4) else: mask = mask.view( bs, ns, self.n_classes, self.out_channels, self.n_fft, n_frames) mask = torch.unbind(mask, dim=1) mask = torch.cat(mask, dim=4) # batch_size x in_channels x 1025 x time samples x 2 X_complex = X_complex.view( bs, ns, self.out_channels, self.n_fft, n_frames, 2) X_complex = torch.unbind(X_complex, dim=1) X_complex = torch.cat(X_complex, dim=3) # convert to complex # batch size x nclasses x in_channels x 1025 x time samples x 2 X_complex = X_complex.unsqueeze(1).repeat( 1, self.n_classes, 1, 1, 1, 1) X_complex = self._to_complex(X_complex) if self.regression: _, X_phase = magphase(X_complex) preds = preds.cpu().numpy() * X_phase else: preds = mask.cpu().numpy() * X_complex # batch size x nclasses x in_channels x 1025 x time samples ys = self._to_complex(y) all_preds += [preds] all_ys += [ys] all_cs += [cs] all_ts += [ts] all_ms += [mask.cpu().numpy()] all_idx += [track_idx] return all_preds, all_ys, all_cs, all_ts, all_ms, all_idx
def _conductance_grads(self, forward_fn, input, target_ind=None): with torch.autograd.set_grad_enabled(True): # Set a forward hook on specified module and run forward pass to # get output tensor size. saved_tensor = None def forward_hook(module, inp, out): nonlocal saved_tensor saved_tensor = out hook = self.layer.register_forward_hook(forward_hook) output = forward_fn(input) # Compute layer output tensor dimensions and total number of units. # The hidden layer tensor is assumed to have dimension (num_hidden, ...) # where the product of the dimensions >= 1 correspond to the total # number of hidden neurons in the layer. layer_size = tuple(saved_tensor.size())[1:] layer_units = int(np.prod(layer_size)) # Remove unnecessary forward hook. hook.remove() # Backward hook function to override gradients in order to obtain # just the gradient of each hidden unit with respect to input. saved_grads = None def backward_hook(grads): nonlocal saved_grads saved_grads = grads zero_mat = torch.zeros((1, ) + layer_size) scatter_indices = torch.arange(0, layer_units).view_as(zero_mat) # Creates matrix with each layer containing a single unit with # value 1 and remaining zeros, which will provide gradients # with respect to each unit independently. to_return = torch.zeros((layer_units, ) + layer_size).scatter( 0, scatter_indices, 1) to_repeat = [1] * len(to_return.shape) to_repeat[0] = grads.shape[0] // to_return.shape[0] expanded = to_return.repeat(to_repeat) return expanded # Create a forward hook in order to attach backward hook to appropriate # tensor. Save backward hook in order to remove hook appropriately. back_hook = None def forward_hook_register_back(module, inp, out): nonlocal back_hook back_hook = out.register_hook(backward_hook) hook = self.layer.register_forward_hook(forward_hook_register_back) # Expand input to include layer_units copies of each input. # This allows obtaining gradient with respect to each hidden unit # in one pass. expanded_input = torch.repeat_interleave(input, layer_units, dim=0) output = forward_fn(expanded_input) hook.remove() output = output[:, target_ind] if target_ind is not None else output input_grads = torch.autograd.grad(torch.unbind(output), expanded_input) # Remove backwards hook back_hook.remove() # Remove duplicates in gradient with respect to hidden layer, # choose one for each layer_units indices. output_mid_grads = torch.index_select( saved_grads, 0, torch.tensor(range(0, input_grads[0].shape[0], layer_units)), ) return input_grads[0], output_mid_grads, layer_units
def build(self, im_batch): ''' Decomposes a batch of images into a complex steerable pyramid. The pyramid typically has ~4 levels and 4-8 orientations. Args: im_batch (torch.Tensor): Batch of images of shape [N,C,H,W] Returns: pyramid: list containing torch.Tensor objects storing the pyramid ''' assert im_batch.device == self.device, 'Devices invalid (pyr = {}, batch = {})'.format( self.device, im_batch.device) assert im_batch.dtype == torch.float32, 'Image batch must be torch.float32' assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]' assert im_batch.shape[ 1] == 1, 'Second dimension must be 1 encoding grayscale image' im_batch = im_batch.squeeze(1) # flatten channels dim height, width = im_batch.shape[2], im_batch.shape[1] # Check whether image size is sufficient for number of levels if self.height > int(np.floor(np.log2(min(width, height))) - 2): raise RuntimeError( 'Cannot build {} levels, image too small.'.format(self.height)) # Prepare a grid log_rad, angle = math_utils.prepare_grid(height, width) # Radial transition function (a raised cosine in log-frequency): Xrcos, Yrcos = math_utils.rcosFn(1, -0.5) Yrcos = np.sqrt(Yrcos) YIrcos = np.sqrt(1 - Yrcos**2) lo0mask = pointOp(log_rad, YIrcos, Xrcos) hi0mask = pointOp(log_rad, Yrcos, Xrcos) # Note that we expand dims to support broadcasting later lo0mask = torch.from_numpy(lo0mask).float()[None, :, :, None].to(self.device) hi0mask = torch.from_numpy(hi0mask).float()[None, :, :, None].to(self.device) # Fourier transform (2D) and shifting batch_dft = torch.rfft(im_batch, signal_ndim=2, onesided=False) batch_dft = math_utils.batch_fftshift2d(batch_dft) # Low-pass lo0dft = batch_dft * lo0mask # Start recursively building the pyramids coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height - 1) # High-pass hi0dft = batch_dft * hi0mask hi0 = math_utils.batch_ifftshift2d(hi0dft) hi0 = torch.ifft(hi0, signal_ndim=2) hi0_real = torch.unbind(hi0, -1)[0] coeff.insert(0, hi0_real) return coeff
def compute_pr_curves(class_hist, total_hist): """ Computes precision recall curves from the true sample / total sample histogram tensors. The histogram tensors are num_bins x num_classes and each column represents a histogram over prediction_probabilities. The two tensors should have the same dimensions. The two tensors should have nonnegative integer values. Returns map of precision / recall values from highest precision to lowest and the calculated AUPRC (i.e. the average precision). """ assert torch.is_tensor(class_hist) and torch.is_tensor( total_hist), "Both arguments must be tensors" assert (class_hist.dtype == torch.int64 and total_hist.dtype == torch.int64), "Both arguments must contain int64 values" assert (len(class_hist.size()) == 2 and len(total_hist.size()) == 2), "Both arguments must have 2 dimensions, (score_bin, class)" assert (class_hist.size() == total_hist.size()), """ For compute_pr_curve, arguments must be of same size. class_hist.size(): %s total_hist.size(): %s """ % ( str(class_hist.size()), str(total_hist.size()), ) assert (class_hist > total_hist).sum() == 0, ( "Invalid. Class histogram must be less than or equal to total histogram" ) num_bins = class_hist.size()[0] # Cumsum from highest bucket to lowest cum_class_hist = torch.cumsum(torch.flip(class_hist, dims=(0, )), dim=0).double() cum_total_hist = torch.cumsum(torch.flip(total_hist, dims=(0, )), dim=0).double() class_totals = cum_class_hist[-1, :] prec_t = cum_class_hist / cum_total_hist recall_t = cum_class_hist / class_totals prec = torch.unbind(prec_t, dim=1) recall = torch.unbind(recall_t, dim=1) assert len(prec) == len( recall ), "The number of precision curves does not match the number of recall curves" final_prec = [] final_recall = [] final_ap = [] for c, prec_curve in enumerate(prec): recall_curve = recall[c] assert ( recall_curve.size()[0] == num_bins and prec_curve.size()[0] == num_bins ), "Precision and recall curves do not have the correct number of entries" # Check if any samples from class were seen if class_totals[c] == 0: continue # Remove duplicate entries prev_r = torch.tensor(-1.0).double() prev_p = torch.tensor(1.1).double() new_recall_curve = torch.tensor([], dtype=torch.double) new_prec_curve = torch.tensor([], dtype=torch.double) for idx, r in enumerate(recall_curve): p = prec_curve[idx] # Remove points on PR curve that are invalid if r.item() <= 0: continue # Remove duplicates (due to empty buckets): if r.item() == prev_r.item() and p.item() == prev_p.item(): continue # Add points to curve new_recall_curve = torch.cat((new_recall_curve, r.unsqueeze(0)), dim=0) new_prec_curve = torch.cat((new_prec_curve, p.unsqueeze(0)), dim=0) prev_r = r prev_p = p ap = calc_ap(new_prec_curve, new_recall_curve) final_prec.append(new_prec_curve) final_recall.append(new_recall_curve) final_ap.append(ap) return {"prec": final_prec, "recall": final_recall, "ap": final_ap}
def forward( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, answer_masks=None, ): r""" start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. Returns: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): Span-start scores (before SoftMax). end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): Span-end scores (before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Examples:: # The checkpoint roberta-large is not fine-tuned for question answering. Please see the # examples/run_squad.py example to see how to fine-tune a model to a question answering task. from transformers import RobertaTokenizer, RobertaForQuestionAnswering import torch tokenizer = RobertaTokenizer.from_pretrained('roberta-base') model = RobertaForQuestionAnswering.from_pretrained('roberta-base') question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" input_ids = tokenizer.encode(question, text) start_scores, end_scores = model(torch.tensor([input_ids])) all_tokens = tokenizer.convert_ids_to_tokens(input_ids) answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]) """ outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) outputs = ( start_logits, end_logits, ) + outputs[2:] if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index, reduce=False) start_losses = [(loss_fct(start_logits, _start_positions) * _span_mask) \ for (_start_positions, _span_mask) \ in zip(torch.unbind(start_positions, dim=1), torch.unbind(answer_masks, dim=1))] end_losses = [(loss_fct(end_logits, _end_positions) * _span_mask) \ for (_end_positions, _span_mask) \ in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_masks, dim=1))] total_loss = sum(start_losses + end_losses) total_loss = torch.mean(total_loss) / 2 outputs = (total_loss, ) + outputs return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
def __init__( self, actions, actions_logp, actions_entropy, dones, behaviour_logits, target_logits, discount, rewards, values, bootstrap_value, dist_class, valid_mask, vf_loss_coeff=0.5, entropy_coeff=0.01, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, ): """Policy gradient loss with vtrace importance weighting. VTraceLoss takes tensors of shape [T, B, ...], where `B` is the batch_size. The reason we need to know `B` is for V-trace to properly handle episode cut boundaries. Args: actions: An int|float32 tensor of shape [T, B, ACTION_SPACE]. actions_logp: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. dones: A bool tensor of shape [T, B]. behaviour_logits: A list with length of ACTION_SPACE of float32 tensors of shapes [T, B, ACTION_SPACE[0]], ..., [T, B, ACTION_SPACE[-1]] target_logits: A list with length of ACTION_SPACE of float32 tensors of shapes [T, B, ACTION_SPACE[0]], ..., [T, B, ACTION_SPACE[-1]] discount: A float32 scalar. rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. dist_class: action distribution class for logits. valid_mask: A bool tensor of valid RNN input elements (#2992). """ # Compute vtrace on the CPU for the better perf device = behaviour_logits[0].get_device() if device >= 0: device = torch.device("cuda:" + str(device)) else: device = torch.device("cpu") for i in range(len(behaviour_logits)): behaviour_logits[i] = behaviour_logits[i].data.cpu() target_logits[i] = target_logits[i].data.cpu() # Make sure tensor ranks are as expected # The rest will be checked by from_aciton_log_probs. assert len(behaviour_logits[i].size()) == 3 assert len(target_logits[i].size()) == 3 reverse_dones = (dones.float() == torch.zeros_like(dones.float())) self.vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=behaviour_logits, target_policy_logits=target_logits, actions=torch.unbind(actions.data.cpu(), dim=2), discounts=reverse_dones.float() * discount, rewards=rewards.data.cpu(), values=values.data.cpu(), bootstrap_value=bootstrap_value.data.cpu(), dist_class=dist_class, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold) self.value_targets = self.vtrace_returns.vs valid_mask = (valid_mask == torch.ones_like(valid_mask).to(device)) self.pi_loss = -torch.sum( (actions_logp * self.vtrace_returns.pg_advantages.to(device))[valid_mask]) self.valid_mask = valid_mask self.values = values self.logp_min = actions_logp.min() self.logp_mean = actions_logp.mean() self.logp_max = actions_logp.max() self.ad_min = self.vtrace_returns.pg_advantages.min() self.ad_mean = self.vtrace_returns.pg_advantages.mean() self.ad_max = self.vtrace_returns.pg_advantages.max() delta = (values - self.vtrace_returns.vs)[valid_mask] self.vf_loss = 0.5 * (delta**2).sum() self.entropy = torch.sum(actions_entropy[valid_mask]) self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - self.entropy * entropy_coeff)
def main(): args = parse_arguments() config = json.load(open(args.config)) # dataset_type = config['dataset']['type'] # data loader params loader = get_instance(datasets, 'dataset', 'val_args', config) # to_tensor = transforms.ToTensor() to_tensor = transforms.Compose([ transforms.Resize((config['dataset']['train_args']['crop_size'], config['dataset']['train_args']['crop_size'])), transforms.ToTensor() ]) restore_transform = transforms.Compose([ DeNormalize(config['dataset']['mean'], config['dataset']['std']), ]) palette = loader.dataset.palette model = get_instance(models, 'model', 'args', config) checkpoint = torch.load(args.weight) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict'] if 'module' in list(checkpoint.keys())[0] and not isinstance( model, torch.nn.DataParallel): model = torch.nn.DataParallel(model) model.load_state_dict(checkpoint) model.to(device) model.eval() print(model) if not os.path.exists(args.output): os.makedirs(args.output) with torch.no_grad(): image_path = args.image if image_path is not None: image_name = str(image_path).split('/')[-1].split('.')[0] image = Image.open(image_path).convert('RGB') input = normalize(to_tensor(image)).unsqueeze(0) mask = inference(model, input, image, palette) mask_path = os.path.join(args.output, image_name + config['name'] + '.png') save_mask(mask, mask_path) else: dataiter = iter(loader) while True: batch = dataiter.next() images = batch['img'] label = batch['label'] if config['dataset']['name'] == "SeqLane": images = torch.unbind(images, dim=0) image = images[0][-1] image_name = batch['images_name'][-1][0] else: image = images[0] image_name = batch['img_name'][0] # images = torch.unbind(images, dim=0) # print(image.size()) mask = inference(model, batch['img'], restore_transform(image), label[0], palette) # image_name = batch['images_name'][-1] print(image_name) mask_path = os.path.join(args.output, image_name[-1])
save_ply(pc, "./input.ply", colors=None, normals=None) pc = torch.from_numpy(pc).requires_grad_().to(cuda0).unsqueeze(0) pc = pc.transpose(2, 1) # test furthest point idx, sampled_pc = furthest_point_sample(pc, 1250) output = sampled_pc.transpose(2, 1).cpu().squeeze() save_ply(output.detach(), "./output.ply", colors=None, normals=None) # test KNN knn_points, _, _ = group_knn(10, sampled_pc, pc, NCHW=True) # B, C, M, K labels = torch.arange( 0, knn_points.size(2)).unsqueeze_(0).unsqueeze_(0).unsqueeze_( -1) # 1, 1, M, 1 labels = labels.expand(knn_points.size(0), -1, -1, knn_points.size(3)) # B, 1, M, K # B, C, P labels = torch.cat(torch.unbind(labels, dim=-1), dim=-1).squeeze().detach().cpu().numpy() knn_points = torch.cat(torch.unbind(knn_points, dim=-1), dim=-1).transpose( 2, 1).squeeze(0).detach().cpu().numpy() save_ply_property(knn_points, labels, "./knn_output.ply", cmap_name='jet') from torch.autograd import gradcheck # test = gradcheck(furthest_point_sample, [pc, 1250], eps=1e-6, atol=1e-4) # print(test) test = gradcheck(gather_points, [pc.to(dtype=torch.float64), idx], eps=1e-6, atol=1e-4) print(test)
def apply(func, M): ## applies a function along a given dimension tList = [func(m) for m in torch.unbind(M, dim=0)] res = torch.stack(tList, dim=0) return res
def future_v(self, x, boundry_params, segment_params): TINY = 1e-6 v0, q0, rhoNp1 = torch.unbind(boundry_params, dim=2) vf, a_var, rhocr, g, omegar, omegas, epsq, epsv = torch.unbind( segment_params, dim=2) #rhocr = torch.clamp(rhocr, min=30, max=500) try: if self.print_count % 100 == 0: #random.random() > 0.9999: wandb.log({"vf": wandb.Histogram(vf.cpu().detach().numpy())}) wandb.log( {"a_var": wandb.Histogram(a_var.cpu().detach().numpy())}) wandb.log( {"rhocr": wandb.Histogram(rhocr.cpu().detach().numpy())}) wandb.log({"g": wandb.Histogram(g.cpu().detach().numpy())}) wandb.log( {"omegar": wandb.Histogram(q0.cpu().detach().numpy())}) wandb.log( {"omegas": wandb.Histogram(rhoNp1.cpu().detach().numpy())}) # tb.add_histogram('vf', vf, epoch) # tb.add_histogram('a', a, epoch) # tb.add_histogram('rhocr', rhocr, epoch) # tb.add_histogram('g', g, epoch) # tb.add_histogram('omegar', omegar, epoch) # tb.add_histogram('omegas', omegas, epoch) except Exception as e: print(e) current_densities = x[:, :, self.rho_index] current_flows = x[:, :, self.q_index] if self.calculate_velocity: current_velocities = current_flows / ( current_densities * self.segment_fixed[:, self.lambda_index] + TINY) if self.print_count % 1000 == 0: #random.random() > 0.9999: wandb.log({ "current_velocities": wandb.Histogram(current_velocities.cpu().detach().numpy()) }) wandb.log({ "current_densities": wandb.Histogram(current_densities.cpu().detach().numpy()) }) wandb.log({ "current_flows": wandb.Histogram(current_flows.cpu().detach().numpy()) }) current_velocities = torch.clamp(current_velocities, min=5, max=200) else: current_velocities = x[:, :, self.v_index] current_velocities = torch.clamp(current_velocities, min=5, max=200) prev_velocities = torch.cat([v0, current_velocities[:, :-1]], dim=1) next_densities = torch.cat([current_densities[:, 1:], rhoNp1], dim=1) stat_speed = vf * torch.exp( torch.div(-1, a_var + TINY) * torch.pow( torch.div(current_densities, rhocr + TINY) + TINY, a_var)) if self.print_count % 1000 == 0: #random.random() > 0.9999: wandb.log({ "stat_speed": wandb.Histogram(stat_speed.cpu().detach().numpy()) }) print("stat speed", stat_speed.size(), stat_speed.min().item(), stat_speed.mean().item(), stat_speed.max().item()) print("v0,q0,rhoNN", v0[0].item(), q0[0].item(), rhoNp1[0].item(), v0.mean().item(), q0.mean().item(), rhoNp1.mean().item()) print("q1", x[0, 0, self.q_index].item(), x[:, 0, self.q_index].mean().item()) print("vf, a, rhocr,g, omegar, omegas, epsq, epsv", vf[0].mean().item(), a_var[0].mean().item(), rhocr[0].mean().item(), g[0].mean().item(), omegar[0].mean().item(), omegas[0].mean().item(), epsq[0].mean().item(), epsv[0].mean().item()) #import pdb; pdb.set_trace() return current_velocities + (torch.div(self.T,self.tau+TINY)) * (stat_speed - current_velocities ) \ + (torch.div(self.T,self.Delta) * current_velocities * (prev_velocities - current_velocities)) \ - (torch.div(self.nu*self.T, (self.tau*self.Delta)) * torch.div( (next_densities - current_densities), (current_densities+self.kappa)) ) \ - (torch.div( (self.delta*self.T) , (self.Delta * self.lambda_var) ) * torch.div( (x[:,:,self.r_index]*current_velocities),(current_densities+self.kappa) ) ) \ + epsv
def forward(self, x): out = torch.index_select(self.I, 0, x) # add some noise - not enough to change anything (I don't think) + Variable(torch.rand(o_t.size())*1e-5) out = torch.stack([o_t for o_t in torch.unbind(out, 0)]) return out
def forward(self, h_o_prev, y_e_prev, C_o, pos, phase, memor): """ h_o_prev: N * O * dim_h_o y_e_prev: N * O * dim_y_e C_o: N * C2_1 * C2_2 """ o = self.o perm_mat = torch.zeros(o.N, o.O, o.O) m = memor.cpu().detach() p = pos.cpu() for i in range(o.N): tree = cKDTree(m[i]) r = tree.query(p[i], k=10)[1].reshape(-1) #r = np.array(r.tolist() + [sum(range(o.O)) - r.sum()]) #print(r, p[i].tolist(), m[i].tolist()) if(r.sum() != sum(range(o.O))): print(m[i]) print(p[i]) print(r) perm_mat[i, torch.arange(o.O), r] = 1 perm_mat = perm_mat.cuda() perm_mat_inv = perm_mat.transpose(1, 2) #pos = pos / 120 #inp = pos.unsqueeze(1).expand(o.N, o.O, 2).contiguous().view(-1, 2) #y_e_prev = self.linear_pos_to_conf(inp).tanh().abs().view(o.N, o.O, o.dim_y_e) #if "no_rep" not in o.exp_config: # if o.train == 0: # print(y_e_prev[0].view(-1)) # delta = torch.arange(0, o.O).type(torch.FloatTensor).cuda().unsqueeze(0) * 0.0001 # 1 * O # y_e_prev_mdf = y_e_prev.squeeze(2).round() - Variable(delta) # perm_mat = self.permutation_matrix_calculator(y_e_prev_mdf) # N * O * O # perm_mat_inv = perm_mat.transpose(1, 2) # N * O * O if phase == 'obs': #if "no_tem" in o.exp_config: # h_o_prev = Variable(torch.zeros_like(h_o_prev).cuda())W3 # y_e_prev = Variable(torch.zeros_like(y_e_prev).cuda()) # Sort h_o_prev and y_e_prev if "no_rep" not in o.exp_config: h_o_prev = perm_mat.bmm(h_o_prev) # N * O * dim_h_o # y_e_prev = perm_mat.bmm(y_e_prev) # N * O * dim_y_e # if o.train == 0: # print('Shuffled:') # print(y_e_prev[0].view(-1)) # Update h_o h_o_prev_split = torch.unbind(h_o_prev, 1) # N * dim_h_o h_o_split = {} k_split = {} r_split = {} for i in range(0, o.O): self.ntm_cell.i = i h_o_split[i], C_o, k_split[i], r_split[i] = self.ntm_cell(h_o_prev_split[i], C_o, pos, phase) h_o = torch.stack(tuple(h_o_split.values()), dim=1) # N * O * dim_h_o k = torch.stack(tuple(k_split.values()), dim=1) # N * O * C2_2 r = torch.stack(tuple(r_split.values()), dim=1) # N * O * C2_2 att = self.ntm_cell.att mem = self.ntm_cell.mem # Recover the original order of h_o if "no_rep" not in o.exp_config: #perm_mat_inv = perm_mat.transpose(1, 2) # N * O * O h_o = perm_mat_inv.bmm(h_o) # N * O * dim_h_o k = perm_mat_inv.bmm(k) # N * O * dim_c_2 r = perm_mat_inv.bmm(r) # N * O * dim_c_2 att = perm_mat_inv.data[self.ntm_cell.n].mm(att.view(o.O, -1)).view(o.O, -1, self.ntm_cell.wa) # O * ha * wa mem = perm_mat_inv.data[self.ntm_cell.n].mm(mem.view(o.O, -1)).view(o.O, -1, self.ntm_cell.wa) # O * ha * wa h_o_prev = perm_mat_inv.bmm(h_o_prev) #y_e_prev = perm_mat_inv.bmm(y_e_prev) if o.v > 0: self.att[self.t].copy_(att) self.mem[self.t].copy_(mem) else: h_o = h_o_prev y_e, y_l, y_p, Y_s, Y_a = self.generate_outputs(h_o, pos) # adaptive computation time if "act" in o.exp_config: #y_e_perm = perm_mat.bmm(y_e).round() # N * O * dim_y_e #y_e_mask = y_e_prev.round() + y_e_perm # N * O * dim_y_e # y_e_mask = y_e_perm #y_e_mask = perm_mat.bmm(y_e).round() # N * O * dim_y_e #y_e_mask = y_e_mask.lt(0.5).type_as(y_e_mask) #y_e_mask = y_e_mask.cumsum(1) #y_e_mask = y_e_mask.lt(0.5).type_as(y_e_mask) #ones = Variable(torch.ones(y_e_mask.size(0), 1, o.dim_y_e).cuda()) # N * 1 * dim_y_e #y_e_mask = torch.cat((ones, y_e_mask[:, 0:o.O-1]), dim=1) #y_e_mask = perm_mat_inv.bmm(y_e_mask) # N * O * dim_y_e y_e_mask = y_e.round() # N * O * dim_y_e y_e_inv_mask = y_e.round().lt(0.5).type_as(y_e) # N * O * dim_y_e h_o = y_e_mask * (h_o - h_o_prev) + h_o_prev # N * O * dim_h_o # h_o = y_e_mask * h_o # N * O * dim_h_o y_e = y_e_mask * y_e # N * O * dim_y_e y_p = y_e_mask * y_p # N * O * dim_y_p Y_a = y_e_mask.view(-1, o.O, o.dim_y_e, 1, 1) * Y_a # N * O * D * h * w global_agent_pos = y_e_mask * pos.unsqueeze(1).expand(o.N, o.O, 2) global_tracker_offset = y_p[:, :, -2:] * torch.Tensor([o.H // 2, o.W // 2]).cuda() if phase == 'obs': memor = (y_e_inv_mask * memor) + global_agent_pos + global_tracker_offset if self.t == o.T - 1: print(y_e.data.view(-1, o.O)[0:1, 0:min(o.O, 10)]) return memor, h_o, y_e, y_l, y_p, Y_s, Y_a
def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): if height <= 1: # Low-pass lo0 = math_utils.batch_ifftshift2d(lodft) lo0 = torch.ifft(lo0, signal_ndim=2) lo0_real = torch.unbind(lo0, -1)[0] coeff = [lo0_real] else: Xrcos = Xrcos - np.log2(self.scale_factor) #################################################################### ####################### Orientation bandpass ####################### #################################################################### himask = pointOp(log_rad, Yrcos, Xrcos) himask = torch.from_numpy(himask[None, :, :, None]).float().to(self.device) order = self.nbands - 1 const = np.power(2, 2 * order) * np.square( factorial(order)) / (self.nbands * factorial(2 * order)) Ycosn = 2 * np.sqrt(const) * np.power(np.cos( self.Xcosn), order) * (np.abs(self.alpha) < np.pi / 2) # [n,] # Loop through all orientation bands orientations = [] for b in range(self.nbands): anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi * b / self.nbands) anglemask = anglemask[None, :, :, None] # for broadcasting anglemask = torch.from_numpy(anglemask).float().to(self.device) # Bandpass filtering banddft = lodft * anglemask * himask # Now multiply with complex number # (x+yi)(u+vi) = (xu-yv) + (xv+yu)i banddft = torch.unbind(banddft, -1) banddft_real = self.complex_fact_construct.real * banddft[ 0] - self.complex_fact_construct.imag * banddft[1] banddft_imag = self.complex_fact_construct.real * banddft[ 1] + self.complex_fact_construct.imag * banddft[0] banddft = torch.stack((banddft_real, banddft_imag), -1) band = math_utils.batch_ifftshift2d(banddft) band = torch.ifft(band, signal_ndim=2) orientations.append(band) #################################################################### ######################## Subsample lowpass ######################### #################################################################### # Don't consider batch_size and imag/real dim dims = np.array(lodft.shape[1:3]) # Both are tuples of size 2 low_ind_start = (np.ceil((dims + 0.5) / 2) - np.ceil((np.ceil( (dims - 0.5) / 2) + 0.5) / 2)).astype(int) low_ind_end = (low_ind_start + np.ceil( (dims - 0.5) / 2)).astype(int) # Subsampling indices log_rad = log_rad[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]] angle = angle[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]] # Actual subsampling lodft = lodft[:, low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1], :] # Filtering YIrcos = np.abs(np.sqrt(1 - Yrcos**2)) lomask = pointOp(log_rad, YIrcos, Xrcos) lomask = torch.from_numpy(lomask[None, :, :, None]).float() lomask = lomask.to(self.device) # Convolution in spatial domain lodft = lomask * lodft #################################################################### ####################### Recursion next level ####################### #################################################################### coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height - 1) coeff.insert(0, orientations) return coeff
def forward(self, x): cfilter = th.cat([ f.repeat(i) for f, i in zip(th.unbind(self._filter, 1), self.cat) ], 0).unsqueeze(0) # print(x.shape, cfilter.shape) return x * (self.hard_filter * cfilter).expand_as(x)
def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): if len(coeff) == 1: dft = torch.rfft(coeff[0], signal_ndim=2, onesided=False) dft = math_utils.batch_fftshift2d(dft) return dft Xrcos = Xrcos - np.log2(self.scale_factor) #################################################################### ####################### Orientation Residue ######################## #################################################################### himask = pointOp(log_rad, Yrcos, Xrcos) himask = torch.from_numpy(himask[None, :, :, None]).float().to(self.device) lutsize = 1024 Xcosn = np.pi * np.array(range(-(2 * lutsize + 1), (lutsize + 2))) / lutsize order = self.nbands - 1 const = np.power(2, 2 * order) * np.square( factorial(order)) / (self.nbands * factorial(2 * order)) Ycosn = np.sqrt(const) * np.power(np.cos(Xcosn), order) orientdft = torch.zeros_like(coeff[0][0]) for b in range(self.nbands): anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b / self.nbands) anglemask = anglemask[None, :, :, None] # for broadcasting anglemask = torch.from_numpy(anglemask).float().to(self.device) banddft = torch.fft(coeff[0][b], signal_ndim=2) banddft = math_utils.batch_fftshift2d(banddft) banddft = banddft * anglemask * himask banddft = torch.unbind(banddft, -1) banddft_real = self.complex_fact_reconstruct.real * banddft[ 0] - self.complex_fact_reconstruct.imag * banddft[1] banddft_imag = self.complex_fact_reconstruct.real * banddft[ 1] + self.complex_fact_reconstruct.imag * banddft[0] banddft = torch.stack((banddft_real, banddft_imag), -1) orientdft = orientdft + banddft #################################################################### ########## Lowpass component are upsampled and convolved ########## #################################################################### dims = np.array(coeff[0][0].shape[1:3]) lostart = (np.ceil((dims + 0.5) / 2) - np.ceil((np.ceil( (dims - 0.5) / 2) + 0.5) / 2)).astype(np.int32) loend = lostart + np.ceil((dims - 0.5) / 2).astype(np.int32) nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]] nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]] YIrcos = np.sqrt(np.abs(1 - Yrcos**2)) lomask = pointOp(nlog_rad, YIrcos, Xrcos) # Filtering lomask = pointOp(nlog_rad, YIrcos, Xrcos) lomask = torch.from_numpy(lomask[None, :, :, None]) lomask = lomask.float().to(self.device) ################################################################################ # Recursive call for image reconstruction nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle) resdft = torch.zeros_like(coeff[0][0]).to(self.device) resdft[:, lostart[0]:loend[0], lostart[1]:loend[1], :] = nresdft * lomask return resdft + orientdft
def _compute_dplstm_grad_sample(layer: DPLSTM, A: torch.Tensor, B: torch.Tensor, batch_dim: int = 0) -> None: """ Computes per sample gradients for ``DPLSTM`` layer Parameters ---------- layer : opacus.layers.dp_lstm.DPLSTM Layer A : torch.Tensor Activations B : torch.Tensor Backpropagations batch_dim : int, optional Batch dimension position """ lstm_params = [ layer.weight_ih_l0, layer.weight_hh_l0, layer.bias_ih_l0, layer.bias_hh_l0, ] lstm_out_dim = layer.hidden_size x = torch.unbind(A, dim=1) hooks_delta = torch.unbind(B, dim=1) SEQ_LENGTH = len(x) BATCH_SIZE = B.shape[0] h_init = torch.zeros(1, BATCH_SIZE, lstm_out_dim, device=A.device) c_init = torch.zeros(1, BATCH_SIZE, lstm_out_dim, device=A.device) delta_h = {} delta_h[SEQ_LENGTH - 1] = 0 f_last = 0 dc_last = 0 for t in range(SEQ_LENGTH - 1, -1, -1): f_next = f_last if t == SEQ_LENGTH - 1 else layer.cells[t + 1].f_t dc_next = dc_last if t == SEQ_LENGTH - 1 else layer.cells[t + 1].dc_t c_prev = c_init if t == 0 else layer.cells[t - 1].c_t delta_h[t - 1] = layer.cells[t].backward(x[t], delta_h[t], hooks_delta[t], f_next, dc_next, c_prev) grad_sample = {param: 0 for param in lstm_params} for t in range(0, SEQ_LENGTH): h_prev = h_init[0, :] if t == 0 else layer.cells[t - 1].h_t[0, :] grad_sample[layer.weight_ih_l0] += torch.einsum( "ij,ik->ijk", layer.cells[t].dgates_t, x[t]) grad_sample[layer.weight_hh_l0] += torch.einsum( "ij,ik->ijk", layer.cells[t].dgates_t, h_prev) grad_sample[layer.bias_ih_l0] += layer.cells[t].dgates_t grad_sample[layer.bias_hh_l0] += layer.cells[t].dgates_t for param, grad_value in grad_sample.items(): # pyre-ignore[6] _create_or_extend_grad_sample(param, grad_value, batch_dim)
def RadiallyAverageFourierTransform(x, dim=2): real, imag = torch.unbind(SpaceToFourier(x, signal_dim=dim), dim=-1) xF = (real.pow(2) + imag.pow(2)).sqrt() fsc = RadiallyAverage(xF, dim) return fsc
def batch_ifftshift2d(x): real, imag = torch.unbind(x, -1) for dim in range(len(real.size()) - 1, 0, -1): real = roll_n(real, axis=dim, n=real.size(dim) // 2) imag = roll_n(imag, axis=dim, n=imag.size(dim) // 2) return torch.stack((real, imag), -1) # last dim=2 (real&imag)
def forward(self, v0): X0 = torch.stack((self.h0, v0), 1) X1 = self.model(X0) h1, v1 = torch.unbind(X1, 1) return h1, v1