Пример #1
0
def train(net, trainloader, optimizer, criterion, device):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    time_cost = datetime.datetime.now()
    for batch_idx, (points, targets) in enumerate(trainloader):
        points = points.data.numpy()
        points = provider.random_point_dropout(points)
        points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
        points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
        points = torch.Tensor(points)
        points = points.transpose(2, 1)
        points, targets = points.to(device), targets.to(device).long()
        optimizer.zero_grad()
        out = net(points)
        loss = criterion(out, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = out["logits"].max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))

    time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
    return {
        "loss": float("%.3f" % (train_loss / (batch_idx + 1))),
        "acc": float("%.3f" % (100. * correct / total)),
        "time": time_cost
    }
Пример #2
0
def train_one_epoch(sess, ops, train_writer, dataset, verbose=True):
  """
  Train model for one epoch
  """
  global EPOCH_CNT
  is_training = True

  # Shuffle train samples
  train_idxs = np.arange(0, len(dataset))
  np.random.shuffle(train_idxs)

  num_batches = len(dataset) / FLAGS['BATCH_SIZE'] # discards samples if dataset not divisible by batch size

  log_string('[' + str(datetime.now()) + ' | EPOCH ' + str(EPOCH_CNT) + '] Starting training.', printout=False)

  loss_sum, batch_print_steps = 0, 10
  for batch_idx in range(num_batches):
    start_idx, end_idx = batch_idx * FLAGS['BATCH_SIZE'], (batch_idx + 1) * FLAGS['BATCH_SIZE']
    batch_data, batch_label = get_batch(dataset, train_idxs, start_idx, end_idx)
    # Perturb point clouds:
    batch_data[:,:,:3] = provider.jitter_point_cloud(batch_data[:,:,:3])
    batch_data[:,:,:3] = provider.rotate_perturbation_point_cloud(batch_data[:,:,:3])
    batch_data[:,:,:3] = provider.shift_point_cloud(batch_data[:,:,:3])
    batch_data[:,:,:3] = provider.random_point_dropout(batch_data[:,:,:3],
                                                       max_dropout_ratio=FLAGS['MAX_POINT_DROPOUT_RATIO'])
    feed_dict = {ops['pointclouds_pl']: batch_data,
                 ops['labels_pl']: batch_label,
                 ops['is_training_pl']: is_training}
    summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'],
                                                     ops['loss'], ops['pred']], feed_dict=feed_dict)
    train_writer.add_summary(summary, step)
    loss_sum += loss_val
    if batch_idx % batch_print_steps == 0:
      log_string('[Batch %03d] Mean Loss: %f' % ((batch_idx + 1), (loss_sum / batch_print_steps)), printout=verbose)
      loss_sum = 0
Пример #3
0
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)

    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        current_data, current_label, normal_data = provider.loadDataFile_with_normal(
            TRAIN_FILES[train_file_idxs[fn]])
        normal_data = normal_data[:, 0:NUM_POINT, :]
        current_data = current_data[:, 0:NUM_POINT, :]
        current_data, current_label, shuffle_idx = provider.shuffle_data(
            current_data, np.squeeze(current_label))
        current_label = np.squeeze(current_label)
        normal_data = normal_data[shuffle_idx, ...]

        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE

        total_correct = 0
        total_seen = 0
        loss_sum = 0

        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx + 1) * BATCH_SIZE

            # Augment batched point clouds by rotation and jittering
            rotated_data = provider.rotate_point_cloud(
                current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)
            input_data = np.concatenate(
                (jittered_data, normal_data[start_idx:end_idx, :, :]), 2)
            #random point dropout
            input_data = provider.random_point_dropout(input_data)

            feed_dict = {
                ops['pointclouds_pl']: input_data,
                ops['labels_pl']: current_label[start_idx:end_idx],
                ops['is_training_pl']: is_training,
            }
            summary, step, _, loss_val, pred_val = sess.run(
                [
                    ops['merged'], ops['step'], ops['train_op'], ops['loss'],
                    ops['pred']
                ],
                feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val

        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen)))
Пример #4
0
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    log_string(str(datetime.now()))

    # Make sure batch data is of same size
    cur_batch_data = np.zeros(
        (BATCH_SIZE, NUM_POINT, TRAIN_DATASET.num_channel()))
    cur_batch_label = np.zeros((BATCH_SIZE), dtype=np.int32)

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    batch_idx = 0
    while TRAIN_DATASET.has_next_batch():
        batch_data, batch_label = TRAIN_DATASET.next_batch(augment=True)
        if FLAGS.dropout:
            batch_data = provider.random_point_dropout(batch_data)
        else:
            assert False
        bsize = batch_data.shape[0]
        cur_batch_data[0:bsize, ...] = batch_data
        cur_batch_label[0:bsize] = batch_label

        feed_dict = {
            ops['pointclouds_pl']: cur_batch_data,
            ops['labels_pl']: cur_batch_label,
            ops['is_training_pl']: is_training,
        }
        summary, step, _, loss_val, pred_val = sess.run([
            ops['merged'], ops['step'], ops['train_op'], ops['loss'],
            ops['pred']
        ],
                                                        feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val[0:bsize] == batch_label[0:bsize])
        total_correct += correct
        total_seen += bsize
        loss_sum += loss_val
        if (batch_idx + 1) % 50 == 0:
            log_string(' ---- batch: %03d ----' % (batch_idx + 1))
            log_string('mean loss: %f' % (loss_sum / 50))
            log_string('accuracy: %f' % (total_correct / float(total_seen)))
            total_correct = 0
            total_seen = 0
            loss_sum = 0
        batch_idx += 1

    TRAIN_DATASET.reset()
Пример #5
0
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    # Shuffle train samples
    train_idxs = np.arange(0, len(TRAIN_DATASET))
    np.random.shuffle(train_idxs)
    num_batches = len(TRAIN_DATASET) / BATCH_SIZE

    log_string(str(datetime.now()))

    total_correct = 0
    total_seen = 0
    loss_sum = 0

    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx + 1) * BATCH_SIZE
        batch_data, batch_label = get_batch(TRAIN_DATASET, train_idxs,
                                            start_idx, end_idx)
        aug_data = augment_batch_data(batch_data)

        # random point drop out as from C Qi.
        aug_data = provider.random_point_dropout(aug_data)

        feed_dict = {
            ops['pointclouds_pl']: aug_data,
            ops['labels_pl']: batch_label,
            ops['is_training_pl']: is_training,
        }
        summary, step, _, loss_val, pred_val = sess.run([
            ops['merged'], ops['step'], ops['train_op'], ops['loss'],
            ops['pred']
        ],
                                                        feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == batch_label)
        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val

        if (batch_idx + 1) % 50 == 0:
            log_string(' -- %03d / %03d --' % (batch_idx + 1, num_batches))
            log_string('mean loss: %f' % (loss_sum / 50))
            log_string('accuracy: %f' % (total_correct / float(total_seen)))
            total_correct = 0
            total_seen = 0
            loss_sum = 0
Пример #6
0
def train(net, opt, scheduler, train_loader, dev):

    net.train()

    total_loss = 0
    num_batches = 0
    total_correct = 0
    count = 0
    loss_f = nn.CrossEntropyLoss()
    start_time = time.time()
    with tqdm.tqdm(train_loader, ascii=True) as tq:
        for data, label in tq:
            data = data.data.numpy()
            data = provider.random_point_dropout(data)
            data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :,
                                                                     0:3])
            data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
            data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
            data = torch.tensor(data)
            label = label[:, 0]

            num_examples = label.shape[0]
            data, label = data.to(dev), label.to(dev).squeeze().long()
            opt.zero_grad()
            logits = net(data)
            loss = loss_f(logits, label)
            loss.backward()
            opt.step()

            _, preds = logits.max(1)

            num_batches += 1
            count += num_examples
            loss = loss.item()
            correct = (preds == label).sum().item()
            total_loss += loss
            total_correct += correct

            tq.set_postfix({
                'AvgLoss': '%.5f' % (total_loss / num_batches),
                'AvgAcc': '%.5f' % (total_correct / count)
            })
    print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
        total_loss / num_batches, total_correct / count,
        time.time() - start_time))
    scheduler.step()
Пример #7
0
def get_features_from_encoder(encoder, loader):
    raw_model = pointnet2_cls_msg_raw.get_model(num_class=40,
                                                normal_channel=True).cuda()
    x_train = []
    y_train = []
    print(type(loader))
    # get the features from the pre-trained model
    for batch_id, data in tqdm(enumerate(loader, 0),
                               total=len(loader),
                               smoothing=0.9):
        points, target = data
        points = points.data.numpy()
        points = provider.random_point_dropout(points)
        points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                     0:3])
        points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
        points = torch.Tensor(points)
        target = target[:, 0]
        points = points.transpose(2, 1)
        points, target = points.cuda(), target.cuda()
        with torch.no_grad():
            raw_feature, raw_cls = raw_model(points)
            feature_vector, cls = encoder(points)
            feature_vector = torch.cat((raw_feature, feature_vector), 1)
            #这里要用extend,append会把[12*128]一块放进去
            x_train.extend(feature_vector.cpu().numpy())
            y_train.extend(target.cpu().numpy())
    x_train = np.array(x_train)
    y_train = torch.tensor(y_train)
    print("success feature")

    # for i, (x,y) in enumerate(loader):
    #     # i=i.to(device)
    #     # x=x.to(device)
    #     # y=y.to(device)
    #     x1=torch.tensor([item.cpu().detach().numpy() for item in x1]).cuda()
    #     with torch.no_grad():
    #         feature_vector = encoder(x1)
    #         x_train.extend(feature_vector)
    #         y_train.extend(y.numpy())
    return x_train, y_train
def main():
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.remark != None:
        args.remark = args.remark
    else:
        args.remark = args.ori_dataset + "-" + args.task + "-" + args.norm

    if args.dataset == "shapenet":
        args.num_class = 40
        args.ft_num_class = 16
        args.ori_dataset = "modelnet"
    else:
        args.num_class = 16
        args.ft_num_class = 40
        args.ori_dataset = "shapenet"

    if args.rtll:
        args.ft_type = "RTLL"
    else:
        args.ft_type = "RTAL"

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))

    if args.task == 'baseline':
        experiment_dir_root = Path('/data-x/g12/zhangjie/3dIP/baseline')
    else:
        experiment_dir_root = Path('/data-x/g12/zhangjie/3dIP/ours')

    experiment_dir_root.mkdir(exist_ok=True)
    experiment_dir = experiment_dir_root.joinpath('fine-tune')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.remark)
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath("ft_" + args.dataset + "-" +
                                             args.ft_type + "_" + timestr)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG_curve'''
    title = ''
    logger_loss = Logger(os.path.join(log_dir, 'log_loss.txt'), title=title)
    logger_loss.set_names([
        'Train AVE Loss', 'Train Public Loss', 'Train Private Loss',
        'Valid AVE Loss', 'Valid Public Loss', 'Valid Private Loss'
    ])
    logger_acc = Logger(os.path.join(log_dir, 'log_acc.txt'), title=title)
    logger_acc.set_names([
        'Train AVE Acc.', 'Train Public Acc.', 'Train Private Acc.',
        'Valid AVE Acc.', 'Valid Public Acc.', 'Valid Private Acc.',
        'Valid Private Sign Acc.'
    ])
    '''LOG'''  #创建log文件
    logger = logging.getLogger("Model")  #log的名字
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)  #log的最低等级
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log文件名
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load original dataset ...')
    if args.dataset == "shapenet":
        testDataLoader = getData_ft.get_dataLoader(train=False,
                                                   Shapenet=False,
                                                   batchsize=args.batch_size)
    else:
        testDataLoader = getData_ft.get_dataLoader(train=False,
                                                   Shapenet=True,
                                                   batchsize=args.batch_size)

    log_string('Load finished ...')

    log_string('Load fine tune dataset ...')
    if args.dataset == "shapenet":
        ft_trainDataLoader = getData_ft.get_dataLoader(
            train=True, Shapenet=True, batchsize=args.batch_size)
        ft_testDataLoader = getData_ft.get_dataLoader(
            train=False, Shapenet=True, batchsize=args.batch_size)
    else:
        ft_trainDataLoader = getData_ft.get_dataLoader(
            train=True, Shapenet=False, batchsize=args.batch_size)
        ft_testDataLoader = getData_ft.get_dataLoader(
            train=False, Shapenet=False, batchsize=args.batch_size)
    log_string('Load finished ...')
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)

    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    shutil.copy('fine_tune2.py', str(experiment_dir))

    classifier = MODEL.get_model(num_class, channel=3).cuda()

    # pprint(classifier)

    sd = experiment_dir_root.joinpath('classification')
    sd.mkdir(exist_ok=True)
    sd = sd.joinpath(args.ori_dataset + "-" + args.task + "-" + args.norm)
    sd.mkdir(exist_ok=True)
    sd = sd.joinpath('checkpoints/best_model.pth')

    log_string('pre-trained model chk pth: %s' % sd)
    checkpoint = torch.load(sd)
    model_dict = checkpoint['model_state_dict']
    p_num = get_parameter_number(classifier)
    log_string('Original trainable parameter: %s' % p_num)
    # print(p_num)
    print("best epoch", checkpoint['epoch'])
    classifier.load_state_dict(model_dict)
    classifier.cuda()
    '''TESTING ORIGINAL'''
    logger.info('Test original model...')

    with torch.no_grad():
        _, instance_acc, class_acc, _, _ = test(classifier,
                                                testDataLoader,
                                                num_class=args.num_class,
                                                ind=0)
        _, instance_acc2, class_acc2, signloss, signacc = test(
            classifier, testDataLoader, num_class=args.num_class, ind=1)
        log_string(
            'Original Instance Public Accuracy: %f, Class Public Accuracy: %f'
            % (instance_acc, class_acc))
        log_string(
            'Original Instance Private Accuracy: %f, Class Private Accuracy: %f'
            % (instance_acc2, class_acc2))
        log_string('Private  Sign Accuracy: %f' % (signacc))

    # fine tune the last year
    # classifier, _ = re_initializer_layer(classifier, args.ft_num_class)
    classifier, _ = re_initializer_passport_layer(classifier,
                                                  args.ft_num_class)
    if args.rtll:
        classifier.freeze_hidden_layers()
    elif args.task == 'ours':
        classifier.freeze_passport_layers()
    else:
        pass

    criterion = MODEL.get_loss().cuda()
    #fine tune param number
    p_num = get_parameter_number(classifier)
    log_string('Fine tune trainable parameter: %s' % p_num)

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    mean_correct2 = []
    mean_loss = []
    mean_loss1 = []
    mean_loss2 = []
    '''FINR TUNEING'''
    logger.info('Start training of tine tune...')
    start_epoch = 0

    for epoch in range(start_epoch, args.epoch):
        time_start = datetime.datetime.now()
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        for batch_id, data in tqdm(enumerate(ft_trainDataLoader, 0),
                                   total=len(ft_trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            points = points.data.numpy()
            points = provider.random_point_dropout(
                points)  # provider是自己写的一个对点云操作的函数,随机dropout,置为第一个点的值
            points[:, :, 0:3] = provider.random_scale_point_cloud(
                points[:, :, 0:3])  # 点的放缩
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :,
                                                                  0:3])  # 点的偏移
            points = torch.Tensor(points)
            # target = target[:, 0]

            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()

            for ind in range(2):
                if ind == 0:
                    pred, trans_feat = classifier(points, ind=ind)
                    loss1 = criterion(pred, target.long(), trans_feat)
                    mean_loss1.append(loss1.item() / float(points.size()[0]))
                    pred_choice = pred.data.max(1)[1]
                    correct = pred_choice.eq(target.long().data).cpu().sum()
                    mean_correct.append(correct.item() /
                                        float(points.size()[0]))

                else:
                    pred2, trans_feat2 = classifier(points, ind=ind)
                    loss2 = criterion(pred2, target.long(), trans_feat2)
                    mean_loss2.append(loss2.item() / float(points.size()[0]))
                    pred_choice2 = pred2.data.max(1)[1]
                    correct2 = pred_choice2.eq(target.long().data).cpu().sum()
                    mean_correct2.append(correct2.item() /
                                         float(points.size()[0]))

            # loss = args.beta * loss1 +loss2
            loss = loss1
            mean_loss.append(loss.item() / float(points.size()[0]))
            # loss = loss2
            loss.backward()
            optimizer.step()
            global_step += 1

        train_instance_acc = np.mean(mean_correct)
        train_instance_acc2 = np.mean(mean_correct2)
        train_instance_acc_ave = (train_instance_acc + train_instance_acc2) / 2
        train_loss = np.mean(mean_loss) / 2
        train_loss1 = np.mean(mean_loss1)
        train_loss2 = np.mean(mean_loss2)
        log_string('FT-Train Instance Public Accuracy: %f' %
                   train_instance_acc)
        log_string('FT-Train Instance Private Accuracy: %f' %
                   train_instance_acc2)

        with torch.no_grad():
            val_loss1, test_instance_acc1, class_acc1, _, _ = test(
                classifier,
                ft_testDataLoader,
                num_class=args.ft_num_class,
                ind=0)
            val_loss2, test_instance_acc2, class_acc2, signloss, signacc = test(
                classifier,
                ft_testDataLoader,
                num_class=args.ft_num_class,
                ind=1)
            log_string(
                'FT-Test Instance Public Accuracy: %f, Class Public Accuracy: %f'
                % (test_instance_acc1, class_acc1))
            log_string(
                'FT-Test Instance Private Accuracy: %f, Class Private Accuracy: %f'
                % (test_instance_acc2, class_acc2))
            log_string('FT-Test Private  Sign Accuracy: %f' % (signacc))

            # for ind in range(2):
            #     if ind == 0:
            #         val_loss1, test_instance_acc1, class_acc1 = test(classifier, ft_testDataLoader, num_class=args.ft_num_class, ind=0)
            #     else:
            #         val_loss2, test_instance_acc2, class_acc2 = test(classifier, ft_testDataLoader, num_class=args.ft_num_class, ind =1)
            #
            # log_string('FT-Test Instance Public Accuracy: %f, Class Public Accuracy: %f'% (test_instance_acc1, class_acc1))
            # log_string('FT-Test Instance Private Accuracy: %f, Class Private Accuracy: %f'% (test_instance_acc2, class_acc2))

            val_loss = (val_loss1 + val_loss2) / 2
            test_instance_acc = (test_instance_acc1 + test_instance_acc2) / 2
            class_acc = (class_acc1 + class_acc2) / 2

            if (test_instance_acc >= best_instance_acc):
                best_instance_acc = test_instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string(
                'FT-Test Instance Average Accuracy: %f, Class Average Accuracy: %f'
                % (test_instance_acc, class_acc))
            log_string(
                'FT-Best Instance Average Accuracy: %f, Class Average Accuracy: %f'
                % (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                log_string('best_epoch %s' % str(best_epoch))
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

        logger_loss.append([
            train_loss, train_loss1, train_loss2, val_loss, val_loss1,
            val_loss2
        ])
        logger_acc.append([
            train_instance_acc_ave, train_instance_acc, train_instance_acc2,
            test_instance_acc, test_instance_acc1, test_instance_acc2, signacc
        ])

        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

    logger_loss.close()
    logger_loss.plot()
    savefig(os.path.join(log_dir, 'log_loss.eps'))
    logger_acc.close()
    logger_acc.plot()
    savefig(os.path.join(log_dir, 'log_acc.eps'))

    log_string('best_epoch %s' % str(best_epoch))

    logger.info('End of fine-turning...')
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    exp_dir = Path('./log/')
    exp_dir.mkdir(exist_ok=True)
    exp_dir = exp_dir.joinpath('classification')
    exp_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        exp_dir = exp_dir.joinpath(timestr)
    else:
        exp_dir = exp_dir.joinpath(args.log_dir)
    exp_dir.mkdir(exist_ok=True)
    checkpoints_dir = exp_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = exp_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    #设置数据集路径
    data_path = '/home/wgk/dataset/Pointnet_Pointnet2_pytorch/modelnet40_normal_resampled/'
    #设置训练数据集
    train_dataset = ModelNetDataLoader(root=data_path,
                                       args=args,
                                       split='train',
                                       process_data=args.process_data)
    #设置测试集
    test_dataset = ModelNetDataLoader(root=data_path,
                                      args=args,
                                      split='test',
                                      process_data=args.process_data)
    #加载训练集合
    trainDataLoader = torch.utils.data.DataLoader(train_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=1,
                                                  drop_last=True)
    testDataLoader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=1)
    '''MODEL LOADING'''
    num_class = args.num_category
    #这里默认导入pointnet2_cls_msg.py文件,语法的使用 https://www.bilibili.com/read/cv5891176/
    model = importlib.import_module(args.model)

    shutil.copy('./models/%s.py' % args.model, str(exp_dir))
    shutil.copy('models/pointnet2_utils.py', str(exp_dir))
    shutil.copy('./train_classification.py', str(exp_dir))
    #加载分类器模型(实例化get_model这个class)
    classifier = model.get_model(num_class, normal_channel=args.use_normals)
    #实例化
    criterion = model.get_loss()
    #.apply函数,是应用在
    classifier.apply(inplace_relu)

    if not args.use_cpu:
        #将模型转移到gpu中
        classifier = classifier.cuda()
        criterion = criterion.cuda()

    try:
        checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    #配置优化器
    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    '''TRANING'''
    logger.info('Start training...')
    #默认训练200轮
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))
        mean_correct = []
        classifier = classifier.train()

        scheduler.step()
        #
        for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0),
                                               total=len(trainDataLoader),
                                               smoothing=0.9):
            optimizer.zero_grad()

            #获取到一个batch的数据,形状是(batch_size,点数=1024,channel)
            points = points.data.numpy()
            #print(points.shape)
            points = provider.random_point_dropout(points)
            #对数据集数据进行随机缩放
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                         0:3])
            #对数据集数据进行随机旋转
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            #将ndarray数据转换为tensor
            points = torch.Tensor(points)
            #转置一下形状(batch_size,channel,点数)
            points = points.transpose(2, 1)

            if not args.use_cpu:
                #先将数据转移到显卡中
                points, target = points.cuda(), target.cuda()
            #执行训练,返回预测的值
            # pred.shape=[batchsize,40]  trans_feat.shape=[batchsize,1024,1]
            pred, trans_feat = classifier(points)
            #print(pred.shape,trans_feat.shape)
            #loss函数按说应该是添加一个正则规范项,但是这里的loss实际上并没有添加
            loss = criterion(pred, target.long(), trans_feat)
            pred_choice = pred.data.max(1)[1]
            #print(pred_choice)

            #查看预测值与真实值哪个位置相同,转到cpu中,并求总共有几个相同的
            correct = pred_choice.eq(target.long().data).cpu().sum()
            #求正确预测的百分比,points.size()[0]=batchsize
            mean_correct.append(correct.item() / float(points.size()[0]))
            #这里的loss此时是一个tensor,所以可应用backward函数
            loss.backward()
            optimizer.step()
            global_step += 1

        train_instance_acc = np.mean(mean_correct)
        log_string('Train Instance Accuracy: %f' % train_instance_acc)

        with torch.no_grad():
            instance_acc, class_acc = test(classifier.eval(),
                                           testDataLoader,
                                           num_class=num_class)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Instance Accuracy: %f, Class Accuracy: %f' %
                       (instance_acc, class_acc))
            log_string('Best Instance Accuracy: %f, Class Accuracy: %f' %
                       (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
Пример #10
0
def train_one_epoch(sess, ops, logger):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    log_string(str(datetime.now()))

    # Make sure batch data is of same size
    cur_batch_data = np.zeros((BATCH_SIZE, NUM_POINT,
                               TRAIN_DATASET.num_channel()))
    cur_batch_label = np.zeros((BATCH_SIZE), dtype=np.int32)

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    batch_idx = 0
    while TRAIN_DATASET.has_next_batch():
        # batch_data, batch_label = TRAIN_DATASET.next_batch(augment=True)
        # assert FLAGS.aug
        batch_data, batch_label = TRAIN_DATASET.next_batch(augment=FLAGS.aug)

        if FLAGS.dropout:
            batch_data = provider.random_point_dropout(batch_data)

        bsize = batch_data.shape[0]
        # offset = np.zeros([bsize, batch_data.shape[1], NUM_NEIGHBORS, 3])
        # for b in range(bsize):
        #     _data = batch_data[b]  # [N, 3]
        #     nbrs = NearestNeighbors(
        #         n_neighbors=NUM_NEIGHBORS, algorithm='ball_tree').fit(_data)
        #     _, idx = nbrs.kneighbors(_data)
        #     offset[b] = _data[idx] - _data[:, np.newaxis, :]

        cur_batch_data[0:bsize, ...] = batch_data
        cur_batch_label[0:bsize] = batch_label

        feed_dict = {
            ops['pointclouds_pl']: cur_batch_data,
            ops['labels_pl']: cur_batch_label,
            ops['is_training_pl']: is_training,
            # ops['offset_pl']: offset
        }
        summary, step, _, loss_val, pred_val = sess.run(
            [
                ops['merged'], ops['step'], ops['train_op'], ops['loss'],
                ops['pred']
            ],
            feed_dict=feed_dict)
        # train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val[0:bsize] == batch_label[0:bsize])
        total_correct += correct
        total_seen += bsize
        loss_sum += loss_val
        if (batch_idx + 1) % 50 == 0:
            log_string(' ---- batch: %03d ----' % (batch_idx + 1))
            avg_loss = loss_sum / 50
            avg_acc = total_correct / float(total_seen)
            log_string('mean loss: %f' % avg_loss)
            log_string('accuracy: %f' % avg_acc)
            total_correct = 0
            total_seen = 0
            loss_sum = 0
        batch_idx += 1

    logger.log('train loss', EPOCH_CNT, avg_loss)
    logger.log('train acc', EPOCH_CNT, avg_acc)
    logger.flush()

    TRAIN_DATASET.reset()
Пример #11
0
def main(args):
    # 定义输出log和输出console
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu  # 设置GPU的编号(可以多GPU运行)
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))  # 获取当前时间
    experiment_dir = Path('./log/')  # 对目录初始化path类
    experiment_dir.mkdir(exist_ok=True)  # 创建目录./log/,exist_ok=True目录存在不报错
    # 创建classification目录
    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)  # 目录存在不报错
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath(
        'checkpoints/')  # 创建checkpoints目录
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    # 设置日志级别info : 打印info,warning,error,critical级别的日志
    logger.setLevel(logging.INFO)
    # 配置日志的格式
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # 格式化字符串,路径为log_dir/args.model.txt
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    DATA_PATH = 'data/modelnet40_normal_resampled/'

    TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH,
                                       npoint=args.num_point,
                                       split='train',
                                       normal_channel=args.normal)
    TEST_DATASET = ModelNetDataLoader(root=DATA_PATH,
                                      npoint=args.num_point,
                                      split='test',
                                      normal_channel=args.normal)
    import torch.utils.data.dataloader
    # 读取TRAIN_DATASET,设置batch_size,shuffle=True,打乱顺序,num_workers多线程
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=4)
    # test数据集不需要打乱顺序
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=4)
    '''MODEL LOADING'''
    num_class = 40
    # 导入model模型,相当于 import model
    MODEL = importlib.import_module(args.model)
    # 拷贝文件和权限,拷贝到experiment_dir='log\\classification\\pointnet2_cls_msg'
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))

    # 调用model(如pointnet_cls.py)中的方法
    classifier = MODEL.get_model(num_class, normal_channel=args.normal).cuda()
    criterion = MODEL.get_loss().cuda()

    # 间断后继续训练
    try:
        # 读取pth文件,pth中都以字典存储
        # 将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)储存
        # (注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)
        # 优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        # 读取epoch
        start_epoch = checkpoint['epoch']
        # 冲checkpoint中读取model_state_dict,并且用load_state_dict恢复模型参数
        classifier.load_state_dict(checkpoint['model_state_dict'])
        # 写入log
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    # 创建optimizer优化器对象,这个对象能够保持当前参数状态并基于计算得到的梯度进行参数更新
    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            # 输入参数
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)
    # todo: https://blog.csdn.net/qyhaill/article/details/103043637
    # 每过step_size个epoch,做一次更新
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    '''TRANING'''
    # 输出log
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        # 更新lr,在epoch处调整
        scheduler.step()
        # total迭代总次数,默认为迭代元素的长度,smoothing:0-1,0平均速度,1当前速度
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            # 从trainDataLoader中提取points:点集,target:对应的目标
            points, target = data
            # 将tensor转化为numpy
            points = points.data.numpy()
            '''数据增强模块'''
            # 随机丢点
            points = provider.random_point_dropout(points)
            # 随机范围
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                         0:3])
            # 随机移动点云
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            # 转化为tensor
            points = torch.Tensor(points)
            # 得到target
            target = target[:, 0]
            # 输入要求(N,3)因此转置
            points = points.transpose(2, 1)
            # 存入显存
            points, target = points.cuda(), target.cuda()
            # 梯度归零
            optimizer.zero_grad()
            # 训练模式
            classifier = classifier.train()
            # 得到pointnet_cls前向传播返回的两个数据
            pred, trans_feat = classifier(points)
            # loss
            loss = criterion(pred, target.long(), trans_feat)
            # 得到最大值的index
            pred_choice = pred.data.max(1)[1]
            # 将预测pred_choice和target(label)比较返回Boolean,将正确的求和算出预测正确的个数
            correct = pred_choice.eq(target.long().data).cpu().sum()
            # correct.item()提取出tensor中的元素,将正确率其添加到mean_correct数组中
            mean_correct.append(correct.item() / float(points.size()[0]))
            # 反向传播
            loss.backward()
            # 更新梯度
            optimizer.step()
            global_step += 1

        train_instance_acc = np.mean(mean_correct)
        log_string('Train Instance Accuracy: %f' % train_instance_acc)

        # 在这个block(with torch.no_grad():)中不需要计算梯度(test模式)
        with torch.no_grad():
            # classifier.eval():在test模式中禁用dropout和BN
            instance_acc, class_acc = test(classifier.eval(), testDataLoader)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Instance Accuracy: %f, Class Accuracy: %f' %
                       (instance_acc, class_acc))
            log_string('Best Instance Accuracy: %f, Class Accuracy: %f' %
                       (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
def main():
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.remark != None:
        args.remark = args.remark
    else:
        args.remark = args.dataset + "-" + args.task + "-" + args.norm

    if args.dataset == "shapenet":
        args.num_class = 16
    else:
        args.num_class = 40

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    # experiment_dir = Path('./exp/v2/')
    experiment_dir = Path('/data-x/g12/zhangjie/3dIP/exp3.0/v2')
    # experiment_dir = Path('/data-x/g10/zhangjie/3D/exp/v2')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.remark + "_" + timestr)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG_curve'''
    title = args.dataset + "-" + args.task + "-" + args.norm
    logger_loss = Logger(os.path.join(log_dir, 'log_loss.txt'), title=title)
    logger_loss.set_names([
        'Train AVE Loss', 'Train Public Loss', 'Train Private Loss',
        'Valid AVE Loss', 'Valid Public Loss', 'Valid Private Loss'
    ])
    logger_acc = Logger(os.path.join(log_dir, 'log_acc.txt'), title=title)
    logger_acc.set_names([
        'Train Public Acc.', 'Train Private Acc.', 'Test Public Acc.',
        'Test Private Acc.'
    ])
    '''LOG'''  #创建log文件
    logger = logging.getLogger("Model")  #log的名字
    # print("FFFFFFFF",logger) #<Logger Model (WARNING)>
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)  #log的最低等级
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log文件名
    log_string('PARAMETER ...')
    # print("FFFFFFF",logger.info)  #<bound method Logger.info of <Logger Model (INFO)>>
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    if args.dataset == "shapenet":
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=True,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=True,
                                                batchsize=args.batch_size)
    else:
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=False,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=False,
                                                batchsize=args.batch_size)

    log_string('Finished ...')
    log_string('Load model ...')
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)
    # 当在写代码时,我们希望能够根据传入的选项设置,如args.model来确定要导入使用的是哪个model.py文件,而不是一股脑地导入, 这种时候就需要用上python的动态导入模块

    # 复制model文件到exp——dir
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    shutil.copy('train_2_cls.py', str(experiment_dir))
    shutil.copy('./data/getData.py', str(experiment_dir))
    shutil.copytree('./models/layers', str(experiment_dir) + "/layers")

    classifier = MODEL.get_model(num_class, channel=3).cuda()
    criterion = MODEL.get_loss().cuda()

    pprint(classifier)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    mean_correct2 = []
    mean_loss = []
    mean_loss1 = []
    mean_loss2 = []
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        time_start = datetime.datetime.now()
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            points = points.data.numpy()
            points = provider.random_point_dropout(
                points)  #provider是自己写的一个对点云操作的函数,随机dropout,置为第一个点的值
            points[:, :,
                   0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                   0:3])  #点的放缩
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :,
                                                                  0:3])  #点的偏移
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()

            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    m.reset()

            loss1 = torch.tensor(0.).cuda()
            loss2 = torch.tensor(0.).cuda()
            sign_loss = torch.tensor(0.).cuda()

            for ind in range(2):
                if ind == 0:
                    pred, trans_feat = classifier(points, ind=ind)
                    loss1 = criterion(pred, target.long(), trans_feat)
                    mean_loss1.append(loss1.item() / float(points.size()[0]))
                    pred_choice = pred.data.max(1)[1]
                    correct = pred_choice.eq(target.long().data).cpu().sum()
                    mean_correct.append(correct.item() /
                                        float(points.size()[0]))

                else:
                    pred2, trans_feat2 = classifier(points, ind=ind)
                    loss2 = criterion(pred2, target.long(), trans_feat2)
                    mean_loss2.append(loss2.item() / float(points.size()[0]))
                    pred_choice2 = pred2.data.max(1)[1]
                    correct2 = pred_choice2.eq(target.long().data).cpu().sum()
                    mean_correct2.append(correct2.item() /
                                         float(points.size()[0]))

            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    sign_loss += m.loss

            loss = args.beta * loss1 + loss2 + sign_loss
            mean_loss.append(loss.item() / float(points.size()[0]))

            # loss = loss2
            loss.backward()
            optimizer.step()
            global_step += 1

        train_instance_acc = np.mean(mean_correct)
        train_instance_acc2 = np.mean(mean_correct2)
        train_instance_acc_ave = (train_instance_acc + train_instance_acc2) / 2
        train_loss = np.mean(mean_loss) / 2
        train_loss1 = np.mean(mean_loss1)
        train_loss2 = np.mean(mean_loss2)

        log_string('Train Instance Public Accuracy: %f' % train_instance_acc)
        log_string('Train Instance Private Accuracy: %f' % train_instance_acc2)

        sign_acc = torch.tensor(0.).cuda()
        count = 0

        for m in classifier.modules():
            if isinstance(m, SignLoss):
                sign_acc += m.acc
                count += 1

        if count != 0:
            sign_acc /= count

        log_string('Sign Accuracy: %f' % sign_acc)

        with torch.no_grad():
            for ind in range(2):
                if ind == 0:
                    val_loss1, test_instance_acc1, class_acc1 = test(
                        classifier,
                        testDataLoader,
                        num_class=args.num_class,
                        ind=0)
                else:
                    val_loss2, test_instance_acc2, class_acc2 = test(
                        classifier,
                        testDataLoader,
                        num_class=args.num_class,
                        ind=1)

            log_string(
                'Test Instance Public Accuracy: %f, Class Public Accuracy: %f'
                % (test_instance_acc1, class_acc1))
            log_string(
                'Test Instance Private Accuracy: %f, Class Private Accuracy: %f'
                % (test_instance_acc2, class_acc2))

            val_loss = (val_loss1 + val_loss2) / 2
            test_instance_acc = (test_instance_acc1 + test_instance_acc2) / 2
            class_acc = (class_acc1 + class_acc2) / 2

            if (test_instance_acc >= best_instance_acc):
                best_instance_acc = test_instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string(
                'Test Instance Average Accuracy: %f, Class Average Accuracy: %f'
                % (test_instance_acc, class_acc))
            log_string(
                'Best Instance Average Accuracy: %f, Class Average Accuracy: %f'
                % (best_instance_acc, best_class_acc))

            if (test_instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                log_string('best_epoch %s' % str(best_epoch))
                state = {
                    'epoch': best_epoch,
                    'instance_acc': test_instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

        logger_loss.append([
            train_loss, train_loss1, train_loss2, val_loss, val_loss1,
            val_loss2
        ])
        logger_acc.append([
            train_instance_acc, train_instance_acc2, test_instance_acc1,
            test_instance_acc2
        ])

        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

    logger_loss.close()
    logger_loss.plot()
    savefig(os.path.join(log_dir, 'log_loss.eps'))
    logger_acc.close()
    logger_acc.plot()
    savefig(os.path.join(log_dir, 'log_acc.eps'))

    log_string('best_epoch %s' % str(best_epoch))

    logger.info('End of training...')
Пример #13
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    exp_dir = Path('./log/')
    exp_dir.mkdir(exist_ok=True)
    exp_dir = exp_dir.joinpath('reg_seg_heatmap_v3')
    exp_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        exp_dir = exp_dir.joinpath(timestr)
    else:
        exp_dir = exp_dir.joinpath(args.log_dir)
    exp_dir.mkdir(exist_ok=True)
    checkpoints_dir = exp_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = exp_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    # Construct the dataset
    train_dataset, train_config = construct_dataset(is_train=True)
    # Random split
    train_set_size = int(len(train_dataset) * 0.8)
    valid_set_size = len(train_dataset) - train_set_size
    train_dataset, valid_dataset = torch.utils.data.random_split(
        train_dataset, [train_set_size, valid_set_size])
    # And the dataloader
    trainDataLoader = DataLoader(dataset=train_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=4)
    validDataLoader = DataLoader(dataset=valid_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4)
    '''MODEL LOADING'''
    out_channel = args.out_channel
    model = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(exp_dir))
    shutil.copy('models/pointnet2_utils.py', str(exp_dir))
    shutil.copy('./train_pointnet2_reg_seg_heatmap_stepsize.py', str(exp_dir))

    #network = model.get_model(out_channel, normal_channel=args.use_normals)
    network = model.get_model(out_channel)
    criterion_rmse = RMSELoss()
    criterion_cos = torch.nn.CosineSimilarity(dim=1)
    criterion_bce = torch.nn.BCELoss()

    network.apply(inplace_relu)

    if not args.use_cpu:
        network = network.cuda()
        criterion_rmse = criterion_rmse.cuda()
        criterion_cos = criterion_cos.cuda()
    try:
        checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        network.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(network.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_rot_error = 99.9
    best_xyz_error = 99.9
    best_heatmap_error = 99.9
    best_step_size_error = 99.9
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))
        train_rot_error = []
        train_xyz_error = []
        train_heatmap_error = []
        train_step_size_error = []
        network = network.train()

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            optimizer.zero_grad()

            points = data[parameter.pcd_key].numpy()
            points = provider.normalize_data(points)
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                         0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            heatmap_target = data[parameter.heatmap_key]
            segmentation_target = data[parameter.segmentation_key]
            #print('heatmap size', heatmap_target.size())
            #print('segmentation', segmentation_target.size())
            delta_rot = data[parameter.delta_rot_key]
            delta_xyz = data[parameter.delta_xyz_key]
            unit_delta_xyz = data[parameter.unit_delta_xyz_key]
            step_size = data[parameter.step_size_key]

            if not args.use_cpu:
                points = points.cuda()
                delta_rot = delta_rot.cuda()
                delta_xyz = delta_xyz.cuda()
                heatmap_target = heatmap_target.cuda()
                unit_delta_xyz = unit_delta_xyz.cuda()
                step_size = step_size.cuda()

            heatmap_pred, action_pred, step_size_pred = network(points)
            # action control
            delta_rot_pred_6d = action_pred[:, 0:6]
            delta_rot_pred = compute_rotation_matrix_from_ortho6d(
                delta_rot_pred_6d, args.use_cpu)  # batch*3*3
            delta_xyz_pred = action_pred[:, 6:9].view(-1, 3)  # batch*3

            # loss computation
            loss_heatmap = criterion_rmse(heatmap_pred, heatmap_target)
            loss_r = criterion_rmse(delta_rot_pred, delta_rot)
            #loss_t = (1-criterion_cos(delta_xyz_pred, delta_xyz)).mean() + criterion_rmse(delta_xyz_pred, delta_xyz)
            loss_t = (1 - criterion_cos(delta_xyz_pred, unit_delta_xyz)).mean()
            loss_step_size = criterion_bce(step_size_pred, step_size)
            loss = loss_r + loss_t + loss_heatmap + loss_step_size
            loss.backward()
            optimizer.step()
            global_step += 1

            train_rot_error.append(loss_r.item())
            train_xyz_error.append(loss_t.item())
            train_heatmap_error.append(loss_heatmap.item())
            train_step_size_error.append(loss_step_size.item())

        train_rot_error = sum(train_rot_error) / len(train_rot_error)
        train_xyz_error = sum(train_xyz_error) / len(train_xyz_error)
        train_heatmap_error = sum(train_heatmap_error) / len(
            train_heatmap_error)
        train_step_size_error = sum(train_step_size_error) / len(
            train_step_size_error)
        log_string('Train Rotation Error: %f' % train_rot_error)
        log_string('Train Translation Error: %f' % train_xyz_error)
        log_string('Train Heatmap Error: %f' % train_xyz_error)
        log_string('Train Step size Error: %f' % train_step_size_error)

        with torch.no_grad():
            rot_error, xyz_error, heatmap_error, step_size_error = test(
                network.eval(), validDataLoader, out_channel, criterion_rmse,
                criterion_cos, criterion_bce)

            log_string(
                'Test Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f'
                % (rot_error, xyz_error, heatmap_error, step_size_error))
            log_string(
                'Best Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f'
                % (best_rot_error, best_xyz_error, best_heatmap_error,
                   best_step_size_error))

            if (rot_error + xyz_error + heatmap_error + step_size_error) < (
                    best_rot_error + best_xyz_error + best_heatmap_error +
                    best_step_size_error):
                best_rot_error = rot_error
                best_xyz_error = xyz_error
                best_heatmap_error = heatmap_error
                best_step_size_error = step_size_error
                best_epoch = epoch + 1
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'rot_error': rot_error,
                    'xyz_error': xyz_error,
                    'heatmap_error': heatmap_error,
                    'step_size_error': step_size_error,
                    'model_state_dict': network.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
def main():
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.remark != None:
        args.remark = args.remark
    else:
        args.remark = args.dataset + "-" + args.task + "-" + args.norm

    if args.dataset == "shapenet":
        args.num_class = 16
    else:
        args.num_class = 40

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    # experiment_dir = Path('./exp/v1/')
    experiment_dir = Path('/data-x/g12/zhangjie/3dIP/exp/v1')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.remark + "_" + timestr)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG_curve'''
    title = args.dataset + "-" + args.task + "-" + args.norm
    logger_loss = Logger(os.path.join(log_dir, 'log_loss_v1.txt'), title=title)
    logger_loss.set_names(
        ['Train Loss', 'Valid Clean Loss', 'Valid Trigger Loss'])
    logger_acc = Logger(os.path.join(log_dir, 'log_acc_v1.txt'), title=title)
    logger_acc.set_names(
        ['Train  Acc.', 'Valid Clean Acc.', 'Valid Trigger Acc.'])
    '''LOG'''  #创建log文件
    logger = logging.getLogger("Model")  #log的名字
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)  #log的最低等级
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log文件名
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    if args.dataset == "shapenet":
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=True,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=True,
                                                batchsize=args.batch_size)
        triggerDataLoader = getData2.get_dataLoader(Shapenet=True,
                                                    batchsize=args.batch_size)
    else:
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=False,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=False,
                                                batchsize=args.batch_size)
        triggerDataLoader = getData2.get_dataLoader(Shapenet=False,
                                                    batchsize=args.batch_size)

    wminputs, wmtargets = [], []
    for wm_idx, (wminput, wmtarget) in enumerate(triggerDataLoader):
        wminputs.append(wminput)
        wmtargets.append(wmtarget)
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)

    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    shutil.copy('train_1_cls.py', str(experiment_dir))
    shutil.copy('./data/getData.py', str(experiment_dir))
    shutil.copy('./data/getData2.py', str(experiment_dir))
    shutil.copytree('./models/layers', str(experiment_dir) + "/layers")

    classifier = MODEL.get_model(num_class, channel=3).cuda()
    # classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda()
    criterion = MODEL.get_loss().cuda()

    pprint(classifier)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    mean_loss = []
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        time_start = datetime.datetime.now()
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            wm_id = np.random.randint(len(wminputs))
            points = torch.cat(
                [points, wminputs[(wm_id + batch_id) % len(wminputs)]],
                dim=0)  #随机选择wininputs和inputscat
            target = torch.cat(
                [target, wmtargets[(wm_id + batch_id) % len(wminputs)]], dim=0)

            points = points.data.numpy()
            points = provider.random_point_dropout(
                points)  #provider是自己写的一个对点云操作的函数,随机dropout,置为第一个点的值
            points[:, :,
                   0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                   0:3])  #点的放缩
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :,
                                                                  0:3])  #点的偏移
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()

            optimizer.zero_grad()
            classifier = classifier.train()
            pred, trans_feat = classifier(points)
            loss = criterion(pred, target.long(), trans_feat)
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1

            mean_loss.append(loss.item() / float(points.size()[0]))

        train_loss = np.mean(mean_loss)
        train_instance_acc = np.mean(mean_correct)
        log_string('Train Instance Accuracy: %f' % train_instance_acc)

        with torch.no_grad():
            val_loss, instance_acc, class_acc = test(classifier,
                                                     testDataLoader,
                                                     num_class=args.num_class)
            val_loss2, instance_acc2, class_acc2 = test(
                classifier, triggerDataLoader, num_class=args.num_class)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Clean Instance Accuracy: %f, Class Accuracy: %f' %
                       (instance_acc, class_acc))
            log_string('Best Clean Instance Accuracy: %f, Class Accuracy: %f' %
                       (best_instance_acc, best_class_acc))
            log_string(
                'Test Trigger Accuracy: %f, Trigger Class Accuracy: %f' %
                (instance_acc2, class_acc2))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                log_string('best_epoch %s' % str(best_epoch))
                state = {
                    'epoch': best_epoch,
                    'clean instance_acc': instance_acc,
                    'clean class_acc': class_acc,
                    'trigger instance_acc': instance_acc2,
                    'trigger class_acc': class_acc2,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

        logger_loss.append([train_loss, val_loss, val_loss2])
        logger_acc.append([train_instance_acc, instance_acc, instance_acc2])

        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

    logger_loss.close()
    logger_loss.plot()
    savefig(os.path.join(log_dir, 'log_loss_v3.eps'))
    logger_acc.close()
    logger_acc.plot()
    savefig(os.path.join(log_dir, 'log_acc_v3.eps'))

    log_string('best_epoch %s' % str(best_epoch))
    logger.info('End of training...')
Пример #15
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    if not args.reduced_computation:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
        num_workers = 8
    else:
        num_workers = 0


    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path(args.experiment_dir)
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    tensorboard_dir = experiment_dir.joinpath('tensorboard/')
    tensorboard_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    '''TENSORBOARD'''
    train_writer = SummaryWriter(tensorboard_dir.joinpath("train"))
    val_writer = SummaryWriter(tensorboard_dir.joinpath("validation"))

    '''DATA LOADING'''
    log_string('Load dataset ...')
    DATA_PATH = args.data_path

    class_in_filename = False if args.data_extension == ".npy" else True

    TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, split_name=args.split_name, extension=args.data_extension, npoint=args.num_point, split='train',
                                       normal_channel=args.normal, class_in_filename=class_in_filename, uniform=args.uniform, voxel_size=args.reduced_resolution_voxel_size)
    TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, split_name=args.split_name, extension=args.data_extension, npoint=args.num_point, split='validation',
                                      normal_channel=args.normal, class_in_filename=class_in_filename, uniform=args.uniform, voxel_size=args.reduced_resolution_voxel_size)
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=num_workers)

    '''MODEL LOADING'''
    num_class = args.num_classes
    MODEL = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))

    classifier = MODEL.get_model(num_class, normal_channel=args.normal)
    if not args.reduced_computation:
        classifier = classifier.cuda()
    criterion = MODEL.get_loss()
    if not args.reduced_computation:
        criterion = criterion.cuda()

    try:
        checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0


    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    else:
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10, min_lr=0.000001)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []

    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch,args.epoch):
        log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))

        mean_correct = []
        batch_tqdm = tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9)
        total_loss = 0
        predictions_likelihood_tot = torch.zeros([len(trainDataLoader.dataset), num_class])

        for batch_id, data in batch_tqdm:
            points, target = data
            points = points.data.numpy()
            if args.augment:
                points = provider.random_point_dropout(points)
                points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3])
                points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3])
            points = torch.Tensor(points)
            target = target[:, 0]

            points = points.transpose(2, 1)
            if not args.reduced_computation:
                points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()

            classifier = classifier.train()
            pred, trans_feat = classifier(points)
            loss = criterion(pred, target.long(), trans_feat)
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1
            total_loss += loss
            mean_loss = total_loss / (batch_id + 1)
            batch_tqdm.set_description(f"loss {mean_loss}, batch ({batch_id}/{len(trainDataLoader)})")
            preds_likelihood = torch.exp(pred)
            predictions_likelihood_tot[batch_id*trainDataLoader.batch_size:(batch_id+1)*trainDataLoader.batch_size] = preds_likelihood

        train_instance_acc = np.mean(mean_correct)
        log_string('Train Instance Accuracy: %f' % train_instance_acc)
        train_writer.add_scalar('Loss', mean_loss, epoch)
        train_writer.add_scalar('Accuracy', train_instance_acc, epoch)
        for cls in range(num_class):
            train_writer.add_histogram(f"class_{cls}", predictions_likelihood_tot[:, cls], epoch)

        with torch.no_grad():
            instance_acc, class_acc, val_loss = test(classifier.eval(), testDataLoader, criterion, num_class=num_class)
            scheduler.step(val_loss)
            val_writer.add_scalar('Loss', val_loss, epoch)
            val_writer.add_scalar('Accuracy', instance_acc, epoch)
            val_writer.add_scalar('Class_Accuracy', class_acc, epoch)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc))
            log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s'% savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
Пример #16
0
    def train_pointnet(self, trainDataLoader, testDataLoader):
        def log_string(str):
            logger.info(str)
            print(str)

        '''LOG'''
        logger = logging.getLogger("Model")
        logger.setLevel(logging.INFO)

        #         def test(model, loader, num_class=40):
        #             mean_correct = []
        #             class_acc = np.zeros((num_class,3))
        #             for j, data in tqdm(enumerate(loader), total=len(loader)):
        #                 points, target = data
        #                 target = target[:, 0]
        #                 points = points.transpose(2, 1)
        #                 points, target = points.cuda(), target.cuda()
        #                 classifier = model.eval()
        #                 pred, _ = classifier(points)
        #                 pred_choice = pred.data.max(1)[1]
        #                 for cat in np.unique(target.cpu()):
        #                     classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum()
        #                     class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0])
        #                     class_acc[cat,1]+=1
        #                 correct = pred_choice.eq(target.long().data).cpu().sum()
        #                 mean_correct.append(correct.item()/float(points.size()[0]))
        #             class_acc[:,2] =  class_acc[:,0]/ class_acc[:,1]
        #             class_acc = np.mean(class_acc[:,2])
        #             instance_acc = np.mean(mean_correct)
        #             return instance_acc, class_acc

        scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                    step_size=20,
                                                    gamma=0.7)
        global_epoch = 0
        global_step = 0
        best_instance_acc = 0.0
        best_class_acc = 0.0
        mean_correct = []

        for epoch in range(0, self.max_epochs):
            log_string('Epoch %d (%d/%s):' %
                       (global_epoch + 1, epoch + 1, self.max_epochs))
            scheduler.step()
            test_acc.get_acc()
            for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                       total=len(trainDataLoader),
                                       smoothing=0.9):
                points, target = data
                points = points.data.numpy()
                points = provider.random_point_dropout(points)
                points[:, :,
                       0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                       0:3])
                points[:, :, 0:3] = provider.shift_point_cloud(points[:, :,
                                                                      0:3])
                points = torch.Tensor(points)
                target = target[:, 0]
                points = points.transpose(2, 1)
                points, target = points.cuda(), target.cuda()
                points1, target1 = data
                points1 = points1.data.numpy()
                points1 = provider.random_point_dropout(points1)
                points1[:, :,
                        0:3] = provider.random_scale_point_cloud(points1[:, :,
                                                                         0:3])
                points1[:, :, 0:3] = provider.shift_point_cloud(points1[:, :,
                                                                        0:3])
                points1 = torch.Tensor(points1)
                target1 = target1[:, 0]
                points1 = points1.transpose(2, 1)
                points1, target1 = points1.cuda(), target1.cuda()
                loss = self.update(points, target, points1, testDataLoader)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self._update_target_network_parameters()
            self.save_model(os.path.join('checkpoints', 'model.pth'))
Пример #17
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''SET THE SEED'''
    setup_seed(args.seed)
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    # log_dir = experiment_dir.joinpath('logs/')
    log_dir = experiment_dir.joinpath('./')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA TYPE'''
    if args.use_voxel:
        assert "voxel" in args.dataset
        assert "mink" in args.model
        assert args.voxel_size > 0
    '''AUX SUPERVISION TYPE'''
    if args.aux == "pred":
        assert args.pred_path is not None

    if args.pred_path is not None:
        assert args.aux == "pred"

    args.with_pred = None
    args.with_instance = False
    args.with_seg = False
    if args.aux is not None:
        args.with_aux = True
        assert "scannet" in args.dataset
        if args.aux == "pred":
            args.with_pred = args.pred_path
        elif args.aux == "instance":
            args.with_instance = True
        elif args.aux == "seg":
            args.with_seg = True
        else:
            raise NotImplementedError
    else:
        args.with_aux = False
    '''DATA LOADING'''
    if "modelnet" in args.dataset:
        '''
        the modelnet 40 loading, support both the point & ME-point Ver.
        '''
        if not "voxel" in args.dataset:
            log_string('Load dataset {}'.format(args.dataset))
            num_class = 40
            DATA_PATH = './data/modelnet40_normal_resampled/'
            TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH,
                                               npoint=args.num_point,
                                               split='train',
                                               normal_channel=args.normal,
                                               apply_aug=True)
            TEST_DATASET = ModelNetDataLoader(root=DATA_PATH,
                                              npoint=args.num_point,
                                              split='test',
                                              normal_channel=args.normal)

            trainDataLoader = torch.utils.data.DataLoader(
                TRAIN_DATASET,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.num_worker)
            testDataLoader = torch.utils.data.DataLoader(
                TEST_DATASET,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.num_worker)
        else:
            assert args.dataset == 'modelnet_voxel'
            '''
            use the modelnet example dataloader from the ME-engine
            however, it seems still is point-based, retrun point features,
            and feed in the TenesorField, not really the voxel-modelnet
            '''
            log_string('Load dataset {}'.format(args.dataset))
            num_class = 40
            DATA_PATH = './data/modelnet40_ply_hdf5_2048'

            trainset = ModelNet40H5(
                phase="train",
                transform=CoordinateTransformation(trans=0.2),
                data_root=DATA_PATH,
            )
            testset = ModelNet40H5(
                phase="test",
                transform=None,  # no transform for test
                data_root=DATA_PATH,
            )

            trainDataLoader = DataLoader(
                trainset,
                num_workers=args.num_worker,
                shuffle=True,
                batch_size=args.batch_size,
                collate_fn=minkowski_collate_fn,
                pin_memory=True,
            )

            testDataLoader = DataLoader(
                testset,
                num_workers=args.num_worker,
                shuffle=False,
                batch_size=args.batch_size,
                collate_fn=minkowski_collate_fn,
                pin_memory=True,
            )

    elif args.dataset == "scanobjnn":
        log_string('Load dataset {}'.format(args.dataset))
        num_class = 15
        DATA_PATH = './data/scanobjnn/main_split_nobg'
        TRAIN_DATASET = ScanObjectNNDataLoader(root=DATA_PATH,
                                               npoint=args.num_point,
                                               split='train',
                                               normal_channel=args.normal)
        TEST_DATASET = ScanObjectNNDataLoader(root=DATA_PATH,
                                              npoint=args.num_point,
                                              split='test',
                                              normal_channel=args.normal)
        trainDataLoader = torch.utils.data.DataLoader(
            TRAIN_DATASET,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_worker)
        testDataLoader = torch.utils.data.DataLoader(
            TEST_DATASET,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_worker)

    elif "scannet" in args.dataset:
        num_class = 21
        if not "voxel" in args.dataset:
            if args.mode == "train":
                trainset = ScannetDataset(
                    root='./data/scannet_v2/scannet_pickles',
                    npoints=args.num_point,
                    split='train',
                    with_seg=args.with_seg,
                    with_instance=args.with_instance,
                    with_pred=args.pred_path)
                trainDataLoader = torch.utils.data.DataLoader(
                    trainset,
                    batch_size=args.batch_size,
                    shuffle=True,
                    num_workers=args.num_worker,
                    pin_memory=True)
            if args.mode == 'export':
                final_trainset = ScannetDatasetWholeScene_evaluation(root='./data/scannet_v2/scannet_pickles', scene_list_dir='./data/scannet_v2/metadata',split='train',block_points=args.num_point, with_rgb=True, with_norm=True,\
                                                                     with_seg=args.with_seg, with_instance=args.with_instance, with_pred=args.pred_path, delta=2.0)
                final_train_loader = torch.utils.data.DataLoader(
                    final_trainset,
                    batch_size=args.batch_size,
                    shuffle=False,
                    num_workers=0,
                    pin_memory=True)


            final_testset = ScannetDatasetWholeScene_evaluation(root='./data/scannet_v2/scannet_pickles', scene_list_dir='./data/scannet_v2/metadata',split='eval',block_points=args.num_point, with_rgb=True, with_norm=True, \
                                                                with_seg=args.with_seg, with_instance=args.with_instance, with_pred=args.pred_path, delta=1.0) # DEBUG: change to 1.0 to axquire proper
            final_test_loader = torch.utils.data.DataLoader(
                final_testset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=0,
                pin_memory=True)

            # generate the trainset as whole_dataset for export
        else:
            trainDataLoader = initialize_data_loader(
                DatasetClass=ScannetSparseVoxelizationDataset,
                data_root='data/scannet_v2/scannet_pickles',
                phase="train",
                threads=4,  # num-workers
                shuffle=True,
                repeat=False,
                augment_data=True,
                batch_size=16,
                limit_numpoints=1200000,
            )

            # TODO: the testloader

    else:
        raise NotImplementedError
    '''MODEL LOADING'''
    # copy files
    if args.mode == "train":
        if not os.path.exists(os.path.join(str(experiment_dir), 'model')):
            os.mkdir(os.path.join(str(experiment_dir), 'model'))
        for filename in os.listdir('./model'):
            if ".py" in filename:
                shutil.copy(os.path.join("./model", filename),
                            os.path.join(str(experiment_dir), 'model'))
        shutil.copy("./train_cls.py", str(experiment_dir))

    if "mink" not in args.model:
        # no use mink-net
        if "seg" in args.model:
            N = args.num_point
        else:
            N = args.num_point
        MODEL = importlib.import_module(args.model)
        classifier = MODEL.get_model(num_class,
                                     normal_channel=args.normal,
                                     N=N).cuda()
        criterion = MODEL.get_loss().cuda()
        classifier.loss = criterion
    else:
        '''
        The Voxel-based Networks based on the MinkowskiEngine
        '''
        # TODO: should align with above, using importlib.import_module, maybe fix later
        # classifier = ResNet14(in_channels=3, out_channels=num_class, D=3)  # D is the conv spatial dimension, 3 menas 3-d shapes
        if "pointnet" in args.model:
            classifier = MinkowskiPointNet(in_channel=3,
                                           out_channel=41,
                                           embedding_channel=1024,
                                           dimension=3).cuda()
        elif "trans" in args.model:
            classifier = MinkowskiTransformer(in_channel=3,
                                              out_channel=41,
                                              num_class=num_class,
                                              embedding_channel=1024,
                                              dimension=3).cuda()
        criterion = nn.CrossEntropyLoss().cuda()
        classifier.loss = criterion
    '''Loading existing ckpt'''
    try:
        if args.pretrain:
            # FIXME: currently only loading the best_model.pth, should support string, maybe latter
            checkpoint = torch.load(
                str(experiment_dir) + '/checkpoints/best_model.pth')
            start_epoch = checkpoint['epoch']
            classifier.load_state_dict(checkpoint['model_state_dict'])
            log_string('Use pretrain model')
            start_epoch = 0
        else:
            start_epoch = 0
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    elif args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(), \
                                    lr=args.learning_rate, momentum=0.9,\
                                    weight_decay=args.decay_rate)
    else:
        raise NotImplementedError

    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
    # Use MultiStepLR as in paper, decay by 10 at [120, 160]
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[120, 160], gamma=0.1)

    # FIXME:  for scannet, now using the cosine anneal
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           args.epoch,
                                                           eta_min=0.0)

    global_epoch = 0
    global_step = 0
    if "scannet" in args.dataset:
        best_mIoU = 0.0
    else:
        best_instance_acc = 0.0
        best_class_acc = 0.0
    mean_correct = []

    # only run for one epoch on the eval-only mode
    if args.mode == "eval" or args.mode == "export":
        assert args.pretrain
        start_epoch = 0
        args.epoch = 1

    # '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        log_string('Cur LR: {:.5f}'.format(optimizer.param_groups[0]['lr']))
        # when eval only, skip the traininig part

        if args.mode == "train":
            '''The main training-loop'''
            for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                       total=len(trainDataLoader),
                                       smoothing=0.9):
                if not args.use_voxel:
                    if "modelnet" in args.dataset:
                        # use points, normal unpacking
                        points, target = data
                        points = points.data.numpy()
                        points = provider.random_point_dropout(points)
                        # points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3])
                        # points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3])
                        points = torch.Tensor(points)
                        target = target[:, 0]

                        points = points.transpose(2, 1)
                        points, target = points.cuda(), target.cuda()
                    elif "scanobjnn" in args.dataset:
                        points, target, mask = data
                        points = points.data.numpy()
                        # TODO: move the aug in the dataset but not here
                        # points = provider.random_point_dropout(points)
                        points[:, :, 0:3] = provider.random_scale_point_cloud(
                            points[:, :, 0:3])
                        points[:, :,
                               0:3] = provider.shift_point_cloud(points[:, :,
                                                                        0:3])
                        points = torch.Tensor(points)
                        points = points.transpose(2, 1)
                        points, target = points.cuda(), target.cuda()
                    elif "scannet" in args.dataset:
                        # TODO: fiil the scannet loading here
                        # TODO: maybe implement the grad-accmu/or simply not
                        if args.aux is not None:
                            points, target, sample_weight, aux = data
                            points, target, sample_weight, aux = points.float(
                            ).transpose(1, 2).cuda(), target.cuda(
                            ), sample_weight.cuda(), aux.cuda()
                        else:
                            points, target, sample_weight = data
                            points, target, sample_weight = points.float(
                            ).transpose(
                                1,
                                2).cuda(), target.cuda(), sample_weight.cuda()

                else:
                    if "modelnet" in args.dataset:
                        # use voxel
                        # points = create_input_batch(data, True, 'cuda', quantization_size=args.voxel_size)
                        data['coordinates'][:, 1:] = data[
                            'coordinates'][:, 1:] / args.voxel_size
                        points = ME.TensorField(
                            coordinates=(data['coordinates'].cuda()),
                            features=data['features'].cuda())
                        target = data['labels'].cuda()
                    elif "scannet" in args.dataset:
                        dat = ME.SparseTensor(features=data[1],
                                              coordinates=data[0]).cuda()
                        target = data[2].cuda()

                optimizer.zero_grad()
                '''save the intermediate attention map'''
                # WANINIG: DISABLED FOR NOW!!!
                SAVE_INTERVAL = 50
                NUM_PER_EPOCH = 1

                if (epoch + 1) % SAVE_INTERVAL == 0:
                    if batch_id < NUM_PER_EPOCH:
                        classifier.save_flag = True
                    elif batch_id == NUM_PER_EPOCH:
                        intermediate_dict = classifier.save_intermediate()
                        intermediate_path = os.path.join(
                            experiment_dir, "attn")
                        if not os.path.exists(intermediate_path):
                            os.mkdir(intermediate_path)
                        torch.save(
                            intermediate_dict,
                            os.path.join(intermediate_path,
                                         "epoch_{}".format(epoch)))
                        log_string('Saved Intermediate at {}'.format(epoch))
                    else:
                        classifier.save_flag = False
                else:
                    classifier.save_flag = False

                classifier = classifier.train()
                # when with-instance, use instance label to guide the point-transformer training
                if args.aux is not None:
                    pred = classifier(points, aux)
                else:
                    pred = classifier(points)
                # if use_voxel, get the feature from the SparseTensor
                if args.use_voxel:
                    pred = pred.F
                if 'scannet' in args.dataset:
                    loss = criterion(pred, target.long(), sample_weight)
                else:
                    loss = criterion(pred, target.long())
                loss.backward()
                optimizer.step()
                global_step += 1

                if "scannet" in args.dataset:
                    pred_choice = torch.argmax(pred,
                                               dim=2).cpu().numpy()  # B,N
                    target = target.cpu().numpy()
                    correct = np.sum(pred_choice == target)
                    mean_correct.append(correct / pred_choice.size)
                else:
                    pred_choice = pred.data.max(1)[1]
                    correct = pred_choice.eq(target.long().data).cpu().sum()
                    mean_correct.append(correct.item() /
                                        float(points.size()[0]))

            train_instance_acc = np.mean(mean_correct)
            log_string('Train Instance Accuracy: %f' % train_instance_acc)
            '''TEST'''
            if not "scannet" in args.dataset:
                # WARNING: Temporarily disable eval for scannet for now, just test at last
                if (epoch + 1) % 20 == 0:
                    with torch.no_grad():
                        returned_metric = test(classifier.eval(),
                                               testDataLoader,
                                               num_class=num_class,
                                               log_string=log_string)

                    if 'scannet' in args.dataset:
                        mIoU = returned_metric
                        if (mIoU >= best_mIoU):
                            best_mIoU = mIoU
                            best_epoch = epoch + 1

                        if (mIoU >= best_mIoU):
                            logger.info('Save model...')
                            savepath = str(checkpoints_dir) + '/best_model.pth'
                            log_string('Saving at %s' % savepath)
                            state = {
                                'epoch': best_epoch,
                                'mIoU': mIoU,
                                'model_state_dict': classifier.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                            }
                            torch.save(state, savepath)
                    else:
                        instance_acc, class_acc = returned_metric

                        if (instance_acc >= best_instance_acc):
                            best_instance_acc = instance_acc
                            best_epoch = epoch + 1

                        if (class_acc >= best_class_acc):
                            best_class_acc = class_acc

                        log_string(
                            'Test Instance Accuracy: %f, Class Accuracy: %f' %
                            (instance_acc, class_acc))
                        log_string(
                            'Best Instance Accuracy: %f, Class Accuracy: %f' %
                            (best_instance_acc, best_class_acc))

                        if (instance_acc >= best_instance_acc):
                            logger.info('Save model...')
                            savepath = str(checkpoints_dir) + '/best_model.pth'
                            log_string('Saving at %s' % savepath)
                            state = {
                                'epoch': best_epoch,
                                'instance_acc': instance_acc,
                                'class_acc': class_acc,
                                'model_state_dict': classifier.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                            }
                            torch.save(state, savepath)

        global_epoch += 1

    # final save of the model
    logger.info('Save model...')
    savepath = str(checkpoints_dir) + '/final_model.pth'
    log_string('Saving at %s' % savepath)
    state = {
        'epoch': global_epoch,
        # 'mIoU': mIoU,
        'model_state_dict': classifier.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(state, savepath)

    # for the scannet dataset, test at last
    if args.dataset == 'scannet':
        if not os.path.exists(os.path.join(str(experiment_dir), 'pred')):
            os.mkdir(os.path.join(str(experiment_dir), 'pred'))
        if args.mode == "export":
            test_scannet(args,
                         classifier.eval(),
                         final_test_loader,
                         log_string,
                         with_aux=args.with_aux,
                         save_dir=os.path.join(str(experiment_dir), 'pred'),
                         split='eval')
            test_scannet(args,
                         classifier.eval(),
                         final_train_loader,
                         log_string,
                         with_aux=args.with_aux,
                         save_dir=os.path.join(str(experiment_dir), 'pred'),
                         split='train')
        else:
            test_scannet(args,
                         classifier.eval(),
                         final_test_loader,
                         log_string,
                         with_aux=args.with_aux,
                         split='eval')

    # final save of the model
    logger.info('Save model...')
    savepath = str(checkpoints_dir) + '/best_model.pth'
    log_string('Saving at %s' % savepath)
    state = {
        'epoch': global_epoch,
        # 'mIoU': mIoU,
        'model_state_dict': classifier.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(state, savepath)

    logger.info('End of training...')
Пример #18
0
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    print(str(datetime.now()))

    # Make sure batch data is of same size
    cur_batch_data = np.zeros(
        (BATCH_SIZE, NUM_POINT, TRAIN_DATASET.num_channel()))
    cur_batch_data_ORIGIN = np.zeros((BATCH_SIZE, NUM_POINT, 3))
    cur_batch_label = np.zeros((BATCH_SIZE), dtype=np.int32)

    total_correct = 0
    total_seen = 0
    reconstruct_sum = 0.
    margin_sum = 0.
    mse_sum = np.array([0. for _ in range(FLAGS.iter_routing)])
    l2_loss_sum = 0.
    loss_sum = 0.
    batch_idx = 0
    while TRAIN_DATASET.has_next_batch():
        batch_data, batch_label = TRAIN_DATASET.next_batch(augment=AUGMENT)
        bsize = batch_data.shape[0]

        cur_batch_data_ORIGIN[0:bsize, ...] = copy.deepcopy(batch_data[:, :,
                                                                       0:3])
        batch_data = provider.random_point_dropout(batch_data,
                                                   max_dropout_ratio=0.5)
        #cur_batch_data_ORIGIN[0:bsize,...] = batch_data[:,:,0:3]
        cur_batch_data[0:bsize, ...] = batch_data
        cur_batch_label[0:bsize] = batch_label

        feed_dict = {
            ops['pointclouds_pl']: cur_batch_data,
            ops['labels_pl']: cur_batch_label,
            ops['is_training_pl']: is_training,
            ops['pointclouds_pl_ORIGIN']: cur_batch_data_ORIGIN
        }
        u_v_list, summary, step, _, loss_val, pred_val, margin_loss, reconstruct_loss, l2_loss = sess.run(
            [
                ops['mse'], ops['merged'], ops['step'], ops['train_op'],
                ops['loss'], ops['pred'], ops['margin_loss'],
                ops['reconstruct_loss'], ops['l2_loss']
            ],
            feed_dict=feed_dict)
        #sess.run(ops['reset_b_IJ'])
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val[0:bsize] == batch_label[0:bsize])
        mse_sum += u_v_list
        reconstruct_sum += reconstruct_loss
        margin_sum += margin_loss
        total_correct += correct
        total_seen += bsize
        loss_sum += loss_val
        l2_loss_sum += l2_loss
        if (batch_idx + 1) % 50 == 0:
            mse_str = str([round(x, 3) for x in (mse_sum / 50).tolist()])
            print(' ---- batch: %03d ----' % (batch_idx + 1))
            print('mean loss: %.5f | accuracy: %.5f | margin_loss: %.5f | rec_loss: %.5f | mse: %s | l2_loss: %.4f | step: %d' \
                %(loss_sum/50, total_correct/float(total_seen), margin_sum/50, reconstruct_sum/50, mse_str, l2_loss_sum/50, int(step)))
            total_correct = 0
            total_seen = 0
            loss_sum = 0
            reconstruct_sum = 0.
            margin_sum = 0.
            mse_sum *= 0.
            l2_loss_sum = 0.
        batch_idx += 1

    TRAIN_DATASET.reset()
Пример #19
0
def main():
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.remark != None:
        args.remark = args.remark
    else:
        args.remark = args.dataset + "-" + args.task + "-" + args.norm

    if args.dataset == "shapenet":
        args.num_class = 16
    else:
        args.num_class = 40

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    # experiment_dir = Path('./exp/v3/')
    if args.task == 'baseline':
        experiment_dir = Path('/data-x/g12/zhangjie/3dIP/baseline')
    else:
        experiment_dir = Path('/data-x/g12/zhangjie/3dIP/ours')

    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.remark + "_" + timestr)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG_curve'''
    title = args.dataset + "-" + args.task + "-" + args.norm
    logger_loss = Logger(os.path.join(log_dir, 'log_loss_v3.txt'), title=title)
    logger_loss.set_names([
        'Train Pub&Pri  Loss',
        'Train Public Loss',
        'Train Private Loss',
        'Valid Pub-Clean loss',
        'Valid Pub-Trigger Loss',
        'Valid Pri-Clean Loss',
        'Valid Pri-Trigger Loss',
    ])
    logger_acc = Logger(os.path.join(log_dir, 'log_acc_v3.txt'), title=title)
    logger_acc.set_names([
        'Train Pub-Combine  Acc.', 'Valid Pub-Clean Acc.',
        'Valid Pub-Trigger Acc.', 'Train Pri-Combine  Acc.',
        'Valid Pri-Clean Acc.', 'Valid Pri-Trigger Acc.'
    ])
    '''LOG'''  #创建log文件
    logger = logging.getLogger("Model")  #log的名字
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)  #log的最低等级
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log文件名
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    if args.dataset == "shapenet":
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=True,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=True,
                                                batchsize=args.batch_size)
        triggerDataLoader = getData2.get_dataLoader(Shapenet=True,
                                                    T1=args.T1,
                                                    batchsize=args.batch_size)
    else:
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=False,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=False,
                                                batchsize=args.batch_size)
        triggerDataLoader = getData2.get_dataLoader(Shapenet=False,
                                                    T1=args.T1,
                                                    batchsize=args.batch_size)

    wminputs, wmtargets = [], []
    for wm_idx, (wminput, wmtarget) in enumerate(triggerDataLoader):
        wminputs.append(wminput)
        wmtargets.append(wmtarget)
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)

    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    shutil.copy('train_3_cls.py', str(experiment_dir))
    shutil.copy('./data/getData.py', str(experiment_dir))
    shutil.copy('./data/getData2.py', str(experiment_dir))
    shutil.copytree('./models/layers', str(experiment_dir) + "/layers")

    classifier = MODEL.get_model(num_class, channel=3).cuda()
    criterion = MODEL.get_loss().cuda()

    pprint(classifier)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    mean_correct2 = []
    mean_loss = []
    mean_loss1 = []
    mean_loss2 = []
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        time_start = datetime.datetime.now()
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        wm_id = np.random.randint(len(wminputs))

        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            points = torch.cat(
                [points, wminputs[(wm_id + batch_id) % len(wminputs)]],
                dim=0)  #随机选择wininputs和inputscat
            target = torch.cat(
                [target, wmtargets[(wm_id + batch_id) % len(wminputs)]], dim=0)
            points = points.data.numpy()
            points = provider.random_point_dropout(
                points)  #provider是自己写的一个对点云操作的函数,随机dropout,置为第一个点的值
            points[:, :,
                   0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                   0:3])  #点的放缩
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :,
                                                                  0:3])  #点的偏移
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()

            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    m.reset()

            loss1 = torch.tensor(0.).cuda()
            loss2 = torch.tensor(0.).cuda()
            sign_loss = torch.tensor(0.).cuda()

            for ind in range(2):
                if ind == 0:
                    pred, trans_feat = classifier(points, ind=ind)
                    loss1 = criterion(pred, target.long(), trans_feat)
                    mean_loss1.append(loss1.item() / float(points.size()[0]))
                    pred_choice = pred.data.max(1)[1]
                    correct = pred_choice.eq(target.long().data).cpu().sum()
                    mean_correct.append(correct.item() /
                                        float(points.size()[0]))

                else:
                    pred2, trans_feat2 = classifier(points, ind=ind)
                    loss2 = criterion(pred2, target.long(), trans_feat2)
                    mean_loss2.append(loss2.item() / float(points.size()[0]))
                    pred_choice2 = pred2.data.max(1)[1]
                    correct2 = pred_choice2.eq(target.long().data).cpu().sum()
                    mean_correct2.append(correct2.item() /
                                         float(points.size()[0]))

            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    sign_loss += m.loss

            loss = args.beta * loss1 + loss2 + sign_loss
            mean_loss.append(loss.item() / float(points.size()[0]))

            loss.backward()
            optimizer.step()
            global_step += 1

        train_instance_acc = np.mean(mean_correct)
        train_instance_acc2 = np.mean(mean_correct2)

        train_loss = np.mean(mean_loss)
        train_loss1 = np.mean(mean_loss1)
        train_loss2 = np.mean(mean_loss2)
        log_string('Train Combine Public Accuracy: %f' % train_instance_acc)
        log_string('Train Combine Private Accuracy: %f' % train_instance_acc2)

        sign_acc = torch.tensor(0.).cuda()
        count = 0

        for m in classifier.modules():
            if isinstance(m, SignLoss):
                sign_acc += m.acc
                count += 1

        if count != 0:
            sign_acc /= count

        log_string('Sign Accuracy: %f' % sign_acc)

        res = {}
        avg_private = 0
        count_private = 0

        with torch.no_grad():
            if args.task == 'ours':
                for name, m in classifier.named_modules():
                    if name in [
                            'convp1', 'convp2', 'convp3', 'p1', 'p2', 'p3'
                    ]:
                        signbit, _ = m.get_scale(ind=1)
                        signbit = signbit.view(-1).sign()
                        privatebit = m.b

                        detection = (
                            signbit == privatebit).float().mean().item()
                        res['private_' + name] = detection
                        avg_private += detection
                        count_private += 1

            elif args.task == 'baseline':
                for name, m in classifier.named_modules():
                    if name in [
                            'convp1', 'convp2', 'convp3', 'p1', 'p2', 'p3'
                    ]:
                        signbit = m.get_scale(ind=1).view(-1).sign()
                        privatebit = m.b

                        detection = (
                            signbit == privatebit).float().mean().item()
                        res['private_' + name] = detection
                        avg_private += detection
                        count_private += 1

            log_string('Private Sign Detection Accuracy: %f' %
                       (avg_private / count_private * 100))

            for ind in range(2):
                if ind == 0:
                    val_loss1, test_instance_acc1, class_acc1 = test(
                        classifier,
                        testDataLoader,
                        num_class=args.num_class,
                        ind=0)
                    val_loss_wm1, instance_acc_wm, class_acc_wm = test(
                        classifier,
                        triggerDataLoader,
                        num_class=args.num_class,
                        ind=0)
                else:
                    val_loss2, test_instance_acc2, class_acc2 = test(
                        classifier,
                        testDataLoader,
                        num_class=args.num_class,
                        ind=1)
                    val_loss_wm2, instance_acc_wm2, class_acc_wm2 = test(
                        classifier,
                        triggerDataLoader,
                        num_class=args.num_class,
                        ind=1)

            log_string(
                'Test Clean Public Accuracy: %f, Class Public Accuracy: %f' %
                (test_instance_acc1, class_acc1))
            log_string(
                'Test Clean Private Accuracy: %f, Class Private Accuracy: %f' %
                (test_instance_acc2, class_acc2))
            log_string(
                'Test Trigger Public Accuracy: %f, Trigger Class Public Accuracy: %f'
                % (instance_acc_wm, class_acc_wm))
            log_string(
                'Test Trigger Private Accuracy: %f, Trigger Class Private Accuracy: %f'
                % (instance_acc_wm2, class_acc_wm2))

            test_instance_acc = (test_instance_acc1 + test_instance_acc2) / 2
            class_acc = (class_acc1 + class_acc2) / 2

            if (test_instance_acc >= best_instance_acc):
                best_instance_acc = test_instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string(
                'Test Combine Average Accuracy: %f, Class Average Accuracy: %f'
                % (test_instance_acc, class_acc))
            log_string(
                'Best Combine Average Accuracy: %f, Class Average Accuracy: %f'
                % (best_instance_acc, best_class_acc))

            if (test_instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                log_string('best_epoch %s' % str(best_epoch))
                state = {
                    'epoch': best_epoch,
                    'instance_acc': test_instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

        logger_loss.append([
            train_loss, train_loss1, train_loss2, val_loss1, val_loss_wm1,
            val_loss2, val_loss_wm2
        ])
        logger_acc.append([
            train_instance_acc, test_instance_acc1, instance_acc_wm,
            train_instance_acc2, test_instance_acc2, instance_acc_wm2
        ])

        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

    logger_loss.close()
    logger_loss.plot()
    savefig(os.path.join(log_dir, 'log_loss_v3.eps'))
    logger_acc.close()
    logger_acc.plot()
    savefig(os.path.join(log_dir, 'log_acc_v3.eps'))

    log_string('best_epoch %s' % str(best_epoch))
    logger.info('End of training...')
Пример #20
0
valid_acc_curve = list()

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()  # 切换到训练模式
    for i, data in enumerate(trainloader):

        # 读取数据并数据增强
        points, labels = data['points'], data['label'].squeeze().long(
        )  # 需要标签是Long类型
        points = points.data.numpy()
        points = provider.random_point_dropout(points)  # 随机舍弃
        points[:, :,
               0:3] = provider.random_scale_point_cloud(points[:, :,
                                                               0:3])  # 随机放缩
        points[:, :, 0:3] = provider.shift_point_cloud(points[:, :,
                                                              0:3])  # 随机偏移
        points = torch.Tensor(points)
        points = points.transpose(2,
                                  1)  # [Batchsize, N, C] -> [Batchsize, C, N]

        # forward
        points = points.to(device)
        labels = labels.to(device)
        if args.pointnet:
            outputs, trans_feat = net(points)
        else:
Пример #21
0
def main():
    args = parser.parse_args()
    # torch.manual_seed(args.seed)
    # torch.cuda.manual_seed(args.seed)
    # np.random.seed(args.seed)

    if args.remark != None:
        args.remark = args.remark
    else:
        args.remark = args.dataset + "-" + args.task + "-" + args.norm

    if args.dataset == "shapenet":
        args.num_class = 16
    else:
        args.num_class = 40

    if 'fake2-' in args.type:
        args.flipperc = 0
        print('No Flip')
    elif 'fake3-' in args.type:
        args.flipperc = 0
        print('No Flip')

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    if args.task == 'baseline':
        experiment_dir_root = Path('/data-x/g12/zhangjie/3dIP/baseline')
    else:
        experiment_dir_root = Path('/data-x/g12/zhangjie/3dIP/ours')

    experiment_dir_root.mkdir(exist_ok=True)
    experiment_dir = experiment_dir_root.joinpath('ambiguity_attack')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.remark)
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(args.type)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    logger = logging.getLogger("Model")  #log name
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  #log file name
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    if args.dataset == "shapenet":
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=True,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=True,
                                                batchsize=args.batch_size)
    else:
        trainDataLoader = getData.get_dataLoader(train=True,
                                                 Shapenet=False,
                                                 batchsize=args.batch_size)
        testDataLoader = getData.get_dataLoader(train=False,
                                                Shapenet=False,
                                                batchsize=args.batch_size)

    log_string('Finished ...')
    log_string('Load model ...')
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)

    #  copy model file to exp dir
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    shutil.copy('attack2.py', str(experiment_dir))

    classifier = MODEL.get_model(num_class, channel=3).cuda()
    # criterion = MODEL.get_loss().cuda()
    criterion = nn.NLLLoss().cuda()

    sd = experiment_dir_root.joinpath('classification')
    sd.mkdir(exist_ok=True)
    sd = sd.joinpath(str(args.remark))
    sd.mkdir(exist_ok=True)
    sd = sd.joinpath('checkpoints/best_model.pth')

    checkpoint = torch.load(sd)
    classifier.load_state_dict(checkpoint['model_state_dict'])

    for param in classifier.parameters():
        param.requires_grad_(False)

    origpassport = []
    fakepassport = []

    for n, m in classifier.named_modules():
        if n in ['convp1', 'convp2', 'convp3', 'p1', 'p2', 'fc3']:
            key, skey = m.__getattr__('key_private').data.clone(
            ), m.__getattr__('skey_private').data.clone()
            origpassport.append(key.cuda())
            origpassport.append(skey.cuda())

            m.__delattr__('key_private')  # 删除属性
            m.__delattr__('skey_private')

            # fake like random onise
            if 'fake2-' in args.type:
                # fake random
                m.register_parameter(
                    'key_private',
                    nn.Parameter(torch.randn(*key.size()) * 0.001,
                                 requires_grad=True))
                m.register_parameter(
                    'skey_private',
                    nn.Parameter(torch.randn(*skey.size()) * 0.001,
                                 requires_grad=True))

            # fake slightly modify ori
            else:
                # fake slightly modify ori
                m.register_parameter(
                    'key_private',
                    nn.Parameter(key.clone() +
                                 torch.randn(*key.size()) * 0.001,
                                 requires_grad=True))
                m.register_parameter(
                    'skey_private',
                    nn.Parameter(skey.clone() +
                                 torch.randn(*skey.size()) * 0.001,
                                 requires_grad=True))

            fakepassport.append(m.__getattr__('key_private'))
            fakepassport.append(m.__getattr__('skey_private'))

            if args.task == 'ours':
                if args.type != 'fake2':

                    for layer in m.fc.modules():
                        if isinstance(layer, nn.Linear):
                            nn.init.xavier_normal_(layer.weight)

                    for i in m.fc.parameters():
                        i.requires_grad = True

    if args.flipperc != 0:
        log_string(f'Reverse {args.flipperc * 100:.2f}% of binary signature')

        for name, m in classifier.named_modules():
            if name in ['convp1', 'convp2', 'convp3', 'p1', 'p2', 'p3']:
                mflip = args.flipperc
                oldb = m.sign_loss_private.b
                newb = oldb.clone()
                npidx = np.arange(len(oldb))  # bit 长度
                randsize = int(oldb.view(-1).size(0) * mflip)
                randomidx = np.random.choice(npidx, randsize,
                                             replace=False)  # 随机选择
                newb[randomidx] = oldb[randomidx] * -1  # reverse bit  进行翻转
                m.sign_loss_private.set_b(newb)

    classifier.cuda()
    optimizer = torch.optim.SGD(fakepassport,
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=0.0005)

    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
    scheduler = None

    def run_cs():
        cs = []

        for d1, d2 in zip(origpassport, fakepassport):
            d1 = d1.view(d1.size(0), -1)
            d2 = d2.view(d2.size(0), -1)

            cs.append(F.cosine_similarity(d1, d2).item())

        return cs

    classifier.train()
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct2 = []
    mean_loss2 = []
    start_epoch = 0

    mse_criterion = nn.MSELoss()
    cs_criterion = nn.CosineSimilarity()
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        time_start = datetime.datetime.now()
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))
        optimizer.zero_grad()
        signacc_meter = 0
        signloss_meter = 0

        if scheduler is not None:
            scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            points = points.data.numpy()
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                         0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()

            # loss define
            pred2, _ = classifier(points, ind=1)
            loss2 = criterion(pred2, target.long())
            mean_loss2.append(loss2.item() / float(points.size()[0]))
            pred_choice2 = pred2.data.max(1)[1]
            correct2 = pred_choice2.eq(target.long().data).cpu().sum()
            mean_correct2.append(correct2.item() / float(points.size()[0]))

            signacc = torch.tensor(0.).cuda()
            count = 0
            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    signacc += m.get_acc()
                    count += 1
            try:
                signacc_meter += signacc.item() / count
            except:
                pass

            sign_loss = torch.tensor(0.).cuda()
            for m in classifier.modules():
                if isinstance(m, SignLoss):
                    sign_loss += m.loss
            signloss_meter += sign_loss

            loss = loss2
            maximizeloss = torch.tensor(0.).cuda()
            mseloss = torch.tensor(0.).cuda()
            csloss = torch.tensor(0.).cuda()

            for l, r in zip(origpassport, fakepassport):
                mse = mse_criterion(l, r)
                cs = cs_criterion(l.view(1, -1), r.view(1, -1)).mean()
                csloss += cs
                mseloss += mse
                maximizeloss += 1 / mse

            if 'fake2-' in args.type:
                (loss).backward()  # only cross-entropy loss  backward  fake2
            elif 'fake3-' in args.type:
                (loss +
                 maximizeloss).backward()  # csloss do not backward   kafe3

            else:
                (loss + maximizeloss +
                 1000 * sign_loss).backward()  # csloss  backward   #fake3_S
                # (loss  + 1000 * sign_loss).backward()  # csloss  backward   #fake3_S

            torch.nn.utils.clip_grad_norm_(fakepassport, 2)

            optimizer.step()
            global_step += 1

        signacc = signacc_meter / len(trainDataLoader)
        log_string('Train Sign Accuracy: %f' % signacc)

        signloss = signloss_meter / len(trainDataLoader)
        log_string('Train Sign Loss: %f' % signloss)

        train_instance_acc2 = np.mean(mean_correct2)
        log_string('Train Instance Private Accuracy: %f' % train_instance_acc2)

        with torch.no_grad():
            cs = run_cs()
            log_string(
                f'Cosine Similarity of Real and Maximize passport: {sum(cs) / len(origpassport):.4f}'
            )
            val_loss2, test_instance_acc2, class_acc2, singloss2, signacc2 = test(
                classifier, testDataLoader, num_class=args.num_class, ind=1)

            log_string(
                'Test Instance Private Accuracy: %f, Class Private Accuracy: %f'
                % (test_instance_acc2, class_acc2))
            log_string('Test Private Sign Accuracy: %f' % (signacc2))

            test_instance_acc = test_instance_acc2
            class_acc = class_acc2

            if (test_instance_acc >= best_instance_acc):
                best_instance_acc = test_instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string(
                'Test Instance Average Accuracy: %f, Class Average Accuracy: %f'
                % (test_instance_acc, class_acc))
            log_string(
                'Best Instance Average Accuracy: %f, Class Average Accuracy: %f'
                % (best_instance_acc, best_class_acc))

            if (test_instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_attack_model.pth'
                log_string('Saving at %s' % savepath)
                log_string('best_epoch %s' % str(best_epoch))
                state = {
                    'epoch': best_epoch,
                    'instance_acc': test_instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'origpassport': origpassport,
                    'fakepassport': fakepassport
                }
                torch.save(state, savepath)
            global_epoch += 1

        time_end = datetime.datetime.now()
        time_span_str = str((time_end - time_start).seconds)
        log_string('Epoch time : %s S' % (time_span_str))

    log_string('best_epoch %s' % str(best_epoch))

    logger.info('End of training...')
Пример #22
0
def main(args):
    omegaconf.OmegaConf.set_struct(args, False)
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    logger = logging.getLogger(__name__)

    print(args.pretty())
    '''DATA LOADING'''
    logger.info('Load dataset ...')
    DATA_PATH = hydra.utils.to_absolute_path('modelnet40_normal_resampled/')

    TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH,
                                       npoint=args.num_point,
                                       split='train',
                                       normal_channel=args.normal)
    TEST_DATASET = ModelNetDataLoader(root=DATA_PATH,
                                      npoint=args.num_point,
                                      split='test',
                                      normal_channel=args.normal)
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=4)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=4)
    '''MODEL LOADING'''
    args.num_class = 40
    args.input_dim = 6 if args.normal else 3
    shutil.copy(
        hydra.utils.to_absolute_path('models/{}/model.py'.format(
            args.model.name)), '.')

    classifier = getattr(
        importlib.import_module('models.{}.model'.format(args.model.name)),
        'PointTransformer')(args).cuda()
    criterion = torch.nn.CrossEntropyLoss()

    try:
        checkpoint = torch.load('best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        logger.info('Use pretrain model')
    except:
        logger.info('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=50,
                                                gamma=0.3)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    best_epoch = 0
    mean_correct = []
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        logger.info('Epoch %d (%d/%s):' %
                    (global_epoch + 1, epoch + 1, args.epoch))

        classifier.train()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            points = points.data.numpy()
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                         0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            target = target[:, 0]

            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()

            pred = classifier(points)
            loss = criterion(pred, target.long())
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1

        scheduler.step()

        train_instance_acc = np.mean(mean_correct)
        logger.info('Train Instance Accuracy: %f' % train_instance_acc)

        with torch.no_grad():
            instance_acc, class_acc = test(classifier.eval(), testDataLoader)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            logger.info('Test Instance Accuracy: %f, Class Accuracy: %f' %
                        (instance_acc, class_acc))
            logger.info('Best Instance Accuracy: %f, Class Accuracy: %f' %
                        (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = 'best_model.pth'
                logger.info('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
Пример #23
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    dataset_name = args.dataset_name
    experiment_dir = experiment_dir.joinpath(
        'classification_{}'.format(dataset_name))
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''TENSORBOARD LOG'''
    writer = SummaryWriter()
    '''DATA LOADING'''
    log_string('Load dataset ...')

    DATA_PATH = os.path.join(ROOT_DIR, 'data', dataset_name)

    print("loading dataset from {}".format(dataset_name))
    if 'modelnet' in dataset_name:
        TRAIN_DATASET = ModelNetDataLoader(DATA_PATH,
                                           split='train',
                                           normal_channel=args.normal)
        TEST_DATASET = ModelNetDataLoader(DATA_PATH,
                                          split='test',
                                          normal_channel=args.normal)
        num_class = 40
    else:
        print(DATA_PATH)
        TRAIN_DATASET = ReplicaDataLoader(DATA_PATH,
                                          split='train',
                                          uniform=True,
                                          normal_channel=False,
                                          rot_transform=True)
        TEST_DATASET = ReplicaDataLoader(DATA_PATH,
                                         split='test',
                                         uniform=True,
                                         normal_channel=False,
                                         rot_transform=False)
        num_class = 31

    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=6)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=6)
    '''MODEL LOADING'''
    print("Number of classes are {:d}".format(num_class))
    MODEL = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))

    print("Obtain GPU device ")
    train_GPU = True
    device = torch.device("cuda" if (
        torch.cuda.is_available() and train_GPU) else "cpu")
    print(device)
    print("Load the network to the device ")
    classifier = MODEL.get_model(num_class,
                                 normal_channel=args.normal).to(device)
    print("Load the loss to the device ")
    criterion = MODEL.get_loss().to(device)

    if os.path.exists((str(experiment_dir) + '/checkpoints/best_model.pth')):
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])

        # strict set to false to allow using the model trained with modelnet
    else:
        start_epoch = 0
        if dataset_name == 'replica':
            log_string('Use pretrain model of Model net')
            # double check again if there is pretrained modelnet model
            checkpoint = torch.load(
                str(experiment_dir).replace("replica",
                                            'modelnet40_normal_resampled') +
                '/checkpoints/best_model.pth')
            classifier = MODEL.get_model(40,
                                         normal_channel=args.normal).to(device)
            classifier.load_state_dict(checkpoint['model_state_dict'])
            classifier.fc3 = nn.Linear(256, num_class).to(device)
            print(classifier)
        else:
            log_string('No existing model, starting training from scratch...')

    if args.optimizer == 'Adam':
        print("Using Adam opimizer ")
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        loss_array = np.zeros((len(trainDataLoader), 1))
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))
        classifier.train()  # setting the model to train mode
        print("Clear GPU cache ...")
        torch.cuda.empty_cache()
        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            points = points.data.numpy()
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                         0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            target = target[:, 0]

            points = points.transpose(2, 1)
            points, target = points.to(device), target.to(device)

            optimizer.zero_grad()

            pred, trans_feat = classifier(
                points)  ### This is the part of the runtime error:

            loss = criterion(pred, target.long(), trans_feat)
            loss_array[batch_id] = loss.cpu().detach().numpy()

            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1

            train_instance_acc = np.mean(mean_correct)
            log_string('Train Instance Accuracy: %f' % train_instance_acc)
        avg_loss = np.mean(loss_array[:])
        writer.add_scalar("Loss/train", avg_loss, epoch)

        ## This is for validation
        with torch.no_grad():
            instance_acc, class_acc = test(classifier.eval(), testDataLoader,
                                           device, num_class)

            writer.add_scalar("ClassAccuracy/test", class_acc, epoch)
            writer.add_scalar("InstanceAccuracy/test", instance_acc, epoch)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Instance Accuracy: %f, Class Accuracy: %f' %
                       (instance_acc, class_acc))
            log_string('Best Instance Accuracy: %f, Class Accuracy: %f' %
                       (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'.format(
                    epoch)
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
    writer.flush()
    writer.close()
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('cls')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    # DATA_PATH = 'data/modelnet40_normal_resampled/'
    DATA_PATH = args.data_dir

    # if args.model == 'pointcnn_cls':
    #     trainDataLoader = PyGDataloader(TRAIN_DATASET, args.batch_size, shuffle=True)
    #     testDataLoader = PyGDataloader(TEST_DATASET, args.batch_size, shuffle=False)
    # else:
    TRAIN_DATASET = ClsDataLoader(root=DATA_PATH,
                                  dataset_name=args.dataset_name,
                                  npoint=args.num_point,
                                  split='train',
                                  normal_channel=args.normal)
    TEST_DATASET = ClsDataLoader(root=DATA_PATH,
                                 dataset_name=args.dataset_name,
                                 npoint=args.num_point,
                                 split='test',
                                 normal_channel=args.normal)
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=args.num_worker,
                                                  drop_last=True)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.num_worker,
                                                 drop_last=True)
    '''MODEL LOADING'''
    num_class = args.num_class
    MODEL = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))

    classifier = MODEL.get_model(num_class, normal_channel=args.normal).cuda()
    criterion = MODEL.get_loss().cuda()

    # try:
    #     checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
    #     start_epoch = checkpoint['epoch']
    #     classifier.load_state_dict(checkpoint['model_state_dict'])
    #     log_string('Use pretrain model')
    # except:
    #     log_string('No existing model, starting training from scratch...')
    #     start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/last_model.pth')
        start_epoch = checkpoint['epoch'] + 1
        classifier.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        best_instance_acc = checkpoint['instance_acc']
        best_class_acc = checkpoint['class_acc']
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0
        best_instance_acc = 0.0
        best_class_acc = 0.0

    global_epoch = 0
    global_step = 0
    mean_correct = []
    '''TRAINING'''
    logger.info('Start training...')
    writer_loss = SummaryWriter(os.path.join(str(log_dir), 'loss'))
    writer_train_instance_accuracy = SummaryWriter(
        os.path.join(str(log_dir), 'train_instance_accuracy'))
    writer_test_instance_accuracy = SummaryWriter(
        os.path.join(str(log_dir), 'test_instance_accuracy'))
    writer_test_class_accuracy = SummaryWriter(
        os.path.join(str(log_dir), 'test_class_accuracy'))
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        log_string('lr: %f' % optimizer.param_groups[0]['lr'])
        running_loss = 0.0
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            points = points.data.numpy()
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                         0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            if args.model == 'pointcnn_cls' and args.pointcnn_data_aug == True:
                points = provider.shuffle_points(points)
                points[:, :, 0:3] = provider.rotate_point_cloud(points[:, :,
                                                                       0:3])
                points[:, :, 0:3] = provider.jitter_point_cloud(points[:, :,
                                                                       0:3])
            points = torch.Tensor(points)
            target = target[:, 0]

            points = points.transpose(2, 1)

            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()

            classifier = classifier.train()
            if args.model == 'pointcnn_cls':
                points = points.transpose(2, 1)
                if args.dataset_name == 'cifar':
                    pos = points.reshape((-1, 6))
                    # normalise rgb
                    pos[:, 3:6] = pos[:, 3:6] / 255.0
                else:
                    pos = points.reshape((-1, 3))
                x = np.arange(0, args.batch_size)
                batch = torch.from_numpy(np.repeat(x, args.num_point)).cuda()
                pred, trans_feat = classifier(pos, batch)
            else:
                pred, trans_feat = classifier(points)

            loss = criterion(pred, target.long(), trans_feat)
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1

            running_loss += loss.item()
            if batch_id % 10 == 9:  # print every 10 batches
                niter = epoch * len(trainDataLoader) + batch_id
                writer_loss.add_scalar('Train/loss', loss.item(), niter)

        log_string('Loss: %f' % (running_loss / len(trainDataLoader)))
        train_instance_acc = np.mean(mean_correct)
        log_string('Train Instance Accuracy: %f' % train_instance_acc)
        writer_train_instance_accuracy.add_scalar('Train/instance_accuracy',
                                                  train_instance_acc.item(),
                                                  epoch)

        with torch.no_grad():
            instance_acc, class_acc = test(classifier.eval(), testDataLoader,
                                           num_class)
            writer_test_instance_accuracy.add_scalar('Test/instance_accuracy',
                                                     instance_acc.item(),
                                                     epoch)
            writer_test_class_accuracy.add_scalar('Test/class_accuracy',
                                                  class_acc.item(), epoch)

            if instance_acc >= best_instance_acc:
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if class_acc >= best_class_acc:
                best_class_acc = class_acc
            log_string('Test Instance Accuracy: %f, Class Accuracy: %f' %
                       (instance_acc, class_acc))
            log_string('Best Instance Accuracy: %f, Class Accuracy: %f' %
                       (best_instance_acc, best_class_acc))

            logger.info('Save the last model...')
            savepath_last = str(checkpoints_dir) + '/last_model.pth'
            log_string('Saving at %s' % savepath_last)
            state_last = {
                'epoch': epoch,
                'instance_acc': instance_acc,
                'class_acc': class_acc,
                'model_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }
            torch.save(state_last, savepath_last)

            if instance_acc >= best_instance_acc:
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
    writer_loss.close()
    writer_train_instance_accuracy.close()
    writer_test_instance_accuracy.close()
    writer_test_class_accuracy.close()
Пример #25
0
def get_acc():
    batch_size = 8

    config = yaml.load(open("/content/PointNet-BYOL/config/config.yaml", "r"),
                       Loader=yaml.FullLoader)
    #这里normal_channel一定要改成True,不然channel会变成3,无法与6匹配
    TRAIN_DATASET = ModelNetDataLoader(
        root='data/modelnet40_normal_resampled/',
        npoint=1024,
        split='train',
        normal_channel=True)
    TEST_DATASET = ModelNetDataLoader(root='data/modelnet40_normal_resampled/',
                                      npoint=1024,
                                      split='test',
                                      normal_channel=True)

    print("Input shape:", len(TRAIN_DATASET))

    train_loader = torch.utils.data.DataLoader(TRAIN_DATASET,
                                               batch_size=16,
                                               shuffle=True,
                                               num_workers=12)
    test_loader = torch.utils.data.DataLoader(TEST_DATASET,
                                              batch_size=16,
                                              shuffle=False,
                                              num_workers=12)
    device = 'cuda' if torch.cuda.is_available(
    ) else 'cpu'  #'cuda' if torch.cuda.is_available() else 'cpu'
    encoder = get_model(num_class=40, normal_channel=True)
    #  output_feature_dim = encoder.projetion.net[0].in_features#

    #  load pre-trained parameters
    load_params = torch.load(
        os.path.join('/content/PointNet-BYOL/checkpoints/model.pth'),
        map_location=torch.device(torch.device(device)))

    if 'online_network_state_dict' in load_params:
        encoder.load_state_dict(load_params['online_network_state_dict'])
        print("Parameters successfully loaded.")

    # remove the projection head
    encoder = encoder.to(device)
    encoder.eval()

    #  x_train, y_train = get_features_from_encoder(encoder, train_loader)
    #  x_test, y_test = get_features_from_encoder(encoder, test_loader)

    #  x_train = torch.mean(x_train, dim=[2, 3])
    #  x_test = torch.mean(x_test, dim=[2, 3])

    #  print("Training data shape:", x_train.shape, y_train.shape)
    #  print("Testing data shape:", x_test.shape, y_test.shape)
    #  x_train=np.array(x_train)
    #  scaler = preprocessing.StandardScaler()
    #  scaler.fit(x_train)
    #  x_train = scaler.transform(x_train).astype(np.float32)
    #  x_test = scaler.transform(x_test).astype(np.float32)

    #  train_loader, test_loader = create_data_loaders_from_arrays(torch.tensor([item.cpu().detach().numpy() for item in x_train]).cuda(), \
    #  y_train, torch.from_numpy(x_test), y_test)
    #  train_loader, test_loader = create_data_loaders_from_arrays(torch.from_numpy(x_train), \
    #  y_train, torch.from_numpy(x_test), y_test)

    criterion = pointnet2_cls_msg_concat.get_loss().cuda()
    classifier = pointnet2_cls_msg_concat.get_model(
        num_class=40, normal_channel=True).cuda()
    #  optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
    optimizer = torch.optim.Adam(classifier.parameters(),
                                 lr=0.001,
                                 betas=(0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=1e-4)
    eval_every_n_epochs = 1
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    try:
        checkpoint = torch.load(
            '/content/PointNet-BYOL/checkpoints/best_model.pth')
        classifier.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print('Use pretrain model')
    except:
        print('No existing model, starting training from scratch...')

    for epoch in range(20):
        print('Epoch %d ' % (epoch + 1))

        scheduler.step()
        for batch_id, data in tqdm(enumerate(train_loader, 0),
                                   total=len(train_loader),
                                   smoothing=0.9):
            points, target = data
            points = points.data.numpy()
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                         0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            target = target[:, 0]
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()
            feature_vector, cls = encoder(points)
            pred, cls = classifier(points, feature_vector)
            loss = criterion(pred, target.long(), target)
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
        train_instance_acc = np.mean(mean_correct)
        print('Train Instance Accuracy: %f' % train_instance_acc)
        with torch.no_grad():
            instance_acc, class_acc = test(classifier.eval(), encoder.eval(),
                                           test_loader)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            print('Test Instance Accuracy: %f, Class Accuracy: %f' %
                  (instance_acc, class_acc))
            print('Best Instance Accuracy: %f, Class Accuracy: %f' %
                  (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                print('Save model...')
                savepath = 'checkpoints' + '/best_model.pth'
                print('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)

    print('End of training...')

    #  train = torch.utils.data.TensorDataset(pred1, target)
    #  train_loader = torch.utils.data.DataLoader(train, batch_size=96, shuffle=True)
    #  for epoch in range(10):
    #     train_acc = []
    #     for x, y in train_loader:
    #         x = x.to(device)
    #         y = y.to(device)
    #         # zero the parameter gradients
    #         optimizer.zero_grad()
    #         classifier = classifier.train()
    #         pred = classifier(x)
    #         predictions = torch.argmax(pred, dim=1)
    #         loss = criterion(pred, y.long(),y)
    #         loss.backward(retain_graph=True)
    #         optimizer.step()

    #     if epoch % eval_every_n_epochs == 0:
    #         train_total,total = 0,0
    #         train_correct,correct = 0,0
    #         for x, y in train_loader:
    #             x = x.to(device)
    #             y = y.to(device)
    #             classifier = classifier.train()
    #             pred = classifier(x)
    #             predictions = torch.argmax(pred, dim=1)

    #             train_total += y.size(0)
    #             train_correct += (predictions == y).sum().item()
    #         for x, y in test_loader:
    #             x = x.to(device)
    #             y = y.to(device)

    #             classifier = classifier.train()
    #             pred = classifier(x)
    #             predictions = torch.argmax(pred, dim=1)

    #             total += y.size(0)
    #             correct += (predictions == y).sum().item()
    #         train_acc=  train_correct / train_total
    #         acc =  correct / total
    #         print(f"Training accuracy: {np.mean(train_acc)}")
    #         print(f"Testing accuracy: {np.mean(acc)}")
    return train_instance_acc
Пример #26
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    '''DATA LOADING'''
    log_string('Load dataset ...')
    DATA_PATH = 'data/modelnet40_normal_resampled/'

    TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train',
                                                     normal_channel=args.normal)
    TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test',
                                                    normal_channel=args.normal)
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=4)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4)

    '''MODEL LOADING'''
    num_class = 40
    MODEL = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util_psn.py', str(experiment_dir))

    classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda()
    criterion = MODEL.get_loss().cuda()

    try:
        checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0


    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    else:
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []
    best_epoch = 0

    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch,args.epoch):
        log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
            points, target = data
            points = points.data.numpy()
            points = provider.random_point_dropout(points)
            points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3])
            points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3])
            points = torch.Tensor(points)
            target = target[:, 0]

            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()

            classifier = classifier.train()
            pred, trans_feat = classifier(points, False)
            loss = criterion(pred, target.long(), trans_feat)
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1

        train_instance_acc = np.mean(mean_correct)
        log_string('Train Instance Accuracy: %f' % train_instance_acc)


        with torch.no_grad():
            instance_acc, class_acc = test(classifier.eval(), testDataLoader)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc))
            log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s'% savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
Пример #27
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR 建立保存路径'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('regression')
    experiment_dir.mkdir(exist_ok=True)

    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)

    checkpoints_dir = experiment_dir.joinpath('checkpoints')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs')
    log_dir.mkdir(exist_ok=True)
    '''LOG 建立日志文件'''
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING 加载数据'''
    log_string('Load dataset ...')
    DATA_PATH = './data/nyu_hand_dataset_v2/'
    # 这里是需要我改的 编写专用的数据集加载函数
    TRAIN_DATASET = NyuHandDataLoader(root=DATA_PATH,
                                      npoint=args.num_point,
                                      split='train_' + args.shape,
                                      normal_channel=args.normal)
    TEST_DATASET = NyuHandDataLoader(root=DATA_PATH,
                                     npoint=args.num_point,
                                     split='test_' + args.shape,
                                     normal_channel=args.normal)
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=4)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=4)
    '''MODEL LOADING 加载模型'''
    # 把依赖的模块函数 备份到日志路径里
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))
    MODEL = importlib.import_module(args.model)  #导入模型所在的模块
    classifier = MODEL.get_model(args.num_joint * 3,
                                 normal_channel=args.normal).cuda()
    criterion = MODEL.get_loss().cuda()  # 计算损失函数

    # 尝试加载已有的训练模型
    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    # 优化器 optimizer
    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    # scheduler.step()是对lr进行调整,step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_loss = 1000.0
    mean_loss = []  # 记录每次训练后,该batch的争取率
    '''TRANING 训练'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))

        # scheduler.step()是对lr进行调整
        scheduler.step()  # scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            # 提取 mini_batch
            points, target, scale, centroid = data
            # points 是当前 batch 的数据  target 是当前 batch 的标签
            points = points.data.numpy()  # points [B, Nsample, 坐标+法向量]
            points = provider.random_point_dropout(
                points, max_dropout_ratio=0.875)  # points 做随机的 dropout
            points[:, :, 0:3] = provider.random_scale_point_cloud(
                points[:, :, 0:3])  # points 做随机比例的放缩
            points[:, :, 0:3] = provider.shift_point_cloud(
                points[:, :, 0:3])  # points 做随机幅度的位移
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()

            # 计算损失函数
            optimizer.zero_grad()  # 梯度置零 把loss关于weight的导数变成0.
            classifier.train(
            )  # 使用PyTorch进行训练时,一定注意要把实例化的model指定train,表示启用 BatchNormalization 和 Dropout
            pred, trans_feat = classifier(
                points)  # pred 网络的输出, trans_feat 是输入的特征值,就是三个sa层后的输出
            # pred [Batch_size, num_joint*3]
            # trans_feat [Batch_size, SA层的输出大小, 1]

            # 依据scale, centroid,对 pred 做还原
            pred_reduction = pred.cpu().data.numpy()
            pred_reduction = pred_reduction * np.tile(
                scale, (args.num_joint * 3, 1)).transpose(1, 0)
            pred_reduction = pred_reduction + np.tile(centroid,
                                                      (1, args.num_joint))
            pred_reduction = torch.Tensor(pred_reduction)
            pred_reduction = pred_reduction.cuda()

            # 由于pred参数中除data外还有很多参数与求导有关,所以要将修正的值替换掉原值
            pred.data = pred_reduction.data

            loss = criterion(pred, target, trans_feat)
            loss.backward()
            optimizer.step()
            global_step += 1

            # 记录该 batch 的误差率
            mean_loss.append(loss.item())  # 记录该 batch 的误差率

        # 计算该 epoch 的 loss
        train_instance_loss = np.mean(mean_loss)
        log_string('The %dth Train Instance Loss: %f' %
                   (epoch + 1, train_instance_loss))

        with torch.no_grad():  # 不需要计算梯度
            test_instance_loss = test(classifier.eval(), testDataLoader)

            if (test_instance_loss <= best_instance_loss):
                best_instance_loss = test_instance_loss
                best_epoch = epoch + 1
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_loss': best_instance_loss,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)

            log_string('The %dth Test Instance Error: %f' %
                       (epoch + 1, test_instance_loss))
            log_string('Best Instance Error: %f \n' % (best_instance_loss))

            global_epoch += 1

    logger.info('End of training...')