コード例 #1
0
ファイル: hmm.py プロジェクト: synabreu/numpyro
def semi_supervised_hmm(transition_prior, emission_prior,
                        supervised_categories, supervised_words,
                        unsupervised_words):
    num_categories, num_words = transition_prior.shape[
        0], emission_prior.shape[0]
    transition_prob = numpyro.sample(
        'transition_prob',
        dist.Dirichlet(
            np.broadcast_to(transition_prior,
                            (num_categories, num_categories))))
    emission_prob = numpyro.sample(
        'emission_prob',
        dist.Dirichlet(
            np.broadcast_to(emission_prior, (num_categories, num_words))))

    # models supervised data;
    # here we don't make any assumption about the first supervised category, in other words,
    # we place a flat/uniform prior on it.
    numpyro.sample('supervised_categories',
                   dist.Categorical(
                       transition_prob[supervised_categories[:-1]]),
                   obs=supervised_categories[1:])
    numpyro.sample('supervised_words',
                   dist.Categorical(emission_prob[supervised_categories]),
                   obs=supervised_words)

    # computes log prob of unsupervised data
    transition_log_prob = np.log(transition_prob)
    emission_log_prob = np.log(emission_prob)
    init_log_prob = emission_log_prob[:, unsupervised_words[0]]
    log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:],
                                transition_log_prob, emission_log_prob)
    log_prob = logsumexp(log_prob, axis=0, keepdims=True)
    # inject log_prob to potential function
    numpyro.factor('forward_log_prob', log_prob)
コード例 #2
0
def dawid_skene(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 2 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            beta = numpyro.sample("beta",
                                  dist.Dirichlet(jnp.ones(num_classes)))

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c",
                           dist.Categorical(pi),
                           infer={"enumerate": "parallel"})

        # here we use Vindex to allow broadcasting for the second index `c`
        # ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex
        with numpyro.plate("position", num_positions):
            numpyro.sample("y",
                           dist.Categorical(Vindex(beta)[positions, c, :]),
                           obs=annotations)
コード例 #3
0
ファイル: utils.py プロジェクト: dimarkov/pybefit
def nondynamic_mixture_model(beliefs, y, mask):
    M, T, _ = beliefs[0].shape

    tau = .5
    weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M)))

    assert weights.shape == (M, )

    gamma = npyro.sample('gamma', dist.InverseGamma(2., 5.))

    p = npyro.sample('p', dist.Dirichlet(jnp.ones(3)))

    p0 = npyro.deterministic(
        'p0',
        jnp.concatenate([
            p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2,
            p[..., 2:] / 2
        ], -1))

    U = jnp.log(p0)

    def transition_fn(carry, t):

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs).mask(mask[t])
        npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist))

        return None, None

    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, None, jnp.arange(T))
コード例 #4
0
ファイル: hmm_enum.py プロジェクト: dirmeier/numpyro
def model_3(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with mask(mask=include_prior):
        probs_w = numpyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1))
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2,
                                        data_dim]).to_event(3),
        )

    def transition_fn(carry, y):
        w_prev, x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
                x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
                # Note the broadcasting tricks here: to index probs_y on tensors x and y,
                # we also need a final tensor for the tones dimension. This is conveniently
                # provided by the plate associated with that dimension.
                with numpyro.plate("tones", data_dim, dim=-1) as tones:
                    numpyro.sample("y",
                                   dist.Bernoulli(probs_y[w, x, tones]),
                                   obs=y)
        return (w, x, t + 1), None

    w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
コード例 #5
0
ファイル: hmm.py プロジェクト: neerajprad/numpyro
def semi_supervised_hmm(transition_prior, emission_prior,
                        supervised_categories, supervised_words,
                        unsupervised_words):
    num_categories, num_words = transition_prior.shape[0], emission_prior.shape[0]
    transition_prob = sample('transition_prob', dist.Dirichlet(
        np.broadcast_to(transition_prior, (num_categories, num_categories))))
    emission_prob = sample('emission_prob', dist.Dirichlet(
        np.broadcast_to(emission_prior, (num_categories, num_words))))

    # models supervised data;
    # here we don't make any assumption about the first supervised category, in other words,
    # we place a flat/uniform prior on it.
    sample('supervised_categories', dist.Categorical(transition_prob[supervised_categories[:-1]]),
           obs=supervised_categories[1:])
    sample('supervised_words', dist.Categorical(emission_prob[supervised_categories]),
           obs=supervised_words)

    # computes log prob of unsupervised data
    transition_log_prob = np.log(transition_prob)
    emission_log_prob = np.log(emission_prob)
    init_log_prob = emission_log_prob[:, unsupervised_words[0]]
    log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:],
                                transition_log_prob, emission_log_prob)
    log_prob = logsumexp(log_prob, axis=0, keepdims=True)
    # inject log_prob to potential function
    # NB: This is a trick (uses an invalid value `0` of Multinomial distribution) to add an
    # additional term to potential energy.
    return sample('forward_log_prob', dist.Multinomial(logits=-log_prob), obs=0)
コード例 #6
0
ファイル: hmm.py プロジェクト: synabreu/numpyro
def simulate_data(rng_key, num_categories, num_words, num_supervised_data,
                  num_unsupervised_data):
    rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)

    transition_prior = np.ones(num_categories)
    emission_prior = np.repeat(0.1, num_words)

    transition_prob = dist.Dirichlet(transition_prior).sample(
        key=rng_key_transition, sample_shape=(num_categories, ))
    emission_prob = dist.Dirichlet(emission_prior).sample(
        key=rng_key_emission, sample_shape=(num_categories, ))

    start_prob = np.repeat(1. / num_categories, num_categories)
    categories, words = [], []
    for t in range(num_supervised_data + num_unsupervised_data):
        rng_key, rng_key_transition, rng_key_emission = random.split(
            rng_key, 3)
        if t == 0 or t == num_supervised_data:
            category = dist.Categorical(start_prob).sample(
                key=rng_key_transition)
        else:
            category = dist.Categorical(
                transition_prob[category]).sample(key=rng_key_transition)
        word = dist.Categorical(
            emission_prob[category]).sample(key=rng_key_emission)
        categories.append(category)
        words.append(word)

    # split into supervised data and unsupervised data
    categories, words = np.stack(categories), np.stack(words)
    supervised_categories = categories[:num_supervised_data]
    supervised_words = words[:num_supervised_data]
    unsupervised_words = words[num_supervised_data:]
    return (transition_prior, emission_prior, transition_prob, emission_prob,
            supervised_categories, supervised_words, unsupervised_words)
コード例 #7
0
ファイル: hmm_enum.py プロジェクト: dirmeier/numpyro
def model_4(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with mask(mask=include_prior):
        probs_w = numpyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).expand_by(
                [hidden_dim]).to_event(2),
        )
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim,
                                        data_dim]).to_event(3),
        )

    def transition_fn(carry, y):
        w_prev, x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
                x = numpyro.sample(
                    "x", dist.Categorical(Vindex(probs_x)[w, x_prev]))
                with numpyro.plate("tones", data_dim, dim=-1) as tones:
                    numpyro.sample("y",
                                   dist.Bernoulli(probs_y[w, x, tones]),
                                   obs=y)
        return (w, x, t + 1), None

    w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
コード例 #8
0
    def model(data):
        with numpyro.plate("states", dim):
            transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim)))
            emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1))
            emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1))

        trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim)))
        for t, y in markov(enumerate(data)):
            x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob))
            numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y)
            trans_prob = transition[x]
コード例 #9
0
def simulate_data(
    rng_key: np.ndarray,
    num_categories: int,
    num_words: int,
    num_supervised: int,
    num_unsupservised: int,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:

    rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)

    transition_prior = jnp.ones(num_categories)
    emission_prior = jnp.repeat(0.1, num_words)

    transition_prob = dist.Dirichlet(transition_prior).sample(
        rng_key_transition, sample_shape=(num_categories, ))
    emission_prob = dist.Dirichlet(emission_prior).sample(
        rng_key_emission, sample_shape=(num_categories, ))

    start_prob = jnp.repeat(1.0 / num_categories, num_categories)
    category = 0
    categories = []
    words = []

    for t in range(num_supervised + num_unsupservised):
        rng_key, rng_key_transition, rng_key_emission = random.split(
            rng_key, 3)
        if t == 0 or t == num_supervised:
            category = dist.Categorical(start_prob).sample(rng_key_transition)
        else:
            category = dist.Categorical(
                transition_prob[category]).sample(rng_key_transition)

        word = dist.Categorical(
            emission_prob[category]).sample(rng_key_emission)
        categories.append(category)
        words.append(word)

    # Split data into supervised and unsupervised
    categories = jnp.stack(categories)
    words = jnp.stack(words)

    supervised_categories = categories[:num_supervised]
    supervised_words = words[:num_supervised]
    unsupervised_words = words[num_supervised:]

    return (
        transition_prob,
        emission_prob,
        supervised_categories,
        supervised_words,
        unsupervised_words,
    )
コード例 #10
0
ファイル: numpyro_hmm.py プロジェクト: dirknbr/hmm_num_pyro
def unsupervised_hmm(words):
    with numpyro.plate("prob_plate", num_categories):
        transition_prob = numpyro.sample("transition_prob", dist.Dirichlet(transition_prior))
        emission_prob = numpyro.sample("emission_prob", dist.Dirichlet(emission_prior))

    transition_log_prob = jnp.log(transition_prob)
    emission_log_prob = jnp.log(emission_prob)
    log_prob = emission_log_prob[:, words[0]]
    for t in range(1, len(words)):
        log_prob = forward_log_prob(log_prob, words[t], transition_log_prob, emission_log_prob)
    prob = jnp.exp(logsumexp(log_prob, 0))
    # a trick to inject an additional log_prob into model's log_prob
    numpyro.sample("forward_prob", dist.Bernoulli(prob), obs=jnp.array(1.))
コード例 #11
0
ファイル: utils.py プロジェクト: dimarkov/pybefit
def gammadyn_mixture_model(beliefs, y, mask):
    M, T, _ = beliefs[0].shape

    tau = .5
    weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M)))

    assert weights.shape == (M, )

    gamma = npyro.sample('gamma', dist.InverseGamma(2., 2.))
    mu = jnp.log(jnp.exp(gamma) - 1)

    p = npyro.sample('p', dist.Dirichlet(jnp.ones(3)))
    p0 = npyro.deterministic(
        'p0',
        jnp.concatenate([
            p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2,
            p[..., 2:] / 2
        ], -1))

    scale = npyro.sample('scale', dist.Gamma(1., 1.))
    rho = npyro.sample('rho', dist.Beta(1., 2.))
    sigma = jnp.sqrt(-(1 - rho**2) / (2 * jnp.log(rho))) * scale

    U = jnp.log(p0)

    def transition_fn(carry, t):
        x_prev = carry

        gamma_dyn = npyro.deterministic('gamma_dyn', nn.softplus(mu + x_prev))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma_dyn, -1), jnp.expand_dims(U, -2))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs).mask(mask[t])
        npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist))

        with npyro.handlers.reparam(
                config={"x_next": npyro.infer.reparam.TransformReparam()}):
            affine = dist.transforms.AffineTransform(rho * x_prev, sigma)
            x_next = npyro.sample(
                'x_next',
                dist.TransformedDistribution(dist.Normal(0., 1.), affine))

        return (x_next), None

    x0 = jnp.zeros(1)
    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, (x0), jnp.arange(T))
コード例 #12
0
def item_difficulty(annotations):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        eta = numpyro.sample(
            "eta",
            dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        chi = numpyro.sample(
            "Chi",
            dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c",
                           dist.Categorical(pi),
                           infer={"enumerate": "parallel"})

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta",
                                   dist.Normal(eta[c], chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", annotations.shape[-1]):
            numpyro.sample("y",
                           dist.Categorical(logits=theta),
                           obs=annotations)
コード例 #13
0
ファイル: annotation.py プロジェクト: while519/numpyro
def logistic_random_effects(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))
        chi = numpyro.sample("Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
                beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :] - theta
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
コード例 #14
0
ファイル: annotation.py プロジェクト: while519/numpyro
def hierarchical_dawid_skene(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 4 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
        # invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
        # to 0 and only define hyperpriors for the first `num_classes - 1` terms.
        zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            # non-centered parameterization
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
            # pad 0 to the last item
            beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :]
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
コード例 #15
0
ファイル: hmm_enum.py プロジェクト: dirmeier/numpyro
def model_1(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    with mask(mask=include_prior):
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1))
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2),
        )

    def transition_fn(carry, y):
        x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
                with numpyro.plate("tones", data_dim, dim=-1):
                    numpyro.sample("y",
                                   dist.Bernoulli(probs_y[x.squeeze(-1)]),
                                   obs=y)
        return (x, t + 1), None

    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    # NB swapaxes: we move time dimension of `sequences` to the front to scan over it
    scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
コード例 #16
0
ファイル: utils.py プロジェクト: dimarkov/pybefit
def fulldyn_mixture_model(beliefs, y, mask):
    M, T, N, _ = beliefs[0].shape

    c0 = beliefs[-1]

    tau = .5
    with npyro.plate('N', N):
        weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M)))

        assert weights.shape == (N, M)

        mu = npyro.sample('mu', dist.Normal(5., 5.))

        lam12 = npyro.sample('lam12',
                             dist.HalfCauchy(1.).expand([2]).to_event(1))
        lam34 = npyro.sample('lam34', dist.HalfCauchy(1.))

        _lam34 = jnp.expand_dims(lam34, -1)
        lam0 = npyro.deterministic(
            'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1))

        eta = npyro.sample('eta', dist.Beta(1, 10))

        scale = npyro.sample('scale', dist.HalfNormal(1.))
        theta = npyro.sample('theta', dist.HalfCauchy(5.))
        rho = jnp.exp(-theta)
        sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale

    x0 = jnp.zeros(N)

    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask(
            mask[t][..., None])
        with npyro.plate('subjects', N):
            y = npyro.sample(
                'y', dist.MixtureSameFamily(mixing_dist, component_dist))
            noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None

    lam_start = npyro.deterministic('lam_start',
                                    lam0 + jnp.expand_dims(eta, -1) * c0)
    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, (lam_start, x0), jnp.arange(T))
コード例 #17
0
ファイル: utils.py プロジェクト: dimarkov/pybefit
def nondyn_single_model(beliefs, y, mask):
    T, _ = beliefs[0].shape

    gamma = npyro.sample('gamma', dist.InverseGamma(2., 3.))

    p = npyro.sample('p', dist.Dirichlet(jnp.ones(3)))
    p0 = npyro.deterministic(
        'p0',
        jnp.concatenate([
            p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2,
            p[..., 2:] / 2
        ], -1))

    U = jnp.log(p0)

    def transition_fn(carry, t):

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))
        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))

        return None, None

    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, None, jnp.arange(T))
コード例 #18
0
ファイル: hmm_enum.py プロジェクト: dirmeier/numpyro
def model_6(sequences, lengths, args, include_prior=False):
    num_sequences, max_length, data_dim = sequences.shape

    with mask(mask=include_prior):
        # Explicitly parameterize the full tensor of transition probabilities, which
        # has hidden_dim cubed entries.
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).expand(
                [args.hidden_dim, args.hidden_dim]).to_event(2),
        )

        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2),
        )

    def transition_fn(carry, y):
        x_prev, x_curr, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                probs_x_t = Vindex(probs_x)[x_prev, x_curr]
                x_prev, x_curr = x_curr, numpyro.sample(
                    "x", dist.Categorical(probs_x_t))
                with numpyro.plate("tones", data_dim, dim=-1):
                    probs_y_t = probs_y[x_curr.squeeze(-1)]
                    numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y)
        return (x_prev, x_curr, t + 1), None

    x_prev = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    x_curr = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    scan(transition_fn, (x_prev, x_curr, 0),
         jnp.swapaxes(sequences, 0, 1),
         history=2)
コード例 #19
0
def mace(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 3 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators):
        epsilon = numpyro.sample("epsilon",
                                 dist.Dirichlet(jnp.full(num_classes, 10)))
        theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample(
            "c",
            dist.DiscreteUniform(0, num_classes - 1),
            infer={"enumerate": "parallel"},
        )

        with numpyro.plate("position", num_positions):
            s = numpyro.sample(
                "s",
                dist.Bernoulli(1 - theta[positions]),
                infer={"enumerate": "parallel"},
            )
            probs = jnp.where(s[..., None] == 0, nn.one_hot(c, num_classes),
                              epsilon[positions])
            numpyro.sample("y", dist.Categorical(probs), obs=annotations)
コード例 #20
0
 def gmm(data):
     mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
     with numpyro.plate("num_clusters", K, dim=-1):
         cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.))
     with numpyro.plate("data", data.shape[0], dim=-1):
         assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions))
         numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data)
コード例 #21
0
def multinomial(annotations):
    """
    This model corresponds to the plate diagram in Figure 1 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Dirichlet(jnp.ones(num_classes)))

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        with numpyro.plate("position", num_positions):
            numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)
コード例 #22
0
def create_model(yT, yC, num_components):
    # Cosntants
    nC = yC.shape[0]
    nT = yT.shape[0]
    zC = jnp.isinf(yC).sum().item()
    zT = jnp.isinf(yT).sum().item()
    yT_finite = yT[jnp.isinf(yT) == False]
    yC_finite = yC_finite = yC[jnp.isinf(yC) == False]
    K = num_components
    
    p = numpyro.sample('p', dist.Beta(.5, .5))
    gammaC = numpyro.sample('gammaC', dist.Beta(1, 1))
    gammaT = numpyro.sample('gammaT', dist.Beta(1, 1))

    etaC = numpyro.sample('etaC', dist.Dirichlet(jnp.ones(K) / K))
    etaT = numpyro.sample('etaT', dist.Dirichlet(jnp.ones(K) / K))
    
    with numpyro.plate('mixutre_components', K):
        nu = numpyro.sample('nu', dist.LogNormal(3.5, 0.5))
        mu = numpyro.sample('mu', dist.Normal(0, 3))
        sigma = numpyro.sample('sigma', dist.LogNormal(0, .5))
        phi = numpyro.sample('phi', dist.Normal(0, 3))


    gammaT_star = simulate_data.compute_gamma_T_star(gammaC, gammaT, p)
    etaT_star = simulate_data.compute_eta_T_star(etaC, etaT, p, gammaC, gammaT,
                                                 gammaT_star)

    with numpyro.plate('y_C', nC - zC):
        numpyro.sample('finite_obs_C',
                       Mix(nu[None, :],
                           mu[None, :],
                           sigma[None, :],
                           phi[None, :],
                           etaC[None, :]), obs=yC_finite[:, None])

    with numpyro.plate('y_T', nT - zT):
        numpyro.sample('finite_obs_T',
                       Mix(nu[None, :],
                           mu[None, :],
                           sigma[None, :],
                           phi[None, :],
                           etaT_star[None, :]), obs=yT_finite[:, None])

    numpyro.sample('N_C', dist.Binomial(nC, gammaC), obs=zC)
    numpyro.sample('N_T', dist.Binomial(nT, gammaT_star), obs=zT)
コード例 #23
0
def gmm(data, num_components=3):
    mus = numpyro.sample('mus', dist.Normal(jnp.zeros(num_components),
                                            jnp.ones(num_components) * 100.).to_event(1))
    sigmas = numpyro.sample('sigmas', dist.HalfNormal(jnp.ones(num_components) * 100.).to_event(1))
    mixture_probs = numpyro.sample('mixture_probs', dist.Dirichlet(
        jnp.ones(num_components) / num_components))
    with numpyro.plate('data', len(data), dim=-1):
        z = numpyro.sample('z', dist.Categorical(mixture_probs))
        numpyro.sample('ll', dist.Normal(mus[z], sigmas[z]), obs=data)
コード例 #24
0
    def simplex_proposal(self, rng_key, x, grad_, hess_):
        max_non_diag_hess = np.max(hess_[np.logical_not(
            np.eye(hess_.shape[0], dtype=bool))].reshape(hess_.shape[0], -1),
                                   axis=1)
        concentration = 1 - x**2 * (np.diag(hess_) - max_non_diag_hess)

        dist_ = dist.Dirichlet(concentration=concentration + W_CORRECTION)

        return dist_.sample(rng_key).reshape(x.shape), dist_
コード例 #25
0
ファイル: test_inspect.py プロジェクト: pyro-ppl/numpyro
 def model():
     a = numpyro.sample("a", dist.Normal(0, 1))
     b = numpyro.sample("b", dist.Normal(a[..., None], jnp.ones(3)).to_event(1))
     c = numpyro.sample(
         "c", dist.MultivariateNormal(jnp.zeros(3) + a[..., None], jnp.eye(3))
     )
     with numpyro.plate("i", 2):
         d = numpyro.sample("d", dist.Dirichlet(jnp.exp(b + c)))
         numpyro.sample("e", dist.Categorical(logits=d), obs=jnp.array([0, 0]))
     return a, b, c, d
コード例 #26
0
def guide(docs, hyperparams, is_training=False, nn_framework="flax"):
    if nn_framework == "flax":
        encoder = flax_module(
            "encoder",
            FlaxEncoder(
                hyperparams["vocab_size"],
                hyperparams["num_topics"],
                hyperparams["hidden"],
                hyperparams["dropout_rate"],
            ),
            input_shape=(1, hyperparams["vocab_size"]),
            # ensure PRNGKey is made available to dropout layers
            apply_rng=["dropout"],
            # indicate mutable state due to BatchNorm layers
            mutable=["batch_stats"],
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    elif nn_framework == "haiku":
        encoder = haiku_module(
            "encoder",
            # use `transform_with_state` for BatchNorm
            hk.transform_with_state(
                HaikuEncoder(
                    hyperparams["vocab_size"],
                    hyperparams["num_topics"],
                    hyperparams["hidden"],
                    hyperparams["dropout_rate"],
                )),
            input_shape=(1, hyperparams["vocab_size"]),
            apply_rng=True,
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    else:
        raise ValueError(
            f"Invalid choice {nn_framework} for argument nn_framework")

    with numpyro.plate("documents",
                       docs.shape[0],
                       subsample_size=hyperparams["batch_size"]):
        batch_docs = numpyro.subsample(docs, event_dim=1)

        if nn_framework == "flax":
            concentration = encoder(batch_docs,
                                    is_training,
                                    rngs={"dropout": numpyro.prng_key()})
        elif nn_framework == "haiku":
            concentration = encoder(numpyro.prng_key(), batch_docs,
                                    is_training)

        numpyro.sample("theta", dist.Dirichlet(concentration))
コード例 #27
0
def semi_supervised_hmm(
    num_categories: int,
    num_words: int,
    supervised_categories: jnp.ndarray,
    supervised_words: jnp.ndarray,
    unsupervised_words: jnp.ndarray,
) -> None:

    transition_prior = jnp.ones(num_categories)
    emission_prior = jnp.repeat(0.1, num_words)

    transition_prob = numpyro.sample(
        "transition_prob",
        dist.Dirichlet(
            jnp.broadcast_to(transition_prior,
                             (num_categories, num_categories))),
    )
    emission_prob = numpyro.sample(
        "emission_prob",
        dist.Dirichlet(
            jnp.broadcast_to(emission_prior, (num_categories, num_words))),
    )

    numpyro.sample(
        "supervised_categories",
        dist.Categorical(transition_prob[supervised_categories[:-1]]),
        obs=supervised_categories[1:],
    )
    numpyro.sample(
        "supervised_words",
        dist.Categorical(emission_prob[supervised_categories]),
        obs=supervised_words,
    )

    transition_log_prob = jnp.log(transition_prob)
    emission_log_prob = jnp.log(emission_prob)
    init_log_prob = emission_log_prob[:, unsupervised_words[0]]
    log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:],
                                transition_log_prob, emission_log_prob)
    log_prob = logsumexp(log_prob, axis=0, keepdims=True)
    numpyro.factor("forward_log_prob", log_prob)
コード例 #28
0
def model(docs, hyperparams, is_training=False, nn_framework="flax"):
    if nn_framework == "flax":
        decoder = flax_module(
            "decoder",
            FlaxDecoder(hyperparams["vocab_size"],
                        hyperparams["dropout_rate"]),
            input_shape=(1, hyperparams["num_topics"]),
            # ensure PRNGKey is made available to dropout layers
            apply_rng=["dropout"],
            # indicate mutable state due to BatchNorm layers
            mutable=["batch_stats"],
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    elif nn_framework == "haiku":
        decoder = haiku_module(
            "decoder",
            # use `transform_with_state` for BatchNorm
            hk.transform_with_state(
                HaikuDecoder(hyperparams["vocab_size"],
                             hyperparams["dropout_rate"])),
            input_shape=(1, hyperparams["num_topics"]),
            apply_rng=True,
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    else:
        raise ValueError(
            f"Invalid choice {nn_framework} for argument nn_framework")

    with numpyro.plate("documents",
                       docs.shape[0],
                       subsample_size=hyperparams["batch_size"]):
        batch_docs = numpyro.subsample(docs, event_dim=1)
        theta = numpyro.sample(
            "theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"])))

        if nn_framework == "flax":
            logits = decoder(theta,
                             is_training,
                             rngs={"dropout": numpyro.prng_key()})
        elif nn_framework == "haiku":
            logits = decoder(numpyro.prng_key(), theta, is_training)

        total_count = batch_docs.sum(-1)
        numpyro.sample("obs",
                       dist.Multinomial(total_count, logits=logits),
                       obs=batch_docs)
コード例 #29
0
def model_1(
    sequences: Optional[np.ndarray] = None,
    lengths: Optional[np.ndarray] = None,
    hidden_dim: int = 16,
    batch: int = 100,
    seq_len: int = 0,
    data_dim: int = 10,
    future_steps: int = 0,
) -> None:

    if sequences is not None:
        assert lengths is not None
        batch, seq_len, data_dim = sequences.shape

        future = np.zeros((batch, future_steps, data_dim))
        sequences = np.concatenate([sequences, future], axis=1)
    else:
        lengths = np.zeros((batch))

    probs_x = numpyro.sample(
        "probs_x",
        dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1))
    probs_y = numpyro.sample(
        "probs_y",
        dist.Beta(0.1, 0.9).expand([hidden_dim, data_dim]).to_event(2))

    def transition_fn(
            carry: Tuple[jnp.ndarray, jnp.ndarray], y: jnp.ndarray
    ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
        """One time step funciton."""

        x_prev, t = carry
        with numpyro.plate("sequence", batch, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
                with numpyro.plate("tones", data_dim, dim=-1):
                    numpyro.sample("y",
                                   dist.Bernoulli(probs_y[x.squeeze(-1)]),
                                   obs=y)
        return (x, t + 1), None

    x_init = jnp.zeros((batch, 1), dtype=jnp.int32)

    if sequences is not None:
        # for loop with time step: data shape = (seq, batch, data_dim)
        scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
    else:
        scan(transition_fn, (x_init, 0), None, length=seq_len + future_steps)
コード例 #30
0
ファイル: gaussian_mixture_model.py プロジェクト: DPBayes/d3p
def guide(k, obs=None, num_obs_total=None, d=None):
    # the latent MixGaus distribution which learns the parameters
    if obs is not None:
        assert(jnp.ndim(obs) == 2)
        _, d = jnp.shape(obs)
    else:
        assert(num_obs_total is not None)
        assert(d is not None)

    alpha_log = param('alpha_log', jnp.zeros(k))
    alpha = jnp.exp(alpha_log)
    pis = sample('pis', dist.Dirichlet(alpha))

    mus_loc = param('mus_loc', jnp.zeros((k, d)))
    mus = sample('mus', dist.Normal(mus_loc, 1.))
    sigs = sample('sigs', dist.InverseGamma(1., 1.), obs=jnp.ones_like(mus))
    return pis, mus, sigs