Ejemplo n.º 1
0
 def __init__(self,
              dataset_name,
              hidden_size,
              activation,
              architecture,
              inference,
              epochs,
              lr,
              n_samples,
              warmup,
              input_shape,
              output_size,
              step_size=0.005,
              num_steps=10):
     super(BNN, self).__init__()
     self.dataset_name = dataset_name
     self.inference = inference
     self.architecture = architecture
     self.epochs = epochs
     self.lr = lr
     self.n_samples = n_samples
     self.warmup = warmup
     self.step_size = step_size
     self.num_steps = num_steps
     self.basenet = NN(dataset_name=dataset_name,
                       input_shape=input_shape,
                       output_size=output_size,
                       hidden_size=hidden_size,
                       activation=activation,
                       architecture=architecture,
                       epochs=epochs,
                       lr=lr)
     print(self.basenet)
     self.name = self.get_name()
Ejemplo n.º 2
0
    def load(self, device, rel_path=TESTS):

        savedir = self.name + "/weights"
        for seed in self.random_seeds:

            net = NN(dataset_name=self.dataset_name,
                     input_shape=self.input_shape,
                     output_size=self.output_size,
                     hidden_size=self.hidden_size,
                     activation=self.activation,
                     architecture=self.architecture,
                     epochs=self.epochs,
                     lr=self.lr)

            net.load(device=device,
                     savedir=savedir,
                     seed=seed,
                     rel_path=rel_path)
            self.ensemble_models[str(seed)] = net
Ejemplo n.º 3
0
    def train(self, x_train, y_train, device):

        for seed in self.random_seeds:

            batch_size = 100  #random.choice([32,64,128,256])
            train_loader = DataLoader(dataset=list(zip(x_train, y_train)),
                                      batch_size=batch_size,
                                      shuffle=True)

            net = NN(dataset_name=self.dataset_name,
                     input_shape=self.input_shape,
                     output_size=self.output_size,
                     hidden_size=self.hidden_size,
                     activation=self.activation,
                     architecture=self.architecture,
                     epochs=self.epochs,
                     lr=self.lr)
            net.train(train_loader=train_loader,
                      device=device,
                      seed=seed,
                      save=False)
            self.ensemble_models[str(seed)] = net
            self.save(seed=seed)
Ejemplo n.º 4
0
class BNN(PyroModule):
    def __init__(self,
                 dataset_name,
                 hidden_size,
                 activation,
                 architecture,
                 inference,
                 epochs,
                 lr,
                 n_samples,
                 warmup,
                 input_shape,
                 output_size,
                 step_size=0.005,
                 num_steps=10):
        super(BNN, self).__init__()
        self.dataset_name = dataset_name
        self.inference = inference
        self.architecture = architecture
        self.epochs = epochs
        self.lr = lr
        self.n_samples = n_samples
        self.warmup = warmup
        self.step_size = step_size
        self.num_steps = num_steps
        self.basenet = NN(dataset_name=dataset_name,
                          input_shape=input_shape,
                          output_size=output_size,
                          hidden_size=hidden_size,
                          activation=activation,
                          architecture=architecture,
                          epochs=epochs,
                          lr=lr)
        print(self.basenet)
        self.name = self.get_name()

    def get_name(self, n_inputs=None):

        name = str(self.dataset_name)+"_bnn_"+str(self.inference)+"_hid="+\
               str(self.basenet.hidden_size)+"_act="+str(self.basenet.activation)+\
               "_arch="+str(self.basenet.architecture)

        if n_inputs:
            name = name + "_inp=" + str(n_inputs)

        if self.inference == "svi":
            return name + "_ep=" + str(self.epochs) + "_lr=" + str(self.lr)
        elif self.inference == "hmc":
            return name+"_samp="+str(self.n_samples)+"_warm="+str(self.warmup)+\
                   "_stepsize="+str(self.step_size)+"_numsteps="+str(self.num_steps)

    def model(self, x_data, y_data):

        priors = {}
        for key, value in self.basenet.state_dict().items():
            loc = torch.zeros_like(value)
            scale = torch.ones_like(value)
            prior = Normal(loc=loc, scale=scale)
            priors.update({str(key): prior})

        lifted_module = pyro.random_module("module", self.basenet, priors)()

        with pyro.plate("data", len(x_data)):
            logits = lifted_module(x_data)
            lhat = nnf.log_softmax(logits, dim=-1)
            obs = pyro.sample("obs", Categorical(logits=lhat), obs=y_data)

    def guide(self, x_data, y_data=None):

        dists = {}
        for key, value in self.basenet.state_dict().items():
            loc = pyro.param(str(f"{key}_loc"), torch.randn_like(value))
            scale = pyro.param(str(f"{key}_scale"), torch.randn_like(value))
            distr = Normal(loc=loc, scale=softplus(scale))
            dists.update({str(key): distr})

        lifted_module = pyro.random_module("module", self.basenet, dists)()

        with pyro.plate("data", len(x_data)):
            logits = lifted_module(x_data)
            preds = nnf.softmax(logits, dim=-1)

        return preds

    def save(self):

        name = self.name
        path = TESTS + name + "/"
        filename = name + "_weights"
        os.makedirs(os.path.dirname(path), exist_ok=True)

        if self.inference == "svi":
            self.basenet.to("cpu")
            self.to("cpu")

            param_store = pyro.get_param_store()
            print("\nSaving: ", path + filename + ".pt")
            print(f"\nlearned params = {param_store.get_all_param_names()}")
            param_store.save(path + filename + ".pt")

        elif self.inference == "hmc":
            self.basenet.to("cpu")
            self.to("cpu")

            for key, value in self.posterior_predictive.items():
                torch.save(value.state_dict(),
                           path + filename + "_" + str(key) + ".pt")

                if DEBUG:
                    print(value.state_dict()["model.5.bias"])

    def load(self, device, rel_path=TESTS):
        self.device = device
        self.basenet.device = device
        name = self.name
        path = rel_path + name + "/"
        filename = name + "_weights"

        if self.inference == "svi":
            param_store = pyro.get_param_store()
            param_store.load(path + filename + ".pt")
            for key, value in param_store.items():
                param_store.replace_param(key, value.to(device), value)
            print("\nLoading ", path + filename + ".pt\n")

        elif self.inference == "hmc":

            self.posterior_predictive = {}
            for model_idx in range(self.n_samples):
                net_copy = copy.deepcopy(self.basenet)
                net_copy.load_state_dict(
                    torch.load(path + filename + "_" + str(model_idx) + ".pt"))
                self.posterior_predictive.update({model_idx: net_copy})

            if len(self.posterior_predictive) != self.n_samples:
                raise AttributeError("wrong number of posterior models")

        self.to(device)
        self.basenet.to(device)

    def forward(self, inputs, n_samples=10, avg_posterior=False, seeds=None):

        if seeds:
            if len(seeds) != n_samples:
                raise ValueError(
                    "Number of seeds should match number of samples.")

        if self.inference == "svi":

            if avg_posterior is True:

                guide_trace = poutine.trace(self.guide).get_trace(inputs)

                avg_state_dict = {}
                for key in self.basenet.state_dict().keys():
                    avg_weights = guide_trace.nodes[str(key) + "_loc"]['value']
                    avg_state_dict.update({str(key): avg_weights})

                self.basenet.load_state_dict(avg_state_dict)
                preds = [self.basenet.model(inputs)]

            else:

                preds = []

                if seeds:
                    for seed in seeds:
                        pyro.set_rng_seed(seed)
                        guide_trace = poutine.trace(
                            self.guide).get_trace(inputs)
                        preds.append(guide_trace.nodes['_RETURN']['value'])

                else:

                    for _ in range(n_samples):
                        guide_trace = poutine.trace(
                            self.guide).get_trace(inputs)
                        preds.append(guide_trace.nodes['_RETURN']['value'])

                if DEBUG:
                    print("\nlearned variational params:\n")
                    print(pyro.get_param_store().get_all_param_names())
                    print(
                        list(
                            poutine.trace(
                                self.guide).get_trace(inputs).nodes.keys()))
                    print("\n",
                          pyro.get_param_store()["model.0.weight_loc"][0][:5])
                    print(guide_trace.nodes['module$$$model.0.weight']
                          ["fn"].loc[0][:5])
                    print(
                        "posterior sample: ",
                        guide_trace.nodes['module$$$model.0.weight']['value']
                        [5][0][0])

        elif self.inference == "hmc":

            preds = []
            posterior_predictive = list(self.posterior_predictive.values())

            if seeds is None:
                seeds = range(n_samples)

            for seed in seeds:
                net = posterior_predictive[seed]
                preds.append(net.forward(inputs))

        output_probs = torch.stack(preds).mean(0)
        return output_probs

    def _train_hmc(self, train_loader, n_samples, warmup, step_size, num_steps,
                   device):

        print("\n == HMC training ==")
        pyro.clear_param_store()

        num_batches = int(len(train_loader.dataset) / train_loader.batch_size)
        batch_samples = int(n_samples / num_batches) + 1
        print("\nn_batches=", num_batches, "\tbatch_samples =", batch_samples)

        kernel = HMC(self.model, step_size=step_size, num_steps=num_steps)
        mcmc = MCMC(kernel=kernel,
                    num_samples=batch_samples,
                    warmup_steps=warmup,
                    num_chains=1)

        start = time.time()
        for x_batch, y_batch in train_loader:
            x_batch = x_batch.to(device)
            labels = y_batch.to(device).argmax(-1)
            mcmc.run(x_batch, labels)

        execution_time(start=start, end=time.time())

        self.posterior_predictive = {}
        posterior_samples = mcmc.get_samples(n_samples)
        state_dict_keys = list(self.basenet.state_dict().keys())

        if DEBUG:
            print("\n", list(posterior_samples.values())[-1])

        for model_idx in range(n_samples):
            net_copy = copy.deepcopy(self.basenet)

            model_dict = OrderedDict({})
            for weight_idx, weights in enumerate(posterior_samples.values()):
                model_dict.update(
                    {state_dict_keys[weight_idx]: weights[model_idx]})

            net_copy.load_state_dict(model_dict)
            self.posterior_predictive.update({str(model_idx): net_copy})

        if DEBUG:
            print("\n", weights[model_idx])

        self.save()

    def _train_svi(self, train_loader, epochs, lr, device):
        self.device = device

        print("\n == SVI training ==")

        optimizer = pyro.optim.Adam({"lr": lr})
        elbo = TraceMeanField_ELBO()
        svi = SVI(self.model, self.guide, optimizer, loss=elbo)

        loss_list = []
        accuracy_list = []

        start = time.time()
        for epoch in range(epochs):
            loss = 0.0
            correct_predictions = 0.0

            for x_batch, y_batch in train_loader:

                x_batch = x_batch.to(device)
                y_batch = y_batch.to(device)
                labels = y_batch.argmax(-1)
                loss += svi.step(x_data=x_batch, y_data=labels)

                outputs = self.forward(x_batch, n_samples=10)
                predictions = outputs.argmax(dim=-1)
                correct_predictions += (predictions == labels).sum().item()

            if DEBUG:
                print("\n",
                      pyro.get_param_store()["model.0.weight_loc"][0][:5])
                print("\n", predictions[:10], "\n", labels[:10])

            total_loss = loss / len(train_loader.dataset)
            accuracy = 100 * correct_predictions / len(train_loader.dataset)

            print(
                f"\n[Epoch {epoch + 1}]\t loss: {total_loss:.2f} \t accuracy: {accuracy:.2f}",
                end="\t")

            loss_list.append(loss)
            accuracy_list.append(accuracy)

        execution_time(start=start, end=time.time())
        self.save()

        plot_loss_accuracy(dict={
            'loss': loss_list,
            'accuracy': accuracy_list
        },
                           path=TESTS + self.name + "/" + self.name +
                           "_training.png")

    def train(self, train_loader, device):
        self.device = device
        self.basenet.device = device

        self.to(device)
        self.basenet.to(device)

        random.seed(0)
        pyro.set_rng_seed(0)

        if self.inference == "svi":
            self._train_svi(train_loader, self.epochs, self.lr, device)

        elif self.inference == "hmc":
            self._train_hmc(train_loader, self.n_samples, self.warmup,
                            self.step_size, self.num_steps, device)

    def evaluate(self, test_loader, device, n_samples=10, seeds_list=None):
        self.device = device
        self.basenet.device = device
        self.to(device)
        self.basenet.to(device)

        random.seed(0)
        pyro.set_rng_seed(0)

        bnn_seeds = list(
            range(n_samples)) if seeds_list is None else seeds_list

        with torch.no_grad():

            correct_predictions = 0.0
            for x_batch, y_batch in test_loader:

                x_batch = x_batch.to(device)
                outputs = self.forward(x_batch,
                                       n_samples=n_samples,
                                       seeds=bnn_seeds)
                predictions = outputs.argmax(-1)
                labels = y_batch.to(device).argmax(-1)
                correct_predictions += (predictions == labels).sum().item()

            accuracy = 100 * correct_predictions / len(test_loader.dataset)
            print("Accuracy: %.2f%%" % (accuracy))
            return accuracy
Ejemplo n.º 5
0
# sample dataset with undersampling
print("Resampling")
undersample = NearMiss(version=1, n_neighbors=5)
X_train, y_train = undersample.fit_resample(X_train, y_train)
X_test_sample, y_test_sample = undersample.fit_resample(X_test, y_test)
print("Completed Resampling")

# expand dims
X_train = np.expand_dims(X_train, axis=2)
X_test = np.expand_dims(X_test, axis=2)
X_test_sample = np.expand_dims(X_test_sample, axis=2)
X_train, y_train = shuffle(X_train, y_train, random_state=42)

# CNN Model
model = NN()
opt = keras.optimizers.Adam(learning_rate=1e-5)
model.compile(optimizer="adam",
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# # weights
# weights = class_weight.compute_class_weight('balanced',
#                                             np.unique(y_train),
#                                             y_train)

# print(weights)
weights = [1, 0.000001]
# weights = [50248756, 0.1010]
# weights = [1,1]
# print(weights)
Ejemplo n.º 6
0
def main(args):

    hyperparams = {"epsilon": 0.3}

    rel_path = DATA if args.savedir == "DATA" else TESTS
    train_inputs = 100 if DEBUG else None

    if args.device == "cuda":
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    if args.model_type == "nn":

        ### NN model
        dataset, hid, activ, arch, ep, lr = saved_NNs[
            "model_" + str(args.model_idx)].values()

        x_train, y_train, x_test, y_test, inp_shape, out_size = \
            load_dataset(dataset_name=dataset, n_inputs=train_inputs)
        train_loader = DataLoader(dataset=list(zip(x_train, y_train)),
                                  shuffle=True)
        test_loader = DataLoader(dataset=list(zip(x_test, y_test)))

        nn = NN(dataset_name=dataset,
                input_shape=inp_shape,
                output_size=out_size,
                hidden_size=hid,
                activation=activ,
                architecture=arch,
                epochs=ep,
                lr=lr)

        if args.train:
            nn.train(train_loader=train_loader, device=args.device)
        else:
            nn.load(device=args.device, rel_path=rel_path)

        if args.test:
            nn.evaluate(test_loader=test_loader, device=args.device)

        ### attack NN
        if args.attack:
            x_test, y_test = (torch.from_numpy(x_test[:args.n_inputs]),
                              torch.from_numpy(y_test[:args.n_inputs]))
            x_attack = attack(net=nn,
                              x_test=x_test,
                              y_test=y_test,
                              dataset_name=dataset,
                              device=args.device,
                              method=args.attack_method,
                              filename=nn.name,
                              hyperparams=hyperparams)
        else:
            x_attack = load_attack(method=args.attack_method,
                                   rel_path=DATA,
                                   filename=nn.name)

        attack_evaluation(net=nn,
                          x_test=x_test,
                          x_attack=x_attack,
                          y_test=y_test,
                          device=args.device)

    elif args.model_type == "bnn":

        bayesian_attack_samples = [10]
        bayesian_defence_samples = [10]

        ### BNN model
        dataset, model = saved_BNNs["model_" + str(args.model_idx)]
        batch_size = 5000 if model["inference"] == "hmc" else 128

        x_train, y_train, x_test, y_test, inp_shape, out_size = \
            load_dataset(dataset_name=dataset, n_inputs=train_inputs)
        train_loader = DataLoader(dataset=list(zip(x_train, y_train)),
                                  batch_size=batch_size,
                                  shuffle=True)
        test_loader = DataLoader(dataset=list(zip(x_test, y_test)))

        bnn = BNN(dataset, *list(model.values()), inp_shape, out_size)

        if args.train:
            bnn.train(train_loader=train_loader, device=args.device)
        else:
            bnn.load(device=args.device, rel_path=rel_path)

        if args.test:
            bnn.evaluate(test_loader=test_loader,
                         device=args.device,
                         n_samples=10)

        ### attack BNN
        x_test, y_test = (torch.from_numpy(x_test[:args.n_inputs]),
                          torch.from_numpy(y_test[:args.n_inputs]))

        for attack_samples in bayesian_attack_samples:
            x_attack = attack(net=bnn,
                              x_test=x_test,
                              y_test=y_test,
                              dataset_name=dataset,
                              device=args.device,
                              method=args.attack_method,
                              filename=bnn.name,
                              n_samples=attack_samples,
                              hyperparams=hyperparams)

            for defence_samples in bayesian_defence_samples:
                attack_evaluation(net=bnn,
                                  x_test=x_test,
                                  x_attack=x_attack,
                                  y_test=y_test,
                                  device=args.device,
                                  n_samples=defence_samples)

    elif args.model_type == "avg_ensemble":

        ensemble_size = 10
        n_samples = 10

        dataset, hid, activ, arch, ep, lr = saved_NNs[
            "model_" + str(args.model_idx)].values()

        _, _, x_test, y_test, inp_shape, out_size = \
            load_dataset(dataset_name=dataset, n_inputs=args.n_inputs)
        test_loader = DataLoader(dataset=list(zip(x_test, y_test)))

        x_test, y_test = (torch.from_numpy(x_test[:args.n_inputs]),
                          torch.from_numpy(y_test[:args.n_inputs]))

        ens_net = Ensemble_NN(dataset_name=dataset,
                              input_shape=inp_shape,
                              output_size=out_size,
                              hidden_size=hid,
                              activation=activ,
                              architecture=arch,
                              epochs=ep,
                              lr=lr,
                              ensemble_size=ensemble_size)

        results = torch.empty(size=(n_samples, 3))

        for seed in range(0, n_samples):

            nn = NN(dataset_name=dataset,
                    input_shape=inp_shape,
                    output_size=out_size,
                    hidden_size=hid,
                    activation=activ,
                    architecture=arch,
                    epochs=ep,
                    lr=lr)
            nn.load(device=args.device,
                    rel_path=rel_path,
                    savedir=ens_net.name + "/weights",
                    seed=seed)

            nn_attack = attack(net=nn,
                               x_test=x_test,
                               y_test=y_test,
                               dataset_name=dataset,
                               device=args.device,
                               method=args.attack_method,
                               filename=nn.name,
                               hyperparams=hyperparams)

            test_acc, adv_acc, softmax_rob = attack_evaluation(
                net=nn,
                x_test=x_test,
                x_attack=nn_attack,
                y_test=y_test,
                device=args.device)

            results[seed] = torch.tensor(
                [test_acc, adv_acc, softmax_rob.mean(0)])

        avg_res = results.mean(0)
        print(
            f"\navg test_acc = {avg_res[0]:.2f}\tavg adv_acc = {avg_res[1]:.2f}\tavg avg_softmax_rob = {avg_res[2]:.2f}"
        )

    elif args.model_type == "ensemble":

        ensemble_size = 10
        n_samples = 10

        dataset, hid, activ, arch, ep, lr = saved_NNs[
            "model_" + str(args.model_idx)].values()

        _, _, x_test, y_test, inp_shape, out_size = \
            load_dataset(dataset_name=dataset, n_inputs=args.n_inputs)
        test_loader = DataLoader(dataset=list(zip(x_test, y_test)))

        x_test, y_test = (torch.from_numpy(x_test[:args.n_inputs]),
                          torch.from_numpy(y_test[:args.n_inputs]))

        ens_net = Ensemble_NN(dataset_name=dataset,
                              input_shape=inp_shape,
                              output_size=out_size,
                              hidden_size=hid,
                              activation=activ,
                              architecture=arch,
                              epochs=ep,
                              lr=lr,
                              ensemble_size=ensemble_size)

        ens_net.load(device=args.device, rel_path=rel_path)

        ens_attack = attack(net=ens_net,
                            x_test=x_test,
                            y_test=y_test,
                            dataset_name=dataset,
                            device=args.device,
                            method=args.attack_method,
                            filename=ens_net.name,
                            hyperparams=hyperparams)

        test_acc, adv_acc, softmax_rob = attack_evaluation(net=ens_net,
                                                           x_test=x_test,
                                                           x_attack=ens_attack,
                                                           y_test=y_test,
                                                           device=args.device)

    else:
        raise NotImplementedError()