def RNN(self, inputs): """""" input_size = inputs.get_shape().as_list()[-1] cell = self.recur_cell(self._config, input_size=input_size, moving_params=self.moving_params) lengths = tf.reshape(tf.to_int64(self.sequence_lengths), [-1]) if self.moving_params is None: ff_keep_prob = self.ff_keep_prob recur_keep_prob = self.recur_keep_prob else: ff_keep_prob = 1 recur_keep_prob = 1 if self.recur_bidir: top_recur, fw_recur, bw_recur = rnn.dynamic_bidirectional_rnn(cell, cell, inputs, lengths, ff_keep_prob=ff_keep_prob, recur_keep_prob=recur_keep_prob, dtype=tf.float32) fw_cell, fw_out = tf.split(axis=1, num_or_size_splits=2, value=fw_recur) bw_cell, bw_out = tf.split(axis=1, num_or_size_splits=2, value=bw_recur) end_recur = tf.concat(axis=1, values=[fw_out, bw_out]) top_recur.set_shape([tf.Dimension(None), tf.Dimension(None), tf.Dimension(2*self.recur_size)]) else: top_recur, end_recur = rnn.dynamic_rnn(cell, inputs, lengths, ff_keep_prob=ff_keep_prob, recur_keep_prob=recur_keep_prob, dtype=tf.float32) top_recur.set_shape([tf.Dimension(None), tf.Dimension(None), tf.Dimension(self.recur_size)]) return top_recur, end_recur
def RNN(self, inputs, fw_keep_mask=None, bw_keep_mask=None): """""" batch_size = tf.shape(inputs)[0] input_size = inputs.get_shape().as_list()[-1] cell = self.recur_cell(self._config, input_size=input_size, moving_params=self.moving_params) lengths = tf.reshape(tf.to_int64(self.sequence_lengths), [-1]) if self.moving_params is None: recur_keep_prob = self.recur_keep_prob else: recur_keep_prob = 1 if self.recur_bidir: top_recur, fw_recur, bw_recur = rnn.dynamic_bidirectional_rnn( cell, cell, inputs, lengths, fw_keep_mask=fw_keep_mask, bw_keep_mask=bw_keep_mask, recur_keep_prob=recur_keep_prob, dtype=tf.float32) fw_cell, fw_out = tf.split(1, 2, fw_recur) bw_cell, bw_out = tf.split(1, 2, bw_recur) end_recur = tf.concat(1, [fw_out, bw_out]) top_recur.set_shape([ tf.Dimension(None), tf.Dimension(None), tf.Dimension(2 * self.recur_size) ]) if self.moving_params is None: for direction in ('FW', 'BW'): if self.recur_cell.__name__ != 'GRUCell': with tf.variable_scope( "BiRNN_%s/%s/Linear" % (direction, self.recur_cell.__name__), reuse=True): matrix = tf.get_variable('Weights') n_splits = matrix.get_shape().as_list( )[-1] // self.recur_size I = tf.diag(tf.ones([self.recur_size])) for W in tf.split(1, n_splits, matrix): WTWmI = tf.matmul(W, W, transpose_a=True) - I tf.add_to_collection('ortho_losses', tf.nn.l2_loss(WTWmI)) else: for name in ['Gates', 'Candidate']: with tf.variable_scope( "BiRNN_%s/GRUCell/%s/Linear" % (direction, name), reuse=True): matrix = tf.get_variable('Weights') n_splits = matrix.get_shape().as_list( )[-1] // self.recur_size I = tf.diag(tf.ones([self.recur_size])) for W in tf.split(1, n_splits, matrix): WTWmI = tf.matmul(W, W, transpose_a=True) - I tf.add_to_collection( 'ortho_losses', tf.nn.l2_loss(WTWmI)) else: top_recur, end_recur = rnn.dynamic_rnn( cell, inputs, lengths, ff_keep_mask=fw_keep_mask, recur_keep_prob=recur_keep_prob, dtype=tf.float32) top_recur.set_shape([ tf.Dimension(None), tf.Dimension(None), tf.Dimension(self.recur_size) ]) if self.moving_params is None: if self.recur_cell.__name__ != 'GRUCell': with tf.variable_scope("%s/Linear" % (self.recur_cell.__name__), reuse=True): matrix = tf.get_variable('Weights') n_splits = matrix.get_shape().as_list( )[-1] // self.recur_size I = tf.diag(tf.ones([self.recur_size])) for W in tf.split(1, n_splits, matrix): WTWmI = tf.matmul(W, W, transpose_a=True) - I tf.add_to_collection('ortho_losses', tf.nn.l2_loss(WTWmI)) else: for name in ['Gates', 'Candidate']: with tf.variable_scope("GRUCell/%s/Linear" % (name), reuse=True): matrix = tf.get_variable('Weights') n_splits = matrix.get_shape().as_list( )[-1] // self.recur_size I = tf.diag(tf.ones([self.recur_size])) for W in tf.split(1, n_splits, matrix): WTWmI = tf.matmul(W, W, transpose_a=True) - I tf.add_to_collection('ortho_losses', tf.nn.l2_loss(WTWmI)) if self.moving_params is None: tf.add_to_collection('recur_losses', self.recur_loss(top_recur)) tf.add_to_collection('covar_losses', self.covar_loss(top_recur)) return top_recur, end_recur