Example #1
0
  Returns:
    A tensor of shape `[B, H, W, 1]` containing the mattes.

  Raises:
    ValueError: If `image`, `coeff_mul`, or `coeff_add` are not of rank 4. If
    the last dimension of `coeff_add` is not 1. If the batch dimensions of
    `image`, `coeff_mul`, and `coeff_add` do not match.
  """
    with tf.compat.v1.name_scope(name, "matting_reconstruct",
                                 [image, coeff_mul, coeff_add]):
        image = tf.convert_to_tensor(value=image)
        coeff_mul = tf.convert_to_tensor(value=coeff_mul)
        coeff_add = tf.convert_to_tensor(value=coeff_add)

        shape.check_static(image, has_rank=4)
        shape.check_static(coeff_mul, has_rank=4)
        shape.check_static(coeff_add, has_rank=4, has_dim_equals=(-1, 1))
        shape.compare_batch_dimensions(tensors=(image, coeff_mul),
                                       last_axes=-1,
                                       broadcast_compatible=False)
        shape.compare_batch_dimensions(tensors=(image, coeff_add),
                                       last_axes=-2,
                                       broadcast_compatible=False)

        return tfg_vector.dot(coeff_mul, image) + coeff_add


# API contains all public functions and classes.
__all__ = export_api.get_functions_and_classes()
Example #2
0
 def test_get_functions_and_classes(self):
     """Tests that get_functions_and_classes does not raise an exception."""
     try:
         export_api.get_functions_and_classes()
     except Exception as e:  # pylint: disable=broad-except
         self.fail("Exception raised: %s" % str(e))