예제 #1
0
def main():
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.remark != None:
        args.remark = args.remark
    else:
        args.remark = args.dataset + "-" + args.task + "-" + args.norm

    if args.dataset == "shapenet":
        args.num_class = 16
    else:
        args.num_class = 40

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('/data-x/g12/zhangjie/3dIP/exp/v2')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('pruning')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.remark + "_" + timestr)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG_curve'''
    title = args.dataset + "-" + args.task + "-" + args.norm + "-" + "Pruning"
    logger_loss = Logger(os.path.join(log_dir, 'log_loss.txt'), title=title)
    logger_loss.set_names(['Valid Public Loss', 'Valid Private Loss'])
    logger_acc = Logger(os.path.join(log_dir, 'log_acc.txt'), title=title)
    logger_acc.set_names(['Valid Public Acc.', 'Valid Private Acc.'])
    '''LOG'''  #创建log文件
    logger = logging.getLogger("Model")  #log的名字
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)  #log的最低等级
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log文件名
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load pruning test dataset ...')
    if args.dataset == "shapenet":
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=True,
                                                batchsize=args.batch_size)
    else:
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=False,
                                                batchsize=args.batch_size)

    log_string('Load finished ...')
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)

    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    shutil.copy('prun0.py', str(experiment_dir))
    shutil.copytree('./models/layers', str(experiment_dir) + "/layers")
    shutil.copytree('./data', str(experiment_dir) + "/data")
    shutil.copytree('./utils', str(experiment_dir) + "/utils")

    classifier = MODEL.get_model(num_class, channel=3).cuda()

    pprint(classifier)

    pth_dir = '/data-x/g12/zhangjie/3dIP/exp/v2/classification/' + args.dataset + "-" \
              + args.task + "-" + args.norm + "/checkpoints/best_model.pth"
    log_string('pre-trained model chk pth: %s' % pth_dir)

    checkpoint = torch.load(pth_dir)
    model_dict = checkpoint['model_state_dict']
    print('Total : {}'.format(len(model_dict)))
    print("best epoch", checkpoint['epoch'])
    classifier.load_state_dict(model_dict)
    classifier.cuda()

    p_num = get_parameter_number(classifier)
    log_string('Original trainable parameter: %s' % p_num)
    '''TESTING ORIGINAL'''
    logger.info('Test original model...')

    with torch.no_grad():
        _, instance_acc, class_acc = test(classifier,
                                          testDataLoader,
                                          num_class=args.num_class,
                                          ind=0)
        _, instance_acc2, class_acc2 = test(classifier,
                                            testDataLoader,
                                            num_class=args.num_class,
                                            ind=1)
        log_string(
            'Original Instance Public Accuracy: %f, Class Public Accuracy: %f'
            % (instance_acc, class_acc))
        log_string(
            'Original Instance Private Accuracy: %f, Class Private Accuracy: %f'
            % (instance_acc2, class_acc2))
    '''PRUNING'''
    logger.info('Start testing of pruning...')

    for perc in [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
        time_start = datetime.datetime.now()
        classifier.load_state_dict(model_dict)
        p_num = get_parameter_number(classifier)
        log_string('Original trainable parameter: %s' % p_num)
        '''Testing pruning model'''
        logger.info('Testing pruning model--%d%%' % perc)
        pruning_net(classifier, perc)
        classifier.cuda()
        p_num = get_parameter_number(classifier)
        log_string('Pruning %02d%% -- trainable parameter: %s' % (perc, p_num))

        with torch.no_grad():
            val_loss1, test_instance_acc1, class_acc1 = test(
                classifier, testDataLoader, num_class=args.num_class, ind=0)
            val_loss2, test_instance_acc2, class_acc2 = test(
                classifier, testDataLoader, num_class=args.num_class, ind=1)
            log_string(
                'Pruning %02d%%-Test Instance Public Accuracy: %f, Class Public Accuracy: %f'
                % (perc, test_instance_acc1, class_acc1))
            log_string(
                'Pruning %02d%%-Test Instance Private Accuracy: %f, Class Private Accuracy: %f'
                % (perc, test_instance_acc2, class_acc2))
            # val_loss = (val_loss1 + val_loss2)/2
            # test_instance_acc = (test_instance_acc1 + test_instance_acc2)/2

        logger_loss.append([val_loss1, val_loss2])
        logger_acc.append([test_instance_acc1, test_instance_acc2])

        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

    logger_loss.close()
    logger_loss.plot_prun()
    savefig(os.path.join(log_dir, 'log_loss.eps'))
    logger_acc.close()
    logger_acc.plot_prun()
    savefig(os.path.join(log_dir, 'log_acc.eps'))

    logger.info('End of pruning...')
def main():
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.remark != None:
        args.remark = args.remark
    else:
        args.remark = args.dataset + "-" + args.task + "-" + args.norm

    if args.dataset == "shapenet":
        args.num_class = 16
    else:
        args.num_class = 40

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    # experiment_dir = Path('./exp/v1/')
    experiment_dir = Path('/data-x/g12/zhangjie/3dIP/exp/v1')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.remark + "_" + timestr)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG_curve'''
    title = args.dataset + "-" + args.task + "-" + args.norm
    logger_loss = Logger(os.path.join(log_dir, 'log_loss_v1.txt'), title=title)
    logger_loss.set_names(
        ['Train Loss', 'Valid Clean Loss', 'Valid Trigger Loss'])
    logger_acc = Logger(os.path.join(log_dir, 'log_acc_v1.txt'), title=title)
    logger_acc.set_names(
        ['Train  Acc.', 'Valid Clean Acc.', 'Valid Trigger Acc.'])
    '''LOG'''  #创建log文件
    logger = logging.getLogger("Model")  #log的名字
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)  #log的最低等级
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log文件名
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    if args.dataset == "shapenet":
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=True,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=True,
                                                batchsize=args.batch_size)
        triggerDataLoader = getData2.get_dataLoader(Shapenet=True,
                                                    batchsize=args.batch_size)
    else:
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=False,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=False,
                                                batchsize=args.batch_size)
        triggerDataLoader = getData2.get_dataLoader(Shapenet=False,
                                                    batchsize=args.batch_size)

    wminputs, wmtargets = [], []
    for wm_idx, (wminput, wmtarget) in enumerate(triggerDataLoader):
        wminputs.append(wminput)
        wmtargets.append(wmtarget)
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)

    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    shutil.copy('train_1_cls.py', str(experiment_dir))
    shutil.copy('./data/getData.py', str(experiment_dir))
    shutil.copy('./data/getData2.py', str(experiment_dir))
    shutil.copytree('./models/layers', str(experiment_dir) + "/layers")

    classifier = MODEL.get_model(num_class, channel=3).cuda()
    # classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda()
    criterion = MODEL.get_loss().cuda()

    pprint(classifier)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    mean_loss = []
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        time_start = datetime.datetime.now()
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            wm_id = np.random.randint(len(wminputs))
            points = torch.cat(
                [points, wminputs[(wm_id + batch_id) % len(wminputs)]],
                dim=0)  #随机选择wininputs和inputscat
            target = torch.cat(
                [target, wmtargets[(wm_id + batch_id) % len(wminputs)]], dim=0)

            points = points.data.numpy()
            points = provider.random_point_dropout(
                points)  #provider是自己写的一个对点云操作的函数,随机dropout,置为第一个点的值
            points[:, :,
                   0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                   0:3])  #点的放缩
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :,
                                                                  0:3])  #点的偏移
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()

            optimizer.zero_grad()
            classifier = classifier.train()
            pred, trans_feat = classifier(points)
            loss = criterion(pred, target.long(), trans_feat)
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1

            mean_loss.append(loss.item() / float(points.size()[0]))

        train_loss = np.mean(mean_loss)
        train_instance_acc = np.mean(mean_correct)
        log_string('Train Instance Accuracy: %f' % train_instance_acc)

        with torch.no_grad():
            val_loss, instance_acc, class_acc = test(classifier,
                                                     testDataLoader,
                                                     num_class=args.num_class)
            val_loss2, instance_acc2, class_acc2 = test(
                classifier, triggerDataLoader, num_class=args.num_class)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Clean Instance Accuracy: %f, Class Accuracy: %f' %
                       (instance_acc, class_acc))
            log_string('Best Clean Instance Accuracy: %f, Class Accuracy: %f' %
                       (best_instance_acc, best_class_acc))
            log_string(
                'Test Trigger Accuracy: %f, Trigger Class Accuracy: %f' %
                (instance_acc2, class_acc2))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                log_string('best_epoch %s' % str(best_epoch))
                state = {
                    'epoch': best_epoch,
                    'clean instance_acc': instance_acc,
                    'clean class_acc': class_acc,
                    'trigger instance_acc': instance_acc2,
                    'trigger class_acc': class_acc2,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

        logger_loss.append([train_loss, val_loss, val_loss2])
        logger_acc.append([train_instance_acc, instance_acc, instance_acc2])

        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

    logger_loss.close()
    logger_loss.plot()
    savefig(os.path.join(log_dir, 'log_loss_v3.eps'))
    logger_acc.close()
    logger_acc.plot()
    savefig(os.path.join(log_dir, 'log_acc_v3.eps'))

    log_string('best_epoch %s' % str(best_epoch))
    logger.info('End of training...')
예제 #3
0
def main():
    args = parser.parse_args()
    # torch.manual_seed(args.seed)
    # torch.cuda.manual_seed(args.seed)
    # np.random.seed(args.seed)

    if args.remark != None:
        args.remark = args.remark
    else:
        args.remark = args.dataset + "-" + args.task + "-" + args.norm

    if args.dataset == "shapenet":
        args.num_class = 16
    else:
        args.num_class = 40

    if 'fake2-' in args.type:
        args.flipperc = 0
        print('No Flip')
    elif 'fake3-' in args.type:
        args.flipperc = 0
        print('No Flip')

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    if args.task == 'baseline':
        experiment_dir_root = Path('/data-x/g12/zhangjie/3dIP/baseline')
    else:
        experiment_dir_root = Path('/data-x/g12/zhangjie/3dIP/ours')

    experiment_dir_root.mkdir(exist_ok=True)
    experiment_dir = experiment_dir_root.joinpath('ambiguity_attack')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.remark)
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.type)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    logger = logging.getLogger("Model")  #log name
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log file name
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    if args.dataset == "shapenet":
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=True,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=True,
                                                batchsize=args.batch_size)
    else:
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=False,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=False,
                                                batchsize=args.batch_size)

    log_string('Finished ...')
    log_string('Load model ...')
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)

    #  copy model file to exp dir
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    shutil.copy('attack2.py', str(experiment_dir))

    classifier = MODEL.get_model(num_class, channel=3).cuda()
    # criterion = MODEL.get_loss().cuda()
    criterion = nn.NLLLoss().cuda()

    sd = experiment_dir_root.joinpath('classification')
    sd.mkdir(exist_ok=True)
    sd = sd.joinpath(str(args.remark))
    sd.mkdir(exist_ok=True)
    sd = sd.joinpath('checkpoints/best_model.pth')

    checkpoint = torch.load(sd)
    classifier.load_state_dict(checkpoint['model_state_dict'])

    for param in classifier.parameters():
        param.requires_grad_(False)

    origpassport = []
    fakepassport = []

    for n, m in classifier.named_modules():
        if n in ['convp1', 'convp2', 'convp3', 'p1', 'p2', 'fc3']:
            key, skey = m.__getattr__('key_private').data.clone(
            ), m.__getattr__('skey_private').data.clone()
            origpassport.append(key.cuda())
            origpassport.append(skey.cuda())

            m.__delattr__('key_private')  # 删除属性
            m.__delattr__('skey_private')

            # fake like random onise
            if 'fake2-' in args.type:
                # fake random
                m.register_parameter(
                    'key_private',
                    nn.Parameter(torch.randn(*key.size()) * 0.001,
                                 requires_grad=True))
                m.register_parameter(
                    'skey_private',
                    nn.Parameter(torch.randn(*skey.size()) * 0.001,
                                 requires_grad=True))

            # fake slightly modify ori
            else:
                # fake slightly modify ori
                m.register_parameter(
                    'key_private',
                    nn.Parameter(key.clone() +
                                 torch.randn(*key.size()) * 0.001,
                                 requires_grad=True))
                m.register_parameter(
                    'skey_private',
                    nn.Parameter(skey.clone() +
                                 torch.randn(*skey.size()) * 0.001,
                                 requires_grad=True))

            fakepassport.append(m.__getattr__('key_private'))
            fakepassport.append(m.__getattr__('skey_private'))

            if args.task == 'ours':
                if args.type != 'fake2':

                    for layer in m.fc.modules():
                        if isinstance(layer, nn.Linear):
                            nn.init.xavier_normal_(layer.weight)

                    for i in m.fc.parameters():
                        i.requires_grad = True

    if args.flipperc != 0:
        log_string(f'Reverse {args.flipperc * 100:.2f}% of binary signature')

        for name, m in classifier.named_modules():
            if name in ['convp1', 'convp2', 'convp3', 'p1', 'p2', 'p3']:
                mflip = args.flipperc
                oldb = m.sign_loss_private.b
                newb = oldb.clone()
                npidx = np.arange(len(oldb))  # bit 长度
                randsize = int(oldb.view(-1).size(0) * mflip)
                randomidx = np.random.choice(npidx, randsize,
                                             replace=False)  # 随机选择
                newb[randomidx] = oldb[randomidx] * -1  # reverse bit  进行翻转
                m.sign_loss_private.set_b(newb)

    classifier.cuda()
    optimizer = torch.optim.SGD(fakepassport,
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=0.0005)

    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
    scheduler = None

    def run_cs():
        cs = []

        for d1, d2 in zip(origpassport, fakepassport):
            d1 = d1.view(d1.size(0), -1)
            d2 = d2.view(d2.size(0), -1)

            cs.append(F.cosine_similarity(d1, d2).item())

        return cs

    classifier.train()
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct2 = []
    mean_loss2 = []
    start_epoch = 0

    mse_criterion = nn.MSELoss()
    cs_criterion = nn.CosineSimilarity()
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        time_start = datetime.datetime.now()
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))
        optimizer.zero_grad()
        signacc_meter = 0
        signloss_meter = 0

        if scheduler is not None:
            scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            points = points.data.numpy()
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                         0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()

            # loss define
            pred2, _ = classifier(points, ind=1)
            loss2 = criterion(pred2, target.long())
            mean_loss2.append(loss2.item() / float(points.size()[0]))
            pred_choice2 = pred2.data.max(1)[1]
            correct2 = pred_choice2.eq(target.long().data).cpu().sum()
            mean_correct2.append(correct2.item() / float(points.size()[0]))

            signacc = torch.tensor(0.).cuda()
            count = 0
            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    signacc += m.get_acc()
                    count += 1
            try:
                signacc_meter += signacc.item() / count
            except:
                pass

            sign_loss = torch.tensor(0.).cuda()
            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    sign_loss += m.loss
            signloss_meter += sign_loss

            loss = loss2
            maximizeloss = torch.tensor(0.).cuda()
            mseloss = torch.tensor(0.).cuda()
            csloss = torch.tensor(0.).cuda()

            for l, r in zip(origpassport, fakepassport):
                mse = mse_criterion(l, r)
                cs = cs_criterion(l.view(1, -1), r.view(1, -1)).mean()
                csloss += cs
                mseloss += mse
                maximizeloss += 1 / mse

            if 'fake2-' in args.type:
                (loss).backward()  # only cross-entropy loss  backward  fake2
            elif 'fake3-' in args.type:
                (loss +
                 maximizeloss).backward()  # csloss do not backward   kafe3

            else:
                (loss + maximizeloss +
                 1000 * sign_loss).backward()  # csloss  backward   #fake3_S
                # (loss  + 1000 * sign_loss).backward()  # csloss  backward   #fake3_S

            torch.nn.utils.clip_grad_norm_(fakepassport, 2)

            optimizer.step()
            global_step += 1

        signacc = signacc_meter / len(trainDataLoader)
        log_string('Train Sign Accuracy: %f' % signacc)

        signloss = signloss_meter / len(trainDataLoader)
        log_string('Train Sign Loss: %f' % signloss)

        train_instance_acc2 = np.mean(mean_correct2)
        log_string('Train Instance Private Accuracy: %f' % train_instance_acc2)

        with torch.no_grad():
            cs = run_cs()
            log_string(
                f'Cosine Similarity of Real and Maximize passport: {sum(cs) / len(origpassport):.4f}'
            )
            val_loss2, test_instance_acc2, class_acc2, singloss2, signacc2 = test(
                classifier, testDataLoader, num_class=args.num_class, ind=1)

            log_string(
                'Test Instance Private Accuracy: %f, Class Private Accuracy: %f'
                % (test_instance_acc2, class_acc2))
            log_string('Test Private Sign Accuracy: %f' % (signacc2))

            test_instance_acc = test_instance_acc2
            class_acc = class_acc2

            if (test_instance_acc >= best_instance_acc):
                best_instance_acc = test_instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string(
                'Test Instance Average Accuracy: %f, Class Average Accuracy: %f'
                % (test_instance_acc, class_acc))
            log_string(
                'Best Instance Average Accuracy: %f, Class Average Accuracy: %f'
                % (best_instance_acc, best_class_acc))

            if (test_instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_attack_model.pth'
                log_string('Saving at %s' % savepath)
                log_string('best_epoch %s' % str(best_epoch))
                state = {
                    'epoch': best_epoch,
                    'instance_acc': test_instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'origpassport': origpassport,
                    'fakepassport': fakepassport
                }
                torch.save(state, savepath)
            global_epoch += 1

        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

    log_string('best_epoch %s' % str(best_epoch))

    logger.info('End of training...')
예제 #4
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./exp/')
    # experiment_dir = Path('/data-x/g12/zhangjie/pointnet/exp/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
        experiment_dir.mkdir(exist_ok=True)
        experiment_dir = experiment_dir.joinpath(args.remark + "_" + timestr)

    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''  #创建log文件
    args = parse_args()
    logger = logging.getLogger("Model")  #log的名字
    # print("FFFFFFFF",logger) #<Logger Model (WARNING)>
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)  #log的最低等级
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log文件名
    log_string('PARAMETER ...')
    # print("FFFFFFF",logger.info)  #<bound method Logger.info of <Logger Model (INFO)>>
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    # # DATA_PATH = 'data/modelnet40_normal_resampled/'
    # DATA_PATH = 'data/modelnet40_normal_resampled/'
    #
    # TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train',
    #                                                  normal_channel=args.normal)
    # TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test',
    #                                                 normal_channel=args.normal)
    # trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=4)
    # testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4)

    trainDataLoader = getData.get_dataLoader(train=True)
    testDataLoader = getData.get_dataLoader(train=False)
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)
    # 当在写代码时,我们希望能够根据传入的选项设置,如args.model来确定要导入使用的是哪个model.py文件,而不是一股脑地导入, 这种时候就需要用上python的动态导入模块

    # 复制model文件到exp——dir
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))

    classifier = MODEL.get_model(num_class, normal_channel=args.normal).cuda()
    criterion = MODEL.get_loss().cuda()

    pprint(classifier)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            # print("POINTS",points.shape)  [24Batch,1024,6]
            # print("TARGET",target.shape)  [24Batch,1]
            points = points.data.numpy()
            points = provider.random_point_dropout(
                points)  #provider是自己写的一个对点云操作的函数,随机dropout,置为第一个点的值
            points[:, :,
                   0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                   0:3])  #点的放缩
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :,
                                                                  0:3])  #点的偏移
            points = torch.Tensor(points)
            # target = target[:, 0]

            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()

            classifier = classifier.train()
            pred, trans_feat = classifier(points)
            loss = criterion(pred, target.long(), trans_feat)
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1

        train_instance_acc = np.mean(mean_correct)
        log_string('Train Instance Accuracy: %f' % train_instance_acc)

        with torch.no_grad():
            instance_acc, class_acc = test(classifier.eval(),
                                           testDataLoader,
                                           num_class=args.num_class)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Instance Accuracy: %f, Class Accuracy: %f' %
                       (instance_acc, class_acc))
            log_string('Best Instance Accuracy: %f, Class Accuracy: %f' %
                       (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
예제 #5
0
def main():
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.remark != None:
        args.remark = args.remark
    else:
        args.remark = args.dataset + "-" + args.task + "-" + args.norm

    if args.dataset == "shapenet":
        args.num_class = 16
    else:
        args.num_class = 40

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    # experiment_dir = Path('./exp/v3/')
    if args.task == 'baseline':
        experiment_dir = Path('/data-x/g12/zhangjie/3dIP/baseline')
    else:
        experiment_dir = Path('/data-x/g12/zhangjie/3dIP/ours')

    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.remark + "_" + timestr)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG_curve'''
    title = args.dataset + "-" + args.task + "-" + args.norm
    logger_loss = Logger(os.path.join(log_dir, 'log_loss_v3.txt'), title=title)
    logger_loss.set_names([
        'Train Pub&Pri  Loss',
        'Train Public Loss',
        'Train Private Loss',
        'Valid Pub-Clean loss',
        'Valid Pub-Trigger Loss',
        'Valid Pri-Clean Loss',
        'Valid Pri-Trigger Loss',
    ])
    logger_acc = Logger(os.path.join(log_dir, 'log_acc_v3.txt'), title=title)
    logger_acc.set_names([
        'Train Pub-Combine  Acc.', 'Valid Pub-Clean Acc.',
        'Valid Pub-Trigger Acc.', 'Train Pri-Combine  Acc.',
        'Valid Pri-Clean Acc.', 'Valid Pri-Trigger Acc.'
    ])
    '''LOG'''  #创建log文件
    logger = logging.getLogger("Model")  #log的名字
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)  #log的最低等级
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log文件名
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    if args.dataset == "shapenet":
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=True,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=True,
                                                batchsize=args.batch_size)
        triggerDataLoader = getData2.get_dataLoader(Shapenet=True,
                                                    T1=args.T1,
                                                    batchsize=args.batch_size)
    else:
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=False,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=False,
                                                batchsize=args.batch_size)
        triggerDataLoader = getData2.get_dataLoader(Shapenet=False,
                                                    T1=args.T1,
                                                    batchsize=args.batch_size)

    wminputs, wmtargets = [], []
    for wm_idx, (wminput, wmtarget) in enumerate(triggerDataLoader):
        wminputs.append(wminput)
        wmtargets.append(wmtarget)
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)

    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    shutil.copy('train_3_cls.py', str(experiment_dir))
    shutil.copy('./data/getData.py', str(experiment_dir))
    shutil.copy('./data/getData2.py', str(experiment_dir))
    shutil.copytree('./models/layers', str(experiment_dir) + "/layers")

    classifier = MODEL.get_model(num_class, channel=3).cuda()
    criterion = MODEL.get_loss().cuda()

    pprint(classifier)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    mean_correct2 = []
    mean_loss = []
    mean_loss1 = []
    mean_loss2 = []
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        time_start = datetime.datetime.now()
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        wm_id = np.random.randint(len(wminputs))

        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            points = torch.cat(
                [points, wminputs[(wm_id + batch_id) % len(wminputs)]],
                dim=0)  #随机选择wininputs和inputscat
            target = torch.cat(
                [target, wmtargets[(wm_id + batch_id) % len(wminputs)]], dim=0)
            points = points.data.numpy()
            points = provider.random_point_dropout(
                points)  #provider是自己写的一个对点云操作的函数,随机dropout,置为第一个点的值
            points[:, :,
                   0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                   0:3])  #点的放缩
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :,
                                                                  0:3])  #点的偏移
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()

            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    m.reset()

            loss1 = torch.tensor(0.).cuda()
            loss2 = torch.tensor(0.).cuda()
            sign_loss = torch.tensor(0.).cuda()

            for ind in range(2):
                if ind == 0:
                    pred, trans_feat = classifier(points, ind=ind)
                    loss1 = criterion(pred, target.long(), trans_feat)
                    mean_loss1.append(loss1.item() / float(points.size()[0]))
                    pred_choice = pred.data.max(1)[1]
                    correct = pred_choice.eq(target.long().data).cpu().sum()
                    mean_correct.append(correct.item() /
                                        float(points.size()[0]))

                else:
                    pred2, trans_feat2 = classifier(points, ind=ind)
                    loss2 = criterion(pred2, target.long(), trans_feat2)
                    mean_loss2.append(loss2.item() / float(points.size()[0]))
                    pred_choice2 = pred2.data.max(1)[1]
                    correct2 = pred_choice2.eq(target.long().data).cpu().sum()
                    mean_correct2.append(correct2.item() /
                                         float(points.size()[0]))

            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    sign_loss += m.loss

            loss = args.beta * loss1 + loss2 + sign_loss
            mean_loss.append(loss.item() / float(points.size()[0]))

            loss.backward()
            optimizer.step()
            global_step += 1

        train_instance_acc = np.mean(mean_correct)
        train_instance_acc2 = np.mean(mean_correct2)

        train_loss = np.mean(mean_loss)
        train_loss1 = np.mean(mean_loss1)
        train_loss2 = np.mean(mean_loss2)
        log_string('Train Combine Public Accuracy: %f' % train_instance_acc)
        log_string('Train Combine Private Accuracy: %f' % train_instance_acc2)

        sign_acc = torch.tensor(0.).cuda()
        count = 0

        for m in classifier.modules():
            if isinstance(m, SignLoss):
                sign_acc += m.acc
                count += 1

        if count != 0:
            sign_acc /= count

        log_string('Sign Accuracy: %f' % sign_acc)

        res = {}
        avg_private = 0
        count_private = 0

        with torch.no_grad():
            if args.task == 'ours':
                for name, m in classifier.named_modules():
                    if name in [
                            'convp1', 'convp2', 'convp3', 'p1', 'p2', 'p3'
                    ]:
                        signbit, _ = m.get_scale(ind=1)
                        signbit = signbit.view(-1).sign()
                        privatebit = m.b

                        detection = (
                            signbit == privatebit).float().mean().item()
                        res['private_' + name] = detection
                        avg_private += detection
                        count_private += 1

            elif args.task == 'baseline':
                for name, m in classifier.named_modules():
                    if name in [
                            'convp1', 'convp2', 'convp3', 'p1', 'p2', 'p3'
                    ]:
                        signbit = m.get_scale(ind=1).view(-1).sign()
                        privatebit = m.b

                        detection = (
                            signbit == privatebit).float().mean().item()
                        res['private_' + name] = detection
                        avg_private += detection
                        count_private += 1

            log_string('Private Sign Detection Accuracy: %f' %
                       (avg_private / count_private * 100))

            for ind in range(2):
                if ind == 0:
                    val_loss1, test_instance_acc1, class_acc1 = test(
                        classifier,
                        testDataLoader,
                        num_class=args.num_class,
                        ind=0)
                    val_loss_wm1, instance_acc_wm, class_acc_wm = test(
                        classifier,
                        triggerDataLoader,
                        num_class=args.num_class,
                        ind=0)
                else:
                    val_loss2, test_instance_acc2, class_acc2 = test(
                        classifier,
                        testDataLoader,
                        num_class=args.num_class,
                        ind=1)
                    val_loss_wm2, instance_acc_wm2, class_acc_wm2 = test(
                        classifier,
                        triggerDataLoader,
                        num_class=args.num_class,
                        ind=1)

            log_string(
                'Test Clean Public Accuracy: %f, Class Public Accuracy: %f' %
                (test_instance_acc1, class_acc1))
            log_string(
                'Test Clean Private Accuracy: %f, Class Private Accuracy: %f' %
                (test_instance_acc2, class_acc2))
            log_string(
                'Test Trigger Public Accuracy: %f, Trigger Class Public Accuracy: %f'
                % (instance_acc_wm, class_acc_wm))
            log_string(
                'Test Trigger Private Accuracy: %f, Trigger Class Private Accuracy: %f'
                % (instance_acc_wm2, class_acc_wm2))

            test_instance_acc = (test_instance_acc1 + test_instance_acc2) / 2
            class_acc = (class_acc1 + class_acc2) / 2

            if (test_instance_acc >= best_instance_acc):
                best_instance_acc = test_instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string(
                'Test Combine Average Accuracy: %f, Class Average Accuracy: %f'
                % (test_instance_acc, class_acc))
            log_string(
                'Best Combine Average Accuracy: %f, Class Average Accuracy: %f'
                % (best_instance_acc, best_class_acc))

            if (test_instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                log_string('best_epoch %s' % str(best_epoch))
                state = {
                    'epoch': best_epoch,
                    'instance_acc': test_instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

        logger_loss.append([
            train_loss, train_loss1, train_loss2, val_loss1, val_loss_wm1,
            val_loss2, val_loss_wm2
        ])
        logger_acc.append([
            train_instance_acc, test_instance_acc1, instance_acc_wm,
            train_instance_acc2, test_instance_acc2, instance_acc_wm2
        ])

        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

    logger_loss.close()
    logger_loss.plot()
    savefig(os.path.join(log_dir, 'log_loss_v3.eps'))
    logger_acc.close()
    logger_acc.plot()
    savefig(os.path.join(log_dir, 'log_acc_v3.eps'))

    log_string('best_epoch %s' % str(best_epoch))
    logger.info('End of training...')
def main():
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.remark != None:
        args.remark = args.remark
    else:
        args.remark = args.dataset + "-" + args.task + "-" + args.norm

    if args.dataset =="shapenet":
        args.num_class=16
    else:
        args.num_class=40

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    if   args.task == 'baseline':
        experiment_dir_root = Path('/data-x/g12/zhangjie/3dIP/baseline')
    else:
        experiment_dir_root = Path('/data-x/g12/zhangjie/3dIP/ours')
    experiment_dir_root.mkdir(exist_ok=True)
    experiment_dir = experiment_dir_root.joinpath('pruning')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.remark)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG_curve'''
    title = ''

    logger_loss = Logger(os.path.join(log_dir, 'log_loss_v3.txt'), title=title)
    logger_loss.set_names([ 'Valid Pub-Clean loss', 'Valid Pub-Trigger Loss', 'Valid Pri-Clean Loss', 'Valid Pri-Trigger Loss', ])
    logger_acc = Logger(os.path.join(log_dir, 'log_acc_v3.txt'), title=title)
    logger_acc.set_names([   'Model for Releasing', 'Model for Verification', 'Trigger', 'Signature'])

    '''LOG'''  #创建log文件
    logger = logging.getLogger("Model") #log的名字
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO) #log的最低等级
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log文件名
    log_string('PARAMETER ...')
    log_string(args)

    '''DATA LOADING'''
    log_string('Load pruning test dataset ...')
    if args.dataset == "shapenet":
        testDataLoader = getData.get_dataLoader(train=False, Shapenet=True, batchsize=args.batch_size)
        triggerDataLoader = getData2.get_dataLoader(Shapenet=True, T1=args.T1, batchsize=args.batch_size)
    else:
        testDataLoader = getData.get_dataLoader(train=False, Shapenet=False, batchsize=args.batch_size)
        triggerDataLoader = getData2.get_dataLoader(Shapenet=False,T1=args.T1, batchsize=args.batch_size)


    log_string('Load finished ...')



    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)

    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    shutil.copy('prun3.py', str(experiment_dir))

    classifier = MODEL.get_model(num_class, channel=3).cuda()

    # pprint(classifier)

    sd = experiment_dir_root.joinpath('classification')
    sd.mkdir(exist_ok=True)
    sd = sd.joinpath(str(args.remark))
    sd.mkdir(exist_ok=True)
    sd = sd.joinpath('checkpoints/best_model.pth')

    log_string('pre-trained model chk pth: %s'%sd)

    checkpoint = torch.load(sd)
    model_dict = checkpoint['model_state_dict']
    print('Total : {}'.format(len(model_dict)))
    print("best epoch", checkpoint['epoch'])
    classifier.load_state_dict(model_dict)
    classifier.cuda()

    p_num = get_parameter_number(classifier)
    log_string('Original trainable parameter: %s'%p_num)

    '''TESTING ORIGINAL'''
    logger.info('Test original model...')

    with torch.no_grad():

        _, test_instance_acc1, class_acc1, _, _ = test(classifier, testDataLoader, num_class=args.num_class, ind=0)
        _, instance_acc_wm, class_acc_wm ,_, _ = test(classifier, triggerDataLoader, num_class=args.num_class, ind=0)
        _, test_instance_acc2, class_acc2, signloss2, signacc2 = test(classifier, testDataLoader, num_class=args.num_class,ind=1)
        _, instance_acc_wm2, class_acc_wm2, _, _ = test(classifier, triggerDataLoader,num_class=args.num_class, ind=1)

    log_string('Test Clean Public Accuracy: %f, Class Public Accuracy: %f' % (test_instance_acc1, class_acc1))
    log_string('Test Clean Private Accuracy: %f, Class Private Accuracy: %f' % (test_instance_acc2, class_acc2))
    log_string('Test  Private Sign  Accuracy: %f' % (signacc2))
    log_string('Test Trigger Public Accuracy: %f, Trigger Class Public Accuracy: %f' % (instance_acc_wm, class_acc_wm))
    log_string('Test Trigger Private Accuracy: %f, Trigger Class Private Accuracy: %f' % (instance_acc_wm2, class_acc_wm2))

    '''PRUNING'''
    logger.info('Start testing of pruning...')

    for perc in [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
        time_start = datetime.datetime.now()
        classifier.load_state_dict(model_dict)
        p_num = get_parameter_number(classifier)
        log_string('Original trainable parameter: %s' % p_num)

        '''Testing pruning model'''
        logger.info('Testing pruning model--%d%%'%perc)
        pruning_net(classifier, perc)
        classifier.cuda()
        p_num = get_parameter_number(classifier)
        log_string('Pruning %02d%% -- trainable parameter: %s' % (perc, p_num))

        with torch.no_grad():

            val_loss1, test_instance_acc1, class_acc1, _, _ = test(classifier, testDataLoader, num_class=args.num_class, ind=0)
            val_loss_wm1, instance_acc_wm, class_acc_wm, _, _ = test(classifier, triggerDataLoader, num_class=args.num_class, ind=0)
            val_loss2, test_instance_acc2, class_acc2, signloss2, signacc2 = test(classifier, testDataLoader, num_class=args.num_class, ind=1)
            val_loss_wm2, instance_acc_wm2, class_acc_wm2, _, _ = test(classifier, triggerDataLoader, num_class=args.num_class, ind=1)

        log_string('Pruning %d%% Test Clean Public Accuracy: %f, Class Public Accuracy: %f' % (perc,test_instance_acc1, class_acc1))
        log_string('Pruning %d%% Test Clean Private Accuracy: %f, Class Private Accuracy: %f' % (perc,test_instance_acc2, class_acc2))
        log_string('Pruning %d%% Test Private Sign  Accuracy: %f' % (perc,signacc2))
        log_string('Pruning %d%% Test Trigger Public Accuracy: %f, Trigger Class Public Accuracy: %f' % (perc,instance_acc_wm, class_acc_wm))
        log_string('Pruning %d%% Test Trigger Private Accuracy: %f, Trigger Class Private Accuracy: %f' % (perc,instance_acc_wm2, class_acc_wm2))

        logger_loss.append([ val_loss1, val_loss_wm1, val_loss2,val_loss_wm2])
        logger_acc.append([test_instance_acc1 *100 ,  test_instance_acc2 *100,  instance_acc_wm*100, signacc2*100])
        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

    logger_loss.close()
    logger_loss.plot_prun()
    savefig(os.path.join(log_dir, 'log_loss.eps'))
    logger_acc.close()
    logger_acc.plot_prun()
    acc_name = args.remark + '_prun.eps'
    savefig(os.path.join(log_dir, acc_name))

    logger.info('End of pruning...')
def main():
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.remark != None:
        args.remark = args.remark
    else:
        args.remark = args.dataset + "-" + args.task + "-" + args.norm

    if args.dataset == "shapenet":
        args.num_class = 16
    else:
        args.num_class = 40

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    # experiment_dir = Path('./exp/v2/')
    experiment_dir = Path('/data-x/g12/zhangjie/3dIP/exp3.0/v2')
    # experiment_dir = Path('/data-x/g10/zhangjie/3D/exp/v2')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.remark + "_" + timestr)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG_curve'''
    title = args.dataset + "-" + args.task + "-" + args.norm
    logger_loss = Logger(os.path.join(log_dir, 'log_loss.txt'), title=title)
    logger_loss.set_names([
        'Train AVE Loss', 'Train Public Loss', 'Train Private Loss',
        'Valid AVE Loss', 'Valid Public Loss', 'Valid Private Loss'
    ])
    logger_acc = Logger(os.path.join(log_dir, 'log_acc.txt'), title=title)
    logger_acc.set_names([
        'Train Public Acc.', 'Train Private Acc.', 'Test Public Acc.',
        'Test Private Acc.'
    ])
    '''LOG'''  #创建log文件
    logger = logging.getLogger("Model")  #log的名字
    # print("FFFFFFFF",logger) #<Logger Model (WARNING)>
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)  #log的最低等级
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log文件名
    log_string('PARAMETER ...')
    # print("FFFFFFF",logger.info)  #<bound method Logger.info of <Logger Model (INFO)>>
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    if args.dataset == "shapenet":
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=True,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=True,
                                                batchsize=args.batch_size)
    else:
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=False,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=False,
                                                batchsize=args.batch_size)

    log_string('Finished ...')
    log_string('Load model ...')
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)
    # 当在写代码时,我们希望能够根据传入的选项设置,如args.model来确定要导入使用的是哪个model.py文件,而不是一股脑地导入, 这种时候就需要用上python的动态导入模块

    # 复制model文件到exp——dir
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    shutil.copy('train_2_cls.py', str(experiment_dir))
    shutil.copy('./data/getData.py', str(experiment_dir))
    shutil.copytree('./models/layers', str(experiment_dir) + "/layers")

    classifier = MODEL.get_model(num_class, channel=3).cuda()
    criterion = MODEL.get_loss().cuda()

    pprint(classifier)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    mean_correct2 = []
    mean_loss = []
    mean_loss1 = []
    mean_loss2 = []
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        time_start = datetime.datetime.now()
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            points = points.data.numpy()
            points = provider.random_point_dropout(
                points)  #provider是自己写的一个对点云操作的函数,随机dropout,置为第一个点的值
            points[:, :,
                   0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                   0:3])  #点的放缩
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :,
                                                                  0:3])  #点的偏移
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()

            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    m.reset()

            loss1 = torch.tensor(0.).cuda()
            loss2 = torch.tensor(0.).cuda()
            sign_loss = torch.tensor(0.).cuda()

            for ind in range(2):
                if ind == 0:
                    pred, trans_feat = classifier(points, ind=ind)
                    loss1 = criterion(pred, target.long(), trans_feat)
                    mean_loss1.append(loss1.item() / float(points.size()[0]))
                    pred_choice = pred.data.max(1)[1]
                    correct = pred_choice.eq(target.long().data).cpu().sum()
                    mean_correct.append(correct.item() /
                                        float(points.size()[0]))

                else:
                    pred2, trans_feat2 = classifier(points, ind=ind)
                    loss2 = criterion(pred2, target.long(), trans_feat2)
                    mean_loss2.append(loss2.item() / float(points.size()[0]))
                    pred_choice2 = pred2.data.max(1)[1]
                    correct2 = pred_choice2.eq(target.long().data).cpu().sum()
                    mean_correct2.append(correct2.item() /
                                         float(points.size()[0]))

            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    sign_loss += m.loss

            loss = args.beta * loss1 + loss2 + sign_loss
            mean_loss.append(loss.item() / float(points.size()[0]))

            # loss = loss2
            loss.backward()
            optimizer.step()
            global_step += 1

        train_instance_acc = np.mean(mean_correct)
        train_instance_acc2 = np.mean(mean_correct2)
        train_instance_acc_ave = (train_instance_acc + train_instance_acc2) / 2
        train_loss = np.mean(mean_loss) / 2
        train_loss1 = np.mean(mean_loss1)
        train_loss2 = np.mean(mean_loss2)

        log_string('Train Instance Public Accuracy: %f' % train_instance_acc)
        log_string('Train Instance Private Accuracy: %f' % train_instance_acc2)

        sign_acc = torch.tensor(0.).cuda()
        count = 0

        for m in classifier.modules():
            if isinstance(m, SignLoss):
                sign_acc += m.acc
                count += 1

        if count != 0:
            sign_acc /= count

        log_string('Sign Accuracy: %f' % sign_acc)

        with torch.no_grad():
            for ind in range(2):
                if ind == 0:
                    val_loss1, test_instance_acc1, class_acc1 = test(
                        classifier,
                        testDataLoader,
                        num_class=args.num_class,
                        ind=0)
                else:
                    val_loss2, test_instance_acc2, class_acc2 = test(
                        classifier,
                        testDataLoader,
                        num_class=args.num_class,
                        ind=1)

            log_string(
                'Test Instance Public Accuracy: %f, Class Public Accuracy: %f'
                % (test_instance_acc1, class_acc1))
            log_string(
                'Test Instance Private Accuracy: %f, Class Private Accuracy: %f'
                % (test_instance_acc2, class_acc2))

            val_loss = (val_loss1 + val_loss2) / 2
            test_instance_acc = (test_instance_acc1 + test_instance_acc2) / 2
            class_acc = (class_acc1 + class_acc2) / 2

            if (test_instance_acc >= best_instance_acc):
                best_instance_acc = test_instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string(
                'Test Instance Average Accuracy: %f, Class Average Accuracy: %f'
                % (test_instance_acc, class_acc))
            log_string(
                'Best Instance Average Accuracy: %f, Class Average Accuracy: %f'
                % (best_instance_acc, best_class_acc))

            if (test_instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                log_string('best_epoch %s' % str(best_epoch))
                state = {
                    'epoch': best_epoch,
                    'instance_acc': test_instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

        logger_loss.append([
            train_loss, train_loss1, train_loss2, val_loss, val_loss1,
            val_loss2
        ])
        logger_acc.append([
            train_instance_acc, train_instance_acc2, test_instance_acc1,
            test_instance_acc2
        ])

        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

    logger_loss.close()
    logger_loss.plot()
    savefig(os.path.join(log_dir, 'log_loss.eps'))
    logger_acc.close()
    logger_acc.plot()
    savefig(os.path.join(log_dir, 'log_acc.eps'))

    log_string('best_epoch %s' % str(best_epoch))

    logger.info('End of training...')