def test_guide_enumerated_elbo(model, guide, data, history): pyro.clear_param_store() with pyro_backend("contrib.funsor"), \ pytest.raises( NotImplementedError, match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"): if history > 1: pytest.xfail( reason="TraceMarkovEnum_ELBO does not yet support history > 1") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.loss_and_grads(model, guide, data, history, False) expected_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) actual_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_close(actual_grad, expected_grad)
def load_checkpoint( self, path: Union[str, Path] = None, param_only: bool = False, warnings: bool = False, ): """ Load checkpoint. :param path: Path to model checkpoint. :param param_only: Load only parameters. :param warnings: Give warnings if loaded model has not been fully trained. """ device = self.device path = Path(path) if path else self.run_path pyro.clear_param_store() checkpoint = torch.load( path / f"{self.full_name}-model.tpqr", map_location=device ) pyro.get_param_store().set_state(checkpoint["params"]) if not param_only: self.converged = checkpoint["convergence_status"] self._rolling = checkpoint["rolling"] self.iter = checkpoint["iter"] self.optim.set_state(checkpoint["optimizer"]) logger.info( f"Iteration #{self.iter}. Loaded a model checkpoint from {path}" ) if warnings and not checkpoint["convergence_status"]: logger.warning(f"Model at {path} has not been fully trained")
def test_model_enumerated_elbo(model, guide, data, history): pyro.clear_param_store() with pyro_backend("contrib.funsor"): if history > 1: pytest.xfail( reason="TraceMarkovEnum_ELBO does not yet support history > 1") model = infer.config_enumerate(model, default="parallel") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.loss_and_grads(model, guide, data, history, False) expected_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) actual_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_close(actual_grad, expected_grad)
def init( self, lr: float = 0.005, nbatch_size: int = 5, fbatch_size: int = 512, jit: bool = False, ) -> None: """ Initialize SVI object. :param lr: Learning rate. :param nbatch_size: AOI batch size. :param fbatch_size: Frame batch size. :param jit: Use JIT compiler. """ self.lr = lr self.optim_fn = optim.Adam self.optim_args = {"lr": lr, "betas": [0.9, 0.999]} self.optim = self.optim_fn(self.optim_args) try: self.load_checkpoint() except (FileNotFoundError, TypeError): pyro.clear_param_store() self.iter = 0 self.converged = False self._rolling = {p: deque([], maxlen=100) for p in self.conv_params} self.init_parameters() self.elbo = self.TraceELBO(jit) self.svi = infer.SVI(self.model, self.guide, self.optim, loss=self.elbo) self.nbatch_size = min(nbatch_size, self.data.Nt) self.fbatch_size = min(fbatch_size, self.data.F)
def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs): """ Assert that enumeration runs... """ with pyro_backend("pyro"): pyro.clear_param_store() if guide is None: guide = lambda **kwargs: None # noqa: E731 q_pyro, q_funsor = LifoQueue(), LifoQueue() q_pyro.put(Trace()) q_funsor.put(Trace()) while not q_pyro.empty() and not q_funsor.empty(): with pyro_backend("pyro"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): guide_tr_pyro = handlers.trace( handlers.queue( guide, q_pyro, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend, )).get_trace(**kwargs) tr_pyro = handlers.trace( handlers.replay(model, trace=guide_tr_pyro)).get_trace(**kwargs) with pyro_backend("contrib.funsor"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): guide_tr_funsor = handlers.trace( handlers.queue( guide, q_funsor, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend, )).get_trace(**kwargs) tr_funsor = handlers.trace( handlers.replay(model, trace=guide_tr_funsor)).get_trace(**kwargs) # make sure all dimensions were cleaned up assert _DIM_STACK.local_frame is _DIM_STACK.global_frame assert (not _DIM_STACK.global_frame.name_to_dim and not _DIM_STACK.global_frame.dim_to_name) assert _DIM_STACK.outermost is None tr_pyro = prune_subsample_sites(tr_pyro.copy()) tr_funsor = prune_subsample_sites(tr_funsor.copy()) _check_traces(tr_pyro, tr_funsor)
def train(model, guide, lr=1e-3, n_steps=1000, jit=True, verbose=False, **kwargs): pyro.clear_param_store() optimizer = optim.Adam({"lr": lr}) elbo = ( infer.JitTraceEnum_ELBO(max_plate_nesting=2) if jit else infer.TraceEnum_ELBO(max_plate_nesting=2) ) svi = infer.SVI(model, guide, optimizer, elbo) for step in range(n_steps): svi.step(**kwargs) if step % 100 == 99 and verbose: values = tuple(f"{k}: {v}" for k, v in pyro.get_param_store().items()) print(values)
def main(args): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') logging.info('Loading data') data = poly.load_data(poly.JSB_CHORALES) logging.info('-' * 40) model = models[args.model] logging.info('Training {} on {} sequences'.format( model.__name__, len(data['train']['sequences']))) sequences = data['train']['sequences'] lengths = data['train']['sequence_lengths'] # find all the notes that are present at least once in the training set present_notes = ((sequences == 1).sum(0).sum(0) > 0) # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() pyro.enable_validation(__debug__) # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( handlers.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = handlers.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = handlers.trace( handlers.replay(handlers.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Bind non-PyTorch parameters to make these functions jittable. model = functools.partial(model, args=args) guide = functools.partial(guide, args=args) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". optimizer = optim.Adam({'lr': args.learning_rate}) if args.tmc: if args.jit and not args.funsor: raise NotImplementedError( "jit support not yet added for TraceTMC_ELBO") Elbo = infer.JitTraceTMC_ELBO if args.jit else infer.TraceTMC_ELBO elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2) tmc_model = handlers.infer_config(model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {} ) # noqa: E501 svi = infer.SVI(tmc_model, guide, optimizer, elbo) else: Elbo = infer.JitTraceEnum_ELBO if args.jit else infer.TraceEnum_ELBO elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2, strict_enumeration_warning=True, jit_options={"time_compilation": args.time_compilation}) svi = infer.SVI(model, guide, optimizer, elbo) # We'll train on small minibatches. logging.info('Step\tLoss') for step in range(args.num_steps): loss = svi.step(sequences, lengths, batch_size=args.batch_size) logging.info('{: >5d}\t{}'.format(step, loss / num_observations)) if args.jit and args.time_compilation: logging.debug('time to compile: {} s.'.format( elbo._differentiable_loss.compile_time)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, batch_size=sequences.shape[0], include_prior=False) logging.info('training loss = {}'.format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info('-' * 40) logging.info('Evaluating on {} test sequences'.format( len(data['test']['sequences']))) sequences = data['test']['sequences'][..., present_notes] lengths = data['test']['sequence_lengths'] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, batch_size=sequences.shape[0], include_prior=False) logging.info('test loss = {}'.format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info('model_{} capacity = {} parameters'.format( args.model, capacity))
def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history, use_replay): pyro.clear_param_store() with pyro_backend("contrib.funsor"): with handlers.enum(): enum_model = infer.config_enumerate(model, default="parallel") # sequential factors trace = handlers.trace(enum_model).get_trace( weeks_data, days_data, history, False) # vectorized trace if use_replay: guide_trace = handlers.trace( _guide_from_model(model)).get_trace( weeks_data, days_data, history, True) vectorized_trace = handlers.trace( handlers.replay(model, trace=guide_trace)).get_trace( weeks_data, days_data, history, True) else: vectorized_trace = handlers.trace(enum_model).get_trace( weeks_data, days_data, history, True) factors = list() # sequential weeks factors for i in range(len(weeks_data)): for v in vars1: factors.append(trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) # sequential days factors for j in range(len(days_data)): for v in vars2: factors.append(trace.nodes["{}_{}".format( v, j)]["funsor"]["log_prob"]) vectorized_factors = list() # vectorized weeks factors for i in range(history): for v in vars1: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) for i in range(history, len(weeks_data)): for v in vars1: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, slice(history, len(weeks_data)))]["funsor"]["log_prob"](**{ "weeks": i - history }, **{ "{}_{}".format( k, slice(history - j, len(weeks_data) - j)): "{}_{}".format(k, i - j) for j in range(history + 1) for k in vars1 })) # vectorized days factors for i in range(history): for v in vars2: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) for i in range(history, len(days_data)): for v in vars2: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, slice(history, len(days_data)))]["funsor"]["log_prob"](**{ "days": i - history }, **{ "{}_{}".format( k, slice(history - j, len(days_data) - j)): "{}_{}".format(k, i - j) for j in range(history + 1) for k in vars2 })) # assert correct factors for f1, f2 in zip(factors, vectorized_factors): assert_close(f2, f1.align(tuple(f2.inputs))) # assert correct step expected_measure_vars = frozenset() actual_weeks_step = vectorized_trace.nodes["weeks"]["value"] # expected step: assume that all but the last var is markov expected_weeks_step = frozenset() for v in vars1[:-1]: v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ + tuple("{}_{}".format(v, slice(j, len(weeks_data)-history+j)) for j in range(history+1)) expected_weeks_step |= frozenset({v_step}) # grab measure_vars, found only at sites that are not replayed if not use_replay: expected_measure_vars |= frozenset(v_step) actual_days_step = vectorized_trace.nodes["days"]["value"] # expected step: assume that all but the last var is markov expected_days_step = frozenset() for v in vars2[:-1]: v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ + tuple("{}_{}".format(v, slice(j, len(days_data)-history+j)) for j in range(history+1)) expected_days_step |= frozenset({v_step}) # grab measure_vars, found only at sites that are not replayed if not use_replay: expected_measure_vars |= frozenset(v_step) assert actual_weeks_step == expected_weeks_step assert actual_days_step == expected_days_step # check measure_vars actual_measure_vars = terms_from_trace( vectorized_trace)["measure_vars"] assert actual_measure_vars == expected_measure_vars
def simulate( model: str, N: int, F: int, C: int = 1, P: int = 14, seed: int = 0, params: dict = dict(), ) -> CosmosDataset: """ Simulate a new dataset. :param model: Tapqir model. :param N: Number of total AOIs. Half will be on-target and half off-target. :param F: Number of frames. :param C: Number of color channels. :param P: Number of pixels alongs the axis. :param seed: Rng seed. :param params: A dictionary of fixed parameter values. :return: A new simulated data set. """ pyro.set_rng_seed(seed) pyro.clear_param_store() # samples samples = {} samples["gain"] = torch.full((1, 1), params["gain"]) samples["lamda"] = torch.full((1, 1), params["lamda"]) samples["proximity"] = torch.full((1, 1), params["proximity"]) if "pi" in params: samples["pi"] = torch.tensor([[1 - params["pi"], params["pi"]]]) samples["background"] = torch.full((1, N, 1), params["background"]) for k in range(model.K): samples[f"width_{k}"] = torch.full((1, N, F), params["width"]) samples[f"height_{k}"] = torch.full((1, N, F), params["height"]) else: # kinetic simulations samples["init"] = torch.tensor([[ params["koff"] / (params["kon"] + params["koff"]), params["kon"] / (params["kon"] + params["koff"]), ]]) samples["trans"] = torch.tensor([[ [1 - params["kon"], params["kon"]], [params["koff"], 1 - params["koff"]], ]]) for f in range(F): samples[f"background_{f}"] = torch.full((1, N, 1), params["background"]) for k in range(model.K): samples[f"width_{k}_{f}"] = torch.full((1, N, 1), params["width"]) samples[f"height_{k}_{f}"] = torch.full((1, N, 1), params["height"]) offset = torch.full((3, ), params["offset"]) target_locs = torch.full((N, F, 1, 2), (P - 1) / 2) is_ontarget = torch.zeros((N, ), dtype=torch.bool) is_ontarget[:N // 2] = True # placeholder dataset model.data = CosmosDataset( torch.full((N, F, C, P, P), params["background"] + params["offset"]), target_locs, is_ontarget, None, offset_samples=offset, offset_weights=torch.ones(3) / 3, device=model.device, ) # sample predictive = Predictive(handlers.uncondition(model.model), posterior_samples=samples, num_samples=1) samples = predictive() data = torch.zeros(N, F, C, P, P) labels = np.zeros((N // 2, F, 1), dtype=[("aoi", int), ("frame", int), ("z", bool)]) labels["aoi"] = np.arange(N // 2).reshape(-1, 1, 1) labels["frame"] = np.arange(F).reshape(-1, 1) if "pi" in params: data[:, :, 0] = samples["data"][0].data.floor() labels["z"][:, :, 0] = samples["theta"][0][:N // 2].cpu() > 0 else: # kinetic simulations for f in range(F): data[:, f:f + 1, 0] = samples[f"data_{f}"][0].data.floor() labels["z"][:, f:f + 1, 0] = samples[f"theta_{f}"][0][:N // 2].cpu() > 0 return CosmosDataset( data.cpu(), target_locs.cpu(), is_ontarget.cpu(), labels, offset_samples=offset, offset_weights=torch.ones(3) / 3, device=model.device, )