示例#1
0
    def project(self, x, y):
        """Project data.

        Args:
            x (matrix): Locations of data.
            y (matrix): Observations of data.

        Returns:
            tuple: Tuple containing the locations of the projection,
                the projection, weights associated with the projection, and
                a regularisation term.
        """
        n = B.shape(x)[0]
        available = ~B.isnan(B.to_numpy(y))

        # Optimise the case where all data is available.
        if B.all(available):
            return self._project_pattern(x, y, (True,) * self.p)

        # Extract patterns.
        patterns = list(set(map(tuple, list(available))))

        if len(patterns) > 30:
            warnings.warn(
                f"Detected {len(patterns)} patterns, which is more "
                f"than 30 and can be slow.",
                category=UserWarning,
            )

        # Per pattern, find data points that belong to it.
        patterns_inds = [[] for _ in range(len(patterns))]
        for i in range(n):
            patterns_inds[patterns.index(tuple(available[i]))].append(i)

        # Per pattern, perform the projection.
        proj_xs = []
        proj_ys = []
        proj_ws = []
        total_reg = 0

        for pattern, pattern_inds in zip(patterns, patterns_inds):
            proj_x, proj_y, proj_w, reg = self._project_pattern(
                B.take(x, pattern_inds), B.take(y, pattern_inds), pattern
            )

            proj_xs.append(proj_x)
            proj_ys.append(proj_y)
            proj_ws.append(proj_w)
            total_reg = total_reg + reg

        return (
            B.concat(*proj_xs, axis=0),
            B.concat(*proj_ys, axis=0),
            B.concat(*proj_ws, axis=0),
            total_reg,
        )
示例#2
0
def _per_output(x, y, w):
    p = B.shape(y)[1]

    for i in range(p):
        yi = y[:, i]
        wi = w[:, i]

        # Only return available observations.
        available = ~B.isnan(yi)

        yield x[available], yi[available], wi[available]
示例#3
0
文件: isnan.py 项目: wesselb/matrix
def isnan(a: AbstractMatrix):
    if structured(a):
        warn_upmodule(f'Applying "isnan" to {a}: converting to dense.',
                      category=ToDenseWarning)
    return B.isnan(B.dense(a))
示例#4
0
def check_function(
    f,
    args_spec,
    kw_args_spec=None,
    assert_dtype=True,
    skip=None,
    contains_nans=None,
):
    """Check that a function produces consistent output. Moreover, if the first
    argument is a data type, check that the result is exactly of that type."""
    skip = [] if skip is None else skip

    if kw_args_spec is None:
        kw_args_spec = {}

    # Construct product of keyword arguments.
    kw_args_prod = list(
        product(*[[(k, v) for v in vs.forms()]
                  for k, vs in kw_args_spec.items()]))
    kw_args_prod = [{k: v for k, v in kw_args} for kw_args in kw_args_prod]

    # Add default call.
    kw_args_prod += [{}]

    # Construct product of arguments.
    args_prod = list(product(*[arg.forms() for arg in args_spec]))

    # Construct framework types to skip mixes of.
    fw_types = [
        plum.Union(t, plum.List(t), plum.Tuple(t))
        for t in [B.AGNumeric, B.TorchNumeric, B.TFNumeric, B.JAXNumeric]
    ]

    # Construct other types to skip entirely.
    skip_types = [plum.Union(t, plum.List(t), plum.Tuple(t)) for t in skip]

    # Check consistency of results.
    for kw_args in kw_args_prod:
        # Compare everything against the first result.
        first_result = f(*args_prod[0], **kw_args)

        # If first argument is a data type, then check that.
        if isinstance(args_prod[0][0], B.DType):
            assert B.dtype(first_result) is args_prod[0][0]

        for args in args_prod:
            # Skip mixes of FW types.
            fw_count = sum(
                [any(isinstance(arg, t) for arg in args) for t in fw_types])

            # Skip all skips.
            skip_count = sum(
                [any(isinstance(arg, t) for arg in args) for t in skip_types])

            if fw_count >= 2 or skip_count >= 1:
                log.debug(f"Skipping call with arguments {args} and keyword "
                          f"arguments {kw_args}.")
                continue

            # Check consistency.
            log.debug(
                f"Call with arguments {args} and keyword arguments {kw_args}.")
            result = f(*args, **kw_args)
            approx(first_result, result, assert_dtype=assert_dtype)

            # If first argument is a data type, then again check that.
            if isinstance(args[0], B.DType):
                assert B.dtype(result) is args[0]

            # Check NaNs.
            if contains_nans is not None:
                assert B.any(B.isnan(result)) == contains_nans
示例#5
0
 def f(x):
     available = B.jit_to_numpy(~B.isnan(x))
     return B.sum(x[available])