def test_fast_bw(): print("Make op...") from NativeOp import FastBaumWelchOp op = FastBaumWelchOp().make_op() # (am_scores, edges, weights, start_end_states, float_idx, state_buffer) print("Op:", op) n_batch = 3 seq_len = 5 n_classes = 5 from Fsa import FastBwFsaShared fsa = FastBwFsaShared() fsa.add_inf_loop(state_idx=0, num_emission_labels=n_classes) fast_bw_fsa = fsa.get_fast_bw_fsa(n_batch=n_batch) edges = fast_bw_fsa.edges.view("float32") edges_placeholder = T.fmatrix(name="edges") weights = fast_bw_fsa.weights weights_placeholder = T.fvector(name="weights") start_end_states = fast_bw_fsa.start_end_states.view("float32") start_end_states_placeholder = T.fmatrix(name="start_end_states") am_scores = numpy.random.normal(size=(seq_len, n_batch, n_classes)).astype("float32") # in -log space am_scores_placeholder = T.ftensor3(name="am_scores") float_idx = numpy.ones((seq_len, n_batch), dtype="float32") float_idx_placeholder = T.fmatrix(name="float_idx") last_state_idx = numpy.max(fast_bw_fsa.start_end_states[1]) # see get_automata_for_batch state_buffer = numpy.zeros((2, last_state_idx + 1), dtype="float32") state_buffer_placeholder = T.fmatrix(name="state_buffer") print("Construct call...") fwdbwd, obs_scores = op( am_scores_placeholder, edges_placeholder, weights_placeholder, start_end_states_placeholder, float_idx_placeholder, state_buffer_placeholder) f = theano.function(inputs=[am_scores_placeholder, edges_placeholder, weights_placeholder, start_end_states_placeholder, float_idx_placeholder, state_buffer_placeholder], outputs=[fwdbwd, obs_scores]) print("Done.") print("Eval:") _, score = f(am_scores, edges, weights, start_end_states, float_idx, state_buffer) print("score:", score)
def test_FastBaumWelch(): print("Make op...") op = make_fast_baum_welch_op(compiler_opts=dict( verbose=True)) # will be cached, used inside :func:`fast_baum_welch` print("Op:", op) n_batch = 3 seq_len = 5 n_classes = 10 from Fsa import FastBwFsaShared fsa = FastBwFsaShared() fsa.add_inf_loop(state_idx=0, num_emission_labels=n_classes) fast_bw_fsa = fsa.get_fast_bw_fsa(n_batch=n_batch) edges = tf.constant(fast_bw_fsa.edges, dtype=tf.int32) weights = tf.constant(fast_bw_fsa.weights, dtype=tf.float32) start_end_states = tf.constant(fast_bw_fsa.start_end_states, dtype=tf.int32) am_scores = tf.constant(numpy.random.normal(size=(seq_len, n_batch, n_classes)), dtype=tf.float32) # in -log space float_idx = tf.ones((seq_len, n_batch), dtype=tf.float32) print("Construct call...") fwdbwd, obs_scores = fast_baum_welch(am_scores=am_scores, float_idx=float_idx, edges=edges, weights=weights, start_end_states=start_end_states) print("Done.") print("Eval:") _, score = session.run([fwdbwd, obs_scores]) print("score:", score)