def test_model_enumerated_elbo(model, guide, data, history): pyro.clear_param_store() with pyro_backend("contrib.funsor"): if history > 1: pytest.xfail( reason="TraceMarkovEnum_ELBO does not yet support history > 1") model = infer.config_enumerate(model, default="parallel") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.loss_and_grads(model, guide, data, history, False) expected_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) actual_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_close(actual_grad, expected_grad)
def test_guide_enumerated_elbo(model, guide, data, history): pyro.clear_param_store() with pyro_backend("contrib.funsor"), \ pytest.raises( NotImplementedError, match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"): if history > 1: pytest.xfail( reason="TraceMarkovEnum_ELBO does not yet support history > 1") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.loss_and_grads(model, guide, data, history, False) expected_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) actual_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_close(actual_grad, expected_grad)
def load_checkpoint( self, path: Union[str, Path] = None, param_only: bool = False, warnings: bool = False, ): """ Load checkpoint. :param path: Path to model checkpoint. :param param_only: Load only parameters. :param warnings: Give warnings if loaded model has not been fully trained. """ device = self.device path = Path(path) if path else self.run_path pyro.clear_param_store() checkpoint = torch.load( path / f"{self.full_name}-model.tpqr", map_location=device ) pyro.get_param_store().set_state(checkpoint["params"]) if not param_only: self.converged = checkpoint["convergence_status"] self._rolling = checkpoint["rolling"] self.iter = checkpoint["iter"] self.optim.set_state(checkpoint["optimizer"]) logger.info( f"Iteration #{self.iter}. Loaded a model checkpoint from {path}" ) if warnings and not checkpoint["convergence_status"]: logger.warning(f"Model at {path} has not been fully trained")
def assert_ok(model, guide, elbo, *args, **kwargs): """ Assert that inference works without warnings or errors. """ pyro.get_param_store().clear() adam = optim.Adam({"lr": 1e-6}) inference = infer.SVI(model, guide, adam, elbo) for i in range(2): inference.step(*args, **kwargs)
def assert_error(model, guide, elbo, match=None): """ Assert that inference fails with an error. """ pyro.get_param_store().clear() adam = optim.Adam({"lr": 1e-6}) inference = infer.SVI(model, guide, adam, elbo) with pytest.raises((NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError), match=match): inference.step()
def assert_warning(model, guide, elbo): """ Assert that inference works but with a warning. """ pyro.get_param_store().clear() adam = optim.Adam({"lr": 1e-6}) inference = infer.SVI(model, guide, adam, elbo) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") inference.step() assert len(w), 'No warnings were raised' for warning in w: print(warning)
def main(args): funsor.set_backend("torch") # Define a basic model with a single Normal latent random variable `loc` # and a batch of Normally distributed observations. def model(data): loc = pyro.sample("loc", dist.Normal(0., 1.)) with pyro.plate("data", len(data), dim=-1): pyro.sample("obs", dist.Normal(loc, 1.), obs=data) # Define a guide (i.e. variational distribution) with a Normal # distribution over the latent random variable `loc`. def guide(data): guide_loc = pyro.param("guide_loc", torch.tensor(0.)) guide_scale = pyro.param("guide_scale", torch.tensor(1.), constraint=constraints.positive) pyro.sample("loc", dist.Normal(guide_loc, guide_scale)) # Generate some data. torch.manual_seed(0) data = torch.randn(100) + 3.0 # Because the API in minipyro matches that of Pyro proper, # training code works with generic Pyro implementations. with pyro_backend(args.backend), interpretation(MonteCarlo()): # Construct an SVI object so we can do variational inference on our # model/guide pair. Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO elbo = Elbo() adam = optim.Adam({"lr": args.learning_rate}) svi = infer.SVI(model, guide, adam, elbo) # Basic training loop pyro.get_param_store().clear() for step in range(args.num_steps): loss = svi.step(data) if args.verbose and step % 100 == 0: print("step {} loss = {}".format(step, loss)) # Report the final values of the variational parameters # in the guide after training. if args.verbose: for name in pyro.get_param_store(): value = pyro.param(name).data print("{} = {}".format(name, value.detach().cpu().numpy())) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert (pyro.param("guide_loc") - 3.0).abs() < 0.1
def test_elbo_plate_plate(backend, outer_dim, inner_dim): with pyro_backend(backend): pyro.get_param_store().clear() num_particles = 1 q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True)) p = 0.2693204236205713 # for which kl(Categorical(q), Categorical(p)) = 0.5 p = torch.tensor([p, 1 - p]) def model(): d = dist.Categorical(p) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d) with context1: pyro.sample("x", d) with context2: pyro.sample("y", d) with context1, context2: pyro.sample("z", d) def guide(): d = dist.Categorical(pyro.param("q")) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d, infer={"enumerate": "parallel"}) with context1: pyro.sample("x", d, infer={"enumerate": "parallel"}) with context2: pyro.sample("y", d, infer={"enumerate": "parallel"}) with context1, context2: pyro.sample("z", d, infer={"enumerate": "parallel"}) kl_node = kl_divergence( torch.distributions.Categorical(funsor.to_data(q)), torch.distributions.Categorical(funsor.to_data(p))) kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node expected_loss = kl expected_grad = grad(kl, [funsor.to_data(q)])[0] elbo = infer.TraceEnum_ELBO(num_particles=num_particles, vectorize_particles=True, strict_enumeration_warning=True) elbo = elbo.differentiable_loss if backend == "pyro" else elbo actual_loss = funsor.to_data(elbo(model, guide)) actual_loss.backward() actual_grad = funsor.to_data(pyro.param('q')).grad assert ops.allclose(actual_loss, expected_loss, atol=1e-5) assert ops.allclose(actual_grad, expected_grad, atol=1e-5)
def test_optimizer(backend, optim_name, jit): def model(data): p = pyro.param("p", torch.tensor(0.5)) pyro.sample("x", dist.Bernoulli(p), obs=data) def guide(data): pass data = torch.tensor(0.) with pyro_backend(backend): pyro.get_param_store().clear() Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO elbo = Elbo(ignore_jit_warnings=True) optimizer = getattr(optim, optim_name)({"lr": 1e-6}) inference = infer.SVI(model, guide, optimizer, elbo) for i in range(2): inference.step(data)
def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4): # copied from pyro expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data(actual_loss) assert ops.allclose(actual_loss, expected_loss, atol=atol, rtol=rtol) names = pyro.get_param_store().keys() params = [] for name in names: params.append(funsor.to_data(pyro.param(name)).unconstrained()) actual_grads = grad(actual_loss, params, allow_unused=True, retain_graph=True) expected_grads = grad(expected_loss, params, allow_unused=True, retain_graph=True) for name, actual_grad, expected_grad in zip(names, actual_grads, expected_grads): if actual_grad is None or expected_grad is None: continue assert ops.allclose(actual_grad, expected_grad, atol=atol, rtol=rtol)
def train(model, guide, lr=1e-3, n_steps=1000, jit=True, verbose=False, **kwargs): pyro.clear_param_store() optimizer = optim.Adam({"lr": lr}) elbo = ( infer.JitTraceEnum_ELBO(max_plate_nesting=2) if jit else infer.TraceEnum_ELBO(max_plate_nesting=2) ) svi = infer.SVI(model, guide, optimizer, elbo) for step in range(n_steps): svi.step(**kwargs) if step % 100 == 99 and verbose: values = tuple(f"{k}: {v}" for k, v in pyro.get_param_store().items()) print(values)
def main(args): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') logging.info('Loading data') data = poly.load_data(poly.JSB_CHORALES) logging.info('-' * 40) model = models[args.model] logging.info('Training {} on {} sequences'.format( model.__name__, len(data['train']['sequences']))) sequences = data['train']['sequences'] lengths = data['train']['sequence_lengths'] # find all the notes that are present at least once in the training set present_notes = ((sequences == 1).sum(0).sum(0) > 0) # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() pyro.enable_validation(__debug__) # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( handlers.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = handlers.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = handlers.trace( handlers.replay(handlers.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Bind non-PyTorch parameters to make these functions jittable. model = functools.partial(model, args=args) guide = functools.partial(guide, args=args) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". optimizer = optim.Adam({'lr': args.learning_rate}) if args.tmc: if args.jit and not args.funsor: raise NotImplementedError( "jit support not yet added for TraceTMC_ELBO") Elbo = infer.JitTraceTMC_ELBO if args.jit else infer.TraceTMC_ELBO elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2) tmc_model = handlers.infer_config(model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {} ) # noqa: E501 svi = infer.SVI(tmc_model, guide, optimizer, elbo) else: Elbo = infer.JitTraceEnum_ELBO if args.jit else infer.TraceEnum_ELBO elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2, strict_enumeration_warning=True, jit_options={"time_compilation": args.time_compilation}) svi = infer.SVI(model, guide, optimizer, elbo) # We'll train on small minibatches. logging.info('Step\tLoss') for step in range(args.num_steps): loss = svi.step(sequences, lengths, batch_size=args.batch_size) logging.info('{: >5d}\t{}'.format(step, loss / num_observations)) if args.jit and args.time_compilation: logging.debug('time to compile: {} s.'.format( elbo._differentiable_loss.compile_time)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, batch_size=sequences.shape[0], include_prior=False) logging.info('training loss = {}'.format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info('-' * 40) logging.info('Evaluating on {} test sequences'.format( len(data['test']['sequences']))) sequences = data['test']['sequences'][..., present_notes] lengths = data['test']['sequence_lengths'] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, batch_size=sequences.shape[0], include_prior=False) logging.info('test loss = {}'.format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info('model_{} capacity = {} parameters'.format( args.model, capacity))
def save_checkpoint(self, writer: SummaryWriter = None): """ Save checkpoint. :param writer: SummaryWriter object. """ # save only if no NaN values for k, v in pyro.get_param_store().items(): if torch.isnan(v).any() or torch.isinf(v).any(): raise ValueError( "Iteration #{}. Detected NaN values in {}".format(self.iter, k) ) # update convergence criteria parameters for name in self.conv_params: if name == "-ELBO": self._rolling["-ELBO"].append(self.iter_loss) else: self._rolling[name].append(pyro.param(name).item()) # check convergence status self.converged = False if len(self._rolling["-ELBO"]) == self._rolling["-ELBO"].maxlen: crit = all( torch.tensor(self._rolling[p]).std() / torch.tensor(self._rolling[p])[-50:].std() < 1.05 for p in self.conv_params ) if crit: self.converged = True # save the model state torch.save( { "iter": self.iter, "params": pyro.get_param_store().get_state(), "optimizer": self.optim.get_state(), "rolling": self._rolling, "convergence_status": self.converged, }, self.run_path / f"{self.full_name}-model.tpqr", ) # save global paramters for tensorboard writer.add_scalar("-ELBO", self.iter_loss, self.iter) for name, val in pyro.get_param_store().items(): if val.dim() == 0: writer.add_scalar(name, val.item(), self.iter) elif val.dim() == 1 and len(val) <= self.S + 1: scalars = {str(i): v.item() for i, v in enumerate(val)} writer.add_scalars(name, scalars, self.iter) if False and self.data.labels is not None: pred_labels = ( self.pspecific_map[self.data.is_ontarget].cpu().numpy().ravel() ) true_labels = self.data.labels["z"].ravel() metrics = {} with np.errstate(divide="ignore", invalid="ignore"): metrics["MCC"] = matthews_corrcoef(true_labels, pred_labels) metrics["Recall"] = recall_score(true_labels, pred_labels, zero_division=0) metrics["Precision"] = precision_score( true_labels, pred_labels, zero_division=0 ) neg, pos = {}, {} neg["TN"], neg["FP"], pos["FN"], pos["TP"] = confusion_matrix( true_labels, pred_labels, labels=(0, 1) ).ravel() writer.add_scalars("ACCURACY", metrics, self.iter) writer.add_scalars("NEGATIVES", neg, self.iter) writer.add_scalars("POSITIVES", pos, self.iter) logger.debug(f"Iteration #{self.iter}: Successful.")