예제 #1
0
def vis(args):
    test_data, test_label = load_data(root, train=False)
    log.info(test_data=test_data.shape, test_label=test_label.shape)

    log.debug('Building Model', args.model_name)
    if args.model_name == 'pointnet':
        num_class = 40
        model = PointNetCls(num_class, args.feature_transform).cuda()
    else:
        model = PointNet2ClsMsg().cuda()

    torch.backends.cudnn.benchmark = True
    model = torch.nn.DataParallel(model)
    model.cuda()
    log.info('Using multi GPU:', args.gpu)

    if args.pretrain is None:
        log.err('No pretrain model')
        return

    log.debug('Loading pretrain model...')
    checkpoint = torch.load(args.pretrain)
    model.load_state_dict(checkpoint)
    model.eval()

    log.info('Press space to exit, press Q for next frame')

    for idx in range(test_data.shape[0]):
        point_np = test_data[idx:idx + 1]
        gt = test_label[idx][0]

        points = torch.from_numpy(point_np)
        points = points.transpose(2, 1).cuda()

        pred, trans_feat = model(points)
        pred_choice = pred.data.max(1)[1]
        log.info(gt=class_names[gt],
                 pred_choice=class_names[pred_choice.cpu().numpy().item()])

        point_cloud = open3d.geometry.PointCloud()
        point_cloud.points = open3d.utility.Vector3dVector(point_np[0])

        vis = open3d.visualization.VisualizerWithKeyCallback()
        vis.create_window()
        vis.get_render_option().background_color = np.asarray([0, 0, 0])
        vis.add_geometry(point_cloud)
        vis.register_key_callback(32, lambda vis: exit())
        vis.run()
        vis.destroy_window()
예제 #2
0
def evaluate(args):
    test_data, test_label = load_data(
        'experiment/data/modelnet40_ply_hdf5_2048/', train=False)
    testDataset = ModelNetDataLoader(test_data, test_label)
    testDataLoader = torch.utils.data.DataLoader(testDataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False)

    log.debug('Building Model', args.model_name)
    if args.model_name == 'pointnet':
        num_class = 40
        model = PointNetCls(num_class, args.feature_transform)
    else:
        model = PointNet2ClsMsg()

    torch.backends.cudnn.benchmark = True
    model = torch.nn.DataParallel(model).cuda()
    log.debug('Using gpu:', args.gpu)

    if args.pretrain is None:
        log.err('No pretrain model')
        return

    log.debug('Loading pretrain model...')
    state_dict = torch.load(args.pretrain)
    model.load_state_dict(state_dict)

    acc = test_clf(model.eval(), testDataLoader)
    log.msg(Test_Accuracy='%.5f' % (acc))
# import torch.nn.functional as F
import numpy as np
import sys
from os import path
from model.pointnet import PointNetCls
import torch.multiprocessing as mp

sys.path.append(path.dirname(path.dirname(path.abspath("__file__"))))

# torch.cuda.manual_seed(1)  # don't delete

grasp_points_num = 1024
model = PointNetCls(num_points=grasp_points_num, input_chann=3, k=2)
model.cuda()

model.eval()
torch.set_grad_enabled(False)


def load_weight(weights_path):
    print('[Python] Load weight: {}'.format(weights_path))
    model.load_state_dict(torch.load(weights_path))


def classify_pcs(local_pcs, output_cls=0):
    """ Classify point clouds FPS:650 """
    print("[Python] Classify point clouds...")

    #     print(local_pcs)
    #     print(local_pcs.shape)
def main():
    ''' --- SELECT DEVICES --- '''
    # Select either gpu or cpu
    device = torch.device("cuda" if args.cuda else "cpu")
    # Select among available GPUs
    if args.cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
            str(x) for x in args.gpudevice)
    ''' --- INIT NETWORK MODEL --- '''
    projdir = sys.path[0]
    # Path for saving and loading the network.
    saveloadpath = os.path.join(projdir, 'experiment\\checkpoints',
                                args.exp_name + '.pth')
    Path(os.path.dirname(saveloadpath)).mkdir(exist_ok=True, parents=True)
    # Load selected network model and put it to right device
    if args.model_name == 'pointnet':
        classifier = PointNetCls(dim=args.pointCoordDim,
                                 num_class=len(args.categories),
                                 feature_transform=args.feature_transform)
    elif args.model_name == 'pointnet2':
        classifier = PointNet2ClsMsg(dim=args.pointCoordDim,
                                     num_class=len(args.categories))
    else:
        raise Exception(
            'Argument "model_name" does not match existent networks')
    classifier = classifier.to(device)
    ''' --- LOAD NETWORK IF EXISTS --- '''
    if os.path.exists(saveloadpath):
        print('Using pretrained model found...')
        checkpoint = torch.load(saveloadpath)
        start_epoch = checkpoint[
            'epoch'] + 1  # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of start_epoch
        iteration = checkpoint['iteration']
        best_test_acc = checkpoint['test_accuracy']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        raise Exception('Model in: {} does not exists'.format(saveloadpath))

    # put classifier in evaluation mode
    classifier = classifier.eval()
    ''' --- INIT DATASETS AND DATALOADER (FOR SINGLE EPOCH) --- '''
    # Ideal for PointNet and pointLSTM - dataloader will return (B:batch, S:seq, C:features, N:points)
    dataTransformations = transforms.Compose([
        ToSeries(),
        DataAugmentation(),
        Resampling(maxPointsPerFrame=10),
        ToTensor()
    ])

    # init dataset
    nusc = NuScenes(version=args.nuscenes_eval_dir,
                    dataroot=args.nuscenes_dir,
                    verbose=True)
    ''' Iterate over samples '''
    idx = list(range(len(nusc.sample)))
    random.shuffle(idx)  # shuffle samples
    # for sample_rec in nusc.sample:
    for i in idx:
        sample_rec = nusc.sample[i]
        sample_token = sample_rec['token']

        # read all sensors data and merge them to a common reference frame
        sensorlist = []
        for sensor in args.sensors:
            sensorlist.append(
                MyRadarPointCloud.from_sample_token(nusc, sample_token,
                                                    sensor))

        global radar_pc
        radar_pc = MyRadarPointCloud.merge(sensorlist, 0)

        # get annotations for actual sample
        ann_tokens = nusc.get('sample', sample_token)['anns']

        # filter annotations that are objects of interest for our classifier
        def get_ann_properties(nusc, ann_token: str):
            ann_rec = nusc.get('sample_annotation', ann_token)
            category = ann_rec['category_name']
            attr_tokens = ann_rec['attribute_tokens']
            attr = [
                nusc.get('attribute', attr_token)['name']
                for attr_token in attr_tokens
            ]
            if len(attr) == 0: attr = ['']
            print('object category:{:30} attr:{:40} center:{:20}'.format(
                category, str(attr),
                str(ann_rec['translation'] - radar_pc.A_cs_2_gl[:3, 3])))
            return category, attr[
                0]  # for now lets return only the first attribute (we might have more per object)

        def isInterestingObject(ann_token):
            cat, att = get_ann_properties(nusc, ann_token)
            # category-attribute pair must match at least one of the desired category-attr pair
            for desired_cat, desired_att in zip(args.categories,
                                                args.attributes):
                if cat in desired_cat and att in desired_att:
                    return True
            return False

        print('\n\nGround-truth objects on this frame:')
        ann_tokens = list(filter(isInterestingObject, ann_tokens))

        def getObjLabel(ann_token):
            cat, att = get_ann_properties(nusc, ann_token)
            for idx, (desired_cat, desired_att) in enumerate(
                    zip(args.categories, args.attributes)):
                if cat in desired_cat and att in desired_att:
                    return idx
            return np.NaN

        labels = list(map(getObjLabel, ann_tokens))
        assert np.all(
            ~np.isnan(labels)
        ), 'Something strange happened... object was selected as interesting but we can not find its label'

        # create bounding boxes from annotations
        ann_boxes = []
        for ann_token, label in zip(ann_tokens, labels):
            box = radar_pc.box(ann_token)
            box.label = label
            box.name = args.objectNames[label]
            ann_boxes.append(box)
        ''' APPLY DBSCAN ON EACH SCENE'''
        from sklearn.cluster import DBSCAN
        points_scene = radar_pc.points  # points_gl = <5,N>
        # filter out objects, whose speed is less than 0.3 m/s
        idx_moving = np.linalg.norm(points_scene[2:4, :].T, axis=1) > 0.25
        points_scene = points_scene[:, idx_moving]
        # apply DBSCAN
        clustering = DBSCAN(eps=3, min_samples=1).fit(points_scene[:2, :].T)
        ''' For each cluster, run network and predict class '''
        pred_results = []
        for cluster_idx in range(max(clustering.labels_)):
            # select points from the current cluster
            points_idx = clustering.labels_ == cluster_idx
            points_obj = points_scene[:, points_idx]
            # apply necessary transformations
            features = dataTransformations(points_obj)
            # convert to torch.tensor
            features = torch.tensor(features).type(
                torch.FloatTensor).unsqueeze(0).to(device)
            # calculate network prediction
            pred = classifier(features).argmax().item()
            # store result
            pred_results.append((points_obj, pred))
        ''' RENDER '''
        OBJCOLORS = ['red', 'green', 'blue', 'grey', 'orange', 'white']
        fig = plt.figure(constrained_layout=True)
        gs = GridSpec(2, 6, figure=fig)
        gs.update(wspace=0, hspace=0)
        axs = [
            fig.add_subplot(gs[1, :3]),
            fig.add_subplot(gs[1, 3:]),
            fig.add_subplot(gs[0, 0]),
            fig.add_subplot(gs[0, 1]),
            fig.add_subplot(gs[0, 2]),
            fig.add_subplot(gs[0, 3]),
            fig.add_subplot(gs[0, 4]),
            fig.add_subplot(gs[0, 5])
        ]

        # render annotation boxes
        for box in ann_boxes:
            color = OBJCOLORS[box.label]
            box.render(axs[0], colors=(color, ) * 3)
            box.render(axs[1], colors=(color, ) * 3)
        # render cars
        radar_pc.car.render(axs[0], colors=('orange', 'k', 'k'))
        radar_pc.car.render(axs[1], colors=('orange', 'k', 'k'))
        # render pointcloud
        radar_pc._render_pc(axs[0], radar_pc.points, color_channel='k')
        for cluster_points, pred in pred_results:
            # radar_pc._render_pc( axs[1], cluster_points, color_channel=OBJCOLORS[pred])
            if pred != 5:
                axs[1].scatter(cluster_points[0, :],
                               cluster_points[1, :],
                               s=30,
                               c=OBJCOLORS[pred],
                               edgecolors='k',
                               linewidths=1,
                               zorder=100)
                # plot ellipses
                if cluster_points.shape[1] > 1:
                    confidence_ellipse(axs[1],
                                       cluster_points[0, :],
                                       cluster_points[1, :],
                                       edgecolor=OBJCOLORS[pred],
                                       facecolor=OBJCOLORS[pred],
                                       alpha=0.5)
        # render camera images
        cameras = [
            'CAM_BACK_LEFT', 'CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT',
            'CAM_BACK_RIGHT', 'CAM_BACK'
        ]
        for ax_idx, cam_name in enumerate(cameras):
            cam_token = sample_rec['data'][cam_name]
            cam_rec = nusc.get('sample_data', cam_token)
            img_filename = os.path.join(nusc.dataroot, cam_rec['filename'])
            img = mpimg.imread(img_filename)
            # axs[ax_idx].set_axis_off()
            axs[ax_idx + 2].get_xaxis().set_visible(False)
            axs[ax_idx + 2].get_yaxis().set_visible(False)
            axs[ax_idx + 2].imshow(img)
            axs[ax_idx + 2].set_title(cam_name)
        # format
        legend_elements = [
            Line2D([0], [0], color=color, lw=4, label=name)
            for color, name in zip(OBJCOLORS, args.objectNames)
        ]
        axs[0].legend(handles=legend_elements,
                      loc='upper right',
                      prop={'size': 8})
        axs[1].legend(handles=legend_elements,
                      loc='upper right',
                      prop={'size': 8})
        axs[0].axis('equal')
        axs[1].axis('equal')
        mng = plt.get_current_fig_manager()
        mng.window.showMaximized()
        plt.tight_layout()
        lims = (-60, 60)
        axs[0].set_xlim(lims)
        axs[0].set_ylim(lims)
        axs[0].set_title(
            'Radar detection points and ground truth bounding boxes of MOVING objects'
        )
        axs[1].set_xlim(lims)
        axs[1].set_ylim(lims)
        axs[1].set_title(
            'Segmented segmentation of detection points (DBSCAN + PointNet classifier)\n3-sigma ellipses for each cluster and ground-truth bbs'
        )
        plt.tight_layout()
        plt.show()

        # plot scene with the standard Nuscenes method
        # nusc.render_sample(sample_token)
        # plt.show()

    # sample_token = nusc.get('sample', sample_token)['next']

    tb_writer.close()
예제 #5
0
def main(args):
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    # datapath = './data/ModelNet/'  
    datapath = './data/objecnn20_data_hdf5_2048/'
    if args.rotation is not None:
        ROTATION = (int(args.rotation[0:2]),int(args.rotation[3:5]))
    else:
        ROTATION = None

    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(str(experiment_dir) +'/%sObjectNNClf-'%args.model_name+ str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(str(log_dir) + '/train_%s_ObjectNNClf.txt'%args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info('---------------------------------------------------TRANING---------------------------------------------------')
    logger.info('PARAMETER ...')
    logger.info(args)

    '''DATA LOADING'''
    logger.info('Load dataset ...')
    train_data, train_label, test_data, test_label = load_data(datapath, classification=True)
    logger.info("The number of training data is: %d",train_data.shape[0])
    logger.info("The number of test data is: %d", test_data.shape[0])
    trainDataset = ObjectNNDataLoader(train_data, train_label, rotation=ROTATION)
    if ROTATION is not None:
        print('The range of training rotation is',ROTATION)
    testDataset = ObjectNNDataLoader(test_data, test_label, rotation=ROTATION)

    trainDataLoader = torch.utils.data.DataLoader(trainDataset, batch_size=args.batchsize, shuffle=True)
    testDataLoader = torch.utils.data.DataLoader(testDataset, batch_size=args.batchsize, shuffle=False)

    '''MODEL LOADING'''
    num_class = 20
    classifier = PointNetCls(num_class,args.feature_transform).cuda() if args.model_name == 'pointnet' else PointNet2ClsMsg().cuda()
    if args.pretrain is not None:
        print('Use pretrain model...')
        logger.info('Use pretrain model')
        checkpoint = torch.load(args.pretrain)
        start_epoch = checkpoint['epoch']
        # classifier.load_state_dict(checkpoint['model_state_dict'])
        # print(checkpoint['model_state_dict'])
        model_dict = classifier.state_dict()
        pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k in model_dict}
        model_dict.update(pretrained_dict)
        classifier.load_state_dict(model_dict)
    else:
        print('No existing model, starting training from scratch...')
        start_epoch = 0


    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) # 调整学习率
    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0
    blue = lambda x: '\033[94m' + x + '\033[0m'

    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch,args.epoch):
        print('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        logger.info('Epoch %d (%d/%s):' ,global_epoch + 1, epoch + 1, args.epoch)

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
            points, target = data
            target = target[:, 0]
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()
            pred, trans_feat, global_feature = classifier(points)
            loss = F.nll_loss(pred, target.long())
            if args.feature_transform and args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001
            loss.backward()
            optimizer.step()
            global_step += 1

        train_acc = test(classifier.eval(), trainDataLoader) if args.train_metric else None
        acc = test(classifier, testDataLoader)


        print('\r Loss: %f' % loss.data)
        logger.info('Loss: %.2f', loss.data)
        if args.train_metric:
            print('Train Accuracy: %f' % train_acc)
            logger.info('Train Accuracy: %f', (train_acc))
        print('\r Test %s: %f' % (blue('Accuracy'),acc))
        logger.info('Test Accuracy: %f', acc)

        if (acc >= best_tst_accuracy) and epoch > 5:
            best_tst_accuracy = acc
            logger.info('Save model...')
            save_checkpoint(
                global_epoch + 1,
                train_acc if args.train_metric else 0.0,
                acc,
                classifier,
                optimizer,
                str(checkpoints_dir),
                args.model_name)
            print('Saving model....')
        global_epoch += 1
    print('Best Accuracy: %f'%best_tst_accuracy)

    logger.info('End of training...')
예제 #6
0
파일: ex_clf.py 프로젝트: GryffindorLi/MPAN
def main(testDataset):
    args = parse_args()
    '''HYPER PARAMETER'''
    # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else '0,1,2,3'

    datapath = './data/ModelNet/'

    if args.rotation is not None:
        ROTATION = (int(args.rotation[0:2]),int(args.rotation[3:5]))
    else:
        ROTATION = None

    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = Path('./experiment/checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = Path('./experiment/logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("PointNet2")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('./experiment/logs/train_%s_'%args.model_name+ str(datetime.datetime.now().strftime('%Y-%m-%d %H-%M'))+'.txt')
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info('---------------------------------------------------TRANING---------------------------------------------------')
    logger.info('PARAMETER ...')
    logger.info(args)

    '''DATA LOADING'''
    logger.info('Load dataset ...')
    # train_data, train_label, test_data, test_label = load_data(datapath, classification=True)
    # logger.info("The number of training data is: %d",train_data.shape[0])
    # logger.info("The number of test data is: %d", test_data.shape[0])
    # trainDataset = ModelNetDataLoader(train_data, train_label, rotation=ROTATION)
    #if ROTATION is not None:
    #    print('The range of training rotation is',ROTATION)
    # testDataset = ModelNetDataLoader(test_data, test_label, rotation=ROTATION)
#     trainDataLoader = torch.utils.data.DataLoader(trainDataset, batch_size=args.batchsize, shuffle=True)
    testDataLoader = torch.utils.data.DataLoader(testDataset, batch_size=args.batchsize, shuffle=False)

    '''MODEL LOADING'''
    num_class = 40
    classifier = PointNetCls(num_class,args.feature_transform).cuda() if args.model_name == 'pointnet' else PointNet2ClsMsg().cuda()

    '''GPU selection and multi-GPU'''
    if args.multi_gpu is not None:
        device_ids = [int(x) for x in args.multi_gpu.split(',')]
        torch.backends.cudnn.benchmark = True
        classifier.cuda(device_ids[0])
        classifier = torch.nn.DataParallel(classifier, device_ids=device_ids)
    else:
        classifier.cuda()

    
    
    
    ''' 使用预训练模型  '''
    if args.pretrain is not None:
        print('Use pretrain model...')
        logger.info('Use pretrain model')
        checkpoint = torch.load(args.pretrain)
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('No existing model, starting training from scratch...')
        start_epoch = 0


    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0
    blue = lambda x: '\033[94m' + x + '\033[0m'
    
    start_epoch = 0
    args.epoch = 1
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch,args.epoch):
        print('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        logger.info('Epoch %d (%d/%s):' ,global_epoch + 1, epoch + 1, args.epoch)

#         scheduler.step()
#         for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
#             points, target = data
#             target = target[:, 0]
#             points = points.transpose(2, 1)
#             points, target = points.cuda(), target.cuda()
#             optimizer.zero_grad()
#             classifier = classifier.train()
#             pred, trans_feat = classifier(points)
#             loss = F.nll_loss(pred, target.long())
#             if args.feature_transform and args.model_name == 'pointnet':
#                 loss += feature_transform_reguliarzer(trans_feat) * 0.001
#             loss.backward()
#             optimizer.step()
#             global_step += 1

        train_acc = test(classifier.eval(), trainDataLoader) if args.train_metric else None
        acc, fts = test(classifier, testDataLoader)
        # return fts
        # print('\r Loss: %f' % loss.data)
        # logger.info('Loss: %.2f', loss.data)
        if args.train_metric:
            print('Train Accuracy: %f' % train_acc)
            logger.info('Train Accuracy: %f', (train_acc))
        print('\r Test %s: %f' % (blue('Accuracy'),acc))
        logger.info('Test Accuracy: %f', acc)
        return fts
        if (acc >= best_tst_accuracy) and epoch > 5:
            best_tst_accuracy = acc
            logger.info('Save model...')
            save_checkpoint(
                global_epoch + 1,
                train_acc if args.train_metric else 0.0,
                acc,
                classifier,
                optimizer,
                str(checkpoints_dir),
                args.model_name)
            print('Saving model....')
        global_epoch += 1
    print('Best Accuracy: %f'%best_tst_accuracy)

    logger.info('End of training...')
예제 #7
0
def main(args):
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    # datapath = './data/ModelNet/'
    datapath = './data/modelnet40_ply_hdf5_2048/'
    if args.rotation is not None:
        ROTATION = (int(args.rotation[0:2]), int(args.rotation[3:5]))
    else:
        ROTATION = None
    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = Path('./experiment/checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = Path('./experiment/logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("PointNet2")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        './experiment/logs/test_%s_' % args.model_name +
        str(datetime.datetime.now().strftime('%Y-%m-%d %H-%M')) + '.txt')
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(
        '---------------------------------------------------Test---------------------------------------------------'
    )
    logger.info('PARAMETER ...')
    logger.info(args)
    '''DATA LOADING'''
    logger.info('Load dataset ...')
    train_data, train_label, test_data, test_label = load_data(
        datapath, classification=True)
    logger.info("The number of training data is: %d", train_data.shape[0])
    logger.info("The number of test data is: %d", test_data.shape[0])
    trainDataset = ModelNetDataLoader(train_data,
                                      train_label,
                                      rotation=ROTATION)
    if ROTATION is not None:
        print('The range of training rotation is', ROTATION)
    testDataset = ModelNetDataLoader(test_data, test_label, rotation=ROTATION)

    trainDataLoader = torch.utils.data.DataLoader(trainDataset,
                                                  batch_size=args.batchsize,
                                                  shuffle=True)
    testDataLoader = torch.utils.data.DataLoader(testDataset,
                                                 batch_size=args.batchsize,
                                                 shuffle=False)  # 不打乱
    '''MODEL LOADING'''
    num_class = 40
    ###################### PointNetCls ######################
    classifier = PointNetCls(num_class, args.feature_transform).cuda(
    ) if args.model_name == 'pointnet' else PointNet2ClsMsg().cuda()
    if args.pretrain is not None:
        print('Use pretrain model...')
        logger.info('Use pretrain model')
        checkpoint = torch.load(args.pretrain)
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('Please Input the pretrained model ***.pth')
        return

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.5)
    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0
    blue = lambda x: '\033[94m' + x + '\033[0m'
    '''TestING'''
    logger.info('Start testing...')

    scheduler.step()

    acc = test(classifier.eval(), testDataLoader)

    print('\r Test %s: %f' % (blue('Accuracy'), acc))
    logger.info('Test Accuracy: %f', acc)

    logger.info('End of testing...')
예제 #8
0
def main(args):
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    # datapath = './data/ModelNet/'
    datapath = './data/objecnn20_data_hdf5_2048/'
    if args.rotation is not None:
        ROTATION = (int(args.rotation[0:2]), int(args.rotation[3:5]))
    else:
        ROTATION = None
    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(
        str(experiment_dir) + '/Test_%sObjectNNClf-' % args.model_name +
        str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        str(log_dir) + '/test_%s_ObjectNNClf.txt' % args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(
        '---------------------------------------------------Test---------------------------------------------------'
    )
    logger.info('PARAMETER ...')
    logger.info(args)
    '''DATA LOADING'''
    logger.info('Load dataset ...')
    database_data, database_label, query_data, query_label = load_data(
        datapath, classification=True)

    print(">>>>>>>>>database_data:", database_data.shape)
    print(">>>>>>>>>query_data:", query_data.shape)

    logger.info("The number of database_data data is: %d",
                database_data.shape[0])
    logger.info("The number of query_data data is: %d", query_data.shape[0])

    ###################### 加载 database 和 query ######################
    databaseDataset = TestQueryObjectNNDataLoader(database_data,
                                                  database_label,
                                                  rotation=ROTATION)
    if ROTATION is not None:
        print('The range of training rotation is', ROTATION)
    queryDataset = TestQueryObjectNNDataLoader(query_data,
                                               query_label,
                                               rotation=ROTATION)

    databaseDataLoader = torch.utils.data.DataLoader(databaseDataset,
                                                     batch_size=args.batchsize,
                                                     shuffle=False)
    queryDataLoader = torch.utils.data.DataLoader(queryDataset,
                                                  batch_size=args.batchsize,
                                                  shuffle=False)  # 不打乱
    '''MODEL LOADING'''
    num_class = 20
    ###################### PointNetCls ######################
    classifier = PointNetCls(num_class, args.feature_transform).cuda(
    ) if args.model_name == 'pointnet' else PointNet2ClsMsg().cuda()
    # classifier = PointNetCls(num_class,args.feature_transform) if args.model_name == 'pointnet' else PointNet2ClsMsg()

    if args.pretrain is not None:
        print('Use pretrain model...')
        logger.info('Use pretrain model')
        checkpoint = torch.load(args.pretrain)
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('Please Input the pretrained model ***.pth')
        return

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.5)
    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0
    blue = lambda x: '\033[94m' + x + '\033[0m'
    '''QueryING'''
    logger.info('Start query...')

    scheduler.step()

    ###################### Query ######################
    classifier.eval()

    _acc, database_feature_martix = getGlobalFeature('database',
                                                     classifier.eval(),
                                                     databaseDataLoader)

    _acc, query_feature_martix = getGlobalFeature('query', classifier.eval(),
                                                  queryDataLoader)