def create_model_fns(hdim): """Util function to create model fns to fit model params to sequences. Args: hdim: Guessed hidden dimension for model fitting. Returns: A dictionary mapping method names to model_fns. Each model_fn takes output seq and input seq, and returns fitted model params. """ model_fns = collections.OrderedDict() # Using raw outputs. # model_fns['raw_output'] = lambda s: _replace_nan_with_0(s.outputs) # pylint: disable=g-long-lambda # Pure AR. model_fns['AR'] = lambda s: arma.fit_ar(s.outputs, None, hdim) # Iterated regression without regularization and with regularization. model_fns['ARMA_OLS'] = lambda s: arma.fit_arma_iter(s.outputs, None, hdim) model_fns['ARMA'] = lambda s: arma.fit_arma_iter( s.outputs, None, hdim, l2_reg=0.01) if FLAGS.include_slow_methods: model_fns['LDS'] = lambda s: lds.fit_lds_gibbs( s.outputs, None, hdim, num_update_samples=FLAGS.LDS_GIBBS_num_update_samples) model_fns['ARMA_MLE'] = lambda s: arma.fit_arma_mle( s.outputs, None, hdim) if FLAGS.include_LDS_MLE: model_fns['LDS_MLE'] = lambda s: lds.fit_lds_mle(s.outputs, None, hdim) return model_fns
def create_model_fns(hdim): """Util function to create model fns to fit model params to sequences. Args: hdim: Guessed hidden dimension for model fitting. Returns: A dictionary mapping method names to model_fns. Each model_fn takes output seq and input seq, and returns fitted model params. """ model_fns = collections.OrderedDict() # Using raw outputs. # model_fns['raw_output'] = lambda o, i: o # pylint: disable=g-long-lambda # Pure AR. model_fns['AR'] = lambda o, i: arma.fit_ar(o, i, hdim) # Iterated regression without regularization and with regularization. model_fns['ARMA_OLS'] = lambda o, i: arma.fit_arma_iter(o, i, hdim) model_fns['ARMA_RLS'] = lambda o, i: arma.fit_arma_iter( o, i, hdim, l2_reg=0.01) # Fit AR model and cluster based on AR param roots. # model_fns['AR_roots'] = lambda o, i: arma.get_eig_from_arparams( # arma.fit_ar(o, i, hdim)) # Fit ARMA model and cluster based on AR param roots. # model_fns['ARMA_OLS_roots'] = lambda o, i: arma.get_eig_from_arparams( # arma.fit_arma_iter(o, i, hdim)) # model_fns['ARMA_RLS_roots_0.01'] = lambda o, i: arma.get_eig_from_arparams( # arma.fit_arma_iter(o, i, hdim, l2_reg=0.01)) if FLAGS.include_LDS_GIBBS: model_fns['LDS_GIBBS'] = lambda o, i: lds.fit_lds_gibbs( o, i, hdim, num_update_samples=FLAGS.LDS_GIBBS_num_update_samples) if FLAGS.include_ARMA_MLE: model_fns['ARMA_MLE'] = lambda o, i: arma.fit_arma_mle(o, i, hdim) if FLAGS.include_LDS_MLE: model_fns['LDS_MLE'] = lambda o, i: lds.fit_lds_mle(o, i, hdim) return model_fns