Exemplo n.º 1
0
class nice(nn.Module):
    def __init__(self, im_size):
        super(nice, self).__init__()
        self.cp1 = additive_alternant_coupling_layer(im_size, 1024)
        self.cp2 = additive_alternant_coupling_layer(im_size, 1024)
        self.cp3 = additive_alternant_coupling_layer(im_size, 1024)
        self.scale = Parameter(torch.zeros(im_size))
        self.to(DEVICE)

    def forward(self, x, inv=False):
        if not inv:
            cp1 = self.cp1(x)
            cp2 = self.cp2(cp1)
            cp3 = self.cp3(cp2)
            return torch.exp(self.scale) * cp3
        else:
            cp3 = x * torch.exp(-self.scale)
            cp2 = self.cp3(cp3, True)
            cp1 = self.cp2(cp2, True)
            return self.cp1(cp1, True)

    def log_logistic(self, h):
        return -(F.softplus(h) + F.softplus(-h))

    def train_loss(self, h):
        return -(self.log_logistic(h).sum(1).mean() + self.scale.sum())
Exemplo n.º 2
0
class PopArt(Module):
    """PopArt http://papers.nips.cc/paper/6076-learning-values-across-many-orders-of-magnitude"""

    def __init__(self, output_layer, beta: float = 0.0003, zero_debias: bool = True, start_pop: int = 8):
        # zero_debias=True and start_pop=8 seem to improve things a little but (False, 0) works as well
        super().__init__()
        self.start_pop = start_pop
        self.beta = beta
        self.zero_debias = zero_debias
        self.output_layers = output_layer if isinstance(output_layer, (tuple, list, torch.nn.ModuleList)) else (output_layer,)
        shape = self.output_layers[0].bias.shape
        device = self.output_layers[0].bias.device
        assert all(shape == x.bias.shape for x in self.output_layers)
        self.mean = Parameter(torch.zeros(shape, device=device), requires_grad=False)
        self.mean_square = Parameter(torch.ones(shape, device=device), requires_grad=False)
        self.std = Parameter(torch.ones(shape, device=device), requires_grad=False)
        self.updates = 0

    @torch.no_grad()
    def update(self, targets):
        beta = max(1 / (self.updates + 1), self.beta) if self.zero_debias else self.beta
        # note that for beta = 1/self.updates the resulting mean, std would be the true mean and std over all past data

        new_mean = (1 - beta) * self.mean + beta * targets.mean(0)
        new_mean_square = (1 - beta) * self.mean_square + beta * (targets * targets).mean(0)
        new_std = (new_mean_square - new_mean * new_mean).sqrt().clamp(0.0001, 1e6)

        # assert self.std.shape == (1,), 'this has only been tested in 1D'

        if self.updates >= self.start_pop:
            for layer in self.output_layers:
                layer.weight *= (self.std / new_std)[:, None]
                layer.bias *= self.std
                layer.bias += self.mean - new_mean
                layer.bias /= new_std

        self.mean.copy_(new_mean)
        self.mean_square.copy_(new_mean_square)
        self.std.copy_(new_std)
        self.updates += 1
        return self.normalize(targets)

    def normalize(self, x):
        return (x - self.mean) / self.std

    def unnormalize(self, x):
        return x * self.std + self.mean

    def normalize_sum(self, s):
        """normalize x.sum(1) preserving relative weightings between elements"""
        return (s - self.mean.sum()) / self.std.norm()
Exemplo n.º 3
0
class SimplePolicyContinuous(nn.Module):
    """ Simple policy for continuous actions, using a Gaussian normal as the distribution from which
    to sample the actions. The parameters (mean, std) of the distribution are computed from the state
    using a neural network. The architecture is taken from:
    https://github.com/lantunes/mountain-car-continuous/blob/master/rl/reinforce/agent.py
    """
    def __init__(self, input_size: int, output_size: int):
        super(SimplePolicyContinuous, self).__init__()
        self.output_size = output_size

        self.affine1Mu = nn.Linear(input_size, 128)
        self.affine1Mu.bias.data.fill_(0)
        self.affine2Mu = nn.Linear(128, 128)
        self.affine2Mu.bias.data.fill_(0)
        self.affine3Mu = nn.Linear(128, output_size, bias=False)
        self.affine3Mu.weight.data.fill_(0)

        # Important: It must be a Parameter and not a variable! If it's a variable
        # then it won't be part of the parameters given to the optimizer, meaning
        # it will never change. This will mean sigma will never improve and so
        # the results will stay very bad.
        self.hiddenSigma = Parameter(torch.zeros(32), requires_grad=True)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        mu = self.affine1Mu(x)
        mu = torch.tanh(mu)
        mu = self.affine2Mu(mu)
        mu = torch.tanh(mu)
        mu = self.affine3Mu(mu)
        """
        This network has decided that the sigma will not depend on the state. 
        A possible question one might be asking is: why is there the exp() call below?
        I found the answer in the Spinning Up website by OpenAI:
        
            "Note that in both cases we output log standard deviations instead of standard deviations directly. 
            This is because log stds are free to take on any values in (-oo, oo), while stds must be nonnegative. 
            It’s easier to train parameters if you don’t have to enforce those kinds of constraints. The standard 
            deviations can be obtained immediately from the log standard deviations by exponentiating them, so 
            we do not lose anything  by representing them this way."
            
        Src: https://spinningup.openai.com/en/latest/spinningup/rl_intro.html
        """
        sigma = self.hiddenSigma.sum()
        sigma = torch.exp(sigma)

        return mu, sigma
Exemplo n.º 4
0
class ProbalisticLinear(Module):
    __constants__ = ['bias']

    def __init__(self, in_features, out_features, eps):
        super(ProbalisticLinear, self).__init__()
        self.in_features = in_features
        self.eps = eps
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    @weak_script_method
    def forward(self, input):
        weight = self.weight / self.weight.sum(1, keepdim=True)  # .clamp(min=self.eps)
        return F.linear(input, weight)

    def extra_repr(self):
        return 'in_features={}, out_features={} (eps={})'.format(
            self.in_features, self.out_features, self.eps
        )
Exemplo n.º 5
0
class FilterStripe(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(FilterStripe, self).__init__(in_channels,
                                           out_channels,
                                           kernel_size,
                                           stride,
                                           kernel_size // 2,
                                           groups=1,
                                           bias=False)
        self.BrokenTarget = None
        self.FilterSkeleton = Parameter(torch.ones(self.out_channels,
                                                   self.kernel_size[0],
                                                   self.kernel_size[1]),
                                        requires_grad=True)

    def forward(self, x):
        if self.BrokenTarget is not None:
            out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0],
                              int(np.ceil(x.shape[2] / self.stride[0])),
                              int(np.ceil(x.shape[3] / self.stride[1])))
            if x.is_cuda:
                out = out.cuda()
            x = F.conv2d(x, self.weight)
            l, h = 0, 0
            for i in range(self.BrokenTarget.shape[0]):
                for j in range(self.BrokenTarget.shape[1]):
                    h += self.FilterSkeleton[:, i, j].sum().item()
                    out[:, self.FilterSkeleton[:, i, j]] += self.shift(
                        x[:, l:h], i,
                        j)[:, :, ::self.stride[0], ::self.stride[1]]
                    l += self.FilterSkeleton[:, i, j].sum().item()
            return out
        else:
            return F.conv2d(x,
                            self.weight * self.FilterSkeleton.unsqueeze(1),
                            stride=self.stride,
                            padding=self.padding,
                            groups=self.groups)

    def prune_in(self, in_mask=None):
        self.weight = Parameter(self.weight[:, in_mask])
        self.in_channels = in_mask.sum().item()

    def prune_out(self, threshold):
        out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0
        if out_mask.sum() == 0:
            out_mask[0] = True
        self.weight = Parameter(self.weight[out_mask])
        self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask],
                                        requires_grad=True)
        self.out_channels = out_mask.sum().item()
        return out_mask

    def _break(self, threshold):
        self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1))
        self.FilterSkeleton = Parameter(
            (self.FilterSkeleton.abs() > threshold), requires_grad=False)
        if self.FilterSkeleton.sum() == 0:
            self.FilterSkeleton.data[0][0][0] = True
        self.out_channels = self.FilterSkeleton.sum().item()
        self.BrokenTarget = self.FilterSkeleton.sum(dim=0)
        self.kernel_size = (1, 1)
        self.weight = Parameter(
            self.weight.permute(2, 3, 0,
                                1).reshape(-1, self.in_channels, 1,
                                           1)[self.FilterSkeleton.permute(
                                               1, 2, 0).reshape(-1)])

    def update_skeleton(self, sr, threshold):
        self.FilterSkeleton.grad.data.add_(
            sr * torch.sign(self.FilterSkeleton.data))
        mask = self.FilterSkeleton.data.abs() > threshold
        self.FilterSkeleton.data.mul_(mask)
        self.FilterSkeleton.grad.data.mul_(mask)
        out_mask = mask.sum(dim=(1, 2)) != 0
        return out_mask

    def shift(self, x, i, j):
        return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j,
                         j - self.BrokenTarget.shape[0] // 2,
                         self.BrokenTarget.shape[0] // 2 - i,
                         i - self.BrokenTarget.shape[1] // 2), 'constant', 0)

    def extra_repr(self):
        s = (
            '{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}'
            ', stride={stride}')
        return s.format(**self.__dict__)
class PGDAttack(BaseAttack):
    """
    Spectral attack for graph data
    """
    def __init__(self,
                 model=None,
                 nnodes=None,
                 loss_type='CE',
                 feature_shape=None,
                 attack_structure=True,
                 attack_features=False,
                 loss_weight=1.0,
                 regularization_weight=0.0,
                 device='cpu'):

        super(PGDAttack, self).__init__(model, nnodes, attack_structure,
                                        attack_features, device)

        assert attack_structure or attack_features, 'attack_feature or attack_structure cannot be both False'

        self.loss_type = loss_type
        self.modified_adj = None
        self.modified_features = None
        self.loss_weight = loss_weight
        self.regularization_weight = regularization_weight

        if attack_features:
            assert True, 'Current Spectral Attack does not support attack feature'

        if attack_structure:
            assert nnodes is not None, 'Please give nnodes='
            self.adj_changes = Parameter(
                torch.FloatTensor(int(nnodes * (nnodes - 1) / 2)))
            torch.nn.init.uniform_(self.adj_changes, 0.0, 0.001)
            # self.adj_changes.data.fill_(0)

        self.complementary = None

    def set_model(self, model):
        self.surrogate = model

    def attack(self,
               ori_features,
               ori_adj,
               labels,
               idx_target,
               n_perturbations,
               att_lr,
               epochs=200,
               distance_type='l2',
               sample_type='sample',
               opt_type='max',
               verbose=True,
               **kwargs):
        """
        Generate perturbations on the input graph
        """

        victim_model = self.surrogate

        self.sparse_features = sp.issparse(ori_features)
        # ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device)
        ori_adj_norm = utils.normalize_adj_tensor(ori_adj, device=self.device)
        ori_e, ori_v = torch.symeig(ori_adj_norm, eigenvectors=True)

        l, r, m = 0, 0, 0
        victim_model.eval()
        # for t in tqdm(range(epochs), desc='Perturb Adj'):
        for t in tqdm(range(epochs)):
            modified_adj = self.get_modified_adj(ori_adj)
            adj_norm = utils.normalize_adj_tensor(modified_adj,
                                                  device=self.device)
            output = victim_model(
                ori_features,
                adj_norm)  # forward of gcn need to normalize adj first
            task_loss = self._loss(output[idx_target], labels[idx_target])

            # spectral distance term for spectral distance
            eigen_mse = torch.tensor(0)
            eigen_self = torch.tensor(0)
            eigen_gf = torch.tensor(0)
            eigen_norm = self.norm = torch.norm(ori_e)
            if self.regularization_weight != 0:
                # add noise to make the graph asymmetric
                modified_adj_noise = modified_adj
                # modified_adj_noise = self.add_random_noise(modified_adj)
                adj_norm_noise = utils.normalize_adj_tensor(modified_adj_noise,
                                                            device=self.device)
                e, v = torch.symeig(adj_norm_noise, eigenvectors=True)
                eigen_mse = torch.norm(ori_e - e)
                eigen_self = torch.norm(e)

                # low-rank loss in GF-attack
                idx = torch.argsort(e)[:128]
                mask = torch.zeros_like(e).bool()
                mask[idx] = True
                eigen_gf = torch.pow(torch.norm(e * mask, p=2), 2) * torch.pow(
                    torch.norm(torch.matmul(v.detach() * mask, ori_features),
                               p=2), 2)

            reg_loss = 0
            if distance_type == 'l2':
                reg_loss = eigen_mse / eigen_norm
            elif distance_type == 'normDiv':
                reg_loss = eigen_self / eigen_norm
            elif distance_type == 'gf':
                reg_loss = eigen_gf
            else:
                exit(f'unknown distance metric: {distance_type}')

            if verbose and t % 20 == 0:
                loss_target, acc_target = calc_acc(output, labels, idx_target)
                print(
                    '-- Epoch {}, '.format(t),
                    'ptb budget/true = {:.1f}/{:.1f}'.format(
                        n_perturbations,
                        torch.clamp(self.adj_changes, 0, 1).sum()),
                    'l/r/m = {:.4f}/{:.4f}/{:.4f}'.format(l, r, m),
                    'class loss = {:.4f} | '.format(task_loss.item()),
                    'reg loss = {:.4f} | '.format(reg_loss.item()),
                    'mse_norm = {:4f} | '.format(eigen_norm),
                    'eigen_mse = {:.4f} | '.format(eigen_mse),
                    'eigen_self = {:.4f} | '.format(eigen_self),
                    'acc/mis = {:.4f}/{:.4f}'.format(acc_target,
                                                     1 - acc_target))

            self.loss = self.loss_weight * task_loss + self.regularization_weight * reg_loss

            adj_grad = torch.autograd.grad(self.loss, self.adj_changes)[0]

            if self.loss_type == 'CE':
                lr = att_lr / np.sqrt(t + 1)
                self.adj_changes.data.add_(lr * adj_grad)

            if self.loss_type == 'CW':
                lr = att_lr / np.sqrt(t + 1)
                self.adj_changes.data.add_(lr * adj_grad)

            # return self.adj_changes.cpu().detach().numpy()

            if verbose and t % 20 == 0:
                print('budget/true={:.1f}/{:.1f}'.format(
                    n_perturbations,
                    torch.clamp(self.adj_changes, 0, 1).sum()))

            if sample_type == 'sample':
                l, r, m = self.projection(n_perturbations)
            elif sample_type == 'greedy':
                self.greedy(n_perturbations)
            elif sample_type == 'greedy2':
                self.greedy2(n_perturbations)
            elif sample_type == 'greedy3':
                self.greedy3(n_perturbations)
            else:
                exit(f"unkown sample type {sample_type}")

            if verbose and t % 20 == 0:
                print('budget/true={:.1f}/{:.1f}'.format(
                    n_perturbations,
                    torch.clamp(self.adj_changes, 0, 1).sum()))

        if sample_type == 'sample':
            self.random_sample(ori_adj, ori_features, labels, idx_target,
                               n_perturbations)
        elif sample_type == 'greedy':
            self.greedy(n_perturbations)
        elif sample_type == 'greedy2':
            self.greedy2(n_perturbations)
        elif sample_type == 'greedy3':
            self.greedy3(n_perturbations)
        else:
            exit(f"unkown sample type {sample_type}")

        print("final ptb budget/true= {:.1f}/{:.1f}".format(
            n_perturbations, self.adj_changes.sum()))
        self.modified_adj = self.get_modified_adj(ori_adj).detach()
        self.check_adj_tensor(self.modified_adj)

        # for sanity check
        ori_adj_norm = utils.normalize_adj_tensor(ori_adj, device=self.device)
        ori_e, ori_v = torch.symeig(ori_adj_norm, eigenvectors=True)
        adj_norm = utils.normalize_adj_tensor(self.modified_adj,
                                              device=self.device)
        e, v = torch.symeig(adj_norm, eigenvectors=True)

        self.adj = ori_adj.detach()
        self.labels = labels.detach()
        self.ori_e = ori_e
        self.ori_v = ori_v
        self.e = e
        self.v = v

    def greedy(self, n_perturbations):
        s = self.adj_changes.cpu().detach().numpy()
        # l = min(s)
        # r = max(s)
        # noise = np.random.normal((l+r)/2, 0.1*(r-l), s.shape)
        # s += noise

        s_vec = np.squeeze(np.reshape(s, (1, -1)))
        # max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations]
        max_index = (-s_vec).argsort()[:n_perturbations]

        mask = np.zeros_like(s_vec)
        mask[max_index] = 1.0

        best_s = np.reshape(mask, s.shape)

        self.adj_changes.data.copy_(
            torch.clamp(torch.tensor(best_s), min=0, max=1))

    def greedy3(self, n_perturbations):
        s = self.adj_changes.cpu().detach().numpy()
        s_vec = np.squeeze(np.reshape(s, (1, -1)))
        # max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations]
        max_index = (s_vec).argsort()[:n_perturbations]

        mask = np.zeros_like(s_vec)
        mask[max_index] = 1.0

        best_s = np.reshape(mask, s.shape)

        self.adj_changes.data.copy_(
            torch.clamp(torch.tensor(best_s), min=0, max=1))

    def greedy2(self, n_perturbations):
        s = self.adj_changes.cpu().detach().numpy()
        l = min(s)
        r = max(s)
        noise = np.random.normal((l + r) / 2, 0.4 * (r - l), s.shape)
        s += noise

        s_vec = np.squeeze(np.reshape(s, (1, -1)))
        max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations]

        mask = np.zeros_like(s_vec)
        mask[max_index] = 1.0

        best_s = np.reshape(mask, s.shape)

        self.adj_changes.data.copy_(
            torch.clamp(torch.tensor(best_s), min=0, max=1))

    def random_sample(self, ori_adj, ori_features, labels, idx_target,
                      n_perturbations):
        K = 10
        best_loss = -1000
        victim_model = self.surrogate
        with torch.no_grad():
            s = self.adj_changes.cpu().detach().numpy()
            for i in range(K):
                sampled = np.random.binomial(1, s)
                # randm = np.random.uniform(size=s.shape[0])
                # sampled = np.where(s > randm, 1, 0)

                # if sampled.sum() > n_perturbations:
                #     continue
                while sampled.sum() > n_perturbations:
                    sampled = np.random.binomial(1, s)
                # if sampled.sum() > n_perturbations:
                #     indices = np.transpose(np.nonzero(sampled))
                #     candidate_idx = [m for m in range(indices.shape[0])]
                #     chosen_idx = np.random.choice(candidate_idx, n_perturbations, replace=False)
                #     chosen_indices = indices[chosen_idx, :]
                #     sampled = np.zeros_like(sampled)
                #     for idx in chosen_indices:
                #         sampled[idx] = 1

                self.adj_changes.data.copy_(torch.tensor(sampled))
                modified_adj = self.get_modified_adj(ori_adj)
                adj_norm = utils.normalize_adj_tensor(modified_adj,
                                                      device=self.device)
                output = victim_model(ori_features, adj_norm)
                loss = self._loss(output[idx_target], labels[idx_target])
                # loss = F.nll_loss(output[idx_target], labels[idx_target])
                # print(loss)
                if best_loss < loss:
                    best_loss = loss
                    best_s = sampled
            self.adj_changes.data.copy_(torch.tensor(best_s))

    def get_modified_adj(self, ori_adj):

        if self.complementary is None:
            self.complementary = (torch.ones_like(ori_adj) - torch.eye(
                self.nnodes).to(self.device) - ori_adj) - ori_adj

        m = torch.zeros((self.nnodes, self.nnodes)).to(self.device)
        tril_indices = torch.tril_indices(row=self.nnodes,
                                          col=self.nnodes,
                                          offset=-1)
        m[tril_indices[0], tril_indices[1]] = self.adj_changes
        m = m + m.t()
        modified_adj = self.complementary * m + ori_adj

        return modified_adj

    def add_random_noise(self, ori_adj):
        noise = 1e-4 * torch.rand(self.nnodes, self.nnodes).to(self.device)
        return (noise + torch.transpose(noise, 0, 1)) / 2.0 + ori_adj

    def projection2(self, n_perturbations):
        s = self.adj_changes.cpu().detach().numpy()
        n = np.squeeze(np.reshape(s, (1, -1))).shape[0]
        self.adj_changes.data.copy_(
            torch.clamp(self.adj_changes.data, min=0, max=n_perturbations / n))
        return 0, 0, 0

    def projection(self, n_perturbations):
        l, r, m = 0, 0, 0
        if torch.clamp(self.adj_changes, 0, 1).sum() > n_perturbations:
            left = (self.adj_changes).min()
            right = self.adj_changes.max()
            miu = self.bisection(left, right, n_perturbations, epsilon=1e-5)
            l = left.cpu().detach()
            r = right.cpu().detach()
            m = miu.cpu().detach()
            self.adj_changes.data.copy_(
                torch.clamp(self.adj_changes.data - miu, min=0, max=1))
        else:
            self.adj_changes.data.copy_(
                torch.clamp(self.adj_changes.data, min=0, max=1))

        return l, r, m

    def _loss(self, output, labels):
        if self.loss_type == "CE":
            loss = F.nll_loss(output, labels)
        if self.loss_type == "CW":
            onehot = utils.tensor2onehot(labels)
            best_second_class = (output - 1000 * onehot).argmax(1).detach()
            margin = output[np.arange(len(output)), labels] - \
                   output[np.arange(len(output)), best_second_class]
            k = 0
            loss = -torch.clamp(margin, min=k).mean()
            # loss = torch.clamp(margin.sum()+50, min=k)
        return loss

    def bisection(self, a, b, n_perturbations, epsilon):
        def func(x):
            return torch.clamp(self.adj_changes - x, 0,
                               1).sum() - n_perturbations

        miu = a
        while ((b - a) >= epsilon):
            miu = (a + b) / 2
            # Check if middle point is root
            if (func(miu) == 0.0):
                b = miu
                break
            # Decide the side to repeat the steps
            if (func(miu) * func(a) < 0):
                b = miu
            else:
                a = miu
        # print("The value of root is : ","%.4f" % miu)
        return miu
Exemplo n.º 7
0
class SafetyNet(nn.Module):
    def __init__(self,
                 nKnapsackCategories,
                 nThresholds,
                 starting_thresholds,
                 nineq=1,
                 neq=0,
                 eps=1e-8,
                 cancel_rate_target=.05,
                 cancel_rate_evaluation=.05,
                 accept_rate_target=.75,
                 accept_rate_evaluation=.75,
                 cancel_initializer=.02,
                 inventory_initializer=3,
                 cancel_coef_initializer=-.2,
                 cancel_intercept_initializer=.3,
                 price_initializer=1,
                 parametric_knapsack=False,
                 knapsack_type=None):
        super().__init__()
        self.nKnapsackCategories = nKnapsackCategories
        self.nThresholds = nThresholds
        #self.nBatch = nBatch
        self.nineq = nineq
        self.neq = neq
        self.eps = eps
        self.cancel_rate_evaluation = cancel_rate_evaluation
        self.accept_rate_evaluation = accept_rate_evaluation
        self.benchmark_thresholds = Variable(starting_thresholds)
        #self.accept_rate_original=Parameter(accept_rate*torch.ones(1))
        #self.cancel_rate_original=Parameter(cancel_rate*torch.ones(1))
        #self.cancel_rate= self.cancel_rate_original*1.0
        #self.accept_rate=self.accept_rate_original*1.0
        self.accept_rate_param = Parameter(accept_rate_target * torch.ones(1))
        self.cancel_rate_param = Parameter(cancel_rate_target * torch.ones(1))
        self.inventory_initializer = inventory_initializer
        self.parametric_knapsack = parametric_knapsack
        self.h = Variable(torch.ones(self.nineq))
        ##Add matrix to make all variables >=0
        self.PosValMatrix = -1 * Variable(
            torch.eye(self.nKnapsackCategories * self.nThresholds))
        self.PosValVector = Variable(
            torch.zeros(self.nKnapsackCategories * self.nThresholds))

        #Equality constraints. These will be the constraints to choose one variable per category
        ##These will be Variables as they are not something that is estimated by the model
        A = torch.zeros(self.nKnapsackCategories,
                        self.nKnapsackCategories * self.nThresholds)
        for row in range(self.nKnapsackCategories):
            A[row][self.nThresholds * row:self.nThresholds * (row + 1)] = 1
        self.A = Variable(A)
        self.b = Variable(torch.ones(self.nKnapsackCategories))

        self.Q_zeros = Variable(
            torch.zeros(nKnapsackCategories * nThresholds,
                        nKnapsackCategories * nThresholds))
        #Initialize thresholds
        self.thresholds = Variable(torch.arange(0, self.nThresholds))
        #Initialize cancel and revenue parameters
        if self.parametric_knapsack:
            self.thresholds_raw_matrix = Variable(starting_thresholds)
            #self.cancel_scale = Parameter((torch.rand(self.nKnapsackCategories)+.5)*cancel_initializer)
            #self.cancel_lam = Parameter(torch.ones(self.nKnapsackCategories)*cancel_initializer)
            #self.cancel_spread = Variable(torch.ones(self.nKnapsackCategories))
            #self.revenue_scale = Parameter((torch.rand(self.nKnapsackCategories)+.5))
            #self.revenue_lam = Parameter(torch.ones(self.nKnapsackCategories))
            #self.revenue_spread = Variable(torch.ones(self.nKnapsackCategories))
        else:
            #self.thresholds_raw_matrix = Parameter(torch.ones(self.nKnapsackCategories,self.nThresholds)*(1.0/self.nThresholds))
            self.thresholds_raw_matrix = Parameter(starting_thresholds)
        self.thresholds_raw_matrix_norm = torch.div(
            self.thresholds_raw_matrix,
            torch.sum(self.thresholds_raw_matrix,
                      dim=1).unsqueeze(1).expand_as(
                          self.thresholds_raw_matrix))

        #Inventory distribution parameters
        self.inventory_lam_opt = Parameter(
            torch.ones(self.nKnapsackCategories) * inventory_initializer)
        #Cancel distribution parameters
        self.cancel_coef_opt = Parameter(
            torch.ones(self.nKnapsackCategories) * cancel_coef_initializer)
        self.cancel_intercept_opt = Parameter(
            torch.ones(self.nKnapsackCategories) *
            cancel_intercept_initializer)
        self.prices_opt = Parameter(
            torch.ones(self.nKnapsackCategories) * price_initializer)
        self.demand_distribution_opt = Parameter(
            torch.ones(self.nKnapsackCategories) *
            (1.0 / self.nKnapsackCategories))

        self.inventory_lam_est = Parameter(
            torch.ones(self.nKnapsackCategories) * inventory_initializer)
        #Cancel distribution parameters
        self.cancel_coef_est = Parameter(
            torch.ones(self.nKnapsackCategories) * cancel_coef_initializer)
        self.cancel_intercept_est = Parameter(
            torch.ones(self.nKnapsackCategories) *
            cancel_intercept_initializer)
        self.prices_est = Parameter(
            torch.ones(self.nKnapsackCategories) * price_initializer)
        self.demand_distribution_est = Parameter(
            torch.ones(self.nKnapsackCategories) *
            (1.0 / self.nKnapsackCategories))

    def normalize_thresholds(self):
        self.thresholds_raw_matrix.data.clamp_(min=self.eps,
                                               max=1.0 - self.eps)
        param_sums = torch.ger(
            self.thresholds_raw_matrix.sum(dim=1).squeeze(),
            Variable(torch.ones(self.nThresholds)))
        self.thresholds_raw_matrix.data.div_(param_sums.data)

    def normalize_demand_params(self):
        self.demand_distribution_est.data.clamp_(min=self.eps)
        self.demand_distribution_est.data.div_(
            self.demand_distribution_est.data.sum())
        self.demand_distribution_opt.data.clamp_(min=self.eps)
        self.demand_distribution_opt.data.div_(
            self.demand_distribution_opt.data.sum())

    def forward(self, category, inv_count, price, cancel,
                collection_thresholds):
        #print("collection_thresholds",collection_thresholds)
        self.lp_infeasible = 0
        self.cancel_coef_neg_est = self.cancel_coef_est.clamp(max=0)
        self.cancel_coef_neg_opt = self.cancel_coef_opt.clamp(max=0)
        self.nBatch = category.size(0)
        #x = x.view(nBatch, -1)

        #We want to compute everything we can without thresholds first. This will allow us to use our learned parameters to feed the LP
        self.inventory_distribution_raw_est = PoissonFunction(
            self.nKnapsackCategories, self.nThresholds, verbose=-1)(
                self.inventory_lam_est, self.thresholds) + self.eps
        #self.inventory_distribution_norm_est = normalize_JK(self.inventory_distribution_raw_est,dim=1)
        self.inventory_distribution_batch_by_threshold_est = torch.mm(
            category, self.inventory_distribution_raw_est) + self.eps

        self.inventory_distribution_raw_opt = PoissonFunction(
            self.nKnapsackCategories, self.nThresholds, verbose=-1)(
                self.inventory_lam_opt, self.thresholds) + self.eps
        #self.inventory_distribution_norm_opt = normalize_JK(self.inventory_distribution_raw_opt,dim=1)
        self.inventory_distribution_batch_by_threshold_opt = torch.mm(
            category, self.inventory_distribution_raw_opt) + self.eps

        ##Here we'll calculate cancel probability by inventory
        self.belief_cancel_rate_cXt_est = cancel_rate_belief_cXt(
            self.cancel_coef_neg_est, self.cancel_intercept_est,
            self.thresholds.unsqueeze(0).expand(self.nKnapsackCategories,
                                                self.nThresholds))
        belief_fill_rate_cXt_est = 1 - self.belief_cancel_rate_cXt_est
        price_cXt_est = self.prices_est.unsqueeze(1).expand(
            self.nKnapsackCategories, self.nThresholds)

        ##Here we'll calculate cancel probability by inventory
        self.belief_cancel_rate_cXt_opt = cancel_rate_belief_cXt(
            self.cancel_coef_neg_opt, self.cancel_intercept_opt,
            self.thresholds.unsqueeze(0).expand(self.nKnapsackCategories,
                                                self.nThresholds))
        belief_fill_rate_cXt_opt = 1 - self.belief_cancel_rate_cXt_opt
        price_cXt_opt = self.prices_opt.unsqueeze(1).expand(
            self.nKnapsackCategories, self.nThresholds)

        self.belief_total_demand_cXt_est = self.inventory_distribution_raw_est * (
            self.demand_distribution_est.unsqueeze(1).expand(
                self.nKnapsackCategories, self.nThresholds))
        belief_total_demand_c_vector_est = torch.sum(
            self.belief_total_demand_cXt_est, dim=1)

        self.belief_total_demand_cXt_opt = self.inventory_distribution_raw_opt * (
            self.demand_distribution_opt.unsqueeze(1).expand(
                self.nKnapsackCategories, self.nThresholds))
        belief_total_demand_c_vector_opt = torch.sum(
            self.belief_total_demand_cXt_opt, dim=1)

        if self.parametric_knapsack:

            self.belief_total_demand_opt = torch.sum(
                self.belief_total_demand_cXt_opt)
            self.belief_total_cancels_cXt_opt = self.belief_cancel_rate_cXt_opt * self.belief_total_demand_cXt_opt
            self.belief_total_fills_cXt_opt = belief_fill_rate_cXt_opt * self.belief_total_demand_cXt_opt
            self.knapsack_cancels_matrix = torch.div(
                torch.sum(self.belief_total_cancels_cXt_opt, dim=1).expand_as(
                    self.belief_total_cancels_cXt_opt) -
                torch.cumsum(self.belief_total_cancels_cXt_opt, dim=1) +
                self.belief_total_cancels_cXt_opt,
                self.belief_total_demand_opt.expand(self.nKnapsackCategories,
                                                    self.nThresholds))
            self.knapsack_fills_matrix = torch.div(
                torch.sum(self.belief_total_fills_cXt_opt, dim=1).expand_as(
                    self.belief_total_fills_cXt_opt) -
                torch.cumsum(self.belief_total_fills_cXt_opt, dim=1) +
                self.belief_total_fills_cXt_opt,
                self.belief_total_demand_opt.expand(self.nKnapsackCategories,
                                                    self.nThresholds))
            self.knapsack_revenues_matrix = self.knapsack_fills_matrix * price_cXt_opt
            self.knapsack_cancels = self.knapsack_cancels_matrix.view(1, -1)
            self.knapsack_fills = self.knapsack_fills_matrix.view(1, -1)
            self.knapsack_revenues = self.knapsack_revenues_matrix.view(-1)
            Q = self.Q_zeros + self.eps * Variable(
                torch.eye(self.nKnapsackCategories * self.nThresholds))
            self.inequalityMatrix = torch.cat(
                (self.knapsack_cancels, -1 * self.knapsack_fills,
                 self.PosValMatrix))
            self.knapsack_cancels_RHS = torch.sum(
                self.knapsack_cancels_matrix * self.benchmark_thresholds)
            self.knapsack_fills_RHS = torch.sum(self.knapsack_fills_matrix *
                                                self.benchmark_thresholds)
            #self.inequalityVector = torch.cat((self.cancel_rate_param*self.h,-1*self.accept_rate_param*self.h,self.PosValVector))
            self.inequalityVector = torch.cat(
                (self.knapsack_cancels_RHS * self.h,
                 -1 * self.knapsack_fills_RHS * self.h, self.PosValVector))
            try:
                thresholds_raw = QPFunctionJK(verbose=1)(
                    Q, -1 * self.knapsack_revenues, self.inequalityMatrix,
                    self.inequalityVector, self.A, self.b)
                self.thresholds_raw_matrix = thresholds_raw.view(
                    self.nKnapsackCategories, -1)
                #self.accept_rate=1.0*self.accept_rate_original
                #self.cancel_rate=1.0*self.cancel_rate_original
            except AssertionError:
                print("Error solving LP, likely infeasible")
                self.lp_infeasible = 1
                #print("New Accept and Cancel Rates:",self.accept_rate,self.cancel_rate)
            self.thresholds_raw_matrix = F.relu(
                self.thresholds_raw_matrix) + self.eps
        self.thresholds_raw_matrix_norm = normalize_JK(
            self.thresholds_raw_matrix, dim=1)
        #This cXt matrix shows the probability of accepting an order under the learned thresholds, obtained either through direct optimization or through solving an LP
        accept_probability_cXt = torch.cumsum(
            self.thresholds_raw_matrix_norm, dim=1
        )  #this gives the accept probability by cXt under parameterized thresholds

        #category is BxC matrix, so summing across dim 0 gets the number of accepted orders per category
        accept_probability_collection_bXt = torch.cumsum(collection_thresholds,
                                                         dim=1)
        reject_probability_collection_bXt = 1 - accept_probability_collection_bXt
        accept_percent_collection_bXt = accept_probability_collection_bXt * self.inventory_distribution_batch_by_threshold_est
        accept_percent_collection_b_vector = torch.sum(
            accept_percent_collection_bXt, dim=1
        ).squeeze(
        )  #This is the believed acceptance rate of general orders of the categories corresponding with the batch under the collection thresholds
        reject_percent_collection_b_vector = 1 - accept_percent_collection_b_vector
        self.batch_total_demand_b_vector = (
            1 / accept_percent_collection_b_vector)  #.clamp(min=0,max=100)

        #new to v37
        reject_percent_collection_expanded_bXt = reject_percent_collection_b_vector.unsqueeze(
            1).expand(self.nBatch, self.nThresholds)
        self.truncated_orders_distribution_bXt = torch.div(
            reject_probability_collection_bXt *
            self.inventory_distribution_batch_by_threshold_est,
            reject_percent_collection_expanded_bXt + self.eps)
        truncated_demand_b_vector = self.batch_total_demand_b_vector - 1  #self.belief_total_demand_cXt
        truncated_demand_bXt = truncated_demand_b_vector.unsqueeze(1).expand(
            self.nBatch,
            self.nThresholds) * self.truncated_orders_distribution_bXt
        batch_total_demand_bXt = truncated_demand_bXt + inv_count
        self.batch_total_demand_cXt = torch.mm(category.t(),
                                               batch_total_demand_bXt)
        batch_total_demand_c_vector = torch.sum(self.batch_total_demand_cXt,
                                                dim=1)
        batch_zero_demand_c_vector = 1 - batch_total_demand_c_vector.ge(0)
        #batch_supplement_demand = torch.masked_select(belief_total_demand_c_vector_est,batch_zero_demand_c_vector)
        self.estimated_batch_total_demand = torch.sum(
            self.batch_total_demand_b_vector
        )  #+torch.sum(batch_supplement_demand)

        #Now we want to see how accurate our inventory distributions are for the batch
        accept_probability_batch_by_threshold = CumSumNoGrad(
            verbose=-1)(collection_thresholds) + self.eps
        self.inventory_distribution_batch_by_thresholds = torch.mm(
            category, self.inventory_distribution_raw_est)
        arrival_probability_batch_by_threshold_unnormed = self.inventory_distribution_batch_by_thresholds * accept_probability_batch_by_threshold
        arrival_probability_batch_by_threshold = torch.div(
            arrival_probability_batch_by_threshold_unnormed,
            torch.sum(arrival_probability_batch_by_threshold_unnormed,
                      dim=1).unsqueeze(1).expand_as(
                          arrival_probability_batch_by_threshold_unnormed))
        log_arrival_prob = torch.log(arrival_probability_batch_by_threshold +
                                     self.eps)

        #Like we do for inventory, we want to measure the accuracy of our cancel params for the batch
        self.belief_cancel_rate_bXt = torch.mm(category,
                                               self.belief_cancel_rate_cXt_est)
        belief_fill_rate_bXt = 1 - self.belief_cancel_rate_bXt
        self.belief_cancel_rate_b_vector = torch.sum(
            self.belief_cancel_rate_bXt * inv_count, dim=1).squeeze()
        belief_fill_rate_b_vector = 1 - self.belief_cancel_rate_b_vector
        log_cancel_prob = torch.log(
            torch.cat((belief_fill_rate_b_vector.unsqueeze(1),
                       self.belief_cancel_rate_b_vector.unsqueeze(1)), 1) +
            self.eps)

        self.belief_category_dist_bXc = self.demand_distribution_est.unsqueeze(
            0).expand(self.nBatch, self.nKnapsackCategories)
        log_category_prob = torch.log(self.belief_category_dist_bXc + self.eps)

        ##This is new in v37. We want to combine the actual results observed in the batch but add in estimated effects of truncation

        accept_probability_using_threshold_params_bXt = torch.mm(
            category, accept_probability_cXt)
        truncated_accept_estimate = truncated_demand_bXt * accept_probability_using_threshold_params_bXt  #This is the number of truncated orders we expect to accept (using param thresholds) at each inventory level corresponding to each order in the batch
        truncated_cancel_estimate = truncated_accept_estimate * self.belief_cancel_rate_bXt
        truncated_fill_estimate = truncated_accept_estimate * belief_fill_rate_bXt
        truncated_revenue_estimate = truncated_fill_estimate * (
            price.unsqueeze(1).expand(self.nBatch, self.nThresholds))
        truncated_revenue_estimate_sum = torch.sum(truncated_revenue_estimate)
        self.truncated_cancel_estimate_sum = torch.sum(
            truncated_cancel_estimate)
        truncated_fill_estimate_sum = torch.sum(truncated_fill_estimate)
        self.truncated_accept_estimate_sum = torch.sum(
            truncated_accept_estimate)

        ##This is new in v37. We want to combine the actual results observed in the batch but add in estimated effects of truncation
        fill = 1 - cancel
        batch_cancel_bXt = cancel.unsqueeze(1).expand(
            self.nBatch, self.nThresholds
        ) * inv_count * accept_probability_using_threshold_params_bXt
        batch_fill_bXt = fill.unsqueeze(1).expand(
            self.nBatch, self.nThresholds
        ) * inv_count * accept_probability_using_threshold_params_bXt
        batch_cancel_b_vector = torch.sum(batch_cancel_bXt, dim=1).squeeze()
        batch_fill_b_vector = torch.sum(batch_fill_bXt, dim=1).squeeze()
        batch_accept_b_vector = torch.sum(
            inv_count * accept_probability_using_threshold_params_bXt,
            dim=1).squeeze()
        #print("sanity check",batch_accept_b_vector, batch_fill_b_vector+batch_cancel_b_vector)
        #print("sanity check 2", torch.sum(batch_accept_b_vector), torch.sum(batch_fill_b_vector+batch_cancel_b_vector))

        batch_revenue_b_vector = price * batch_fill_b_vector
        self.batch_fill_sum = torch.sum(batch_fill_b_vector, dim=0)
        self.batch_revenue_sum = torch.sum(batch_revenue_b_vector, dim=0)
        self.batch_cancel_sum = torch.sum(batch_cancel_b_vector, dim=0)
        self.batch_accept_sum = torch.sum(batch_accept_b_vector, dim=0)

        new_objective_loss = -(1.0 / 50000) * (truncated_revenue_estimate_sum +
                                               self.batch_revenue_sum)
        new_cancel_constraint_loss = self.truncated_cancel_estimate_sum + self.batch_cancel_sum - (
            self.truncated_accept_estimate_sum +
            self.batch_accept_sum) * self.cancel_rate_evaluation
        new_accept_constraint_loss = (1.0 / 7.0) * (
            (self.truncated_accept_estimate_sum + self.batch_accept_sum) *
            self.accept_rate_evaluation - truncated_fill_estimate_sum -
            self.batch_fill_sum)
        #new_cancel_constraint_loss = truncated_cancel_estimate_sum+self.batch_cancel_sum-self.estimated_batch_total_demand*self.cancel_rate_param
        #new_accept_constraint_loss = (1.0/7.0)*(self.estimated_batch_total_demand*self.accept_rate_param-truncated_fill_estimate_sum-self.batch_fill_sum)

        observed_cancel_constraint_loss = self.batch_cancel_sum - (
            self.batch_accept_sum) * self.cancel_rate_evaluation
        observed_accept_constraint_loss = (
            1.0 / 7.0) * (self.batch_accept_sum * self.accept_rate_evaluation -
                          self.batch_fill_sum)

        return new_objective_loss, new_cancel_constraint_loss, new_accept_constraint_loss, arrival_probability_batch_by_threshold, log_arrival_prob, log_cancel_prob, log_category_prob, self.estimated_batch_total_demand, observed_cancel_constraint_loss, observed_accept_constraint_loss, self.lp_infeasible
Exemplo n.º 8
0
class FilterStripe(nn.Conv2d):#卷积层+FS层
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(FilterStripe, self).__init__(in_channels, out_channels, kernel_size, stride, kernel_size // 2, groups=1, bias=False)
        self.BrokenTarget = None
        self.FilterSkeleton = Parameter(torch.ones(self.out_channels, self.kernel_size[0], self.kernel_size[1]), requires_grad=True)#FS层初始化

    def forward(self, x):#forward()是自动调用的,x:[N,通道数,width,height]
        if self.BrokenTarget is not None:
            #out:[N,通道数,width,height]
            out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0], int(np.ceil(x.shape[2] / self.stride[0])), int(np.ceil(x.shape[3] / self.stride[1])))#ceil() 函数返回数字的上入整数
            if x.is_cuda:
                out = out.cuda()
            x = F.conv2d(x, self.weight)#卷积输出
            l, h = 0, 0
            for i in range(self.BrokenTarget.shape[0]):
                for j in range(self.BrokenTarget.shape[1]):
                    h += self.FilterSkeleton[:, i, j].sum().item()#FS层每个通道对应的值相加
                    out[:, self.FilterSkeleton[:, i, j]] += self.shift(x[:, l:h], i, j)[:, :, ::self.stride[0], ::self.stride[1]]#获得每个通道对应索引的输出
                    l += self.FilterSkeleton[:, i, j].sum().item()
            return out#输出
        else:
            #unsqueeze(1)在第二个维度增加一个维度
            return F.conv2d(x, self.weight * self.FilterSkeleton.unsqueeze(1), stride=self.stride, padding=self.padding, groups=self.groups)

    def prune_in(self, in_mask=None):#in_mask掩膜
        #self.weight.shape:[out_channel,k,k,in_channel]
        print(self.weight.shape)
        self.weight = Parameter(self.weight[:, in_mask])#??????????
        print(self.weight)
        self.in_channels = in_mask.sum().item()

    def prune_out(self, threshold):#threshold为阈值
        out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0#获得掩膜
        if out_mask.sum() == 0:
            print(out_mask.sum())
            out_mask[0] = True
        self.weight = Parameter(self.weight[out_mask])#卷积核掩膜化
        self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask], requires_grad=True)#FS层掩膜化
        self.out_channels = out_mask.sum().item()#获取输出通道
        return out_mask#掩膜

    def _break(self, threshold):
        self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1))#卷积核与FS层相乘
        self.FilterSkeleton = Parameter((self.FilterSkeleton.abs() > threshold), requires_grad=False)#FS层大于阈值的为true
        if self.FilterSkeleton.sum() == 0:
            self.FilterSkeleton.data[0][0][0] = True
        self.out_channels = self.FilterSkeleton.sum().item()
        self.BrokenTarget = self.FilterSkeleton.sum(dim=0)
        self.kernel_size = (1, 1)
        #permute()将tensor的维度换位。
        # print(self.FilterSkeleton.permute(1, 2, 0).reshape(-1))
        # print(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1))
        self.weight = Parameter(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)[self.FilterSkeleton.permute(1, 2, 0).reshape(-1)])#掩膜化
        # print(self.weight)

    def update_skeleton(self, sr, threshold):
        self.FilterSkeleton.grad.data.add_(sr * torch.sign(self.FilterSkeleton.data))#FS层的梯度更新,加入L1范数的导数
        mask = self.FilterSkeleton.data.abs() > threshold
        self.FilterSkeleton.data.mul_(mask)#掩码化
        self.FilterSkeleton.grad.data.mul_(mask)#掩码化
        out_mask = mask.sum(dim=(1, 2)) != 0#????
        return out_mask

    def shift(self, x, i, j):
        return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j, j - self.BrokenTarget.shape[0] // 2, self.BrokenTarget.shape[0] // 2 - i, i - self.BrokenTarget.shape[1] // 2), 'constant', 0)

    def extra_repr(self):
        s = ('{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        return s.format(**self.__dict__)
Exemplo n.º 9
0
class Frankenstein(nn.Module):
    def __init__(self,
                 x=47764,
                 h=128,
                 L=16,
                 v_t=3620,
                 W=32,
                 R=8,
                 N=512,
                 bs=1,
                 reset=True,
                 palette=False):
        super(Frankenstein, self).__init__()

        self.reset = reset
        # debugging usages
        self.last_state_dict = None
        '''PARAMETERS'''
        self.x = x
        self.h = h
        self.L = L
        self.v_t = v_t
        self.W = W
        self.R = R
        self.N = N
        self.bs = bs
        self.E_t = W * R + 3 * W + 5 * R + 3
        '''CONTROLLER'''
        self.RNN_list = nn.ModuleList()
        for _ in range(self.L):
            self.RNN_list.append(
                RNN_Unit(self.x, self.R, self.W, self.h, self.bs))
        self.hidden_previous_timestep = Parameter(torch.Tensor(
            self.bs, self.L, self.h).cuda(),
                                                  requires_grad=False)
        self.W_y = Parameter(torch.Tensor(self.L * self.h, self.v_t).cuda())
        self.W_E = Parameter(torch.Tensor(self.L * self.h, self.E_t).cuda())
        self.b_y = Parameter(torch.Tensor(self.v_t).cuda())
        self.b_E = Parameter(torch.Tensor(self.E_t).cuda())
        '''MEMORY'''
        # p, (N), should be simplex bound
        self.precedence_weighting = Parameter(torch.Tensor(self.bs,
                                                           self.N).cuda(),
                                              requires_grad=False)
        # (N,N)
        self.temporal_memory_linkage = Parameter(torch.Tensor(
            self.bs, self.N, self.N).cuda(),
                                                 requires_grad=False)
        # (N,W)
        self.memory = Parameter(torch.Tensor(self.N, self.W).cuda(),
                                requires_grad=False)
        # (N, R).
        self.last_read_weightings = Parameter(torch.Tensor(
            self.bs, self.N, self.R).cuda(),
                                              requires_grad=False)
        # u_t, (N)
        self.last_usage_vector = Parameter(torch.Tensor(self.bs,
                                                        self.N).cuda(),
                                           requires_grad=False)
        # store last write weightings for the calculation of usage vector
        self.last_write_weighting = Parameter(torch.Tensor(self.bs,
                                                           self.N).cuda(),
                                              requires_grad=False)

        self.palette = None
        if palette:
            stdv = 1.0
            self.memory.data.uniform_(-stdv, stdv)
            self.palette = palette
            self.initialz = self.memory.data

        self.first_t_flag = True
        '''COMPUTER'''
        self.last_read_vector = Parameter(torch.Tensor(self.bs, self.W,
                                                       self.R).cuda(),
                                          requires_grad=False)
        self.W_r = Parameter(torch.Tensor(self.W * self.R, self.v_t).cuda())

        self.reset_parameters()

    def reset_parameters(self):
        # if debug:
        #     print("parameters are reset")
        '''Controller'''
        for module in self.RNN_list:
            # this should iterate over RNN_Units only
            module.reset_parameters()
        self.hidden_previous_timestep.zero_()
        stdv = 1.0 / math.sqrt(self.v_t)
        self.W_y.data.uniform_(-stdv, stdv)
        self.b_y.data.uniform_(-stdv, stdv)
        stdv = 1.0 / math.sqrt(self.E_t)
        self.W_E.data.uniform_(-stdv, stdv)
        self.b_E.data.uniform_(-stdv, stdv)
        '''Memory'''
        self.precedence_weighting.zero_()
        self.last_usage_vector.zero_()
        self.last_read_weightings.zero_()
        self.last_write_weighting.zero_()
        self.temporal_memory_linkage.zero_()
        # memory must be initialized like this, otherwise usage vector will be stuck at zero.
        stdv = 1.0
        self.memory.data.uniform_(-stdv, stdv)
        self.first_t_flag = True
        '''Computer'''
        # see paper, paragraph 2 page 7
        self.last_read_vector.zero_()
        stdv = 1.0 / math.sqrt(self.v_t)
        self.W_r.data.uniform_(-stdv, stdv)

    def new_sequence_reset(self):
        '''
        The biggest question is whether to reset memory every time a new sequence is taken in.
        My take is to not reset the memory, but this might not be the best strategy there is.
        If memory is not reset at each new sequence, then we should not reset the memory at all?
        :return:
        '''
        # if debug:
        #     print('new sequence reset')
        '''controller'''
        self.hidden_previous_timestep = Parameter(torch.Tensor(
            self.bs, self.L, self.h).zero_().cuda(),
                                                  requires_grad=False)
        for RNN in self.RNN_list:
            RNN.new_sequence_reset()
        self.W_y = Parameter(self.W_y.data)
        self.b_y = Parameter(self.b_y.data)
        self.W_E = Parameter(self.W_E.data)
        self.b_E = Parameter(self.b_E.data)
        '''memory'''

        if self.reset:
            if self.palette:
                self.memory.data = self.initialz
            else:
                # we will reset the memory altogether.
                # TODO The question is, should we reset the memory to a fixed state? There are good arguments for it.
                stdv = 1.0
                # gradient should not carry over, since at this stage, requires_grad on this parameter should be False.
                self.memory.data.uniform_(-stdv, stdv)
                # TODO is there a reason to reinitialize the parameter object? I don't think so. The graph is not carried over.

            self.last_usage_vector.zero_()
            self.precedence_weighting.zero_()
            self.temporal_memory_linkage.zero_()
            self.last_read_weightings.zero_()
            self.last_write_weighting.zero_()
        # self.last_usage_vector = Parameter(torch.Tensor(self.bs, self.N).zero_().cuda(), requires_grad=False)
        # self.precedence_weighting = Parameter(torch.Tensor(self.bs, self.N).zero_().cuda(),requires_grad=False)
        # self.temporal_memory_linkage = Parameter(torch.Tensor(self.bs, self.N, self.N).zero_().cuda(),requires_grad=False)
        # # with a new sequence, the calculation of forward weighting, for example, still requires the last_read_weighting
        # self.last_read_weightings = Parameter(torch.Tensor(self.bs, self.N, self.R).zero_().cuda(),requires_grad=False)
        # self.last_write_weighting = Parameter(torch.Tensor(self.bs, self.N).zero_().cuda(),requires_grad=False)
        self.first_t_flag = True
        '''computer'''
        self.last_read_vector = Parameter(torch.Tensor(self.bs, self.W,
                                                       self.R).zero_().cuda(),
                                          requires_grad=False)
        self.W_r = Parameter(self.W_r.data)

    def forward(self, input):
        if (input != input).any():
            raise ValueError("We have NAN in inputs")
        input_x_t = torch.cat((input, self.last_read_vector.view(self.bs, -1)),
                              dim=1)
        '''Controller'''
        hidden_previous_layer = Variable(
            torch.Tensor(self.bs, self.h).zero_().cuda())
        hidden_this_timestep = Variable(
            torch.Tensor(self.bs, self.L, self.h).cuda())
        for i in range(self.L):
            hidden_output = self.RNN_list[i](
                input_x_t, self.hidden_previous_timestep[:, i, :],
                hidden_previous_layer)
            if (hidden_output != hidden_output).any():
                raise ValueError("We have NAN in controller output.")
            hidden_this_timestep[:, i, :] = hidden_output
            hidden_previous_layer = hidden_output

        flat_hidden = hidden_this_timestep.view((self.bs, self.L * self.h))
        output = torch.matmul(flat_hidden, self.W_y)
        interface_input = torch.matmul(flat_hidden, self.W_E)
        # this detaches hidden from previous hidden.
        self.hidden_previous_timestep = Parameter(hidden_this_timestep.data,
                                                  requires_grad=False)
        '''interface'''
        last_index = self.W * self.R

        # Read keys, each W dimensions, [W*R] in total
        # no processing needed
        # this is the address keys, not the contents
        read_keys = interface_input[:, 0:last_index].contiguous().view(
            self.bs, self.W, self.R)

        # Read strengths, [R]
        # 1 to infinity
        # slightly different equation from the paper, should be okay
        read_strengths = interface_input[:, last_index:last_index + self.R]
        last_index = last_index + self.R
        read_strengths = 1 - nn.functional.logsigmoid(read_strengths)

        # Write key, [W]
        write_key = interface_input[:, last_index:last_index + self.W]
        last_index = last_index + self.W

        # write strength beta, [1]
        write_strength = interface_input[:, last_index:last_index + 1]
        last_index = last_index + 1
        write_strength = 1 - nn.functional.logsigmoid(write_strength)

        # erase strength, [W]
        erase_vector = interface_input[:, last_index:last_index + self.W]
        last_index = last_index + self.W
        erase_vector = torch.sigmoid(erase_vector)

        # write vector, [W]
        write_vector = interface_input[:, last_index:last_index + self.W]
        last_index = last_index + self.W

        # R free gates? [R]
        free_gates = interface_input[:, last_index:last_index + self.R]

        last_index = last_index + self.R
        free_gates = torch.sigmoid(free_gates)

        # allocation gate [1]
        allocation_gate = interface_input[:, last_index:last_index + 1]
        last_index = last_index + 1
        allocation_gate = torch.sigmoid(allocation_gate)

        # write gate [1]
        write_gate = interface_input[:, last_index:last_index + 1]
        last_index = last_index + 1
        write_gate = torch.sigmoid(write_gate)

        # read modes [R,3]
        read_modes = interface_input[:, last_index:last_index + self.R * 3]
        read_modes = read_modes.contiguous().view(self.bs, self.R, 3)
        read_modes = nn.functional.softmax(read_modes, dim=2)
        '''memory'''
        memory_retention = self.memory_retention(free_gates)
        # usage vector update must be called before allocation weighting.
        self.update_usage_vector(memory_retention)
        allocation_weighting = self.allocation_weighting()

        write_weighting = self.write_weighting(write_key, write_strength,
                                               allocation_gate, write_gate,
                                               allocation_weighting)
        self.write_to_memory(write_weighting, erase_vector, write_vector)

        # update some
        self.update_temporal_linkage_matrix(write_weighting)
        self.update_precedence_weighting(write_weighting)

        forward_weighting = self_weighting()
        backward_weighting = self.backward_weighting()

        read_weightings = self.read_weightings(forward_weighting,
                                               backward_weighting, read_keys,
                                               read_strengths, read_modes)
        # read from memory last, a new modification.
        read_vector = Parameter(self.read_memory(read_weightings).data,
                                requires_grad=False)
        # DEBUG NAN
        if (read_vector != read_vector).any():
            # this is a problem! TODO
            raise ValueError("nan is found.")
        '''back to computer'''
        output2 = output + torch.matmul(
            read_vector.view(self.bs, self.W * self.R), self.W_r)

        # update the last weightings
        self.last_read_vector = read_vector
        self.last_read_weightings = Parameter(read_weightings.data,
                                              requires_grad=False)
        self.last_write_weighting = Parameter(write_weighting.data,
                                              requires_grad=False)

        self.first_t_flag = False

        if debug:
            test_simplex_bound(self.last_read_weightings)
            test_simplex_bound(self.last_write_weighting)
            if (output2 != output2).any():
                raise ValueError("nan is found.")

        return output2

    def write_content_weighting(self, write_key, key_strength, eps=1e-8):
        '''

        :param memory: M, (N, W)
        :param write_key: k, (W), R, desired content
        :param key_strength: \beta, (1) [1, \infty)
        :param index: i, lookup on memory[i]
        :return: most similar weighted: C(M,k,\beta), (N), (0,1)
        '''

        # memory will be (N,W)
        # write_key will be (bs, W)
        # I expect a return of (N,bs), which marks the similiarity of each W with each mem loc

        # (self.bs, self.N)
        innerprod = torch.matmul(write_key, self.memory.t())
        # (parm.N)
        memnorm = torch.norm(self.memory, 2, 1)
        # (self.bs)
        writenorm = torch.norm(write_key, 2, 1)
        # (self.N, self.bs)
        normalizer = torch.ger(memnorm, writenorm)
        similarties = innerprod / normalizer.t().clamp(min=eps)
        similarties = similarties * key_strength.expand(-1, self.N)
        normalized = softmax(similarties, dim=1)
        if debug:
            if (normalized != normalized).any():
                task_dir = os.path.dirname(abspath(__file__))
                save_dir = Path(task_dir) / "saves" / "keykey.pkl"
                pickle.dump((write_key.cpu(), key_strength.cpu()),
                            save_dir.open('wb'))
                raise ValueError("NA found in write content weighting")
        return normalized

    def read_content_weighting(self, read_keys, key_strengths, eps=1e-8):
        '''
        :param memory: M, (N, W)
        :param read_keys: k^r_t, (W,R), R, desired content
        :param key_strength: \beta, (R) [1, \infty)
        :param index: i, lookup on memory[i]
        :return: most similar weighted: C(M,k,\beta), (N, R), (0,1)
        '''
        '''
            torch definition
            def cosine_similarity(x1, x2, dim=1, eps=1e-8):
                w12 = torch.sum(x1 * x2, dim)
                w1 = torch.norm(x1, 2, dim)
                w2 = torch.norm(x2, 2, dim)
                return w12 / (w1 * w2).clamp(min=eps)
        '''

        innerprod = torch.matmul(self.memory.unsqueeze(0), read_keys)
        # this is confusing. matrix[n] access nth row, not column
        # this is very counter-intuitive, since columns have meaning,
        # because they represent vectors
        mem_norm = torch.norm(self.memory, p=2, dim=1)
        read_norm = torch.norm(read_keys, p=2, dim=1)
        mem_norm = mem_norm.unsqueeze(1)
        read_norm = read_norm.unsqueeze(1)
        # (batch_size, locations, read_heads)
        normalizer = torch.matmul(mem_norm, read_norm)

        # if transposed then similiarities[0] refers to the first read key
        similarties = innerprod / normalizer.clamp(min=eps)
        weighted = similarties * key_strengths.unsqueeze(1).expand(
            -1, self.N, -1)
        ret = softmax(weighted, dim=1)
        return ret

    # the highest freed will be retained? What does it mean?
    def memory_retention(self, free_gate):
        '''

        :param free_gate: f, (R), [0,1], from interface vector
        :param read_weighting: w^r_t, (N, R), simplex bounded,
               note it's from previous timestep.
        :return: \psi, (N), [0,1]
        '''

        # a free gate belongs to a read head.
        # a single read head weighting is a (N) dimensional simplex bounded value

        # (N, R)
        inside_bracket = 1 - self.last_read_weightings * free_gate.unsqueeze(
            1).expand(-1, self.N, -1)
        ret = torch.prod(inside_bracket, 2)
        return ret

    def update_usage_vector(self, memory_retention):
        '''

        :param memory_retention: \psi_t, (N), simplex bound
        :return: u_t, (N), [0,1], the next usage
        '''
        if self.first_t_flag:
            ret = Parameter(torch.Tensor(self.bs, self.N).zero_().cuda(),
                            requires_grad=False)
            return ret
        ret = (self.last_usage_vector + self.last_write_weighting - self.last_usage_vector * self.last_write_weighting) \
              * memory_retention

        # Here we should use .data instead? Like:
        # self.usage_vector.data=ret.data
        # Usage vector contain all computation history,
        # which is not necessary? I'm not sure, maybe the write weighting should be back_propped here?
        # We reset usage vector for every seq, but should we for every timestep?
        self.last_usage_vector = Parameter(ret.data, requires_grad=False)
        return ret

    def allocation_weighting(self):
        '''
        Sorts the memory by usages first.
        Then perform calculation depending on the sort order.

        The alloation_weighting of the third least used memory is calculated as follows:
        Find the least used and second least used. Multiply their usages.
        Multiply the product with (1-usage of the third least), return.

        Do not confuse the sort order and the memory's natural location.
        Verify backprop.

        :param usage_vector: u_t, (N), [0,1]
        :return: allocation_wighting: a_t, (N), simplex bound
        '''

        # not the last usage, since we will update usage before this
        sorted, indices = self.last_usage_vector.sort(dim=1)
        cum_prod = torch.cumprod(sorted, 1)
        # notice the index on the product
        cum_prod = torch.cat(
            [Variable(torch.ones(self.bs, 1).cuda()), cum_prod], 1)[:, :-1]
        sorted_inv = 1 - sorted
        allocation_weighting = sorted_inv * cum_prod
        # to shuffle back in place
        ret = torch.gather(allocation_weighting, 1, indices)
        if debug:
            if (ret != ret).any():
                raise ValueError("NA found in allocation weighting")
        return ret

    def write_weighting(self, write_key, write_strength, allocation_gate,
                        write_gate, allocation_weighting):
        '''
        calculates the weighting on each memory cell when writing a new value in

        :param memory: M, (N, W), memory block
        :param write_key: k^w_t, (W), R, the key that is to be written
        :param write_strength: \beta, (1) bigger it is, stronger it concentrates the content weighting
        :param allocation_gate: g^a_t, (1), balances between write by content and write by allocation gate
        :param write_gate: g^w_t, (1), overall strength of the write signal
        :param allocation_weighting: see above.
        :return: write_weighting: (N), simplex bound
        '''
        # measures content similarity
        content_weighting = self.write_content_weighting(
            write_key, write_strength)
        write_weighting = write_gate * (
            allocation_gate * allocation_weighting +
            (1 - allocation_gate) * content_weighting)
        if debug:
            test_simplex_bound(write_weighting, 1)
        return write_weighting

    def update_precedence_weighting(self, write_weighting):
        '''

        :param write_weighting: (N)
        :return: self.precedence_weighting: (N), simplex bound
        '''
        # this is the bug. I called the python default sum() instead of torch.sum()
        # Took me 3 hours.
        # sum_ww=sum(write_weighting,1)
        sum_ww = torch.sum(write_weighting, dim=1)
        self.precedence_weighting = Parameter(
            ((1 - sum_ww).unsqueeze(1) * self.precedence_weighting +
             write_weighting).data,
            requires_grad=False)
        if debug:
            test_simplex_bound(self.precedence_weighting, 1)
        return self.precedence_weighting

    def update_temporal_linkage_matrix(self, write_weighting):
        '''

        :param write_weighting: (N)
        :param precedence_weighting: (N), simplex bound
        :return: updated_temporal_linkage_matrix
        '''

        # TODO We need to mathematically understand why this function will
        # TODO maintain the simplex bound condition.
        if self.first_t_flag:
            return self.temporal_memory_linkage
        else:
            ww_j = write_weighting.unsqueeze(1).expand(-1, self.N, -1)
            ww_i = write_weighting.unsqueeze(2).expand(-1, -1, self.N)
            p_j = self.precedence_weighting.unsqueeze(1).expand(-1, self.N, -1)
            batch_temporal_memory_linkage = self.temporal_memory_linkage.expand(
                self.bs, -1, -1)
            newtml = Parameter(
                ((1 - ww_j - ww_i) * batch_temporal_memory_linkage +
                 ww_i * p_j).data,
                requires_grad=False)
            is_cuda = ww_j.is_cuda
            if is_cuda:
                ### WHAT IS THIS?
                idx = torch.arange(0, self.N, out=torch.cuda.LongTensor())
            else:
                idx = torch.arange(0, self.N, out=torch.LongTensor())
            newtml[:, idx, idx] = 0
            if debug:
                try:
                    test_simplex_bound(newtml, 1)
                    test_simplex_bound(newtml.transpose(1, 2), 1)
                except ValueError:
                    traceback.print_exc()
                    print("precedence close to one?",
                          self.precedence_weighting.sum() > 1)
                    raise
            self.temporal_memory_linkage = Parameter(newtml.data,
                                                     requires_grad=False)
            return self.temporal_memory_linkage

    def backward_weighting(self):
        '''
        :return: backward_weighting: b^i_t, (N,R)
        '''
        ret = torch.matmul(self.temporal_memory_linkage,
                           self.last_read_weightings)
        if debug:
            test_simplex_bound(ret, 1)
        return ret

    def forward_weighting(self):
        '''

        :return: forward_weighting: f^i_t, (N,R)
        '''
        ret = torch.matmul(self.temporal_memory_linkage.transpose(1, 2),
                           self.last_read_weightings)
        if debug:
            test_simplex_bound(ret, 1)
        return ret

    # TODO sparse update, skipped because it's for performance improvement.

    def read_weightings(self, forward_weighting, backward_weighting, read_keys,
                        read_strengths, read_modes):
        '''

        :param forward_weighting: (bs,N,R)
        :param backward_weighting: (bs,N,R)
        ****** content_weighting: C, (bs,N,R), (0,1)
        :param read_keys: k^w_t, (bs,W,R)
        :param read_key_strengths: (bs,R)
        :param read_modes: /pi_t^i, (bs,R,3)
        :return: read_weightings: w^r_t, (bs,N,R)

        '''

        content_weighting = self.read_content_weighting(
            read_keys, read_strengths)
        if debug:
            test_simplex_bound(content_weighting, 1)
            test_simplex_bound(backward_weighting, 1)
            test_simplex_bound(forward_weighting, 1)
        # has dimension (bs,3,N,R)
        all_weightings = torch.stack(
            [backward_weighting, content_weighting, forward_weighting], dim=1)
        # permute to dimension (bs,R,N,3)
        all_weightings = all_weightings.permute(0, 3, 2, 1)
        # this is becuase torch.matmul is designed to iterate all dimension excluding the last two
        # dimension (bs,R,3,1)
        read_modes = read_modes.unsqueeze(3)
        # dimension (bs,N,R)
        read_weightings = torch.matmul(all_weightings,
                                       read_modes).squeeze(3).transpose(1, 2)
        # last read weightings
        if debug:
            # if the second test passes, how come the first one does not?
            test_simplex_bound(self.last_read_weightings, 1)
            test_simplex_bound(read_weightings, 1)
            if (read_weightings != read_weightings).any():
                raise ValueError("NAN is found")
        return read_weightings

    def read_memory(self, read_weightings):
        '''

        memory: (N,W)
        read weightings: (N,R)

        :return: read_vectors: [r^i_R], (W,R)
        '''

        return torch.matmul(self.memory.t(), read_weightings)

    def write_to_memory(self, write_weighting, erase_vector, write_vector):
        '''

        :param write_weighting: the strength of writing
        :param erase_vector: e_t, (W), [0,1]
        :param write_vector: w^w_t, (W),
        :return:
        '''
        term1_2 = torch.matmul(write_weighting.unsqueeze(2),
                               erase_vector.unsqueeze(1))
        # term1=self.memory.unsqueeze(0)*Variable(torch.ones((self.bs,self.N,self.W)).cuda()-term1_2.data)
        term1 = self.memory.unsqueeze(0) * (1 - term1_2)
        term2 = torch.matmul(write_weighting.unsqueeze(2),
                             write_vector.unsqueeze(1))
        self.memory = Parameter(torch.mean(term1 + term2, dim=0).data,
                                requires_grad=False)