コード例 #1
0
def create_learning_fns(hidden_dim):
    """Util function to create learning fns to get learned eigenvalues."""
    learning_fns = collections.OrderedDict()
    # pylint: disable=g-long-lambda
    learning_fns['AR'] = lambda o, i: get_eig_from_arparams(
        arma.fit_ar(o, i, hidden_dim))
    # The ARMA_OLS method gives very big errors, hence not included.
    # learning_fns['ARMA_OLS'] = lambda o, i: get_eig_from_arparams(
    #     arma.fit_arma_iter(o, i, hidden_dim))
    # learning_fns['ARMA_RLS_0.1'] = lambda o, i: get_eig_from_arparams(
    #     arma.fit_arma_iter(o, i, hidden_dim, l2_reg=0.1))
    learning_fns['ARMA_RLS'] = lambda o, i: get_eig_from_arparams(
        arma.fit_arma_iter(o, i, hidden_dim, l2_reg=0.01))
    # learning_fns['ARMA_RLS_0.005'] = lambda o, i: get_eig_from_arparams(
    #     arma.fit_arma_iter(o, i, hidden_dim, l2_reg=0.005))
    # learning_fns['ARMA_RLS_0.001'] = lambda o, i: get_eig_from_arparams(
    #     arma.fit_arma_iter(o, i, hidden_dim, l2_reg=0.001))
    if FLAGS.include_slow_methods:
        learning_fns['ARMA_MLE'] = lambda o, i: get_eig_from_arparams(
            arma.fit_arma_mle(o, i, hidden_dim))
        # The LDS_MLE  method fails convergence too often, hence not included.
        # learning_fns['LDS_MLE'] = lambda o, i: lds.fit_lds_mle(
        #     o, i, hidden_dim)
        learning_fns['LDS_GIBBS'] = lambda o, i: lds.fit_lds_gibbs(
            o,
            i,
            hidden_dim,
            num_update_samples=FLAGS.LDS_GIBBS_num_update_samples)
    return learning_fns
コード例 #2
0
def create_model_fns(hdim):
    """Util function to create model fns to fit model params to sequences.

  Args:
    hdim: Guessed hidden dimension for model fitting.

  Returns:
    A dictionary mapping method names to model_fns. Each model_fn
    takes output seq and input seq, and returns fitted model params.
  """
    model_fns = collections.OrderedDict()
    # Using raw outputs.
    # model_fns['raw_output'] = lambda s: _replace_nan_with_0(s.outputs)
    # pylint: disable=g-long-lambda
    # Pure AR.
    model_fns['AR'] = lambda s: arma.fit_ar(s.outputs, None, hdim)
    # Iterated regression without regularization and with regularization.
    model_fns['ARMA_OLS'] = lambda s: arma.fit_arma_iter(s.outputs, None, hdim)
    model_fns['ARMA'] = lambda s: arma.fit_arma_iter(
        s.outputs, None, hdim, l2_reg=0.01)
    if FLAGS.include_slow_methods:
        model_fns['LDS'] = lambda s: lds.fit_lds_gibbs(
            s.outputs,
            None,
            hdim,
            num_update_samples=FLAGS.LDS_GIBBS_num_update_samples)
        model_fns['ARMA_MLE'] = lambda s: arma.fit_arma_mle(
            s.outputs, None, hdim)
    if FLAGS.include_LDS_MLE:
        model_fns['LDS_MLE'] = lambda s: lds.fit_lds_mle(s.outputs, None, hdim)
    return model_fns
コード例 #3
0
def create_model_fns(hdim):
    """Util function to create model fns to fit model params to sequences.

  Args:
    hdim: Guessed hidden dimension for model fitting.

  Returns:
    A dictionary mapping method names to model_fns. Each model_fn
    takes output seq and input seq, and returns fitted model params.
  """
    model_fns = collections.OrderedDict()
    # Using raw outputs.
    # model_fns['raw_output'] = lambda o, i: o
    # pylint: disable=g-long-lambda
    # Pure AR.
    model_fns['AR'] = lambda o, i: arma.fit_ar(o, i, hdim)
    # Iterated regression without regularization and with regularization.
    model_fns['ARMA_OLS'] = lambda o, i: arma.fit_arma_iter(o, i, hdim)
    model_fns['ARMA_RLS'] = lambda o, i: arma.fit_arma_iter(
        o, i, hdim, l2_reg=0.01)
    # Fit AR model and cluster based on AR param roots.
    # model_fns['AR_roots'] = lambda o, i: arma.get_eig_from_arparams(
    #     arma.fit_ar(o, i, hdim))
    # Fit ARMA model and cluster based on AR param roots.
    # model_fns['ARMA_OLS_roots'] = lambda o, i: arma.get_eig_from_arparams(
    #     arma.fit_arma_iter(o, i, hdim))
    # model_fns['ARMA_RLS_roots_0.01'] = lambda o, i: arma.get_eig_from_arparams(
    #     arma.fit_arma_iter(o, i, hdim, l2_reg=0.01))
    if FLAGS.include_LDS_GIBBS:
        model_fns['LDS_GIBBS'] = lambda o, i: lds.fit_lds_gibbs(
            o, i, hdim, num_update_samples=FLAGS.LDS_GIBBS_num_update_samples)
    if FLAGS.include_ARMA_MLE:
        model_fns['ARMA_MLE'] = lambda o, i: arma.fit_arma_mle(o, i, hdim)
    if FLAGS.include_LDS_MLE:
        model_fns['LDS_MLE'] = lambda o, i: lds.fit_lds_mle(o, i, hdim)
    return model_fns