def natural_lds_estep_general_autograd(natparam, node_params): init_params, pair_params = natparam def lds_log_normalizer(natparam): init_params, pair_params = natparam _, lognorm = natural_filter_forward_general(init_params, pair_params, node_params) return lognorm return vgrad(lds_log_normalizer)(natparam)
def EM_update(params): natural_params = list(map(np.log, params)) loglike, E_stats = vgrad(log_partition_function)(natural_params, data) # E step if callback: callback(loglike, params) return list(map(normalize, E_stats)) # M step
def EM_update(params): natural_params = map(np.log, params) loglike, E_stats = vgrad(log_partition_function)(natural_params, data) # E step if callback: callback(loglike, params) return map(normalize, E_stats) # M step