def _valid_epoch(self, epoch): """ validation after training an epoch :return: """ self.score_net.eval() x, y = next(iter(self.data_loader.test_loader)) x = x.to(self.device) y = y.to(self.device) batch_size = y.shape[0] with torch.no_grad(): xhs = self.sdn_model(x) # internal_fm = self.sdn_model.internal_fm # self.sdn_model.internal_fm = [None]*len(internal_fm) internal_fm = torch.rand(2, 2) scores = self.score_net(x, internal_fm, xhs) # scores = self.score_net(x, xhs) stop_idx = self.q_posterior(self.args.policy_type, scores, stochastic=False, device=self.device) q = self.q_posterior(self.args.policy_type, scores, stochastic=True, device=self.device) if epoch == 20 or epoch == 50 or epoch == 70 or epoch == 99: stop_idx = self.q_posterior(self.args.policy_type, scores, stochastic=False, device=self.device) q_p = self.q_posterior(self.args.policy_type, scores, stochastic=True, device=self.device) q_p_idx = torch.argmax(q_p, dim=-1) p_true, _ = self.true_posterior(self.args, xhs, y) p_true_b = max_onehot(p_true, dim=-1, device=self.device) p_true_idx = torch.argmax(p_true_b, dim=-1) print('Here is the policy classification validation accuracy:') print(data.accuracy(q_p, p_true_idx)) # pdb.set_trace() # validation loss if self.args.kl_type == 'forward': loss, _ = self.forward_kl_loss(y, xhs, scores, p_det=False) else: assert self.args.kl_type == 'backward' loss, _ = self.backward_kl_loss(y, xhs, scores) if self.args.stochastic: log = {'val loss': loss, 'sto q': torch.mean(q, dim=0)} else: log = {'val loss': loss, 'det q': torch.mean(stop_idx, dim=0)} return log, log
def forward_kl_loss(self, y, xhs, scores, p_det=True): batch_size = y.shape[0] # true posterior p_true, _ = self.true_posterior(self.args, xhs, y) assert batch_size == p_true.shape[0] p = torch.stack([p_true[:, t] for t in self.nz_post.values()], dim=1) # mse = torch.stack([mse_all[:, t] for t in self.nz_post.values()], dim=1) if p_det: p = max_onehot(p, dim=-1, device=self.device) # parameteric log posterior log_q_pi = self.log_q_posterior(self.args.policy_type, scores) assert batch_size == log_q_pi.shape[0] # pdb.set_trace() return -torch.sum(p * log_q_pi, dim=-1).mean(), -torch.sum( p * log_q_pi, dim=-1).mean()
def q_posterior(type, scores, stochastic=True, device='cuda'): if (type == 'multiclass') or (type == 'confidence'): if stochastic: return F.softmax(scores, dim=-1) else: return max_onehot(scores, dim=-1, device=device) if type == 'sequential': batch_size, num_train_post = scores.shape q_pi = [] pi = F.sigmoid(scores) if not stochastic: pi = (pi > 0.5).float() q_cont = torch.ones(batch_size).to(device) for i in range(num_train_post): # prob of stop at t q_pi.append(((1 - pi[:, i]) * q_cont).view(-1, 1)) # prob of continue to next q_cont = q_cont * pi[:, i] q_pi.append(q_cont.view(-1, 1)) return torch.cat(q_pi, dim=-1)
def _update_policy(self, x, y): self.model.eval() self.score_net.train() # pdb.set_trace() xhs = self.model(x) internal_fm = torch.rand(2, 2) scores = self.score_net(x, internal_fm, xhs) self.optimizer2.zero_grad() # loss # true posterior p_true, _ = PolicyKL.true_posterior(self.args, xhs, y) p = torch.stack([p_true[:, t] for t in self.nz_post.values()], dim=1) p = max_onehot(p, dim=-1) # parameteric log posterior log_q_pi = PolicyKL.log_q_posterior(self.args.policy_type, scores) loss = -torch.sum(p * log_q_pi, dim=-1).mean() # backward loss.backward() self.optimizer2.step() return loss
def _train_epoch(self, epoch): # n outputs, n-1 nets self.score_net.train() total_loss = 0.0 for i, batch in enumerate(self.train_data_generator): # generate path x, y = batch x = x.to(self.device) y = y.to(self.device) batch_size = y.shape[0] xhs = self.sdn_model(x) # internal_fm = self.sdn_model.internal_fm # self.sdn_model.internal_fm = [None]*len(internal_fm) internal_fm = torch.rand(2, 2) # if self.args.policy_type == 'sequential': # pi_all = [] # for i, t in self.train_post.items(): # pi_all.append(self.policy_nets[i](y, xhs[t]).view(-1)) # if self.args.policy_type == 'multiclass': # pi_all = self.policy_nets(y, xhs) scores = self.score_net(x, internal_fm, xhs) # scores = self.score_net(x, xhs) self.optimizer.zero_grad() # loss if self.args.kl_type == 'forward': loss, _ = self.forward_kl_loss(y, xhs, scores, p_det=False) else: assert self.args.kl_type == 'backward' loss, _ = self.backward_kl_loss(y, xhs, scores) # backward loss.backward() self.optimizer.step() if i % self.args.iters_per_eval == 0: print('Epoch: {}, Step: {}, Loss: {}'.format(epoch, i, loss)) total_loss += loss.item() if epoch == 20 or epoch == 50 or epoch == 70 or epoch == 99: self.score_net.eval() x, y = next(iter(self.train_data_generator)) x = x.to(self.device) y = y.to(self.device) batch_size = y.shape[0] xhs = self.sdn_model(x) internal_fm = torch.rand(2, 2) scores = self.score_net(x, internal_fm, xhs) stop_idx = self.q_posterior(self.args.policy_type, scores, stochastic=False, device=self.device) q_p = self.q_posterior(self.args.policy_type, scores, stochastic=True, device=self.device) q_p_idx = torch.argmax(q_p, dim=-1) p_true, _ = self.true_posterior(self.args, xhs, y) p_true_b = max_onehot(p_true, dim=-1, device=self.device) p_true_idx = torch.argmax(p_true_b, dim=-1) print('Here is the policy classification training accuracy:') print(data.accuracy(q_p, p_true_idx)) # p_true_max, _ = self.true_posterior_max(self.args, xhs, y) # p_true_max_b = max_onehot(p_true_max, dim=-1, device=self.device) # p_true_max_b_idx = torch.argmax(p_true_max_b, dim=-1) # pdb.set_trace() log = {'epo': epoch, 'train loss': total_loss / i} return log
def policy_training(models_path, device='cpu'): #sdn_name = 'cifar10_vgg16bn_bd_sdn_converted'; add_trigger = True # for the backdoored network add_trigger = False #task = 'cifar10' # task = 'cifar100' task = 'tinyimagenet' network = 'vgg16bn' #network = 'resnet56' #network = 'wideresnet32_4' #network = 'mobilenet' # sdn_name = task + '_' + network + '_sdn_ic_only' sdn_name = task + '_' + network + '_sdn_ic_only_ic1' # sdn_name = task + '_' + network + '_sdn_ic_only_ic1_ds' # sdn_name = task + '_' + network + '_sdn_sdn_training' # sdn_name = task + '_' + network + '_sdn_sdn_training_ds' # sdn_name = task + '_' + network + '_sdn_sdn_training_ic14_ds' sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1) sdn_model.to(device) dataset = af.get_dataset(sdn_params['task'], add_trigger=add_trigger) # need to construct the policy network and train the policy net. # the architecture of the policy network need to be designed. ###################################### # need to think about the model of policynet ###################################### sdn_model.eval() p_true_all = list() xhs_all = list() y_all = list() for batch in dataset.val_loader: x, y = batch x = x.to(device) y = y.to(device) batch_size = y.shape[0] with torch.no_grad(): xhs = sdn_model(x) categories = xhs[-1].shape[-1] # pdb.set_trace() # internal_fm = sdn_model.internal_fm # sdn_model.internal_fm = [None]*len(internal_fm) p_true, _ = PolicyKL.true_posterior(cmd_args, xhs, y) xhs_all.append(xhs) y_all.append(y) p_true_all.append(p_true) p_true = torch.cat(p_true_all, dim=0) p_det = max_onehot(p_true, dim=-1, device=device) p_true = torch.mean(p_true, dim=0) # find positions with nonzero posterior train_post = {} nz_post = {} i = 0 for t in range(cmd_args.num_output): if p_true[t] > 0.001: train_post[i] = t nz_post[i] = t i += 1 del train_post[i - 1] p_str = 'val p true:[' p_str += ','.join(['%0.3f' % p_true[t] for t in nz_post.values()]) print(p_str + ']') p_det = torch.mean(p_det, dim=0) p_str = 'val p true det:[' p_str += ','.join(['%0.3f' % p_det[t] for t in nz_post.values()]) print(p_str + ']') ###################################### #### #check the performance based on confidence score #### y_all = torch.cat(y_all, dim=-1) xhs_all = list(zip(*xhs_all)) for i in range(len(xhs_all)): xhs_all[i] = torch.cat(xhs_all[i], dim=0) print('The {}th classifier performance:'.format(i)) prec1, prec5 = data.accuracy(xhs_all[i], y_all, topk=(1, 5)) print('Top1 Test accuracy: {}'.format(prec1)) print('Top5 Test accuracy: {}'.format(prec5)) xhs_all = list(map(lambda x: F.softmax(x, dim=-1), xhs_all)) max_confidences = list(map(lambda x: torch.max(x, dim=-1)[0], xhs_all)) max_confidences = torch.stack(max_confidences, dim=-1) xhs_all_stack = torch.stack(xhs_all, dim=1) predictions = list(map(lambda x: torch.argmax(x, dim=-1), xhs_all)) predictions = torch.stack(predictions, dim=-1) thresholds = [ 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 0.999, -1 ] # thresholds = [0.8, 0.9, 0.95, 0.99, 0.999, -1] for threshold in thresholds: if threshold == -1: index = torch.argmax(max_confidences, dim=-1).cpu().numpy() # pdb.set_trace() else: mask = (max_confidences > threshold).to(int).cpu().numpy() mask[:, -1] = 1 index = np.array(list(map(lambda x: list(x).index(1), list(mask)))) results = xhs_all_stack.gather( 1, torch.Tensor([index] * 200).t().view( -1, 1, 200).long().to(device)).squeeze() prec1, prec5 = data.accuracy(results, y_all, topk=(1, 5)) print('htreshold: ', threshold) print('Top1 Test accuracy: {}'.format(prec1)) print('Top5 Test accuracy: {}'.format(prec5)) #### #confidence score check finish #### # pdb.set_trace() internal_fm = [torch.rand(2, 2) for i in range(cmd_args.num_output)] # initialize nets with nonzero posterior if cmd_args.model_type == 'sequential': score_net = MNIconfidence(cmd_args, x, internal_fm, train_post, category=categories, share=cmd_args.share, net_size=cmd_args.net_size) score_net.to(device) # print('Sequential model to be implemented') if cmd_args.model_type == 'multiclass': score_net = MulticlassNetImage(cmd_args, x, internal_fm, train_post, category=categories) score_net.to(device) if cmd_args.model_type == 'confidence': score_net = MNIconfidence(cmd_args, x, internal_fm, train_post, category=categories, share=cmd_args.share, net_size=cmd_args.net_size) score_net.to(device) if cmd_args.model_type == 'imiconfidence': score_net = Imiconfidence(cmd_args, x, internal_fm, train_post, category=categories, share=cmd_args.share, net_size=cmd_args.net_size) score_net.to(device) # train if cmd_args.phase == 'train': # start training optimizer = optim.Adam(list(score_net.parameters()), lr=cmd_args.learning_rate, weight_decay=cmd_args.weight_decay) milestones = [10, 20, 40, 60, 80] gammas = [0.4, 0.2, 0.2, 0.2, 0.2] scheduler = MultiStepMultiLR(optimizer, milestones=milestones, gammas=gammas) trainer = PolicyKL(args=cmd_args, sdn_model=sdn_model, score_net=score_net, train_post=train_post, nz_post=nz_post, optimizer=optimizer, data_loader=dataset, device=device, scheduler=scheduler, sdn_name=sdn_name) trainer.train() #pdb.set_trace() # test dump = cmd_args.save_dir + '/{}_best_val_policy.dump'.format(sdn_name) print('Loading model...') score_net.load_state_dict(torch.load(dump)) PolicyKL.test(args=cmd_args, score_net=score_net, sdn_model=sdn_model, data_loader=dataset.test_loader, nz_post=nz_post, device=device) print(cmd_args.save_dir)