def finetunning(self, x_spt, y_spt, x_qry, y_qry): """ :param x_spt: [setsz, c_, h, w] :param y_spt: [setsz] :param x_qry: [querysz, c_, h, w] :param y_qry: [querysz] :return: """ assert len(x_spt.shape) == 4 querysz = x_qry.size(0) corrects = [0 for _ in range(self.update_step_test + 1)] need_adv = True beta = 0 tradesloss = 0 optimizer = torch.optim.SGD(self.net.parameters(), lr=self.update_lr, momentum=0.9, weight_decay=5e-4) eps, step = (8, 10) corrects_adv = [0 for _ in range(self.update_step_test + 1)] corrects_adv_prior = [0 for _ in range(self.update_step_test + 1)] # in order to not ruin the state of running_mean/variance and bn_weight/bias # we finetunning on the copied model instead of self.net net = deepcopy(self.net) # 1. run the i-th task and compute loss for k=0 logits = net(x_spt) loss = F.cross_entropy(logits, y_spt) #tradesloss = self.trades_loss(net, net.parameters(), cifar, device = torch.device('cuda:0'),epsilon=eps) grad = torch.autograd.grad(loss + beta * tradesloss, net.parameters()) fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))) #PGD AT if need_adv: at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step, DEVICE=self.device) # data = x_spt # label = y_spt # optimizer.zero_grad() # adv_inp = at.attack(self.net, self.net.parameters(), data, label) # logits = self.net(adv_inp, self.net.parameters(), bn_training=True) # loss = F.cross_entropy(logits, label) # grad = torch.autograd.grad(loss, self.net.parameters()) # fast_weights_adv = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) data = x_qry label = y_qry optimizer.zero_grad() adv_inp_adv = at.attack(net, fast_weights, data, label) # this is the loss and accuracy before first update with torch.no_grad(): # [setsz, nway] logits_q = net(x_qry, net.parameters(), bn_training=True) # [setsz] pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) #find the correct index corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() # scalar correct = torch.eq(pred_q, y_qry).sum().item() corrects[0] = corrects[0] + correct #PGD AT if need_adv: data = x_qry label = y_qry optimizer.zero_grad() adv_inp = at.attack(net, net.parameters(), data, label) with torch.no_grad(): logits_q_adv = net(adv_inp, net.parameters(), bn_training=True) pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() corrects_adv[0] = corrects_adv[0] + correct_adv # corrects_adv_prior[0] = corrects_adv_prior[0] + correct_adv_prior/len(corr_ind) # this is the loss and accuracy after the first update with torch.no_grad(): # [setsz, nway] logits_q = net(x_qry, fast_weights, bn_training=True) # [setsz] pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) #find the correct index corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() # scalar correct = torch.eq(pred_q, y_qry).sum().item() corrects[1] = corrects[1] + correct #PGD AT if need_adv: logits_q_adv = net(adv_inp_adv, fast_weights, bn_training=True) pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() corrects_adv[1] = corrects_adv[1] + correct_adv # corrects_adv_prior[1] = corrects_adv_prior[1] + correct_adv_prior/len(corr_ind) for k in range(1, self.update_step_test): # 1. run the i-th task and compute loss for k=1~K-1 logits = net(x_spt, fast_weights, bn_training=True) loss = F.cross_entropy(logits, y_spt) #tradesloss = self.trades_loss(net, fast_weights, x_spt, device = torch.device('cuda:0'),epsilon=eps) # 2. compute grad on theta_pi grad = torch.autograd.grad(loss + beta * tradesloss, fast_weights) # 3. theta_pi = theta_pi - train_lr * grad fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) logits_q = net(x_qry, fast_weights, bn_training=True) # loss_q will be overwritten and just keep the loss_q on last update step. loss_q = F.cross_entropy(logits_q, y_qry) #PGD AT if need_adv and k == self.update_step_test - 1: at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step, DEVICE=self.device) # data = x_spt # label = y_spt # optimizer.zero_grad() # adv_inp = at.attack(self.net, fast_weights_adv, data, label) # logits = self.net(adv_inp, fast_weights_adv, bn_training=True) # loss = F.cross_entropy(logits, label) # grad = torch.autograd.grad(loss, fast_weights_adv) # fast_weights_adv = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights_adv))) data = x_qry label = y_qry optimizer.zero_grad() adv_inp_adv = at.attack(net, fast_weights, data, label) logits_q_adv = net(adv_inp_adv, fast_weights, bn_training=True) loss_q_adv = F.cross_entropy(logits_q_adv, label) with torch.no_grad(): pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) #find the correct index corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() correct = torch.eq(pred_q, y_qry).sum().item() # convert to numpy corrects[k + 1] = corrects[k + 1] + correct #PGD AT if need_adv: pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() corrects_adv[k + 1] = corrects_adv[k + 1] + correct_adv # corrects_adv_prior[k + 1] = corrects_adv_prior[k + 1] + correct_adv_prior/len(corr_ind) del net accs = np.array(corrects) / querysz accs_adv = np.array(corrects_adv) / querysz accs_adv_prior = np.array(corrects_adv_prior) return accs, accs_adv, accs_adv_prior
checkpoint['model'] = { k: v for k, v in checkpoint['model'].items() if model.state_dict()[k].numel() == v.numel() } model.load_state_dict(checkpoint['model'], strict=True) elif load_name.endswith('.npy'): checkpoint = np.load(load_name, allow_pickle=True).item() model_dict = { k: torch.from_numpy(checkpoint[k]) for k in checkpoint.keys() if model.state_dict()[k].numel() == torch.from_numpy(checkpoint[k]).numel() } model.load_state_dict(model_dict, strict=True) model.cuda().eval() model_adv = PGD(model) files = sorted(glob.glob('%s/*.jpg' % source)) colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(classes))] for img_name_index in trange(len(files)): img_path = files[img_name_index] img_name = img_path.split('/')[-1] original_image = Image.open(img_path) im = np.asarray(original_image) img_rgb = im.copy() im2show = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) if len(im.shape) == 2: im = im[:, :, np.newaxis] im = np.concatenate((im, im, im), axis=2)
def finetunning(self, x_spt, y_spt, x_qry, y_qry, net): """ :param x_spt: [setsz, c_, h, w] :param y_spt: [setsz] :param x_qry: [querysz, c_, h, w] :param y_qry: [querysz] :return: """ assert len(x_spt.shape) == 4 configtest = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []), ('linear', [5, 32 * 5 * 5])] studentnet = Learner(configtest, 3, 84) #.to('cuda:3') for i in range(0, 16): studentnet.parameters()[i] = net.parameters()[i] studentnet.to('cuda:3') querysz = x_qry.size(0) corrects = [0 for _ in range(self.update_step_test + 1)] need_adv = True optimizer = torch.optim.SGD(studentnet.parameters(), lr=self.update_lr, momentum=0.9, weight_decay=5e-4) eps, step = (2, 10) corrects_adv = [0 for _ in range(self.update_step_test + 1)] corrects_adv_prior = [0 for _ in range(self.update_step_test + 1)] # in order to not ruin the state of running_mean/variance and bn_weight/bias # we finetunning on the copied model instead of self.net #net = deepcopy(self.net) # 1. run the i-th task and compute loss for k=0 logits = studentnet(x_spt) loss = F.cross_entropy(logits, y_spt) grad = torch.autograd.grad(loss, studentnet.parameters()) fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, studentnet.parameters()))) #PGD AT if need_adv: at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) # data = x_spt # label = y_spt # optimizer.zero_grad() # adv_inp = at.attack(self.net, self.net.parameters(), data, label) # logits = self.net(adv_inp, self.net.parameters(), bn_training=True) # loss = F.cross_entropy(logits, label) # grad = torch.autograd.grad(loss, self.net.parameters()) # fast_weights_adv = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) data = x_qry label = y_qry optimizer.zero_grad() adv_inp_adv = at.attack(studentnet, fast_weights, data, label) # this is the loss and accuracy before first update with torch.no_grad(): # [setsz, nway] logits_q = studentnet(x_qry, studentnet.parameters(), bn_training=True) # [setsz] pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) #find the correct index corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() # scalar correct = torch.eq(pred_q, y_qry).sum().item() corrects[0] = corrects[0] + correct #PGD AT if need_adv: data = x_qry label = y_qry optimizer.zero_grad() adv_inp = at.attack(studentnet, studentnet.parameters(), data, label) with torch.no_grad(): logits_q_adv = studentnet(adv_inp, studentnet.parameters(), bn_training=True) pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() corrects_adv[0] = corrects_adv[0] + correct_adv corrects_adv_prior[0] = corrects_adv_prior[ 0] + correct_adv_prior / len(corr_ind) # this is the loss and accuracy after the first update with torch.no_grad(): # [setsz, nway] logits_q = studentnet(x_qry, fast_weights, bn_training=True) # [setsz] pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) #find the correct index corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() # scalar correct = torch.eq(pred_q, y_qry).sum().item() corrects[1] = corrects[1] + correct #PGD AT if need_adv: logits_q_adv = studentnet(adv_inp_adv, fast_weights, bn_training=True) pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() corrects_adv[1] = corrects_adv[1] + correct_adv corrects_adv_prior[1] = corrects_adv_prior[ 1] + correct_adv_prior / len(corr_ind) for k in range(1, self.update_step_test): # 1. run the i-th task and compute loss for k=1~K-1 logits = studentnet(x_spt, fast_weights, bn_training=True) loss = F.cross_entropy(logits, y_spt) # 2. compute grad on theta_pi grad = torch.autograd.grad(loss, fast_weights) # 3. theta_pi = theta_pi - train_lr * grad fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) logits_q = studentnet(x_qry, fast_weights, bn_training=True) # loss_q will be overwritten and just keep the loss_q on last update step. loss_q = F.cross_entropy(logits_q, y_qry) #PGD AT if need_adv: at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) # data = x_spt # label = y_spt # optimizer.zero_grad() # adv_inp = at.attack(self.net, fast_weights_adv, data, label) # logits = self.net(adv_inp, fast_weights_adv, bn_training=True) # loss = F.cross_entropy(logits, label) # grad = torch.autograd.grad(loss, fast_weights_adv) # fast_weights_adv = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights_adv))) data = x_qry label = y_qry optimizer.zero_grad() adv_inp_adv = at.attack(studentnet, fast_weights, data, label) logits_q_adv = studentnet(adv_inp_adv, fast_weights, bn_training=True) loss_q_adv = F.cross_entropy(logits_q_adv, label) with torch.no_grad(): pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) #find the correct index corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() correct = torch.eq(pred_q, y_qry).sum().item() # convert to numpy corrects[k + 1] = corrects[k + 1] + correct #PGD AT if need_adv: pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() corrects_adv[k + 1] = corrects_adv[k + 1] + correct_adv corrects_adv_prior[k + 1] = corrects_adv_prior[ k + 1] + correct_adv_prior / len(corr_ind) del studentnet accs = np.array(corrects) / querysz accs_adv = np.array(corrects_adv) / querysz accs_adv_prior = np.array(corrects_adv_prior) return accs, accs_adv, accs_adv_prior
def finetunning(self, x_spt, y_spt, x_qry, y_qry): """ :param x_spt: [setsz, c_, h, w] :param y_spt: [setsz] :param x_qry: [querysz, c_, h, w] :param y_qry: [querysz] :return: """ assert len(x_spt.shape) == 4 querysz = x_qry.size(0) corrects = [0 for _ in range(self.update_step_test + 1)] need_adv = False optimizer = torch.optim.SGD(self.net.parameters(), lr=self.update_lr, momentum=0.9, weight_decay=5e-4) eps, step = (2.0, 10) corrects_adv = [0 for _ in range(self.update_step_test + 1)] corrects_adv_prior = [0 for _ in range(self.update_step_test + 1)] # in order to not ruin the state of running_mean/variance and bn_weight/bias # we finetunning on the copied model instead of self.net net = deepcopy(self.net) # 1. run the i-th task and compute loss for k=0 logits = net(x_spt) loss = F.cross_entropy(logits, y_spt) grad = torch.autograd.grad(loss, net.parameters()) fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))) #PGD AT if need_adv: data = x_spt label = y_spt net.eval() data.requires_grad = True global_noise_data = torch.zeros(list(data.size())).cuda() global_noise_data.uniform_(-eps / 255.0, eps / 255.0) logits = net(data, fast_weights, bn_training=True) loss = F.cross_entropy(logits, label) grad_sign = torch.autograd.grad(loss, data, only_inputs=True, retain_graph=False)[0].sign() adv_inp = data + 1.25 * eps / 255.0 * grad_sign adv_inp.clamp_(0, 1.0) net.train() logits = net(adv_inp, fast_weights, bn_training=True) loss = F.cross_entropy(logits, label) grad = torch.autograd.grad(loss, net.parameters()) fast_weights = list( map(lambda p: p[1] - self.lr_adv * p[0], zip(grad, fast_weights))) if True: at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) data = x_qry label = y_qry optimizer.zero_grad() adv_inp_adv = at.attack(net, fast_weights, data, label) # this is the loss and accuracy before first update with torch.no_grad(): # [setsz, nway] logits_q = net(x_qry, net.parameters(), bn_training=True) # [setsz] pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) #find the correct index corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() # scalar correct = torch.eq(pred_q, y_qry).sum().item() corrects[0] = corrects[0] + correct #PGD AT if True: data = x_qry label = y_qry optimizer.zero_grad() adv_inp = at.attack(net, net.parameters(), data, label) with torch.no_grad(): logits_q_adv = net(adv_inp, net.parameters(), bn_training=True) pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() corrects_adv[0] = corrects_adv[0] + correct_adv corrects_adv_prior[0] = corrects_adv_prior[ 0] + correct_adv_prior / len(corr_ind) # this is the loss and accuracy after the first update with torch.no_grad(): # [setsz, nway] logits_q = net(x_qry, fast_weights, bn_training=True) # [setsz] pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) #find the correct index corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() # scalar correct = torch.eq(pred_q, y_qry).sum().item() corrects[1] = corrects[1] + correct #PGD AT if True: logits_q_adv = net(adv_inp_adv, fast_weights, bn_training=True) pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() corrects_adv[1] = corrects_adv[1] + correct_adv corrects_adv_prior[1] = corrects_adv_prior[ 1] + correct_adv_prior / len(corr_ind) for k in range(1, self.update_step_test): # 1. run the i-th task and compute loss for k=1~K-1 logits = net(x_spt, fast_weights, bn_training=True) loss = F.cross_entropy(logits, y_spt) # 2. compute grad on theta_pi grad = torch.autograd.grad(loss, fast_weights) # 3. theta_pi = theta_pi - train_lr * grad fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) logits_q = net(x_qry, fast_weights, bn_training=True) # loss_q will be overwritten and just keep the loss_q on last update step. loss_q = F.cross_entropy(logits_q, y_qry) #PGD AT if need_adv: data = x_spt label = y_spt net.eval() data.requires_grad = True global_noise_data = torch.zeros(list(data.size())).cuda() global_noise_data.uniform_(-eps / 255.0, eps / 255.0) logits = net(data, fast_weights, bn_training=True) loss = F.cross_entropy(logits, label) grad_sign = torch.autograd.grad(loss, data, only_inputs=True, retain_graph=False)[0].sign() adv_inp = data + 1.25 * eps / 255.0 * grad_sign adv_inp.clamp_(0, 1.0) net.train() logits = net(adv_inp, fast_weights, bn_training=True) loss = F.cross_entropy(logits, label) grad = torch.autograd.grad(loss, fast_weights) fast_weights = list( map(lambda p: p[1] - self.lr_adv * p[0], zip(grad, fast_weights))) if True: at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) data = x_qry label = y_qry optimizer.zero_grad() adv_inp_adv = at.attack(net, fast_weights, data, label) logits_q_adv = net(adv_inp_adv, fast_weights, bn_training=True) loss_q_adv = F.cross_entropy(logits_q_adv, label) with torch.no_grad(): pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) #find the correct index corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero() correct = torch.eq(pred_q, y_qry).sum().item() # convert to numpy corrects[k + 1] = corrects[k + 1] + correct #PGD AT if True: pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() correct_adv_prior = torch.eq(pred_q_adv[corr_ind], label[corr_ind]).sum().item() corrects_adv[k + 1] = corrects_adv[k + 1] + correct_adv corrects_adv_prior[k + 1] = corrects_adv_prior[ k + 1] + correct_adv_prior / len(corr_ind) del net accs = np.array(corrects) / querysz accs_adv = np.array(corrects_adv) / querysz accs_adv_prior = np.array(corrects_adv_prior) return accs, accs_adv, accs_adv_prior
def train(): cfg = args.cfg data = args.data weights = args.weights epochs = args.epochs batch_size = args.bs resume = args.resume # adv adv = args.adv imgs_weight = args.iw num_steps = args.num_steps step_size = args.step_size kdfa = args.kdfa ssfa = args.ssfa beta = args.beta gamma = args.gamma kdfa_weights = args.kdfa_weights tod = args.tod kdfa_cfg = cfg img_size = 416 rect = False multi_scale = False accumulate = 1 scale_factor = 0.5 num_workers = min([os.cpu_count(), batch_size, 16]) path = 'weights/' if not os.path.exists(path): os.makedirs(path) wdir = path + os.sep # weights dir last = wdir + 'last.pt' tb_writer = SummaryWriter() # Initialize init_seeds(seed=3) if multi_scale: img_sz_min = round(img_size / 32 / 1.5) + 1 img_sz_max = round(img_size / 32 * 1.5) - 1 img_size = img_sz_max * 32 # initiate with maximum multi_scale size print('Using multi-scale %g - %g' % (img_sz_min * 32, img_size)) # Configure run data_dict = parse_data_cfg(data) train_path = data_dict['train'] # Initialize model model = Darknet(cfg, arc='default').cuda().train() hyp = model.hyp # Optimizer pg0, pg1 = [], [] # optimizer parameter groups for k, v in dict(model.named_parameters()).items(): if 'Conv2d.weight' in k: pg1 += [v] # parameter group 1 (apply weight_decay) else: pg0 += [v] # parameter group 0 optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) optimizer.add_param_group({ 'params': pg1, 'weight_decay': hyp['weight_decay'] }) # add pg1 with weight_decay del pg0, pg1 start_epoch = 0 if weights.endswith('.pt'): # pytorch format # possible weights are 'last.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc. chkpt = torch.load(weights) # load model chkpt['model'] = { k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel() } model.load_state_dict(chkpt['model'], strict=True) if resume: if (chkpt['optimizer'] is not None): optimizer.load_state_dict(chkpt['optimizer']) start_epoch = chkpt['epoch'] + 1 del chkpt elif weights.endswith('.weights'): # darknet format # possible weights are 'yolov3.weights', 'yolov3-tiny.conv.15', 'darknet53.conv.74' etc. print('inherit model weights') if 'yolo' in weights: model.load_darknet_weights(weights) print(' inherit model weights from yolo pretrained weights') else: load_darknet_weights(model, weights) print(' do not inherit model weights from yolo pretrained weights') if adv: model_adv = PGD(model) if kdfa: model_t = Darknet(kdfa_cfg, arc='default').cuda().eval() print('inherit kdfa_weights') if 'yolo' in kdfa_weights: model_t.load_darknet_weights(kdfa_weights) print(' inherit model weights from yolo pretrained weights') else: load_darknet_weights(model_t, kdfa_weights) print( ' do not inherit model weights from yolo pretrained weights' ) for param_k in model_t.parameters(): param_k.requires_grad = False # Dataset dataset = LoadImagesAndLabels( train_path, img_size, batch_size, augment=True, hyp=hyp, # augmentation hyperparameters rect=rect, # rectangular training image_weights=False, cache_labels=True if epochs > 10 else False, cache_images=False) # Dataloader dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, shuffle=not rect, # Shuffle=True unless rectangular training is used pin_memory=True, collate_fn=dataset.collate_fn, drop_last=True) nb = len(dataloader) t0 = time.time() print('Starting %g for %g epochs...' % (start_epoch, epochs)) for epoch in range( start_epoch, epochs + 1 ): # epoch ------------------------------------------------------------------ if epoch == int(epochs * 2 / 3) + 1: for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * 0.1 print('param_group lr anealing: ', param_group['lr']) print(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'clean', 'adv', 'kd', 'ss', 'total')) mloss = torch.zeros( 5).cuda() # mean losses,'clean', 'adv', 'kd', 'ss', 'total' loss_ss = torch.zeros(1).cuda() loss_kd = torch.zeros(1).cuda() loss_clean = torch.zeros(1).cuda() loss_adv = torch.zeros(1).cuda() loss = torch.zeros(1).cuda() pbar = tqdm(enumerate(dataloader), total=nb) # progress bar for i, ( imgs, targets, paths, _ ) in pbar: # batch ------------------------------------------------------------- ni = i + nb * epoch # number integrated batches (since train start) imgs = imgs.cuda() targets = targets.cuda() # Multi-Scale training if multi_scale: if ni / accumulate % 10 == 0: # adjust (67% - 150%) every 10 batches img_size = random.randrange(img_sz_min, img_sz_max + 1) * 32 sf = img_size / max(imgs.shape[2:]) # scale factor if sf != 1: ns = [ math.ceil(x * sf / 32.) * 32 for x in imgs.shape[2:] ] # new shape (stretched to 32-multiple) imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) if adv: if tod: imgs_adv, loss_clean = model_adv.adv_sample_train_tod( imgs, targets, step_size=step_size, num_steps=num_steps, all_bp=True, sf=imgs_weight * scale_factor) pred = model(imgs_adv, fa=False) loss_adv, loss_items = compute_loss(pred, targets, model) loss_adv *= (1 - imgs_weight) else: imgs_adv, ssfa_out, loss_clean = model_adv.adv_sample_train( imgs, targets, step_size=step_size, all_bp=True, sf=imgs_weight * scale_factor, num_steps=num_steps) pred, fa_out = model(imgs_adv, fa=True) fa_out_norm = F.normalize(fa_out, dim=1) loss_adv, loss_items = compute_loss(pred, targets, model) loss_adv *= (1 - imgs_weight) if kdfa: kdfa_out = model_t(imgs, fa=True, only_fa=True) kdfa_out_norm = F.normalize(kdfa_out, dim=1) kd_sim = torch.einsum('nc,nc->n', [fa_out_norm, kdfa_out_norm]) kd_sim.data.clamp_(-1., 1.) loss_kd = (1. - kd_sim).mean().view(-1) * beta if ssfa: ssfa_out_norm = F.normalize(ssfa_out, dim=1) ss_sim = torch.einsum('nc,nc->n', [fa_out_norm, ssfa_out_norm]) ss_sim.data.clamp_(-1., 1.) loss_ss = (1 - ss_sim).mean().view(-1) * gamma else: pred = model(imgs, fa=False) loss_adv, loss_items = compute_loss(pred, targets, model) loss_kd *= scale_factor loss_ss *= scale_factor loss_adv *= scale_factor loss_items = torch.cat( (loss_clean, loss_adv, loss_kd, loss_ss, (loss_clean + loss_adv + loss_kd + loss_ss))).detach() loss = loss_adv + loss_kd + loss_ss if not torch.isfinite(loss): print('WARNING: non-finite loss, ending training ', loss_items) break loss.backward() # Accumulate gradient for x batches before optimizing if ni % accumulate == 0: optimizer.step() optimizer.zero_grad() # Print batch results mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available( ) else 0 # (GB) script = ('%10s' * 2 + '%10.3g' * 5) % ('%g/%g' % (epoch, epochs), '%.3gG' % mem, *mloss) pbar.set_description(script) # end batch ------------------------------------------------------------------------------------------------ # Write Tensorboard results x = list( mloss.cpu().numpy() ) # + list(results) + list([thre]) + list([prune_ratio]) + list([par_prune_ratio]) titles = ['Loss_clean', 'Loss_adv', 'Loss_kd', 'Loss_ss', 'Train_loss'] for xi, title in zip(x, titles): tb_writer.add_scalar(title, xi, epoch) chkpt = { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(chkpt, last) if epoch > 0 and (epoch) % 5 == 0: torch.save(chkpt, wdir + 'backup%g.pt' % epoch) if epoch == epochs: convert(cfg=cfg, weights=wdir + 'backup%g.pt' % epoch) del chkpt time_consume = '%g epochs completed in %.3f hours.\n' % ( epoch - start_epoch + 1, (time.time() - t0) / 3600) print(time_consume) # end epoch ---------------------------------------------------------------------------------------------------- torch.cuda.empty_cache()
def forward(self, x_spt, y_spt, x_qry, y_qry, x_nat): """ :param x_spt: [b, setsz, c_, h, w] :param y_spt: [b, setsz] :param x_qry: [b, querysz, c_, h, w] :param y_qry: [b, querysz] :return: """ task_num, setsz, c_, h, w = x_spt.size() querysz = x_qry.size(1) losses_q = [0 for _ in range(self.update_step + 1)] # losses_q[i] is the loss on step i corrects = [0 for _ in range(self.update_step + 1)] need_adv = False beta = 2.5 #AT optimizer = torch.optim.SGD(self.net.parameters(), lr=self.update_lr, momentum=0.9, weight_decay=5e-4) optimizertrade = torch.optim.SGD(self.net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) eps, step = (8.0,10) losses_q_adv = [0 for _ in range(self.update_step + 1)] corrects_adv = [0 for _ in range(self.update_step + 1)] for i in range(task_num): x_q = x_qry[i].view(-1, 3, 32, 32) x_s = x_spt[i].view(-1, 3, 32, 32) if x_nat != None: x_unlab = x_nat[i].view(-1, 3, 32, 32) # 1. run the i-th task and compute loss for k=0 logits = self.net(x_spt[i], vars=None, bn_training=True) loss = F.cross_entropy(logits, y_spt[i]) grad = torch.autograd.grad(loss, self.net.parameters()) fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) #PGD AT if need_adv: at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step, DEVICE=self.device) # data = x_spt[i] # label = y_spt[i] # optimizer.zero_grad() # adv_inp = at.attack(self.net, self.net.parameters(), data, label) # logits = self.net(adv_inp, self.net.parameters(), bn_training=True) # loss = F.cross_entropy(logits, label) # grad = torch.autograd.grad(loss, self.net.parameters()) # fast_weights_adv = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) # #print(fast_weights_adv - self.net.parameters()) data = x_qry[i] label = y_qry[i] optimizer.zero_grad() adv_inp_adv = at.attack(self.net, fast_weights, data, label) optimizer.zero_grad() self.net.train() # this is the loss and accuracy before first update with torch.no_grad(): # [setsz, nway] logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True) loss_q = F.cross_entropy(logits_q, y_qry[i]) # tradesloss = self.trades_loss(self.net, self.net.parameters(), optimizertrade, x_nat, device = self.device,epsilon=eps) losses_q[0] += loss_q# + beta*tradesloss pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) correct = torch.eq(pred_q, y_qry[i]).sum().item() corrects[0] = corrects[0] + correct #PGD AT if need_adv: data = x_qry[i] label = y_qry[i] optimizer.zero_grad() adv_inp = at.attack(self.net, self.net.parameters(), data, label) optimizer.zero_grad() self.net.train() with torch.no_grad(): logits_q_adv = self.net(adv_inp, self.net.parameters(), bn_training=True) loss_q_adv = F.cross_entropy(logits_q_adv, label) losses_q_adv[0] += loss_q_adv pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() corrects_adv[0] = corrects_adv[0] + correct_adv # this is the loss and accuracy after the first update with torch.no_grad(): # [setsz, nway] logits_q = self.net(x_qry[i], fast_weights, bn_training=True) loss_q = F.cross_entropy(logits_q, y_qry[i]) # tradesloss = self.trades_loss(self.net, fast_weights, optimizertrade, x_nat, device = self.device,epsilon=eps) losses_q[1] += loss_q# + beta*tradesloss # [setsz] pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) correct = torch.eq(pred_q, y_qry[i]).sum().item() corrects[1] = corrects[1] + correct #PGD AT if need_adv: logits_q_adv = self.net(adv_inp_adv, fast_weights, bn_training=True) loss_q_adv = F.cross_entropy(logits_q_adv, label) losses_q_adv[1] += loss_q_adv pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() corrects_adv[1] = corrects_adv[1] + correct_adv for k in range(1, self.update_step): # 1. run the i-th task and compute loss for k=1~K-1 logits = self.net(x_spt[i], fast_weights, bn_training=True) loss = F.cross_entropy(logits, y_spt[i]) # 2. compute grad on theta_pi grad = torch.autograd.grad(loss, fast_weights) # 3. theta_pi = theta_pi - train_lr * grad fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) logits_q = self.net(x_qry[i], fast_weights, bn_training=True) # loss_q will be overwritten and just keep the loss_q on last update step. loss_q = F.cross_entropy(logits_q, y_qry[i]) if k == self.update_step - 1: if x_nat == None: x_natt = x_q else: x_natt = torch.cat((x_q, x_unlab), 0) x_natt = torch.cat((x_s, x_natt), 0) criterion_kl = nn.KLDivLoss(size_average=False) self.net.eval() #global global_noise_data global_noise_data = torch.zeros(list(x_natt.size())).to(self.device) global_noise_data.uniform_(-eps/255.0, eps/255.0) noise_batch = Variable(global_noise_data[0:x_natt.size(0)], requires_grad=True).to(self.device) x_adv = x_natt + noise_batch x_adv.clamp_(0, 1.0) log1 = self.net(x_adv,fast_weights) log2 = self.net(x_natt,fast_weights) # log22 = F.softmax(log2, dim=1).argmax(dim=1) # loss_kl = F.cross_entropy(log1,log22) loss_kl = criterion_kl(F.log_softmax(log1, dim=1), F.softmax(log2, dim=1)) loss_kl.backward() #grad = torch.autograd.grad(loss_kl, [noise_batch])[0] global_noise_data = global_noise_data + 1.25*eps/255.0*torch.sign(noise_batch.grad) global_noise_data.clamp_(-eps/255.0, eps/255.0) noise_batch = Variable(global_noise_data[0:x_natt.size(0)], requires_grad=False).to(self.device) x_adv = x_natt + noise_batch x_adv.clamp_(0, 1.0) self.net.train() # zero gradient optimizer.zero_grad() # calculate robust loss x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) logits = self.net(x_natt,fast_weights) adv_logits = self.net(x_adv,fast_weights) tradesloss = (1.0 / x_natt.size(0)) * criterion_kl(F.log_softmax(adv_logits, dim=1),F.softmax(logits, dim=1)) else: tradesloss = 0 #tradesloss = self.trades_loss(self.net, fast_weights, optimizertrade, x_nat, device = self.device,epsilon=eps) losses_q[k + 1] += loss_q + beta*tradesloss #PGD AT # if need_adv: # at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) # # data = x_spt[i] # # label = y_spt[i] # # optimizer.zero_grad() # # adv_inp = at.attack(self.net, fast_weights_adv, data, label) # # logits = self.net(adv_inp, fast_weights_adv, bn_training=True) # # loss = F.cross_entropy(logits, label) # # grad = torch.autograd.grad(loss, fast_weights_adv) # # fast_weights_adv = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights_adv))) # data = x_qry[i] # label = y_qry[i] # optimizer.zero_grad() # adv_inp_adv = at.attack(self.net, fast_weights, data, label) # optimizer.zero_grad() # logits_q_adv = self.net(adv_inp_adv, fast_weights, bn_training=True) # loss_q_adv = F.cross_entropy(logits_q_adv, label) # losses_q_adv[k + 1] += loss_q_adv with torch.no_grad(): pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) correct = torch.eq(pred_q, y_qry[i]).sum().item() # convert to numpy corrects[k + 1] = corrects[k + 1] + correct #PGD AT if need_adv: pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() corrects_adv[k + 1] = corrects_adv[k + 1] + correct_adv # end of all tasks # sum over all losses on query set across all tasks loss_q = losses_q[-1] / task_num loss_q_adv = losses_q_adv[-1] / task_num # optimize theta parameters self.meta_optim.zero_grad() loss_q.backward() # print('meta update') # for p in self.net.parameters()[:5]: # print(torch.norm(p).item()) self.meta_optim.step() # self.meta_optim.zero_grad() # loss_q_adv.backward() # self.meta_optim.step() accs = np.array(corrects) / (querysz * task_num) accs_adv = np.array(corrects_adv) / (querysz * task_num) return accs, accs_adv
def test_adv( cfg, data, weights=None, batch_size=4, step_size=0.01, num_steps=3, test_type='test', iou_thres=0.5, nms_thres=0.5, conf_thres=0.001, img_size=416, ): data = parse_data_cfg(data) nc = int(data['classes']) # number of classes if test_type == 'valid': test_path = data['valid'] # path to test images elif test_type == 'test': test_path = data['test'] print('test_path:', test_path) # Initialize model model = Darknet(cfg, img_size).cuda().eval() if weights.endswith('.pt'): # pytorch format chkpt = torch.load(weights) chkpt['model'] = { k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel() } model.load_state_dict(chkpt['model'], strict=True) del chkpt elif weights.endswith('.weights'): # darknet format print('inherit model weights') if 'yolo' in weights: model.load_darknet_weights(weights) print(' inherit model weights from yolo pretrained weights') else: load_darknet_weights(model, weights) print(' do not inherit model weights from yolo pretrained weights') model_adv = PGD(model) dataset = LoadImagesAndLabels( test_path, img_size=img_size, batch_size=batch_size, augment=False, hyp=None, # augmentation hyperparameters rect=True, # rectangular training image_weights=False, cache_labels=False, cache_images=False) dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=min([os.cpu_count(), batch_size, 16]), shuffle=False, # Shuffle=True unless rectangular training is used pin_memory=True, collate_fn=dataset.collate_fn) # Run inference seen = 0 p, r, f1, mp, mr, map, mf1 = 0., 0., 0., 0., 0., 0., 0. loss = torch.zeros(3) jdict, stats, ap, ap_class = [], [], [], [] s = ('%20s' + '%10s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP', 'F1') for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): imgs = imgs.cuda() targets = targets.cuda() if num_steps * step_size > 0: input_img = model_adv.adv_sample_infer(imgs, targets, step_size=step_size, num_steps=num_steps) else: input_img = imgs _, _, height, width = imgs.shape # batch size, channels, height, width # Run model inf_out, train_out = model(input_img) # inference and training outputs if hasattr(model, 'hyp'): # if model has loss hyperparameters loss += compute_loss(train_out, targets, model)[1][:3].cpu() # GIoU, obj, cls # Run NMS output = non_max_suppression(inf_out, conf_thres=conf_thres, nms_thres=nms_thres) # Statistics per image for si, pred in enumerate(output): labels = targets[targets[:, 0] == si, 1:] nl = len(labels) tcls = labels[:, 0].tolist() if nl else [] # target class if pred is None: if nl: stats.append(([], torch.Tensor(), torch.Tensor(), tcls)) continue # Clip boxes to image bounds clip_coords(pred, (height, width)) # Assign all predictions as incorrect correct = [0] * len(pred) if nl: detected = [] tcls_tensor = labels[:, 0] seen += 1 # target boxes tbox = xywh2xyxy(labels[:, 1:5]) tbox[:, [0, 2]] *= width tbox[:, [1, 3]] *= height # Search for correct predictions for i, (*pbox, pconf, pcls_conf, pcls) in enumerate(pred): # Break if all targets already located in image if len(detected) == nl: break # Continue if predicted class not among image classes if pcls.item() not in tcls: continue # Best iou, index between pred and targets m = (pcls == tcls_tensor).nonzero().view(-1) iou, bi = bbox_iou(pbox, tbox[m]).max(0) # If iou > threshold and class is correct mark as correct if iou > iou_thres and m[ bi] not in detected: # and pcls == tcls[bi]: correct[i] = 1 detected.append(m[bi]) # Append statistics (correct, conf, pcls, tcls) stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls)) # Compute statistics stats = [np.concatenate(x, 0) for x in list(zip(*stats))] # to numpy if len(stats): p, r, ap, f1, ap_class = ap_per_class(*stats) mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean() nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class else: nt = torch.zeros(1) # Print results pf = '%20s' + '%10.3g' * 6 # print format print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1)) return map
x_train = x_train.astype('float32') / 255. x_test = x_test.astype('float32') / 255. y_train = keras.utils.to_categorical(y_train, 10) y_test = keras.utils.to_categorical(y_test, 10) return x_train, x_test, y_train, y_test path = "./mnist.npz" x_train, x_test, y_train, y_test = load_mnist(path) # load your model model = keras.models.load_model("./Lenet5_mnist.h5") fgsm = FGSM(model, ep=0.3, isRand=True) pgd = PGD(model, ep=0.3, epochs=10, isRand=True) # generate adversarial examples at once. advs, labels, fols, ginis = fgsm.generate(x_train, y_train) np.savez('./FGSM_TrainFull.npz', advs=advs, labels=labels, fols=fols, ginis=ginis) advs, labels, fols, ginis = pgd.generate(x_train, y_train) np.savez('./PGD_TrainFull.npz', advs=advs, labels=labels, fols=fols, ginis=ginis) advs, labels, _, _ = fgsm.generate(x_test, y_test) np.savez('./FGSM_Test.npz', advs=advs, labels=labels) advs, labels, _, _ = pgd.generate(x_test, y_test) np.savez('./PGD_Test.npz', advs=advs, labels=labels)
def forward(self, x_spt, y_spt, x_qry, y_qry): """ :param x_spt: [b, setsz, c_, h, w] :param y_spt: [b, setsz] :param x_qry: [b, querysz, c_, h, w] :param y_qry: [b, querysz] :return: """ task_num, setsz, c_, h, w = x_spt.size() querysz = x_qry.size(1) losses_q = [0 for _ in range(self.update_step + 1) ] # losses_q[i] is the loss on step i corrects = [0 for _ in range(self.update_step + 1)] need_adv = False #AT optimizer = torch.optim.SGD(self.net.parameters(), lr=self.update_lr, momentum=0.9, weight_decay=5e-4) eps, step = (4.0, 10) losses_q_adv = [0 for _ in range(self.update_step + 1)] corrects_adv = [0 for _ in range(self.update_step + 1)] for i in range(task_num): # 1. run the i-th task and compute loss for k=0 logits = self.net(x_spt[i], vars=None, bn_training=True) loss = F.cross_entropy(logits, y_spt[i]) grad = torch.autograd.grad(loss, self.net.parameters()) fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) #PGD AT if need_adv: at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) # data = x_spt[i] # label = y_spt[i] # optimizer.zero_grad() # adv_inp = at.attack(self.net, self.net.parameters(), data, label) # logits = self.net(adv_inp, self.net.parameters(), bn_training=True) # loss = F.cross_entropy(logits, label) # grad = torch.autograd.grad(loss, self.net.parameters()) # fast_weights_adv = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) # #print(fast_weights_adv - self.net.parameters()) data = x_qry[i] label = y_qry[i] optimizer.zero_grad() adv_inp_adv = at.attack(self.net, fast_weights, data, label) optimizer.zero_grad() # this is the loss and accuracy before first update with torch.no_grad(): # [setsz, nway] logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True) loss_q = F.cross_entropy(logits_q, y_qry[i]) losses_q[0] += loss_q pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) correct = torch.eq(pred_q, y_qry[i]).sum().item() corrects[0] = corrects[0] + correct #PGD AT if need_adv: data = x_qry[i] label = y_qry[i] optimizer.zero_grad() adv_inp = at.attack(self.net, self.net.parameters(), data, label) optimizer.zero_grad() with torch.no_grad(): logits_q_adv = self.net(adv_inp, self.net.parameters(), bn_training=True) loss_q_adv = F.cross_entropy(logits_q_adv, label) losses_q_adv[0] += loss_q_adv pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() corrects_adv[0] = corrects_adv[0] + correct_adv # this is the loss and accuracy after the first update with torch.no_grad(): # [setsz, nway] logits_q = self.net(x_qry[i], fast_weights, bn_training=True) loss_q = F.cross_entropy(logits_q, y_qry[i]) losses_q[1] += loss_q # [setsz] pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) correct = torch.eq(pred_q, y_qry[i]).sum().item() corrects[1] = corrects[1] + correct #PGD AT if need_adv: logits_q_adv = self.net(adv_inp_adv, fast_weights, bn_training=True) loss_q_adv = F.cross_entropy(logits_q_adv, label) losses_q_adv[1] += loss_q_adv pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() corrects_adv[1] = corrects_adv[1] + correct_adv for k in range(1, self.update_step): # 1. run the i-th task and compute loss for k=1~K-1 logits = self.net(x_spt[i], fast_weights, bn_training=True) loss = F.cross_entropy(logits, y_spt[i]) # 2. compute grad on theta_pi grad = torch.autograd.grad(loss, fast_weights) # 3. theta_pi = theta_pi - train_lr * grad fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) logits_q = self.net(x_qry[i], fast_weights, bn_training=True) # loss_q will be overwritten and just keep the loss_q on last update step. loss_q = F.cross_entropy(logits_q, y_qry[i]) losses_q[k + 1] += loss_q #PGD AT if need_adv: at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step) # data = x_spt[i] # label = y_spt[i] # optimizer.zero_grad() # adv_inp = at.attack(self.net, fast_weights_adv, data, label) # logits = self.net(adv_inp, fast_weights_adv, bn_training=True) # loss = F.cross_entropy(logits, label) # grad = torch.autograd.grad(loss, fast_weights_adv) # fast_weights_adv = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights_adv))) data = x_qry[i] label = y_qry[i] optimizer.zero_grad() adv_inp_adv = at.attack(self.net, fast_weights, data, label) optimizer.zero_grad() logits_q_adv = self.net(adv_inp_adv, fast_weights, bn_training=True) loss_q_adv = F.cross_entropy(logits_q_adv, label) losses_q_adv[k + 1] += loss_q_adv with torch.no_grad(): pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) correct = torch.eq( pred_q, y_qry[i]).sum().item() # convert to numpy corrects[k + 1] = corrects[k + 1] + correct #PGD AT if need_adv: pred_q_adv = F.softmax(logits_q_adv, dim=1).argmax(dim=1) correct_adv = torch.eq(pred_q_adv, label).sum().item() corrects_adv[k + 1] = corrects_adv[k + 1] + correct_adv # end of all tasks # sum over all losses on query set across all tasks loss_q = losses_q[-1] / task_num loss_q_adv = losses_q_adv[-1] / task_num # optimize theta parameters self.meta_optim.zero_grad() loss_q.backward() # print('meta update') # for p in self.net.parameters()[:5]: # print(torch.norm(p).item()) self.meta_optim.step() # self.meta_optim.zero_grad() # loss_q_adv.backward() # self.meta_optim.step() accs = np.array(corrects) / (querysz * task_num) accs_adv = np.array(corrects_adv) / (querysz * task_num) return accs, accs_adv
elif weights.endswith('.npy'): checkpoint = np.load(weights, allow_pickle=True).item() model_dict = { k: torch.from_numpy(checkpoint[k]) for k in checkpoint.keys() if model.state_dict()[k].numel() == torch.from_numpy(checkpoint[k]).numel() } model.load_state_dict(model_dict, strict=True) model.cuda().train() del checkpoint print('load model successfully!') iters_per_epoch = int(train_size / batch_size) if adv: model_adv = PGD(model) if kdfa: model_t = resnet(imdb.classes, 50, pretrained=False, class_agnostic=False) if kdfa_weights.endswith('.pt'): checkpoint = torch.load(kdfa_weights) checkpoint['model'] = { k: v for k, v in checkpoint['model'].items() if model_t.state_dict()[k].numel() == v.numel() } model_t.load_state_dict(checkpoint['model'], strict=True) elif kdfa_weights.endswith('.npy'): checkpoint = np.load(kdfa_weights, allow_pickle=True).item()