def __call__(self, boxlists): """ Arguments: boxlists (list[BoxList]) """ # Compute level ids s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists])) # Eqn.(1) in FPN paper target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps)) target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max) return target_lvls.to(torch.int64) - self.k_min
def __init__(self, output_size, scales, sampling_ratio): """ Arguments: output_size (list[tuple[int]] or list[int]): output size for the pooled region scales (list[float]): scales for each Pooler sampling_ratio (int): sampling ratio for ROIAlign """ super(Pooler, self).__init__() poolers = [] for scale in scales: poolers.append( ROIAlign( output_size, spatial_scale=scale, sampling_ratio=sampling_ratio ) ) self.poolers = nn.ModuleList(poolers) self.output_size = output_size # get the levels in the feature map by leveraging the fact that the network always # downsamples by a factor of 2 at each level. lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item() lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item() self.map_levels = LevelMapper(lvl_min, lvl_max)
def normalize_state_reward(self, val): zero_pos = (val == 0) val = torch.log2(val) / 15 val[zero_pos] = 0 return val
def check_separability_plus(pathdir, filename): try: pathdir_weights = "results/NN_trained_models/models/" # load the data n_variables = np.loadtxt(pathdir + filename, dtype='str').shape[1] - 1 variables = np.loadtxt(pathdir + filename, usecols=(0, )) if n_variables == 1: print(filename, "just one variable for ADD") # if there is just one variable you have nothing to separate return (-1, -1, -1) else: for j in range(1, n_variables): v = np.loadtxt(pathdir + filename, usecols=(j, )) variables = np.column_stack((variables, v)) f_dependent = np.loadtxt(pathdir + filename, usecols=(n_variables, )) f_dependent = np.reshape(f_dependent, (len(f_dependent), 1)) factors = torch.from_numpy(variables) if is_cuda: factors = factors.cuda() else: factors = factors factors = factors.float() product = torch.from_numpy(f_dependent) if is_cuda: product = product.cuda() else: product = product product = product.float() # load the trained model and put it in evaluation mode if is_cuda: model = SimpleNet(n_variables).cuda() else: model = SimpleNet(n_variables) model.load_state_dict(torch.load(pathdir_weights + filename + ".h5")) model.eval() # make some variables at the time equal to the median of factors models_one = [] models_rest = [] with torch.no_grad(): fact_vary = factors.clone() for k in range(len(factors[0])): fact_vary[:, k] = torch.full((len(factors), ), torch.median(factors[:, k])) # loop through all indices combinations var_indices_list = np.arange(0, n_variables, 1) min_error = 1000 best_i = [] best_j = [] best_mu = 0 best_sigma = 0 for i in range(1, n_variables): c = combinations(var_indices_list, i) for j in c: fact_vary_one = factors.clone() fact_vary_rest = factors.clone() rest_indx = list( filter(lambda x: x not in j, var_indices_list)) for t1 in rest_indx: fact_vary_one[:, t1] = torch.full( (len(factors), ), torch.median(factors[:, t1])) for t2 in j: fact_vary_rest[:, t2] = torch.full( (len(factors), ), torch.median(factors[:, t2])) # check if the equation is separable sm = model(fact_vary_one) + model(fact_vary_rest) #error = torch.sqrt(torch.mean((product-sm+model(fact_vary))**2))/torch.sqrt(torch.mean(product**2)) list_errs = 2 * abs(product - sm + model(fact_vary)) error = torch.median(list_errs) mu = torch.mean(torch.log2(1 + list_errs * 2**30)) sigma = torch.std(torch.log2(1 + list_errs * 2**30)) #error = 2*torch.median(abs(product-sm+model(fact_vary))) if error < min_error: min_error = error best_i = j best_j = rest_indx best_mu = mu best_sigma = sigma return min_error, best_i, best_j, best_mu, best_sigma except Exception as e: print(e) return (-1, -1, -1, -1, -1)
def train(train_loader, val_loader, model, criterion, optimizer, lr_scheduler, start_iter, tb_logger): global args, rank, world_size, best_prec1, emulate_node global grad_exp, grad_man, param_exp, param_man batch_time = AverageMeter(args.print_freq) data_time = AverageMeter(args.print_freq) losses = AverageMeter(args.print_freq) model.train() end = time.time() curr_step = start_iter emulate_step = 0 momentum_buffer = [] for master_p in master_params: momentum_buffer.append(torch.zeros_like(master_p)) grad_buffer = [] for param_g in model.parameters(): grad_buffer.append([]) for i, (input, target) in enumerate(train_loader): emulate_step += 1 if emulate_step == emulate_node: curr_step += 1 if curr_step > args.max_iter: break current_lr = adjust_learning_rate(optimizer, curr_step) target = target.cuda() input_var = input.cuda() data_time.update(time.time() - end) output = model(input_var, rank) loss = criterion(output, target) / (world_size * emulate_node) reduced_loss = loss.data.clone() if args.dist: dist.all_reduce(reduced_loss) losses.update(float(reduced_loss.item())) model.zero_grad() loss.backward() for idx, param in enumerate(model.parameters()): if param.grad is not None: grad_buffer[idx].append(param.grad.detach().clone().data) model.zero_grad() if emulate_node == emulate_step: emulate_step = 0 # reduce all gradients with low precision for idx, param in enumerate(model.parameters()): if param.grad is not None: if emulate_node == 1: param.grad.data.copy_(grad_buffer[idx][0]) continue # find maximum exponent max_exp = -100 for val in grad_buffer[idx]: t_exp = torch.log2( torch.abs(val * args.emulate_node).max()).ceil( ).detach().cpu().numpy() if t_exp > max_exp: max_exp = t_exp upper_bound = 2**(args.grad_exp - 1) - 1 shift_factor = upper_bound - max_exp if max_exp == -100 or not args.use_APS: shift_factor = 0 for grad in grad_buffer[idx]: grad.data.copy_( float_quantize(grad * (2**shift_factor), args.grad_exp, args.grad_man)) # as we use a single node to emulate multi-node, we should # first accumulate gradients within a single node and then # communicate them in the distributed system res = torch.zeros_like(grad_buffer[idx][0]) for val in grad_buffer[idx]: res = float_quantize(res + val, args.grad_exp, args.grad_man) param.grad.data.copy_(res.data / (2**shift_factor)) grad_buffer = [] for param_g in model.parameters(): grad_buffer.append([]) if args.dist: sum_gradients(model, use_APS=args.use_APS, use_kahan=args.use_kahan, grad_exp=args.grad_exp, grad_man=args.grad_man) for model_p, master_p in zip(model_params, master_params): if model_p.grad is not None: master_p.backward(model_p.grad.float()) # update parameters if args.use_lars: for idx, master_p in enumerate(master_params): if master_p.grad is not None: local_lr = master_p.norm(2) /\ (master_p.grad.data.norm(2) + args.weight_decay * master_p.norm(2)) lars_coefficient = 0.001 local_lr = local_lr * lars_coefficient momentum_buffer[idx] = args.momentum * momentum_buffer[idx].data \ + current_lr \ * local_lr \ * (master_p.grad.data + args.weight_decay * master_p.data) update = momentum_buffer[idx] master_p.data.copy_(master_p - update) else: optimizer.step() for model_p, master_p in zip(model_params, master_params): model_p.data.copy_(master_p.data) optimizer.zero_grad() batch_time.update(time.time() - end) end = time.time() if (curr_step == 1 or curr_step % args.print_freq == 0) and rank == 0: if tb_logger: tb_logger.add_scalar('loss_train', losses.avg, curr_step) tb_logger.add_scalar('lr', current_lr, curr_step) print('Iter: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'LR {lr:.4f}'.format(curr_step, args.max_iter, batch_time=batch_time, data_time=data_time, loss=losses, lr=current_lr)) if curr_step % args.val_freq == 0 and curr_step != 0: val_loss, prec1, prec5 = validate(val_loader, model, criterion) if tb_logger: tb_logger.add_scalar('loss_val', val_loss, curr_step) tb_logger.add_scalar('acc1_val', prec1, curr_step) tb_logger.add_scalar('acc5_val', prec5, curr_step) if rank == 0: # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'step': curr_step, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best, args.save_path + '/ckpt_' + str(curr_step)) del momentum_buffer val_loss, prec1, prec5 = validate(val_loader, model, criterion)
def main(): args = parser.parse_args() assert os.path.exists(args.dir), '{} does not exist'.format(args.dir) if args.complement: print('Use twos complement representation') else: print('Use true form') fs = os.listdir(args.dir) activations = {} for f in fs: if '.npy' in f: activations[f] = torch.from_numpy( np.load(os.path.join(args.dir, f))) col_keys = ['Layer', 'Activation sparsity', 'Bit sparsity'] data = [] # total_cnt = 0 total_conv_cnt = 0 total_weight_cnt = 0 total_weight_conv_cnt = 0 total_bit_cnt = 0 total_bit_conv_cnt = 0 instance_bs_dict = {} batch_size = -1 for k, v in activations.items(): batch_size = v.shape[0] v_reshape = v.view(batch_size, -1) il = torch.log2(v_reshape.max(1)[0].sort()[0][int(batch_size / 3)]) + 1 il = math.ceil(il - 1e-5) radix_position = 8 - il print(radix_position) radix_position = 7 _, value_int = truncation(v, radix_position) cnt_sum = v.view(-1).shape[0] # total_cnt += cnt_sum total_weight_cnt += (value_int.float().abs() > 0).sum().float() bit_cnt = count_bit(value_int, complement=args.complement) total_bit_cnt += bit_cnt.sum().float() value_sparsity = 1 - (v_reshape.float().abs() > 0).sum().float() / cnt_sum bit_sparsity = bit_sparse(bit_cnt, args.complement) instance_bs = [] for i in range(batch_size): instance_bs.append(bit_sparse(bit_cnt[i], args.complement).item()) instance_bs_dict[k] = instance_bs total_conv_cnt += cnt_sum total_weight_conv_cnt += (value_int.float().abs() > 0).sum().float() total_bit_conv_cnt += bit_cnt.sum().float() data.append([ k, '{:.3f}'.format(value_sparsity), '{:.3f}'.format(bit_sparsity) ]) pandas.set_option('display.width', 5000) df = pandas.DataFrame(data=data, columns=col_keys) print(df) instance_data = [] instance_keys = ['instance_id'] for k, v in instance_bs_dict.items(): instance_keys.append(k) for i in range(batch_size): in_data = [i] for layer_id in instance_keys[1:]: in_data.append('{:.3f}'.format(instance_bs_dict[layer_id][i])) instance_data.append(in_data) instance_df = pandas.DataFrame(data=instance_data, columns=instance_keys) # print(instance_df) instance_df.to_csv(os.path.join(args.dir, 'act_bs_analysis.csv'), index=None) print('act_bs_analysis.csv has been saved in {}'.format(args.dir))
def shift(x, ceil=False): max_entry = x.abs().max() if ceil: return x / 2.**torch.ceil(torch.log2(max_entry)) else: return x / 2.**torch.round(torch.log2(max_entry))
def ndcg(gt_item, pred_items): if gt_item in pred_items: index = torch.tensor(pred_items.tolist().index(gt_item), dtype=torch.float32) return torch.reciprocal(torch.log2(index + 2)) return 0
def kl_div(d1, d2): """ Compute KL-Divergence between d1 and d2. """ dirty_logs = d1 * torch.log2(d1 / d2) return torch.sum(torch.where(d1 != 0, dirty_logs, torch.zeros_like(d1)), axis=1)
def quanz_minmax_param(m_min, m_max, bit, is_scale, is_offset): r''' 根据最大值就散量化参数 参数: m_min (torch.tensor):原始数据的最小值 m_max (torch.tensor):原始数据的最大值 bit (torch.tensor):量化参数位宽 is_scale (bool): 量化是否使用scale is_offset (bool): 量化是否使用offset 返回: m_position (torch.tensor): 量化参数position m_scale (torch.tensor):量化参数scale m_offset (torch.tensor):量化参数offset ''' device = m_max.device o_max = m_max.clone().detach() o_min = m_min.clone().detach() n = torch.tensor(bit).float().to(device) if is_offset: # 保证最大值不会小于0,保证最小值不会大于0 m_max = torch.max( o_max, torch.zeros_like(o_max, dtype=torch.float32, device=device)) m_min = torch.min( o_min, torch.zeros_like(o_min, dtype=torch.float32, device=device)) interval = m_max - m_min # if interval is zero, the offset will inf, use 0.0 instead m_offset = torch.where( interval > 0.0, m_round(-torch.pow(2.0, n - 1.0) - m_min * 2**1.0 * (torch.pow(2.0, n - 1.0) - 2**-1.0) / interval), torch.zeros_like(interval, dtype=torch.float32, device=device)) m_position = torch.where( interval > 0.0, torch.floor(torch.log2(interval)) - (n - 1.0), torch.zeros_like(interval, dtype=torch.float32, device=device)) if is_scale: # if interval is zero, the scale will inf, use 0.0 instead m_scale = torch.where( interval > 0.0, (torch.pow(2.0, m_position + n) - torch.pow(2.0, m_position)) / interval, torch.ones_like(interval, dtype=torch.float32, device=device)) else: m_scale = torch.ones_like(interval, dtype=torch.float32, device=device) else: # 取正负半轴的最大值 m_max = torch.max(o_max, -o_min) m_offset = torch.zeros_like(m_max, dtype=torch.float32, device=device) m_position = torch.where( m_max > 0.0, torch.floor(torch.log2(m_max)) - (n - 2.0), torch.zeros_like(m_max, dtype=torch.float32, device=device)) if is_scale: # if interval is zero, the scale will inf, use 0.0 instead m_scale = torch.where( m_max > 0.0, (torch.pow(2.0, m_position + n - 1.0) - torch.pow(2.0, m_position)) / m_max, torch.ones_like(m_max, dtype=torch.float32, device=device)) else: m_scale = torch.ones_like(m_max, dtype=torch.float32, device=device) return m_position, m_scale, m_offset
def get_MIs( X, y, noise_amp_all, group_sizes, mode="xn-y", noise_type="uniform-series", estimate_method="k-nearest", ): assert len(X.shape) == len(y.shape) == 3 _, K, N = X.shape if isinstance(group_sizes, int): num_models = int(N / group_sizes) else: num_models = len(group_sizes) num = noise_amp_all.size(0) if noise_type == "uniform-series": MI = np.zeros((num, num_models)) elif noise_type == "fully-random": MI = np.zeros((num, K, num_models)) else: raise X_std = X.std(0) is_cuda = X.is_cuda device = torch.device("cuda" if is_cuda else "cpu") if noise_type == "uniform-series": for i in range(num): noise_amp_core = X_std * expand_tensor(noise_amp_all[i].to(device), -1, group_sizes) X_tilde = X + torch.randn(X.size()).to(device) * noise_amp_core if mode == "xn-y": arg1 = y arg2 = X_tilde elif mode == "x-y": arg1 = y arg2 = X elif mode == "xn-x": arg1 = X_tilde arg2 = X else: raise for j in range(num_models): if mode == "xn-x": if estimate_method == "k-nearest": MI[i, j] = ee.mi( to_np_array(arg1[:, :, j * group_sizes:(j + 1) * group_sizes].contiguous().view( arg1.size(0), -1)), to_np_array(arg2[:, :, j * group_sizes:(j + 1) * group_sizes].contiguous().view( arg2.size(0), -1))) elif estimate_method == "Gaussian": entropy_X_tilde = get_entropy_Gaussian( arg1[:, :, j * group_sizes:(j + 1) * group_sizes].contiguous().view( arg1.size(0), -1), is_diagonal=False) KM = group_sizes * K entropy_noise = ( KM / float(2) * np.log2(2 * np.pi * np.e) + torch.log2(noise_amp_core[:, j]).sum()) MI[i, j] = entropy_X_tilde - entropy_noise else: raise else: if estimate_method == "k-nearest": MI[i, j] = ee.mi( to_np_array(arg1[:, :, i * group_sizes:(i + 1) * group_sizes].contiguous().view( arg1.size(0), -1)), to_np_array(arg2[:, :, j * group_sizes:(j + 1) * group_sizes].contiguous().view( arg2.size(0), -1))) elif estimate_method == "Gaussian": MI[i, j] = get_entropy_Gaussian( arg1[:, :, i * group_sizes:(i + 1) * group_sizes].contiguous().view( arg1.size(0), -1), arg2[:, :, j * group_sizes:(j + 1) * group_sizes].contiguous().view( arg2.size(0), -1)) else: raise elif noise_type == "fully-random": for i in range(num): noise_amp_core = X_std * noise_amp_all[i].to(device) X_tilde = X + torch.randn(X.size()).to(device) * noise_amp_core if mode == "xn-y": arg1 = y arg2 = X_tilde elif mode == "x-y": arg1 = y arg2 = X elif mode == "xn-x": arg1 = X_tilde arg2 = X else: raise for k in range(K): for j in range(num_models): if mode == "xn-x": MI[i, k, j] = ee.mi( to_np_array(arg1[:, k, j * group_sizes:(j + 1) * group_sizes]), to_np_array(arg2[:, k, j * group_sizes:(j + 1) * group_sizes])) else: MI[i, k, j] = ee.mi( to_np_array(arg1[:, :, i * group_sizes:(i + 1) * group_sizes].contiguous().view( arg1.size(0), -1)), to_np_array(arg2[:, k, j * group_sizes:(j + 1) * group_sizes])) else: raise return MI
def shift(x, ceil=True): #TODO: edge case, when x contains 0 if ceil: return 2.**torch.ceil(torch.log2(x)) else: return 2.**torch.round(torch.log2(x))
def pow_2_round(dims): return 2 ** torch.round(torch.log2(dims.type(torch.float)))
def forward(self, input): if self.radix_position is None: return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) Qn = -2**(self.nbits - 1) Qp = 2**(self.nbits - 1) - 1 if self.init_state == 0: il = torch.log2(self.weight.abs().max()) + 1 il = math.ceil(il - 1e-5) self.radix_position.data.fill_(self.nbits - il) print('Initialize radix position of {} with {}'.format( self._get_name(), int(self.radix_position.item()))) alpha = 2**self.radix_position w_int = round_pass((self.weight * alpha).clamp(Qn, Qp)) w_q = w_int / alpha self.weight_int.data.copy_(w_int) self.weight_old.data.copy_(self.weight) if self.training: if self.expected_bit_sparsity is None: bit_cnt = count_bit(w_int) original_bit_sparsity = bit_sparse(bit_cnt) self.expected_bit_sparsity = self.expected_bit_sparsity_func( original_bit_sparsity) print('original: {:.3f} expected: {:.3f}'.format( original_bit_sparsity, self.expected_bit_sparsity)) self.mask = (w_int.abs() > 0).float() self.init_state.fill_(1) else: # quantize weight alpha = 2**self.radix_position w_int = round_pass((self.weight * alpha).clamp(Qn, Qp)) w_q = w_int / alpha if self.training: bit_cnt_old = count_bit(self.weight_int) # bit_sparsity_new = bit_sparse(bit_cnt_new, self.complement) bit_sparsity_old = bit_sparse(bit_cnt_old) if bit_sparsity_old < self.expected_bit_sparsity: # need bit pruning bit_cnt_new = count_bit(w_int) bit_increase = bit_cnt_new - bit_cnt_old case = (bit_increase > 0) # todo: bug always False w_q = torch.where( case, self.weight_int.float() * 2**(-self.radix_position), w_q) # don't work self.weight.data.copy_( torch.where(case, self.weight_old, self.weight)) self.weight_old.data.copy_(self.weight) self.weight_int.data.copy_(w_q * 2**self.radix_position) else: # don't need bit pruning # print('do not need bit pruning') # use new weights self.weight_old.data.copy_(self.weight) self.weight_int.data.copy_(w_int) else: # inference w_q = self.weight_int.data.float() * 2**(-self.radix_position) # weight has no grad, why? (update optimizer after wrapper.replacement) weight_mask = FunctionStopGradient.apply(self.weight, self.mask) # weight_mask = self.weight * self.mask # STE for quantized weight. weight_bp = w_q.detach() + weight_mask - weight_mask.detach() out = F.conv2d(input, weight_bp, self.bias, self.stride, self.padding, self.dilation, self.groups) return out
def projection_linf(self, points_to_project, w_hyperplane, b_hyperplane): t = points_to_project.clone() w = w_hyperplane.clone() b = b_hyperplane.clone() ind2 = ((w * t).sum(1) - b < 0).nonzero().squeeze() ind2 = self.check_shape(ind2) w[ind2] *= -1 b[ind2] *= -1 c5 = (w < 0).float() a = torch.ones(t.shape).to(self.device) d = (a * c5 - t) * (w != 0).float() a -= a * (1 - c5) p = torch.ones(t.shape).to(self.device) * c5 - t * (2 * c5 - 1) indp = torch.argsort(p, dim=1) b = b - (w * t).sum(1) b0 = (w * d).sum(1) b1 = b0.clone() counter = 0 indp2 = indp.unsqueeze(-1).flip(dims=(1, 2)).squeeze() u = torch.arange(0, w.shape[0]) ws = w[u.unsqueeze(1), indp2] bs2 = -ws * d[u.unsqueeze(1), indp2] s = torch.cumsum(ws.abs(), dim=1) sb = torch.cumsum(bs2, dim=1) + b0.unsqueeze(1) c = b - b1 > 0 b2 = sb[u, -1] - s[u, -1] * p[u, indp[u, 0]] c_l = (b - b2 > 0).nonzero().squeeze() c2 = ((b - b1 > 0) * (b - b2 <= 0)).nonzero().squeeze() c_l = self.check_shape(c_l) c2 = self.check_shape(c2) lb = torch.zeros(c2.shape[0]) ub = torch.ones(c2.shape[0]) * (w.shape[1] - 1) nitermax = torch.ceil(torch.log2(torch.tensor(w.shape[1]).float())) counter2 = torch.zeros(lb.shape).long() while counter < nitermax: counter4 = torch.floor((lb + ub) / 2) counter2 = counter4.long() indcurr = indp[c2, -counter2 - 1] b2 = sb[c2, counter2] - s[c2, counter2] * p[c2, indcurr] c = b[c2] - b2 > 0 ind3 = c.nonzero().squeeze() ind32 = (~c).nonzero().squeeze() ind3 = self.check_shape(ind3) ind32 = self.check_shape(ind32) lb[ind3] = counter4[ind3] ub[ind32] = counter4[ind32] counter += 1 lb = lb.long() counter2 = 0 if c_l.nelement != 0: lmbd_opt = (torch.max( (b[c_l] - sb[c_l, -1]) / (-s[c_l, -1]), torch.zeros(sb[c_l, -1].shape).to(self.device))).unsqueeze(-1) d[c_l] = (2 * a[c_l] - 1) * lmbd_opt lmbd_opt = (torch.max( (b[c2] - sb[c2, lb]) / (-s[c2, lb]), torch.zeros(sb[c2, lb].shape).to(self.device))).unsqueeze(-1) d[c2] = torch.min(lmbd_opt, d[c2]) * c5[c2]\ + torch.max(-lmbd_opt, d[c2]) * (1 - c5[c2]) return d * (w != 0).float()
def projection_l2(self, points_to_project, w_hyperplane, b_hyperplane): t = points_to_project.clone() w = w_hyperplane.clone() b = b_hyperplane.clone() c = (w * t).sum(1) - b ind2 = (c < 0).nonzero().squeeze() ind2 = self.check_shape(ind2) w[ind2] *= -1 c[ind2] *= -1 u = torch.arange(0, w.shape[0]).unsqueeze(1) r = torch.max(t / w, (t - 1) / w) u2 = torch.ones(r.shape).to(self.device) r = torch.min(r, 1e12 * u2) r = torch.max(r, -1e12 * u2) r[w.abs() < 1e-8] = 1e12 r[r == -1e12] = -r[r == -1e12] rs, indr = torch.sort(r, dim=1) rs2 = torch.cat( (rs[:, 1:], torch.zeros(rs.shape[0], 1).to(self.device)), 1) rs[rs == 1e12] = 0 rs2[rs2 == 1e12] = 0 w3 = w**2 w3s = w3[u, indr] w5 = w3s.sum(dim=1, keepdim=True) ws = w5 - torch.cumsum(w3s, dim=1) d = -(r * w).clone() d = d * (w.abs() > 1e-8).float() s = torch.cat( ((-w5.squeeze() * rs[:, 0]).unsqueeze(1), torch.cumsum( (-rs2 + rs) * ws, dim=1) - w5 * rs[:, 0].unsqueeze(-1)), 1) c4 = (s[:, 0] + c < 0) c3 = ((d * w).sum(dim=1) + c > 0) c6 = c4.nonzero().squeeze() c2 = ((1 - c4.float()) * (1 - c3.float())).nonzero().squeeze() c6 = self.check_shape(c6) c2 = self.check_shape(c2) counter = 0 lb = torch.zeros(c2.shape[0]) ub = torch.ones(c2.shape[0]) * (w.shape[1] - 1) nitermax = torch.ceil(torch.log2(torch.tensor(w.shape[1]).float())) counter2 = torch.zeros(lb.shape).long() while counter < nitermax: counter4 = torch.floor((lb + ub) / 2) counter2 = counter4.long() c3 = s[c2, counter2] + c[c2] > 0 ind3 = c3.nonzero().squeeze() ind32 = (~c3).nonzero().squeeze() ind3 = self.check_shape(ind3) ind32 = self.check_shape(ind32) lb[ind3] = counter4[ind3] ub[ind32] = counter4[ind32] counter += 1 lb = lb.long() alpha = torch.zeros([1]) if c6.nelement() != 0: alpha = c[c6] / w5[c6].squeeze(-1) d[c6] = -alpha.unsqueeze(-1) * w[c6] if c2.nelement() != 0: alpha = (s[c2, lb] + c[c2]) / ws[c2, lb] + rs[c2, lb] if torch.sum(ws[c2, lb] == 0) > 0: ind = (ws[c2, lb] == 0).nonzero().squeeze().long() ind = self.check_shape(ind) alpha[ind] = 0 c5 = (alpha.unsqueeze(-1) > r[c2]).float() d[c2] = d[c2] * c5 - alpha.unsqueeze(-1) * w[c2] * (1 - c5) return d * (w.abs() > 1e-8).float()
def inner_train(self, batch_preds, batch_stds, **kwargs): ''' per-query training process :param batch_preds: [batch, ranking_size] each row represents the relevance predictions for documents within a ltr_adhoc :param batch_stds: [batch, ranking_size] each row represents the standard relevance grades for documents within a ltr_adhoc :return: ''' label_type = kwargs['label_type'] assert label_type == LABEL_TYPE.MultiLabel if 'presort' in kwargs and kwargs['presort']: target_batch_preds, target_batch_stds = batch_preds, batch_stds else: target_batch_stds, batch_sorted_inds = torch.sort(batch_stds, dim=1, descending=True) target_batch_preds = torch.gather(batch_preds, dim=1, index=batch_sorted_inds) batch_preds_sorted, batch_preds_sorted_inds = torch.sort( target_batch_preds, dim=1, descending=True ) # sort documents according to the predicted relevance batch_stds_sorted_via_preds = torch.gather( target_batch_stds, dim=1, index=batch_preds_sorted_inds ) # reorder batch_stds correspondingly so as to make it consistent. BTW, batch_stds[batch_preds_sorted_inds] only works with 1-D tensor batch_std_ranks = torch.arange(target_batch_preds.size(1)).type( torch.cuda.FloatTensor) if self.gpu else torch.arange( target_batch_preds.size(1)).type(torch.FloatTensor) dists_1D = 1.0 / torch.log2( batch_std_ranks + 2.0) # discount co-efficients # ideal dcg values based on optimal order batch_idcgs = torch_dcg_at_k(batch_sorted_labels=target_batch_stds, gpu=self.gpu) if label_type == LABEL_TYPE.MultiLabel: batch_gains = torch.pow(2.0, batch_stds_sorted_via_preds) - 1.0 elif label_type == LABEL_TYPE.Permutation: batch_gains = batch_stds_sorted_via_preds else: raise NotImplementedError batch_n_gains = batch_gains / batch_idcgs # normalised gains if 'NDCG_Loss1' == self.loss_type: power_weights = ndcg_loss1_power_weights( batch_n_gains=batch_n_gains, discounts=dists_1D) elif 'NDCG_Loss2' == self.loss_type: power_weights = ndcg_loss2_power_weights( batch_n_gains=batch_n_gains, discounts=dists_1D) elif 'NDCG_Loss2++' == self.loss_type: power_weights = ndcg_loss2plusplus_power_weights( batch_n_gains=batch_n_gains, discounts=dists_1D, mu=self.mu) batch_pred_diffs = ( torch.unsqueeze(batch_preds_sorted, dim=2) - torch.unsqueeze(batch_preds_sorted, dim=1)).clamp( min=-1e8, max=1e8) # computing pairwise differences, i.e., s_i - s_j batch_pred_diffs[torch.isnan(batch_pred_diffs)] = 0. weighted_probas = (torch.sigmoid(self.sigma * batch_pred_diffs).clamp( min=epsilon)**power_weights).clamp(min=epsilon) log_weighted_probas = torch.log2(weighted_probas) # mask for truncation based on cutoff k trunc_mask = torch.zeros( (target_batch_preds.shape[1], target_batch_preds.shape[1]), dtype=torch.bool, device=self.device) trunc_mask[:self.k, :self.k] = 1 if self.loss_type in ['NDCG_Loss2', 'NDCG_Loss2++']: batch_std_diffs = torch.unsqueeze( batch_stds_sorted_via_preds, dim=2) - torch.unsqueeze( batch_stds_sorted_via_preds, dim=1) # standard pairwise differences, i.e., S_{ij} padded_pairs_mask = batch_std_diffs > 0 padded_log_weighted_probas = log_weighted_probas[padded_pairs_mask & trunc_mask] else: padded_log_weighted_probas = log_weighted_probas[trunc_mask[ None, :, :]] batch_loss = -torch.sum(padded_log_weighted_probas) self.optimizer.zero_grad() batch_loss.backward() self.optimizer.step() return batch_loss
def logisticloss(D): """ k-way logistic loss """ return torch.log2(1 + (torch.exp(D)).squeeze(-1).sum(-1))
def get_complexity(state, obs, sentid): ''' Generates complexity output for given state, observation, and sentid ''' Hs = torch.log2(torch.exp(torch.squeeze(apply(get_entropy, state)))) surps = torch.log2(torch.exp(apply(get_surps, state))) if args.guess: guesses = apply(get_guesses, state) guessscores = apply(get_guessscores, state) for corpuspos, targ in enumerate(obs): word = corpus.dictionary.idx2word[int(targ)] if word == '<eos>': # don't output the complexity of EOS continue surp = surps[corpuspos][int(targ)] if args.guess: outputguesses = [] for guess_ix in range(args.guessn): outputguesses.append(corpus.dictionary.idx2word[int( guesses[corpuspos][guess_ix])]) if args.guessscores: # output raw scores outputguesses.append("{:.3f}".format( float(guessscores[corpuspos][guess_ix]))) elif args.guessratios: # output scores (ratio of score(x)/score(best guess) outputguesses.append("{:.3f}".format( float(guessscores[corpuspos][guess_ix]) / float(guessscores[corpuspos][0]))) elif args.guessprobs: # output probabilities # Currently normalizes probs over N-best list; # ideally it'd normalize to probs before getting the N-best outputguesses.append("{:.3f}".format( math.exp( float( nn.functional.log_softmax( guessscores[corpuspos], dim=0)[guess_ix])))) outputguesses = args.csep.join(outputguesses) print( args.csep.join([ str(word), str(sentid), str(corpuspos), str(len(word)), str(float(surp)), str(float(Hs[corpuspos])), str( max( 0, float(Hs[max(corpuspos - 1, 0)]) - float(Hs[corpuspos]))), str(outputguesses) ])) else: print( args.csep.join([ str(word), str(sentid), str(corpuspos), str(len(word)), str(float(surp)), str(float(Hs[corpuspos])), str( max( 0, float(Hs[max(corpuspos - 1, 0)]) - float(Hs[corpuspos]))) ]))