def main():
    pyro.clear_param_store()
    #pyro.get_param_store().load('Pyro_model')
    for j in range(n_epochs):
        loss = 0
        start = time.time()
        for data in train_loader:
            data[0] = Variable(data[0].cuda())  #.view(-1, 28 * 28).cuda())
            data[1] = Variable(data[1].long().cuda())
            loss += svi.step(data)
        print(time.time() - start)
        #if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" %
              (j + 1, loss / float(n_train_batches * batch_size)))
    #for name in pyro.get_param_store().get_all_param_names():
    #    print("[%s]: %.3f" % (name, pyro.param(name).data.numpy()))
    pyro.get_param_store().save('Pyro_model')
    datasets = {'RegularImages_0.0': [test.test_data, test.test_labels]}

    fgsm = glob.glob('fgsm/fgsm_cifar10_examples_x_10000_*'
                     )  #glob.glob('fgsm/fgsm_mnist_adv_x_1000_*')
    fgsm_labels = test.test_labels  #torch.from_numpy(np.argmax(np.load('fgsm/fgsm_mnist_adv_y_1000.npy'), axis=1))
    for file in fgsm:
        parts = file.split('_')
        key = parts[0].split('/')[0] + '_' + parts[-1].split('.npy')[0]

        datasets[key] = [torch.from_numpy(np.load(file)), fgsm_labels]

    #jsma = glob.glob('jsma/jsma_cifar10_adv_x_10000*')
    #jsma_labels = torch.from_numpy(np.argmax(np.load('jsma/jsma_cifar10_adv_y_10000.npy'), axis=1))
    #for file in jsma:
    #    parts = file.split('_')
    #    key = parts[0].split('/')[0] + '_' + parts[-1].split('.npy')[0]
    #    datasets[key] = [torch.from_numpy(np.load(file)), jsma_labels]
    gaussian = glob.glob('gaussian/cifar_gaussian_adv_x_*')
    gaussian_labels = torch.from_numpy(
        np.argmax(np.load('gaussian/cifar_gaussian_adv_y.npy')[0:1000],
                  axis=1))
    for file in gaussian:
        parts = file.split('_')
        key = parts[0].split('/')[0] + '_' + parts[-1].split('.npy')[0]
        datasets[key] = [torch.from_numpy(np.load(file)), gaussian_labels]

    print(datasets.keys())
    print(
        '################################################################################'
    )
    accuracies = {}
    for key, value in datasets.iteritems():
        print(key)
        parts = key.split('_')
        adversary_type = parts[0]
        epsilon = parts[1]
        data = value
        X, y = data[0], data[1]  #.view(-1, 28 * 28), data[1]
        x_data, y_data = Variable(X.float().cuda()), Variable(y.cuda())
        T = 100

        accs = []
        samples = np.zeros((y_data.data.size()[0], T, outputs))
        for i in range(T):
            sampled_model = guide(None)
            pred = sampled_model(x_data)
            samples[:, i, :] = pred.data.cpu().numpy()
            _, out = torch.max(pred, 1)

            acc = np.count_nonzero(
                np.squeeze(out.data.cpu().numpy()) == np.int32(y_data.data.cpu(
                ).numpy().ravel())) / float(y_data.data.size()[0])
            accs.append(acc)

        variationRatio = []
        mutualInformation = []
        predictiveEntropy = []
        predictions = []

        for i in range(0, len(y_data)):
            entry = samples[i, :, :]
            variationRatio.append(Uncertainty.variation_ratio(entry))
            mutualInformation.append(Uncertainty.mutual_information(entry))
            predictiveEntropy.append(Uncertainty.predictive_entropy(entry))
            predictions.append(np.max(entry.mean(axis=0), axis=0))

        uncertainty = {}
        uncertainty['varation_ratio'] = np.array(variationRatio)
        uncertainty['predictive_entropy'] = np.array(predictiveEntropy)
        uncertainty['mutual_information'] = np.array(mutualInformation)
        predictions = np.array(predictions)

        Uncertainty.plot_uncertainty(uncertainty,
                                     predictions,
                                     adversarial_type=adversary_type,
                                     epsilon=float(epsilon),
                                     directory='Results_CIFAR10_PYRO')
        #, directory='Results_CIFAR10_PYRO')

        accs = np.array(accs)
        print('Accuracy mean: {}, Accuracy std: {}'.format(
            accs.mean(), accs.std()))
        #accuracies[key] = {'mean': accs.mean(), 'std': accs.std()}
        accuracies[key] = {'mean': accs.mean(), 'std': accs.std(),  \
                   'variationratio': [uncertainty['varation_ratio'].mean(), uncertainty['varation_ratio'].std()], \
                   'predictiveEntropy': [uncertainty['predictive_entropy'].mean(), uncertainty['predictive_entropy'].std()], \
                   'mutualInformation': [uncertainty['mutual_information'].mean(), uncertainty['mutual_information'].std()]}

    np.save('PyroBNN_accuracies_CIFAR10', accuracies)
gaussian = glob.glob('gaussian/cifar_gaussian_adv_x*')
gaussian_labels = torch.from_numpy(
    np.argmax(np.load('gaussian/cifar_gaussian_adv_y.npy')[0:1000], axis=1))
for file in gaussian:
    parts = file.split('_')
    key = parts[0].split('/')[0] + '_' + parts[-1].split('.npy')[0]

    datasets[key] = [torch.from_numpy(np.load(file)), gaussian_labels]

print(datasets.keys())
print(
    '################################################################################'
)
accuracies = {}
for key, value in datasets.iteritems():
    print(key)
    parts = key.split('_')
    adversary_type = parts[0]
    epsilon = parts[1]
    data = value
    X, y = data[0], data[1]
    x_data, y_data = Variable(X.float().cuda()), Variable(y.cuda())
    T = 50

    accs = []
    samples = np.zeros((y_data.data.size()[0], T, outputs))
    for i in range(T):
        pred = createProbabilityOfClasses(x_data, samples=20)
        samples[:, i, :] = pred.data.cpu().numpy()
        _, out = torch.max(pred, 1)