def fast_baum_welch_by_sprint_automata(am_scores, float_idx, tags, sprint_opts, tdp_scale=1.0): """ :param tf.Tensor am_scores: (time, batch, dim), in -log space :param tf.Tensor float_idx: (time, batch) -> 0 or 1 (index mask, via seq lens) :param tf.Tensor tags: (batch,) -> seq name (str) :param float tdp_scale: weights are multiplied by this :param dict[str] sprint_opts: :return: (fwdbwd, obs_scores), fwdbwd is (time, batch, dim), obs_scores is (time, batch), in -log space :rtype: (tf.Tensor, tf.Tensor) """ from TFSprint import get_sprint_automata_for_batch_op edges, weights, start_end_states = get_sprint_automata_for_batch_op( sprint_opts=sprint_opts, tags=tags) if tdp_scale != 1: if tdp_scale == 0: weights = tf.zeros_like(weights) else: weights *= tdp_scale return fast_baum_welch(am_scores=am_scores, float_idx=float_idx, edges=edges, weights=weights, start_end_states=start_end_states)
def fast_baum_welch_by_sprint_automata(am_scores, float_idx, tags, sprint_opts): """ :param tf.Tensor am_scores: (time, batch, dim), in -log space :param tf.Tensor float_idx: (time, batch) -> 0 or 1 (index mask, via seq lens) :param tf.Tensor tags: (batch,) -> seq name (str) :param dict[str] sprint_opts: :return: (fwdbwd, obs_scores), fwdbwd is (time, batch, dim), obs_scores is (time, batch), in -log space :rtype: (tf.Tensor, tf.Tensor) """ from TFSprint import get_sprint_automata_for_batch_op edges, weights, start_end_states = get_sprint_automata_for_batch_op( sprint_opts=sprint_opts, tags=tags) return fast_baum_welch(am_scores=am_scores, float_idx=float_idx, edges=edges, weights=weights, start_end_states=start_end_states)