コード例 #1
0
def test_fast_bw_uniform():
  print("Make op...")
  from NativeOp import FastBaumWelchOp
  op = FastBaumWelchOp().make_op()  # (am_scores, edges, weights, start_end_states, float_idx, state_buffer)
  print("Op:", op)
  n_batch = 3
  seq_len = 7
  n_classes = 5
  from Fsa import FastBwFsaShared
  fsa = FastBwFsaShared()
  for i in range(n_classes):
    fsa.add_edge(i, i + 1, emission_idx=i)  # fwd
    fsa.add_edge(i + 1, i + 1, emission_idx=i)  # loop
  assert n_classes <= seq_len
  fast_bw_fsa = fsa.get_fast_bw_fsa(n_batch=n_batch)
  print("edges:")
  print(fast_bw_fsa.edges)
  edges = fast_bw_fsa.edges.view("float32")
  edges_placeholder = T.fmatrix(name="edges")
  weights = fast_bw_fsa.weights
  weights_placeholder = T.fvector(name="weights")
  print("start_end_states:")
  print(fast_bw_fsa.start_end_states)
  start_end_states = fast_bw_fsa.start_end_states.view("float32")
  start_end_states_placeholder = T.fmatrix(name="start_end_states")
  am_scores = numpy.ones((seq_len, n_batch, n_classes), dtype="float32") * numpy.float32(1.0 / n_classes)
  am_scores = -numpy.log(am_scores)  # in -log space
  am_scores_placeholder = T.ftensor3(name="am_scores")
  float_idx = numpy.ones((seq_len, n_batch), dtype="float32")
  float_idx_placeholder = T.fmatrix(name="float_idx")
  last_state_idx = numpy.max(fast_bw_fsa.start_end_states[1])  # see get_automata_for_batch
  state_buffer = numpy.zeros((2, last_state_idx + 1), dtype="float32")
  state_buffer_placeholder = T.fmatrix(name="state_buffer")
  print("Construct call...")
  fwdbwd, obs_scores = op(
    am_scores_placeholder, edges_placeholder, weights_placeholder, start_end_states_placeholder, float_idx_placeholder, state_buffer_placeholder)
  f = theano.function(inputs=[am_scores_placeholder, edges_placeholder, weights_placeholder, start_end_states_placeholder, float_idx_placeholder, state_buffer_placeholder], outputs=[fwdbwd, obs_scores])
  print("Done.")
  print("Eval:")
  fwdbwd, score = f(am_scores, edges, weights, start_end_states, float_idx, state_buffer)
  print("score:")
  print(repr(score))
  assert_equal(score.shape, (seq_len, n_batch))
  bw = numpy.exp(-fwdbwd)
  print("Baum-Welch soft alignment:")
  print(repr(bw))
  assert_equal(bw.shape, (seq_len, n_batch, n_classes))
  from numpy import array, float32
  if seq_len == n_classes:
    print("Extra check identity...")
    for i in range(n_batch):
      assert_almost_equal(numpy.identity(n_classes), bw[:, i])
  if seq_len == 7 and n_classes == 5:
    print("Extra check ref_align (7,5)...")
    assert_allclose(score, 8.55801582, rtol=1e-5)  # should be the same everywhere
    ref_align = \
      array([[[1., 0., 0., 0., 0.]],
             [[0.33333316, 0.66666663, 0., 0., 0.]],
             [[0.06666669, 0.53333354, 0.40000018, 0., 0.]],
             [[0., 0.20000014, 0.60000014, 0.19999999, 0.]],
             [[0., 0., 0.39999962, 0.53333312, 0.06666663]],
             [[0., 0., 0., 0.66666633, 0.33333316]],
             [[0., 0., 0., 0., 0.99999982]]], dtype=float32)
    assert_equal(ref_align.shape, (seq_len, 1, n_classes))
    ref_align = numpy.tile(ref_align, (1, n_batch, 1))
    assert_equal(ref_align.shape, bw.shape)
    # print("Reference alignment:")
    # print(repr(ref_align))
    print("mean square diff:", numpy.mean(numpy.square(ref_align - bw)))
    print("max square diff:", numpy.max(numpy.square(ref_align - bw)))
    assert_allclose(ref_align, bw, rtol=1e-5)
  print("Done.")
コード例 #2
0
ファイル: test_TFNativeOp.py プロジェクト: wbengine/returnn
def test_fast_bw_uniform():
    print("Make op...")
    op = make_fast_baum_welch_op(compiler_opts=dict(
        verbose=True))  # will be cached, used inside :func:`fast_baum_welch`
    # args: (am_scores, edges, weights, start_end_states, float_idx, state_buffer)
    print("Op:", op)
    n_batch = 3
    seq_len = 7
    n_classes = 5
    from Fsa import FastBwFsaShared
    fsa = FastBwFsaShared()
    for i in range(n_classes):
        fsa.add_edge(i, i + 1, emission_idx=i)  # fwd
        fsa.add_edge(i + 1, i + 1, emission_idx=i)  # loop
    assert n_classes <= seq_len
    fast_bw_fsa = fsa.get_fast_bw_fsa(n_batch=n_batch)
    edges = tf.constant(fast_bw_fsa.edges, dtype=tf.int32)
    weights = tf.constant(fast_bw_fsa.weights, dtype=tf.float32)
    start_end_states = tf.constant(fast_bw_fsa.start_end_states,
                                   dtype=tf.int32)
    am_scores = numpy.ones((seq_len, n_batch, n_classes),
                           dtype="float32") * numpy.float32(1.0 / n_classes)
    am_scores = -numpy.log(am_scores)  # in -log space
    am_scores = tf.constant(am_scores, dtype=tf.float32)
    float_idx = tf.ones((seq_len, n_batch), dtype=tf.float32)
    # from TFUtil import sequence_mask_time_major
    # float_idx = tf.cast(sequence_mask_time_major(tf.convert_to_tensor(list(range(seq_len - n_batch + 1, seq_len + 1)))), dtype=tf.float32)
    print("Construct call...")
    fwdbwd, obs_scores = fast_baum_welch(am_scores=am_scores,
                                         float_idx=float_idx,
                                         edges=edges,
                                         weights=weights,
                                         start_end_states=start_end_states)
    print("Done.")
    print("Eval:")
    fwdbwd, score = session.run([fwdbwd, obs_scores])
    print("score:")
    print(repr(score))
    assert_equal(score.shape, (seq_len, n_batch))
    bw = numpy.exp(-fwdbwd)
    print("Baum-Welch soft alignment:")
    print(repr(bw))
    assert_equal(bw.shape, (seq_len, n_batch, n_classes))
    from numpy import array, float32
    if seq_len == n_classes:
        print("Extra check identity...")
        for i in range(n_batch):
            assert_almost_equal(numpy.identity(n_classes), bw[:, i])
    if seq_len == 7 and n_classes == 5:
        print("Extra check ref_align (7,5)...")
        assert_allclose(score, 8.55801582,
                        rtol=1e-5)  # should be the same everywhere
        ref_align = \
          array([[[1., 0., 0., 0., 0.]],
                 [[0.33333316, 0.66666663, 0., 0., 0.]],
                 [[0.06666669, 0.53333354, 0.40000018, 0., 0.]],
                 [[0., 0.20000014, 0.60000014, 0.19999999, 0.]],
                 [[0., 0., 0.39999962, 0.53333312, 0.06666663]],
                 [[0., 0., 0., 0.66666633, 0.33333316]],
                 [[0., 0., 0., 0., 0.99999982]]], dtype=float32)
        assert_equal(ref_align.shape, (seq_len, 1, n_classes))
        ref_align = numpy.tile(ref_align, (1, n_batch, 1))
        assert_equal(ref_align.shape, bw.shape)
        # print("Reference alignment:")
        # print(repr(ref_align))
        print("mean square diff:", numpy.mean(numpy.square(ref_align - bw)))
        print("max square diff:", numpy.max(numpy.square(ref_align - bw)))
        assert_allclose(ref_align, bw, rtol=1e-5)
    print("Done.")