예제 #1
0
    def _train_full_data(self,
                         x_data,
                         obs2sample,
                         n_epochs=20000,
                         lr=0.002,
                         progressbar=True,
                         random_seed=1):

        idx = np.arange(x_data.shape[0]).astype("int64")

        # move data to default device
        x_data = device_put(jnp.array(x_data))
        extra_data = {
            'idx': device_put(jnp.array(idx)),
            'obs2sample': device_put(jnp.array(obs2sample))
        }

        # initialise SVI inference method
        svi = SVI(
            self.model.forward,
            self.guide,
            # limit the gradient step from becoming too large
            optim.ClippedAdam(clip_norm=jnp.array(200),
                              **{'step_size': jnp.array(lr)}),
            loss=Trace_ELBO())
        init_state = svi.init(random.PRNGKey(random_seed),
                              x_data=x_data,
                              **extra_data)
        self.state = init_state

        if not progressbar:
            # Training in one step
            epochs_iterator = tqdm(range(1))
            for e in epochs_iterator:
                state, losses = lax.scan(
                    lambda state_1, i: svi.update(
                        state_1, x_data=self.x_data, **extra_data),
                    # TODO for minibatch DataLoader goes here
                    init_state,
                    jnp.arange(n_epochs))
                # print(state)
                epochs_iterator.set_description(
                    'ELBO Loss: ' + '{:.4e}'.format(losses[::-1][0]))

            self.state = state
            self.hist = losses

        else:
            # training using for-loop

            jit_step_update = jit(lambda state_1: svi.update(
                state_1, x_data=x_data, **extra_data))
            # TODO figure out minibatch static_argnums https://github.com/pyro-ppl/numpyro/issues/869

            ### very slow
            epochs_iterator = tqdm(range(n_epochs))
            for e in epochs_iterator:
                self.state, loss = jit_step_update(self.state)
                self.hist.append(loss)
                epochs_iterator.set_description('ELBO Loss: ' +
                                                '{:.4e}'.format(loss))

        self.state_param = svi.get_params(self.state).copy()