Exemple #1
0
def evaluate(args):
    cache = _load(root)
    norm = True if args.model_name == 'pointnet' else False
    test_ds = PartNormalDataset(root, cache, npoints=2048, split='test')
    testdataloader = DataLoader(test_ds,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=int(args.workers))
    log.info("The number of test data is:", len(test_ds))

    log.info('Building Model', args.model_name)
    num_classes = 16
    num_part = 50
    if args.model_name == 'pointnet2':
        model = PointNet2PartSegMsg_one_hot(num_part)
    else:
        model = PointNetDenseCls(cat_num=num_classes, part_num=num_part)

    torch.backends.cudnn.benchmark = True
    model = torch.nn.DataParallel(model).cuda()
    log.debug('Using gpu:', args.gpu)

    if args.pretrain is None:
        log.err('No pretrain model')
        return

    log.debug('Loading pretrain model...')
    state_dict = torch.load(args.pretrain)
    model.load_state_dict(state_dict)

    log.info('Testing pretrain model...')

    test_metrics, test_hist_acc, cat_mean_iou = test_partseg(
        model.eval(),
        testdataloader,
        label_id_to_name,
        args.model_name,
        num_part,
    )

    log.info('test_hist_acc', len(test_hist_acc))
    log.info(cat_mean_iou)
    log.info('Test Accuracy', '%.5f' % test_metrics['accuracy'])
    log.info('Class avg mIOU:', '%.5f' % test_metrics['class_avg_iou'])
    log.info('Inctance avg mIOU:', '%.5f' % test_metrics['inctance_avg_iou'])
def main(args):
    os.environ[
        "CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else '0,1,2,3'
    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(
        str(experiment_dir) + '/' +
        str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        str(log_dir) + '/train_%s_partseg.txt' % args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(
        '---------------------------------------------------TRANING---------------------------------------------------'
    )
    logger.info('PARAMETER ...')
    logger.info(args)

    TRAIN_DATASET = PartNormalDataset(npoints=2048, split='trainval')
    dataloader = torch.utils.data.DataLoader(TRAIN_DATASET,
                                             batch_size=args.batchsize,
                                             shuffle=True,
                                             num_workers=int(args.workers))
    TEST_DATASET = PartNormalDataset(npoints=2048, split='test')
    testdataloader = torch.utils.data.DataLoader(TEST_DATASET,
                                                 batch_size=10,
                                                 shuffle=True,
                                                 num_workers=int(args.workers))
    print("The number of training data is:", len(TRAIN_DATASET))
    logger.info("The number of training data is:%d", len(TRAIN_DATASET))
    print("The number of test data is:", len(TEST_DATASET))
    logger.info("The number of test data is:%d", len(TEST_DATASET))
    num_classes = 16
    num_part = 50
    blue = lambda x: '\033[94m' + x + '\033[0m'
    model = PointNet2PartSeg_msg_one_hot(
        num_part) if args.model_name == 'pointnet2' else PointNetDenseCls(
            cat_num=num_classes, part_num=num_part)

    if args.pretrain is not None:
        model.load_state_dict(torch.load(args.pretrain))
        print('load model %s' % args.pretrain)
        logger.info('load model %s' % args.pretrain)
    else:
        print('Training from scratch')
        logger.info('Training from scratch')
    pretrain = args.pretrain
    init_epoch = int(pretrain[-14:-11]) if args.pretrain is not None else 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.5)
    '''GPU selection and multi-GPU'''
    if args.multi_gpu is not None:
        device_ids = [int(x) for x in args.multi_gpu.split(',')]
        torch.backends.cudnn.benchmark = True
        model.cuda(device_ids[0])
        model = torch.nn.DataParallel(model, device_ids=device_ids)
    else:
        model.cuda()
    criterion = PointNetLoss()
    LEARNING_RATE_CLIP = 1e-5

    history = defaultdict(lambda: list())
    best_acc = 0
    best_class_avg_iou = 0
    best_inctance_avg_iou = 0

    for epoch in range(init_epoch, args.epoch):
        scheduler.step()
        lr = max(optimizer.param_groups[0]['lr'], LEARNING_RATE_CLIP)
        print('Learning rate:%f' % lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        for i, data in tqdm(enumerate(dataloader, 0),
                            total=len(dataloader),
                            smoothing=0.9):
            points, label, target, norm_plt = data
            points, label, target = Variable(points.float()), Variable(
                label.long()), Variable(target.long())
            points = points.transpose(2, 1)
            norm_plt = norm_plt.transpose(2, 1)
            points, label, target, norm_plt = points.cuda(), label.squeeze(
            ).cuda(), target.cuda(), norm_plt.cuda()
            optimizer.zero_grad()
            model = model.train()
            if args.model_name == 'pointnet':
                labels_pred, seg_pred, trans_feat = model(
                    points, to_categorical(label, 16))
                seg_pred = seg_pred.contiguous().view(-1, num_part)
                target = target.view(-1, 1)[:, 0]
                loss, seg_loss, label_loss = criterion(labels_pred, label,
                                                       seg_pred, target,
                                                       trans_feat)
            else:
                seg_pred = model(points, norm_plt, to_categorical(label, 16))
                seg_pred = seg_pred.contiguous().view(-1, num_part)
                target = target.view(-1, 1)[:, 0]
                loss = F.nll_loss(seg_pred, target)

            history['loss'].append(loss.cpu().data.numpy())
            loss.backward()
            optimizer.step()

        forpointnet2 = args.model_name == 'pointnet2'
        test_metrics, test_hist_acc, cat_mean_iou = test_partseg(
            model, testdataloader, seg_label_to_cat, 50, forpointnet2)

        print(
            'Epoch %d %s accuracy: %f  Class avg mIOU: %f   Inctance avg mIOU: %f'
            %
            (epoch, blue('test'), test_metrics['accuracy'],
             test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))

        logger.info(
            'Epoch %d %s Accuracy: %f  Class avg mIOU: %f   Inctance avg mIOU: %f'
            %
            (epoch, blue('test'), test_metrics['accuracy'],
             test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))
        if test_metrics['accuracy'] > best_acc:
            best_acc = test_metrics['accuracy']
            torch.save(
                model.state_dict(), '%s/%s_%.3d_%.4f.pth' %
                (checkpoints_dir, args.model_name, epoch, best_acc))
            logger.info(cat_mean_iou)
            logger.info('Save model..')
            print('Save model..')
            print(cat_mean_iou)
        if test_metrics['class_avg_iou'] > best_class_avg_iou:
            best_class_avg_iou = test_metrics['class_avg_iou']
        if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
            best_inctance_avg_iou = test_metrics['inctance_avg_iou']
        print('Best accuracy is: %.5f' % best_acc)
        logger.info('Best accuracy is: %.5f' % best_acc)
        print('Best class avg mIOU is: %.5f' % best_class_avg_iou)
        logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou)
        print('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
        logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
Exemple #3
0
def train(args):
    experiment_dir = mkdir('./experiment/')
    checkpoints_dir = mkdir('./experiment/partseg/%s/' % (args.model_name))
    cache = _load(root)

    norm = True if args.model_name == 'pointnet' else False
    npoints = 2048
    train_ds = PartNormalDataset(root,
                                 cache,
                                 npoints=npoints,
                                 split='trainval',
                                 data_augmentation=args.augment)
    dataloader = DataLoader(train_ds,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=int(args.workers))

    test_ds = PartNormalDataset(root, cache, npoints=npoints, split='test')
    testdataloader = DataLoader(test_ds,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=int(args.workers))

    num_classes = 16
    num_part = 50
    log.info(len_training=len(train_ds), len_testing=len(test_ds))
    log.info(num_classes=num_classes, num_part=num_part)

    if args.model_name == 'pointnet':
        model = PointNetDenseCls(cat_num=num_classes, part_num=num_part)
    else:
        model = PointNet2PartSegMsg_one_hot(num_part)

    torch.backends.cudnn.benchmark = True
    model = torch.nn.DataParallel(model).cuda()
    log.debug('Using gpu:', args.gpu)

    if args.pretrain is not None and args.pretrain != 'None':
        log.debug('Use pretrain model...')
        model.load_state_dict(torch.load(args.pretrain))
        init_epoch = int(args.pretrain[:-4].split('-')[-1])
        log.debug('start epoch from', init_epoch)
    else:
        log.debug('Training from scratch')
        init_epoch = 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.5)

    history = {'loss': []}
    best_acc = 0
    best_class_avg_iou = 0
    best_inctance_avg_iou = 0
    LEARNING_RATE_CLIP = 1e-5

    # criterion = PointNetLoss()
    def feature_transform_reguliarzer(trans):
        d = trans.size()[1]
        I = torch.eye(d)[None, :, :]
        if trans.is_cuda:
            I = I.cuda()
        loss = torch.mean(
            torch.norm(torch.bmm(trans,
                                 trans.transpose(2, 1) - I), dim=(1, 2)))
        return loss

    def PointNet_Loss(labels_pred, label, seg_pred, seg, trans_feat):
        mat_diff_loss_scale = 0.001
        weight = 1
        seg_loss = F.nll_loss(seg_pred, seg)
        mat_diff_loss = feature_transform_reguliarzer(trans_feat)
        label_loss = F.nll_loss(labels_pred, label)
        loss = weight * seg_loss + (
            1 - weight) * label_loss + mat_diff_loss * mat_diff_loss_scale
        return loss, seg_loss, label_loss

    for epoch in range(init_epoch, args.epoch):
        scheduler.step()
        lr = max(optimizer.param_groups[0]['lr'], LEARNING_RATE_CLIP)
        log.info(job='partseg',
                 model=args.model_name,
                 gpu=args.gpu,
                 epoch='%d/%s' % (epoch, args.epoch),
                 lr=lr)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        for i, data in tqdm(enumerate(dataloader, 0),
                            total=len(dataloader),
                            smoothing=0.9):
            points, label, target, norm_plt = data
            points, label, target = points.float(), label.long(), target.long()
            points = points.transpose(2, 1)
            norm_plt = norm_plt.transpose(2, 1)
            points, label, target, norm_plt = points.cuda(), label.squeeze(
            ).cuda(), target.cuda(), norm_plt.cuda()
            optimizer.zero_grad()
            model = model.train()

            if args.model_name == 'pointnet':
                labels_pred, seg_pred, trans_feat = model(
                    points, to_categorical(label, 16))
                seg_pred = seg_pred.contiguous().view(-1, num_part)
                target = target.view(-1, 1)[:, 0]
                # loss, seg_loss, label_loss = criterion(labels_pred, label, seg_pred, target, trans_feat)
                loss, seg_loss, label_loss = PointNet_Loss(
                    labels_pred, label, seg_pred, target, trans_feat)
            else:
                seg_pred = model(points, norm_plt, to_categorical(label, 16))
                seg_pred = seg_pred.contiguous().view(-1, num_part)
                target = target.view(-1, 1)[:, 0]
                loss = F.nll_loss(seg_pred, target)

            history['loss'].append(loss.cpu().data.numpy())
            loss.backward()
            optimizer.step()

        log.debug('clear cuda cache')
        torch.cuda.empty_cache()

        test_metrics, test_hist_acc, cat_mean_iou = test_partseg(
            model.eval(),
            testdataloader,
            label_id_to_name,
            args.model_name,
            num_part,
        )

        save_model = False
        if test_metrics['accuracy'] > best_acc:
            best_acc = test_metrics['accuracy']

        if test_metrics['class_avg_iou'] > best_class_avg_iou:
            best_class_avg_iou = test_metrics['class_avg_iou']

        if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
            best_inctance_avg_iou = test_metrics['inctance_avg_iou']
            save_model = True

        if save_model:
            fn_pth = 'partseg-%s-%.5f-%04d.pth' % (
                args.model_name, best_inctance_avg_iou, epoch)
            log.info('Save model...', fn=fn_pth)
            torch.save(model.state_dict(),
                       os.path.join(checkpoints_dir, fn_pth))
            log.info(cat_mean_iou)
        else:
            log.info('No need to save model')

        log.warn('Curr',
                 accuracy=test_metrics['accuracy'],
                 class_avg_mIOU=test_metrics['class_avg_iou'],
                 inctance_avg_mIOU=test_metrics['inctance_avg_iou'])

        log.warn('Best',
                 accuracy=best_acc,
                 class_avg_mIOU=best_class_avg_iou,
                 inctance_avg_mIOU=best_inctance_avg_iou)