예제 #1
0
  def test_timeseries_classification_sequential_tf_rnn(self):
    np.random.seed(1337)
    (x_train, y_train), _ = testing_utils.get_test_data(
        train_samples=100,
        test_samples=0,
        input_shape=(4, 10),
        num_classes=2)
    y_train = np_utils.to_categorical(y_train)

    with base_layer.keras_style_scope():
      model = keras.models.Sequential()
      model.add(keras.layers.RNN(rnn_cell.LSTMCell(5), return_sequences=True,
                                 input_shape=x_train.shape[1:]))
      model.add(keras.layers.RNN(rnn_cell.GRUCell(y_train.shape[-1],
                                                  activation='softmax',
                                                  dtype=tf.float32)))
      model.compile(
          loss='categorical_crossentropy',
          optimizer=keras.optimizer_v2.adam.Adam(0.005),
          metrics=['acc'],
          run_eagerly=testing_utils.should_run_eagerly())

    history = model.fit(x_train, y_train, epochs=15, batch_size=10,
                        validation_data=(x_train, y_train),
                        verbose=2)
    self.assertGreater(history.history['val_acc'][-1], 0.7)
    _, val_acc = model.evaluate(x_train, y_train)
    self.assertAlmostEqual(history.history['val_acc'][-1], val_acc)
    predictions = model.predict(x_train)
    self.assertEqual(predictions.shape, (x_train.shape[0], 2))
예제 #2
0
  def testDeviceWrapper(self):
    wrapper_type = rnn_cell_wrapper_v2.DeviceWrapper
    x = tf.zeros([1, 3])
    m = tf.zeros([1, 3])
    cell = rnn_cell_impl.GRUCell(3)
    wrapped_cell = wrapper_type(cell, "/cpu:0")
    self.assertDictEqual({"cell": cell},
                         wrapped_cell._trackable_children())
    wrapped_cell.get_config()  # Should not throw an error

    outputs, _ = wrapped_cell(x, m)
    self.assertIn("cpu:0", outputs.device.lower())
예제 #3
0
    def testDeviceWrapper(self):
        wrapper_type = rnn_cell_wrapper_v2.DeviceWrapper
        x = tf.zeros([1, 3])
        m = tf.zeros([1, 3])
        cell = rnn_cell_impl.GRUCell(3)
        wrapped_cell = wrapper_type(cell, "/cpu:0")
        (name, dep), = wrapped_cell._checkpoint_dependencies
        wrapped_cell.get_config()  # Should not throw an error
        self.assertIs(dep, cell)
        self.assertEqual("cell", name)

        outputs, _ = wrapped_cell(x, m)
        self.assertIn("cpu:0", outputs.device.lower())
    def testResidualWrapper(self):
        wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper
        x = tf.convert_to_tensor(np.array([[1., 1., 1.]]), dtype="float32")
        m = tf.convert_to_tensor(np.array([[0.1, 0.1, 0.1]]), dtype="float32")
        base_cell = rnn_cell_impl.GRUCell(
            3,
            kernel_initializer=tf.compat.v1.constant_initializer(0.5),
            bias_initializer=tf.compat.v1.constant_initializer(0.5))
        g, m_new = base_cell(x, m)
        wrapper_object = wrapper_type(base_cell)
        self.assertDictEqual({"cell": base_cell},
                             wrapper_object._trackable_children())
        wrapper_object.get_config()  # Should not throw an error

        g_res, m_new_res = wrapper_object(x, m)
        self.evaluate([tf.compat.v1.global_variables_initializer()])
        res = self.evaluate([g, g_res, m_new, m_new_res])
        # Residual connections
        self.assertAllClose(res[1], res[0] + [1., 1., 1.])
        # States are left untouched
        self.assertAllClose(res[2], res[3])
예제 #5
0
  def testResidualWrapperWithSlice(self):
    wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper
    x = tf.convert_to_tensor(
        np.array([[1., 1., 1., 1., 1.]]), dtype="float32")
    m = tf.convert_to_tensor(
        np.array([[0.1, 0.1, 0.1]]), dtype="float32")
    base_cell = rnn_cell_impl.GRUCell(
        3, kernel_initializer=tf.compat.v1.constant_initializer(0.5),
        bias_initializer=tf.compat.v1.constant_initializer(0.5))
    g, m_new = base_cell(x, m)

    def residual_with_slice_fn(inp, out):
      inp_sliced = tf.slice(inp, [0, 0], [-1, 3])
      return inp_sliced + out

    g_res, m_new_res = wrapper_type(
        base_cell, residual_with_slice_fn)(x, m)
    self.evaluate([tf.compat.v1.global_variables_initializer()])
    res_g, res_g_res, res_m_new, res_m_new_res = self.evaluate(
        [g, g_res, m_new, m_new_res])
    # Residual connections
    self.assertAllClose(res_g_res, res_g + [1., 1., 1.])
    # States are left untouched
    self.assertAllClose(res_m_new, res_m_new_res)