Exemplo n.º 1
0
 def RNN(self, inputs):
   """"""
   
   input_size = inputs.get_shape().as_list()[-1]
   cell = self.recur_cell(self._config, input_size=input_size, moving_params=self.moving_params)
   lengths = tf.reshape(tf.to_int64(self.sequence_lengths), [-1])
   
   if self.moving_params is None:
     ff_keep_prob = self.ff_keep_prob
     recur_keep_prob = self.recur_keep_prob
   else:
     ff_keep_prob = 1
     recur_keep_prob = 1
   
   if self.recur_bidir:
     top_recur, fw_recur, bw_recur = rnn.dynamic_bidirectional_rnn(cell, cell, inputs,
                                                                   lengths,
                                                                   ff_keep_prob=ff_keep_prob,
                                                                   recur_keep_prob=recur_keep_prob,
                                                                   dtype=tf.float32)
     fw_cell, fw_out = tf.split(axis=1, num_or_size_splits=2, value=fw_recur)
     bw_cell, bw_out = tf.split(axis=1, num_or_size_splits=2, value=bw_recur)
     end_recur = tf.concat(axis=1, values=[fw_out, bw_out])
     top_recur.set_shape([tf.Dimension(None), tf.Dimension(None), tf.Dimension(2*self.recur_size)])
   else:
     top_recur, end_recur = rnn.dynamic_rnn(cell, inputs,
                                            lengths,
                                            ff_keep_prob=ff_keep_prob,
                                            recur_keep_prob=recur_keep_prob,
                                            dtype=tf.float32)
     top_recur.set_shape([tf.Dimension(None), tf.Dimension(None), tf.Dimension(self.recur_size)])
   return top_recur, end_recur
Exemplo n.º 2
0
    def RNN(self, inputs, fw_keep_mask=None, bw_keep_mask=None):
        """"""

        batch_size = tf.shape(inputs)[0]
        input_size = inputs.get_shape().as_list()[-1]
        cell = self.recur_cell(self._config,
                               input_size=input_size,
                               moving_params=self.moving_params)
        lengths = tf.reshape(tf.to_int64(self.sequence_lengths), [-1])

        if self.moving_params is None:
            recur_keep_prob = self.recur_keep_prob
        else:
            recur_keep_prob = 1

        if self.recur_bidir:
            top_recur, fw_recur, bw_recur = rnn.dynamic_bidirectional_rnn(
                cell,
                cell,
                inputs,
                lengths,
                fw_keep_mask=fw_keep_mask,
                bw_keep_mask=bw_keep_mask,
                recur_keep_prob=recur_keep_prob,
                dtype=tf.float32)
            fw_cell, fw_out = tf.split(1, 2, fw_recur)
            bw_cell, bw_out = tf.split(1, 2, bw_recur)
            end_recur = tf.concat(1, [fw_out, bw_out])
            top_recur.set_shape([
                tf.Dimension(None),
                tf.Dimension(None),
                tf.Dimension(2 * self.recur_size)
            ])
            if self.moving_params is None:
                for direction in ('FW', 'BW'):
                    if self.recur_cell.__name__ != 'GRUCell':
                        with tf.variable_scope(
                                "BiRNN_%s/%s/Linear" %
                            (direction, self.recur_cell.__name__),
                                reuse=True):
                            matrix = tf.get_variable('Weights')
                            n_splits = matrix.get_shape().as_list(
                            )[-1] // self.recur_size
                            I = tf.diag(tf.ones([self.recur_size]))
                            for W in tf.split(1, n_splits, matrix):
                                WTWmI = tf.matmul(W, W, transpose_a=True) - I
                                tf.add_to_collection('ortho_losses',
                                                     tf.nn.l2_loss(WTWmI))
                    else:
                        for name in ['Gates', 'Candidate']:
                            with tf.variable_scope(
                                    "BiRNN_%s/GRUCell/%s/Linear" %
                                (direction, name),
                                    reuse=True):
                                matrix = tf.get_variable('Weights')
                                n_splits = matrix.get_shape().as_list(
                                )[-1] // self.recur_size
                                I = tf.diag(tf.ones([self.recur_size]))
                                for W in tf.split(1, n_splits, matrix):
                                    WTWmI = tf.matmul(W, W,
                                                      transpose_a=True) - I
                                    tf.add_to_collection(
                                        'ortho_losses', tf.nn.l2_loss(WTWmI))
        else:
            top_recur, end_recur = rnn.dynamic_rnn(
                cell,
                inputs,
                lengths,
                ff_keep_mask=fw_keep_mask,
                recur_keep_prob=recur_keep_prob,
                dtype=tf.float32)
            top_recur.set_shape([
                tf.Dimension(None),
                tf.Dimension(None),
                tf.Dimension(self.recur_size)
            ])
            if self.moving_params is None:
                if self.recur_cell.__name__ != 'GRUCell':
                    with tf.variable_scope("%s/Linear" %
                                           (self.recur_cell.__name__),
                                           reuse=True):
                        matrix = tf.get_variable('Weights')
                        n_splits = matrix.get_shape().as_list(
                        )[-1] // self.recur_size
                        I = tf.diag(tf.ones([self.recur_size]))
                        for W in tf.split(1, n_splits, matrix):
                            WTWmI = tf.matmul(W, W, transpose_a=True) - I
                            tf.add_to_collection('ortho_losses',
                                                 tf.nn.l2_loss(WTWmI))
                else:
                    for name in ['Gates', 'Candidate']:
                        with tf.variable_scope("GRUCell/%s/Linear" % (name),
                                               reuse=True):
                            matrix = tf.get_variable('Weights')
                            n_splits = matrix.get_shape().as_list(
                            )[-1] // self.recur_size
                            I = tf.diag(tf.ones([self.recur_size]))
                            for W in tf.split(1, n_splits, matrix):
                                WTWmI = tf.matmul(W, W, transpose_a=True) - I
                                tf.add_to_collection('ortho_losses',
                                                     tf.nn.l2_loss(WTWmI))

        if self.moving_params is None:
            tf.add_to_collection('recur_losses', self.recur_loss(top_recur))
            tf.add_to_collection('covar_losses', self.covar_loss(top_recur))
        return top_recur, end_recur