def element_center_lowres_grid_direct_loss(model_config, training_example, structured_implicit): element_centers = structured_implicit.element_centers gt_sdf_at_centers, _ = interpolate_util.interpolate( training_example.grid, element_centers, training_example.world2grid) mse = model_config.hparams.gd * tf.reduce_mean(gt_sdf_at_centers) + 1e-5 summarize.summarize_loss(model_config, mse, 'lowres_grid_direct_loss') return mse
def element_center_lowres_grid_squared_loss(model_config, training_example, structured_implicit): element_centers = structured_implicit.element_centers gt_sdf_at_centers, _ = interpolate_util.interpolate( training_example.grid, element_centers, training_example.world2grid) mse = model_config.hparams.gs * tf.reduce_mean( tf.sign(gt_sdf_at_centers) * tf.square(gt_sdf_at_centers + 1e-04)) + 1e-5 summarize.summarize_loss(model_config, mse, 'lowres_grid_magnitude_loss') return mse
def element_center_lowres_grid_inside_loss(model_config, training_example, structured_implicit): """Loss that element centers should lie within a voxel of the GT inside.""" element_centers = structured_implicit.element_centers gt_sdf_at_centers, _ = interpolate_util.interpolate( training_example.grid, element_centers, training_example.world2grid) gt_sdf_at_centers = tf.where_v2(gt_sdf_at_centers > model_config.hparams.igt, gt_sdf_at_centers, 0.0) mse = model_config.hparams.ig * tf.reduce_mean( tf.square(gt_sdf_at_centers + 1e-04)) + 1e-05 summarize.summarize_loss(model_config, mse, 'lowres_grid_inside_loss') return mse
def smooth_element_center_lowres_grid_inside_loss(model_config, training_example, structured_implicit): """Offset version of element_center_lowres_grid_inside_loss by voxel width.""" element_centers = structured_implicit.element_centers gt_sdf_at_centers, _ = interpolate_util.interpolate( training_example.grid, element_centers, training_example.world2grid) gt_sdf_at_centers = tf.maximum(gt_sdf_at_centers - model_config.hparams.igt, 0.0) mse = model_config.hparams.ig * tf.reduce_mean( tf.square(gt_sdf_at_centers + 1e-04)) + 1e-05 summarize.summarize_loss(model_config, mse, 'lowres_grid_inside_loss') return mse