Exemple #1
0
def main():
    global args, logger
    args = get_parser()
    logger = get_logger()
    logger.info(args)
    assert args.classes > 1
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    if args.arch == 'pointnet_seg':
        from model.pointnet.pointnet import PointNetSeg as Model
    elif args.arch == 'pointnet2_seg':
        from model.pointnet2.pointnet2_seg import PointNet2SSGSeg as Model
    elif args.arch == 'pointweb_seg':
        from model.pointweb.pointweb_seg import PointWebSeg as Model
    elif args.arch == 'pointweb_cls':
        from model.pointweb.pointweb_cls import PointWebCls as Model
    else:
        raise Exception('architecture not supported yet'.format(args.arch))
    model = Model(c=args.fea_dim, k=args.classes, use_xyz=args.use_xyz)
    model = torch.nn.DataParallel(model.cuda())
    logger.info(model)
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    names = [line.rstrip('\n') for line in open(args.names_path)]
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))
    test(model, criterion, names)
Exemple #2
0
def main():
    init()
    if args.arch == 'pointnet_seg':
        from model.pointnet.pointnet import PointNetSeg as Model
    elif args.arch == 'pointnet2_seg':
        from model.pointnet2.pointnet2_seg import PointNet2SSGSeg as Model
    elif args.arch == 'pointweb_seg':
        from model.pointweb.pointweb_seg import PointWebSeg as Model
    else:
        raise Exception('architecture not supported yet'.format(args.arch))
    model = Model(c=args.fea_dim, k=args.classes, use_xyz=args.use_xyz)
    if args.sync_bn:
        from util.util import convert_to_syncbn
        from lib.sync_bn import patch_replication_callback
        convert_to_syncbn(model), patch_replication_callback(model)
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=args.step_epoch,
                                    gamma=args.multiplier)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))
    logger.info(model)
    model = torch.nn.DataParallel(model.cuda())
    if args.weight:
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    train_transform = transform.Compose([transform.ToTensor()])
    if args.data_name == 's3dis':
        train_data = S3DIS(split='train',
                           data_root=args.train_full_folder,
                           num_point=args.num_point,
                           test_area=args.test_area,
                           block_size=args.block_size,
                           sample_rate=args.sample_rate,
                           transform=train_transform)
        # train_data = dataset.PointData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    elif args.data_name == 'scannet':
        train_data = ScanNet(split='train',
                             data_root=args.data_root,
                             num_point=args.num_point,
                             block_size=args.block_size,
                             sample_rate=args.sample_rate,
                             transform=train_transform)
    elif args.data_name == 'modelnet40':
        train_data = dataset.PointData(split='train',
                                       data_root=args.data_root,
                                       data_list=args.train_list,
                                       transform=train_transform,
                                       num_point=args.num_point,
                                       random_index=True)
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.train_batch_size,
        shuffle=True,
        num_workers=args.train_workers,
        pin_memory=True)

    val_loader = None
    if args.evaluate:
        val_transform = transform.Compose([transform.ToTensor()])
        val_data = dataset.PointData(split='val',
                                     data_root=args.data_root,
                                     data_list=args.val_list,
                                     transform=val_transform)
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.train_batch_size_val,
            shuffle=False,
            num_workers=args.train_workers,
            pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, criterion, optimizer, epoch)
        epoch_log = epoch + 1
        writer.add_scalar('loss_train', loss_train, epoch_log)
        writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
        writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
        writer.add_scalar('allAcc_train', allAcc_train, epoch_log)

        if epoch_log % args.save_freq == 0:
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict()
                }, filename)
            if epoch_log / args.save_freq > 2:
                deletename = args.save_path + '/train_epoch_' + str(
                    epoch_log - args.save_freq * 2) + '.pth'
                os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion)
            writer.add_scalar('loss_val', loss_val, epoch_log)
            writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
            writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
            writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
Exemple #3
0
def main():
    init()
    if args.arch == 'pointnet_seg':
        from model.pointnet.pointnet import PointNetSeg as Model
    elif args.arch == 'pointnet2_seg':
        from model.pointnet2.pointnet2_seg import PointNet2SSGSeg as Model
    elif args.arch == 'pointnet2_paconv_seg':
        from model.pointnet2.pointnet2_paconv_seg import PointNet2SSGSeg as Model
    else:
        raise Exception('architecture not supported yet'.format(args.arch))
    model = Model(c=args.fea_dim,
                  k=args.classes,
                  use_xyz=args.use_xyz,
                  args=args)

    best_mIoU = 0.0

    if args.sync_bn:
        from util.util import convert_to_syncbn
        convert_to_syncbn(model)
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.get('lr_multidecay', False):
        scheduler = lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[int(args.epochs * 0.6),
                        int(args.epochs * 0.8)],
            gamma=args.multiplier)
    else:
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=args.step_epoch,
                                        gamma=args.multiplier)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))
    logger.info(model)
    model = torch.nn.DataParallel(model.cuda())
    if args.sync_bn:
        from lib.sync_bn import patch_replication_callback
        patch_replication_callback(model)
    if args.weight:
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            try:
                best_mIoU = checkpoint['val_mIoU']
            except Exception:
                pass
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    if args.get('no_transformation', True):
        train_transform = None
    else:
        train_transform = transform.Compose([
            transform.RandomRotate(along_z=args.get('rotate_along_z', True)),
            transform.RandomScale(scale_low=args.get('scale_low', 0.8),
                                  scale_high=args.get('scale_high', 1.2)),
            transform.RandomJitter(sigma=args.get('jitter_sigma', 0.01),
                                   clip=args.get('jitter_clip', 0.05)),
            transform.RandomDropColor(
                color_augment=args.get('color_augment', 0.0))
        ])
    logger.info(train_transform)
    if args.data_name == 's3dis':
        train_data = S3DIS(split='train',
                           data_root=args.train_full_folder,
                           num_point=args.num_point,
                           test_area=args.test_area,
                           block_size=args.block_size,
                           sample_rate=args.sample_rate,
                           transform=train_transform,
                           fea_dim=args.get('fea_dim', 6),
                           shuffle_idx=args.get('shuffle_idx', False))
    else:
        raise ValueError('{} dataset not supported.'.format(args.data_name))
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.train_batch_size,
        shuffle=True,
        num_workers=args.train_workers,
        pin_memory=True,
        drop_last=True)

    val_loader = None
    if args.evaluate:
        val_transform = transform.Compose([transform.ToTensor()])
        if args.data_name == 's3dis':
            val_data = dataset.PointData(split='val',
                                         data_root=args.data_root,
                                         data_list=args.val_list,
                                         transform=val_transform,
                                         norm_as_feat=args.get(
                                             'norm_as_feat', True),
                                         fea_dim=args.get('fea_dim', 6))
        else:
            raise ValueError('{} dataset not supported.'.format(
                args.data_name))

        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.train_batch_size_val,
            shuffle=False,
            num_workers=args.train_workers,
            pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs):
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, criterion, optimizer, epoch,
            args.get('correlation_loss', False))
        epoch_log = epoch + 1
        writer.add_scalar('loss_train', loss_train, epoch_log)
        writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
        writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
        writer.add_scalar('allAcc_train', allAcc_train, epoch_log)

        if epoch_log % args.save_freq == 0:
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'commit_id': get_git_commit_id()
                }, filename)
            if epoch_log / args.save_freq > 2:
                try:
                    deletename = args.save_path + '/train_epoch_' + str(
                        epoch_log - args.save_freq * 2) + '.pth'
                    os.remove(deletename)
                except Exception:
                    logger.info('{} Not found.'.format(deletename))

        if args.evaluate and epoch_log % args.get('eval_freq', 1) == 0:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion)
            writer.add_scalar('loss_val', loss_val, epoch_log)
            writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
            writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
            writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
            if mIoU_val > best_mIoU:
                best_mIoU = mIoU_val
                filename = args.save_path + '/best_train.pth'
                logger.info('Best Model Saving checkpoint to: ' + filename)
                torch.save(
                    {
                        'epoch': epoch_log,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'val_mIoU': best_mIoU,
                        'commit_id': get_git_commit_id()
                    }, filename)
        scheduler.step()