Example #1
0
def create_job(kwargs):
    # model arguments
    nb_states = kwargs.pop('nb_states')
    trans_type = kwargs.pop('trans_type')
    obs_prior = kwargs.pop('obs_prior')
    trans_prior = kwargs.pop('trans_prior')
    trans_kwargs = kwargs.pop('trans_kwargs')

    # em arguments
    obs = kwargs.pop('obs')
    act = kwargs.pop('act')
    prec = kwargs.pop('prec')
    nb_iter = kwargs.pop('nb_iter')
    obs_mstep_kwargs = kwargs.pop('obs_mstep_kwargs')
    trans_mstep_kwargs = kwargs.pop('trans_mstep_kwargs')

    process_id = kwargs.pop('process_id')

    train_obs, train_act, test_obs, test_act = [], [], [], []
    train_idx = npr.choice(a=len(obs), size=int(0.8 * len(obs)), replace=False)
    for i in range(len(obs)):
        if i in train_idx:
            train_obs.append(obs[i])
            train_act.append(act[i])
        else:
            test_obs.append(obs[i])
            test_act.append(act[i])

    dm_obs = train_obs[0].shape[-1]
    dm_act = train_act[0].shape[-1]

    rarhmm = rARHMM(nb_states,
                    dm_obs,
                    dm_act,
                    trans_type=trans_type,
                    obs_prior=obs_prior,
                    trans_prior=trans_prior,
                    trans_kwargs=trans_kwargs)
    rarhmm.initialize(train_obs, train_act)

    rarhmm.em(train_obs,
              train_act,
              nb_iter=nb_iter,
              prec=prec,
              obs_mstep_kwargs=obs_mstep_kwargs,
              trans_mstep_kwargs=trans_mstep_kwargs,
              process_id=process_id)

    nb_train = np.vstack(train_obs).shape[0]
    nb_all = np.vstack(obs).shape[0]

    train_ll = rarhmm.log_norm(train_obs, train_act)
    all_ll = rarhmm.log_norm(obs, act)

    score = (all_ll - train_ll) / (nb_all - nb_train)

    return rarhmm, all_ll, score
    def fit_rarhmm_job(args):
        obs, act, options = args

        rarhmm = rARHMM(options['nb_states'],
                        options['dm_obs'],
                        options['dm_act'],
                        obs_prior=options['obs_prior'],
                        trans_type=options['trans_type'],
                        trans_prior=options['trans_prior'],
                        trans_kwargs=options['trans_kwargs'])

        if options['initialize']:
            rarhmm.initialize(obs, act)

        rarhmm.em(obs=obs,
                  act=act,
                  nb_iter=100,
                  prec=1e-4,
                  verbose=True,
                  obs_mstep_kwargs=options['obs_mstep_kwargs'],
                  trans_mstep_kwargs=options['trans_mstep_kwargs'])

        return rarhmm
Example #3
0
    nb_states = 2

    obs_prior = {'mu0': 0., 'sigma0': 1e64, 'nu0': (dm_obs + 1) * 23, 'psi0': 1e-16 * 23}
    obs_mstep_kwargs = {'use_prior': True}

    trans_type = 'neural'
    trans_prior = {'l2_penalty': 1e-32, 'alpha': 1, 'kappa': 50}
    trans_kwargs = {'hidden_layer_sizes': (16,),
                    'norm': {'mean': np.array([0., 0., 0.]),
                             'std': np.array([1., 1., 1.])}}
    trans_mstep_kwargs = {'nb_iter': 25, 'batch_size': 256, 'lr': 1e-3}

    rarhmm = rARHMM(nb_states, dm_obs, dm_act,
                    trans_type=trans_type,
                    obs_prior=obs_prior,
                    trans_prior=trans_prior,
                    trans_kwargs=trans_kwargs)
    rarhmm.initialize(obs, act)

    lls = rarhmm.em(obs, act, nb_iter=100, prec=1e-4, verbose=True,
                    obs_mstep_kwargs=obs_mstep_kwargs,
                    trans_mstep_kwargs=trans_mstep_kwargs)

    plt.figure(figsize=(5, 5))
    plt.plot(lls)
    plt.show()

    plt.figure(figsize=(8, 8))
    idx = npr.choice(nb_rollouts)
    _, state = rarhmm.viterbi(obs, act)
Example #4
0
import torch

npr.randn(1337)
torch.manual_seed(1337)

sns.set_style("white")
sns.set_context("talk")

color_names = [
    "windows blue", "red", "amber", "faded green", "dusty purple", "orange"
]

colors = sns.xkcd_palette(color_names)
cmap = gradient_cmap(colors)

true_rarhmm = rARHMM(nb_states=3, dm_obs=2, trans_type='poly')

# trajectory lengths
T = [1250, 1150, 1025]

true_z, x = true_rarhmm.sample(horizon=T)
true_ll = true_rarhmm.log_norm(x)

rarhmm = rARHMM(nb_states=3, dm_obs=2, trans_type='poly', preprocess=True)
rarhmm.initialize(x)

lls = rarhmm.em(x, nb_iter=100, prec=0.)
print("true_ll=", true_ll, "hmm_ll=", lls[-1])

plt.figure(figsize=(5, 5))
plt.plot(np.ones(len(lls)) * true_ll, '-r')