コード例 #1
0
    def __init__(self):
        model = get_model()
        print('local_rank', args.node_rank)
        torch.cuda.set_device(args.node_rank)
        model = model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.node_rank], output_device=args.node_rank)
        print("From rank %d: start training, time:%s" %
              (args.node_rank, time.strftime("%Y-%m-%d %H:%M:%S")))
        criterion = get_loss
        optimizer = get_optimizer(model)
        self.ce_loss = torch.nn.BCELoss()
        self.mse_loss = torch.nn.MSELoss()
        self.wing_loss = get_wing_loss
        self.focal_loss = get_focal_loss
        super(FaceAlign_Trainer, self).__init__(model=model,
                                                criterion=criterion,
                                                optimizer=optimizer,
                                                opt=opt)

        from tensorboardX import SummaryWriter
        self.writer = SummaryWriter(opt.vis_dir)
        self.best_metric = 10000
        for m in all_metrics:
            self.metric_meter[m] = meter.AverageValueMeter()
コード例 #2
0
 def __init__(self, model):
     criterion = get_loss
     optimizer = get_optimizer(model)
     self.ce_loss = torch.nn.BCELoss()
     self.mse_loss = torch.nn.MSELoss()
     self.wing_loss = get_wing_loss
     self.focal_loss = get_focal_loss
     self.n_iter = 0
     self.n_plot = 0
     super(FaceAlign_Trainer, self).__init__(model=model, criterion=criterion, optimizer=optimizer, opt=opt)
     if args.rank == 0: 
         from tensorboardX import SummaryWriter
         self.writer = SummaryWriter(opt.vis_dir)
     self.best_metric = 10000
     for m in all_metrics:
         self.metric_meter[m] = meter.AverageValueMeter()
コード例 #3
0
    def __call__(self, network, P1, LABEL1, P2, LABEL2, local_fix=None, latent=False, latent_P1=None, latent_P2=None,
                 distChamfer=None):
        """
        :param network:
        :param P1: input points 1 ->  batch, num_points, 3
        :param P2: input points 2  ->  batch, num_points, 3
        :param local_fix: Vector storing constants for batch_idx function
        :param latent: Boolean to decide if latent vector should be stored for slight optimisation of runtime
        :param latent_P1: Latent vector 1 precomputed
        :param latent_P2: Latent vector 2 precomputed
        :param distChamfer: Chamfer distance function
        :return: P2 reconstructed from P1, latent vectors, chamfer outputs
        """
        if latent:
            P2_P1, latent_vector_P1, latent_vector_P2 = network.forward_classic_with_latent(
                P1.transpose(2, 1).contiguous(), P2.transpose(2, 1).contiguous(), x_latent=latent_P1,
                y_latent=latent_P2)  # forward pass
        else:
            P2_P1 = network(P1.transpose(2, 1).contiguous(), P2.transpose(2, 1).contiguous())  # forward pass

        P2_P1 = P2_P1.transpose(2, 1).contiguous()

        # reconstruction losses
        dist1, dist2, idx1, idx2 = distChamfer(P2_P1, P2)
        loss = meter.AverageValueMeter()
        for shape in range(P1.size(0)):
            for label in set(LABEL1[shape].cpu().numpy()):
                # A for loop on the batch_size in neccessary as shape parts don't have the same number of points for each shape.
                try:
                    dist1_label, dist2_label, idx1_label, idx2_label = distChamfer(
                        P2_P1[shape][LABEL1[shape] == label].unsqueeze(0),
                        P2[shape][LABEL2[shape] == label].unsqueeze(0))
                    loss.update(self.chamfer_loss.forward(dist1_label, dist2_label))
                except:
                    continue

        # TODO :find a good way to make sure points with no matching labels are not accounted for in cycle consistency
        # This is the critical step that allows the batched indexing of the computation of the cycle losses to runs smoothly.
        if not local_fix is None:
            idx1 = batch_idx(idx1, local_fix)
            idx2 = batch_idx(idx2, local_fix)

        if latent:
            return P2_P1, dist1, dist2, idx1.long(), idx2.long(), latent_vector_P1, latent_vector_P2, loss.avg
        else:
            return P2_P1, dist1, dist2, idx1.long(), idx2.long(), loss.avg
コード例 #4
0
import numpy as np
import open3d
import torch
import copy
import time
import meter

timings_ICP = meter.AverageValueMeter()  # initialize iou for this shape


def ICP(source, target):
    """

    :param source: source point cloud
    :param target:  target point cloud
    :return: source pointcloud registered
    """
    start = time.time()
    use_torch = False
    if isinstance(source, torch.Tensor):
        use_torch = True
        source = source.squeeze().cpu().numpy()
    if isinstance(target, torch.Tensor):
        target = target.squeeze().cpu().numpy()

    pcd_target = open3d.PointCloud()
    pcd_target.points = open3d.Vector3dVector(target)

    pcd_source = open3d.PointCloud()
    pcd_source.points = open3d.Vector3dVector(source)