Пример #1
0
def main():
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)
    print("\n**************************")
    for k, v in config['common'].items():
        setattr(args, k, v)
        print('\n[%s]:'%(k), v)
    print("\n**************************\n")
    
    try:
        os.makedirs(args.save_path)
    except OSError:
        pass
    
    train_transforms = transforms.Compose([
        d_utils.PointcloudToTensor()
    ])
    test_transforms = transforms.Compose([
        d_utils.PointcloudToTensor()
    ])
    
    train_dataset = ModelNet40Cls(num_points = args.num_points, root = args.data_root, transforms=train_transforms)
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=args.batch_size,
        shuffle=True, 
        num_workers=int(args.workers)
    )

    test_dataset = ModelNet40Cls(num_points = args.num_points, root = args.data_root, transforms=test_transforms, train=False)
    test_dataloader = DataLoader(
        test_dataset, 
        batch_size=args.batch_size,
        shuffle=False, 
        num_workers=int(args.workers)
    )
    
    model = RSCNN_SSN(num_classes = args.num_classes, input_channels = args.input_channels, relation_prior = args.relation_prior, use_xyz = True)
    model.cuda()
    optimizer = optim.Adam(
        model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay)

    lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), args.lr_clip / args.base_lr)
    bnm_lmbd = lambda e: max(args.bn_momentum * args.bn_decay**(e // args.decay_step), args.bnm_clip)
    lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd)
    bnm_scheduler = pt_utils.BNMomentumScheduler(model, bnm_lmbd)
    
    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        print('Load model successfully: %s' % (args.checkpoint))

    criterion = nn.CrossEntropyLoss()
    num_batch = len(train_dataset)/args.batch_size
    
    # training
    train(train_dataloader, test_dataloader, model, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch)
Пример #2
0
def main():
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)
    print("\n**************************")
    for k, v in config['common'].items():
        setattr(args, k, v)
        print('\n[%s]:' % (k), v)
    print("\n**************************\n")

    try:
        os.makedirs(args.save_path)
    except OSError:
        pass

    train_dataset = ModelNet40Cls(num_points=args.num_points,
                                  root=args.data_root,
                                  transforms=None)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=int(args.workers),
                                  pin_memory=True)

    test_dataset_z = ModelNet40Cls(num_points=args.num_points,
                                   root=args.data_root,
                                   transforms=None,
                                   train=False)
    test_dataloader_z = DataLoader(test_dataset_z,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=int(args.workers),
                                   pin_memory=True)

    test_dataset_so3 = ModelNet40Cls(num_points=args.num_points,
                                     root=args.data_root,
                                     transforms=None,
                                     train=False)
    test_dataloader_so3 = DataLoader(test_dataset_so3,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     num_workers=int(args.workers),
                                     pin_memory=True)
    if args.model == "pointnet2_ssn":
        model = PointNet2_SSN(num_classes=args.num_classes)
        model.cuda()
    elif args.model == "rscnn_ssn":
        model = RSCNN_SSN(num_classes=args.num_classes)
        model.cuda()
        model = torch.nn.DataParallel(model)
    elif args.model == "rscnn_msn":
        model = RSCNN_MSN(num_classes=args.num_classes)
        model.cuda()
        model = torch.nn.DataParallel(model)
    else:
        print("Doesn't support this model")
        return

    optimizer = optim.Adam(model.parameters(),
                           lr=args.base_lr,
                           weight_decay=args.weight_decay)
    lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), args.lr_clip
                            / args.base_lr)
    bnm_lmbd = lambda e: max(
        args.bn_momentum * args.bn_decay**
        (e // args.decay_step), args.bnm_clip)
    lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd)
    bnm_scheduler = pt_utils.BNMomentumScheduler(model, bnm_lmbd)

    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        print('Load model successfully: %s' % args.checkpoint)

    criterion = nn.CrossEntropyLoss()
    num_batch = len(train_dataset) / args.batch_size

    # training
    train(train_dataloader, test_dataloader_z, test_dataloader_so3, model,
          criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch)
Пример #3
0
def main():
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)

    test_transforms = transforms.Compose([d_utils.PointcloudToTensor()])

    test_dataset = ModelNet40Cls(num_points=args.num_points,
                                 root=args.data_root,
                                 transforms=test_transforms,
                                 train=False)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=int(args.workers),
                                 pin_memory=True)

    model = RSCNN_SSN(num_classes=args.num_classes,
                      input_channels=args.input_channels,
                      relation_prior=args.relation_prior,
                      use_xyz=True)
    model.cuda()

    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        print('Load model successfully: %s' % (args.checkpoint))

    # evaluate
    PointcloudScale = d_utils.PointcloudScale()  # initialize random scaling
    model.eval()
    global_acc = 0
    for i in range(NUM_REPEAT):
        preds = []
        labels = []
        with torch.no_grad():
            for j, data in enumerate(test_dataloader, 0):
                points, target = data
                points, target = points.cuda(), target.cuda()
                #points, target = Variable(points, volatile=True), Variable(target, volatile=True)

                # fastest point sampling
                fps_idx = pointnet2_utils.furthest_point_sample(
                    points, 1200)  # (B, npoint)
                pred = 0
                for v in range(NUM_VOTE):
                    new_fps_idx = fps_idx[:,
                                          np.random.
                                          choice(1200, args.num_points, False)]
                    new_points = pointnet2_utils.gather_operation(
                        points.transpose(1, 2).contiguous(),
                        new_fps_idx).transpose(1, 2).contiguous()
                    if v > 0:
                        new_points.data = PointcloudScale(new_points.data)
                    pred += F.softmax(model(new_points), dim=1)
                pred /= NUM_VOTE
                target = target.view(-1)
                _, pred_choice = torch.max(pred.data, -1)

                preds.append(pred_choice)
                labels.append(target.data)

            preds = torch.cat(preds, 0)
            labels = torch.cat(labels, 0)
            acc = (preds == labels).sum().item() / labels.numel()
            if acc > global_acc:
                global_acc = acc
            print('Repeat %3d \t Acc: %0.6f' % (i + 1, acc))
    print('\nBest voting acc: %0.6f' % (global_acc))
Пример #4
0
def main():
    global logger

    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.safe_load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)

    output_dir = args.save_path
    if output_dir:
        import time
        msg = 'init_train'
        output_dir = os.path.join(
            output_dir, "train_{}_{}".format(time.strftime("%m_%d_%H_%M_%S"),
                                             msg))
        os.makedirs(output_dir)

    logger = get_logger("RS-CNN", output_dir, prefix="train")
    logger.info("Running with config:\n{}".format(args))

    train_transforms = transforms.Compose([d_utils.PointcloudToTensor()])
    test_transforms = transforms.Compose([d_utils.PointcloudToTensor()])

    train_dataset = ModelNet40Cls(num_points=args.num_points,
                                  root=args.data_root,
                                  transforms=train_transforms)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=int(args.workers),
                                  pin_memory=True)

    test_dataset = ModelNet40Cls(num_points=args.num_points,
                                 root=args.data_root,
                                 transforms=test_transforms,
                                 train=False)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=int(args.workers),
                                 pin_memory=True)

    model = RSCNN_SSN(num_classes=args.num_classes,
                      input_channels=args.input_channels,
                      relation_prior=args.relation_prior,
                      use_xyz=True)
    model.cuda()
    optimizer = optim.Adam(model.parameters(),
                           lr=args.base_lr,
                           weight_decay=args.weight_decay)

    lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), args.lr_clip
                            / args.base_lr)
    bnm_lmbd = lambda e: max(
        args.bn_momentum * args.bn_decay**
        (e // args.decay_step), args.bnm_clip)
    lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd)
    bnm_scheduler = pt_utils.BNMomentumScheduler(model, bnm_lmbd)

    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        logger.info('Load model successfully: %s' % (args.checkpoint))

    criterion = nn.CrossEntropyLoss()
    num_batch = len(train_dataset) / args.batch_size

    # training
    train(train_dataloader, test_dataloader, model, criterion, optimizer,
          lr_scheduler, bnm_scheduler, args, num_batch)
Пример #5
0
def main():
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)

    test_transforms = transforms.Compose([d_utils.PointcloudToTensor()])

    test_dataset = Bosphorus_eval(num_points=args.num_points,
                                  root=args.data_root,
                                  transforms=test_transforms)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=int(args.workers))

    model = RSCNN_SSN(num_classes=args.num_classes,
                      input_channels=args.input_channels,
                      relation_prior=args.relation_prior,
                      use_xyz=True)
    model.cuda()

    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        print('Load model successfully: %s' % (args.checkpoint))

    # model is used for feature extraction, so no need FC layers
    model.FC_layer = nn.Linear(1024, 1024, bias=False).cuda()
    for para in model.parameters():
        para.requires_grad = False
    nn.init.eye_(model.FC_layer.weight)

    # evaluate
    #PointcloudScale = d_utils.PointcloudScale()   # initialize random scaling
    model.eval()
    global_acc = 0
    with torch.no_grad():
        Total_samples = 0
        Correct = 0
        gallery_points, gallery_labels = test_dataset.get_gallery()
        gallery_points, gallery_labels = gallery_points.cuda(
        ), gallery_labels.cuda()
        gallery_points = Variable(gallery_points)
        gallery_pred = model(gallery_points)
        print(gallery_pred.size())
        gallery_pred = F.normalize(gallery_pred)

        for j, data in enumerate(test_dataloader, 0):
            probe_points, probe_labels = data
            probe_points, probe_labels = probe_points.cuda(
            ), probe_labels.cuda()
            probe_points = Variable(probe_points)

            # get feature vetor for probe and gallery set from model
            probe_pred = model(probe_points)
            probe_pred = F.normalize(probe_pred)

            # make tensor to size (probe_num, gallery_num, C)
            probe_tmp = probe_pred.unsqueeze(1).expand(probe_pred.shape[0],
                                                       gallery_pred.shape[0],
                                                       probe_pred.shape[1])
            gallery_tmp = gallery_pred.unsqueeze(0).expand(
                probe_pred.shape[0], gallery_pred.shape[0],
                gallery_pred.shape[1])
            results = torch.sum(torch.mul(probe_tmp, gallery_tmp),
                                dim=2)  # cosine distance
            results = torch.argmax(results, dim=1)

            Total_samples += probe_points.shape[0]
            for i in np.arange(0, results.shape[0]):
                if gallery_labels[results[i]] == probe_labels[i]:
                    Correct += 1
        print('Total_samples:{}'.format(Total_samples))
        acc = float(Correct / Total_samples)
        if acc > global_acc:
            global_acc = acc
        print('Repeat %3d \t Acc: %0.6f' % (i + 1, acc))
    print('\nBest voting acc: %0.6f' % (global_acc))