コード例 #1
0
def test_fast_bw():
  print("Make op...")
  from NativeOp import FastBaumWelchOp
  op = FastBaumWelchOp().make_op()  # (am_scores, edges, weights, start_end_states, float_idx, state_buffer)
  print("Op:", op)
  n_batch = 3
  seq_len = 5
  n_classes = 5
  from Fsa import FastBwFsaShared
  fsa = FastBwFsaShared()
  fsa.add_inf_loop(state_idx=0, num_emission_labels=n_classes)
  fast_bw_fsa = fsa.get_fast_bw_fsa(n_batch=n_batch)
  edges = fast_bw_fsa.edges.view("float32")
  edges_placeholder = T.fmatrix(name="edges")
  weights = fast_bw_fsa.weights
  weights_placeholder = T.fvector(name="weights")
  start_end_states = fast_bw_fsa.start_end_states.view("float32")
  start_end_states_placeholder = T.fmatrix(name="start_end_states")
  am_scores = numpy.random.normal(size=(seq_len, n_batch, n_classes)).astype("float32")  # in -log space
  am_scores_placeholder = T.ftensor3(name="am_scores")
  float_idx = numpy.ones((seq_len, n_batch), dtype="float32")
  float_idx_placeholder = T.fmatrix(name="float_idx")
  last_state_idx = numpy.max(fast_bw_fsa.start_end_states[1])  # see get_automata_for_batch
  state_buffer = numpy.zeros((2, last_state_idx + 1), dtype="float32")
  state_buffer_placeholder = T.fmatrix(name="state_buffer")
  print("Construct call...")
  fwdbwd, obs_scores = op(
    am_scores_placeholder, edges_placeholder, weights_placeholder, start_end_states_placeholder, float_idx_placeholder, state_buffer_placeholder)
  f = theano.function(inputs=[am_scores_placeholder, edges_placeholder, weights_placeholder, start_end_states_placeholder, float_idx_placeholder, state_buffer_placeholder], outputs=[fwdbwd, obs_scores])
  print("Done.")
  print("Eval:")
  _, score = f(am_scores, edges, weights, start_end_states, float_idx, state_buffer)
  print("score:", score)
コード例 #2
0
ファイル: test_TFNativeOp.py プロジェクト: wbengine/returnn
def test_FastBaumWelch():
    print("Make op...")
    op = make_fast_baum_welch_op(compiler_opts=dict(
        verbose=True))  # will be cached, used inside :func:`fast_baum_welch`
    print("Op:", op)
    n_batch = 3
    seq_len = 5
    n_classes = 10
    from Fsa import FastBwFsaShared
    fsa = FastBwFsaShared()
    fsa.add_inf_loop(state_idx=0, num_emission_labels=n_classes)
    fast_bw_fsa = fsa.get_fast_bw_fsa(n_batch=n_batch)
    edges = tf.constant(fast_bw_fsa.edges, dtype=tf.int32)
    weights = tf.constant(fast_bw_fsa.weights, dtype=tf.float32)
    start_end_states = tf.constant(fast_bw_fsa.start_end_states,
                                   dtype=tf.int32)
    am_scores = tf.constant(numpy.random.normal(size=(seq_len, n_batch,
                                                      n_classes)),
                            dtype=tf.float32)  # in -log space
    float_idx = tf.ones((seq_len, n_batch), dtype=tf.float32)
    print("Construct call...")
    fwdbwd, obs_scores = fast_baum_welch(am_scores=am_scores,
                                         float_idx=float_idx,
                                         edges=edges,
                                         weights=weights,
                                         start_end_states=start_end_states)
    print("Done.")
    print("Eval:")
    _, score = session.run([fwdbwd, obs_scores])
    print("score:", score)