class nice(nn.Module): def __init__(self, im_size): super(nice, self).__init__() self.cp1 = additive_alternant_coupling_layer(im_size, 1024) self.cp2 = additive_alternant_coupling_layer(im_size, 1024) self.cp3 = additive_alternant_coupling_layer(im_size, 1024) self.scale = Parameter(torch.zeros(im_size)) self.to(DEVICE) def forward(self, x, inv=False): if not inv: cp1 = self.cp1(x) cp2 = self.cp2(cp1) cp3 = self.cp3(cp2) return torch.exp(self.scale) * cp3 else: cp3 = x * torch.exp(-self.scale) cp2 = self.cp3(cp3, True) cp1 = self.cp2(cp2, True) return self.cp1(cp1, True) def log_logistic(self, h): return -(F.softplus(h) + F.softplus(-h)) def train_loss(self, h): return -(self.log_logistic(h).sum(1).mean() + self.scale.sum())
class PopArt(Module): """PopArt http://papers.nips.cc/paper/6076-learning-values-across-many-orders-of-magnitude""" def __init__(self, output_layer, beta: float = 0.0003, zero_debias: bool = True, start_pop: int = 8): # zero_debias=True and start_pop=8 seem to improve things a little but (False, 0) works as well super().__init__() self.start_pop = start_pop self.beta = beta self.zero_debias = zero_debias self.output_layers = output_layer if isinstance(output_layer, (tuple, list, torch.nn.ModuleList)) else (output_layer,) shape = self.output_layers[0].bias.shape device = self.output_layers[0].bias.device assert all(shape == x.bias.shape for x in self.output_layers) self.mean = Parameter(torch.zeros(shape, device=device), requires_grad=False) self.mean_square = Parameter(torch.ones(shape, device=device), requires_grad=False) self.std = Parameter(torch.ones(shape, device=device), requires_grad=False) self.updates = 0 @torch.no_grad() def update(self, targets): beta = max(1 / (self.updates + 1), self.beta) if self.zero_debias else self.beta # note that for beta = 1/self.updates the resulting mean, std would be the true mean and std over all past data new_mean = (1 - beta) * self.mean + beta * targets.mean(0) new_mean_square = (1 - beta) * self.mean_square + beta * (targets * targets).mean(0) new_std = (new_mean_square - new_mean * new_mean).sqrt().clamp(0.0001, 1e6) # assert self.std.shape == (1,), 'this has only been tested in 1D' if self.updates >= self.start_pop: for layer in self.output_layers: layer.weight *= (self.std / new_std)[:, None] layer.bias *= self.std layer.bias += self.mean - new_mean layer.bias /= new_std self.mean.copy_(new_mean) self.mean_square.copy_(new_mean_square) self.std.copy_(new_std) self.updates += 1 return self.normalize(targets) def normalize(self, x): return (x - self.mean) / self.std def unnormalize(self, x): return x * self.std + self.mean def normalize_sum(self, s): """normalize x.sum(1) preserving relative weightings between elements""" return (s - self.mean.sum()) / self.std.norm()
class SimplePolicyContinuous(nn.Module): """ Simple policy for continuous actions, using a Gaussian normal as the distribution from which to sample the actions. The parameters (mean, std) of the distribution are computed from the state using a neural network. The architecture is taken from: https://github.com/lantunes/mountain-car-continuous/blob/master/rl/reinforce/agent.py """ def __init__(self, input_size: int, output_size: int): super(SimplePolicyContinuous, self).__init__() self.output_size = output_size self.affine1Mu = nn.Linear(input_size, 128) self.affine1Mu.bias.data.fill_(0) self.affine2Mu = nn.Linear(128, 128) self.affine2Mu.bias.data.fill_(0) self.affine3Mu = nn.Linear(128, output_size, bias=False) self.affine3Mu.weight.data.fill_(0) # Important: It must be a Parameter and not a variable! If it's a variable # then it won't be part of the parameters given to the optimizer, meaning # it will never change. This will mean sigma will never improve and so # the results will stay very bad. self.hiddenSigma = Parameter(torch.zeros(32), requires_grad=True) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: mu = self.affine1Mu(x) mu = torch.tanh(mu) mu = self.affine2Mu(mu) mu = torch.tanh(mu) mu = self.affine3Mu(mu) """ This network has decided that the sigma will not depend on the state. A possible question one might be asking is: why is there the exp() call below? I found the answer in the Spinning Up website by OpenAI: "Note that in both cases we output log standard deviations instead of standard deviations directly. This is because log stds are free to take on any values in (-oo, oo), while stds must be nonnegative. It’s easier to train parameters if you don’t have to enforce those kinds of constraints. The standard deviations can be obtained immediately from the log standard deviations by exponentiating them, so we do not lose anything by representing them this way." Src: https://spinningup.openai.com/en/latest/spinningup/rl_intro.html """ sigma = self.hiddenSigma.sum() sigma = torch.exp(sigma) return mu, sigma
class ProbalisticLinear(Module): __constants__ = ['bias'] def __init__(self, in_features, out_features, eps): super(ProbalisticLinear, self).__init__() self.in_features = in_features self.eps = eps self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) self.reset_parameters() def reset_parameters(self): init.kaiming_uniform_(self.weight, a=math.sqrt(5)) @weak_script_method def forward(self, input): weight = self.weight / self.weight.sum(1, keepdim=True) # .clamp(min=self.eps) return F.linear(input, weight) def extra_repr(self): return 'in_features={}, out_features={} (eps={})'.format( self.in_features, self.out_features, self.eps )
class FilterStripe(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): super(FilterStripe, self).__init__(in_channels, out_channels, kernel_size, stride, kernel_size // 2, groups=1, bias=False) self.BrokenTarget = None self.FilterSkeleton = Parameter(torch.ones(self.out_channels, self.kernel_size[0], self.kernel_size[1]), requires_grad=True) def forward(self, x): if self.BrokenTarget is not None: out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0], int(np.ceil(x.shape[2] / self.stride[0])), int(np.ceil(x.shape[3] / self.stride[1]))) if x.is_cuda: out = out.cuda() x = F.conv2d(x, self.weight) l, h = 0, 0 for i in range(self.BrokenTarget.shape[0]): for j in range(self.BrokenTarget.shape[1]): h += self.FilterSkeleton[:, i, j].sum().item() out[:, self.FilterSkeleton[:, i, j]] += self.shift( x[:, l:h], i, j)[:, :, ::self.stride[0], ::self.stride[1]] l += self.FilterSkeleton[:, i, j].sum().item() return out else: return F.conv2d(x, self.weight * self.FilterSkeleton.unsqueeze(1), stride=self.stride, padding=self.padding, groups=self.groups) def prune_in(self, in_mask=None): self.weight = Parameter(self.weight[:, in_mask]) self.in_channels = in_mask.sum().item() def prune_out(self, threshold): out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0 if out_mask.sum() == 0: out_mask[0] = True self.weight = Parameter(self.weight[out_mask]) self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask], requires_grad=True) self.out_channels = out_mask.sum().item() return out_mask def _break(self, threshold): self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1)) self.FilterSkeleton = Parameter( (self.FilterSkeleton.abs() > threshold), requires_grad=False) if self.FilterSkeleton.sum() == 0: self.FilterSkeleton.data[0][0][0] = True self.out_channels = self.FilterSkeleton.sum().item() self.BrokenTarget = self.FilterSkeleton.sum(dim=0) self.kernel_size = (1, 1) self.weight = Parameter( self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)[self.FilterSkeleton.permute( 1, 2, 0).reshape(-1)]) def update_skeleton(self, sr, threshold): self.FilterSkeleton.grad.data.add_( sr * torch.sign(self.FilterSkeleton.data)) mask = self.FilterSkeleton.data.abs() > threshold self.FilterSkeleton.data.mul_(mask) self.FilterSkeleton.grad.data.mul_(mask) out_mask = mask.sum(dim=(1, 2)) != 0 return out_mask def shift(self, x, i, j): return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j, j - self.BrokenTarget.shape[0] // 2, self.BrokenTarget.shape[0] // 2 - i, i - self.BrokenTarget.shape[1] // 2), 'constant', 0) def extra_repr(self): s = ( '{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') return s.format(**self.__dict__)
class PGDAttack(BaseAttack): """ Spectral attack for graph data """ def __init__(self, model=None, nnodes=None, loss_type='CE', feature_shape=None, attack_structure=True, attack_features=False, loss_weight=1.0, regularization_weight=0.0, device='cpu'): super(PGDAttack, self).__init__(model, nnodes, attack_structure, attack_features, device) assert attack_structure or attack_features, 'attack_feature or attack_structure cannot be both False' self.loss_type = loss_type self.modified_adj = None self.modified_features = None self.loss_weight = loss_weight self.regularization_weight = regularization_weight if attack_features: assert True, 'Current Spectral Attack does not support attack feature' if attack_structure: assert nnodes is not None, 'Please give nnodes=' self.adj_changes = Parameter( torch.FloatTensor(int(nnodes * (nnodes - 1) / 2))) torch.nn.init.uniform_(self.adj_changes, 0.0, 0.001) # self.adj_changes.data.fill_(0) self.complementary = None def set_model(self, model): self.surrogate = model def attack(self, ori_features, ori_adj, labels, idx_target, n_perturbations, att_lr, epochs=200, distance_type='l2', sample_type='sample', opt_type='max', verbose=True, **kwargs): """ Generate perturbations on the input graph """ victim_model = self.surrogate self.sparse_features = sp.issparse(ori_features) # ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device) ori_adj_norm = utils.normalize_adj_tensor(ori_adj, device=self.device) ori_e, ori_v = torch.symeig(ori_adj_norm, eigenvectors=True) l, r, m = 0, 0, 0 victim_model.eval() # for t in tqdm(range(epochs), desc='Perturb Adj'): for t in tqdm(range(epochs)): modified_adj = self.get_modified_adj(ori_adj) adj_norm = utils.normalize_adj_tensor(modified_adj, device=self.device) output = victim_model( ori_features, adj_norm) # forward of gcn need to normalize adj first task_loss = self._loss(output[idx_target], labels[idx_target]) # spectral distance term for spectral distance eigen_mse = torch.tensor(0) eigen_self = torch.tensor(0) eigen_gf = torch.tensor(0) eigen_norm = self.norm = torch.norm(ori_e) if self.regularization_weight != 0: # add noise to make the graph asymmetric modified_adj_noise = modified_adj # modified_adj_noise = self.add_random_noise(modified_adj) adj_norm_noise = utils.normalize_adj_tensor(modified_adj_noise, device=self.device) e, v = torch.symeig(adj_norm_noise, eigenvectors=True) eigen_mse = torch.norm(ori_e - e) eigen_self = torch.norm(e) # low-rank loss in GF-attack idx = torch.argsort(e)[:128] mask = torch.zeros_like(e).bool() mask[idx] = True eigen_gf = torch.pow(torch.norm(e * mask, p=2), 2) * torch.pow( torch.norm(torch.matmul(v.detach() * mask, ori_features), p=2), 2) reg_loss = 0 if distance_type == 'l2': reg_loss = eigen_mse / eigen_norm elif distance_type == 'normDiv': reg_loss = eigen_self / eigen_norm elif distance_type == 'gf': reg_loss = eigen_gf else: exit(f'unknown distance metric: {distance_type}') if verbose and t % 20 == 0: loss_target, acc_target = calc_acc(output, labels, idx_target) print( '-- Epoch {}, '.format(t), 'ptb budget/true = {:.1f}/{:.1f}'.format( n_perturbations, torch.clamp(self.adj_changes, 0, 1).sum()), 'l/r/m = {:.4f}/{:.4f}/{:.4f}'.format(l, r, m), 'class loss = {:.4f} | '.format(task_loss.item()), 'reg loss = {:.4f} | '.format(reg_loss.item()), 'mse_norm = {:4f} | '.format(eigen_norm), 'eigen_mse = {:.4f} | '.format(eigen_mse), 'eigen_self = {:.4f} | '.format(eigen_self), 'acc/mis = {:.4f}/{:.4f}'.format(acc_target, 1 - acc_target)) self.loss = self.loss_weight * task_loss + self.regularization_weight * reg_loss adj_grad = torch.autograd.grad(self.loss, self.adj_changes)[0] if self.loss_type == 'CE': lr = att_lr / np.sqrt(t + 1) self.adj_changes.data.add_(lr * adj_grad) if self.loss_type == 'CW': lr = att_lr / np.sqrt(t + 1) self.adj_changes.data.add_(lr * adj_grad) # return self.adj_changes.cpu().detach().numpy() if verbose and t % 20 == 0: print('budget/true={:.1f}/{:.1f}'.format( n_perturbations, torch.clamp(self.adj_changes, 0, 1).sum())) if sample_type == 'sample': l, r, m = self.projection(n_perturbations) elif sample_type == 'greedy': self.greedy(n_perturbations) elif sample_type == 'greedy2': self.greedy2(n_perturbations) elif sample_type == 'greedy3': self.greedy3(n_perturbations) else: exit(f"unkown sample type {sample_type}") if verbose and t % 20 == 0: print('budget/true={:.1f}/{:.1f}'.format( n_perturbations, torch.clamp(self.adj_changes, 0, 1).sum())) if sample_type == 'sample': self.random_sample(ori_adj, ori_features, labels, idx_target, n_perturbations) elif sample_type == 'greedy': self.greedy(n_perturbations) elif sample_type == 'greedy2': self.greedy2(n_perturbations) elif sample_type == 'greedy3': self.greedy3(n_perturbations) else: exit(f"unkown sample type {sample_type}") print("final ptb budget/true= {:.1f}/{:.1f}".format( n_perturbations, self.adj_changes.sum())) self.modified_adj = self.get_modified_adj(ori_adj).detach() self.check_adj_tensor(self.modified_adj) # for sanity check ori_adj_norm = utils.normalize_adj_tensor(ori_adj, device=self.device) ori_e, ori_v = torch.symeig(ori_adj_norm, eigenvectors=True) adj_norm = utils.normalize_adj_tensor(self.modified_adj, device=self.device) e, v = torch.symeig(adj_norm, eigenvectors=True) self.adj = ori_adj.detach() self.labels = labels.detach() self.ori_e = ori_e self.ori_v = ori_v self.e = e self.v = v def greedy(self, n_perturbations): s = self.adj_changes.cpu().detach().numpy() # l = min(s) # r = max(s) # noise = np.random.normal((l+r)/2, 0.1*(r-l), s.shape) # s += noise s_vec = np.squeeze(np.reshape(s, (1, -1))) # max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations] max_index = (-s_vec).argsort()[:n_perturbations] mask = np.zeros_like(s_vec) mask[max_index] = 1.0 best_s = np.reshape(mask, s.shape) self.adj_changes.data.copy_( torch.clamp(torch.tensor(best_s), min=0, max=1)) def greedy3(self, n_perturbations): s = self.adj_changes.cpu().detach().numpy() s_vec = np.squeeze(np.reshape(s, (1, -1))) # max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations] max_index = (s_vec).argsort()[:n_perturbations] mask = np.zeros_like(s_vec) mask[max_index] = 1.0 best_s = np.reshape(mask, s.shape) self.adj_changes.data.copy_( torch.clamp(torch.tensor(best_s), min=0, max=1)) def greedy2(self, n_perturbations): s = self.adj_changes.cpu().detach().numpy() l = min(s) r = max(s) noise = np.random.normal((l + r) / 2, 0.4 * (r - l), s.shape) s += noise s_vec = np.squeeze(np.reshape(s, (1, -1))) max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations] mask = np.zeros_like(s_vec) mask[max_index] = 1.0 best_s = np.reshape(mask, s.shape) self.adj_changes.data.copy_( torch.clamp(torch.tensor(best_s), min=0, max=1)) def random_sample(self, ori_adj, ori_features, labels, idx_target, n_perturbations): K = 10 best_loss = -1000 victim_model = self.surrogate with torch.no_grad(): s = self.adj_changes.cpu().detach().numpy() for i in range(K): sampled = np.random.binomial(1, s) # randm = np.random.uniform(size=s.shape[0]) # sampled = np.where(s > randm, 1, 0) # if sampled.sum() > n_perturbations: # continue while sampled.sum() > n_perturbations: sampled = np.random.binomial(1, s) # if sampled.sum() > n_perturbations: # indices = np.transpose(np.nonzero(sampled)) # candidate_idx = [m for m in range(indices.shape[0])] # chosen_idx = np.random.choice(candidate_idx, n_perturbations, replace=False) # chosen_indices = indices[chosen_idx, :] # sampled = np.zeros_like(sampled) # for idx in chosen_indices: # sampled[idx] = 1 self.adj_changes.data.copy_(torch.tensor(sampled)) modified_adj = self.get_modified_adj(ori_adj) adj_norm = utils.normalize_adj_tensor(modified_adj, device=self.device) output = victim_model(ori_features, adj_norm) loss = self._loss(output[idx_target], labels[idx_target]) # loss = F.nll_loss(output[idx_target], labels[idx_target]) # print(loss) if best_loss < loss: best_loss = loss best_s = sampled self.adj_changes.data.copy_(torch.tensor(best_s)) def get_modified_adj(self, ori_adj): if self.complementary is None: self.complementary = (torch.ones_like(ori_adj) - torch.eye( self.nnodes).to(self.device) - ori_adj) - ori_adj m = torch.zeros((self.nnodes, self.nnodes)).to(self.device) tril_indices = torch.tril_indices(row=self.nnodes, col=self.nnodes, offset=-1) m[tril_indices[0], tril_indices[1]] = self.adj_changes m = m + m.t() modified_adj = self.complementary * m + ori_adj return modified_adj def add_random_noise(self, ori_adj): noise = 1e-4 * torch.rand(self.nnodes, self.nnodes).to(self.device) return (noise + torch.transpose(noise, 0, 1)) / 2.0 + ori_adj def projection2(self, n_perturbations): s = self.adj_changes.cpu().detach().numpy() n = np.squeeze(np.reshape(s, (1, -1))).shape[0] self.adj_changes.data.copy_( torch.clamp(self.adj_changes.data, min=0, max=n_perturbations / n)) return 0, 0, 0 def projection(self, n_perturbations): l, r, m = 0, 0, 0 if torch.clamp(self.adj_changes, 0, 1).sum() > n_perturbations: left = (self.adj_changes).min() right = self.adj_changes.max() miu = self.bisection(left, right, n_perturbations, epsilon=1e-5) l = left.cpu().detach() r = right.cpu().detach() m = miu.cpu().detach() self.adj_changes.data.copy_( torch.clamp(self.adj_changes.data - miu, min=0, max=1)) else: self.adj_changes.data.copy_( torch.clamp(self.adj_changes.data, min=0, max=1)) return l, r, m def _loss(self, output, labels): if self.loss_type == "CE": loss = F.nll_loss(output, labels) if self.loss_type == "CW": onehot = utils.tensor2onehot(labels) best_second_class = (output - 1000 * onehot).argmax(1).detach() margin = output[np.arange(len(output)), labels] - \ output[np.arange(len(output)), best_second_class] k = 0 loss = -torch.clamp(margin, min=k).mean() # loss = torch.clamp(margin.sum()+50, min=k) return loss def bisection(self, a, b, n_perturbations, epsilon): def func(x): return torch.clamp(self.adj_changes - x, 0, 1).sum() - n_perturbations miu = a while ((b - a) >= epsilon): miu = (a + b) / 2 # Check if middle point is root if (func(miu) == 0.0): b = miu break # Decide the side to repeat the steps if (func(miu) * func(a) < 0): b = miu else: a = miu # print("The value of root is : ","%.4f" % miu) return miu
class SafetyNet(nn.Module): def __init__(self, nKnapsackCategories, nThresholds, starting_thresholds, nineq=1, neq=0, eps=1e-8, cancel_rate_target=.05, cancel_rate_evaluation=.05, accept_rate_target=.75, accept_rate_evaluation=.75, cancel_initializer=.02, inventory_initializer=3, cancel_coef_initializer=-.2, cancel_intercept_initializer=.3, price_initializer=1, parametric_knapsack=False, knapsack_type=None): super().__init__() self.nKnapsackCategories = nKnapsackCategories self.nThresholds = nThresholds #self.nBatch = nBatch self.nineq = nineq self.neq = neq self.eps = eps self.cancel_rate_evaluation = cancel_rate_evaluation self.accept_rate_evaluation = accept_rate_evaluation self.benchmark_thresholds = Variable(starting_thresholds) #self.accept_rate_original=Parameter(accept_rate*torch.ones(1)) #self.cancel_rate_original=Parameter(cancel_rate*torch.ones(1)) #self.cancel_rate= self.cancel_rate_original*1.0 #self.accept_rate=self.accept_rate_original*1.0 self.accept_rate_param = Parameter(accept_rate_target * torch.ones(1)) self.cancel_rate_param = Parameter(cancel_rate_target * torch.ones(1)) self.inventory_initializer = inventory_initializer self.parametric_knapsack = parametric_knapsack self.h = Variable(torch.ones(self.nineq)) ##Add matrix to make all variables >=0 self.PosValMatrix = -1 * Variable( torch.eye(self.nKnapsackCategories * self.nThresholds)) self.PosValVector = Variable( torch.zeros(self.nKnapsackCategories * self.nThresholds)) #Equality constraints. These will be the constraints to choose one variable per category ##These will be Variables as they are not something that is estimated by the model A = torch.zeros(self.nKnapsackCategories, self.nKnapsackCategories * self.nThresholds) for row in range(self.nKnapsackCategories): A[row][self.nThresholds * row:self.nThresholds * (row + 1)] = 1 self.A = Variable(A) self.b = Variable(torch.ones(self.nKnapsackCategories)) self.Q_zeros = Variable( torch.zeros(nKnapsackCategories * nThresholds, nKnapsackCategories * nThresholds)) #Initialize thresholds self.thresholds = Variable(torch.arange(0, self.nThresholds)) #Initialize cancel and revenue parameters if self.parametric_knapsack: self.thresholds_raw_matrix = Variable(starting_thresholds) #self.cancel_scale = Parameter((torch.rand(self.nKnapsackCategories)+.5)*cancel_initializer) #self.cancel_lam = Parameter(torch.ones(self.nKnapsackCategories)*cancel_initializer) #self.cancel_spread = Variable(torch.ones(self.nKnapsackCategories)) #self.revenue_scale = Parameter((torch.rand(self.nKnapsackCategories)+.5)) #self.revenue_lam = Parameter(torch.ones(self.nKnapsackCategories)) #self.revenue_spread = Variable(torch.ones(self.nKnapsackCategories)) else: #self.thresholds_raw_matrix = Parameter(torch.ones(self.nKnapsackCategories,self.nThresholds)*(1.0/self.nThresholds)) self.thresholds_raw_matrix = Parameter(starting_thresholds) self.thresholds_raw_matrix_norm = torch.div( self.thresholds_raw_matrix, torch.sum(self.thresholds_raw_matrix, dim=1).unsqueeze(1).expand_as( self.thresholds_raw_matrix)) #Inventory distribution parameters self.inventory_lam_opt = Parameter( torch.ones(self.nKnapsackCategories) * inventory_initializer) #Cancel distribution parameters self.cancel_coef_opt = Parameter( torch.ones(self.nKnapsackCategories) * cancel_coef_initializer) self.cancel_intercept_opt = Parameter( torch.ones(self.nKnapsackCategories) * cancel_intercept_initializer) self.prices_opt = Parameter( torch.ones(self.nKnapsackCategories) * price_initializer) self.demand_distribution_opt = Parameter( torch.ones(self.nKnapsackCategories) * (1.0 / self.nKnapsackCategories)) self.inventory_lam_est = Parameter( torch.ones(self.nKnapsackCategories) * inventory_initializer) #Cancel distribution parameters self.cancel_coef_est = Parameter( torch.ones(self.nKnapsackCategories) * cancel_coef_initializer) self.cancel_intercept_est = Parameter( torch.ones(self.nKnapsackCategories) * cancel_intercept_initializer) self.prices_est = Parameter( torch.ones(self.nKnapsackCategories) * price_initializer) self.demand_distribution_est = Parameter( torch.ones(self.nKnapsackCategories) * (1.0 / self.nKnapsackCategories)) def normalize_thresholds(self): self.thresholds_raw_matrix.data.clamp_(min=self.eps, max=1.0 - self.eps) param_sums = torch.ger( self.thresholds_raw_matrix.sum(dim=1).squeeze(), Variable(torch.ones(self.nThresholds))) self.thresholds_raw_matrix.data.div_(param_sums.data) def normalize_demand_params(self): self.demand_distribution_est.data.clamp_(min=self.eps) self.demand_distribution_est.data.div_( self.demand_distribution_est.data.sum()) self.demand_distribution_opt.data.clamp_(min=self.eps) self.demand_distribution_opt.data.div_( self.demand_distribution_opt.data.sum()) def forward(self, category, inv_count, price, cancel, collection_thresholds): #print("collection_thresholds",collection_thresholds) self.lp_infeasible = 0 self.cancel_coef_neg_est = self.cancel_coef_est.clamp(max=0) self.cancel_coef_neg_opt = self.cancel_coef_opt.clamp(max=0) self.nBatch = category.size(0) #x = x.view(nBatch, -1) #We want to compute everything we can without thresholds first. This will allow us to use our learned parameters to feed the LP self.inventory_distribution_raw_est = PoissonFunction( self.nKnapsackCategories, self.nThresholds, verbose=-1)( self.inventory_lam_est, self.thresholds) + self.eps #self.inventory_distribution_norm_est = normalize_JK(self.inventory_distribution_raw_est,dim=1) self.inventory_distribution_batch_by_threshold_est = torch.mm( category, self.inventory_distribution_raw_est) + self.eps self.inventory_distribution_raw_opt = PoissonFunction( self.nKnapsackCategories, self.nThresholds, verbose=-1)( self.inventory_lam_opt, self.thresholds) + self.eps #self.inventory_distribution_norm_opt = normalize_JK(self.inventory_distribution_raw_opt,dim=1) self.inventory_distribution_batch_by_threshold_opt = torch.mm( category, self.inventory_distribution_raw_opt) + self.eps ##Here we'll calculate cancel probability by inventory self.belief_cancel_rate_cXt_est = cancel_rate_belief_cXt( self.cancel_coef_neg_est, self.cancel_intercept_est, self.thresholds.unsqueeze(0).expand(self.nKnapsackCategories, self.nThresholds)) belief_fill_rate_cXt_est = 1 - self.belief_cancel_rate_cXt_est price_cXt_est = self.prices_est.unsqueeze(1).expand( self.nKnapsackCategories, self.nThresholds) ##Here we'll calculate cancel probability by inventory self.belief_cancel_rate_cXt_opt = cancel_rate_belief_cXt( self.cancel_coef_neg_opt, self.cancel_intercept_opt, self.thresholds.unsqueeze(0).expand(self.nKnapsackCategories, self.nThresholds)) belief_fill_rate_cXt_opt = 1 - self.belief_cancel_rate_cXt_opt price_cXt_opt = self.prices_opt.unsqueeze(1).expand( self.nKnapsackCategories, self.nThresholds) self.belief_total_demand_cXt_est = self.inventory_distribution_raw_est * ( self.demand_distribution_est.unsqueeze(1).expand( self.nKnapsackCategories, self.nThresholds)) belief_total_demand_c_vector_est = torch.sum( self.belief_total_demand_cXt_est, dim=1) self.belief_total_demand_cXt_opt = self.inventory_distribution_raw_opt * ( self.demand_distribution_opt.unsqueeze(1).expand( self.nKnapsackCategories, self.nThresholds)) belief_total_demand_c_vector_opt = torch.sum( self.belief_total_demand_cXt_opt, dim=1) if self.parametric_knapsack: self.belief_total_demand_opt = torch.sum( self.belief_total_demand_cXt_opt) self.belief_total_cancels_cXt_opt = self.belief_cancel_rate_cXt_opt * self.belief_total_demand_cXt_opt self.belief_total_fills_cXt_opt = belief_fill_rate_cXt_opt * self.belief_total_demand_cXt_opt self.knapsack_cancels_matrix = torch.div( torch.sum(self.belief_total_cancels_cXt_opt, dim=1).expand_as( self.belief_total_cancels_cXt_opt) - torch.cumsum(self.belief_total_cancels_cXt_opt, dim=1) + self.belief_total_cancels_cXt_opt, self.belief_total_demand_opt.expand(self.nKnapsackCategories, self.nThresholds)) self.knapsack_fills_matrix = torch.div( torch.sum(self.belief_total_fills_cXt_opt, dim=1).expand_as( self.belief_total_fills_cXt_opt) - torch.cumsum(self.belief_total_fills_cXt_opt, dim=1) + self.belief_total_fills_cXt_opt, self.belief_total_demand_opt.expand(self.nKnapsackCategories, self.nThresholds)) self.knapsack_revenues_matrix = self.knapsack_fills_matrix * price_cXt_opt self.knapsack_cancels = self.knapsack_cancels_matrix.view(1, -1) self.knapsack_fills = self.knapsack_fills_matrix.view(1, -1) self.knapsack_revenues = self.knapsack_revenues_matrix.view(-1) Q = self.Q_zeros + self.eps * Variable( torch.eye(self.nKnapsackCategories * self.nThresholds)) self.inequalityMatrix = torch.cat( (self.knapsack_cancels, -1 * self.knapsack_fills, self.PosValMatrix)) self.knapsack_cancels_RHS = torch.sum( self.knapsack_cancels_matrix * self.benchmark_thresholds) self.knapsack_fills_RHS = torch.sum(self.knapsack_fills_matrix * self.benchmark_thresholds) #self.inequalityVector = torch.cat((self.cancel_rate_param*self.h,-1*self.accept_rate_param*self.h,self.PosValVector)) self.inequalityVector = torch.cat( (self.knapsack_cancels_RHS * self.h, -1 * self.knapsack_fills_RHS * self.h, self.PosValVector)) try: thresholds_raw = QPFunctionJK(verbose=1)( Q, -1 * self.knapsack_revenues, self.inequalityMatrix, self.inequalityVector, self.A, self.b) self.thresholds_raw_matrix = thresholds_raw.view( self.nKnapsackCategories, -1) #self.accept_rate=1.0*self.accept_rate_original #self.cancel_rate=1.0*self.cancel_rate_original except AssertionError: print("Error solving LP, likely infeasible") self.lp_infeasible = 1 #print("New Accept and Cancel Rates:",self.accept_rate,self.cancel_rate) self.thresholds_raw_matrix = F.relu( self.thresholds_raw_matrix) + self.eps self.thresholds_raw_matrix_norm = normalize_JK( self.thresholds_raw_matrix, dim=1) #This cXt matrix shows the probability of accepting an order under the learned thresholds, obtained either through direct optimization or through solving an LP accept_probability_cXt = torch.cumsum( self.thresholds_raw_matrix_norm, dim=1 ) #this gives the accept probability by cXt under parameterized thresholds #category is BxC matrix, so summing across dim 0 gets the number of accepted orders per category accept_probability_collection_bXt = torch.cumsum(collection_thresholds, dim=1) reject_probability_collection_bXt = 1 - accept_probability_collection_bXt accept_percent_collection_bXt = accept_probability_collection_bXt * self.inventory_distribution_batch_by_threshold_est accept_percent_collection_b_vector = torch.sum( accept_percent_collection_bXt, dim=1 ).squeeze( ) #This is the believed acceptance rate of general orders of the categories corresponding with the batch under the collection thresholds reject_percent_collection_b_vector = 1 - accept_percent_collection_b_vector self.batch_total_demand_b_vector = ( 1 / accept_percent_collection_b_vector) #.clamp(min=0,max=100) #new to v37 reject_percent_collection_expanded_bXt = reject_percent_collection_b_vector.unsqueeze( 1).expand(self.nBatch, self.nThresholds) self.truncated_orders_distribution_bXt = torch.div( reject_probability_collection_bXt * self.inventory_distribution_batch_by_threshold_est, reject_percent_collection_expanded_bXt + self.eps) truncated_demand_b_vector = self.batch_total_demand_b_vector - 1 #self.belief_total_demand_cXt truncated_demand_bXt = truncated_demand_b_vector.unsqueeze(1).expand( self.nBatch, self.nThresholds) * self.truncated_orders_distribution_bXt batch_total_demand_bXt = truncated_demand_bXt + inv_count self.batch_total_demand_cXt = torch.mm(category.t(), batch_total_demand_bXt) batch_total_demand_c_vector = torch.sum(self.batch_total_demand_cXt, dim=1) batch_zero_demand_c_vector = 1 - batch_total_demand_c_vector.ge(0) #batch_supplement_demand = torch.masked_select(belief_total_demand_c_vector_est,batch_zero_demand_c_vector) self.estimated_batch_total_demand = torch.sum( self.batch_total_demand_b_vector ) #+torch.sum(batch_supplement_demand) #Now we want to see how accurate our inventory distributions are for the batch accept_probability_batch_by_threshold = CumSumNoGrad( verbose=-1)(collection_thresholds) + self.eps self.inventory_distribution_batch_by_thresholds = torch.mm( category, self.inventory_distribution_raw_est) arrival_probability_batch_by_threshold_unnormed = self.inventory_distribution_batch_by_thresholds * accept_probability_batch_by_threshold arrival_probability_batch_by_threshold = torch.div( arrival_probability_batch_by_threshold_unnormed, torch.sum(arrival_probability_batch_by_threshold_unnormed, dim=1).unsqueeze(1).expand_as( arrival_probability_batch_by_threshold_unnormed)) log_arrival_prob = torch.log(arrival_probability_batch_by_threshold + self.eps) #Like we do for inventory, we want to measure the accuracy of our cancel params for the batch self.belief_cancel_rate_bXt = torch.mm(category, self.belief_cancel_rate_cXt_est) belief_fill_rate_bXt = 1 - self.belief_cancel_rate_bXt self.belief_cancel_rate_b_vector = torch.sum( self.belief_cancel_rate_bXt * inv_count, dim=1).squeeze() belief_fill_rate_b_vector = 1 - self.belief_cancel_rate_b_vector log_cancel_prob = torch.log( torch.cat((belief_fill_rate_b_vector.unsqueeze(1), self.belief_cancel_rate_b_vector.unsqueeze(1)), 1) + self.eps) self.belief_category_dist_bXc = self.demand_distribution_est.unsqueeze( 0).expand(self.nBatch, self.nKnapsackCategories) log_category_prob = torch.log(self.belief_category_dist_bXc + self.eps) ##This is new in v37. We want to combine the actual results observed in the batch but add in estimated effects of truncation accept_probability_using_threshold_params_bXt = torch.mm( category, accept_probability_cXt) truncated_accept_estimate = truncated_demand_bXt * accept_probability_using_threshold_params_bXt #This is the number of truncated orders we expect to accept (using param thresholds) at each inventory level corresponding to each order in the batch truncated_cancel_estimate = truncated_accept_estimate * self.belief_cancel_rate_bXt truncated_fill_estimate = truncated_accept_estimate * belief_fill_rate_bXt truncated_revenue_estimate = truncated_fill_estimate * ( price.unsqueeze(1).expand(self.nBatch, self.nThresholds)) truncated_revenue_estimate_sum = torch.sum(truncated_revenue_estimate) self.truncated_cancel_estimate_sum = torch.sum( truncated_cancel_estimate) truncated_fill_estimate_sum = torch.sum(truncated_fill_estimate) self.truncated_accept_estimate_sum = torch.sum( truncated_accept_estimate) ##This is new in v37. We want to combine the actual results observed in the batch but add in estimated effects of truncation fill = 1 - cancel batch_cancel_bXt = cancel.unsqueeze(1).expand( self.nBatch, self.nThresholds ) * inv_count * accept_probability_using_threshold_params_bXt batch_fill_bXt = fill.unsqueeze(1).expand( self.nBatch, self.nThresholds ) * inv_count * accept_probability_using_threshold_params_bXt batch_cancel_b_vector = torch.sum(batch_cancel_bXt, dim=1).squeeze() batch_fill_b_vector = torch.sum(batch_fill_bXt, dim=1).squeeze() batch_accept_b_vector = torch.sum( inv_count * accept_probability_using_threshold_params_bXt, dim=1).squeeze() #print("sanity check",batch_accept_b_vector, batch_fill_b_vector+batch_cancel_b_vector) #print("sanity check 2", torch.sum(batch_accept_b_vector), torch.sum(batch_fill_b_vector+batch_cancel_b_vector)) batch_revenue_b_vector = price * batch_fill_b_vector self.batch_fill_sum = torch.sum(batch_fill_b_vector, dim=0) self.batch_revenue_sum = torch.sum(batch_revenue_b_vector, dim=0) self.batch_cancel_sum = torch.sum(batch_cancel_b_vector, dim=0) self.batch_accept_sum = torch.sum(batch_accept_b_vector, dim=0) new_objective_loss = -(1.0 / 50000) * (truncated_revenue_estimate_sum + self.batch_revenue_sum) new_cancel_constraint_loss = self.truncated_cancel_estimate_sum + self.batch_cancel_sum - ( self.truncated_accept_estimate_sum + self.batch_accept_sum) * self.cancel_rate_evaluation new_accept_constraint_loss = (1.0 / 7.0) * ( (self.truncated_accept_estimate_sum + self.batch_accept_sum) * self.accept_rate_evaluation - truncated_fill_estimate_sum - self.batch_fill_sum) #new_cancel_constraint_loss = truncated_cancel_estimate_sum+self.batch_cancel_sum-self.estimated_batch_total_demand*self.cancel_rate_param #new_accept_constraint_loss = (1.0/7.0)*(self.estimated_batch_total_demand*self.accept_rate_param-truncated_fill_estimate_sum-self.batch_fill_sum) observed_cancel_constraint_loss = self.batch_cancel_sum - ( self.batch_accept_sum) * self.cancel_rate_evaluation observed_accept_constraint_loss = ( 1.0 / 7.0) * (self.batch_accept_sum * self.accept_rate_evaluation - self.batch_fill_sum) return new_objective_loss, new_cancel_constraint_loss, new_accept_constraint_loss, arrival_probability_batch_by_threshold, log_arrival_prob, log_cancel_prob, log_category_prob, self.estimated_batch_total_demand, observed_cancel_constraint_loss, observed_accept_constraint_loss, self.lp_infeasible
class FilterStripe(nn.Conv2d):#卷积层+FS层 def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): super(FilterStripe, self).__init__(in_channels, out_channels, kernel_size, stride, kernel_size // 2, groups=1, bias=False) self.BrokenTarget = None self.FilterSkeleton = Parameter(torch.ones(self.out_channels, self.kernel_size[0], self.kernel_size[1]), requires_grad=True)#FS层初始化 def forward(self, x):#forward()是自动调用的,x:[N,通道数,width,height] if self.BrokenTarget is not None: #out:[N,通道数,width,height] out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0], int(np.ceil(x.shape[2] / self.stride[0])), int(np.ceil(x.shape[3] / self.stride[1])))#ceil() 函数返回数字的上入整数 if x.is_cuda: out = out.cuda() x = F.conv2d(x, self.weight)#卷积输出 l, h = 0, 0 for i in range(self.BrokenTarget.shape[0]): for j in range(self.BrokenTarget.shape[1]): h += self.FilterSkeleton[:, i, j].sum().item()#FS层每个通道对应的值相加 out[:, self.FilterSkeleton[:, i, j]] += self.shift(x[:, l:h], i, j)[:, :, ::self.stride[0], ::self.stride[1]]#获得每个通道对应索引的输出 l += self.FilterSkeleton[:, i, j].sum().item() return out#输出 else: #unsqueeze(1)在第二个维度增加一个维度 return F.conv2d(x, self.weight * self.FilterSkeleton.unsqueeze(1), stride=self.stride, padding=self.padding, groups=self.groups) def prune_in(self, in_mask=None):#in_mask掩膜 #self.weight.shape:[out_channel,k,k,in_channel] print(self.weight.shape) self.weight = Parameter(self.weight[:, in_mask])#?????????? print(self.weight) self.in_channels = in_mask.sum().item() def prune_out(self, threshold):#threshold为阈值 out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0#获得掩膜 if out_mask.sum() == 0: print(out_mask.sum()) out_mask[0] = True self.weight = Parameter(self.weight[out_mask])#卷积核掩膜化 self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask], requires_grad=True)#FS层掩膜化 self.out_channels = out_mask.sum().item()#获取输出通道 return out_mask#掩膜 def _break(self, threshold): self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1))#卷积核与FS层相乘 self.FilterSkeleton = Parameter((self.FilterSkeleton.abs() > threshold), requires_grad=False)#FS层大于阈值的为true if self.FilterSkeleton.sum() == 0: self.FilterSkeleton.data[0][0][0] = True self.out_channels = self.FilterSkeleton.sum().item() self.BrokenTarget = self.FilterSkeleton.sum(dim=0) self.kernel_size = (1, 1) #permute()将tensor的维度换位。 # print(self.FilterSkeleton.permute(1, 2, 0).reshape(-1)) # print(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)) self.weight = Parameter(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)[self.FilterSkeleton.permute(1, 2, 0).reshape(-1)])#掩膜化 # print(self.weight) def update_skeleton(self, sr, threshold): self.FilterSkeleton.grad.data.add_(sr * torch.sign(self.FilterSkeleton.data))#FS层的梯度更新,加入L1范数的导数 mask = self.FilterSkeleton.data.abs() > threshold self.FilterSkeleton.data.mul_(mask)#掩码化 self.FilterSkeleton.grad.data.mul_(mask)#掩码化 out_mask = mask.sum(dim=(1, 2)) != 0#???? return out_mask def shift(self, x, i, j): return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j, j - self.BrokenTarget.shape[0] // 2, self.BrokenTarget.shape[0] // 2 - i, i - self.BrokenTarget.shape[1] // 2), 'constant', 0) def extra_repr(self): s = ('{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') return s.format(**self.__dict__)
class Frankenstein(nn.Module): def __init__(self, x=47764, h=128, L=16, v_t=3620, W=32, R=8, N=512, bs=1, reset=True, palette=False): super(Frankenstein, self).__init__() self.reset = reset # debugging usages self.last_state_dict = None '''PARAMETERS''' self.x = x self.h = h self.L = L self.v_t = v_t self.W = W self.R = R self.N = N self.bs = bs self.E_t = W * R + 3 * W + 5 * R + 3 '''CONTROLLER''' self.RNN_list = nn.ModuleList() for _ in range(self.L): self.RNN_list.append( RNN_Unit(self.x, self.R, self.W, self.h, self.bs)) self.hidden_previous_timestep = Parameter(torch.Tensor( self.bs, self.L, self.h).cuda(), requires_grad=False) self.W_y = Parameter(torch.Tensor(self.L * self.h, self.v_t).cuda()) self.W_E = Parameter(torch.Tensor(self.L * self.h, self.E_t).cuda()) self.b_y = Parameter(torch.Tensor(self.v_t).cuda()) self.b_E = Parameter(torch.Tensor(self.E_t).cuda()) '''MEMORY''' # p, (N), should be simplex bound self.precedence_weighting = Parameter(torch.Tensor(self.bs, self.N).cuda(), requires_grad=False) # (N,N) self.temporal_memory_linkage = Parameter(torch.Tensor( self.bs, self.N, self.N).cuda(), requires_grad=False) # (N,W) self.memory = Parameter(torch.Tensor(self.N, self.W).cuda(), requires_grad=False) # (N, R). self.last_read_weightings = Parameter(torch.Tensor( self.bs, self.N, self.R).cuda(), requires_grad=False) # u_t, (N) self.last_usage_vector = Parameter(torch.Tensor(self.bs, self.N).cuda(), requires_grad=False) # store last write weightings for the calculation of usage vector self.last_write_weighting = Parameter(torch.Tensor(self.bs, self.N).cuda(), requires_grad=False) self.palette = None if palette: stdv = 1.0 self.memory.data.uniform_(-stdv, stdv) self.palette = palette self.initialz = self.memory.data self.first_t_flag = True '''COMPUTER''' self.last_read_vector = Parameter(torch.Tensor(self.bs, self.W, self.R).cuda(), requires_grad=False) self.W_r = Parameter(torch.Tensor(self.W * self.R, self.v_t).cuda()) self.reset_parameters() def reset_parameters(self): # if debug: # print("parameters are reset") '''Controller''' for module in self.RNN_list: # this should iterate over RNN_Units only module.reset_parameters() self.hidden_previous_timestep.zero_() stdv = 1.0 / math.sqrt(self.v_t) self.W_y.data.uniform_(-stdv, stdv) self.b_y.data.uniform_(-stdv, stdv) stdv = 1.0 / math.sqrt(self.E_t) self.W_E.data.uniform_(-stdv, stdv) self.b_E.data.uniform_(-stdv, stdv) '''Memory''' self.precedence_weighting.zero_() self.last_usage_vector.zero_() self.last_read_weightings.zero_() self.last_write_weighting.zero_() self.temporal_memory_linkage.zero_() # memory must be initialized like this, otherwise usage vector will be stuck at zero. stdv = 1.0 self.memory.data.uniform_(-stdv, stdv) self.first_t_flag = True '''Computer''' # see paper, paragraph 2 page 7 self.last_read_vector.zero_() stdv = 1.0 / math.sqrt(self.v_t) self.W_r.data.uniform_(-stdv, stdv) def new_sequence_reset(self): ''' The biggest question is whether to reset memory every time a new sequence is taken in. My take is to not reset the memory, but this might not be the best strategy there is. If memory is not reset at each new sequence, then we should not reset the memory at all? :return: ''' # if debug: # print('new sequence reset') '''controller''' self.hidden_previous_timestep = Parameter(torch.Tensor( self.bs, self.L, self.h).zero_().cuda(), requires_grad=False) for RNN in self.RNN_list: RNN.new_sequence_reset() self.W_y = Parameter(self.W_y.data) self.b_y = Parameter(self.b_y.data) self.W_E = Parameter(self.W_E.data) self.b_E = Parameter(self.b_E.data) '''memory''' if self.reset: if self.palette: self.memory.data = self.initialz else: # we will reset the memory altogether. # TODO The question is, should we reset the memory to a fixed state? There are good arguments for it. stdv = 1.0 # gradient should not carry over, since at this stage, requires_grad on this parameter should be False. self.memory.data.uniform_(-stdv, stdv) # TODO is there a reason to reinitialize the parameter object? I don't think so. The graph is not carried over. self.last_usage_vector.zero_() self.precedence_weighting.zero_() self.temporal_memory_linkage.zero_() self.last_read_weightings.zero_() self.last_write_weighting.zero_() # self.last_usage_vector = Parameter(torch.Tensor(self.bs, self.N).zero_().cuda(), requires_grad=False) # self.precedence_weighting = Parameter(torch.Tensor(self.bs, self.N).zero_().cuda(),requires_grad=False) # self.temporal_memory_linkage = Parameter(torch.Tensor(self.bs, self.N, self.N).zero_().cuda(),requires_grad=False) # # with a new sequence, the calculation of forward weighting, for example, still requires the last_read_weighting # self.last_read_weightings = Parameter(torch.Tensor(self.bs, self.N, self.R).zero_().cuda(),requires_grad=False) # self.last_write_weighting = Parameter(torch.Tensor(self.bs, self.N).zero_().cuda(),requires_grad=False) self.first_t_flag = True '''computer''' self.last_read_vector = Parameter(torch.Tensor(self.bs, self.W, self.R).zero_().cuda(), requires_grad=False) self.W_r = Parameter(self.W_r.data) def forward(self, input): if (input != input).any(): raise ValueError("We have NAN in inputs") input_x_t = torch.cat((input, self.last_read_vector.view(self.bs, -1)), dim=1) '''Controller''' hidden_previous_layer = Variable( torch.Tensor(self.bs, self.h).zero_().cuda()) hidden_this_timestep = Variable( torch.Tensor(self.bs, self.L, self.h).cuda()) for i in range(self.L): hidden_output = self.RNN_list[i]( input_x_t, self.hidden_previous_timestep[:, i, :], hidden_previous_layer) if (hidden_output != hidden_output).any(): raise ValueError("We have NAN in controller output.") hidden_this_timestep[:, i, :] = hidden_output hidden_previous_layer = hidden_output flat_hidden = hidden_this_timestep.view((self.bs, self.L * self.h)) output = torch.matmul(flat_hidden, self.W_y) interface_input = torch.matmul(flat_hidden, self.W_E) # this detaches hidden from previous hidden. self.hidden_previous_timestep = Parameter(hidden_this_timestep.data, requires_grad=False) '''interface''' last_index = self.W * self.R # Read keys, each W dimensions, [W*R] in total # no processing needed # this is the address keys, not the contents read_keys = interface_input[:, 0:last_index].contiguous().view( self.bs, self.W, self.R) # Read strengths, [R] # 1 to infinity # slightly different equation from the paper, should be okay read_strengths = interface_input[:, last_index:last_index + self.R] last_index = last_index + self.R read_strengths = 1 - nn.functional.logsigmoid(read_strengths) # Write key, [W] write_key = interface_input[:, last_index:last_index + self.W] last_index = last_index + self.W # write strength beta, [1] write_strength = interface_input[:, last_index:last_index + 1] last_index = last_index + 1 write_strength = 1 - nn.functional.logsigmoid(write_strength) # erase strength, [W] erase_vector = interface_input[:, last_index:last_index + self.W] last_index = last_index + self.W erase_vector = torch.sigmoid(erase_vector) # write vector, [W] write_vector = interface_input[:, last_index:last_index + self.W] last_index = last_index + self.W # R free gates? [R] free_gates = interface_input[:, last_index:last_index + self.R] last_index = last_index + self.R free_gates = torch.sigmoid(free_gates) # allocation gate [1] allocation_gate = interface_input[:, last_index:last_index + 1] last_index = last_index + 1 allocation_gate = torch.sigmoid(allocation_gate) # write gate [1] write_gate = interface_input[:, last_index:last_index + 1] last_index = last_index + 1 write_gate = torch.sigmoid(write_gate) # read modes [R,3] read_modes = interface_input[:, last_index:last_index + self.R * 3] read_modes = read_modes.contiguous().view(self.bs, self.R, 3) read_modes = nn.functional.softmax(read_modes, dim=2) '''memory''' memory_retention = self.memory_retention(free_gates) # usage vector update must be called before allocation weighting. self.update_usage_vector(memory_retention) allocation_weighting = self.allocation_weighting() write_weighting = self.write_weighting(write_key, write_strength, allocation_gate, write_gate, allocation_weighting) self.write_to_memory(write_weighting, erase_vector, write_vector) # update some self.update_temporal_linkage_matrix(write_weighting) self.update_precedence_weighting(write_weighting) forward_weighting = self_weighting() backward_weighting = self.backward_weighting() read_weightings = self.read_weightings(forward_weighting, backward_weighting, read_keys, read_strengths, read_modes) # read from memory last, a new modification. read_vector = Parameter(self.read_memory(read_weightings).data, requires_grad=False) # DEBUG NAN if (read_vector != read_vector).any(): # this is a problem! TODO raise ValueError("nan is found.") '''back to computer''' output2 = output + torch.matmul( read_vector.view(self.bs, self.W * self.R), self.W_r) # update the last weightings self.last_read_vector = read_vector self.last_read_weightings = Parameter(read_weightings.data, requires_grad=False) self.last_write_weighting = Parameter(write_weighting.data, requires_grad=False) self.first_t_flag = False if debug: test_simplex_bound(self.last_read_weightings) test_simplex_bound(self.last_write_weighting) if (output2 != output2).any(): raise ValueError("nan is found.") return output2 def write_content_weighting(self, write_key, key_strength, eps=1e-8): ''' :param memory: M, (N, W) :param write_key: k, (W), R, desired content :param key_strength: \beta, (1) [1, \infty) :param index: i, lookup on memory[i] :return: most similar weighted: C(M,k,\beta), (N), (0,1) ''' # memory will be (N,W) # write_key will be (bs, W) # I expect a return of (N,bs), which marks the similiarity of each W with each mem loc # (self.bs, self.N) innerprod = torch.matmul(write_key, self.memory.t()) # (parm.N) memnorm = torch.norm(self.memory, 2, 1) # (self.bs) writenorm = torch.norm(write_key, 2, 1) # (self.N, self.bs) normalizer = torch.ger(memnorm, writenorm) similarties = innerprod / normalizer.t().clamp(min=eps) similarties = similarties * key_strength.expand(-1, self.N) normalized = softmax(similarties, dim=1) if debug: if (normalized != normalized).any(): task_dir = os.path.dirname(abspath(__file__)) save_dir = Path(task_dir) / "saves" / "keykey.pkl" pickle.dump((write_key.cpu(), key_strength.cpu()), save_dir.open('wb')) raise ValueError("NA found in write content weighting") return normalized def read_content_weighting(self, read_keys, key_strengths, eps=1e-8): ''' :param memory: M, (N, W) :param read_keys: k^r_t, (W,R), R, desired content :param key_strength: \beta, (R) [1, \infty) :param index: i, lookup on memory[i] :return: most similar weighted: C(M,k,\beta), (N, R), (0,1) ''' ''' torch definition def cosine_similarity(x1, x2, dim=1, eps=1e-8): w12 = torch.sum(x1 * x2, dim) w1 = torch.norm(x1, 2, dim) w2 = torch.norm(x2, 2, dim) return w12 / (w1 * w2).clamp(min=eps) ''' innerprod = torch.matmul(self.memory.unsqueeze(0), read_keys) # this is confusing. matrix[n] access nth row, not column # this is very counter-intuitive, since columns have meaning, # because they represent vectors mem_norm = torch.norm(self.memory, p=2, dim=1) read_norm = torch.norm(read_keys, p=2, dim=1) mem_norm = mem_norm.unsqueeze(1) read_norm = read_norm.unsqueeze(1) # (batch_size, locations, read_heads) normalizer = torch.matmul(mem_norm, read_norm) # if transposed then similiarities[0] refers to the first read key similarties = innerprod / normalizer.clamp(min=eps) weighted = similarties * key_strengths.unsqueeze(1).expand( -1, self.N, -1) ret = softmax(weighted, dim=1) return ret # the highest freed will be retained? What does it mean? def memory_retention(self, free_gate): ''' :param free_gate: f, (R), [0,1], from interface vector :param read_weighting: w^r_t, (N, R), simplex bounded, note it's from previous timestep. :return: \psi, (N), [0,1] ''' # a free gate belongs to a read head. # a single read head weighting is a (N) dimensional simplex bounded value # (N, R) inside_bracket = 1 - self.last_read_weightings * free_gate.unsqueeze( 1).expand(-1, self.N, -1) ret = torch.prod(inside_bracket, 2) return ret def update_usage_vector(self, memory_retention): ''' :param memory_retention: \psi_t, (N), simplex bound :return: u_t, (N), [0,1], the next usage ''' if self.first_t_flag: ret = Parameter(torch.Tensor(self.bs, self.N).zero_().cuda(), requires_grad=False) return ret ret = (self.last_usage_vector + self.last_write_weighting - self.last_usage_vector * self.last_write_weighting) \ * memory_retention # Here we should use .data instead? Like: # self.usage_vector.data=ret.data # Usage vector contain all computation history, # which is not necessary? I'm not sure, maybe the write weighting should be back_propped here? # We reset usage vector for every seq, but should we for every timestep? self.last_usage_vector = Parameter(ret.data, requires_grad=False) return ret def allocation_weighting(self): ''' Sorts the memory by usages first. Then perform calculation depending on the sort order. The alloation_weighting of the third least used memory is calculated as follows: Find the least used and second least used. Multiply their usages. Multiply the product with (1-usage of the third least), return. Do not confuse the sort order and the memory's natural location. Verify backprop. :param usage_vector: u_t, (N), [0,1] :return: allocation_wighting: a_t, (N), simplex bound ''' # not the last usage, since we will update usage before this sorted, indices = self.last_usage_vector.sort(dim=1) cum_prod = torch.cumprod(sorted, 1) # notice the index on the product cum_prod = torch.cat( [Variable(torch.ones(self.bs, 1).cuda()), cum_prod], 1)[:, :-1] sorted_inv = 1 - sorted allocation_weighting = sorted_inv * cum_prod # to shuffle back in place ret = torch.gather(allocation_weighting, 1, indices) if debug: if (ret != ret).any(): raise ValueError("NA found in allocation weighting") return ret def write_weighting(self, write_key, write_strength, allocation_gate, write_gate, allocation_weighting): ''' calculates the weighting on each memory cell when writing a new value in :param memory: M, (N, W), memory block :param write_key: k^w_t, (W), R, the key that is to be written :param write_strength: \beta, (1) bigger it is, stronger it concentrates the content weighting :param allocation_gate: g^a_t, (1), balances between write by content and write by allocation gate :param write_gate: g^w_t, (1), overall strength of the write signal :param allocation_weighting: see above. :return: write_weighting: (N), simplex bound ''' # measures content similarity content_weighting = self.write_content_weighting( write_key, write_strength) write_weighting = write_gate * ( allocation_gate * allocation_weighting + (1 - allocation_gate) * content_weighting) if debug: test_simplex_bound(write_weighting, 1) return write_weighting def update_precedence_weighting(self, write_weighting): ''' :param write_weighting: (N) :return: self.precedence_weighting: (N), simplex bound ''' # this is the bug. I called the python default sum() instead of torch.sum() # Took me 3 hours. # sum_ww=sum(write_weighting,1) sum_ww = torch.sum(write_weighting, dim=1) self.precedence_weighting = Parameter( ((1 - sum_ww).unsqueeze(1) * self.precedence_weighting + write_weighting).data, requires_grad=False) if debug: test_simplex_bound(self.precedence_weighting, 1) return self.precedence_weighting def update_temporal_linkage_matrix(self, write_weighting): ''' :param write_weighting: (N) :param precedence_weighting: (N), simplex bound :return: updated_temporal_linkage_matrix ''' # TODO We need to mathematically understand why this function will # TODO maintain the simplex bound condition. if self.first_t_flag: return self.temporal_memory_linkage else: ww_j = write_weighting.unsqueeze(1).expand(-1, self.N, -1) ww_i = write_weighting.unsqueeze(2).expand(-1, -1, self.N) p_j = self.precedence_weighting.unsqueeze(1).expand(-1, self.N, -1) batch_temporal_memory_linkage = self.temporal_memory_linkage.expand( self.bs, -1, -1) newtml = Parameter( ((1 - ww_j - ww_i) * batch_temporal_memory_linkage + ww_i * p_j).data, requires_grad=False) is_cuda = ww_j.is_cuda if is_cuda: ### WHAT IS THIS? idx = torch.arange(0, self.N, out=torch.cuda.LongTensor()) else: idx = torch.arange(0, self.N, out=torch.LongTensor()) newtml[:, idx, idx] = 0 if debug: try: test_simplex_bound(newtml, 1) test_simplex_bound(newtml.transpose(1, 2), 1) except ValueError: traceback.print_exc() print("precedence close to one?", self.precedence_weighting.sum() > 1) raise self.temporal_memory_linkage = Parameter(newtml.data, requires_grad=False) return self.temporal_memory_linkage def backward_weighting(self): ''' :return: backward_weighting: b^i_t, (N,R) ''' ret = torch.matmul(self.temporal_memory_linkage, self.last_read_weightings) if debug: test_simplex_bound(ret, 1) return ret def forward_weighting(self): ''' :return: forward_weighting: f^i_t, (N,R) ''' ret = torch.matmul(self.temporal_memory_linkage.transpose(1, 2), self.last_read_weightings) if debug: test_simplex_bound(ret, 1) return ret # TODO sparse update, skipped because it's for performance improvement. def read_weightings(self, forward_weighting, backward_weighting, read_keys, read_strengths, read_modes): ''' :param forward_weighting: (bs,N,R) :param backward_weighting: (bs,N,R) ****** content_weighting: C, (bs,N,R), (0,1) :param read_keys: k^w_t, (bs,W,R) :param read_key_strengths: (bs,R) :param read_modes: /pi_t^i, (bs,R,3) :return: read_weightings: w^r_t, (bs,N,R) ''' content_weighting = self.read_content_weighting( read_keys, read_strengths) if debug: test_simplex_bound(content_weighting, 1) test_simplex_bound(backward_weighting, 1) test_simplex_bound(forward_weighting, 1) # has dimension (bs,3,N,R) all_weightings = torch.stack( [backward_weighting, content_weighting, forward_weighting], dim=1) # permute to dimension (bs,R,N,3) all_weightings = all_weightings.permute(0, 3, 2, 1) # this is becuase torch.matmul is designed to iterate all dimension excluding the last two # dimension (bs,R,3,1) read_modes = read_modes.unsqueeze(3) # dimension (bs,N,R) read_weightings = torch.matmul(all_weightings, read_modes).squeeze(3).transpose(1, 2) # last read weightings if debug: # if the second test passes, how come the first one does not? test_simplex_bound(self.last_read_weightings, 1) test_simplex_bound(read_weightings, 1) if (read_weightings != read_weightings).any(): raise ValueError("NAN is found") return read_weightings def read_memory(self, read_weightings): ''' memory: (N,W) read weightings: (N,R) :return: read_vectors: [r^i_R], (W,R) ''' return torch.matmul(self.memory.t(), read_weightings) def write_to_memory(self, write_weighting, erase_vector, write_vector): ''' :param write_weighting: the strength of writing :param erase_vector: e_t, (W), [0,1] :param write_vector: w^w_t, (W), :return: ''' term1_2 = torch.matmul(write_weighting.unsqueeze(2), erase_vector.unsqueeze(1)) # term1=self.memory.unsqueeze(0)*Variable(torch.ones((self.bs,self.N,self.W)).cuda()-term1_2.data) term1 = self.memory.unsqueeze(0) * (1 - term1_2) term2 = torch.matmul(write_weighting.unsqueeze(2), write_vector.unsqueeze(1)) self.memory = Parameter(torch.mean(term1 + term2, dim=0).data, requires_grad=False)