Beispiel #1
0
 def test_dtype_failure(self):
     arr = tt.TensorType(FLOATX, [False] * 3)("arr")
     indices = tt.TensorType(FLOATX, [False] * 3)("indices")
     arr.tag.test_value = np.zeros((1,) * arr.ndim, dtype=FLOATX)
     indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=FLOATX)
     with pytest.raises(IndexError):
         take_along_axis(arr, indices)
Beispiel #2
0
    def __init__(self, w, comp_dists, mixture_axis=-1, *args, **kwargs):
        self.w = tt.as_tensor_variable(w)
        if not isinstance(comp_dists, Distribution):
            raise TypeError(
                "The MixtureSameFamily distribution only accepts Distribution "
                f"instances as its components. Got {type(comp_dists)} instead."
            )
        self.comp_dists = comp_dists
        if mixture_axis < 0:
            mixture_axis = len(comp_dists.shape) + mixture_axis
            if mixture_axis < 0:
                raise ValueError(
                    "`mixture_axis` is supposed to be in shape of components' distribution. "
                    f"Got {mixture_axis + len(comp_dists.shape)} axis instead out of the bounds."
                )
        comp_shape = to_tuple(comp_dists.shape)
        self.shape = comp_shape[:mixture_axis] + comp_shape[mixture_axis + 1 :]
        self.mixture_axis = mixture_axis
        kwargs.setdefault("dtype", self.comp_dists.dtype)

        # Compute the mode so we don't always have to pass a testval
        defaults = kwargs.pop("defaults", [])
        event_shape = self.comp_dists.shape[mixture_axis + 1 :]
        _w = tt.shape_padleft(
            tt.shape_padright(w, len(event_shape)),
            len(self.comp_dists.shape) - w.ndim - len(event_shape),
        )
        mode = take_along_axis(
            self.comp_dists.mode,
            tt.argmax(_w, keepdims=True),
            axis=mixture_axis,
        )
        self.mode = mode[(..., 0) + (slice(None),) * len(event_shape)]

        if not all_discrete(comp_dists):
            mean = tt.as_tensor_variable(self.comp_dists.mean)
            self.mean = (_w * mean).sum(axis=mixture_axis)
            if "mean" not in defaults:
                defaults.append("mean")
        defaults.append("mode")

        super().__init__(defaults=defaults, *args, **kwargs)
Beispiel #3
0
 def test_axis_failure(self, axis):
     arr, indices = self.get_input_tensors((3, 1))
     with pytest.raises(ValueError):
         take_along_axis(arr, indices, axis=axis)
Beispiel #4
0
 def _output_tensor(self, arr, indices, axis):
     return take_along_axis(arr, indices, axis)