def create_job(kwargs): # model arguments nb_states = kwargs.pop('nb_states') trans_type = kwargs.pop('trans_type') obs_prior = kwargs.pop('obs_prior') trans_prior = kwargs.pop('trans_prior') trans_kwargs = kwargs.pop('trans_kwargs') # em arguments obs = kwargs.pop('obs') act = kwargs.pop('act') prec = kwargs.pop('prec') nb_iter = kwargs.pop('nb_iter') obs_mstep_kwargs = kwargs.pop('obs_mstep_kwargs') trans_mstep_kwargs = kwargs.pop('trans_mstep_kwargs') process_id = kwargs.pop('process_id') train_obs, train_act, test_obs, test_act = [], [], [], [] train_idx = npr.choice(a=len(obs), size=int(0.8 * len(obs)), replace=False) for i in range(len(obs)): if i in train_idx: train_obs.append(obs[i]) train_act.append(act[i]) else: test_obs.append(obs[i]) test_act.append(act[i]) dm_obs = train_obs[0].shape[-1] dm_act = train_act[0].shape[-1] rarhmm = rARHMM(nb_states, dm_obs, dm_act, trans_type=trans_type, obs_prior=obs_prior, trans_prior=trans_prior, trans_kwargs=trans_kwargs) rarhmm.initialize(train_obs, train_act) rarhmm.em(train_obs, train_act, nb_iter=nb_iter, prec=prec, obs_mstep_kwargs=obs_mstep_kwargs, trans_mstep_kwargs=trans_mstep_kwargs, process_id=process_id) nb_train = np.vstack(train_obs).shape[0] nb_all = np.vstack(obs).shape[0] train_ll = rarhmm.log_norm(train_obs, train_act) all_ll = rarhmm.log_norm(obs, act) score = (all_ll - train_ll) / (nb_all - nb_train) return rarhmm, all_ll, score
def fit_rarhmm_job(args): obs, act, options = args rarhmm = rARHMM(options['nb_states'], options['dm_obs'], options['dm_act'], obs_prior=options['obs_prior'], trans_type=options['trans_type'], trans_prior=options['trans_prior'], trans_kwargs=options['trans_kwargs']) if options['initialize']: rarhmm.initialize(obs, act) rarhmm.em(obs=obs, act=act, nb_iter=100, prec=1e-4, verbose=True, obs_mstep_kwargs=options['obs_mstep_kwargs'], trans_mstep_kwargs=options['trans_mstep_kwargs']) return rarhmm
nb_states = 2 obs_prior = {'mu0': 0., 'sigma0': 1e64, 'nu0': (dm_obs + 1) * 23, 'psi0': 1e-16 * 23} obs_mstep_kwargs = {'use_prior': True} trans_type = 'neural' trans_prior = {'l2_penalty': 1e-32, 'alpha': 1, 'kappa': 50} trans_kwargs = {'hidden_layer_sizes': (16,), 'norm': {'mean': np.array([0., 0., 0.]), 'std': np.array([1., 1., 1.])}} trans_mstep_kwargs = {'nb_iter': 25, 'batch_size': 256, 'lr': 1e-3} rarhmm = rARHMM(nb_states, dm_obs, dm_act, trans_type=trans_type, obs_prior=obs_prior, trans_prior=trans_prior, trans_kwargs=trans_kwargs) rarhmm.initialize(obs, act) lls = rarhmm.em(obs, act, nb_iter=100, prec=1e-4, verbose=True, obs_mstep_kwargs=obs_mstep_kwargs, trans_mstep_kwargs=trans_mstep_kwargs) plt.figure(figsize=(5, 5)) plt.plot(lls) plt.show() plt.figure(figsize=(8, 8)) idx = npr.choice(nb_rollouts) _, state = rarhmm.viterbi(obs, act)
import torch npr.randn(1337) torch.manual_seed(1337) sns.set_style("white") sns.set_context("talk") color_names = [ "windows blue", "red", "amber", "faded green", "dusty purple", "orange" ] colors = sns.xkcd_palette(color_names) cmap = gradient_cmap(colors) true_rarhmm = rARHMM(nb_states=3, dm_obs=2, trans_type='poly') # trajectory lengths T = [1250, 1150, 1025] true_z, x = true_rarhmm.sample(horizon=T) true_ll = true_rarhmm.log_norm(x) rarhmm = rARHMM(nb_states=3, dm_obs=2, trans_type='poly', preprocess=True) rarhmm.initialize(x) lls = rarhmm.em(x, nb_iter=100, prec=0.) print("true_ll=", true_ll, "hmm_ll=", lls[-1]) plt.figure(figsize=(5, 5)) plt.plot(np.ones(len(lls)) * true_ll, '-r')