def __init__(self, model, data, covariates=None, *, num_warmup=1000, num_samples=1000, num_chains=1, dense_mass=False, jit_compile=False, max_tree_depth=10): assert data.size(-2) == covariates.size(-2) super().__init__() self.model = model max_plate_nesting = _guess_max_plate_nesting(model, (data, covariates), {}) self.max_plate_nesting = max(max_plate_nesting, 1) # force a time plate kernel = NUTS(model, full_mass=dense_mass, jit_compile=jit_compile, ignore_jit_warnings=True, max_tree_depth=max_tree_depth, max_plate_nesting=max_plate_nesting) mcmc = MCMC(kernel, warmup_steps=num_warmup, num_samples=num_samples, num_chains=num_chains) mcmc.run(data, covariates) # conditions to compute rhat if (num_chains == 1 and num_samples >= 4) or (num_chains > 1 and num_samples >= 2): mcmc.summary() # inspect the model with particles plate = 1, so that we can reshape samples to # add any missing plate dim in front. with poutine.trace() as tr: with pyro.plate("particles", 1, dim=-self.max_plate_nesting - 1): model(data, covariates) self._trace = tr.trace self._samples = mcmc.get_samples() self._num_samples = num_samples * num_chains for name, node in list(self._trace.nodes.items()): if name not in self._samples: del self._trace.nodes[name]
def inference(Model, training_data, test_data = None, config = None): ''' A wrapper function calling Pyro's SVI step with settings given in config. Records telemetry including elbo loss, mean negative log likelihood on held out data, gradient norms and parameter history during training. If config includes telemetry from a previous inference run, inference continues from that run. If slope_significance is set to a value less than 1, training halts when the mean negative log likelihood converges. Convergence is estimated by linear regression in a moving window of size convergence_window when p(slope = estimate|true_slope = 0) < slope_significance. Default config is config = dict( n_iter = 1000, learning_rate = 0.1, beta1 = 0.9, beta2 = 0.999, learning_rate_decay = 1., # no decay by default batch_size = 32, n_elbo_particles = 32, n_posterior_samples = 1024, window = 500, convergence_window = 30, slope_significance = 0.1, track_params = False, monitor_gradients = False, telemetry = None ) Example: ''' #initcopy = clone_init(init) if config is None: config = dict( n_iter = 1000, learning_rate = 0.1, beta1 = 0.9, beta2 = 0.999, learning_rate_decay = 1., # no decay by default batch_size = 32, n_elbo_particles = 32, n_posterior_samples = 1024, window = 500, convergence_window = 30, slope_significance = 0.1, track_params = False, monitor_gradients = False, telemetry = None, ) if test_data is None: training_data, test_data = train_test_split(training_data) #def per_param_callable(module_name, param_name): # return {"lr": config['learning_rate'], "betas": (0.90, 0.999)} # from http://pyro.ai/examples/svi_part_i.html model = Model.model guide = Model.guide optim = pyro.optim.Adam({"lr": config['learning_rate'], "betas": (config['beta1'], config['beta2'])}) # if there is previous telemetry in the config from an interrupted inference run # restore the state of that inference and continue training if config['telemetry']: pyro.clear_param_store() print('Continuing from previous inference run.') telemetry = config['telemetry'] optim.set_state(telemetry['optimizer_state']) pyro.get_param_store().set_state(telemetry['param_store_state']) i = len(telemetry['loss']) config['n_iter'] += i # init params not in telemetry model(training_data) guide(training_data) for k,v in pyro.get_param_store().items(): if k not in telemetry['param_history'].keys(): telemetry['param_history'][k] = v.unsqueeze(0) else: pyro.clear_param_store() telemetry = dict() telemetry['gradient_norms'] = defaultdict(list) telemetry['loss'] = [] telemetry['MNLL'] = [] telemetry['training_duration'] = 0 # call model and guide to populate param store #model(training_data, config['batch_size'], init) #guide(training_data, config['batch_size'], init) Model.batch_size = config['batch_size'] model(training_data) guide(training_data) # record init in param_history telemetry['param_history'] = dict({k:v.unsqueeze(0) for k,v in pyro.get_param_store().items()}) # record MNLL at init i = 0 with torch.no_grad(): #mnll = compute_mnll(model, guide, test_data, n_samples=config['n_posterior_samples']) telemetry['MNLL'].append(-Model.mnll(test_data, config['n_posterior_samples'])) print('\n') print("NLL after {}/{} iterations is {}".format(i,config['n_iter'], telemetry['MNLL'][-1])) # Learning rate schedulers # Haven't found a way to get and set its state for checkpointing #optim = torch.optim.Adam #scheduler = pyro.optim.ExponentialLR({'optimizer': optim, 'optim_args': {"lr": config['learning_rate'], "betas": (beta1, beta2)}, 'gamma': config['learning_rate_decay']}) #scheduler = pyro.optim.ExponentialLR({'optimizer': optim, 'optim_args': per_param_callable, 'gamma': config['learning_rate_decay']}) max_plate_nesting = _guess_max_plate_nesting(model,(training_data,),{}) #print("Guessed that model has max {} nested plates.".format(max_plate_nesting)) # look for sample sites with infer:enumerate trace = pyro.poutine.trace(model).get_trace(training_data) contains_enumeration = any([values['infer'] == {'enumerate': 'parallel'} for node,values in trace.nodes.items() if 'infer' in values]) if contains_enumeration: elbo = TraceEnum_ELBO(max_plate_nesting=max_plate_nesting, num_particles=config['n_elbo_particles'], vectorize_particles=True) else: elbo = Trace_ELBO(max_plate_nesting=max_plate_nesting, num_particles=config['n_elbo_particles'], vectorize_particles=True) #svi = SVI(model, guide, scheduler, loss=elbo) svi = SVI(model, guide, optim, loss=elbo) if config['monitor_gradients']: # register gradient hooks for monitoring for name, value in pyro.get_param_store().named_parameters(): value.register_hook(lambda g, name=name: telemetry['gradient_norms'][name].append(g.norm().item())) start = time.time() # with torch.no_grad(): # mnll = compute_mnll(model, guide, test_data, init, n_samples=config['n_posterior_samples']) # telemetry['MNLL'].append(-mnll) # print("NLL at init is {}".format(mnlls[-1])) while p_value_of_slope(telemetry['MNLL'],config['convergence_window'], config['slope_significance']) < config['slope_significance'] and i < config['n_iter']: try: loss = svi.step(training_data) telemetry['loss'].append(loss) if i % config['window'] or i <= config['window']: print('.', end='') #scheduler.step() else: with torch.no_grad(): #mnll = compute_mnll(model, guide, test_data, n_samples=config['n_posterior_samples']) telemetry['MNLL'].append(-Model.mnll(test_data, config['n_posterior_samples'])) print('\n') print("NLL after {}/{} iterations is {}".format(i,config['n_iter'], telemetry['MNLL'][-1])) print('\n') #print('\nSetting number of posterior samples to {}'.format(config['n_posterior_samples']), end='') #print('\nSetting batch size to {}'.format(config['batch_size']), end='') if config['track_params']: telemetry['param_history'] = {k:torch.cat([telemetry['param_history'][k],v.unsqueeze(0).detach()],dim=0) for k,v in pyro.get_param_store().items()} i += 1 # except RuntimeError as e: # print(e) # print("There was a runtime error.") # return telemetry except KeyboardInterrupt: print('\Interrupted by user after {} iterations.\n'.format(i)) params = {k:v.detach() for k,v in pyro.get_param_store().items()} Model.params = params telemetry['training_duration'] += round(time.time() - start) telemetry['optimizer_state'] = optim.get_state() telemetry['param_store_state'] = pyro.get_param_store().get_state() return telemetry print('\nConverged in {} iterations.\n'.format(i)) # make all pytorch tensors into np arrays, which consume less disk space #param_history = dict(zip(param_history.keys(),map(lambda x: x.detach().numpy(), param_history.values()))) params = {k:v.detach() for k,v in pyro.get_param_store().items()} Model.params = params telemetry['training_duration'] += round(time.time() - start) telemetry['optimizer_state'] = optim.get_state() telemetry['param_store_state'] = pyro.get_param_store().get_state() return telemetry