Ejemplo n.º 1
0
  def testComputation(self):
    inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size])
    prev_state = tf.placeholder(tf.float32,
                                shape=[self.batch_size, self.in_size])

    vanilla_rnn = snt.VanillaRNN(name="rnn", hidden_size=self.in_size)
    residual = snt.SkipConnectionCore(vanilla_rnn, name="skip")

    output, new_state = residual(inputs, prev_state)
    in_to_hid = vanilla_rnn.in_to_hidden_variables
    hid_to_hid = vanilla_rnn.hidden_to_hidden_variables
    with self.test_session() as sess:
      # With random data, check the TF calculation matches the Numpy version.
      input_data = np.random.randn(self.batch_size, self.in_size)
      prev_state_data = np.random.randn(self.batch_size, self.in_size)
      tf.global_variables_initializer().run()

      fetches = [output, new_state, in_to_hid[0], in_to_hid[1],
                 hid_to_hid[0], hid_to_hid[1]]
      output = sess.run(fetches,
                        {inputs: input_data, prev_state: prev_state_data})
    output_v, new_state_v, in_to_hid_w, in_to_hid_b = output[:4]
    hid_to_hid_w, hid_to_hid_b = output[4:]

    real_in_to_hid = np.dot(input_data, in_to_hid_w) + in_to_hid_b
    real_hid_to_hid = np.dot(prev_state_data, hid_to_hid_w) + hid_to_hid_b
    vanilla_output = np.tanh(real_in_to_hid + real_hid_to_hid)
    skip_output = np.concatenate((input_data, vanilla_output), -1)

    self.assertAllClose(skip_output, output_v)
    self.assertAllClose(vanilla_output, new_state_v)
Ejemplo n.º 2
0
    def __init__(self,
                 num_embedding,
                 num_hidden,
                 lstm_depth,
                 output_size,
                 use_dynamic_rnn=True,
                 use_skip_connections=True,
                 name="text_model"):
        """Constructs a `TextModel`.

    Args:
      num_embedding: Size of embedding representation, used directly after the
        one-hot encoded input.
      num_hidden: Number of hidden units in each LSTM layer.
      lstm_depth: Number of LSTM layers.
      output_size: Size of the output layer on top of the DeepRNN.
      use_dynamic_rnn: Whether to use dynamic RNN unrolling. If `False`, it uses
        static unrolling. Default is `True`.
      use_skip_connections: Whether to use skip connections in the
        `snt.DeepRNN`. Default is `True`.
      name: Name of the module.
    """

        super(TextModel, self).__init__(name=name)

        self._num_embedding = num_embedding
        self._num_hidden = num_hidden
        self._lstm_depth = lstm_depth
        self._output_size = output_size
        self._use_dynamic_rnn = use_dynamic_rnn
        self._use_skip_connections = use_skip_connections

        with self._enter_variable_scope():
            self._embed_module = snt.Linear(self._num_embedding,
                                            name="linear_embed")
            self._output_module = snt.Linear(self._output_size,
                                             name="linear_output")
            self._subcores = [
                snt.LSTM(self._num_hidden, name="lstm_{}".format(i))
                for i in range(self._lstm_depth)
            ]
            if self._use_skip_connections:
                skips = []
                current_input_shape = self._num_embedding
                for lstm in self._subcores:
                    input_shape = tf.TensorShape([current_input_shape])
                    skip = snt.SkipConnectionCore(lstm,
                                                  input_shape=input_shape,
                                                  name="skip_{}".format(
                                                      lstm.module_name))
                    skips.append(skip)
                    # SkipConnectionCore concatenates the input with the output, so the
                    # dimensionality increases with depth.
                    current_input_shape += self._num_hidden
                self._subcores = skips
            self._core = snt.DeepRNN(self._subcores,
                                     skip_connections=False,
                                     name="deep_lstm")
Ejemplo n.º 3
0
  def testOutputSize(self):
    inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size])
    prev_state = tf.placeholder(
        tf.float32, shape=[self.batch_size, self.hidden_size])
    vanilla_rnn = snt.VanillaRNN(self.hidden_size)
    skip_wrapper = snt.SkipConnectionCore(vanilla_rnn, name="skip")

    with self.assertRaises(ValueError):
      _ = skip_wrapper.output_size

    skip_wrapper(inputs, prev_state)
    self.assertAllEqual([self.in_size + self.hidden_size],
                        skip_wrapper.output_size.as_list())

    skip_wrapper = snt.SkipConnectionCore(
        vanilla_rnn, input_shape=(self.in_size,), name="skip")
    self.assertAllEqual([self.in_size + self.hidden_size],
                        skip_wrapper.output_size.as_list())
Ejemplo n.º 4
0
  def testShape(self):
    inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size])
    prev_state = tf.placeholder(
        tf.float32, shape=[self.batch_size, self.hidden_size])
    vanilla_rnn = snt.VanillaRNN(self.hidden_size)
    skip_wrapper = snt.SkipConnectionCore(vanilla_rnn, name="skip")
    output, next_state = skip_wrapper(inputs, prev_state)
    output_shape = np.ndarray((self.batch_size,
                               self.in_size + self.hidden_size))
    state_shape = np.ndarray((self.batch_size, self.hidden_size))

    self.assertShapeEqual(output_shape, output)
    self.assertShapeEqual(state_shape, next_state)
Ejemplo n.º 5
0
    def testHeterogeneousState(self):
        """Checks that the shape and type of the initial state are preserved."""

        core = HeterogeneousStateCore(name="rnn", hidden_size=self.hidden_size)
        skip_wrapper = snt.SkipConnectionCore(core, name="skip")

        core_state = core.initial_state(self.batch_size)
        skip_state = skip_wrapper.initial_state(self.batch_size)

        self.assertEqual(core_state[0].shape.as_list(),
                         skip_state[0].shape.as_list())
        self.assertEqual(core_state[1].shape.as_list(),
                         skip_state[1].shape.as_list())
        self.assertEqual(core_state[0].dtype, skip_state[0].dtype)
        self.assertEqual(core_state[1].dtype, skip_state[1].dtype)