コード例 #1
0
def test_weibull_likelihood(rate: float, shape: float,
                            hybridize: bool) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """

    # generate samples
    rates = mx.nd.zeros((NUM_SAMPLES, )) + rate
    shapes = mx.nd.zeros((NUM_SAMPLES, )) + shape

    distr = Weibull(rates, shapes)
    samples = distr.sample()

    init_biases = [
        inv_softplus(rate - START_TOL_MULTIPLE * TOL * rate),
        inv_softplus(shape - START_TOL_MULTIPLE * TOL * shape),
    ]

    rate_hat, shape_hat = maximum_likelihood_estimate_sgd(
        WeibullOutput(),
        samples,
        init_biases=init_biases,
        hybridize=hybridize,
        learning_rate=PositiveFloat(0.05),
        num_epochs=PositiveInt(10),
    )

    print("rate:", rate_hat, "shape:", shape_hat)
    assert (np.abs(rate_hat - rate) < TOL *
            rate), f"rate did not match: rate = {rate}, rate_hat = {rate_hat}"
    assert (np.abs(shape_hat - shape) < TOL * shape
            ), f"shape did not match: shape = {shape}, shape_hat = {shape_hat}"
コード例 #2
0
ファイル: test_deeptpp.py プロジェクト: slowjazz/gluon-ts
def test_prediction_network_output():
    mx.rnd.seed(seed_state=1234)
    model = DeepTPPPredictionNetwork(
        num_marks=5,
        time_distr_output=WeibullOutput(),
        interval_length=1.0,
        prediction_interval_length=10.0,
    )
    model.initialize()
    past_ia_times = nd.array([[0.1, 0.2, 0.1, 0.12], [0.3, 0.15, 0.1, 0.12]])
    past_marks = nd.array([[1, 2, 0, 2], [0, 0, 1, 2]])
    past_valid_length = nd.array([3, 4])
    past_target = nd.stack(past_ia_times, past_marks, axis=-1)

    pred_target, pred_valid_length = model(past_target, past_valid_length)

    # pred_target must have shape
    # (num_parallel_samples, batch_size, max_sequence_length, 2)
    assert pred_target.ndim == 4
    assert pred_target.shape[0] == model.num_parallel_samples
    assert pred_target.shape[1] == past_ia_times.shape[0]
    assert pred_target.shape[3] == 2  # TPP prediction contains ia_time & mark
    # pred_valid_length must have shape (num_parallel_samples, batch_size)
    assert pred_valid_length.ndim == 2
    assert pred_valid_length.shape[0] == model.num_parallel_samples
    assert pred_valid_length.shape[1] == past_ia_times.shape[0]

    pred_ia_times = pred_target[..., 0].asnumpy()
    pred_marks = pred_target[..., 1].asnumpy()

    assert pred_marks.min() >= 0
    assert pred_marks.max() < model.num_marks
    assert (pred_ia_times >= 0).all()
    # ia_times are set to zero above valid_length (see DeepTPPPredictionNetwork)
    assert (pred_ia_times.sum(-1) < model.prediction_interval_length).all()
コード例 #3
0
ファイル: test_deeptpp.py プロジェクト: slowjazz/gluon-ts
def test_log_likelihood(ia_times, marks, valid_length, num_marks, loglike):
    mx.rnd.seed(seed_state=1234)

    model = DeepTPPTrainingNetwork(
        num_marks=num_marks,
        interval_length=2,
        time_distr_output=WeibullOutput(),
    )
    model.initialize()

    loglike_pred = model(nd.stack(ia_times, marks, axis=-1), valid_length)

    assert loglike_pred.shape == (ia_times.shape[0], )
    assert _allclose(loglike, loglike_pred)
コード例 #4
0
    def __init__(
        self,
        prediction_interval_length: float,
        context_interval_length: float,
        num_marks: int,
        time_distr_output: TPPDistributionOutput = WeibullOutput(),
        embedding_dim: int = 5,
        trainer: Trainer = Trainer(hybridize=False),
        num_hidden_dimensions: int = 10,
        num_parallel_samples: int = 100,
        num_training_instances: int = 100,
        freq: str = "H",
        batch_size: int = 32,
    ) -> None:
        assert (
            not trainer.hybridize
        ), "DeepTPP currently only supports the non-hybridized training"

        super().__init__(trainer=trainer, batch_size=batch_size)

        assert (
            prediction_interval_length > 0
        ), "The value of `prediction_interval_length` should be > 0"
        assert (
            context_interval_length is None or context_interval_length > 0
        ), "The value of `context_interval_length` should be > 0"
        assert (
            num_hidden_dimensions > 0
        ), "The value of `num_hidden_dimensions` should be > 0"
        assert (
            num_parallel_samples > 0
        ), "The value of `num_parallel_samples` should be > 0"
        assert num_marks > 0, "The value of `num_marks` should be > 0"
        assert (
            num_training_instances > 0
        ), "The value of `num_training_instances` should be > 0"

        self.num_hidden_dimensions = num_hidden_dimensions
        self.prediction_interval_length = prediction_interval_length
        self.context_interval_length = (
            context_interval_length
            if context_interval_length is not None
            else prediction_interval_length
        )
        self.num_marks = num_marks
        self.time_distr_output = time_distr_output
        self.embedding_dim = embedding_dim
        self.num_parallel_samples = num_parallel_samples
        self.num_training_instances = num_training_instances
        self.freq = freq
コード例 #5
0
        DirichletMultinomialOutput(dim=3, n_trials=5),
        DirichletOutput(dim=4),
        EmpiricalDistributionOutput(num_samples=10,
                                    distr_output=GaussianOutput()),
        GammaOutput(),
        GaussianOutput(),
        GenParetoOutput(),
        LaplaceOutput(),
        LogitNormalOutput(),
        LoglogisticOutput(),
        LowrankMultivariateGaussianOutput(dim=5, rank=2),
        MultivariateGaussianOutput(dim=4),
        NegativeBinomialOutput(),
        OneInflatedBetaOutput(),
        PiecewiseLinearOutput(num_pieces=10),
        PoissonOutput(),
        StudentTOutput(),
        UniformOutput(),
        WeibullOutput(),
        ZeroAndOneInflatedBetaOutput(),
        ZeroInflatedBetaOutput(),
        ZeroInflatedNegativeBinomialOutput(),
        ZeroInflatedPoissonOutput(),
    ],
)
def test_distribution_output_serde(distr_output: DistributionOutput):
    distr_output_copy = decode(encode(distr_output))

    assert isinstance(distr_output_copy, type(distr_output))
    assert dump_json(distr_output_copy) == dump_json(distr_output)