def fast_baum_welch_staircase(am_scores, seq_lens, **opts): """ :param tf.Tensor am_scores: (time, batch, dim), in -log space :param tf.Tensor seq_lens: (batch,) -> values in [1, ..., dim-1] :param opts: passed to :func:`Fsa.fast_bw_fsa_staircase` :return: (fwdbwd, obs_scores), fwdbwd is (time, batch, dim), obs_scores is (time, batch), in -log space :rtype: (tf.Tensor, tf.Tensor) """ from TFUtil import sequence_mask_time_major edges, weights, start_end_states = tf_fast_bw_fsa_staircase(seq_lens, **opts) float_idx = sequence_mask_time_major(seq_lens) return fast_baum_welch( am_scores=am_scores, edges=edges, weights=weights, start_end_states=start_end_states, float_idx=float_idx)
def __init__(self, unit="lstm", bidirectional=False, direction=None, input_projection=True, **kwargs): """ :param str unit: the RNNCell/etc name, e.g. "nativelstm". see comment below :param bool bidirectional: whether we should combine a forward and backward cell :param int|None direction: None|1 -> forward, -1 -> backward :param bool input_projection: True -> input is multiplied with matrix. False only works if same input dim :param dict[str] kwargs: passed on to base class """ super(RecLayer, self).__init__(**kwargs) from tensorflow.python.ops import rnn, rnn_cell import tensorflow.contrib.rnn as rnn_contrib import TFNativeOp from TFUtil import swapaxes, dot, sequence_mask_time_major, directed if unit in ["lstmp", "lstm"]: # Some possible LSTM implementations are: # * BasicLSTM, via official TF, pure TF implementation # * LSTMBlockFused, via tf.contrib.rnn (both CPU and GPU). should be much faster than BasicLSTM # * NativeLSTM, our own native LSTM (both CPU and GPU). should be faster than LSTMBlockFused # We default to the fastest one, i.e. NativeLSTM. # Note that they are currently not compatible to each other, i.e. the way the parameters are represented. unit = "nativelstm" if direction is not None: assert not bidirectional assert direction in [-1, 1] if not self._rnn_cells_dict: self._create_rnn_cells_dict() rnn_cell_class = self._rnn_cells_dict[unit.lower()] with tf.variable_scope( "rec", initializer=tf.contrib.layers.xavier_initializer( seed=self.network.random.randint(2**31))) as scope: assert isinstance(scope, tf.VariableScope) scope_name_prefix = scope.name + "/" # e.g. "layer1/rec/" n_hidden = self.output.dim if bidirectional: assert n_hidden % 2 == 0 n_hidden //= 2 cell_fw = rnn_cell_class(n_hidden) assert isinstance(cell_fw, (rnn_cell.RNNCell, rnn_contrib.FusedRNNCell, TFNativeOp.RecSeqCellOp)) # e.g. BasicLSTMCell if bidirectional: cell_bw = rnn_cell_class(n_hidden) else: cell_bw = None x = self.input_data.placeholder # (batch,time,dim) or (time,batch,dim) if not self.input_data.is_time_major: assert self.input_data.batch_dim_axis == 0 assert self.input_data.time_dim_axis == 1 x = swapaxes(x, 0, 1) # (time,batch,[dim]) seq_len = self.input_data.size_placeholder[0] if isinstance(cell_fw, (rnn_cell.RNNCell, rnn_contrib.FusedRNNCell)): assert not self.input_data.sparse assert input_projection if direction == -1: x = tf.reverse_sequence(x, seq_lengths=seq_len, batch_dim=1, seq_dim=0) if isinstance(cell_fw, rnn_cell.RNNCell): # e.g. BasicLSTMCell if bidirectional: # Will get (time,batch,ydim/2). (y_fw, y_bw), _ = rnn.bidirectional_dynamic_rnn( cell_fw=cell_fw, cell_bw=cell_bw, inputs=x, time_major=True, sequence_length=seq_len, dtype=tf.float32) y = tf.concat(2, (y_fw, y_bw)) # (time,batch,ydim) else: # Will get (time,batch,ydim). y, _ = rnn.dynamic_rnn(cell=cell_fw, inputs=x, time_major=True, sequence_length=seq_len, dtype=tf.float32) elif isinstance(cell_fw, rnn_contrib.FusedRNNCell): # e.g. LSTMBlockFusedCell if bidirectional: raise NotImplementedError # Will get (time,batch,ydim). y, _ = cell_fw(inputs=x, sequence_length=seq_len, dtype=tf.float32) else: raise Exception("invalid type: %s" % type(cell_fw)) if direction == -1: y = tf.reverse_sequence(y, seq_lengths=seq_len, batch_dim=1, seq_dim=0) elif isinstance(cell_fw, TFNativeOp.RecSeqCellOp): assert not bidirectional if input_projection: W = tf.get_variable(name="W", shape=(self.input_data.dim, cell_fw.n_input_dim), dtype=tf.float32) if self.input_data.sparse: x = tf.nn.embedding_lookup(W, x) else: x = dot(x, W) else: assert not self.input_data.sparse assert self.input_data.dim == cell_fw.n_input_dim b = tf.get_variable(name="b", shape=(cell_fw.n_input_dim,), dtype=tf.float32, initializer=tf.constant_initializer(0.0)) x += b index = sequence_mask_time_major(seq_len, maxlen=tf.shape(x)[0]) y = cell_fw(inputs=directed(x, direction), index=directed(index, direction)) y = directed(y, direction) else: raise Exception("invalid type: %s" % type(cell_fw)) self.output.time_dim_axis = 0 self.output.batch_dim_axis = 1 self.output.placeholder = y params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope_name_prefix) assert params self.params.update({p.name[len(scope_name_prefix):-2]: p for p in params})