Example #1
0
def fast_baum_welch_staircase(am_scores, seq_lens, **opts):
  """
  :param tf.Tensor am_scores: (time, batch, dim), in -log space
  :param tf.Tensor seq_lens: (batch,) -> values in [1, ..., dim-1]
  :param opts: passed to :func:`Fsa.fast_bw_fsa_staircase`
  :return: (fwdbwd, obs_scores), fwdbwd is (time, batch, dim), obs_scores is (time, batch), in -log space
  :rtype: (tf.Tensor, tf.Tensor)
  """
  from TFUtil import sequence_mask_time_major
  edges, weights, start_end_states = tf_fast_bw_fsa_staircase(seq_lens, **opts)
  float_idx = sequence_mask_time_major(seq_lens)
  return fast_baum_welch(
    am_scores=am_scores, edges=edges, weights=weights, start_end_states=start_end_states, float_idx=float_idx)
Example #2
0
 def __init__(self, unit="lstm", bidirectional=False, direction=None, input_projection=True, **kwargs):
   """
   :param str unit: the RNNCell/etc name, e.g. "nativelstm". see comment below
   :param bool bidirectional: whether we should combine a forward and backward cell
   :param int|None direction: None|1 -> forward, -1 -> backward
   :param bool input_projection: True -> input is multiplied with matrix. False only works if same input dim
   :param dict[str] kwargs: passed on to base class
   """
   super(RecLayer, self).__init__(**kwargs)
   from tensorflow.python.ops import rnn, rnn_cell
   import tensorflow.contrib.rnn as rnn_contrib
   import TFNativeOp
   from TFUtil import swapaxes, dot, sequence_mask_time_major, directed
   if unit in ["lstmp", "lstm"]:
     # Some possible LSTM implementations are:
     # * BasicLSTM, via official TF, pure TF implementation
     # * LSTMBlockFused, via tf.contrib.rnn (both CPU and GPU). should be much faster than BasicLSTM
     # * NativeLSTM, our own native LSTM (both CPU and GPU). should be faster than LSTMBlockFused
     # We default to the fastest one, i.e. NativeLSTM.
     # Note that they are currently not compatible to each other, i.e. the way the parameters are represented.
     unit = "nativelstm"
   if direction is not None:
     assert not bidirectional
     assert direction in [-1, 1]
   if not self._rnn_cells_dict:
     self._create_rnn_cells_dict()
   rnn_cell_class = self._rnn_cells_dict[unit.lower()]
   with tf.variable_scope(
         "rec",
         initializer=tf.contrib.layers.xavier_initializer(
           seed=self.network.random.randint(2**31))) as scope:
     assert isinstance(scope, tf.VariableScope)
     scope_name_prefix = scope.name + "/"  # e.g. "layer1/rec/"
     n_hidden = self.output.dim
     if bidirectional:
       assert n_hidden % 2 == 0
       n_hidden //= 2
     cell_fw = rnn_cell_class(n_hidden)
     assert isinstance(cell_fw, (rnn_cell.RNNCell, rnn_contrib.FusedRNNCell, TFNativeOp.RecSeqCellOp))  # e.g. BasicLSTMCell
     if bidirectional:
       cell_bw = rnn_cell_class(n_hidden)
     else:
       cell_bw = None
     x = self.input_data.placeholder  # (batch,time,dim) or (time,batch,dim)
     if not self.input_data.is_time_major:
       assert self.input_data.batch_dim_axis == 0
       assert self.input_data.time_dim_axis == 1
       x = swapaxes(x, 0, 1)   # (time,batch,[dim])
     seq_len = self.input_data.size_placeholder[0]
     if isinstance(cell_fw, (rnn_cell.RNNCell, rnn_contrib.FusedRNNCell)):
       assert not self.input_data.sparse
       assert input_projection
       if direction == -1:
         x = tf.reverse_sequence(x, seq_lengths=seq_len, batch_dim=1, seq_dim=0)
       if isinstance(cell_fw, rnn_cell.RNNCell):  # e.g. BasicLSTMCell
         if bidirectional:
           # Will get (time,batch,ydim/2).
           (y_fw, y_bw), _ = rnn.bidirectional_dynamic_rnn(
             cell_fw=cell_fw, cell_bw=cell_bw,
             inputs=x, time_major=True, sequence_length=seq_len,
             dtype=tf.float32)
           y = tf.concat(2, (y_fw, y_bw))  # (time,batch,ydim)
         else:
           # Will get (time,batch,ydim).
           y, _ = rnn.dynamic_rnn(cell=cell_fw, inputs=x, time_major=True, sequence_length=seq_len, dtype=tf.float32)
       elif isinstance(cell_fw, rnn_contrib.FusedRNNCell):  # e.g. LSTMBlockFusedCell
         if bidirectional:
           raise NotImplementedError
         # Will get (time,batch,ydim).
         y, _ = cell_fw(inputs=x, sequence_length=seq_len, dtype=tf.float32)
       else:
         raise Exception("invalid type: %s" % type(cell_fw))
       if direction == -1:
         y = tf.reverse_sequence(y, seq_lengths=seq_len, batch_dim=1, seq_dim=0)
     elif isinstance(cell_fw, TFNativeOp.RecSeqCellOp):
       assert not bidirectional
       if input_projection:
         W = tf.get_variable(name="W", shape=(self.input_data.dim, cell_fw.n_input_dim), dtype=tf.float32)
         if self.input_data.sparse:
           x = tf.nn.embedding_lookup(W, x)
         else:
           x = dot(x, W)
       else:
         assert not self.input_data.sparse
         assert self.input_data.dim == cell_fw.n_input_dim
       b = tf.get_variable(name="b", shape=(cell_fw.n_input_dim,), dtype=tf.float32, initializer=tf.constant_initializer(0.0))
       x += b
       index = sequence_mask_time_major(seq_len, maxlen=tf.shape(x)[0])
       y = cell_fw(inputs=directed(x, direction), index=directed(index, direction))
       y = directed(y, direction)
     else:
       raise Exception("invalid type: %s" % type(cell_fw))
     self.output.time_dim_axis = 0
     self.output.batch_dim_axis = 1
     self.output.placeholder = y
     params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope_name_prefix)
     assert params
     self.params.update({p.name[len(scope_name_prefix):-2]: p for p in params})