Example #1
0
 def backward_cell(self):
     """RNN cell for the backward RNN."""
     with tf.compat.v1.variable_scope("backward_cell") as scope:
         cell = graph_utils.create_multilayer_cell(self.rnn_cell, scope,
             self.dim, self.num_layers, self.input_keep, self.output_keep,
             variational_recurrent=self.variational_recurrent_dropout)
     return cell
Example #2
0
    def char_channel_embeddings(self, channel_inputs):
        """
        Generate token representations by character composition.

        :param channel_inputs: batch input char indices
                [[batch, token_size], [batch, token_size], ...]
        :return: embeddings_char [source_vocab_size, char_channel_dim]
        """
        inputs = [tf.squeeze(x, 1) for x in tf.split(axis=1,
                  num_or_size_splits=self.max_source_token_size,
                  value=tf.concat(axis=0, values=channel_inputs))]
        input_embeddings = [tf.nn.embedding_lookup(params=self.char_embeddings(), ids=input) 
                            for input in inputs]
        if self.sc_char_composition == 'rnn':
            with tf.compat.v1.variable_scope("encoder_char_rnn",
                                   reuse=self.char_rnn_vars) as scope:
                cell = graph_utils.create_multilayer_cell(
                    self.sc_char_rnn_cell, scope,
                    self.sc_char_dim, self.sc_char_rnn_num_layers,
                    variational_recurrent=self.variational_recurrent_dropout)
                rnn_outputs, rnn_states = graph_utils.RNNModel(cell, input_embeddings,
                                                               dtype=tf.float32)
                self.char_rnn_vars = True
        else:
            raise NotImplementedError

        return [tf.squeeze(x, 0) for x in
                tf.split(axis=0, num_or_size_splits=len(channel_inputs),
                    value=tf.reshape(rnn_states[-1],
                        [len(channel_inputs), -1, self.sc_char_dim]))]
Example #3
0
 def horizontal_cell(self):
     """Cell that controls transition from left sibling to right sibling."""
     with tf.compat.v1.variable_scope("horizontal_cell") as scope:
         cell = graph_utils.create_multilayer_cell(self.rnn_cell, scope,
                                                   self.dim, self.num_layers,
                                                   self.tg_input_keep,
                                                   self.tg_output_keep)
     return cell, scope
Example #4
0
 def vertical_cell(self):
     """Cell that controls transition from parent to child."""
     with tf.compat.v1.variable_scope("vertical_cell") as scope:
         cell = graph_utils.create_multilayer_cell(self.rnn_cell, scope,
                                                   self.dim, self.num_layers,
                                                   self.tg_input_keep,
                                                   self.tg_output_keep)
     return cell, scope
Example #5
0
 def decoder_cell(self):
     if self.copynet:
         input_size = self.dim * 2
     else:
         input_size = self.dim
     with tf.compat.v1.variable_scope(self.scope +
                                      "_decoder_cell") as scope:
         cell = graph_utils.create_multilayer_cell(
             self.rnn_cell,
             scope,
             self.dim,
             self.num_layers,
             self.input_keep,
             self.output_keep,
             variational_recurrent=self.variational_recurrent_dropout,
             input_dim=input_size)
     return cell