Exemple #1
0
def test_guide_enumerated_elbo(model, guide, data, history):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"), \
        pytest.raises(
            NotImplementedError,
            match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"):

        if history > 1:
            pytest.xfail(
                reason="TraceMarkovEnum_ELBO does not yet support history > 1")

        elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
        expected_loss = elbo.loss_and_grads(model, guide, data, history, False)
        expected_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
        actual_loss = vectorized_elbo.loss_and_grads(model, guide, data,
                                                     history, True)
        actual_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        assert_close(actual_loss, expected_loss)
        for actual_grad, expected_grad in zip(actual_grads, expected_grads):
            assert_close(actual_grad, expected_grad)
Exemple #2
0
    def load_checkpoint(
        self,
        path: Union[str, Path] = None,
        param_only: bool = False,
        warnings: bool = False,
    ):
        """
        Load checkpoint.

        :param path: Path to model checkpoint.
        :param param_only: Load only parameters.
        :param warnings: Give warnings if loaded model has not been fully trained.
        """
        device = self.device
        path = Path(path) if path else self.run_path
        pyro.clear_param_store()
        checkpoint = torch.load(
            path / f"{self.full_name}-model.tpqr", map_location=device
        )
        pyro.get_param_store().set_state(checkpoint["params"])
        if not param_only:
            self.converged = checkpoint["convergence_status"]
            self._rolling = checkpoint["rolling"]
            self.iter = checkpoint["iter"]
            self.optim.set_state(checkpoint["optimizer"])
            logger.info(
                f"Iteration #{self.iter}. Loaded a model checkpoint from {path}"
            )
        if warnings and not checkpoint["convergence_status"]:
            logger.warning(f"Model at {path} has not been fully trained")
Exemple #3
0
def test_model_enumerated_elbo(model, guide, data, history):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"):
        if history > 1:
            pytest.xfail(
                reason="TraceMarkovEnum_ELBO does not yet support history > 1")

        model = infer.config_enumerate(model, default="parallel")
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
        expected_loss = elbo.loss_and_grads(model, guide, data, history, False)
        expected_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
        actual_loss = vectorized_elbo.loss_and_grads(model, guide, data,
                                                     history, True)
        actual_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        assert_close(actual_loss, expected_loss)
        for actual_grad, expected_grad in zip(actual_grads, expected_grads):
            assert_close(actual_grad, expected_grad)
Exemple #4
0
    def init(
        self,
        lr: float = 0.005,
        nbatch_size: int = 5,
        fbatch_size: int = 512,
        jit: bool = False,
    ) -> None:
        """
        Initialize SVI object.

        :param lr: Learning rate.
        :param nbatch_size: AOI batch size.
        :param fbatch_size: Frame batch size.
        :param jit: Use JIT compiler.
        """
        self.lr = lr
        self.optim_fn = optim.Adam
        self.optim_args = {"lr": lr, "betas": [0.9, 0.999]}
        self.optim = self.optim_fn(self.optim_args)

        try:
            self.load_checkpoint()
        except (FileNotFoundError, TypeError):
            pyro.clear_param_store()
            self.iter = 0
            self.converged = False
            self._rolling = {p: deque([], maxlen=100) for p in self.conv_params}
            self.init_parameters()

        self.elbo = self.TraceELBO(jit)
        self.svi = infer.SVI(self.model, self.guide, self.optim, loss=self.elbo)

        self.nbatch_size = min(nbatch_size, self.data.Nt)
        self.fbatch_size = min(fbatch_size, self.data.F)
Exemple #5
0
def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs):
    """
    Assert that enumeration runs...
    """
    with pyro_backend("pyro"):
        pyro.clear_param_store()

    if guide is None:
        guide = lambda **kwargs: None  # noqa: E731

    q_pyro, q_funsor = LifoQueue(), LifoQueue()
    q_pyro.put(Trace())
    q_funsor.put(Trace())

    while not q_pyro.empty() and not q_funsor.empty():
        with pyro_backend("pyro"):
            with handlers.enum(first_available_dim=-max_plate_nesting - 1):
                guide_tr_pyro = handlers.trace(
                    handlers.queue(
                        guide,
                        q_pyro,
                        escape_fn=iter_discrete_escape,
                        extend_fn=iter_discrete_extend,
                    )).get_trace(**kwargs)
                tr_pyro = handlers.trace(
                    handlers.replay(model,
                                    trace=guide_tr_pyro)).get_trace(**kwargs)

        with pyro_backend("contrib.funsor"):
            with handlers.enum(first_available_dim=-max_plate_nesting - 1):
                guide_tr_funsor = handlers.trace(
                    handlers.queue(
                        guide,
                        q_funsor,
                        escape_fn=iter_discrete_escape,
                        extend_fn=iter_discrete_extend,
                    )).get_trace(**kwargs)
                tr_funsor = handlers.trace(
                    handlers.replay(model,
                                    trace=guide_tr_funsor)).get_trace(**kwargs)

        # make sure all dimensions were cleaned up
        assert _DIM_STACK.local_frame is _DIM_STACK.global_frame
        assert (not _DIM_STACK.global_frame.name_to_dim
                and not _DIM_STACK.global_frame.dim_to_name)
        assert _DIM_STACK.outermost is None

        tr_pyro = prune_subsample_sites(tr_pyro.copy())
        tr_funsor = prune_subsample_sites(tr_funsor.copy())
        _check_traces(tr_pyro, tr_funsor)
Exemple #6
0
def train(model, guide, lr=1e-3, n_steps=1000, jit=True, verbose=False, **kwargs):

    pyro.clear_param_store()
    optimizer = optim.Adam({"lr": lr})
    elbo = (
        infer.JitTraceEnum_ELBO(max_plate_nesting=2)
        if jit
        else infer.TraceEnum_ELBO(max_plate_nesting=2)
    )
    svi = infer.SVI(model, guide, optimizer, elbo)

    for step in range(n_steps):
        svi.step(**kwargs)
        if step % 100 == 99 and verbose:
            values = tuple(f"{k}: {v}" for k, v in pyro.get_param_store().items())
            print(values)
Exemple #7
0
def main(args):
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    logging.info('Loading data')
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info('-' * 40)
    model = models[args.model]
    logging.info('Training {} on {} sequences'.format(
        model.__name__, len(data['train']['sequences'])))
    sequences = data['train']['sequences']
    lengths = data['train']['sequence_lengths']

    # find all the notes that are present at least once in the training set
    present_notes = ((sequences == 1).sum(0).sum(0) > 0)
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()
    pyro.enable_validation(__debug__)

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        handlers.block(model,
                       expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = handlers.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = handlers.trace(
            handlers.replay(handlers.enum(model, first_available_dim),
                            guide_trace)).get_trace(sequences,
                                                    lengths,
                                                    args=args,
                                                    batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Bind non-PyTorch parameters to make these functions jittable.
    model = functools.partial(model, args=args)
    guide = functools.partial(guide, args=args)

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    optimizer = optim.Adam({'lr': args.learning_rate})
    if args.tmc:
        if args.jit and not args.funsor:
            raise NotImplementedError(
                "jit support not yet added for TraceTMC_ELBO")
        Elbo = infer.JitTraceTMC_ELBO if args.jit else infer.TraceTMC_ELBO
        elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2)
        tmc_model = handlers.infer_config(model, lambda msg: {
            "num_samples": args.tmc_num_samples,
            "expand": False
        } if msg["infer"].get("enumerate", None) == "parallel" else {}
                                          )  # noqa: E501
        svi = infer.SVI(tmc_model, guide, optimizer, elbo)
    else:
        Elbo = infer.JitTraceEnum_ELBO if args.jit else infer.TraceEnum_ELBO
        elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2,
                    strict_enumeration_warning=True,
                    jit_options={"time_compilation": args.time_compilation})
        svi = infer.SVI(model, guide, optimizer, elbo)

    # We'll train on small minibatches.
    logging.info('Step\tLoss')
    for step in range(args.num_steps):
        loss = svi.step(sequences, lengths, batch_size=args.batch_size)
        logging.info('{: >5d}\t{}'.format(step, loss / num_observations))

    if args.jit and args.time_compilation:
        logging.debug('time to compile: {} s.'.format(
            elbo._differentiable_loss.compile_time))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           batch_size=sequences.shape[0],
                           include_prior=False)
    logging.info('training loss = {}'.format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info('-' * 40)
    logging.info('Evaluating on {} test sequences'.format(
        len(data['test']['sequences'])))
    sequences = data['test']['sequences'][..., present_notes]
    lengths = data['test']['sequence_lengths']
    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          batch_size=sequences.shape[0],
                          include_prior=False)
    logging.info('test loss = {}'.format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info('model_{} capacity = {} parameters'.format(
        args.model, capacity))
Exemple #8
0
def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history,
                           use_replay):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"):
        with handlers.enum():
            enum_model = infer.config_enumerate(model, default="parallel")
            # sequential factors
            trace = handlers.trace(enum_model).get_trace(
                weeks_data, days_data, history, False)

            # vectorized trace
            if use_replay:
                guide_trace = handlers.trace(
                    _guide_from_model(model)).get_trace(
                        weeks_data, days_data, history, True)
                vectorized_trace = handlers.trace(
                    handlers.replay(model, trace=guide_trace)).get_trace(
                        weeks_data, days_data, history, True)
            else:
                vectorized_trace = handlers.trace(enum_model).get_trace(
                    weeks_data, days_data, history, True)

        factors = list()
        # sequential weeks factors
        for i in range(len(weeks_data)):
            for v in vars1:
                factors.append(trace.nodes["{}_{}".format(
                    v, i)]["funsor"]["log_prob"])
        # sequential days factors
        for j in range(len(days_data)):
            for v in vars2:
                factors.append(trace.nodes["{}_{}".format(
                    v, j)]["funsor"]["log_prob"])

        vectorized_factors = list()
        # vectorized weeks factors
        for i in range(history):
            for v in vars1:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, i)]["funsor"]["log_prob"])
        for i in range(history, len(weeks_data)):
            for v in vars1:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, slice(history,
                                 len(weeks_data)))]["funsor"]["log_prob"](**{
                                     "weeks":
                                     i - history
                                 }, **{
                                     "{}_{}".format(
                                         k,
                                         slice(history - j,
                                               len(weeks_data) - j)):
                                     "{}_{}".format(k, i - j)
                                     for j in range(history + 1) for k in vars1
                                 }))
        # vectorized days factors
        for i in range(history):
            for v in vars2:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, i)]["funsor"]["log_prob"])
        for i in range(history, len(days_data)):
            for v in vars2:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, slice(history,
                                 len(days_data)))]["funsor"]["log_prob"](**{
                                     "days":
                                     i - history
                                 }, **{
                                     "{}_{}".format(
                                         k,
                                         slice(history - j,
                                               len(days_data) - j)):
                                     "{}_{}".format(k, i - j)
                                     for j in range(history + 1) for k in vars2
                                 }))

        # assert correct factors
        for f1, f2 in zip(factors, vectorized_factors):
            assert_close(f2, f1.align(tuple(f2.inputs)))

        # assert correct step

        expected_measure_vars = frozenset()
        actual_weeks_step = vectorized_trace.nodes["weeks"]["value"]
        # expected step: assume that all but the last var is markov
        expected_weeks_step = frozenset()
        for v in vars1[:-1]:
            v_step = tuple("{}_{}".format(v, i) for i in range(history)) \
                     + tuple("{}_{}".format(v, slice(j, len(weeks_data)-history+j)) for j in range(history+1))
            expected_weeks_step |= frozenset({v_step})
            # grab measure_vars, found only at sites that are not replayed
            if not use_replay:
                expected_measure_vars |= frozenset(v_step)

        actual_days_step = vectorized_trace.nodes["days"]["value"]
        # expected step: assume that all but the last var is markov
        expected_days_step = frozenset()
        for v in vars2[:-1]:
            v_step = tuple("{}_{}".format(v, i) for i in range(history)) \
                     + tuple("{}_{}".format(v, slice(j, len(days_data)-history+j)) for j in range(history+1))
            expected_days_step |= frozenset({v_step})
            # grab measure_vars, found only at sites that are not replayed
            if not use_replay:
                expected_measure_vars |= frozenset(v_step)

        assert actual_weeks_step == expected_weeks_step
        assert actual_days_step == expected_days_step

        # check measure_vars
        actual_measure_vars = terms_from_trace(
            vectorized_trace)["measure_vars"]
        assert actual_measure_vars == expected_measure_vars
Exemple #9
0
def simulate(
        model: str,
        N: int,
        F: int,
        C: int = 1,
        P: int = 14,
        seed: int = 0,
        params: dict = dict(),
) -> CosmosDataset:
    """
    Simulate a new dataset.

    :param model: Tapqir model.
    :param N: Number of total AOIs. Half will be on-target and half off-target.
    :param F: Number of frames.
    :param C: Number of color channels.
    :param P: Number of pixels alongs the axis.
    :param seed: Rng seed.
    :param params: A dictionary of fixed parameter values.
    :return: A new simulated data set.
    """

    pyro.set_rng_seed(seed)
    pyro.clear_param_store()

    # samples
    samples = {}
    samples["gain"] = torch.full((1, 1), params["gain"])
    samples["lamda"] = torch.full((1, 1), params["lamda"])
    samples["proximity"] = torch.full((1, 1), params["proximity"])

    if "pi" in params:
        samples["pi"] = torch.tensor([[1 - params["pi"], params["pi"]]])
        samples["background"] = torch.full((1, N, 1), params["background"])
        for k in range(model.K):
            samples[f"width_{k}"] = torch.full((1, N, F), params["width"])
            samples[f"height_{k}"] = torch.full((1, N, F), params["height"])
    else:
        # kinetic simulations
        samples["init"] = torch.tensor([[
            params["koff"] / (params["kon"] + params["koff"]),
            params["kon"] / (params["kon"] + params["koff"]),
        ]])
        samples["trans"] = torch.tensor([[
            [1 - params["kon"], params["kon"]],
            [params["koff"], 1 - params["koff"]],
        ]])
        for f in range(F):
            samples[f"background_{f}"] = torch.full((1, N, 1),
                                                    params["background"])
            for k in range(model.K):
                samples[f"width_{k}_{f}"] = torch.full((1, N, 1),
                                                       params["width"])
                samples[f"height_{k}_{f}"] = torch.full((1, N, 1),
                                                        params["height"])

    offset = torch.full((3, ), params["offset"])
    target_locs = torch.full((N, F, 1, 2), (P - 1) / 2)
    is_ontarget = torch.zeros((N, ), dtype=torch.bool)
    is_ontarget[:N // 2] = True
    # placeholder dataset
    model.data = CosmosDataset(
        torch.full((N, F, C, P, P), params["background"] + params["offset"]),
        target_locs,
        is_ontarget,
        None,
        offset_samples=offset,
        offset_weights=torch.ones(3) / 3,
        device=model.device,
    )

    # sample
    predictive = Predictive(handlers.uncondition(model.model),
                            posterior_samples=samples,
                            num_samples=1)
    samples = predictive()
    data = torch.zeros(N, F, C, P, P)
    labels = np.zeros((N // 2, F, 1),
                      dtype=[("aoi", int), ("frame", int), ("z", bool)])
    labels["aoi"] = np.arange(N // 2).reshape(-1, 1, 1)
    labels["frame"] = np.arange(F).reshape(-1, 1)
    if "pi" in params:
        data[:, :, 0] = samples["data"][0].data.floor()
        labels["z"][:, :, 0] = samples["theta"][0][:N // 2].cpu() > 0
    else:
        # kinetic simulations
        for f in range(F):
            data[:, f:f + 1, 0] = samples[f"data_{f}"][0].data.floor()
            labels["z"][:, f:f + 1,
                        0] = samples[f"theta_{f}"][0][:N // 2].cpu() > 0

    return CosmosDataset(
        data.cpu(),
        target_locs.cpu(),
        is_ontarget.cpu(),
        labels,
        offset_samples=offset,
        offset_weights=torch.ones(3) / 3,
        device=model.device,
    )