model.load_state_dict(state_dict)
    except RuntimeError:
        # eliminate 'module.' in keys
        state_dict = {k[7:]: v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)

    # distributed mode on multiple GPUs!
    # much faster than nn.DataParallel
    model = DistributedDataParallel(
        model.cuda(), device_ids=[args.local_rank])

    # setup attack settings
    if args.adv_func == 'logits':
        adv_func = LogitsAdvLoss(kappa=args.kappa)
    else:
        adv_func = CrossEntropyAdvLoss()
    dist_func = L2Dist()
    # hyper-parameters from their official tensorflow code
    attacker = CWPerturb(model, adv_func, dist_func,
                         attack_lr=args.attack_lr,
                         init_weight=10., max_weight=80.,
                         binary_step=args.binary_step,
                         num_iter=args.num_iter)

    # attack
    test_set = ModelNet40Attack(args.data_root, num_points=args.num_points,
                                normalize=True)
    test_sampler = DistributedSampler(test_set, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=args.batch_size,
                             shuffle=False, num_workers=4,
                             pin_memory=True, drop_last=False,
def main():
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    if opt.dataset == 'ycb':
        opt.num_objects = 21  #number of object classes in the dataset
        opt.num_points = 1000  #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb'  #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb'  #folder to save logs
        opt.repeat_epoch = 1  #number of repeat times for one epoch training
    elif opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        opt.repeat_epoch = 20
    else:
        print('Unknown dataset')
        return

    model = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects)
    model.cuda()
    refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects)
    refiner.cuda()
    #import pdb;pdb.set_trace()
    if opt.resume_posenet != '':
        model.load_state_dict(torch.load('{0}'.format(opt.resume_posenet)))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(torch.load('{0}'.format(opt.resume_refinenet)))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(model.parameters(), lr=opt.lr)

    if opt.dataset == 'ycb':
        dataset = PoseDataset_ycb('train', opt.num_points, True,
                                  opt.dataset_root, opt.noise_trans,
                                  opt.refine_start)
    elif opt.dataset == 'linemod':
        dataset = PoseDataset_linemod('train', opt.num_points, True,
                                      opt.dataset_root, opt.noise_trans,
                                      opt.refine_start)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers)
    if opt.dataset == 'ycb':
        test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                       opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'linemod':
        test_dataset = PoseDataset_linemod('test', opt.num_points, False,
                                           opt.dataset_root, 0.0,
                                           opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=opt.workers)

    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()

    print(
        '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
        .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                opt.sym_list))

    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()

    ###R
    adv_func = CrossEntropyAdvLoss()
    delta = 0.08
    budget = delta * \
        np.sqrt(opt.num_points * 3)  # \delta * \sqrt(N * d)
    # attacker = FGM(model, adv_func=adv_func, budget=budget, dist_metric='l2')
    attacker = FGM(model, adv_func=criterion, budget=budget, dist_metric='l2')

    ###

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger(
            'epoch%d' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Training started'))
        train_count = 0
        train_dis_avg = 0.0
        if opt.refine_start:
            model.eval()
            refiner.train()
        else:
            model.train()
        optimizer.zero_grad()

        for rep in range(opt.repeat_epoch):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, idx = data
                points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                                 Variable(choose).cuda(), \
                                                                 Variable(img).cuda(), \
                                                                 Variable(target).cuda(), \
                                                                 Variable(model_points).cuda(), \
                                                                 Variable(idx).cuda()

                atck_pc = torch.from_numpy(
                    attack(attacker, model, img, points, choose, idx,
                           model_points, target, opt.w,
                           opt.refine_start)).cuda()
                #import pdb;pdb.set_trace()
                pred_r, pred_t, pred_c, emb = model(img, points, choose, idx)
                loss, dis, new_points, new_target = criterion(
                    pred_r, pred_t, pred_c, target, model_points, idx, points,
                    opt.w, opt.refine_start)

                pred_r_atck, pred_t_atck, pred_c_atck, emb_atck = model(
                    img, atck_pc, choose, idx)
                loss_atck, dis_atck, new_points_atck, new_target_atck = criterion(
                    pred_r_atck, pred_t_atck, pred_c_atck, target,
                    model_points, idx, atck_pc, opt.w, opt.refine_start)

                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        pred_r_atck, pred_t_atck = refiner(
                            new_points_atck, emb_atck, idx)
                        dis, new_points, new_target = criterion_refine(
                            pred_r, pred_t, new_target, model_points, idx,
                            new_points)
                        dis_atck, new_points_atck, new_target_atck = criterion_refine(
                            pred_r_atck, pred_t_atck, new_target_atck,
                            model_points, idx, new_points_atck)
                        dis.backward()
                        dis_atck.backward()

                else:
                    loss.backward()
                    loss_atck.backward()

                train_dis_avg += dis.item()
                train_dis_avg += dis_atck.item()
                train_count += 2

                if train_count % opt.batch_size == 0:
                    logger.info(
                        'Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}'
                        .format(
                            time.strftime("%Hh %Mm %Ss",
                                          time.gmtime(time.time() - st_time)),
                            epoch, int(train_count / (2 * opt.batch_size)),
                            train_count,
                            train_dis_avg / (2 * (opt.batch_size))))
                    optimizer.step()
                    optimizer.zero_grad()
                    train_dis_avg = 0

                if train_count != 0 or train_count % 2000 == 0:
                    #import pdb;pdb.set_trace()
                    if opt.refine_start:
                        torch.save(
                            refiner.state_dict(),
                            '{0}/pose_refine_model_current_attack.pth'.format(
                                opt.outf))
                    else:
                        torch.save(
                            model.state_dict(),
                            '{0}/pose_model_current_attack_{1}.pth'.format(
                                opt.outf, epoch))

        print(
            '>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(
                epoch))

        logger = setup_logger(
            'epoch%d_test' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        model.eval()
        refiner.eval()

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, idx = data
            points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
            pred_r, pred_t, pred_c, emb = model(img, points, choose, idx)
            _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c,
                                                       target, model_points,
                                                       idx, points, opt.w,
                                                       opt.refine_start)

            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(
                        pred_r, pred_t, new_target, model_points, idx,
                        new_points)

            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} dis:{2}'.format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), test_count,
                dis))

            test_count += 1

        test_dis = test_dis / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)),
            epoch, test_dis))
        if test_dis <= best_test:
            best_test = test_dis
            if opt.refine_start:
                torch.save(
                    refiner.state_dict(),
                    '{0}/pose_refine_model_{1}_{2}_attack.pth'.format(
                        opt.outf, epoch, test_dis))
            else:
                torch.save(
                    model.state_dict(), '{0}/pose_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            print(epoch,
                  '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

        if best_test < opt.decay_margin and not opt.decay_start:
            opt.decay_start = True
            opt.lr *= opt.lr_rate
            opt.w *= opt.w_rate
            optimizer = optim.Adam(model.parameters(), lr=opt.lr)

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)

            if opt.dataset == 'ycb':
                dataset = PoseDataset_ycb('train', opt.num_points, True,
                                          opt.dataset_root, opt.noise_trans,
                                          opt.refine_start)
            elif opt.dataset == 'linemod':
                dataset = PoseDataset_linemod('train', opt.num_points, True,
                                              opt.dataset_root,
                                              opt.noise_trans,
                                              opt.refine_start)
            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=opt.workers)
            if opt.dataset == 'ycb':
                test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                               opt.dataset_root, 0.0,
                                               opt.refine_start)
            elif opt.dataset == 'linemod':
                test_dataset = PoseDataset_linemod('test', opt.num_points,
                                                   False, opt.dataset_root,
                                                   0.0, opt.refine_start)
            testdataloader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=1,
                shuffle=False,
                num_workers=opt.workers)

            opt.sym_list = dataset.get_sym_list()
            opt.num_points_mesh = dataset.get_num_points_mesh()

            print(
                '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
                .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                        opt.sym_list))

            criterion = Loss(opt.num_points_mesh, opt.sym_list)
            criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)