コード例 #1
0
def train(args):
    experiment_dir = mkdir('./experiment/')
    checkpoints_dir = mkdir('./experiment/semseg/%s/'%(args.model_name))
    train_data, train_label, test_data, test_label = _load()

    dataset = S3DISDataLoader(train_data, train_label, data_augmentation = args.augment)
    dataloader = DataLoader(dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers)
    
    test_dataset = S3DISDataLoader(test_data, test_label)
    testdataloader = DataLoader(test_dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers)

    num_classes = 13
    if args.model_name == 'pointnet':
        model = PointNetSeg(num_classes, feature_transform=True, input_dims = 9)
    else:
        model = PointNet2SemSeg(num_classes, feature_dims = 6)

    torch.backends.cudnn.benchmark = True
    model = torch.nn.DataParallel(model).cuda()
    log.debug('Using gpu:',args.gpu)

    if args.pretrain is not None:
        log.debug('Use pretrain model...')
        model.load_state_dict(torch.load(args.pretrain))
        init_epoch = int(args.pretrain[:-4].split('-')[-1])
        log.debug('start epoch from', init_epoch)
    else:
        log.debug('Training from scratch')
        init_epoch = 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate)
            
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    LEARNING_RATE_CLIP = 1e-5

    history = {'loss':[]}
    best_acc = 0
    best_meaniou = 0

    for epoch in range(init_epoch,args.epoch):
        scheduler.step()
        lr = max(optimizer.param_groups[0]['lr'],LEARNING_RATE_CLIP)

        log.info(job='semseg',model=args.model_name,gpu=args.gpu,epoch='%d/%s' % (epoch, args.epoch),lr=lr)
        
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
        for points, target in tqdm(dataloader, total=len(dataloader), smoothing=0.9, dynamic_ncols=True):
            points, target = points.float(), target.long()
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            model = model.train()

            if args.model_name == 'pointnet':
                pred, trans_feat = model(points)
            else:
                pred = model(points)

            pred = pred.contiguous().view(-1, num_classes)
            target = target.view(-1, 1)[:, 0]
            loss = F.nll_loss(pred, target)

            if args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001

            history['loss'].append(loss.cpu().data.numpy())
            loss.backward()
            optimizer.step()
        
        log.debug('clear cuda cache')
        torch.cuda.empty_cache()

        test_metrics, cat_mean_iou = test_semseg(
            model.eval(), 
            testdataloader, 
            label_id_to_name,
            args.model_name,
            num_classes,
        )
        mean_iou = np.mean(cat_mean_iou)

        save_model = False
        if test_metrics['accuracy'] > best_acc:
            best_acc = test_metrics['accuracy']
        
        if mean_iou > best_meaniou:
            best_meaniou = mean_iou
            save_model = True
        
        if save_model:
            fn_pth = 'semseg-%s-%.5f-%04d.pth' % (args.model_name, best_meaniou, epoch)
            log.info('Save model...',fn = fn_pth)
            torch.save(model.state_dict(), os.path.join(checkpoints_dir, fn_pth))
            log.warn(cat_mean_iou)
        else:
            log.info('No need to save model')
            log.warn(cat_mean_iou)

        log.warn('Curr',accuracy=test_metrics['accuracy'], meanIOU=mean_iou)
        log.warn('Best',accuracy=best_acc, meanIOU=best_meaniou)
コード例 #2
0
ファイル: pcdseg.py プロジェクト: jtpils/PointNet12
def train(args):
    experiment_dir = mkdir('experiment/')
    checkpoints_dir = mkdir('experiment/%s/' % (args.model_name))

    kitti_utils = Semantic_KITTI_Utils(KITTI_ROOT, subset=args.subset)
    class_names = kitti_utils.class_names
    num_classes = kitti_utils.num_classes

    if args.subset == 'inview':
        train_npts = 8000
        test_npts = 24000

    if args.subset == 'all':
        train_npts = 50000
        test_npts = 100000

    log.info(subset=args.subset, train_npts=train_npts, test_npts=test_npts)

    dataset = SemKITTI_Loader(KITTI_ROOT,
                              train_npts,
                              train=True,
                              subset=args.subset)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.workers,
                            pin_memory=True)

    test_dataset = SemKITTI_Loader(KITTI_ROOT,
                                   test_npts,
                                   train=False,
                                   subset=args.subset)
    testdataloader = DataLoader(test_dataset,
                                batch_size=int(args.batch_size / 2),
                                shuffle=False,
                                num_workers=args.workers,
                                pin_memory=True)

    if args.model_name == 'pointnet':
        model = PointNetSeg(num_classes, input_dims=4, feature_transform=True)
    else:
        model = PointNet2SemSeg(num_classes, feature_dims=1)

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-4)

    torch.backends.cudnn.benchmark = True
    model = torch.nn.DataParallel(model)
    model.cuda()
    log.info('Using gpu:', args.gpu)

    if args.pretrain is not None:
        log.info('Use pretrain model...')
        model.load_state_dict(torch.load(args.pretrain))
        init_epoch = int(args.pretrain[:-4].split('-')[-1])
        log.info('Restart training', epoch=init_epoch)
    else:
        log.msg('Training from scratch')
        init_epoch = 0

    best_acc = 0
    best_miou = 0

    for epoch in range(init_epoch, args.epoch):
        model.train()
        lr = calc_decay(args.learning_rate, epoch)
        log.info(subset=args.subset,
                 model=args.model_name,
                 gpu=args.gpu,
                 epoch=epoch,
                 lr=lr)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        for points, target in tqdm(dataloader,
                                   total=len(dataloader),
                                   smoothing=0.9,
                                   dynamic_ncols=True):
            points = points.float().transpose(2, 1).cuda()
            target = target.long().cuda()

            if args.model_name == 'pointnet':
                logits, trans_feat = model(points)
            else:
                logits = model(points)

            #logits = logits.contiguous().view(-1, num_classes)
            #target = target.view(-1, 1)[:, 0]
            #loss = F.nll_loss(logits, target)

            logits = logits.transpose(2, 1)
            loss = nn.CrossEntropyLoss()(logits, target)

            if args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        torch.cuda.empty_cache()

        acc, miou = test_kitti_semseg(model.eval(), testdataloader,
                                      args.model_name, num_classes,
                                      class_names)

        save_model = False
        if acc > best_acc:
            best_acc = acc

        if miou > best_miou:
            best_miou = miou
            save_model = True

        if save_model:
            fn_pth = '%s-%s-%.5f-%04d.pth' % (args.model_name, args.subset,
                                              best_miou, epoch)
            log.info('Save model...', fn=fn_pth)
            torch.save(model.state_dict(),
                       os.path.join(checkpoints_dir, fn_pth))
        else:
            log.info('No need to save model')

        log.warn('Curr', accuracy=acc, mIOU=miou)
        log.warn('Best', accuracy=best_acc, mIOU=best_miou)