def model(batch_size, lr, dims, numEpochs, cuda, alpha, pathLoad, pathSave, epochSave, activation, modelType,
             computeEigVectorsOnline, regularizerFcn, _seed, _run):
    """
    Function for creating and training MLPs on MNIST.
    :param batch_size: specifies batch size
    :param rlr: learning rate of stochastic optimizer
    :param dims: A list of N tuples that specifies the input and output sizes for the FC layers. where the last layer is the output layer
    :param numEpochs: number of epochs to train the network for
    :param cuda: boolean variable that will specify whether to use the GPU or nt
    :param alpha: weight for regularizer on spectra. If 0, the regularizer will not be used
    :param pathLoad: path to where MNIST lives
    :param pathSave: path specifying where to save the models
    :param epochSave: integer specifying how often to save loss
    :param activation: string that specified whether to use relu or not
    :param _seed: seed for RNG
    :param _run: Sacred object that logs the relevant data and stores them to a database

    :param computeEigVectorsOnline: online or offline eig estimator
    :param regularizerFcn: function name that computes the discrepancy between the idealized and empirical eigs
    """
    device = 'cuda' if cuda == True else 'cpu'
    os.makedirs(pathSave, exist_ok=True)
    npr.seed(_seed)
    torch.manual_seed(_seed + 1)
    alpha = alpha * torch.ones(1, device=device)

    "Load in MNIST"
    fracVal = 0.1
    train, val, test = split_mnist(pathLoad, fracVal)
    trainData, trainLabels = train[0], train[1]
    valData, valLabels = val[0], val[1]
    testData, testLabels = test[0], test[1]
    numSamples = trainData.shape[0]

    if modelType == 'mlp':
        model = MLP(dims, activation=activation)  # create a mlp object
    elif modelType == 'cnn':
        model = CNN(dims, activation=activation)  # create a CNN object
    else:
        print('WOAHHHHH RELAX')

    model = model.to(device)
    lossFunction = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    "Objects used to store performance metrics while network is training"
    trainSpectra = []  # store the (estimated) spectra of the network at the end of each epoch
    trainLoss = []  # store the training loss (reported at the end of each epoch on the last batch)
    trainRegularizer = []  # store the value of the regularizer during training
    valLoss = []  # validation loss
    valRegularizer = []  # validation regularizer

    "Sample indices for eigenvectors all at once"
    eigBatchIdx = npr.randint(numSamples, size=(numEpochs + 1, batch_size))

    "Get initial estimate of eigenvectors and check loss"
    with torch.no_grad():
        model.eigVec, loss, spectraTemp, regul = computeEigVectors(model, trainData[eigBatchIdx[0, :], :],
                                                                   trainLabels[eigBatchIdx[0, :]], lossFunction,
                                                                   alpha=alpha, cuda=cuda)
        trainSpectra.append(spectraTemp)  # store computed eigenspectra
        trainLoss.append(loss.cpu().item())  # store training loss
        _run.log_scalar("trainLoss", loss.item())
        _run.log_scalar("trainRegularizer", float(alpha * regul) )
        trainRegularizer.append(alpha * regul)  # store value of regularizer

        "Check on validation set"
        loss, regul = compute_loss(model, valData, valLabels, lossFunction, alpha, cuda=cuda)
        valLoss.append(loss.item())
        _run.log_scalar("valLoss", loss.item())
        valRegularizer.append(regul)
        prevVal = loss.item() + alpha * regul.item()  # use for early stopping
        prevModel = copy.deepcopy(model)

    patience = 0
    howMuchPatience = 4
    "Train that bad boy!"
    for epoch in tqdm(range(numEpochs), desc="Epochs", ascii=True, position=0, leave=False):
        batches = create_batches(batch_size=batch_size, numSamples=numSamples)  # create indices for batches
        for batch in tqdm(batches, desc='Train Batches', ascii=True, position=1, leave=False):
            optimizer.zero_grad()
            "Compute a forward pass through the network"
            loss, regul = compute_loss(model, trainData[batch, :], trainLabels[batch], lossFunction, alpha, cuda=cuda)
            lossR = loss + alpha * regul  # compute augmented loss function
            lossR.backward()  # backprop!
            optimizer.step()  # take a gradient step

        "Recompute estimated eigenvectors"
        with torch.no_grad():
            model.eigVec, loss, spectraTemp, regul = computeEigVectors(model, trainData[eigBatchIdx[epoch + 1, :], :],
                                                                       trainLabels[eigBatchIdx[epoch + 1, :]],
                                                                       lossFunction, alpha=alpha, cuda=cuda)
            trainSpectra.append(spectraTemp)  # store computed eigenspectra
            trainLoss.append(loss.cpu().item())  # store training loss
            _run.log_scalar("trainLoss", loss.item())
            trainRegularizer.append(alpha * regul)  # store value of regularizer
            if (epoch + 1) % epochSave == 0:
                "Check early stopping condition"
                loss, regul = compute_loss(model, valData, valLabels, lossFunction, alpha, cuda=cuda)
                currVal = loss.item() + alpha * regul.item()
                percentImprove = (currVal - prevVal) / prevVal
                if percentImprove > 0:
                    if patience > howMuchPatience:
                        model = prevModel
                        break
                    else:
                        patience += 1

                else:
                    patience = 0
                prevVal = currVal
                prevModel = copy.deepcopy(model)  # save for early stopping
                valLoss.append(loss.item())
                _run.log_scalar("valLoss", loss.item())
                valRegularizer.append(regul.item())
                _run.log_scalar("valRegularizer", regul.item())


    "Check accuracy on test set"
    outputs = model(testData.to(device))
    softMax = nn.Softmax(dim=1)
    probs = softMax(outputs.cpu())
    numCorrect = torch.sum(torch.argmax(probs, dim=1) == testLabels).detach().numpy() * 1.0
    testResult = numCorrect / testData.shape[0] * 100

    "Collect accuracy on validation set"
    outputs = model(valData.to(device))
    softMax = nn.Softmax(dim=1)
    probs = softMax(outputs).cpu()
    numCorrect = torch.sum(torch.argmax(probs, dim=1) == valLabels).detach().numpy() * 1.0
    valAcc = numCorrect / valData.shape[0] * 100
    _run.log_scalar("valAcc", valAcc.item())


    "Save everything for later analysis"
    model_data = {'parameters': model.cpu().state_dict(),
                  'training': (trainLoss, trainRegularizer, trainSpectra),
                  'val': (valLoss, valRegularizer, valAcc),
                  'test': testResult}

    if modelType == 'cnn':
        dims = dims[1:]  # first number is number of convolutional layers
    path = pathSave + modelType + '_' + activation + '_hidden=('
    for idx in range(len(dims) - 1):
        path = path + str(dims[idx][1]) + ','

    path = path + str(dims[-1][1]) + ')_lr=' + str(lr) + '_alpha=' + str(alpha) + '_batch_size=' \
           + str(batch_size) + '_seed=' + str(_seed) + '_epochs=' + str(numEpochs)
    torch.save(model_data, path)
    _run.add_artifact(path, "model_data.pt", content_type="application/octet-stream")  # saves the data dump as model_data
    # os.system('ls -l --block-size=M {}'.format(path))
    # shutil.rmtree(pathSave)
    # Returning the validation loss to do model comparision and selection
    return valAcc
def advTrain(batch_size, lr, dims, numEpochs, eps, alpha, gradSteps,
             noRestarts, cuda, pathLoad, pathSave, epochSave, activation,
             modelType, computeEigVectorsOnline, regularizerFcn, _seed, _run):
    """
    Function for creating and training NNs on MNIST using adversarial training.
    :param regularizerFcn
    :param computeEigVectorsOnline:
    :param batch_size: specifies batch size
    :param lr: learning rate of stochastic optimizer
    :param dims: A list of N tuples that specifies the input and output sizes for the FC layers. where the last layer
     is the output layer
    :param numEpochs: number of epochs to train the network for
    :param eps: radius of l infinity ball
    :param alpha: learning rate for projected gradient descent. If alpha is 0, then use FGSM
    :param gradSteps: number of gradient steps to take when doing pgd
    :param noRestarts: number of restarts for pgd
    :param cuda: boolean variable that will specify whether to use the GPU or nt
    :param pathLoad: path to where MNIST lives
    :param pathSave: path specifying where to save the models
    :param epochSave: integer specfying how often to save loss
    :param activation: string that specified whether to use relu or not
    :param _seed: seed for RNG
    :param _run: Sacred object that logs the relevant data and stores them to a database
    """
    device = 'cuda' if cuda == True else 'cpu'
    os.makedirs(pathSave, exist_ok=True)
    npr.seed(_seed)
    torch.manual_seed(_seed + 1)
    alpha = alpha * torch.ones(1, device=device)

    "Load in MNIST"
    fracVal = 0.1
    train, val, test = split_mnist(pathLoad, fracVal)
    trainData, trainLabels = train[0], train[1]
    valData, valLabels = val[0], val[1]
    testData, testLabels = test[0], test[1]
    numSamples = trainData.shape[0]

    # In[]
    if modelType == 'mlp':
        mlp = MLP(dims, activation=activation)  # create a mlp object
    elif modelType == 'cnn':
        mlp = CNN(dims, activation=activation)  # create a CNN object
    else:
        print('WOAHHHHH RELAX')

    mlp.to(device)
    lossFunction = nn.CrossEntropyLoss()
    optimizer = optim.Adam(mlp.parameters(), lr=lr)

    "Create adversary"
    adv = Adversary(eps=eps,
                    alpha=alpha,
                    gradSteps=gradSteps,
                    noRestarts=noRestarts,
                    cuda=cuda)

    "Objects used to store performance metrics while network is training"
    trainLoss = [
    ]  # store the training loss (reported at the end of each epoch on the last batch)

    "Get initial value of loss function"
    tempIdx = npr.randint(numSamples, size=batch_size)
    _, lossAdv = adv.generateAdvImages(trainData[tempIdx, :].to(device),
                                       trainLabels[tempIdx].to(device), mlp,
                                       lossFunction)
    trainLoss.append(lossAdv)
    _run.log_scalar("trainLoss", float(lossAdv))
    "Train that bad boy!"
    counter = count(0)
    for epoch in tqdm(range(numEpochs),
                      desc="Epochs",
                      ascii=True,
                      position=0,
                      leave=False):
        batches = create_batches(
            batch_size=batch_size,
            numSamples=numSamples)  # create indices for batches
        for batch in tqdm(batches,
                          desc='Train Batches',
                          ascii=True,
                          position=1,
                          leave=False):
            "Compute a forward pass through the network"
            # Create adversarial images
            xAdv, _ = adv.generateAdvImages(trainData[batch, :].to(device),
                                            trainLabels[batch].to(device), mlp,
                                            lossFunction)
            optimizer.zero_grad()
            outputs = mlp(xAdv)  # feed data forward
            loss = lossFunction(outputs,
                                trainLabels.to(device)[batch])  # compute loss
            loss.backward()  # backprop!
            optimizer.step()  # take a gradient step

            if (epoch + 1) % epochSave == 0:
                trainLoss.append(loss.item())  # store training loss
                _run.log_scalar("trainLoss", loss.item())

    "Check accuracy on test set"
    outputs = mlp(testData.to(device))
    softMax = nn.Softmax(dim=1)
    probs = softMax(outputs.cpu())
    numCorrect = torch.sum(
        torch.argmax(probs, dim=1) == testLabels).detach().numpy() * 1.0
    testResult = numCorrect / testData.shape[0] * 100

    "Collect accuracy on validation set"
    outputs = mlp(valData.to(device))

    softMax = nn.Softmax(dim=1)
    probs = softMax(outputs).cpu()
    numCorrect = torch.sum(
        torch.argmax(probs, dim=1) == valLabels).detach().numpy() * 1.0
    valAcc = numCorrect / valData.shape[0] * 100
    _run.log_scalar("valAcc", valAcc.item())

    "Save everything for later analysis"
    model_data = {
        'parameters': mlp.cpu().state_dict(),
        'training': trainLoss,
        'valAcc': valAcc,
        'test': testResult
    }
    if modelType == 'cnn':
        dims = dims[1:]  # first number is number of convolutional layers
    path = pathSave + modelType + '_' + activation + '_hidden=('
    for idx in range(len(dims) - 1):
        path = path + str(dims[idx][1]) + ','

    path = path + str(dims[-1][1]) + ')_lr=' + str(lr) + '_alpha=' + str(alpha) + '_batch_size=' \
           + str(batch_size) + '_seed=' + str(_seed) + '_epochs=' + str(numEpochs)
    torch.save(model_data, path)
    _run.add_artifact(path, "model_data")  # saves the data dump as model_data
    # shutil.rmtree(pathSave)
    # Returning the validation loss to do model comparision and selection
    return valAcc