Ejemplo n.º 1
0
def validate(train_dataloader, test_dataloader, encoder, args):
    global svm_best_acc40
    encoder.eval()

    test_features = []
    test_label = []

    train_features = []
    train_label = []

    PointcloudRotate = d_utils.PointcloudRotate()

    # feature extraction
    with torch.no_grad():
        for j, data in enumerate(train_dataloader, 0):
            points, target = data
            points, target = points.cuda(), target.cuda()

            num_points = 1024

            fps_idx = pointnet2_utils.furthest_point_sample(points, num_points)  # (B, npoint)
            points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()

            feature = encoder(points, get_feature=True)
            target = target.view(-1)

            train_features.append(feature.data)
            train_label.append(target.data)

        for j, data in enumerate(test_dataloader, 0):
            points, target = data
            points, target = points.cuda(), target.cuda()

            fps_idx = pointnet2_utils.furthest_point_sample(points, args.num_points)  # (B, npoint)
            points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()

            feature = encoder(points, get_feature=True)
            target = target.view(-1)
            test_label.append(target.data)
            test_features.append(feature.data)

        train_features = torch.cat(train_features, dim=0)
        train_label = torch.cat(train_label, dim=0)
        test_features = torch.cat(test_features, dim=0)
        test_label = torch.cat(test_label, dim=0)

    # train svm
    svm_acc = evaluate_svm(train_features.data.cpu().numpy(), train_label.data.cpu().numpy(), test_features.data.cpu().numpy(), test_label.data.cpu().numpy())

    if svm_acc > svm_best_acc40:
        svm_best_acc40 = svm_acc

    encoder.train()
    print('ModelNet 40 results: svm acc=', svm_acc, 'best svm acc=', svm_best_acc40)
    print(args.name, args.arch)

    return svm_acc
Ejemplo n.º 2
0
def validate(test_dataloader, model, criterion, args, iter): 
    global g_acc
    model.eval()
    losses, preds, labels = [], [], []
    with torch.no_grad():
        for j, data in enumerate(test_dataloader, 0):
            points, target = data
            points, target = points.cuda(), target.cuda()
            points, target = Variable(points), Variable(target)
            
            # fastest point sampling
            fps_idx = pointnet2_utils.furthest_point_sample(points, args.num_points)  # (B, npoint)
            # fps_idx = fps_idx[:, np.random.choice(1200, args.num_points, False)]
            points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()

            pred = model(points)
            target = target.view(-1)
            loss = criterion(pred, target)
            losses.append(loss.data.clone())
            _, pred_choice = torch.max(pred.data, -1)

            
            preds.append(pred_choice)
            labels.append(target.data)
            
        preds = torch.cat(preds, 0)
        labels = torch.cat(labels, 0)
        #print(torch.sum(preds == labels), labels.numel())
        acc = torch.sum(preds == labels).item()/labels.numel()
        print('\nval loss: %0.6f \t acc: %0.6f\n' %(np.array(losses).mean(), acc))
        if acc > g_acc:
            g_acc = acc
            torch.save(model.state_dict(), '%s/cls_ssn_iter_%d_acc_%0.6f.pth' % (args.save_path, iter, acc))
        model.train()
Ejemplo n.º 3
0
def validate(test_dataloader, model, criterion, args, iter, mode):
    global g_acc
    if mode == 'z':
        aug = d_utils.ZRotate()
    else:
        aug = d_utils.SO3Rotate()
    model.eval()
    losses, preds, labels = [], [], []
    for j, data in enumerate(test_dataloader, 0):
        points, normals, target = data
        points, normals, target = points.cuda(), normals.cuda(), target.cuda()

        # fastest point sampling
        fps_idx = pointnet2_utils.furthest_point_sample(
            points, args.num_points)
        points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), fps_idx)\
            .transpose(1, 2).contiguous()
        normals = pointnet2_utils.gather_operation(normals.transpose(1, 2).contiguous(), fps_idx)\
            .transpose(1, 2).contiguous()

        points.data, normals.data = aug(points.data, normals.data)

        with torch.no_grad():
            pred = model(points, normals)
            target = target.view(-1)
            loss = criterion(pred, target)
            losses.append(loss.data.clone())
            _, pred_choice = torch.max(pred.data, -1)
            preds.append(pred_choice)
            labels.append(target.data)

    preds = torch.cat(preds, 0)
    labels = torch.cat(labels, 0)
    acc = (preds == labels).sum().item() / labels.numel()
    print(acc)
    if mode == 'z':
        print('z val loss: %0.6f \t acc: %0.6f\n' %
              (np.array(losses).mean(), acc))
    else:
        print('so3 val loss: %0.6f \t acc: %0.6f\n' %
              (np.array(losses).mean(), acc))
    if acc > g_acc and mode != 'z':
        g_acc = acc
        torch.save(
            model.state_dict(),
            '%s/zso3ours_iter_%d_acc_%0.6f.pth' % (args.save_path, iter, acc))
    model.train()
Ejemplo n.º 4
0
    def __init__(self,
                 num_points,
                 root,
                 transforms=None,
                 train=True,
                 task='expression'):
        super().__init__()

        self.transforms = transforms
        self.num_points = num_points
        self.root = os.path.abspath(root)
        self.folder = 'Only_pts_BU3DFE'
        self.data_dir = os.path.join(self.root, self.folder)

        self.train = train
        if task == 'expression':
            if self.train:
                self.files, self.labels = _get_data_files(
                    os.path.join(self.data_dir, 'train_expression_6.txt'))
            else:
                self.files, self.labels = _get_data_files(
                    os.path.join(self.data_dir, 'test_expression_6.txt'))
        elif task == 'id':
            if self.train:
                self.files, self.labels = _get_data_files(
                    os.path.join(self.data_dir, 'BU3DFE_id_all.txt'))
            else:
                self.files, self.labels = _get_data_files(
                    os.path.join(self.data_dir, 'train_id.txt'))
        #print(self.files[0:10], self.labels[0:10])
        self.points = []
        np.random.seed(19970513)
        for file in self.files:
            single_p = np.loadtxt(os.path.join(self.root, file))
            single_p = pc_normalize(single_p)
            if self.num_points > single_p.shape[0]:
                idx = np.ones(single_p.shape[0], dtype=np.int32)
                idx[-1] = self.num_points - single_p.shape[0] + 1
                single_p_part = np.repeat(single_p, idx, axis=0)
            else:
                #idxs = np.random.choice(single_p.shape[0], self.num_points, replace=False)
                #single_p_part = single_p[idxs].copy()
                single_p_tensor = torch.from_numpy(single_p).type(
                    torch.FloatTensor).cuda()
                single_p_tensor = single_p_tensor.unsqueeze(
                    0)  # change to (1, N, 3)
                fps_idx = pointnet2_utils.furthest_point_sample(
                    single_p_tensor, self.num_points)  # (1, npoint)
                single_p_tensor = pointnet2_utils.gather_operation(
                    single_p_tensor.transpose(1, 2).contiguous(),
                    fps_idx).transpose(1, 2).contiguous()  # (1, npoint, 3)
                single_p_part = single_p_tensor.squeeze(0).cpu().numpy()
            #print(single_p.shape)
            self.points.append(single_p_part)

        self.points = np.stack(self.points, axis=0)  # need to be tested
        #self.points = np.array(self.points)
        self.labels = np.array(self.labels)
Ejemplo n.º 5
0
    def forward(self,
                xyz: torch.Tensor,
                features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the points
        features : torch.Tensor
            (B, C, N) tensor of the descriptors of the the points

        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new points' xyz
        new_features : torch.Tensor
            (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors
        """

        new_features_list = []
        xyz_flipped = xyz.transpose(1, 2).contiguous()
        if self.npoint is not None:
            #fps_idx = point_utils.farthest_point_sampling(xyz_flipped, self.npoint)  # (B, npoint)
            fps_idx = pointnet2_utils.furthest_point_sample(
                xyz_flipped, self.npoint)
            """# my code:
            if self.first_layer_: 
                all_feature = xyz_flipped
            else:
                features = features.contiguous() # features: (B, C, N)
                all_feature = torch.cat((xyz_flipped, features), dim=1) #(B,3+C,N)
            all_feature = self.mlp_for_att1(all_feature)
            all_feature = self.mlp_for_att2(all_feature).squeeze(1) # (B,N)
            att_score = F.softmax(all_feature, dim=1)
            fps_idx = torch.topk(att_score, k=self.npoint, dim=1)
            fps_idx = fps_idx.type(torch.cuda.IntTensor)
            """
            new_xyz = pointnet2_utils.gather_operation(xyz_flipped,
                                                       fps_idx).transpose(
                                                           1, 2).contiguous()
            #new_xyz = point_utils.index_points(xyz_flipped, fps_idx).transpose(1, 2).contiguous()
            #fps_idx = fps_idx.data
        else:
            new_xyz = None
            fps_idx = None

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](
                xyz, new_xyz, features,
                fps_idx) if self.npoint is not None else self.groupers[i](
                    xyz, new_xyz, features)  # (B, C, npoint, nsample)
            # print(new_features.shape)
            new_features = self.mlps[i](new_features)  # (B, mlp[-1], npoint)
            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1)
Ejemplo n.º 6
0
def train(train_dataloader, test_dataloader, model, criterion, optimizer,
          lr_scheduler, bnm_scheduler, args, num_batch):
    PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate(
    )  # initialize augmentation
    global g_acc
    g_acc = 0.91  # only save the model whose acc > 0.91
    batch_count = 0
    model.train()
    for epoch in range(args.epochs):
        for i, data in enumerate(train_dataloader, 0):
            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            if bnm_scheduler is not None:
                bnm_scheduler.step(epoch - 1)
            points, target = data
            points, target = points.cuda(), target.cuda()
            points, target = Variable(points), Variable(target)

            # farthest point sampling
            # fps_idx = pointnet2_utils.furthest_point_sample(points, 1200)  # (B, npoint)

            # random sampling
            fps_idx = np.random.randint(0,
                                        points.shape[1] - 1,
                                        size=[points.shape[0], 1200])
            fps_idx = torch.from_numpy(fps_idx).type(torch.IntTensor).cuda()

            fps_idx = fps_idx[:,
                              np.random.choice(1200, args.num_points, False)]
            points = pointnet2_utils.gather_operation(
                points.transpose(1, 2).contiguous(),
                fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)

            # augmentation
            points.data = PointcloudScaleAndTranslate(points.data)

            optimizer.zero_grad()

            pred = model(points)
            target = target.view(-1)
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()
            if i % args.print_freq_iter == 0:
                print(
                    '[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' %
                    (epoch + 1, i, num_batch, loss.data.clone(),
                     lr_scheduler.get_lr()[0]))
            batch_count += 1

            # validation in between an epoch
            if args.evaluate and batch_count % int(
                    args.val_freq_epoch * num_batch) == 0:
                validate(test_dataloader, model, criterion, args, batch_count)
Ejemplo n.º 7
0
def validate(test_dataloader, model, criterion, args, iter):
    global g_acc
    model.eval()
    losses, preds, labels = [], [], []
    for j, data in enumerate(test_dataloader, 0):
        points, target = data
        points, target = points.cuda(), target.cuda()
        points, target = Variable(points,
                                  volatile=True), Variable(target,
                                                           volatile=True)

        # farthest point sampling
        # fps_idx = pointnet2_utils.furthest_point_sample(points, args.num_points)  # (B, npoint)

        # random sampling
        fps_idx = np.random.randint(0,
                                    points.shape[1] - 1,
                                    size=[points.shape[0], args.num_points])
        fps_idx = torch.from_numpy(fps_idx).type(torch.IntTensor).cuda()

        # fps_idx = fps_idx[:, np.random.choice(1200, args.num_points, False)]
        points = pointnet2_utils.gather_operation(
            points.transpose(1, 2).contiguous(),
            fps_idx).transpose(1, 2).contiguous()

        pred = model(points)
        target = target.view(-1)
        loss = criterion(pred, target)
        losses.append(loss.data.clone())
        _, pred_choice = torch.max(pred.data, -1)

        preds.append(pred_choice)
        labels.append(target.data)

    preds = torch.cat(preds, 0)
    labels = torch.cat(labels, 0)
    acc = (preds == labels).sum() / labels.numel()
    print('\nval loss: %0.6f \t acc: %0.6f\n' % (np.array(losses).mean(), acc))
    if acc > g_acc:
        g_acc = acc
        torch.save(
            model.state_dict(),
            '%s/cls_iter_%d_acc_%0.6f.pth' % (args.save_path, iter, acc))
        print('saved model with accuracy ', acc)
    model.train()
Ejemplo n.º 8
0
    def forward(self,
                xyz: torch.Tensor,
                features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the points
        features : torch.Tensor
            (B, C, N) tensor of the descriptors of the the points

        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new points' xyz
        new_features : torch.Tensor
            (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors
        """

        new_features_list = []
        xyz_flipped = xyz.transpose(1, 2).contiguous()
        if self.npoint is not None:
            #fps_idx = point_utils.farthest_point_sampling(xyz_flipped, self.npoint)  # (B, npoint)
            fps_idx = pointnet2_utils.furthest_point_sample(
                xyz_flipped, self.npoint)
            new_xyz = pointnet2_utils.gather_operation(xyz_flipped,
                                                       fps_idx).transpose(
                                                           1, 2).contiguous()
            #new_xyz = point_utils.index_points(xyz_flipped, fps_idx).transpose(1, 2).contiguous()
            #fps_idx = fps_idx.data
        else:
            new_xyz = None
            fps_idx = None

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](
                xyz, new_xyz, features,
                fps_idx) if self.npoint is not None else self.groupers[i](
                    xyz, new_xyz, features)  # (B, C, npoint, nsample)
            # print(new_features.shape)
            new_features = self.mlps[i](new_features)  # (B, mlp[-1], npoint)
            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1)
Ejemplo n.º 9
0
def train(train_dataloader, test_dataloader_z, test_dataloader_so3, model,
          criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch):
    aug = d_utils.ZRotate()
    global g_acc
    g_acc = 0.88  # only save the model whose acc > g_acc
    batch_count = 0
    model.train()
    for epoch in range(args.epochs):
        for i, data in enumerate(train_dataloader, 0):
            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            if bnm_scheduler is not None:
                bnm_scheduler.step(epoch - 1)
            points, normals, target = data
            points, normals, target = points.cuda(), normals.cuda(
            ), target.cuda()
            if args.model == "pointnet2":
                fps_idx = pointnet2_utils.furthest_point_sample(
                    points, 1024)  # (B, npoint)
                points = pointnet2_utils.gather_operation(
                    points.transpose(1, 2).contiguous(),
                    fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)
                normals = pointnet2_utils.gather_operation(
                    normals.transpose(1, 2).contiguous(),
                    fps_idx).transpose(1, 2).contiguous()
            else:
                fps_idx = pointnet2_utils.furthest_point_sample(
                    points, 1200)  # (B, npoint)
                fps_idx = fps_idx[:,
                                  np.random.choice(1200, args.num_points, False
                                                   )]
                points = pointnet2_utils.gather_operation(
                    points.transpose(1, 2).contiguous(),
                    fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)
                normals = pointnet2_utils.gather_operation(
                    normals.transpose(1, 2).contiguous(),
                    fps_idx).transpose(1, 2).contiguous()
                # # RS-CNN performs a translation to the input first
                points.data = d_utils.PointcloudScaleAndTranslate()(
                    points.data)

            points.data, normals.data = aug(points.data, normals.data)

            optimizer.zero_grad()
            pred = model(points, normals)
            target = target.view(-1)
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()
            if i % args.print_freq_iter == 0:
                print(
                    '[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' %
                    (epoch + 1, i, num_batch, loss.data.clone(),
                     lr_scheduler.get_lr()[0]))
            batch_count += 1

            # validation in between an epoch
            if args.evaluate and batch_count % int(
                    args.val_freq_epoch * num_batch) == 0:
                # validate(test_dataloader_z, model, criterion, args, batch_count, 'z')
                validate(test_dataloader_so3, model, criterion, args,
                         batch_count, 'so3')
Ejemplo n.º 10
0
def main():
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)

    test_transforms = transforms.Compose([d_utils.PointcloudToTensor()])

    test_dataset = ModelNet40Cls(num_points=args.num_points,
                                 root=args.data_root,
                                 transforms=test_transforms,
                                 train=False)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=int(args.workers),
                                 pin_memory=True)

    model = DensePoint(num_classes=args.num_classes,
                       input_channels=args.input_channels,
                       use_xyz=True)
    model.cuda()

    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        print('Load model successfully: %s' % (args.checkpoint))

    # evaluate
    PointcloudScale = d_utils.PointcloudScale()  # initialize random scaling
    model.eval()
    global_acc = 0
    for i in range(NUM_REPEAT):
        preds = []
        labels = []

        s = time.time()
        for j, data in enumerate(test_dataloader, 0):
            points, target = data
            points, target = points.cuda(), target.cuda()
            points, target = Variable(points,
                                      volatile=True), Variable(target,
                                                               volatile=True)
            # points [batch_size, num_points, dimensions], e.g., [256, 2048, 3]

            # furthest point sampling
            # fps_idx = pointnet2_utils.furthest_point_sample(points, 1200)  # (B, npoint)

            # random sampling
            fps_idx = np.random.randint(0,
                                        points.shape[1] - 1,
                                        size=[points.shape[0], 1200])
            fps_idx = torch.from_numpy(fps_idx).type(torch.IntTensor).cuda()

            pred = 0
            for v in range(NUM_VOTE):
                new_fps_idx = fps_idx[:,
                                      np.random.choice(1200, args.
                                                       num_points, False)]
                new_points = pointnet2_utils.gather_operation(
                    points.transpose(1, 2).contiguous(),
                    new_fps_idx).transpose(1, 2).contiguous()
                if v > 0:
                    new_points.data = PointcloudScale(new_points.data)
                pred += F.softmax(model(new_points), dim=1)
            pred /= NUM_VOTE
            target = target.view(-1)
            _, pred_choice = torch.max(pred.data, -1)

            preds.append(pred_choice)
            labels.append(target.data)
        e = time.time()

        preds = torch.cat(preds, 0)
        labels = torch.cat(labels, 0)
        acc = (preds == labels).sum() / labels.numel()
        if acc > global_acc:
            global_acc = acc
        print('Repeat %3d \t Acc: %0.6f' % (i + 1, acc))
        print('time (secs) for 1 epoch: ', (e - s))
    print('\nBest voting acc: %0.6f' % (global_acc))
def main():
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)

    test_transforms = transforms.Compose([d_utils.PointcloudToTensor()])

    test_dataset = ModelNet40Cls(num_points=args.num_points,
                                 root=args.data_root,
                                 transforms=test_transforms,
                                 train=False)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=int(args.workers),
                                 pin_memory=False)

    model = RSCNN_SSN(num_classes=args.num_classes,
                      input_channels=args.input_channels,
                      relation_prior=args.relation_prior,
                      use_xyz=True)
    # for multi GPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
        model = nn.DataParallel(model, device_ids=[0, 1])
        model.to(device)
    elif torch.cuda.is_available() and torch.cuda.device_count() == 1:
        model.cuda()

    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        print('Load model successfully: %s' % (args.checkpoint))

    # evaluate
    PointcloudScale = d_utils.PointcloudScale()  # initialize random scaling
    model.eval()
    global_acc = 0
    for i in range(NUM_REPEAT):
        preds = []
        labels = []
        for j, data in enumerate(test_dataloader, 0):
            points, target = data
            points, target = points.cuda(), target.cuda()
            points, target = Variable(points,
                                      volatile=True), Variable(target,
                                                               volatile=True)

            # fastest point sampling
            fps_idx = pointnet2_utils.furthest_point_sample(
                points, 1200)  # (B, npoint)
            pred = 0
            for v in range(NUM_VOTE):
                new_fps_idx = fps_idx[:,
                                      np.random.choice(1200, args.
                                                       num_points, False)]
                new_points = pointnet2_utils.gather_operation(
                    points.transpose(1, 2).contiguous(),
                    new_fps_idx).transpose(1, 2).contiguous()
                if v > 0:
                    new_points.data = PointcloudScale(new_points.data)
                pred += F.softmax(model(new_points), dim=1)
            pred /= NUM_VOTE
            target = target.view(-1)
            _, pred_choice = torch.max(pred.data, -1)

            preds.append(pred_choice)
            labels.append(target.data)

        preds = torch.cat(preds, 0)
        labels = torch.cat(labels, 0)
        acc = (preds == labels).sum() / labels.numel()
        if acc > global_acc:
            global_acc = acc
        print('Repeat %3d \t Acc: %0.6f' % (i + 1, acc))
    print('\nBest voting acc: %0.6f' % (global_acc))
Ejemplo n.º 12
0
def train(ss_dataloader, train_dataloader, test_dataloader, encoder, decoer,
          optimizer, lr_scheduler, bnm_scheduler, args, num_batch,
          begin_epoch):
    PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate(
    )  # initialize augmentation
    PointcloudRotate = d_utils.PointcloudRotate()
    metric_criterion = MetricLoss()
    chamfer_criterion = ChamferLoss()
    global svm_best_acc40
    batch_count = 0
    encoder.train()
    decoer.train()

    for epoch in range(begin_epoch, args.epochs):
        np.random.seed()
        for i, data in enumerate(ss_dataloader, 0):
            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            if bnm_scheduler is not None:
                bnm_scheduler.step(epoch - 1)
            points = data
            points = Variable(points.cuda())

            # data augmentation
            sampled_points = 1200
            has_normal = (points.size(2) > 3)

            if has_normal:
                normals = points[:, :, 3:6].contiguous()
            points = points[:, :, 0:3].contiguous()

            fps_idx = pointnet2_utils.furthest_point_sample(
                points, sampled_points)  # (B, npoint)
            fps_idx = fps_idx[:,
                              np.random.choice(sampled_points, args.
                                               num_points, False)]
            points_gt = pointnet2_utils.gather_operation(
                points.transpose(1, 2).contiguous(),
                fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)
            if has_normal:
                normals = pointnet2_utils.gather_operation(
                    normals.transpose(1, 2).contiguous(), fps_idx)
            points = PointcloudScaleAndTranslate(points_gt.data)

            # optimize
            optimizer.zero_grad()

            features1, fuse_global, normals_pred = encoder(points)
            global_feature1 = features1[2].squeeze(2)
            refs1 = features1[0:2]
            recon1 = decoer(fuse_global).transpose(1, 2)  # bs, np, 3

            loss_metric = metric_criterion(global_feature1, refs1)
            loss_recon = chamfer_criterion(recon1, points_gt)
            if has_normal:
                loss_normals = NormalLoss(normals_pred, normals)
            else:
                loss_normals = normals_pred.new(1).fill_(0)
            loss = loss_recon + loss_metric + loss_normals
            loss.backward()
            optimizer.step()
            if i % args.print_freq_iter == 0:
                print(
                    '[epoch %3d: %3d/%3d] \t metric/chamfer/normal loss: %0.6f/%0.6f/%0.6f \t lr: %0.5f'
                    % (epoch + 1, i, num_batch, loss_metric.item(),
                       loss_recon.item(), loss_normals.item(),
                       lr_scheduler.get_lr()[0]))
            batch_count += 1

            # validation
            if args.evaluate and batch_count % int(
                    args.val_freq_epoch * num_batch) == 0:
                svm_acc40 = validate(train_dataloader, test_dataloader,
                                     encoder, args)

                save_dict = {
                    'epoch': epoch +
                    1,  # after training one epoch, the start_epoch should be epoch+1
                    'optimizer_state_dict': optimizer.state_dict(),
                    'encoder_state_dict': encoder.state_dict(),
                    'decoder_state_dict': decoer.state_dict(),
                    'svm_best_acc40': svm_best_acc40,
                }
                checkpoint_name = './ckpts/' + args.name + '.pth'
                torch.save(save_dict, checkpoint_name)
                if svm_acc40 == svm_best_acc40:
                    checkpoint_name = './ckpts/' + args.name + '_best.pth'
                    torch.save(save_dict, checkpoint_name)
Ejemplo n.º 13
0
    def __init__(self, num_points, root, transforms=None, task='id'):
        super().__init__()

        self.transforms = transforms
        self.num_points = num_points
        self.root = os.path.abspath(root)
        #self.folder = 'Only_pts_Bos'
        self.folder = 'Bos_Downsample'
        self.data_dir = os.path.join(self.root, self.folder)

        if task == 'id':
            self.probe_files, self.probe_labels = _get_data_files(
                os.path.join(self.data_dir, 'probe.txt'))

            self.gallery_files, self.gallery_labels = _get_data_files(
                os.path.join(self.data_dir, 'gallery.txt'))

        self.probe_points = []
        self.gallery_points = []
        for file in self.probe_files:
            single_p = np.loadtxt(os.path.join(self.root, file), delimiter=',')
            single_p = pc_normalize(single_p)
            if self.num_points > single_p.shape[0]:
                idx = np.ones(single_p.shape[0], dtype=np.int32)
                idx[-1] = self.num_points - single_p.shape[0] + 1
                single_p_part = np.repeat(single_p, idx, axis=0)
            else:
                #idxs = np.random.choice(single_p.shape[0], self.num_points, replace=False)
                #single_p_part = single_p[idxs].copy()
                single_p_tensor = torch.from_numpy(single_p).type(
                    torch.FloatTensor).cuda()
                single_p_tensor = single_p_tensor.unsqueeze(
                    0)  # change to (1, N, 3)
                fps_idx = pointnet2_utils.furthest_point_sample(
                    single_p_tensor, self.num_points)  # (1, npoint)
                single_p_tensor = pointnet2_utils.gather_operation(
                    single_p_tensor.transpose(1, 2).contiguous(),
                    fps_idx).transpose(1, 2).contiguous()  # (1, npoint, 3)
                single_p_part = single_p_tensor.squeeze(0).cpu().numpy()

            #print(single_p.shape)
            self.probe_points.append(single_p_part)

        for file in self.gallery_files:
            single_p = np.loadtxt(os.path.join(self.root, file), delimiter=',')
            single_p = pc_normalize(single_p)
            if self.num_points > single_p.shape[0]:
                idx = np.ones(single_p.shape[0], dtype=np.int32)
                idx[-1] = self.num_points - single_p.shape[0] + 1
                single_p_part = np.repeat(single_p, idx, axis=0)
            else:
                #idxs = np.random.choice(single_p.shape[0], self.num_points, replace=False)
                #single_p_part = single_p[idxs].copy()
                single_p_tensor = torch.from_numpy(single_p).type(
                    torch.FloatTensor).cuda()
                single_p_tensor = single_p_tensor.unsqueeze(
                    0)  # change to (1, N, 3)
                fps_idx = pointnet2_utils.furthest_point_sample(
                    single_p_tensor, self.num_points)  # (1, npoint)
                single_p_tensor = pointnet2_utils.gather_operation(
                    single_p_tensor.transpose(1, 2).contiguous(),
                    fps_idx).transpose(1, 2).contiguous()  # (1, npoint, 3)
                single_p_part = single_p_tensor.squeeze(0).cpu().numpy()

            #print(single_p.shape)
            self.gallery_points.append(single_p_part)

        self.probe_points = np.stack(self.probe_points, axis=0)
        self.gallery_points = np.stack(self.gallery_points,
                                       axis=0)  # need to be tested

        self.probe_labels = np.array(self.probe_labels)
        self.gallery_labels = np.array(self.gallery_labels)