def config_enumerate(guide=None, default="sequential"): """ Configures each enumerable site a guide to enumerate with given method, ``site["infer"]["enumerate"] = default``. This can be used as either a function:: guide = config_enumerate(guide) or as a decorator:: @config_enumerate def guide1(*args, **kwargs): ... @config_enumerate(default="parallel") def guide2(*args, **kwargs): ... This does not overwrite existing annotations ``infer={"enumerate": ...}``. :param callable guide: a pyro model that will be used as a guide in :class:`~pyro.infer.svi.SVI`. :param str default: one of "sequential", "parallel", or None. :return: an annotated guide :rtype: callable """ if default not in ["sequential", "parallel", None]: raise ValueError( "Invalid default value. Expected 'sequential', 'parallel', or None, but got {}" .format(repr(default))) # Support usage as a decorator: if guide is None: return lambda guide: config_enumerate(guide, default=default) return poutine.infer_config(guide, config_fn=_config_enumerate(default))
def config_enumerate(guide=None, default="sequential"): """ Configures each enumerable site a guide to enumerate with given method, ``site["infer"]["enumerate"] = default``. This can be used as either a function:: guide = config_enumerate(guide) or as a decorator:: @config_enumerate def guide1(*args, **kwargs): ... @config_enumerate(default="parallel") def guide2(*args, **kwargs): ... This does not overwrite existing annotations ``infer={"enumerate": ...}``. :param callable guide: a pyro model that will be used as a guide in :class:`~pyro.infer.svi.SVI`. :param str default: one of "sequential", "parallel", or None. :return: an annotated guide :rtype: callable """ if default not in ["sequential", "parallel", None]: raise ValueError("Invalid default value. Expected 'sequential', 'parallel', or None, but got {}".format( repr(default))) # Support usage as a decorator: if guide is None: return lambda guide: config_enumerate(guide, default=default) return poutine.infer_config(guide, config_fn=_config_enumerate(default))
def config_enumerate(guide=None, default="parallel", expand=False, num_samples=None): """ Configures enumeration for all relevant sites in a guide. This is mainly used in conjunction with :class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO`. When configuring for exhaustive enumeration of discrete variables, this configures all sample sites whose distribution satisfies ``.has_enumerate_support == True``. When configuring for local parallel Monte Carlo sampling via ``default="parallel", num_samples=n``, this configures all sample sites. This does not overwrite existing annotations ``infer={"enumerate": ...}``. This can be used as either a function:: guide = config_enumerate(guide) or as a decorator:: @config_enumerate def guide1(*args, **kwargs): ... @config_enumerate(default="sequential", expand=True) def guide2(*args, **kwargs): ... :param callable guide: a pyro model that will be used as a guide in :class:`~pyro.infer.svi.SVI`. :param str default: Which enumerate strategy to use, one of "sequential", "parallel", or None. Defaults to "parallel". :param bool expand: Whether to expand enumerated sample values. See :meth:`~pyro.distributions.Distribution.enumerate_support` for details. This only applies to exhaustive enumeration, where ``num_samples=None``. If ``num_samples`` is not ``None``, then this samples will always be expanded. :param num_samples: if not ``None``, use local Monte Carlo sampling rather than exhaustive enumeration. This makes sense for both continuous and discrete distributions. :type num_samples: int or None :return: an annotated guide :rtype: callable """ if default not in ["sequential", "parallel", None]: raise ValueError("Invalid default value. Expected 'sequential', 'parallel', or None, but got {}".format( repr(default))) if expand not in [True, False]: raise ValueError("Invalid expand value. Expected True or False, but got {}".format(repr(expand))) if num_samples is not None: if not (isinstance(num_samples, numbers.Number) and num_samples > 0): raise ValueError("Invalid num_samples, expected None or positive integer, but got {}".format( repr(num_samples))) if default == "sequential": raise ValueError('Local sampling does not support "sequential" sampling; ' 'use "parallel" sampling instead.') # Support usage as a decorator: if guide is None: return lambda guide: config_enumerate(guide, default=default, expand=expand, num_samples=num_samples) return poutine.infer_config(guide, config_fn=_config_enumerate(default, expand, num_samples))
def test_infer_config_sample(self): cfg_model = poutine.infer_config(self.model, config_fn=self.config_fn) tr = poutine.trace(cfg_model).get_trace() assert tr.nodes["a"]["infer"] == {"enumerate": "parallel", "blah": True} assert tr.nodes["b"]["infer"] == {"blah": True} assert tr.nodes["p"]["infer"] == {}
def main(args): if args.cuda: torch.set_default_tensor_type("torch.cuda.FloatTensor") logging.info("Loading data") data = poly.load_data(poly.JSB_CHORALES) logging.info("-" * 40) model = models[args.model] logging.info("Training {} on {} sequences".format( model.__name__, len(data["train"]["sequences"]))) sequences = data["train"]["sequences"] lengths = data["train"]["sequence_lengths"] # find all the notes that are present at least once in the training set present_notes = (sequences == 1).sum(0).sum(0) > 0 # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = poutine.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = poutine.trace( poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". optim = Adam({"lr": args.learning_rate}) if args.tmc: if args.jit: raise NotImplementedError( "jit support not yet added for TraceTMC_ELBO") elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2) tmc_model = poutine.infer_config( model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {}, ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo( max_plate_nesting=1 if model is model_0 else 2, strict_enumeration_warning=(model is not model_7), jit_options={"time_compilation": args.time_compilation}, ) svi = SVI(model, guide, optim, elbo) # We'll train on small minibatches. logging.info("Step\tLoss") for step in range(args.num_steps): loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size) logging.info("{: >5d}\t{}".format(step, loss / num_observations)) if args.jit and args.time_compilation: logging.debug("time to compile: {} s.".format( elbo._differentiable_loss.compile_time)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False) logging.info("training loss = {}".format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info("-" * 40) logging.info("Evaluating on {} test sequences".format( len(data["test"]["sequences"]))) sequences = data["test"]["sequences"][..., present_notes] lengths = data["test"]["sequence_lengths"] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False) logging.info("test loss = {}".format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info("{} capacity = {} parameters".format(model.__name__, capacity))
def main(args): pyro.set_rng_seed(0) pyro.clear_param_store() pyro.enable_validation(__debug__) # load data if args.dataset == "dipper": capture_history_file = os.path.dirname( os.path.abspath(__file__)) + '/dipper_capture_history.csv' elif args.dataset == "vole": capture_history_file = os.path.dirname( os.path.abspath(__file__)) + '/meadow_voles_capture_history.csv' else: raise ValueError("Available datasets are \'dipper\' and \'vole\'.") capture_history = torch.tensor( np.genfromtxt(capture_history_file, delimiter=',')).float()[:, 1:] N, T = capture_history.shape print( "Loaded {} capture history for {} individuals collected over {} time periods." .format(args.dataset, N, T)) if args.dataset == "dipper" and args.model in ["4", "5"]: sex_file = os.path.dirname( os.path.abspath(__file__)) + '/dipper_sex.csv' sex = torch.tensor(np.genfromtxt(sex_file, delimiter=',')).float()[:, 1] print("Loaded dipper sex data.") elif args.dataset == "vole" and args.model in ["4", "5"]: raise ValueError( "Cannot run model_{} on meadow voles data, since we lack sex " + "information for these animals.".format(args.model)) else: sex = None model = models[args.model] # we use poutine.block to only expose the continuous latent variables # in the models to AutoDiagonalNormal (all of which begin with 'phi' # or 'rho') def expose_fn(msg): return msg["name"][0:3] in ['phi', 'rho'] # we use a mean field diagonal normal variational distributions (i.e. guide) # for the continuous latent variables. guide = AutoDiagonalNormal(poutine.block(model, expose_fn=expose_fn)) # since we enumerate the discrete random variables, # we need to use TraceEnum_ELBO or TraceTMC_ELBO. optim = Adam({'lr': args.learning_rate}) if args.tmc: elbo = TraceTMC_ELBO(max_plate_nesting=1) tmc_model = poutine.infer_config(model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {} ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: elbo = TraceEnum_ELBO(max_plate_nesting=1, num_particles=20, vectorize_particles=True) svi = SVI(model, guide, optim, elbo) losses = [] print( "Beginning training of model_{} with Stochastic Variational Inference." .format(args.model)) for step in range(args.num_steps): loss = svi.step(capture_history, sex) losses.append(loss) if step % 20 == 0 and step > 0 or step == args.num_steps - 1: print("[iteration %03d] loss: %.3f" % (step, np.mean(losses[-20:]))) # evaluate final trained model elbo_eval = TraceEnum_ELBO(max_plate_nesting=1, num_particles=2000, vectorize_particles=True) svi_eval = SVI(model, guide, optim, elbo_eval) print("Final loss: %.4f" % svi_eval.evaluate_loss(capture_history, sex))