Example #1
0
def get_emd_loss(pred, gt):
    """ pred: BxNxC,
        label: BxN, """
    batch_size = tf.shape(pred)[0] #pred.get_shape()[0].value
    matchl_out, matchr_out = tf_auctionmatch.auction_match(pred, gt)
    matched_out = tf_sampling.gather_point(gt, matchl_out)
    dist = tf.reshape((pred - matched_out) ** 2, shape=(batch_size, -1))
    emd_loss = tf.reduce_sum(dist)
    return emd_loss
Example #2
0
def get_emd_loss(pred, gt, radius):
    """ pred: BxNxC,
        label: BxN, """
    batch_size = pred.get_shape()[0].value
    matchl_out, matchr_out = tf_auctionmatch.auction_match(pred, gt)
    matched_out = tf_sampling.gather_point(gt, matchl_out)
    dist = tf.reshape((pred - matched_out)**2, shape=(batch_size, -1))
    dist = tf.reduce_mean(dist, axis=1, keep_dims=True)
    dist_norm = dist / radius

    emd_loss = tf.reduce_mean(dist_norm)
    return emd_loss, matchl_out
Example #3
0
def get_emd_completion_loss(pred, gt, radius=1):
    """ pred: BxNxC,
        label: BxN, """
    npoint = gt.get_shape()[1].value
    pred = tf.reshape(pred, [-1, gt.get_shape()[2], gt.get_shape()[3]])
    gt = tf.reshape(gt, [-1, gt.get_shape()[2], gt.get_shape()[3]])
    batch_size = gt.get_shape()[0].value
    matchl_out, matchr_out = tf_auctionmatch.auction_match(pred, gt)
    matched_out = gather_point(gt, matchl_out)
    dist = tf.reshape((pred - matched_out)**2, shape=(batch_size, -1))
    dist = tf.reduce_mean(dist, axis=1, keep_dims=True)
    emd_loss = tf.reduce_mean(dist)
    return emd_loss
def get_rec_metrics(gt_pcl, pred_pcl, batch_size=16, num_points=1024):
    dists_forward, _, dists_backward, _ = tf_nndistance.nn_distance(
        gt_pcl, pred_pcl)
    dists_forward = tf.reduce_mean(tf.sqrt(dists_forward), axis=1)  # (B, )
    dists_backward = tf.reduce_mean(tf.sqrt(dists_backward), axis=1)  # (B, )
    chamfer_distance = dists_backward + dists_forward

    X, _ = tf.meshgrid(tf.range(batch_size),
                       tf.range(num_points),
                       indexing='ij')
    ind, _ = auction_match(pred_pcl,
                           gt_pcl)  # Ind corresponds to points in pcl_gt
    ind = tf.stack((X, ind), -1)
    emd = tf.reduce_mean(
        tf.sqrt(
            tf.reduce_sum((tf.gather_nd(gt_pcl, ind) - pred_pcl)**2, axis=-1)),
        axis=1
    )  # (BATCH_SIZE,NUM_POINTS,3) --> (BATCH_SIZE,NUM_POINTS) --> (BATCH_SIZE)

    return dists_forward, dists_backward, chamfer_distance, emd
Example #5
0
def calculate_emd_error(pred, gt):
    npoint = gt.shape[1]
    pred_pl = tf.placeholder(tf.float32, shape=(None, npoint, 3))
    gt_pl = tf.placeholder(tf.float32, shape=(None, npoint, 3))

    matchl_out, matchr_out = tf_auctionmatch.auction_match(pred_pl, gt_pl)
    matched_out = tf_sampling.gather_point(gt_pl, matchl_out)
    EMD_dist = tf.sqrt(tf.reduce_sum((pred_pl-matched_out)**2,axis=2))
    EMD_dist = tf.reduce_mean(EMD_dist,axis=1)

    dists_forward, _, dists_backward, _ = tf_nndistance.nn_distance(gt_pl, pred_pl)
    CD_dist = dists_forward + dists_backward
    CD_dist = tf.reduce_mean(CD_dist,axis=1)

    # Create a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False

    with tf.Session(config=config) as sess:
        EMD_error = np.zeros((len(pred)))
        CD_error = np.zeros((len(pred)))
        batch_size = 30
        for idx in range(0, len(pred), batch_size):
            start_idx = idx
            end_idx = min(idx + batch_size, len(pred))
            batch_pred = pred[start_idx:end_idx]
            batch_gt = gt[start_idx:end_idx]
            batch_EMD,batch_CD = sess.run([EMD_dist,CD_dist], feed_dict={pred_pl: batch_pred, gt_pl: batch_gt})
            EMD_error[start_idx:end_idx] = batch_EMD
            CD_error[start_idx:end_idx] = batch_CD
        print "Average EMD distance %s; average CD distance %s"%(EMD_error.mean(),CD_error.mean())

        EMD_error[EMD_error>0.2]=0.2
        CD_error[CD_error>0.2] = 0.2

        fig, axes = plt.subplots(2)
        axes[0].hist(EMD_error,20)
        axes[1].hist(CD_error, 20)
        plt.show()
Example #6
0
    def loss(self, gt, pred):
        from tf_ops.emd import tf_auctionmatch
        from tf_ops.sampling import tf_sampling
        #from tf_ops.CD import tf_nndistance
        from structural_losses import tf_nndistance
        # from structural_losses.tf_approxmatch import approx_match, match_cost

        if self.emd:
            matchl_out, matchr_out = tf_auctionmatch.auction_match(pred, gt)
            matched_out = tf_sampling.gather_point(gt, matchl_out)
            emd_loss = tf.reshape((pred - matched_out)**2,
                                  shape=(self.batch_size, -1))
            emd_loss = tf.reduce_mean(emd_loss, axis=1, keepdims=True)
            return emd_loss
        else:
            #cost_p1_p2, _, cost_p2_p1, _ = nn_distance(self.x_reconstr, self.gt)
            #self.loss = tf.reduce_mean(cost_p1_p2) + tf.reduce_mean(cost_p2_p1)

            p1top2, _, p2top1, _ = tf_nndistance.nn_distance(pred, gt)
            #p1top2 is for each element in gt, the cloest distance to this element
            # cd_loss = p1top2 + p2top1
            cd_loss = K.mean(p1top2) + K.mean(p2top1)
            # cd_loss = K.mean(cd_loss)
            return cd_loss