def __init__(self, dataset_name, hidden_size, activation, architecture, inference, epochs, lr, n_samples, warmup, input_shape, output_size, step_size=0.005, num_steps=10): super(BNN, self).__init__() self.dataset_name = dataset_name self.inference = inference self.architecture = architecture self.epochs = epochs self.lr = lr self.n_samples = n_samples self.warmup = warmup self.step_size = step_size self.num_steps = num_steps self.basenet = NN(dataset_name=dataset_name, input_shape=input_shape, output_size=output_size, hidden_size=hidden_size, activation=activation, architecture=architecture, epochs=epochs, lr=lr) print(self.basenet) self.name = self.get_name()
def load(self, device, rel_path=TESTS): savedir = self.name + "/weights" for seed in self.random_seeds: net = NN(dataset_name=self.dataset_name, input_shape=self.input_shape, output_size=self.output_size, hidden_size=self.hidden_size, activation=self.activation, architecture=self.architecture, epochs=self.epochs, lr=self.lr) net.load(device=device, savedir=savedir, seed=seed, rel_path=rel_path) self.ensemble_models[str(seed)] = net
def train(self, x_train, y_train, device): for seed in self.random_seeds: batch_size = 100 #random.choice([32,64,128,256]) train_loader = DataLoader(dataset=list(zip(x_train, y_train)), batch_size=batch_size, shuffle=True) net = NN(dataset_name=self.dataset_name, input_shape=self.input_shape, output_size=self.output_size, hidden_size=self.hidden_size, activation=self.activation, architecture=self.architecture, epochs=self.epochs, lr=self.lr) net.train(train_loader=train_loader, device=device, seed=seed, save=False) self.ensemble_models[str(seed)] = net self.save(seed=seed)
class BNN(PyroModule): def __init__(self, dataset_name, hidden_size, activation, architecture, inference, epochs, lr, n_samples, warmup, input_shape, output_size, step_size=0.005, num_steps=10): super(BNN, self).__init__() self.dataset_name = dataset_name self.inference = inference self.architecture = architecture self.epochs = epochs self.lr = lr self.n_samples = n_samples self.warmup = warmup self.step_size = step_size self.num_steps = num_steps self.basenet = NN(dataset_name=dataset_name, input_shape=input_shape, output_size=output_size, hidden_size=hidden_size, activation=activation, architecture=architecture, epochs=epochs, lr=lr) print(self.basenet) self.name = self.get_name() def get_name(self, n_inputs=None): name = str(self.dataset_name)+"_bnn_"+str(self.inference)+"_hid="+\ str(self.basenet.hidden_size)+"_act="+str(self.basenet.activation)+\ "_arch="+str(self.basenet.architecture) if n_inputs: name = name + "_inp=" + str(n_inputs) if self.inference == "svi": return name + "_ep=" + str(self.epochs) + "_lr=" + str(self.lr) elif self.inference == "hmc": return name+"_samp="+str(self.n_samples)+"_warm="+str(self.warmup)+\ "_stepsize="+str(self.step_size)+"_numsteps="+str(self.num_steps) def model(self, x_data, y_data): priors = {} for key, value in self.basenet.state_dict().items(): loc = torch.zeros_like(value) scale = torch.ones_like(value) prior = Normal(loc=loc, scale=scale) priors.update({str(key): prior}) lifted_module = pyro.random_module("module", self.basenet, priors)() with pyro.plate("data", len(x_data)): logits = lifted_module(x_data) lhat = nnf.log_softmax(logits, dim=-1) obs = pyro.sample("obs", Categorical(logits=lhat), obs=y_data) def guide(self, x_data, y_data=None): dists = {} for key, value in self.basenet.state_dict().items(): loc = pyro.param(str(f"{key}_loc"), torch.randn_like(value)) scale = pyro.param(str(f"{key}_scale"), torch.randn_like(value)) distr = Normal(loc=loc, scale=softplus(scale)) dists.update({str(key): distr}) lifted_module = pyro.random_module("module", self.basenet, dists)() with pyro.plate("data", len(x_data)): logits = lifted_module(x_data) preds = nnf.softmax(logits, dim=-1) return preds def save(self): name = self.name path = TESTS + name + "/" filename = name + "_weights" os.makedirs(os.path.dirname(path), exist_ok=True) if self.inference == "svi": self.basenet.to("cpu") self.to("cpu") param_store = pyro.get_param_store() print("\nSaving: ", path + filename + ".pt") print(f"\nlearned params = {param_store.get_all_param_names()}") param_store.save(path + filename + ".pt") elif self.inference == "hmc": self.basenet.to("cpu") self.to("cpu") for key, value in self.posterior_predictive.items(): torch.save(value.state_dict(), path + filename + "_" + str(key) + ".pt") if DEBUG: print(value.state_dict()["model.5.bias"]) def load(self, device, rel_path=TESTS): self.device = device self.basenet.device = device name = self.name path = rel_path + name + "/" filename = name + "_weights" if self.inference == "svi": param_store = pyro.get_param_store() param_store.load(path + filename + ".pt") for key, value in param_store.items(): param_store.replace_param(key, value.to(device), value) print("\nLoading ", path + filename + ".pt\n") elif self.inference == "hmc": self.posterior_predictive = {} for model_idx in range(self.n_samples): net_copy = copy.deepcopy(self.basenet) net_copy.load_state_dict( torch.load(path + filename + "_" + str(model_idx) + ".pt")) self.posterior_predictive.update({model_idx: net_copy}) if len(self.posterior_predictive) != self.n_samples: raise AttributeError("wrong number of posterior models") self.to(device) self.basenet.to(device) def forward(self, inputs, n_samples=10, avg_posterior=False, seeds=None): if seeds: if len(seeds) != n_samples: raise ValueError( "Number of seeds should match number of samples.") if self.inference == "svi": if avg_posterior is True: guide_trace = poutine.trace(self.guide).get_trace(inputs) avg_state_dict = {} for key in self.basenet.state_dict().keys(): avg_weights = guide_trace.nodes[str(key) + "_loc"]['value'] avg_state_dict.update({str(key): avg_weights}) self.basenet.load_state_dict(avg_state_dict) preds = [self.basenet.model(inputs)] else: preds = [] if seeds: for seed in seeds: pyro.set_rng_seed(seed) guide_trace = poutine.trace( self.guide).get_trace(inputs) preds.append(guide_trace.nodes['_RETURN']['value']) else: for _ in range(n_samples): guide_trace = poutine.trace( self.guide).get_trace(inputs) preds.append(guide_trace.nodes['_RETURN']['value']) if DEBUG: print("\nlearned variational params:\n") print(pyro.get_param_store().get_all_param_names()) print( list( poutine.trace( self.guide).get_trace(inputs).nodes.keys())) print("\n", pyro.get_param_store()["model.0.weight_loc"][0][:5]) print(guide_trace.nodes['module$$$model.0.weight'] ["fn"].loc[0][:5]) print( "posterior sample: ", guide_trace.nodes['module$$$model.0.weight']['value'] [5][0][0]) elif self.inference == "hmc": preds = [] posterior_predictive = list(self.posterior_predictive.values()) if seeds is None: seeds = range(n_samples) for seed in seeds: net = posterior_predictive[seed] preds.append(net.forward(inputs)) output_probs = torch.stack(preds).mean(0) return output_probs def _train_hmc(self, train_loader, n_samples, warmup, step_size, num_steps, device): print("\n == HMC training ==") pyro.clear_param_store() num_batches = int(len(train_loader.dataset) / train_loader.batch_size) batch_samples = int(n_samples / num_batches) + 1 print("\nn_batches=", num_batches, "\tbatch_samples =", batch_samples) kernel = HMC(self.model, step_size=step_size, num_steps=num_steps) mcmc = MCMC(kernel=kernel, num_samples=batch_samples, warmup_steps=warmup, num_chains=1) start = time.time() for x_batch, y_batch in train_loader: x_batch = x_batch.to(device) labels = y_batch.to(device).argmax(-1) mcmc.run(x_batch, labels) execution_time(start=start, end=time.time()) self.posterior_predictive = {} posterior_samples = mcmc.get_samples(n_samples) state_dict_keys = list(self.basenet.state_dict().keys()) if DEBUG: print("\n", list(posterior_samples.values())[-1]) for model_idx in range(n_samples): net_copy = copy.deepcopy(self.basenet) model_dict = OrderedDict({}) for weight_idx, weights in enumerate(posterior_samples.values()): model_dict.update( {state_dict_keys[weight_idx]: weights[model_idx]}) net_copy.load_state_dict(model_dict) self.posterior_predictive.update({str(model_idx): net_copy}) if DEBUG: print("\n", weights[model_idx]) self.save() def _train_svi(self, train_loader, epochs, lr, device): self.device = device print("\n == SVI training ==") optimizer = pyro.optim.Adam({"lr": lr}) elbo = TraceMeanField_ELBO() svi = SVI(self.model, self.guide, optimizer, loss=elbo) loss_list = [] accuracy_list = [] start = time.time() for epoch in range(epochs): loss = 0.0 correct_predictions = 0.0 for x_batch, y_batch in train_loader: x_batch = x_batch.to(device) y_batch = y_batch.to(device) labels = y_batch.argmax(-1) loss += svi.step(x_data=x_batch, y_data=labels) outputs = self.forward(x_batch, n_samples=10) predictions = outputs.argmax(dim=-1) correct_predictions += (predictions == labels).sum().item() if DEBUG: print("\n", pyro.get_param_store()["model.0.weight_loc"][0][:5]) print("\n", predictions[:10], "\n", labels[:10]) total_loss = loss / len(train_loader.dataset) accuracy = 100 * correct_predictions / len(train_loader.dataset) print( f"\n[Epoch {epoch + 1}]\t loss: {total_loss:.2f} \t accuracy: {accuracy:.2f}", end="\t") loss_list.append(loss) accuracy_list.append(accuracy) execution_time(start=start, end=time.time()) self.save() plot_loss_accuracy(dict={ 'loss': loss_list, 'accuracy': accuracy_list }, path=TESTS + self.name + "/" + self.name + "_training.png") def train(self, train_loader, device): self.device = device self.basenet.device = device self.to(device) self.basenet.to(device) random.seed(0) pyro.set_rng_seed(0) if self.inference == "svi": self._train_svi(train_loader, self.epochs, self.lr, device) elif self.inference == "hmc": self._train_hmc(train_loader, self.n_samples, self.warmup, self.step_size, self.num_steps, device) def evaluate(self, test_loader, device, n_samples=10, seeds_list=None): self.device = device self.basenet.device = device self.to(device) self.basenet.to(device) random.seed(0) pyro.set_rng_seed(0) bnn_seeds = list( range(n_samples)) if seeds_list is None else seeds_list with torch.no_grad(): correct_predictions = 0.0 for x_batch, y_batch in test_loader: x_batch = x_batch.to(device) outputs = self.forward(x_batch, n_samples=n_samples, seeds=bnn_seeds) predictions = outputs.argmax(-1) labels = y_batch.to(device).argmax(-1) correct_predictions += (predictions == labels).sum().item() accuracy = 100 * correct_predictions / len(test_loader.dataset) print("Accuracy: %.2f%%" % (accuracy)) return accuracy
# sample dataset with undersampling print("Resampling") undersample = NearMiss(version=1, n_neighbors=5) X_train, y_train = undersample.fit_resample(X_train, y_train) X_test_sample, y_test_sample = undersample.fit_resample(X_test, y_test) print("Completed Resampling") # expand dims X_train = np.expand_dims(X_train, axis=2) X_test = np.expand_dims(X_test, axis=2) X_test_sample = np.expand_dims(X_test_sample, axis=2) X_train, y_train = shuffle(X_train, y_train, random_state=42) # CNN Model model = NN() opt = keras.optimizers.Adam(learning_rate=1e-5) model.compile(optimizer="adam", loss='categorical_crossentropy', metrics=['accuracy']) # # weights # weights = class_weight.compute_class_weight('balanced', # np.unique(y_train), # y_train) # print(weights) weights = [1, 0.000001] # weights = [50248756, 0.1010] # weights = [1,1] # print(weights)
def main(args): hyperparams = {"epsilon": 0.3} rel_path = DATA if args.savedir == "DATA" else TESTS train_inputs = 100 if DEBUG else None if args.device == "cuda": torch.set_default_tensor_type('torch.cuda.FloatTensor') if args.model_type == "nn": ### NN model dataset, hid, activ, arch, ep, lr = saved_NNs[ "model_" + str(args.model_idx)].values() x_train, y_train, x_test, y_test, inp_shape, out_size = \ load_dataset(dataset_name=dataset, n_inputs=train_inputs) train_loader = DataLoader(dataset=list(zip(x_train, y_train)), shuffle=True) test_loader = DataLoader(dataset=list(zip(x_test, y_test))) nn = NN(dataset_name=dataset, input_shape=inp_shape, output_size=out_size, hidden_size=hid, activation=activ, architecture=arch, epochs=ep, lr=lr) if args.train: nn.train(train_loader=train_loader, device=args.device) else: nn.load(device=args.device, rel_path=rel_path) if args.test: nn.evaluate(test_loader=test_loader, device=args.device) ### attack NN if args.attack: x_test, y_test = (torch.from_numpy(x_test[:args.n_inputs]), torch.from_numpy(y_test[:args.n_inputs])) x_attack = attack(net=nn, x_test=x_test, y_test=y_test, dataset_name=dataset, device=args.device, method=args.attack_method, filename=nn.name, hyperparams=hyperparams) else: x_attack = load_attack(method=args.attack_method, rel_path=DATA, filename=nn.name) attack_evaluation(net=nn, x_test=x_test, x_attack=x_attack, y_test=y_test, device=args.device) elif args.model_type == "bnn": bayesian_attack_samples = [10] bayesian_defence_samples = [10] ### BNN model dataset, model = saved_BNNs["model_" + str(args.model_idx)] batch_size = 5000 if model["inference"] == "hmc" else 128 x_train, y_train, x_test, y_test, inp_shape, out_size = \ load_dataset(dataset_name=dataset, n_inputs=train_inputs) train_loader = DataLoader(dataset=list(zip(x_train, y_train)), batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset=list(zip(x_test, y_test))) bnn = BNN(dataset, *list(model.values()), inp_shape, out_size) if args.train: bnn.train(train_loader=train_loader, device=args.device) else: bnn.load(device=args.device, rel_path=rel_path) if args.test: bnn.evaluate(test_loader=test_loader, device=args.device, n_samples=10) ### attack BNN x_test, y_test = (torch.from_numpy(x_test[:args.n_inputs]), torch.from_numpy(y_test[:args.n_inputs])) for attack_samples in bayesian_attack_samples: x_attack = attack(net=bnn, x_test=x_test, y_test=y_test, dataset_name=dataset, device=args.device, method=args.attack_method, filename=bnn.name, n_samples=attack_samples, hyperparams=hyperparams) for defence_samples in bayesian_defence_samples: attack_evaluation(net=bnn, x_test=x_test, x_attack=x_attack, y_test=y_test, device=args.device, n_samples=defence_samples) elif args.model_type == "avg_ensemble": ensemble_size = 10 n_samples = 10 dataset, hid, activ, arch, ep, lr = saved_NNs[ "model_" + str(args.model_idx)].values() _, _, x_test, y_test, inp_shape, out_size = \ load_dataset(dataset_name=dataset, n_inputs=args.n_inputs) test_loader = DataLoader(dataset=list(zip(x_test, y_test))) x_test, y_test = (torch.from_numpy(x_test[:args.n_inputs]), torch.from_numpy(y_test[:args.n_inputs])) ens_net = Ensemble_NN(dataset_name=dataset, input_shape=inp_shape, output_size=out_size, hidden_size=hid, activation=activ, architecture=arch, epochs=ep, lr=lr, ensemble_size=ensemble_size) results = torch.empty(size=(n_samples, 3)) for seed in range(0, n_samples): nn = NN(dataset_name=dataset, input_shape=inp_shape, output_size=out_size, hidden_size=hid, activation=activ, architecture=arch, epochs=ep, lr=lr) nn.load(device=args.device, rel_path=rel_path, savedir=ens_net.name + "/weights", seed=seed) nn_attack = attack(net=nn, x_test=x_test, y_test=y_test, dataset_name=dataset, device=args.device, method=args.attack_method, filename=nn.name, hyperparams=hyperparams) test_acc, adv_acc, softmax_rob = attack_evaluation( net=nn, x_test=x_test, x_attack=nn_attack, y_test=y_test, device=args.device) results[seed] = torch.tensor( [test_acc, adv_acc, softmax_rob.mean(0)]) avg_res = results.mean(0) print( f"\navg test_acc = {avg_res[0]:.2f}\tavg adv_acc = {avg_res[1]:.2f}\tavg avg_softmax_rob = {avg_res[2]:.2f}" ) elif args.model_type == "ensemble": ensemble_size = 10 n_samples = 10 dataset, hid, activ, arch, ep, lr = saved_NNs[ "model_" + str(args.model_idx)].values() _, _, x_test, y_test, inp_shape, out_size = \ load_dataset(dataset_name=dataset, n_inputs=args.n_inputs) test_loader = DataLoader(dataset=list(zip(x_test, y_test))) x_test, y_test = (torch.from_numpy(x_test[:args.n_inputs]), torch.from_numpy(y_test[:args.n_inputs])) ens_net = Ensemble_NN(dataset_name=dataset, input_shape=inp_shape, output_size=out_size, hidden_size=hid, activation=activ, architecture=arch, epochs=ep, lr=lr, ensemble_size=ensemble_size) ens_net.load(device=args.device, rel_path=rel_path) ens_attack = attack(net=ens_net, x_test=x_test, y_test=y_test, dataset_name=dataset, device=args.device, method=args.attack_method, filename=ens_net.name, hyperparams=hyperparams) test_acc, adv_acc, softmax_rob = attack_evaluation(net=ens_net, x_test=x_test, x_attack=ens_attack, y_test=y_test, device=args.device) else: raise NotImplementedError()