def create_model_fns(hdim):
    """Util function to create model fns to fit model params to sequences.

  Args:
    hdim: Guessed hidden dimension for model fitting.

  Returns:
    A dictionary mapping method names to model_fns. Each model_fn
    takes output seq and input seq, and returns fitted model params.
  """
    model_fns = collections.OrderedDict()
    # Using raw outputs.
    # model_fns['raw_output'] = lambda s: _replace_nan_with_0(s.outputs)
    # pylint: disable=g-long-lambda
    # Pure AR.
    model_fns['AR'] = lambda s: arma.fit_ar(s.outputs, None, hdim)
    # Iterated regression without regularization and with regularization.
    model_fns['ARMA_OLS'] = lambda s: arma.fit_arma_iter(s.outputs, None, hdim)
    model_fns['ARMA'] = lambda s: arma.fit_arma_iter(
        s.outputs, None, hdim, l2_reg=0.01)
    if FLAGS.include_slow_methods:
        model_fns['LDS'] = lambda s: lds.fit_lds_gibbs(
            s.outputs,
            None,
            hdim,
            num_update_samples=FLAGS.LDS_GIBBS_num_update_samples)
        model_fns['ARMA_MLE'] = lambda s: arma.fit_arma_mle(
            s.outputs, None, hdim)
    if FLAGS.include_LDS_MLE:
        model_fns['LDS_MLE'] = lambda s: lds.fit_lds_mle(s.outputs, None, hdim)
    return model_fns
def create_model_fns(hdim):
    """Util function to create model fns to fit model params to sequences.

  Args:
    hdim: Guessed hidden dimension for model fitting.

  Returns:
    A dictionary mapping method names to model_fns. Each model_fn
    takes output seq and input seq, and returns fitted model params.
  """
    model_fns = collections.OrderedDict()
    # Using raw outputs.
    # model_fns['raw_output'] = lambda o, i: o
    # pylint: disable=g-long-lambda
    # Pure AR.
    model_fns['AR'] = lambda o, i: arma.fit_ar(o, i, hdim)
    # Iterated regression without regularization and with regularization.
    model_fns['ARMA_OLS'] = lambda o, i: arma.fit_arma_iter(o, i, hdim)
    model_fns['ARMA_RLS'] = lambda o, i: arma.fit_arma_iter(
        o, i, hdim, l2_reg=0.01)
    # Fit AR model and cluster based on AR param roots.
    # model_fns['AR_roots'] = lambda o, i: arma.get_eig_from_arparams(
    #     arma.fit_ar(o, i, hdim))
    # Fit ARMA model and cluster based on AR param roots.
    # model_fns['ARMA_OLS_roots'] = lambda o, i: arma.get_eig_from_arparams(
    #     arma.fit_arma_iter(o, i, hdim))
    # model_fns['ARMA_RLS_roots_0.01'] = lambda o, i: arma.get_eig_from_arparams(
    #     arma.fit_arma_iter(o, i, hdim, l2_reg=0.01))
    if FLAGS.include_LDS_GIBBS:
        model_fns['LDS_GIBBS'] = lambda o, i: lds.fit_lds_gibbs(
            o, i, hdim, num_update_samples=FLAGS.LDS_GIBBS_num_update_samples)
    if FLAGS.include_ARMA_MLE:
        model_fns['ARMA_MLE'] = lambda o, i: arma.fit_arma_mle(o, i, hdim)
    if FLAGS.include_LDS_MLE:
        model_fns['LDS_MLE'] = lambda o, i: lds.fit_lds_mle(o, i, hdim)
    return model_fns