def main(args, stage=None, pretrain_model_path=None, e_init=0, e_final=2):

    # Now load pickle labels mapping file
    class_dict_fname = F_CLASS_DICT_PKL
    print(class_dict_fname)
    with open(class_dict_fname, "rb") as f:
        class_dict = pickle.load(f)
        f.close()

    print("CLASS DICT: {}".format(class_dict))

    # Use to get numeric classes --> semantic classes
    seg_classes = class_dict
    seg_label_to_cat = {}
    for i, cat in enumerate(seg_classes.values()):
        seg_label_to_cat[i] = cat

    print('SEG LABEL', seg_label_to_cat)

    # First load class weights file
    with open(F_CLASS_WEIGHTS_PKL, "rb") as f:
        class_weights = pickle.load(f)
        f.close()
    print('SEG CLASSES', seg_classes)
    COUNTS = np.array(
        [class_weights[key] for key in list(class_weights.keys())])
    weight_normalizer = np.max(COUNTS)

    weights = []
    for count in COUNTS:
        if count != 0:
            weights.append(weight_normalizer / count)
        else:
            weights.append(0)

    # Threshold weights
    WEIGHTS_NP = np.array(weights)
    WEIGHTS_NP[WEIGHTS_NP > THRESHOLD] = THRESHOLD

    print("WEIGHTS ARE: {}".format(WEIGHTS_NP))

    # Convert to pytorch tensor
    weights = torch.from_numpy(WEIGHTS_NP.astype('float32'))

    if USE_CLI:
        gpu = args.gpu
        multi_gpu = args.multi_gpu
        batch_size = args.batch_size
        model_name = args.model_name
        optimizer = args.optimizer
        learning_rate = args.learning_rate
        pretrain = args.pretrain
        multi_gpu = args.multi_gpu
        batchsize = args.batchsize
        decay_rate = args.decay_rate
        epochs = args.epochs
    else:
        gpu = GPU
        multi_gpu = MULTI_GPU
        batch_size = BATCH_SIZE
        model_name = MODEL_NAME
        optimizer = OPTIMIZER
        learning_rate = LEARNING_RATE
        pretrain = PRETRAIN
        multi_gpu = MULTI_GPU
        batchsize = BATCH_SIZE
        decay_rate = DECAY_RATE
        epochs = EPOCHS

    os.environ[
        "CUDA_VISIBLE_DEVICES"] = gpu if multi_gpu is None else '0,1,2,3'
    '''CREATE DIR'''
    experiment_dir = Path('./experiment/{}'.format(EXPERIMENT_HEADER))
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(
        str(experiment_dir) + '/%sSemSeg-' % model_name +
        str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    if USE_CLI:
        args = parse_args()
        logger = logging.getLogger(model_name)
    else:
        logger = logging.getLogger(MODEL_NAME)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    if USE_CLI:
        file_handler = logging.FileHandler(
            str(log_dir) + '/train_%s_semseg.txt' % args.model_name)
    else:
        file_handler = logging.FileHandler(
            str(log_dir) + '/train_%s_semseg.txt' % MODEL_NAME)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(
        '---------------------------------------------------TRANING---------------------------------------------------'
    )
    if USE_CLI:
        logger.info('PARAMETER ...')
        logger.info(args)
    print('Load data...')
    #train_data, train_label, test_data, test_label = recognize_all_data(test_area = 5)

    # Now pickle our dataset
    if USE_CLI:
        f_in = args.data_path
    else:
        f_in = DATA_PATH

    # Now pickle file
    with open(f_in, "rb") as f:
        DATA = pickle.load(f)
        f.close()

    # Now, we need to adjust labels according to our stage
    labels_fpath = os.path.join("..", "data",
                                "stage_{}_labels.pkl".format(stage))

    # Pickle this stage converter
    with open(labels_fpath, "rb") as f:
        labels = pickle.load(f)
        print(type(labels), len(labels))
        f.close()

    # Now let's convert labels appropriately
    keys = list(DATA.keys())
    for key in keys:
        DATA[key]['labels'] = labels[key]

    random_seed = 42
    indices = [i for i in range(len(list(DATA.keys())))]
    np.random.seed(random_seed)
    np.random.shuffle(indices)
    TEST_SPLIT = 0.2
    test_index = int(np.floor(TEST_SPLIT * len(list(DATA.keys()))))
    print("val index is: {}".format(test_index))
    train_indices, test_indices = indices[test_index:], indices[:test_index]
    if USE_CLI:
        print("LEN TRAIN: {}, LEN TEST: {}, EPOCHS: {}, OPTIMIZER: {}, DECAY_RATE: {}, LEARNING RATE: {}, \
        DATA PATH: {}"                      .format(len(train_indices), len(test_indices), epochs, args.optimizer, args.decay_rate, \
                              args.learning_rate, args.data_path))
    else:
        print("LEN TRAIN: {}, LEN TEST: {}, EPOCHS: {}, OPTIMIZER: {}, DECAY_RATE: {}, LEARNING RATE: {}, \
        DATA PATH: {}"                      .format(len(train_indices), len(test_indices), e_final-e_init, OPTIMIZER, DECAY_RATE, \
                              LEARNING_RATE, DATA_PATH))

    # Creating PT data samplers and loaders:
    train_sampler = SubsetRandomSampler(train_indices)
    test_sampler = SubsetRandomSampler(test_indices)
    print("INTERSECTION OF TRAIN/TEST (should be 0): {}".format(
        len(set(train_indices).intersection(set(test_indices)))))

    # Training dataset
    dataset = A2D2DataLoader(DATA)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batchsize,
                                             shuffle=False,
                                             sampler=train_sampler,
                                             collate_fn=collate_fn)
    # Test dataset
    test_dataset = A2D2DataLoader(DATA)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=batchsize,
                                                 shuffle=False,
                                                 sampler=test_sampler,
                                                 collate_fn=collate_fn)

    num_classes = NUM_CLASSES

    blue = lambda x: '\033[94m' + x + '\033[0m'
    model = PointNet2SemSeg(
        num_classes) if model_name == 'pointnet2' else PointNetSeg(
            num_classes, feature_transform=True, semseg=True)

    if pretrain_model_path is not None:
        model.load_state_dict(torch.load(pretrain_model_path))
        print('load model %s' % pretrain_model_path)
        logger.info('load model %s' % pretrain_model_path)
    else:
        print('Training from scratch')
        logger.info('Training from scratch')
    #pretrain_var = pretrain
    #init_epoch = int(pretrain_var[-14:-11]) if pretrain is not None else 0
    init_epoch = e_init

    if optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.5)
    LEARNING_RATE_CLIP = 1e-5
    '''GPU selection and multi-GPU'''
    if multi_gpu is not None:
        device_ids = [int(x) for x in multi_gpu.split(',')]
        torch.backends.cudnn.benchmark = True
        model.cuda(device_ids[0])
        model = torch.nn.DataParallel(model, device_ids=device_ids)
    else:
        model.cuda()

    history = defaultdict(lambda: list())
    best_acc = 0
    best_meaniou = 0
    graph_losses = []
    steps = []
    step = 0
    print("NUMBER OF EPOCHS IS: {}".format(e_final - e_init))
    for epoch in range(e_init, e_final):
        scheduler.step()
        lr = max(optimizer.param_groups[0]['lr'], LEARNING_RATE_CLIP)
        print('Learning rate:%f' % lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        counter = 0
        # Init confusion matrix
        if USE_CONMAT:
            conf_matrix = torch.zeros(NUM_CLASSES, NUM_CLASSES)
        for points, targets in tqdm(dataloader):
            #for points, target in tqdm(dataloader):
            #points, target = data
            points, targets = Variable(points.float()), Variable(
                targets.long())
            points = points.transpose(2, 1)
            points, targets = points.cuda(), targets.cuda()
            weights = weights.cuda()
            optimizer.zero_grad()  # REMOVE gradients
            model = model.train()
            if model_name == 'pointnet':
                pred, trans_feat = model(points)
            else:
                pred = model(
                    points[:, :3, :], points[:, 3:, :]
                )  # Channels: xyz_norm (first 3) | rgb_norm (second three)
                #pred = model(points)
            if USE_CONMAT:
                conf_matrix = confusion_matrix(pred, targets, conf_matrix)
            pred = pred.contiguous().view(-1, num_classes)
            targets = targets.view(-1, 1)[:, 0]
            loss = F.nll_loss(pred, targets,
                              weight=weights)  # Add class weights from dataset
            if model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001
            graph_losses.append(loss.cpu().data.numpy())
            steps.append(step)
            if counter % 100 == 0:
                print("LOSS IS: {}".format(loss.cpu().data.numpy()))
            #print((loss.cpu().data.numpy()))
            history['loss'].append(loss.cpu().data.numpy())
            loss.backward()
            optimizer.step()
            counter += 1
            step += 1
            #if counter > 3:
            #     break
        if USE_CONMAT:
            print("CONFUSION MATRIX: \n {}".format(conf_matrix))
        pointnet2 = model_name == 'pointnet2'
        test_metrics, test_hist_acc, cat_mean_iou = test_semseg(model.eval(), testdataloader, seg_label_to_cat,\
                                                                num_classes = num_classes,pointnet2=pointnet2)
        mean_iou = np.mean(cat_mean_iou)
        print('Epoch %d  %s accuracy: %f  meanIOU: %f' %
              (epoch, blue('test'), test_metrics['accuracy'], mean_iou))
        logger.info('Epoch %d  %s accuracy: %f  meanIOU: %f' %
                    (epoch, 'test', test_metrics['accuracy'], mean_iou))
        if test_metrics['accuracy'] > best_acc:
            best_acc = test_metrics['accuracy']
            print("HERE")
            save_path = '%s/%s_%.3d_%.4f_stage_%s.pth' % (
                checkpoints_dir, model_name, epoch, best_acc, stage)
            torch.save(model.state_dict(), save_path)
            logger.info(cat_mean_iou)
            logger.info('Save model..')
            print('Save model..')
            print(cat_mean_iou)  #
        if mean_iou > best_meaniou:
            best_meaniou = mean_iou
        print('Best accuracy is: %.5f' % best_acc)
        logger.info('Best accuracy is: %.5f' % best_acc)
        print('Best meanIOU is: %.5f' % best_meaniou)
        logger.info('Best meanIOU is: %.5f' % best_meaniou)
        if USE_CONMAT:
            logger.info('Confusion matrix is: \n {}'.format(conf_matrix))

        # Plot loss vs. steps
        plt.plot(steps, graph_losses)
        plt.xlabel("Batched Steps (Batch Size = {}".format(batch_size))
        plt.ylabel("Multiclass NLL Loss")
        plt.title("NLL Loss vs. Number of Batched Steps")

        # Make directory for loss and other plots
        graphs_dir = os.path.join(experiment_dir, "graphs")
        os.makedirs(graphs_dir, exist_ok=True)

        # Save and close figure
        plt.savefig(os.path.join(graphs_dir, "losses.png"))
        plt.clf()
    return save_path
Beispiel #2
0
def train(args):
    experiment_dir = mkdir('./experiment/')
    checkpoints_dir = mkdir('./experiment/semseg/%s/'%(args.model_name))
    train_data, train_label, test_data, test_label = _load()

    dataset = S3DISDataLoader(train_data, train_label, data_augmentation = args.augment)
    dataloader = DataLoader(dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers)
    
    test_dataset = S3DISDataLoader(test_data, test_label)
    testdataloader = DataLoader(test_dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers)

    num_classes = 13
    if args.model_name == 'pointnet':
        model = PointNetSeg(num_classes, feature_transform=True, input_dims = 9)
    else:
        model = PointNet2SemSeg(num_classes, feature_dims = 6)

    torch.backends.cudnn.benchmark = True
    model = torch.nn.DataParallel(model).cuda()
    log.debug('Using gpu:',args.gpu)

    if args.pretrain is not None:
        log.debug('Use pretrain model...')
        model.load_state_dict(torch.load(args.pretrain))
        init_epoch = int(args.pretrain[:-4].split('-')[-1])
        log.debug('start epoch from', init_epoch)
    else:
        log.debug('Training from scratch')
        init_epoch = 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate)
            
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    LEARNING_RATE_CLIP = 1e-5

    history = {'loss':[]}
    best_acc = 0
    best_meaniou = 0

    for epoch in range(init_epoch,args.epoch):
        scheduler.step()
        lr = max(optimizer.param_groups[0]['lr'],LEARNING_RATE_CLIP)

        log.info(job='semseg',model=args.model_name,gpu=args.gpu,epoch='%d/%s' % (epoch, args.epoch),lr=lr)
        
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
        for points, target in tqdm(dataloader, total=len(dataloader), smoothing=0.9, dynamic_ncols=True):
            points, target = points.float(), target.long()
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            model = model.train()

            if args.model_name == 'pointnet':
                pred, trans_feat = model(points)
            else:
                pred = model(points)

            pred = pred.contiguous().view(-1, num_classes)
            target = target.view(-1, 1)[:, 0]
            loss = F.nll_loss(pred, target)

            if args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001

            history['loss'].append(loss.cpu().data.numpy())
            loss.backward()
            optimizer.step()
        
        log.debug('clear cuda cache')
        torch.cuda.empty_cache()

        test_metrics, cat_mean_iou = test_semseg(
            model.eval(), 
            testdataloader, 
            label_id_to_name,
            args.model_name,
            num_classes,
        )
        mean_iou = np.mean(cat_mean_iou)

        save_model = False
        if test_metrics['accuracy'] > best_acc:
            best_acc = test_metrics['accuracy']
        
        if mean_iou > best_meaniou:
            best_meaniou = mean_iou
            save_model = True
        
        if save_model:
            fn_pth = 'semseg-%s-%.5f-%04d.pth' % (args.model_name, best_meaniou, epoch)
            log.info('Save model...',fn = fn_pth)
            torch.save(model.state_dict(), os.path.join(checkpoints_dir, fn_pth))
            log.warn(cat_mean_iou)
        else:
            log.info('No need to save model')
            log.warn(cat_mean_iou)

        log.warn('Curr',accuracy=test_metrics['accuracy'], meanIOU=mean_iou)
        log.warn('Best',accuracy=best_acc, meanIOU=best_meaniou)
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else '0,1,2,3'
    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(str(experiment_dir) +'/%sSemSeg-'%args.model_name+ str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(str(log_dir) + '/train_%s_semseg.txt'%args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info('---------------------------------------------------TRANING---------------------------------------------------')
    logger.info('PARAMETER ...')
    logger.info(args)
    print('Load data...')
    train_data, train_label, test_data, test_label = recognize_all_data(test_area = 5)

    dataset = S3DISDataLoader(train_data,train_label)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batchsize,
                                             shuffle=True, num_workers=int(args.workers))
    test_dataset = S3DISDataLoader(test_data,test_label)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize,
                                                 shuffle=True, num_workers=int(args.workers))

    num_classes = 13
    blue = lambda x: '\033[94m' + x + '\033[0m'
    model = PointNet2SemSeg(num_classes) if args.model_name == 'pointnet2' else PointNetSeg(num_classes,feature_transform=True,semseg = True)

    if args.pretrain is not None:
        model.load_state_dict(torch.load(args.pretrain))
        print('load model %s'%args.pretrain)
        logger.info('load model %s'%args.pretrain)
    else:
        print('Training from scratch')
        logger.info('Training from scratch')
    pretrain = args.pretrain
    init_epoch = int(pretrain[-14:-11]) if args.pretrain is not None else 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    LEARNING_RATE_CLIP = 1e-5

    '''GPU selection and multi-GPU'''
    if args.multi_gpu is not None:
        device_ids = [int(x) for x in args.multi_gpu.split(',')]
        torch.backends.cudnn.benchmark = True
        model.cuda(device_ids[0])
        model = torch.nn.DataParallel(model, device_ids=device_ids)
    else:
        model.cuda()

    history = defaultdict(lambda: list())
    best_acc = 0
    best_meaniou = 0

    for epoch in range(init_epoch,args.epoch):
        scheduler.step()
        lr = max(optimizer.param_groups[0]['lr'],LEARNING_RATE_CLIP)
        print('Learning rate:%f' % lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        for i, data in tqdm(enumerate(dataloader, 0),total=len(dataloader),smoothing=0.9):
            points, target = data
            points, target = Variable(points.float()), Variable(target.long())
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            model = model.train()
            if args.model_name == 'pointnet':
                pred, trans_feat = model(points)
            else:
                pred = model(points[:,:3,:],points[:,3:,:])
            pred = pred.contiguous().view(-1, num_classes)
            target = target.view(-1, 1)[:, 0]
            loss = F.nll_loss(pred, target)
            if args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001
            history['loss'].append(loss.cpu().data.numpy())
            loss.backward()
            optimizer.step()
        pointnet2 = args.model_name == 'pointnet2'
        test_metrics, test_hist_acc, cat_mean_iou = test_semseg(model.eval(), testdataloader, seg_label_to_cat,num_classes = num_classes,pointnet2=pointnet2)
        mean_iou = np.mean(cat_mean_iou)
        print('Epoch %d  %s accuracy: %f  meanIOU: %f' % (
                 epoch, blue('test'), test_metrics['accuracy'],mean_iou))
        logger.info('Epoch %d  %s accuracy: %f  meanIOU: %f' % (
                 epoch, 'test', test_metrics['accuracy'],mean_iou))
        if test_metrics['accuracy'] > best_acc:
            best_acc = test_metrics['accuracy']
            torch.save(model.state_dict(), '%s/%s_%.3d_%.4f.pth' % (checkpoints_dir,args.model_name, epoch, best_acc))
            logger.info(cat_mean_iou)
            logger.info('Save model..')
            print('Save model..')
            print(cat_mean_iou)
        if mean_iou > best_meaniou:
            best_meaniou = mean_iou
        print('Best accuracy is: %.5f'%best_acc)
        logger.info('Best accuracy is: %.5f'%best_acc)
        print('Best meanIOU is: %.5f'%best_meaniou)
        logger.info('Best meanIOU is: %.5f'%best_meaniou)
Beispiel #4
0
def main(args):
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    # datapath = './data/ModelNet/'  
    datapath = './data/objecnn20_data_hdf5_2048/'
    if args.rotation is not None:
        ROTATION = (int(args.rotation[0:2]),int(args.rotation[3:5]))
    else:
        ROTATION = None

    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(str(experiment_dir) +'/%sObjectNNClf-'%args.model_name+ str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(str(log_dir) + '/train_%s_ObjectNNClf.txt'%args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info('---------------------------------------------------TRANING---------------------------------------------------')
    logger.info('PARAMETER ...')
    logger.info(args)

    '''DATA LOADING'''
    logger.info('Load dataset ...')
    train_data, train_label, test_data, test_label = load_data(datapath, classification=True)
    logger.info("The number of training data is: %d",train_data.shape[0])
    logger.info("The number of test data is: %d", test_data.shape[0])
    trainDataset = ObjectNNDataLoader(train_data, train_label, rotation=ROTATION)
    if ROTATION is not None:
        print('The range of training rotation is',ROTATION)
    testDataset = ObjectNNDataLoader(test_data, test_label, rotation=ROTATION)

    trainDataLoader = torch.utils.data.DataLoader(trainDataset, batch_size=args.batchsize, shuffle=True)
    testDataLoader = torch.utils.data.DataLoader(testDataset, batch_size=args.batchsize, shuffle=False)

    '''MODEL LOADING'''
    num_class = 20
    classifier = PointNetCls(num_class,args.feature_transform).cuda() if args.model_name == 'pointnet' else PointNet2ClsMsg().cuda()
    if args.pretrain is not None:
        print('Use pretrain model...')
        logger.info('Use pretrain model')
        checkpoint = torch.load(args.pretrain)
        start_epoch = checkpoint['epoch']
        # classifier.load_state_dict(checkpoint['model_state_dict'])
        # print(checkpoint['model_state_dict'])
        model_dict = classifier.state_dict()
        pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k in model_dict}
        model_dict.update(pretrained_dict)
        classifier.load_state_dict(model_dict)
    else:
        print('No existing model, starting training from scratch...')
        start_epoch = 0


    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) # 调整学习率
    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0
    blue = lambda x: '\033[94m' + x + '\033[0m'

    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch,args.epoch):
        print('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        logger.info('Epoch %d (%d/%s):' ,global_epoch + 1, epoch + 1, args.epoch)

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
            points, target = data
            target = target[:, 0]
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()
            pred, trans_feat, global_feature = classifier(points)
            loss = F.nll_loss(pred, target.long())
            if args.feature_transform and args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001
            loss.backward()
            optimizer.step()
            global_step += 1

        train_acc = test(classifier.eval(), trainDataLoader) if args.train_metric else None
        acc = test(classifier, testDataLoader)


        print('\r Loss: %f' % loss.data)
        logger.info('Loss: %.2f', loss.data)
        if args.train_metric:
            print('Train Accuracy: %f' % train_acc)
            logger.info('Train Accuracy: %f', (train_acc))
        print('\r Test %s: %f' % (blue('Accuracy'),acc))
        logger.info('Test Accuracy: %f', acc)

        if (acc >= best_tst_accuracy) and epoch > 5:
            best_tst_accuracy = acc
            logger.info('Save model...')
            save_checkpoint(
                global_epoch + 1,
                train_acc if args.train_metric else 0.0,
                acc,
                classifier,
                optimizer,
                str(checkpoints_dir),
                args.model_name)
            print('Saving model....')
        global_epoch += 1
    print('Best Accuracy: %f'%best_tst_accuracy)

    logger.info('End of training...')
Beispiel #5
0
def train(args):
    experiment_dir = mkdir('experiment/')
    checkpoints_dir = mkdir('experiment/%s/' % (args.model_name))

    kitti_utils = Semantic_KITTI_Utils(KITTI_ROOT, subset=args.subset)
    class_names = kitti_utils.class_names
    num_classes = kitti_utils.num_classes

    if args.subset == 'inview':
        train_npts = 8000
        test_npts = 24000

    if args.subset == 'all':
        train_npts = 50000
        test_npts = 100000

    log.info(subset=args.subset, train_npts=train_npts, test_npts=test_npts)

    dataset = SemKITTI_Loader(KITTI_ROOT,
                              train_npts,
                              train=True,
                              subset=args.subset)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.workers,
                            pin_memory=True)

    test_dataset = SemKITTI_Loader(KITTI_ROOT,
                                   test_npts,
                                   train=False,
                                   subset=args.subset)
    testdataloader = DataLoader(test_dataset,
                                batch_size=int(args.batch_size / 2),
                                shuffle=False,
                                num_workers=args.workers,
                                pin_memory=True)

    if args.model_name == 'pointnet':
        model = PointNetSeg(num_classes, input_dims=4, feature_transform=True)
    else:
        model = PointNet2SemSeg(num_classes, feature_dims=1)

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-4)

    torch.backends.cudnn.benchmark = True
    model = torch.nn.DataParallel(model)
    model.cuda()
    log.info('Using gpu:', args.gpu)

    if args.pretrain is not None:
        log.info('Use pretrain model...')
        model.load_state_dict(torch.load(args.pretrain))
        init_epoch = int(args.pretrain[:-4].split('-')[-1])
        log.info('Restart training', epoch=init_epoch)
    else:
        log.msg('Training from scratch')
        init_epoch = 0

    best_acc = 0
    best_miou = 0

    for epoch in range(init_epoch, args.epoch):
        model.train()
        lr = calc_decay(args.learning_rate, epoch)
        log.info(subset=args.subset,
                 model=args.model_name,
                 gpu=args.gpu,
                 epoch=epoch,
                 lr=lr)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        for points, target in tqdm(dataloader,
                                   total=len(dataloader),
                                   smoothing=0.9,
                                   dynamic_ncols=True):
            points = points.float().transpose(2, 1).cuda()
            target = target.long().cuda()

            if args.model_name == 'pointnet':
                logits, trans_feat = model(points)
            else:
                logits = model(points)

            #logits = logits.contiguous().view(-1, num_classes)
            #target = target.view(-1, 1)[:, 0]
            #loss = F.nll_loss(logits, target)

            logits = logits.transpose(2, 1)
            loss = nn.CrossEntropyLoss()(logits, target)

            if args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        torch.cuda.empty_cache()

        acc, miou = test_kitti_semseg(model.eval(), testdataloader,
                                      args.model_name, num_classes,
                                      class_names)

        save_model = False
        if acc > best_acc:
            best_acc = acc

        if miou > best_miou:
            best_miou = miou
            save_model = True

        if save_model:
            fn_pth = '%s-%s-%.5f-%04d.pth' % (args.model_name, args.subset,
                                              best_miou, epoch)
            log.info('Save model...', fn=fn_pth)
            torch.save(model.state_dict(),
                       os.path.join(checkpoints_dir, fn_pth))
        else:
            log.info('No need to save model')

        log.warn('Curr', accuracy=acc, mIOU=miou)
        log.warn('Best', accuracy=best_acc, mIOU=best_miou)
Beispiel #6
0
def train(args):
    experiment_dir = mkdir('./experiment/')
    checkpoints_dir = mkdir('./experiment/clf/%s/' % (args.model_name))
    train_data, train_label, test_data, test_label = load_data(
        'experiment/data/modelnet40_ply_hdf5_2048/')

    trainDataset = ModelNetDataLoader(train_data,
                                      train_label,
                                      data_augmentation=args.augment)
    trainDataLoader = DataLoader(trainDataset,
                                 batch_size=args.batch_size,
                                 shuffle=True)

    testDataset = ModelNetDataLoader(test_data, test_label)
    testDataLoader = torch.utils.data.DataLoader(testDataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False)

    log.info('Building Model', args.model_name)
    if args.model_name == 'pointnet':
        num_class = 40
        model = PointNetCls(num_class, args.feature_transform).cuda()
    else:
        model = PointNet2ClsMsg().cuda()

    torch.backends.cudnn.benchmark = True
    model = torch.nn.DataParallel(model).cuda()
    log.debug('Using gpu:', args.gpu)

    if args.pretrain is not None:
        log.info('Use pretrain model...')
        state_dict = torch.load(args.pretrain)
        model.load_state_dict(state_dict)
        init_epoch = int(args.pretrain[:-4].split('-')[-1])
        log.info('start epoch from', init_epoch)
    else:
        log.info('Training from scratch')
        init_epoch = 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.5)
    LEARNING_RATE_CLIP = 1e-5

    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0

    log.info('Start training...')
    for epoch in range(init_epoch, args.epoch):
        scheduler.step()
        lr = max(optimizer.param_groups[0]['lr'], LEARNING_RATE_CLIP)

        log.debug(job='clf',
                  model=args.model_name,
                  gpu=args.gpu,
                  epoch='%d/%s' % (epoch, args.epoch),
                  lr=lr)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            target = target[:, 0]
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            model = model.train()
            pred, trans_feat = model(points)
            loss = F.nll_loss(pred, target.long())
            if args.feature_transform and args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001
            loss.backward()
            optimizer.step()
            global_step += 1

        log.debug('clear cuda cache')
        torch.cuda.empty_cache()

        acc = test_clf(model, testDataLoader)
        log.info(loss='%.5f' % (loss.data))
        log.info(Test_Accuracy='%.5f' % acc)

        if acc >= best_tst_accuracy:
            best_tst_accuracy = acc
            fn_pth = 'clf-%s-%.5f-%04d.pth' % (args.model_name, acc, epoch)
            log.debug('Saving model....', fn_pth)
            torch.save(model.state_dict(),
                       os.path.join(checkpoints_dir, fn_pth))
        global_epoch += 1

    log.info(Best_Accuracy=best_tst_accuracy)
    log.info('End of training...')
Beispiel #7
0
 def forward(self, pred, target, trans_feat, weight):
     loss = F.nll_loss(pred, target, weight=weight)
     mat_diff_loss = feature_transform_reguliarzer(trans_feat)
     total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale
     return total_loss
def main(args):
    os.environ[
        "CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else '0,1,2,3'
    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(
        str(experiment_dir) + '/%sScanNetSemSeg-' % args.model_name +
        str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        str(log_dir) + '/train_%s_semseg.txt' % args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(
        '---------------------------------------------------TRANING---------------------------------------------------'
    )
    logger.info('PARAMETER ...')
    logger.info(args)
    print('Load data...')

    dataset = ScannetDataset(root='./data', split='train', npoints=8192)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batchsize,
                                             shuffle=True,
                                             num_workers=int(args.workers))
    test_dataset = ScannetDatasetWholeScene(root='./data',
                                            split='test',
                                            npoints=8192)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=args.batchsize,
                                                 shuffle=False,
                                                 num_workers=int(args.workers))

    num_classes = 21
    model = PointNet2SemSeg(
        num_classes) if args.model_name == 'pointnet2' else PointNetSeg(
            num_classes, feature_transform=True, semseg=True)
    loss_function = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='none')
    #loss_function = torch.nn.CrossEntropyLoss(reduction='none')

    if args.pretrain is not None:
        model.load_state_dict(torch.load(args.pretrain))
        print('load model %s' % args.pretrain)
        logger.info('load model %s' % args.pretrain)
    else:
        print('Training from scratch')
        logger.info('Training from scratch')
    pretrain = args.pretrain
    init_epoch = int(pretrain[-14:-11]) if args.pretrain is not None else 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=100,
                                                gamma=0.5)
    LEARNING_RATE_CLIP = 1e-5
    '''GPU selection and multi-GPU'''
    if args.multi_gpu is not None:
        device_ids = [int(x) for x in args.multi_gpu.split(',')]
        torch.backends.cudnn.benchmark = True
        model.cuda(device_ids[0])
        model = torch.nn.DataParallel(model, device_ids=device_ids)
    else:
        model.cuda()

    history = defaultdict(lambda: list())
    best_acc = 0
    best_acc_epoch = 0
    best_mIoU = 0
    best_mIoU_epoch = 0

    for epoch in range(init_epoch, args.epoch):
        scheduler.step()
        lr = max(optimizer.param_groups[0]['lr'], LEARNING_RATE_CLIP)
        print('Learning rate:%f' % lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        train_loss_sum = 0.0
        train_acc_sum = 0.0
        for i, data in enumerate(dataloader):
            points, target, sample_weights = data
            points, target = points.float(), target.long()
            points = points.transpose(2, 1)
            points, target, sample_weights = points.cuda(), target.cuda(
            ), sample_weights.cuda()
            optimizer.zero_grad()
            model = model.train()
            if args.model_name == 'pointnet':
                pred, trans_feat = model(points)
            else:
                #pred = model(points,None)
                pred = model(points[:, :3, :], points[:, 3:, :])
                #pred = model(points[:,:3,:],None)
            pred = pred.contiguous().view(-1, num_classes)
            target = target.view(pred.size(0))
            weights = sample_weights.view(pred.size(0))
            loss = loss_function(pred, target)
            loss = loss * weights
            loss = torch.mean(loss)
            if args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001
            history['loss'].append(loss.item())
            train_loss_sum += loss.item()
            loss.backward()
            optimizer.step()
            # Train acc
            pred_val = torch.argmax(pred, 1)
            correct = torch.sum(
                ((pred_val == target) & (target > 0) & (weights > 0)).float())
            seen = torch.sum(((target > 0) & (weights > 0)).float()) + 1e-08
            train_acc = correct / seen if seen != 0 else correct
            train_acc_sum += train_acc.item()
            if (i + 1) % 5 == 0:
                print(
                    '[Epoch %d/%d] [Iteration %d/%d] TRAIN acc/loss: %f/%f ' %
                    (epoch + 1, args.epoch, i + 1, len(dataloader),
                     train_acc.item(), loss.item()))
                logger.info(
                    '[Epoch %d/%d] [Iteration %d/%d] TRAIN acc/loss: %f/%f ' %
                    (epoch + 1, args.epoch, i + 1, len(dataloader),
                     train_acc.item(), loss.item()))
        train_loss_avg = train_loss_sum / len(dataloader)
        train_acc_avg = train_acc_sum / len(dataloader)
        history['train_acc'].append(train_acc_avg)
        print('[Epoch %d/%d] TRAIN acc/loss: %f/%f ' %
              (epoch + 1, args.epoch, train_acc_avg, train_loss_avg))
        logger.info('[Epoch %d/%d] TRAIN acc/loss: %f/%f ' %
                    (epoch + 1, args.epoch, train_acc_avg, train_loss_avg))

        #Test acc
        test_losses = []
        total_correct = 0
        total_seen = 0
        total_correct_class = [0 for _ in range(num_classes)]
        total_seen_class = [0 for _ in range(num_classes)]
        total_intersection_class = [0 for _ in range(num_classes)]
        total_union_class = [0 for _ in range(num_classes)]

        total_correct_vox = 0
        total_seen_vox = 0
        total_seen_class_vox = [0 for _ in range(num_classes)]
        total_correct_class_vox = [0 for _ in range(num_classes)]
        total_intersection_class_vox = [0 for _ in range(num_classes)]
        total_union_class_vox = [0 for _ in range(num_classes)]

        labelweights = np.zeros(num_classes)
        labelweights_vox = np.zeros(num_classes)

        for j, data in enumerate(testdataloader):
            points, target, sample_weights = data
            points, target, sample_weights = points.float(), target.long(
            ), sample_weights.float()
            points = points.transpose(2, 1)
            points, target, sample_weights = points.cuda(), target.cuda(
            ), sample_weights.cuda()
            model = model.eval()
            if args.model_name == 'pointnet':
                pred, trans_feat = model(points)
            else:
                with torch.no_grad():
                    #pred = model(points,None)
                    pred = model(points[:, :3, :], points[:, 3:, :])
                    #pred = model(points[:,:3,:],None)
            pred_2d = pred.contiguous().view(-1, num_classes)
            target_1d = target.view(pred_2d.size(0))
            weights_1d = sample_weights.view(pred_2d.size(0))
            loss = loss_function(pred_2d, target_1d)
            loss = loss * weights_1d
            loss = torch.mean(loss)
            if args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001
            test_losses.append(loss.item())
            #first convert torch tensor to numpy array
            pred_np = pred.cpu().numpy()  #[B,N,C]
            target_np = target.cpu().numpy()  #[B,N]
            weights_np = sample_weights.cpu().numpy()  #[B,N]
            points_np = points.transpose(2, 1).cpu().numpy()  #[B,N,3]
            # point wise acc
            pred_val = np.argmax(pred_np, 2)  #[B,N]
            correct = np.sum((pred_val == target_np) & (target_np > 0)
                             & (weights_np > 0))
            total_correct += correct
            total_seen += np.sum((target_np > 0) & (weights_np > 0))

            tmp, _ = np.histogram(target_np, range(num_classes + 1))
            labelweights += tmp

            # point wise acc and IoU per class
            for l in range(num_classes):
                total_seen_class[l] += np.sum((target_np == l)
                                              & (weights_np > 0))
                total_correct_class[l] += np.sum((pred_val == l)
                                                 & (target_np == l)
                                                 & (weights_np > 0))
                total_intersection_class[l] += np.sum((pred_val == l)
                                                      & (target_np == l)
                                                      & (weights_np > 0))
                total_union_class[l] += np.sum((
                    (pred_val == l) | (target_np == l)) & (weights_np > 0))

            # voxel wise acc
            for b in range(target_np.shape[0]):
                _, uvlabel, _ = point_cloud_label_to_surface_voxel_label_fast(
                    points_np[b, weights_np[b, :] > 0, :],
                    np.concatenate(
                        (np.expand_dims(target_np[b, weights_np[b, :] > 0], 1),
                         np.expand_dims(pred_val[b, weights_np[b, :] > 0], 1)),
                        axis=1),
                    res=0.02)
                total_correct_vox += np.sum((uvlabel[:, 0] == uvlabel[:, 1])
                                            & (uvlabel[:, 0] > 0))
                total_seen_vox += np.sum(uvlabel[:, 0] > 0)
                tmp, _ = np.histogram(uvlabel[:, 0], range(num_classes + 1))
                labelweights_vox += tmp
                # voxel wise acc and IoU per class
                for l in range(num_classes):
                    total_seen_class_vox[l] += np.sum(uvlabel[:, 0] == l)
                    total_correct_class_vox[l] += np.sum((uvlabel[:, 0] == l)
                                                         & (uvlabel[:,
                                                                    1] == l))
                    total_intersection_class_vox[l] += np.sum(
                        (uvlabel[:, 0] == l) & (uvlabel[:, 1] == l))
                    total_union_class_vox[l] += np.sum((uvlabel[:, 0] == l)
                                                       | (uvlabel[:, 1] == l))

        test_loss = np.mean(test_losses)
        test_point_acc = total_correct / float(total_seen)
        history['test_point_acc'].append(test_point_acc)
        test_voxel_acc = total_correct_vox / float(total_seen_vox)
        history['test_voxel_acc'].append(test_voxel_acc)
        test_avg_class_point_acc = np.mean(
            np.array(total_correct_class[1:]) /
            (np.array(total_seen_class[1:], dtype=np.float) + 1e-6))
        history['test_avg_class_point_acc'].append(test_avg_class_point_acc)
        test_avg_class_voxel_acc = np.mean(
            np.array(total_correct_class_vox[1:]) /
            (np.array(total_seen_class_vox[1:], dtype=np.float) + 1e-6))
        history['test_avg_class_voxel_acc'].append(test_avg_class_voxel_acc)
        test_avg_class_point_IoU = np.mean(
            np.array(total_intersection_class[1:]) /
            (np.array(total_union_class[1:], dtype=np.float) + 1e-6))
        history['test_avg_class_point_IoU'].append(test_avg_class_point_IoU)
        test_avg_class_voxel_IoU = np.mean(
            np.array(total_intersection_class_vox[1:]) /
            (np.array(total_union_class_vox[1:], dtype=np.float) + 1e-6))
        history['test_avg_class_voxel_IoU'].append(test_avg_class_voxel_IoU)
        labelweights = labelweights[1:].astype(np.float32) / np.sum(
            labelweights[1:].astype(np.float32))
        labelweights_vox = labelweights_vox[1:].astype(np.float32) / np.sum(
            labelweights_vox[1:].astype(np.float32))
        #caliweights = np.array([0.388,0.357,0.038,0.033,0.017,0.02,0.016,0.025,0.002,0.002,0.002,0.007,0.006,0.022,0.004,0.0004,0.003,0.002,0.024,0.029])
        #test_cali_voxel_acc = np.average(np.array(total_correct_class_vox[1:])/(np.array(total_seen_class_vox[1:],dtype=np.float)+1e-6),weights=caliweights)
        #history['test_cali_voxel_acc'].append(test_cali_voxel_acc)
        #test_cali_point_acc = np.average(np.array(total_correct_class[1:])/(np.array(total_seen_class[1:],dtype=np.float)+1e-6),weights=caliweights)
        #history['test_cali_point_acc'].append(test_cali_point_acc)

        print('[Epoch %d/%d] TEST acc/loss: %f/%f ' %
              (epoch + 1, args.epoch, test_voxel_acc, test_loss))
        logger.info('[Epoch %d/%d] TEST acc/loss: %f/%f ' %
                    (epoch + 1, args.epoch, test_voxel_acc, test_loss))
        print('Whole scene point wise accuracy: %f' % (test_point_acc))
        logger.info('Whole scene point wise accuracy: %f' % (test_point_acc))
        print('Whole scene voxel wise accuracy: %f' % (test_voxel_acc))
        logger.info('Whole scene voxel wise accuracy: %f' % (test_voxel_acc))
        print('Whole scene class averaged point wise accuracy: %f' %
              (test_avg_class_point_acc))
        logger.info('Whole scene class averaged point wise accuracy: %f' %
                    (test_avg_class_point_acc))
        print('Whole scene class averaged voxel wise accuracy: %f' %
              (test_avg_class_voxel_acc))
        logger.info('Whole scene class averaged voxel wise accuracy: %f' %
                    (test_avg_class_voxel_acc))
        #print('Whole scene calibrated point wise accuracy: %f' % (test_cali_point_acc))
        #logger.info('Whole scene calibrated point wise accuracy: %f' % (test_cali_point_acc))
        #print('Whole scene calibrated voxel wise accuracy: %f' % (test_cali_voxel_acc))
        #logger.info('Whole scene calibrated voxel wise accuracy: %f' % (test_cali_voxel_acc))
        print('Whole scene class averaged point wise IoU: %f' %
              (test_avg_class_point_IoU))
        logger.info('Whole scene class averaged point wise IoU: %f' %
                    (test_avg_class_point_IoU))
        print('Whole scene class averaged voxel wise IoU: %f' %
              (test_avg_class_voxel_IoU))
        logger.info('Whole scene class averaged voxel wise IoU: %f' %
                    (test_avg_class_voxel_IoU))

        per_class_voxel_str = 'voxel based --------\n'
        for l in range(1, num_classes):
            per_class_voxel_str += 'class %d weight: %f, acc: %f, IoU: %f;\n' % (
                l, labelweights_vox[l - 1], total_correct_class_vox[l] / float(
                    total_seen_class_vox[l]), total_intersection_class_vox[l] /
                (float(total_union_class_vox[l]) + 1e-6))
        logger.info(per_class_voxel_str)

        per_class_point_str = 'point based --------\n'
        for l in range(1, num_classes):
            per_class_point_str += 'class %d weight: %f, acc: %f, IoU: %f;\n' % (
                l, labelweights[l - 1], total_correct_class[l] /
                float(total_seen_class[l]), total_intersection_class[l] /
                (float(total_union_class[l]) + 1e-6))
        logger.info(per_class_point_str)

        if (epoch + 1) % 5 == 0:
            torch.save(
                model.state_dict(), '%s/%s_%.3d.pth' %
                (checkpoints_dir, args.model_name, epoch + 1))
            logger.info('Save model..')
            print('Save model..')
        if test_voxel_acc > best_acc:
            best_acc = test_voxel_acc
            best_acc_epoch = epoch + 1
            torch.save(
                model.state_dict(), '%s/%s_%.3d_%.4f_bestacc.pth' %
                (checkpoints_dir, args.model_name, epoch + 1, best_acc))
            logger.info('Save best acc model..')
            print('Save best acc model..')
        if test_avg_class_voxel_IoU > best_mIoU:
            best_mIoU = test_avg_class_voxel_IoU
            best_mIoU_epoch = epoch + 1
            torch.save(
                model.state_dict(), '%s/%s_%.3d_%.4f_bestmIoU.pth' %
                (checkpoints_dir, args.model_name, epoch + 1, best_mIoU))
            logger.info('Save best mIoU model..')
            print('Save best mIoU model..')
    print('Best voxel wise accuracy is %f at epoch %d.' %
          (best_acc, best_acc_epoch))
    logger.info('Best voxel wise accuracy is %f at epoch %d.' %
                (best_acc, best_acc_epoch))
    print('Best class averaged voxel wise IoU is %f at epoch %d.' %
          (best_mIoU, best_mIoU_epoch))
    logger.info('Best class averaged voxel wise IoU is %f at epoch %d.' %
                (best_mIoU, best_mIoU_epoch))
    plot_loss_curve(history['loss'], str(log_dir))
    plot_acc_curve(history['train_acc'], history['test_voxel_acc'],
                   str(log_dir))
    plot_acc_curve(history['train_acc'], history['test_avg_class_voxel_IoU'],
                   str(log_dir))
    print('FINISH.')
    logger.info('FINISH')