예제 #1
0
파일: train.py 프로젝트: zeta1999/PointGLR
def validate(train_dataloader, test_dataloader, encoder, args):
    global svm_best_acc40
    encoder.eval()

    test_features = []
    test_label = []

    train_features = []
    train_label = []

    PointcloudRotate = d_utils.PointcloudRotate()

    # feature extraction
    with torch.no_grad():
        for j, data in enumerate(train_dataloader, 0):
            points, target = data
            points, target = points.cuda(), target.cuda()

            num_points = 1024

            fps_idx = pointnet2_utils.furthest_point_sample(points, num_points)  # (B, npoint)
            points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()

            feature = encoder(points, get_feature=True)
            target = target.view(-1)

            train_features.append(feature.data)
            train_label.append(target.data)

        for j, data in enumerate(test_dataloader, 0):
            points, target = data
            points, target = points.cuda(), target.cuda()

            fps_idx = pointnet2_utils.furthest_point_sample(points, args.num_points)  # (B, npoint)
            points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()

            feature = encoder(points, get_feature=True)
            target = target.view(-1)
            test_label.append(target.data)
            test_features.append(feature.data)

        train_features = torch.cat(train_features, dim=0)
        train_label = torch.cat(train_label, dim=0)
        test_features = torch.cat(test_features, dim=0)
        test_label = torch.cat(test_label, dim=0)

    # train svm
    svm_acc = evaluate_svm(train_features.data.cpu().numpy(), train_label.data.cpu().numpy(), test_features.data.cpu().numpy(), test_label.data.cpu().numpy())

    if svm_acc > svm_best_acc40:
        svm_best_acc40 = svm_acc

    encoder.train()
    print('ModelNet 40 results: svm acc=', svm_acc, 'best svm acc=', svm_best_acc40)
    print(args.name, args.arch)

    return svm_acc
예제 #2
0
    return parser.parse_args()


lr_clip = 1e-5
bnm_clip = 1e-2

if __name__ == "__main__":
    args = parse_args()

    BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')

    transforms = transforms.Compose([
        d_utils.PointcloudToTensor(),
        d_utils.PointcloudScale(),
        d_utils.PointcloudRotate(),
        d_utils.PointcloudRotatePerturbation(),
        d_utils.PointcloudTranslate(),
        d_utils.PointcloudJitter(),
        d_utils.PointcloudRandomInputDropout()
    ])

    test_set = ModelNet40Cls(args.num_points,
                             BASE_DIR,
                             transforms=transforms,
                             train=False)
    test_loader = DataLoader(test_set,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=2,
                             pin_memory=True)
예제 #3
0
def train(ss_dataloader, train_dataloader, test_dataloader, encoder, decoer,
          optimizer, lr_scheduler, bnm_scheduler, args, num_batch,
          begin_epoch):
    PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate(
    )  # initialize augmentation
    PointcloudRotate = d_utils.PointcloudRotate()
    metric_criterion = MetricLoss()
    chamfer_criterion = ChamferLoss()
    global svm_best_acc40
    batch_count = 0
    encoder.train()
    decoer.train()

    for epoch in range(begin_epoch, args.epochs):
        np.random.seed()
        for i, data in enumerate(ss_dataloader, 0):
            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            if bnm_scheduler is not None:
                bnm_scheduler.step(epoch - 1)
            points = data
            points = Variable(points.cuda())

            # data augmentation
            sampled_points = 1200
            has_normal = (points.size(2) > 3)

            if has_normal:
                normals = points[:, :, 3:6].contiguous()
            points = points[:, :, 0:3].contiguous()

            fps_idx = pointnet2_utils.furthest_point_sample(
                points, sampled_points)  # (B, npoint)
            fps_idx = fps_idx[:,
                              np.random.choice(sampled_points, args.
                                               num_points, False)]
            points_gt = pointnet2_utils.gather_operation(
                points.transpose(1, 2).contiguous(),
                fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)
            if has_normal:
                normals = pointnet2_utils.gather_operation(
                    normals.transpose(1, 2).contiguous(), fps_idx)
            points = PointcloudScaleAndTranslate(points_gt.data)

            # optimize
            optimizer.zero_grad()

            features1, fuse_global, normals_pred = encoder(points)
            global_feature1 = features1[2].squeeze(2)
            refs1 = features1[0:2]
            recon1 = decoer(fuse_global).transpose(1, 2)  # bs, np, 3

            loss_metric = metric_criterion(global_feature1, refs1)
            loss_recon = chamfer_criterion(recon1, points_gt)
            if has_normal:
                loss_normals = NormalLoss(normals_pred, normals)
            else:
                loss_normals = normals_pred.new(1).fill_(0)
            loss = loss_recon + loss_metric + loss_normals
            loss.backward()
            optimizer.step()
            if i % args.print_freq_iter == 0:
                print(
                    '[epoch %3d: %3d/%3d] \t metric/chamfer/normal loss: %0.6f/%0.6f/%0.6f \t lr: %0.5f'
                    % (epoch + 1, i, num_batch, loss_metric.item(),
                       loss_recon.item(), loss_normals.item(),
                       lr_scheduler.get_lr()[0]))
            batch_count += 1

            # validation
            if args.evaluate and batch_count % int(
                    args.val_freq_epoch * num_batch) == 0:
                svm_acc40 = validate(train_dataloader, test_dataloader,
                                     encoder, args)

                save_dict = {
                    'epoch': epoch +
                    1,  # after training one epoch, the start_epoch should be epoch+1
                    'optimizer_state_dict': optimizer.state_dict(),
                    'encoder_state_dict': encoder.state_dict(),
                    'decoder_state_dict': decoer.state_dict(),
                    'svm_best_acc40': svm_best_acc40,
                }
                checkpoint_name = './ckpts/' + args.name + '.pth'
                torch.save(save_dict, checkpoint_name)
                if svm_acc40 == svm_best_acc40:
                    checkpoint_name = './ckpts/' + args.name + '_best.pth'
                    torch.save(save_dict, checkpoint_name)
예제 #4
0
from data.ModelNet40Loader import ModelNet40Cls
import data.data_utils as d_utils
from models.pointnet2_ssg_sem import PointNet2SemSegSSG

if __name__ == '__main__':
    
    # device = torch.device('cuda:0')
    torch.cuda.set_device(1)
    num_points_per_shape = 1024

    # build dataloaders
    transforms_with_aug = transforms.Compose(
        [
            d_utils.PointcloudToTensor(),
            d_utils.PointcloudRotate(axis=np.array([1, 0, 0])),
            d_utils.PointcloudScale(),
            d_utils.PointcloudTranslate(),
            d_utils.PointcloudJitter(),
        ]
    ) # it performs random rotate, scale, shift, jitter (add random noise)
    transform = transforms.Compose(
        [
            d_utils.PointcloudToTensor(),
        ]
    )
    trainset = ModelNet40Cls(num_points_per_shape, train=True, transforms=transforms_with_aug)
    testset = ModelNet40Cls(num_points_per_shape, train=False, transforms=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)