Exemplo n.º 1
0
    def __init__(self, model, data, covariates=None, *,
                 num_warmup=1000, num_samples=1000, num_chains=1,
                 dense_mass=False, jit_compile=False, max_tree_depth=10):
        assert data.size(-2) == covariates.size(-2)
        super().__init__()
        self.model = model
        max_plate_nesting = _guess_max_plate_nesting(model, (data, covariates), {})
        self.max_plate_nesting = max(max_plate_nesting, 1)  # force a time plate

        kernel = NUTS(model, full_mass=dense_mass, jit_compile=jit_compile, ignore_jit_warnings=True,
                      max_tree_depth=max_tree_depth, max_plate_nesting=max_plate_nesting)
        mcmc = MCMC(kernel, warmup_steps=num_warmup, num_samples=num_samples, num_chains=num_chains)
        mcmc.run(data, covariates)
        # conditions to compute rhat
        if (num_chains == 1 and num_samples >= 4) or (num_chains > 1 and num_samples >= 2):
            mcmc.summary()

        # inspect the model with particles plate = 1, so that we can reshape samples to
        # add any missing plate dim in front.
        with poutine.trace() as tr:
            with pyro.plate("particles", 1, dim=-self.max_plate_nesting - 1):
                model(data, covariates)

        self._trace = tr.trace
        self._samples = mcmc.get_samples()
        self._num_samples = num_samples * num_chains
        for name, node in list(self._trace.nodes.items()):
            if name not in self._samples:
                del self._trace.nodes[name]
Exemplo n.º 2
0
def inference(Model, training_data, test_data = None, config = None):
    '''
    A wrapper function calling Pyro's SVI step with settings given in config.
    Records telemetry including elbo loss, mean negative log likelihood on held out data, gradient norms and parameter history during training.
    If config includes telemetry from a previous inference run, inference continues from that run.
    If slope_significance is set to a value less than 1, training halts when the mean negative log likelihood converges.
    Convergence is estimated by linear regression in a moving window of size convergence_window when p(slope = estimate|true_slope = 0) < slope_significance.

    Default config is 
    config = dict(
            n_iter = 1000,
            learning_rate = 0.1, 
            beta1 = 0.9,
            beta2 = 0.999,
            learning_rate_decay = 1., # no decay by default
            batch_size = 32, 
            n_elbo_particles = 32, 
            n_posterior_samples = 1024,
            window = 500,
            convergence_window = 30,
            slope_significance = 0.1,
            track_params = False,
            monitor_gradients = False,
            telemetry = None
        )

    Example: 

    '''
    #initcopy = clone_init(init)
    if config is None:
        config = dict(
                n_iter = 1000,
                learning_rate = 0.1, 
                beta1 = 0.9,
                beta2 = 0.999,
                learning_rate_decay = 1., # no decay by default
                batch_size = 32, 
                n_elbo_particles = 32, 
                n_posterior_samples = 1024,
                window = 500,
                convergence_window = 30,
                slope_significance = 0.1,
                track_params = False,
                monitor_gradients = False,
                telemetry = None,
            )

    if test_data is None:
        training_data, test_data = train_test_split(training_data)
    #def per_param_callable(module_name, param_name):
    #    return {"lr": config['learning_rate'], "betas": (0.90, 0.999)} # from http://pyro.ai/examples/svi_part_i.html
    model = Model.model
    guide = Model.guide

    optim = pyro.optim.Adam({"lr": config['learning_rate'], "betas": (config['beta1'], config['beta2'])})
    
    # if there is previous telemetry in the config from an interrupted inference run
    # restore the state of that inference and continue training
    if config['telemetry']:
        pyro.clear_param_store()
        print('Continuing from previous inference run.')
        telemetry = config['telemetry']
        optim.set_state(telemetry['optimizer_state'])
        pyro.get_param_store().set_state(telemetry['param_store_state'])
        i = len(telemetry['loss'])
        config['n_iter'] += i
        # init params not in telemetry
        model(training_data)
        guide(training_data)
        for k,v in pyro.get_param_store().items():
            if k not in telemetry['param_history'].keys():
                telemetry['param_history'][k] = v.unsqueeze(0)
    else:
        pyro.clear_param_store()
        telemetry = dict()
        telemetry['gradient_norms'] = defaultdict(list)
        telemetry['loss'] = []
        telemetry['MNLL'] = []
        telemetry['training_duration'] = 0
        # call model and guide to populate param store
        #model(training_data, config['batch_size'], init)
        #guide(training_data, config['batch_size'], init)
        Model.batch_size = config['batch_size']
        model(training_data)
        guide(training_data)
        # record init in param_history
        telemetry['param_history'] = dict({k:v.unsqueeze(0) for k,v in pyro.get_param_store().items()})
        # record MNLL at init
        i = 0
        with torch.no_grad():
            #mnll = compute_mnll(model, guide, test_data, n_samples=config['n_posterior_samples'])
            telemetry['MNLL'].append(-Model.mnll(test_data, config['n_posterior_samples']))
            print('\n')
            print("NLL after {}/{} iterations is {}".format(i,config['n_iter'], telemetry['MNLL'][-1]))

    # Learning rate schedulers
    # Haven't found a way to get and set its state for checkpointing
    #optim = torch.optim.Adam
    #scheduler = pyro.optim.ExponentialLR({'optimizer': optim, 'optim_args': {"lr": config['learning_rate'], "betas": (beta1, beta2)}, 'gamma': config['learning_rate_decay']})
    #scheduler = pyro.optim.ExponentialLR({'optimizer': optim, 'optim_args': per_param_callable, 'gamma': config['learning_rate_decay']})
    
    max_plate_nesting = _guess_max_plate_nesting(model,(training_data,),{})
    #print("Guessed that model has max {} nested plates.".format(max_plate_nesting)) 

    # look for sample sites with infer:enumerate
    trace = pyro.poutine.trace(model).get_trace(training_data)
    contains_enumeration = any([values['infer'] == {'enumerate': 'parallel'} for node,values in trace.nodes.items() if 'infer' in values])

    if contains_enumeration:
        elbo = TraceEnum_ELBO(max_plate_nesting=max_plate_nesting, num_particles=config['n_elbo_particles'], vectorize_particles=True)
    else:
        elbo = Trace_ELBO(max_plate_nesting=max_plate_nesting, num_particles=config['n_elbo_particles'], vectorize_particles=True)
    #svi = SVI(model, guide, scheduler, loss=elbo)
    svi = SVI(model, guide, optim, loss=elbo)

    if config['monitor_gradients']:
        # register gradient hooks for monitoring
        for name, value in pyro.get_param_store().named_parameters():
            value.register_hook(lambda g, name=name: telemetry['gradient_norms'][name].append(g.norm().item()))
    start = time.time()
    
#    with torch.no_grad():
#        mnll = compute_mnll(model, guide, test_data, init, n_samples=config['n_posterior_samples'])
#        telemetry['MNLL'].append(-mnll)
#        print("NLL at init is {}".format(mnlls[-1]))

    while p_value_of_slope(telemetry['MNLL'],config['convergence_window'], config['slope_significance']) < config['slope_significance'] and i < config['n_iter']:
        try:
            loss = svi.step(training_data)
            telemetry['loss'].append(loss)
            if i % config['window'] or i <= config['window']:
                print('.', end='')
                #scheduler.step()
            else:
                with torch.no_grad():
                    #mnll = compute_mnll(model, guide, test_data, n_samples=config['n_posterior_samples'])
                    telemetry['MNLL'].append(-Model.mnll(test_data, config['n_posterior_samples']))
                    print('\n')
                    print("NLL after {}/{} iterations is {}".format(i,config['n_iter'], telemetry['MNLL'][-1]))
                print('\n')
                #print('\nSetting number of posterior samples to {}'.format(config['n_posterior_samples']), end='')
                #print('\nSetting batch size to {}'.format(config['batch_size']), end='')
            if config['track_params']:
                telemetry['param_history'] = {k:torch.cat([telemetry['param_history'][k],v.unsqueeze(0).detach()],dim=0) for k,v in pyro.get_param_store().items()}
            i += 1
#        except RuntimeError as e:
#            print(e)
#            print("There was a runtime error.")
#            return telemetry
        except KeyboardInterrupt:
            print('\Interrupted by user after {} iterations.\n'.format(i))
            params = {k:v.detach() for k,v in pyro.get_param_store().items()}
            Model.params = params
            telemetry['training_duration'] += round(time.time() - start)
            telemetry['optimizer_state'] = optim.get_state()
            telemetry['param_store_state'] = pyro.get_param_store().get_state()
            return telemetry
    print('\nConverged in {} iterations.\n'.format(i))

    # make all pytorch tensors into np arrays, which consume less disk space
    #param_history = dict(zip(param_history.keys(),map(lambda x: x.detach().numpy(), param_history.values())))
    
    params = {k:v.detach() for k,v in pyro.get_param_store().items()}
    Model.params = params
    telemetry['training_duration'] += round(time.time() - start)
    telemetry['optimizer_state'] = optim.get_state()
    telemetry['param_store_state'] = pyro.get_param_store().get_state()
    return telemetry