def __init__(self, cfg, in_channels, optimizer=None):
        super().__init__()
        self.cfg = cfg

        # Refinement
        self.transform_pc = TransformPC(cfg)
        self.feature_projection = FeatureProjection(cfg)
        self.pc_encode = PointNet2(cfg)
        self.displacement_net = LinearDisplacementNet(cfg)

        self.optimizer = None if optimizer is None else optimizer(
            self.parameters())

        # emd loss
        self.emd_dist = emd.emdModule()

        if torch.cuda.is_available():
            self.transform_pc = torch.nn.DataParallel(
                self.transform_pc, device_ids=cfg.CONST.DEVICE).cuda()
            self.feature_projection = torch.nn.DataParallel(
                self.feature_projection, device_ids=cfg.CONST.DEVICE).cuda()
            self.pc_encode = torch.nn.DataParallel(
                self.pc_encode, device_ids=cfg.CONST.DEVICE).cuda()
            self.displacement_net = torch.nn.DataParallel(
                self.displacement_net, device_ids=cfg.CONST.DEVICE).cuda()
            self.emd_dist = torch.nn.DataParallel(
                self.emd_dist, device_ids=cfg.CONST.DEVICE).cuda()
            self.cuda()
Exemplo n.º 2
0
 def build_train_loss(self):
     # Set up loss functions
     self.chamfer_dist = torch.nn.DataParallel(ChamferDistance().to(
         self.gpu_ids[0]),
                                               device_ids=self.gpu_ids)
     self.chamfer_dist_mean = torch.nn.DataParallel(
         ChamferDistanceMean().to(self.gpu_ids[0]), device_ids=self.gpu_ids)
     self.emd_dist = torch.nn.DataParallel(emd.emdModule().to(
         self.gpu_ids[0]),
                                           device_ids=self.gpu_ids)
    def __init__(self, input_pcs):
        super().__init__()
        self.pred_pcs = nn.Parameter(input_pcs)
        self.emd_dist = emd.emdModule()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)

        if torch.cuda.is_available():
            self.pred_pcs.cuda()
            self.emd_dist.cuda()
            self.cuda()
Exemplo n.º 4
0
    def __init__(self, cfg, optimizer=None, scheduler=None):
        super().__init__()
        self.cfg = cfg
        
        # Graphx Reconstructor
        self.reconstructor = Graphx_Rec(
            cfg=cfg,
            in_channels=3,
            in_instances=cfg.GRAPHX.NUM_INIT_POINTS,
            activation=nn.ReLU(),
        )
        
        self.optimizer = None if optimizer is None else optimizer(self.reconstructor.parameters())
        self.scheduler = None if scheduler or optimizer is None else scheduler(self.optimizer)
        
        # emd loss
        self.emd_dist = emd.emdModule()

        if torch.cuda.is_available():
            # Reconstructor
            self.reconstructor = torch.nn.DataParallel(self.reconstructor, device_ids=cfg.CONST.DEVICE).cuda()
            # loss
            self.emd_dist = torch.nn.DataParallel(self.emd_dist, device_ids=cfg.CONST.DEVICE).cuda()
            self.cuda()
Exemplo n.º 5
0
def test_rec_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.ToTensor(),
    ])
    dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)
    test_data_loader = torch.utils.data.DataLoader(
        dataset=dataset_loader.get_dataset(utils.data_loaders.DatasetType.TEST,
                                           test_transforms),
        batch_size=cfg.TEST.BATCH_SIZE,
        num_workers=1,
        pin_memory=True,
        shuffle=False)

    # Set up networks
    # The parameters here need to be set in cfg
    net = GRAPHX_REC_MODEL(
        cfg=cfg,
        optimizer=lambda x: torch.optim.Adam(x,
                                             lr=cfg.TRAIN.GRAPHX_LEARNING_RATE,
                                             weight_decay=cfg.TRAIN.
                                             GRAPHX_WEIGHT_DECAY),
        scheduler=lambda x: MultiStepLR(
            x, milestones=cfg.TRAIN.MILESTONES, gamma=cfg.TRAIN.GAMMA),
    )

    if torch.cuda.is_available():
        net = torch.nn.DataParallel(net).cuda()

    # Load weight
    # Load weight for encoder, decoder
    print('[INFO] %s Loading reconstruction weights from %s ...' %
          (dt.now(), cfg.TEST.WEIGHT_PATH))
    rec_checkpoint = torch.load(cfg.TEST.WEIGHT_PATH)
    net.load_state_dict(rec_checkpoint['net'])
    print('[INFO] Best reconstruction result at epoch %d ...' %
          rec_checkpoint['epoch_idx'])

    # Set up loss functions
    emd_dist = emd.emdModule()
    cd = ChamferLoss().cuda()

    # Batch average meterics
    cd_distances = utils.network_utils.AverageMeter()
    emd_distances = utils.network_utils.AverageMeter()

    # Switch models to evaluation mode
    net.eval()

    n_batches = len(test_data_loader)

    # Testing loop
    for sample_idx, (taxonomy_names, sample_names, rendering_images, model_azi,
                     model_ele, init_point_clouds,
                     ground_truth_point_clouds) in enumerate(test_data_loader):
        with torch.no_grad():
            # Only one image per sample
            rendering_images = torch.squeeze(rendering_images, 1)

            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(
                rendering_images)
            model_azi = utils.network_utils.var_or_cuda(model_azi)
            model_ele = utils.network_utils.var_or_cuda(model_ele)
            init_point_clouds = utils.network_utils.var_or_cuda(
                init_point_clouds)
            ground_truth_point_clouds = utils.network_utils.var_or_cuda(
                ground_truth_point_clouds)

            #=================================================#
            #           Test the encoder, decoder             #
            #=================================================#
            loss, pred_pc = net.module.valid_step(rendering_images,
                                                  init_point_clouds,
                                                  ground_truth_point_clouds)

            # Compute CD, EMD
            cd_distance = cd(pred_pc, ground_truth_point_clouds
                             ) / cfg.TEST.BATCH_SIZE / cfg.CONST.NUM_POINTS

            # compute reconstruction loss
            emd_loss, _ = emd_dist(pred_pc,
                                   ground_truth_point_clouds,
                                   eps=0.005,
                                   iters=50)
            emd_distance = torch.sqrt(emd_loss).mean(1).mean()

            # Append loss and accuracy to average metrics
            cd_distances.update(cd_distance.item())
            emd_distances.update(emd_distance.item())

            print("Test on [%d/%d] data, CD: %.4f EMD %.4f" %
                  (sample_idx + 1, n_batches, cd_distance.item(),
                   emd_distance.item()))

    # print result
    print("Reconstruction result:")
    print("CD result: ", cd_distances.avg)
    print("EMD result", emd_distances.avg)
    logname = cfg.TEST.RESULT_PATH
    with open(logname, 'a') as f:
        f.write('Reconstruction result: \n')
        f.write("CD result: %.8f \n" % cd_distances.avg)
        f.write("EMD result: %.8f \n" % emd_distances.avg)
Exemplo n.º 6
0
 def build_val_loss(self):
     self.chamfer_dist_mean = ChamferDistanceMean().cuda()
     self.emd_dist = emd.emdModule().cuda()
Exemplo n.º 7
0
 def build_val_loss(self):
     # Set up loss functions
     self.chamfer_dist = ChamferDistance().cuda()
     self.chamfer_dist_mean = ChamferDistanceMean().cuda()
     self.emd_dist = emd.emdModule().cuda()
Exemplo n.º 8
0
    def __init__(self,
                 cfg,
                 optimizer_G=None,
                 scheduler_G=None,
                 optimizer_D=None,
                 scheduler_D=None):

        super().__init__()
        self.cfg = cfg

        # Graphx Generator
        self.model_G = Graphx_Rec(
            cfg=cfg,
            in_channels=3,
            in_instances=cfg.GRAPHX.NUM_INIT_POINTS,
            activation=nn.ReLU(),
        )

        # Projection Discriminator
        self.model_D = ProjectionD(
            num_classes=cfg.DATASET.NUM_CLASSES,
            img_shape=(cfg.RENDER.N_VIEWS + 3, cfg.RENDER.IMG_SIZE,
                       cfg.RENDER.IMG_SIZE),
        )

        # Renderer
        self.renderer = ComputeDepthMaps(
            projection=cfg.RENDER.PROJECTION,
            eyepos_scale=cfg.RENDER.EYEPOS,
            image_size=cfg.RENDER.IMG_SIZE,
        ).float()

        # OptimizerG
        self.optimizer_G = None if optimizer_G is None else optimizer_G(
            self.model_G.parameters())
        self.scheduler_G = None if scheduler_G or optimizer_G is None else scheduler_G(
            self.optimizer_G)

        # OptimizerD
        self.optimizer_D = None if optimizer_D is None else optimizer_D(
            self.model_D.parameters())
        self.scheduler_D = None if scheduler_D or optimizer_D is None else scheduler_D(
            self.optimizer_D)

        # a dict store the losses for each step
        self.loss = {}

        # emd loss
        self.emd_dist = emd.emdModule()

        # GAN criterion
        self.criterionD = torch.nn.MSELoss()

        if torch.cuda.is_available():
            # Generator
            self.model_G = torch.nn.DataParallel(
                self.model_G, device_ids=cfg.CONST.DEVICE).cuda()
            # Discriminator
            self.model_D = torch.nn.DataParallel(
                self.model_D, device_ids=cfg.CONST.DEVICE).cuda()
            # Renderer
            self.renderer = torch.nn.DataParallel(
                self.renderer, device_ids=cfg.CONST.DEVICE).cuda()
            # loss
            self.emd_dist = torch.nn.DataParallel(
                self.emd_dist, device_ids=cfg.CONST.DEVICE).cuda()
            self.criterionD = torch.nn.DataParallel(
                self.criterionD, device_ids=cfg.CONST.DEVICE).cuda()
            self.cuda()
Exemplo n.º 9
0
def test_refine_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.ToTensor(),
    ])
    dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg)
    test_data_loader = torch.utils.data.DataLoader(dataset=dataset_loader.get_dataset(
                                                   utils.data_loaders.DatasetType.TEST,  test_transforms),
                                                   batch_size=cfg.TEST.BATCH_SIZE,
                                                   num_workers=1,
                                                   pin_memory=True,
                                                   shuffle=False)

    # Set up networks
    # The parameters here need to be set in cfg
    rec_net = Graphx_Rec(
        cfg=cfg,
        in_channels=3,
        in_instances=cfg.GRAPHX.NUM_INIT_POINTS,
        activation=nn.ReLU(),
    )
        
    # Refine network
    refine_net = GRAPHX_REFINE_MODEL(
        cfg=cfg,
        in_channels=3,
        optimizer=lambda x: torch.optim.Adam(x, lr=cfg.REFINE.LEARNING_RATE)
    )
    

    if torch.cuda.is_available():
        rec_net = torch.nn.DataParallel(rec_net, device_ids=cfg.CONST.DEVICE).cuda()
        refine_net = torch.nn.DataParallel(refine_net, device_ids=cfg.CONST.DEVICE).cuda()
    
    # Load weight
    # Load pretrained generator
    print('[INFO] %s Recovering generator from %s ...' % (dt.now(), cfg.TEST.GENERATOR_WEIGHTS))
    rec_net_dict = rec_net.state_dict()
    pretrained_dict = torch.load(cfg.TEST.GENERATOR_WEIGHTS)
    pretrained_weight_dict = pretrained_dict['net']
    new_weight_dict = OrderedDict()
    for k, v in pretrained_weight_dict.items():
        if cfg.REFINE.GENERATOR_TYPE == 'REC':
            name = k[21:] # remove module.reconstructor.
        elif cfg.REFINE.GENERATOR_TYPE == 'GAN':
            name = k[15:] # remove module.model_G.
        if name in rec_net_dict:
            new_weight_dict[name] = v
    
    rec_net_dict.update(new_weight_dict)
    rec_net.load_state_dict(rec_net_dict)

    # Load weight
    # Load weight for encoder, decoder
    print('[INFO] %s Recovering refiner from %s ...' % (dt.now(), cfg.TEST.REFINER_WEIGHTS))
    refine_checkpoint = torch.load(cfg.TEST.REFINER_WEIGHTS)
    refine_net.load_state_dict(refine_checkpoint['net'])
    print('[INFO] Best reconstruction result at epoch %d ...' % refine_checkpoint['epoch_idx'])
    epoch_id = int(refine_checkpoint['epoch_idx'])
    
    rec_net.eval()
    refine_net.eval()

    # Set up loss functions
    emd_dist = emd.emdModule()
    cd = ChamferLoss().cuda()
    
    # Batch average meterics
    cd_distances = utils.network_utils.AverageMeter()
    emd_distances = utils.network_utils.AverageMeter()

    n_batches = len(test_data_loader)

    # Testing loop
    for sample_idx, (taxonomy_names, sample_names, rendering_images, update_images,
                    model_azi, model_ele,
                    init_point_clouds, ground_truth_point_clouds) in enumerate(test_data_loader):
        with torch.no_grad():
            # Only one image per sample
            rendering_images = torch.squeeze(rendering_images, 1)
            update_images = torch.squeeze(update_images, 1)
            
            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(rendering_images)
            update_images = utils.network_utils.var_or_cuda(update_images)
            model_azi = utils.network_utils.var_or_cuda(model_azi)
            model_ele = utils.network_utils.var_or_cuda(model_ele)
            init_point_clouds = utils.network_utils.var_or_cuda(init_point_clouds)
            ground_truth_point_clouds = utils.network_utils.var_or_cuda(ground_truth_point_clouds)
            
            #=================================================#
            #           Test the encoder, decoder             #
            #=================================================#
            # rec net give out a coarse point cloud
            coarse_pc, _ = rec_net(rendering_images, init_point_clouds)
            # refine net give out a refine result
            loss, pred_pc = refine_net.module.valid_step(update_images, coarse_pc, ground_truth_point_clouds, model_azi, model_ele)

            # Compute CD, EMD
            cd_distance = cd(pred_pc, ground_truth_point_clouds) / cfg.TEST.BATCH_SIZE / cfg.CONST.NUM_POINTS
            
            # compute reconstruction loss
            emd_loss, _ = emd_dist(
                pred_pc, ground_truth_point_clouds, eps=0.005, iters=50
            )
            emd_distance = torch.sqrt(emd_loss).mean(1).mean()

            # Append loss and accuracy to average metrics
            cd_distances.update(cd_distance.item())
            emd_distances.update(emd_distance.item())

            print("Test on [%d/%d] data, CD: %.4f EMD %.4f" % (sample_idx + 1,  n_batches, cd_distance.item(), emd_distance.item()))
    
    # print result
    print("Reconstruction result:")
    print("CD result: ", cd_distances.avg)
    print("EMD result", emd_distances.avg)
    logname = cfg.TEST.RESULT_PATH 
    with open(logname, 'a') as f:
        f.write('Reconstruction result: \n')
        f.write("CD result: %.8f \n" % cd_distances.avg)
        f.write("EMD result: %.8f \n" % emd_distances.avg)
Exemplo n.º 10
0
class Metrics(object):
    ITEMS = [
        {
            "name": "F-Score",
            "enabled": True,
            "eval_func": "cls._get_f_score",
            "is_greater_better": True,
            "init_value": 0,
        },
        {
            "name": "ChamferDistance",
            "enabled": True,
            "eval_func": "cls._get_chamfer_distance",
            "eval_object": ChamferDistanceMean(),
            "is_greater_better": False,
            "init_value": 32767,
        },
        {
            "name": "EMD",
            "enabled": True,
            "eval_func": "cls._get_emd",
            "eval_object": emd.emdModule(),
            "is_greater_better": False,
            "init_value": 32767,
        },
    ]

    @classmethod
    def get(cls, pred, gt):
        _items = cls.items()
        _values = [0] * len(_items)
        for i, item in enumerate(_items):
            eval_func = eval(item["eval_func"])
            _values[i] = eval_func(pred, gt)

        return _values

    @classmethod
    def items(cls):
        return [i for i in cls.ITEMS if i["enabled"]]

    @classmethod
    def names(cls):
        _items = cls.items()
        return [i["name"] for i in _items]

    @classmethod
    def _get_f_score(cls, pred, gt, th=0.01):
        """References: https://github.com/lmb-freiburg/what3d/blob/master/util.py"""
        pred = cls._get_open3d_ptcloud(pred)
        gt = cls._get_open3d_ptcloud(gt)

        dist1 = pred.compute_point_cloud_distance(gt)
        dist2 = gt.compute_point_cloud_distance(pred)

        recall = float(sum(d < th for d in dist2)) / float(len(dist2))
        precision = float(sum(d < th for d in dist1)) / float(len(dist1))
        return 2 * recall * precision / (
            recall + precision) if recall + precision else 0

    @classmethod
    def _get_open3d_ptcloud(cls, tensor):
        tensor = tensor.squeeze().cpu().numpy()
        ptcloud = open3d.geometry.PointCloud()
        ptcloud.points = open3d.utility.Vector3dVector(tensor)

        return ptcloud

    @classmethod
    def _get_chamfer_distance(cls, pred, gt):
        chamfer_distance = cls.ITEMS[1]["eval_object"]
        return chamfer_distance(pred, gt).item() * 1000

    @classmethod
    def _get_emd(cls, pred, gt):
        EMD = cls.ITEMS[2]["eval_object"]
        dist, _ = EMD(pred, gt, eps=0.005, iters=50)  # for val
        # dist, _ = EMD(pred, gt, 0.002, 10000) # final test ?
        emd = torch.sqrt(dist).mean(1).mean()
        return emd.item() * 100

    def __init__(self, metric_name, values):
        self._items = Metrics.items()
        self._values = [item["init_value"] for item in self._items]
        self.metric_name = metric_name

        if type(values).__name__ == "dict":
            metric_indexes = {}
            for idx, item in enumerate(self._items):
                item_name = item["name"]
                metric_indexes[item_name] = idx
            for k, v in values.items():
                if k not in metric_indexes:
                    logger.warn("Ignore Metric[Name=%s] due to disability." %
                                k)
                    continue
                self._values[metric_indexes[k]] = v
        elif type(values).__name__ == "list":
            self._values = values
        else:
            raise Exception("Unsupported value type: %s" % type(values))

    def state_dict(self):
        _dict = {}
        for i in range(len(self._items)):
            item = self._items[i]["name"]
            value = self._values[i]
            _dict[item] = value

        return _dict

    def __repr__(self):
        return str(self.state_dict())

    def better_than(self, other):
        if other is None:
            return True

        _index = -1
        for i, _item in enumerate(self._items):
            if _item["name"] == self.metric_name:
                _index = i
                break
        if _index == -1:
            raise Exception("Invalid metric name to compare.")

        _metric = self._items[i]
        _value = self._values[_index]
        other_value = other._values[_index]
        return _value > other_value if _metric[
            "is_greater_better"] else _value < other_value