Exemple #1
0
def test_mx_switch() -> None:
    a = (mx.nd.array([[1, 1, 0, 0]]), mx.nd.array([[1, 1, 1, 1]]))
    b = (mx.nd.array([[1, 0, 1, 0]]), mx.nd.array([[2, 2, 2, 2]]))
    c = mx.nd.array([[3, 3, 3, 3]])
    assert (
        (mx_switch(mx.nd, a, b, c) == mx.nd.array([1.0, 1.0, 2.0, 3.0]))
        .asnumpy()
        .all()
    )
Exemple #2
0
    def impute_target_if_unobserved(
        self,
        F,
        output,
        scale,
        current_target,
        current_observed_indicator,
        is_pad,
    ) -> Tensor:
        """
        This will impute the target at unrolling step i if the value is not
        observed. If the target value is a padded dummy value, it will be set
        to zero. We will keep the target value otherwise.

        Parameters
        ----------
        F
        output
            RNN outputs to construct the distribution at unrolling step i.
        scale
            Scale of the time series.
        current_target
            Tensor containing the current target.
        current_observed_indicator
            Tensor containing the current observed value indicator
        is_pad
            Tensor containing the current padding indicator.

        Returns
        -------
        Target (imputed/zero if unobserved).
        """
        distr_args = self.proj_distr_args(output)
        distr = self.distr_output.distribution(distr_args, scale=scale)

        with autograd.pause():
            sample = distr.sample(
                num_samples=self.num_imputation_samples, dtype=self.dtype
            ).mean(axis=0)

        target_value = mx_switch(
            F,
            (current_observed_indicator, current_target),
            (is_pad, F.zeros_like(sample)),
            sample,
        )
        return target_value