def get_mnist_cifar_dl(mnist_classes=(0,1), cifar_classes=None, c={0,1,2,3,4}, bs=256, 
                       randomize_mnist=False, randomize_cifar=False):
    (Xtr, Ytr), (Xte, Yte) = get_mnist_cifar(mnist_classes=mnist_classes, cifar_classes=cifar_classes, 
                                             c=c, randomize_mnist=randomize_mnist, randomize_cifar=randomize_cifar)
    tr_dl = utils._to_dl(Xtr, Ytr, bs=bs, shuffle=True)
    te_dl = utils._to_dl(Xte, Yte, bs=100, shuffle=False)
    return tr_dl, te_dl
    def evaluate_attack(self, dl, model):
        model = model.to(self.device)
        Xa, Ya, Yh, P = [], [], [], []

        for xb, yb in dl:
            xb, yb = xb.to(self.device), yb.to(self.device)
            delta = self.perturb(xb, yb, model)
            xba = xb + delta

            with torch.no_grad():
                out = model(xba).detach()
            yh = torch.argmax(out, dim=1)
            xb, yb, yh, xba, delta = xb.cpu(), yb.cpu(), yh.cpu(), xba.cpu(
            ), delta.cpu()

            Ya.append(yb)
            Yh.append(yh)
            Xa.append(xba)
            P.append(delta)

        Xa, Ya, Yh, P = map(torch.cat, [Xa, Ya, Yh, P])
        ta_dl = utils._to_dl(Xa, Ya, dl.batch_size)
        acc, loss = utils.compute_loss_and_accuracy_from_dl(ta_dl,
                                                            model,
                                                            F.cross_entropy,
                                                            device=self.device)
        return {
            'acc': acc.item(),
            'loss': loss.item(),
            'ta_dl': ta_dl,
            'Xa': Xa.numpy(),
            'Ya': Ya.numpy(),
            'Yh': Yh.numpy(),
            'P': P.numpy()
        }
    def _eval(self, dl, model, eps, pos_dir):
        model = model.to(self.device)
        X, Xa, Ya, P = [], [], [], []

        for xb, yb in dl:
            xb, yb = xb.to(self.device), yb.to(self.device)
            xba, ptb = self._perturb(xb, yb, eps, pos_dir)
            X.append(xb.cpu())
            Xa.append(xba.cpu())
            Ya.append(yb.cpu())
            P.append(ptb.cpu())

        X, Xa, Ya, P = map(torch.cat, [X, Xa, Ya, P])
        ta_dl = utils._to_dl(Xa, Ya, dl.batch_size)
        acc_func = utils.compute_loss_and_accuracy_from_dl
        acc, loss = acc_func(ta_dl, model, F.cross_entropy, device=self.device)

        return {
            'P': P,
            'X': X,
            'Xa': Xa,
            'Ya': Ya,
            'acc': acc.item(),
            'loss': loss.item(),
            'dl': ta_dl,
            'pos_dir': pos_dir
        }
def get_randomized_loader(dl, W, coordinates):
    """
    dl: dataloader
    W: rotation matrix
    coordinates: list of coordinates to randomize
    output: randomized dataloader
    """
    def _randomize(X, coords):
        p = torch.randperm(len(X))
        for c in coords:
            X[:, c] = X[p, c]
        return X

    # rotate data
    X, Y = map(copy.deepcopy, dl.dataset.tensors)
    dim = X.shape[1]
    if W is None: W = np.eye(dim)

    rt_X = torch.Tensor(X.numpy().dot(W.T))
    rand_rt_X = _randomize(rt_X, coordinates)
    rand_X = torch.Tensor(rand_rt_X.numpy().dot(W))

    return utils._to_dl(rand_X, Y, dl.batch_size)
def get_binary_loader(dl, y1, y2):
    X, Y = utils.extract_numpy_from_loader(dl)
    X, Y = get_binary_datasets(X, Y, y1, y2)
    return utils._to_dl(X, Y, bs=dl.batch_size)