model.load_state_dict(state_dict) except RuntimeError: # eliminate 'module.' in keys state_dict = {k[7:]: v for k, v in state_dict.items()} model.load_state_dict(state_dict) # distributed mode on multiple GPUs! # much faster than nn.DataParallel model = DistributedDataParallel( model.cuda(), device_ids=[args.local_rank]) # setup attack settings if args.adv_func == 'logits': adv_func = LogitsAdvLoss(kappa=args.kappa) else: adv_func = CrossEntropyAdvLoss() dist_func = L2Dist() # hyper-parameters from their official tensorflow code attacker = CWPerturb(model, adv_func, dist_func, attack_lr=args.attack_lr, init_weight=10., max_weight=80., binary_step=args.binary_step, num_iter=args.num_iter) # attack test_set = ModelNet40Attack(args.data_root, num_points=args.num_points, normalize=True) test_sampler = DistributedSampler(test_set, shuffle=False) test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False,
def main(): opt.manualSeed = random.randint(1, 10000) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) if opt.dataset == 'ycb': opt.num_objects = 21 #number of object classes in the dataset opt.num_points = 1000 #number of points on the input pointcloud opt.outf = 'trained_models/ycb' #folder to save trained models opt.log_dir = 'experiments/logs/ycb' #folder to save logs opt.repeat_epoch = 1 #number of repeat times for one epoch training elif opt.dataset == 'linemod': opt.num_objects = 13 opt.num_points = 500 opt.outf = 'trained_models/linemod' opt.log_dir = 'experiments/logs/linemod' opt.repeat_epoch = 20 else: print('Unknown dataset') return model = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects) model.cuda() refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects) refiner.cuda() #import pdb;pdb.set_trace() if opt.resume_posenet != '': model.load_state_dict(torch.load('{0}'.format(opt.resume_posenet))) if opt.resume_refinenet != '': refiner.load_state_dict(torch.load('{0}'.format(opt.resume_refinenet))) opt.refine_start = True opt.decay_start = True opt.lr *= opt.lr_rate opt.w *= opt.w_rate opt.batch_size = int(opt.batch_size / opt.iteration) optimizer = optim.Adam(refiner.parameters(), lr=opt.lr) else: opt.refine_start = False opt.decay_start = False optimizer = optim.Adam(model.parameters(), lr=opt.lr) if opt.dataset == 'ycb': dataset = PoseDataset_ycb('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start) elif opt.dataset == 'linemod': dataset = PoseDataset_linemod('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers) if opt.dataset == 'ycb': test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start) elif opt.dataset == 'linemod': test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start) testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers) opt.sym_list = dataset.get_sym_list() opt.num_points_mesh = dataset.get_num_points_mesh() print( '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}' .format(len(dataset), len(test_dataset), opt.num_points_mesh, opt.sym_list)) criterion = Loss(opt.num_points_mesh, opt.sym_list) criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list) best_test = np.Inf if opt.start_epoch == 1: for log in os.listdir(opt.log_dir): os.remove(os.path.join(opt.log_dir, log)) st_time = time.time() ###R adv_func = CrossEntropyAdvLoss() delta = 0.08 budget = delta * \ np.sqrt(opt.num_points * 3) # \delta * \sqrt(N * d) # attacker = FGM(model, adv_func=adv_func, budget=budget, dist_metric='l2') attacker = FGM(model, adv_func=criterion, budget=budget, dist_metric='l2') ### for epoch in range(opt.start_epoch, opt.nepoch): logger = setup_logger( 'epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch)) logger.info('Train time {0}'.format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started')) train_count = 0 train_dis_avg = 0.0 if opt.refine_start: model.eval() refiner.train() else: model.train() optimizer.zero_grad() for rep in range(opt.repeat_epoch): for i, data in enumerate(dataloader, 0): points, choose, img, target, model_points, idx = data points, choose, img, target, model_points, idx = Variable(points).cuda(), \ Variable(choose).cuda(), \ Variable(img).cuda(), \ Variable(target).cuda(), \ Variable(model_points).cuda(), \ Variable(idx).cuda() atck_pc = torch.from_numpy( attack(attacker, model, img, points, choose, idx, model_points, target, opt.w, opt.refine_start)).cuda() #import pdb;pdb.set_trace() pred_r, pred_t, pred_c, emb = model(img, points, choose, idx) loss, dis, new_points, new_target = criterion( pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start) pred_r_atck, pred_t_atck, pred_c_atck, emb_atck = model( img, atck_pc, choose, idx) loss_atck, dis_atck, new_points_atck, new_target_atck = criterion( pred_r_atck, pred_t_atck, pred_c_atck, target, model_points, idx, atck_pc, opt.w, opt.refine_start) if opt.refine_start: for ite in range(0, opt.iteration): pred_r, pred_t = refiner(new_points, emb, idx) pred_r_atck, pred_t_atck = refiner( new_points_atck, emb_atck, idx) dis, new_points, new_target = criterion_refine( pred_r, pred_t, new_target, model_points, idx, new_points) dis_atck, new_points_atck, new_target_atck = criterion_refine( pred_r_atck, pred_t_atck, new_target_atck, model_points, idx, new_points_atck) dis.backward() dis_atck.backward() else: loss.backward() loss_atck.backward() train_dis_avg += dis.item() train_dis_avg += dis_atck.item() train_count += 2 if train_count % opt.batch_size == 0: logger.info( 'Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}' .format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, int(train_count / (2 * opt.batch_size)), train_count, train_dis_avg / (2 * (opt.batch_size)))) optimizer.step() optimizer.zero_grad() train_dis_avg = 0 if train_count != 0 or train_count % 2000 == 0: #import pdb;pdb.set_trace() if opt.refine_start: torch.save( refiner.state_dict(), '{0}/pose_refine_model_current_attack.pth'.format( opt.outf)) else: torch.save( model.state_dict(), '{0}/pose_model_current_attack_{1}.pth'.format( opt.outf, epoch)) print( '>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format( epoch)) logger = setup_logger( 'epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch)) logger.info('Test time {0}'.format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started')) test_dis = 0.0 test_count = 0 model.eval() refiner.eval() for j, data in enumerate(testdataloader, 0): points, choose, img, target, model_points, idx = data points, choose, img, target, model_points, idx = Variable(points).cuda(), \ Variable(choose).cuda(), \ Variable(img).cuda(), \ Variable(target).cuda(), \ Variable(model_points).cuda(), \ Variable(idx).cuda() pred_r, pred_t, pred_c, emb = model(img, points, choose, idx) _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start) if opt.refine_start: for ite in range(0, opt.iteration): pred_r, pred_t = refiner(new_points, emb, idx) dis, new_points, new_target = criterion_refine( pred_r, pred_t, new_target, model_points, idx, new_points) test_dis += dis.item() logger.info('Test time {0} Test Frame No.{1} dis:{2}'.format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_count, dis)) test_count += 1 test_dis = test_dis / test_count logger.info('Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, test_dis)) if test_dis <= best_test: best_test = test_dis if opt.refine_start: torch.save( refiner.state_dict(), '{0}/pose_refine_model_{1}_{2}_attack.pth'.format( opt.outf, epoch, test_dis)) else: torch.save( model.state_dict(), '{0}/pose_model_{1}_{2}.pth'.format( opt.outf, epoch, test_dis)) print(epoch, '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<') if best_test < opt.decay_margin and not opt.decay_start: opt.decay_start = True opt.lr *= opt.lr_rate opt.w *= opt.w_rate optimizer = optim.Adam(model.parameters(), lr=opt.lr) if best_test < opt.refine_margin and not opt.refine_start: opt.refine_start = True opt.batch_size = int(opt.batch_size / opt.iteration) optimizer = optim.Adam(refiner.parameters(), lr=opt.lr) if opt.dataset == 'ycb': dataset = PoseDataset_ycb('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start) elif opt.dataset == 'linemod': dataset = PoseDataset_linemod('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers) if opt.dataset == 'ycb': test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start) elif opt.dataset == 'linemod': test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start) testdataloader = torch.utils.data.DataLoader( test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers) opt.sym_list = dataset.get_sym_list() opt.num_points_mesh = dataset.get_num_points_mesh() print( '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}' .format(len(dataset), len(test_dataset), opt.num_points_mesh, opt.sym_list)) criterion = Loss(opt.num_points_mesh, opt.sym_list) criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)