Beispiel #1
0
def test_model_enumerated_elbo(model, guide, data, history):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"):
        if history > 1:
            pytest.xfail(
                reason="TraceMarkovEnum_ELBO does not yet support history > 1")

        model = infer.config_enumerate(model, default="parallel")
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
        expected_loss = elbo.loss_and_grads(model, guide, data, history, False)
        expected_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
        actual_loss = vectorized_elbo.loss_and_grads(model, guide, data,
                                                     history, True)
        actual_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        assert_close(actual_loss, expected_loss)
        for actual_grad, expected_grad in zip(actual_grads, expected_grads):
            assert_close(actual_grad, expected_grad)
Beispiel #2
0
def test_guide_enumerated_elbo(model, guide, data, history):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"), \
        pytest.raises(
            NotImplementedError,
            match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"):

        if history > 1:
            pytest.xfail(
                reason="TraceMarkovEnum_ELBO does not yet support history > 1")

        elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
        expected_loss = elbo.loss_and_grads(model, guide, data, history, False)
        expected_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
        actual_loss = vectorized_elbo.loss_and_grads(model, guide, data,
                                                     history, True)
        actual_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        assert_close(actual_loss, expected_loss)
        for actual_grad, expected_grad in zip(actual_grads, expected_grads):
            assert_close(actual_grad, expected_grad)
Beispiel #3
0
    def load_checkpoint(
        self,
        path: Union[str, Path] = None,
        param_only: bool = False,
        warnings: bool = False,
    ):
        """
        Load checkpoint.

        :param path: Path to model checkpoint.
        :param param_only: Load only parameters.
        :param warnings: Give warnings if loaded model has not been fully trained.
        """
        device = self.device
        path = Path(path) if path else self.run_path
        pyro.clear_param_store()
        checkpoint = torch.load(
            path / f"{self.full_name}-model.tpqr", map_location=device
        )
        pyro.get_param_store().set_state(checkpoint["params"])
        if not param_only:
            self.converged = checkpoint["convergence_status"]
            self._rolling = checkpoint["rolling"]
            self.iter = checkpoint["iter"]
            self.optim.set_state(checkpoint["optimizer"])
            logger.info(
                f"Iteration #{self.iter}. Loaded a model checkpoint from {path}"
            )
        if warnings and not checkpoint["convergence_status"]:
            logger.warning(f"Model at {path} has not been fully trained")
Beispiel #4
0
def assert_ok(model, guide, elbo, *args, **kwargs):
    """
    Assert that inference works without warnings or errors.
    """
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    inference = infer.SVI(model, guide, adam, elbo)
    for i in range(2):
        inference.step(*args, **kwargs)
Beispiel #5
0
def assert_error(model, guide, elbo, match=None):
    """
    Assert that inference fails with an error.
    """
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    inference = infer.SVI(model,  guide, adam, elbo)
    with pytest.raises((NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError),
                       match=match):
        inference.step()
Beispiel #6
0
def assert_warning(model, guide, elbo):
    """
    Assert that inference works but with a warning.
    """
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    inference = infer.SVI(model, guide, adam, elbo)
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        inference.step()
        assert len(w), 'No warnings were raised'
        for warning in w:
            print(warning)
Beispiel #7
0
def main(args):
    funsor.set_backend("torch")

    # Define a basic model with a single Normal latent random variable `loc`
    # and a batch of Normally distributed observations.
    def model(data):
        loc = pyro.sample("loc", dist.Normal(0., 1.))
        with pyro.plate("data", len(data), dim=-1):
            pyro.sample("obs", dist.Normal(loc, 1.), obs=data)

    # Define a guide (i.e. variational distribution) with a Normal
    # distribution over the latent random variable `loc`.
    def guide(data):
        guide_loc = pyro.param("guide_loc", torch.tensor(0.))
        guide_scale = pyro.param("guide_scale", torch.tensor(1.),
                                 constraint=constraints.positive)
        pyro.sample("loc", dist.Normal(guide_loc, guide_scale))

    # Generate some data.
    torch.manual_seed(0)
    data = torch.randn(100) + 3.0

    # Because the API in minipyro matches that of Pyro proper,
    # training code works with generic Pyro implementations.
    with pyro_backend(args.backend), interpretation(MonteCarlo()):
        # Construct an SVI object so we can do variational inference on our
        # model/guide pair.
        Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO
        elbo = Elbo()
        adam = optim.Adam({"lr": args.learning_rate})
        svi = infer.SVI(model, guide, adam, elbo)

        # Basic training loop
        pyro.get_param_store().clear()
        for step in range(args.num_steps):
            loss = svi.step(data)
            if args.verbose and step % 100 == 0:
                print("step {} loss = {}".format(step, loss))

        # Report the final values of the variational parameters
        # in the guide after training.
        if args.verbose:
            for name in pyro.get_param_store():
                value = pyro.param(name).data
                print("{} = {}".format(name, value.detach().cpu().numpy()))

        # For this simple (conjugate) model we know the exact posterior. In
        # particular we know that the variational distribution should be
        # centered near 3.0. So let's check this explicitly.
        assert (pyro.param("guide_loc") - 3.0).abs() < 0.1
Beispiel #8
0
def test_elbo_plate_plate(backend, outer_dim, inner_dim):
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        num_particles = 1
        q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True))
        p = 0.2693204236205713  # for which kl(Categorical(q), Categorical(p)) = 0.5
        p = torch.tensor([p, 1 - p])

        def model():
            d = dist.Categorical(p)
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d)
            with context1:
                pyro.sample("x", d)
            with context2:
                pyro.sample("y", d)
            with context1, context2:
                pyro.sample("z", d)

        def guide():
            d = dist.Categorical(pyro.param("q"))
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d, infer={"enumerate": "parallel"})
            with context1:
                pyro.sample("x", d, infer={"enumerate": "parallel"})
            with context2:
                pyro.sample("y", d, infer={"enumerate": "parallel"})
            with context1, context2:
                pyro.sample("z", d, infer={"enumerate": "parallel"})

        kl_node = kl_divergence(
            torch.distributions.Categorical(funsor.to_data(q)),
            torch.distributions.Categorical(funsor.to_data(p)))
        kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node
        expected_loss = kl
        expected_grad = grad(kl, [funsor.to_data(q)])[0]

        elbo = infer.TraceEnum_ELBO(num_particles=num_particles,
                                    vectorize_particles=True,
                                    strict_enumeration_warning=True)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        actual_loss = funsor.to_data(elbo(model, guide))
        actual_loss.backward()
        actual_grad = funsor.to_data(pyro.param('q')).grad

        assert ops.allclose(actual_loss, expected_loss, atol=1e-5)
        assert ops.allclose(actual_grad, expected_grad, atol=1e-5)
Beispiel #9
0
def test_optimizer(backend, optim_name, jit):
    def model(data):
        p = pyro.param("p", torch.tensor(0.5))
        pyro.sample("x", dist.Bernoulli(p), obs=data)

    def guide(data):
        pass

    data = torch.tensor(0.)
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        optimizer = getattr(optim, optim_name)({"lr": 1e-6})
        inference = infer.SVI(model, guide, optimizer, elbo)
        for i in range(2):
            inference.step(data)
Beispiel #10
0
def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4):
    # copied from pyro
    expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data(actual_loss)
    assert ops.allclose(actual_loss, expected_loss, atol=atol, rtol=rtol)
    names = pyro.get_param_store().keys()
    params = []
    for name in names:
        params.append(funsor.to_data(pyro.param(name)).unconstrained())
    actual_grads = grad(actual_loss, params, allow_unused=True, retain_graph=True)
    expected_grads = grad(expected_loss, params, allow_unused=True, retain_graph=True)
    for name, actual_grad, expected_grad in zip(names, actual_grads, expected_grads):
        if actual_grad is None or expected_grad is None:
            continue
        assert ops.allclose(actual_grad, expected_grad, atol=atol, rtol=rtol)
Beispiel #11
0
def train(model, guide, lr=1e-3, n_steps=1000, jit=True, verbose=False, **kwargs):

    pyro.clear_param_store()
    optimizer = optim.Adam({"lr": lr})
    elbo = (
        infer.JitTraceEnum_ELBO(max_plate_nesting=2)
        if jit
        else infer.TraceEnum_ELBO(max_plate_nesting=2)
    )
    svi = infer.SVI(model, guide, optimizer, elbo)

    for step in range(n_steps):
        svi.step(**kwargs)
        if step % 100 == 99 and verbose:
            values = tuple(f"{k}: {v}" for k, v in pyro.get_param_store().items())
            print(values)
Beispiel #12
0
def main(args):
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    logging.info('Loading data')
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info('-' * 40)
    model = models[args.model]
    logging.info('Training {} on {} sequences'.format(
        model.__name__, len(data['train']['sequences'])))
    sequences = data['train']['sequences']
    lengths = data['train']['sequence_lengths']

    # find all the notes that are present at least once in the training set
    present_notes = ((sequences == 1).sum(0).sum(0) > 0)
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()
    pyro.enable_validation(__debug__)

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        handlers.block(model,
                       expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = handlers.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = handlers.trace(
            handlers.replay(handlers.enum(model, first_available_dim),
                            guide_trace)).get_trace(sequences,
                                                    lengths,
                                                    args=args,
                                                    batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Bind non-PyTorch parameters to make these functions jittable.
    model = functools.partial(model, args=args)
    guide = functools.partial(guide, args=args)

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    optimizer = optim.Adam({'lr': args.learning_rate})
    if args.tmc:
        if args.jit and not args.funsor:
            raise NotImplementedError(
                "jit support not yet added for TraceTMC_ELBO")
        Elbo = infer.JitTraceTMC_ELBO if args.jit else infer.TraceTMC_ELBO
        elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2)
        tmc_model = handlers.infer_config(model, lambda msg: {
            "num_samples": args.tmc_num_samples,
            "expand": False
        } if msg["infer"].get("enumerate", None) == "parallel" else {}
                                          )  # noqa: E501
        svi = infer.SVI(tmc_model, guide, optimizer, elbo)
    else:
        Elbo = infer.JitTraceEnum_ELBO if args.jit else infer.TraceEnum_ELBO
        elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2,
                    strict_enumeration_warning=True,
                    jit_options={"time_compilation": args.time_compilation})
        svi = infer.SVI(model, guide, optimizer, elbo)

    # We'll train on small minibatches.
    logging.info('Step\tLoss')
    for step in range(args.num_steps):
        loss = svi.step(sequences, lengths, batch_size=args.batch_size)
        logging.info('{: >5d}\t{}'.format(step, loss / num_observations))

    if args.jit and args.time_compilation:
        logging.debug('time to compile: {} s.'.format(
            elbo._differentiable_loss.compile_time))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           batch_size=sequences.shape[0],
                           include_prior=False)
    logging.info('training loss = {}'.format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info('-' * 40)
    logging.info('Evaluating on {} test sequences'.format(
        len(data['test']['sequences'])))
    sequences = data['test']['sequences'][..., present_notes]
    lengths = data['test']['sequence_lengths']
    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          batch_size=sequences.shape[0],
                          include_prior=False)
    logging.info('test loss = {}'.format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info('model_{} capacity = {} parameters'.format(
        args.model, capacity))
Beispiel #13
0
    def save_checkpoint(self, writer: SummaryWriter = None):
        """
        Save checkpoint.

        :param writer: SummaryWriter object.
        """
        # save only if no NaN values
        for k, v in pyro.get_param_store().items():
            if torch.isnan(v).any() or torch.isinf(v).any():
                raise ValueError(
                    "Iteration #{}. Detected NaN values in {}".format(self.iter, k)
                )

        # update convergence criteria parameters
        for name in self.conv_params:
            if name == "-ELBO":
                self._rolling["-ELBO"].append(self.iter_loss)
            else:
                self._rolling[name].append(pyro.param(name).item())

        # check convergence status
        self.converged = False
        if len(self._rolling["-ELBO"]) == self._rolling["-ELBO"].maxlen:
            crit = all(
                torch.tensor(self._rolling[p]).std()
                / torch.tensor(self._rolling[p])[-50:].std()
                < 1.05
                for p in self.conv_params
            )
            if crit:
                self.converged = True

        # save the model state
        torch.save(
            {
                "iter": self.iter,
                "params": pyro.get_param_store().get_state(),
                "optimizer": self.optim.get_state(),
                "rolling": self._rolling,
                "convergence_status": self.converged,
            },
            self.run_path / f"{self.full_name}-model.tpqr",
        )

        # save global paramters for tensorboard
        writer.add_scalar("-ELBO", self.iter_loss, self.iter)
        for name, val in pyro.get_param_store().items():
            if val.dim() == 0:
                writer.add_scalar(name, val.item(), self.iter)
            elif val.dim() == 1 and len(val) <= self.S + 1:
                scalars = {str(i): v.item() for i, v in enumerate(val)}
                writer.add_scalars(name, scalars, self.iter)

        if False and self.data.labels is not None:
            pred_labels = (
                self.pspecific_map[self.data.is_ontarget].cpu().numpy().ravel()
            )
            true_labels = self.data.labels["z"].ravel()

            metrics = {}
            with np.errstate(divide="ignore", invalid="ignore"):
                metrics["MCC"] = matthews_corrcoef(true_labels, pred_labels)
            metrics["Recall"] = recall_score(true_labels, pred_labels, zero_division=0)
            metrics["Precision"] = precision_score(
                true_labels, pred_labels, zero_division=0
            )

            neg, pos = {}, {}
            neg["TN"], neg["FP"], pos["FN"], pos["TP"] = confusion_matrix(
                true_labels, pred_labels, labels=(0, 1)
            ).ravel()

            writer.add_scalars("ACCURACY", metrics, self.iter)
            writer.add_scalars("NEGATIVES", neg, self.iter)
            writer.add_scalars("POSITIVES", pos, self.iter)

        logger.debug(f"Iteration #{self.iter}: Successful.")