def create_loss(self, coarse, fine, gt, alpha): loss_coarse = chamfer(coarse, gt) add_train_summary('train/coarse_loss', loss_coarse) update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse) loss_fine = chamfer(fine, gt) add_train_summary('train/fine_loss', loss_fine) update_fine = add_valid_summary('valid/fine_loss', loss_fine) loss = loss_coarse + alpha * loss_fine add_train_summary('train/loss', loss) update_loss = add_valid_summary('valid/loss', loss) return loss, [update_coarse, update_fine, update_loss]
def create_loss(self, coarse, fine, gt, alpha): gt_ds = gt[:, :coarse.shape[1], :] loss_coarse = earth_mover(coarse, gt_ds) add_train_summary('train/coarse_loss', loss_coarse) update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse) loss_fine = chamfer(fine, gt) add_train_summary('train/fine_loss', loss_fine) update_fine = add_valid_summary('valid/fine_loss', loss_fine) loss = loss_coarse + alpha * loss_fine add_train_summary('train/loss', loss) update_loss = add_valid_summary('valid/loss', loss) return loss, [update_coarse, update_fine, update_loss]
def create_loss(self, coarse_highres, coarse, fine, gt, theta): loss_coarse_highres = chamfer(coarse_highres, gt) loss_coarse = chamfer(coarse, gt) add_train_summary('train/coarse_loss', loss_coarse) update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse) loss_fine = chamfer(fine, gt) add_train_summary('train/fine_loss', loss_fine) update_fine = add_valid_summary('valid/fine_loss', loss_fine) repulsion_loss = get_repulsion_loss4(coarse) loss = 0.5 * loss_coarse_highres + loss_coarse + theta * loss_fine + 0.2 * repulsion_loss add_train_summary('train/loss', loss) update_loss = add_valid_summary('valid/loss', loss) return loss, loss_fine, [update_coarse, update_fine, update_loss]
def create_loss(self, coarse, fine, gt, alpha): gt_ds = gt[:, :coarse.shape[1], :] loss_coarse = 10 * earth_mover(coarse[:, :, 0:3], gt_ds[:, :, 0:3]) _, retb, _, retd = tf_nndistance.nn_distance(coarse[:, :, 0:3], gt_ds[:, :, 0:3]) for i in range(np.shape(gt_ds)[0]): index = tf.expand_dims(retb[i], -1) sem_feat = tf.nn.softmax(coarse[i, :, 3:], -1) sem_gt = tf.cast( tf.one_hot( tf.gather_nd(tf.cast(gt_ds[i, :, 3] * 80 * 12, tf.int32), index), 12), tf.float32) loss_sem_coarse = tf.reduce_mean(-tf.reduce_sum( 0.9 * sem_gt * tf.log(1e-6 + sem_feat) + (1 - 0.9) * (1 - sem_gt) * tf.log(1e-6 + 1 - sem_feat), [1])) loss_coarse += loss_sem_coarse add_train_summary('train/coarse_loss', loss_coarse) update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse) loss_fine = 10 * chamfer(fine[:, :, 0:3], gt[:, :, 0:3]) _, retb, _, retd = tf_nndistance.nn_distance(fine[:, :, 0:3], gt[:, :, 0:3]) for i in range(np.shape(gt)[0]): index = tf.expand_dims(retb[i], -1) sem_feat = tf.nn.softmax(fine[i, :, 3:], -1) sem_gt = tf.cast( tf.one_hot( tf.gather_nd(tf.cast(gt[i, :, 3] * 80 * 12, tf.int32), index), 12), tf.float32) loss_sem_fine = tf.reduce_mean(-tf.reduce_sum( 0.9 * sem_gt * tf.log(1e-6 + sem_feat) + (1 - 0.9) * (1 - sem_gt) * tf.log(1e-6 + 1 - sem_feat), [1])) loss_fine += loss_sem_fine add_train_summary('train/fine_loss', loss_fine) update_fine = add_valid_summary('valid/fine_loss', loss_fine) loss = loss_coarse + alpha * loss_fine add_train_summary('train/loss', loss) update_loss = add_valid_summary('valid/loss', loss) return loss, [update_coarse, update_fine, update_loss]
def compute_loss(self, inputs, gt_pose, est_pose): # see equation (1) from IT-net # est_pose: world -> body coord # gt_pose: body -> world coord (to to invert when applying to inputs) est_inputs = transform_tf(inputs, est_pose) gt_inputs = transform_tf(inputs, tf.linalg.inv(gt_pose)) sq_dist = tf.reduce_sum(tf.square(est_inputs - gt_inputs), axis=2) loss = tf.reduce_mean(tf.reduce_mean(sq_dist, axis=1), axis=0) add_train_summary('train/loss', loss) update_loss = add_valid_summary('valid/loss', loss) return loss, update_loss
def create_loss(self, outputs, gt): loss = chamfer(outputs, gt) add_train_summary('train/loss', loss) update_loss = add_valid_summary('valid/loss', loss) return loss, update_loss