def _train_full_data(self, x_data, obs2sample, n_epochs=20000, lr=0.002, progressbar=True, random_seed=1): idx = np.arange(x_data.shape[0]).astype("int64") # move data to default device x_data = device_put(jnp.array(x_data)) extra_data = { 'idx': device_put(jnp.array(idx)), 'obs2sample': device_put(jnp.array(obs2sample)) } # initialise SVI inference method svi = SVI( self.model.forward, self.guide, # limit the gradient step from becoming too large optim.ClippedAdam(clip_norm=jnp.array(200), **{'step_size': jnp.array(lr)}), loss=Trace_ELBO()) init_state = svi.init(random.PRNGKey(random_seed), x_data=x_data, **extra_data) self.state = init_state if not progressbar: # Training in one step epochs_iterator = tqdm(range(1)) for e in epochs_iterator: state, losses = lax.scan( lambda state_1, i: svi.update( state_1, x_data=self.x_data, **extra_data), # TODO for minibatch DataLoader goes here init_state, jnp.arange(n_epochs)) # print(state) epochs_iterator.set_description( 'ELBO Loss: ' + '{:.4e}'.format(losses[::-1][0])) self.state = state self.hist = losses else: # training using for-loop jit_step_update = jit(lambda state_1: svi.update( state_1, x_data=x_data, **extra_data)) # TODO figure out minibatch static_argnums https://github.com/pyro-ppl/numpyro/issues/869 ### very slow epochs_iterator = tqdm(range(n_epochs)) for e in epochs_iterator: self.state, loss = jit_step_update(self.state) self.hist.append(loss) epochs_iterator.set_description('ELBO Loss: ' + '{:.4e}'.format(loss)) self.state_param = svi.get_params(self.state).copy()