Exemplo n.º 1
0
    def testResidualWrapperWithSlice(self):
        wrapper_type = cell_wrappers.ResidualWrapper
        x = tf.convert_to_tensor(
            np.array([[1.0, 1.0, 1.0, 1.0, 1.0]]), dtype="float32"
        )
        m = tf.convert_to_tensor(np.array([[0.1, 0.1, 0.1]]), dtype="float32")
        base_cell = legacy_cells.GRUCell(
            3,
            kernel_initializer=tf.compat.v1.constant_initializer(0.5),
            bias_initializer=tf.compat.v1.constant_initializer(0.5),
        )
        g, m_new = base_cell(x, m)

        def residual_with_slice_fn(inp, out):
            inp_sliced = tf.slice(inp, [0, 0], [-1, 3])
            return inp_sliced + out

        g_res, m_new_res = wrapper_type(base_cell, residual_with_slice_fn)(x, m)
        self.evaluate([tf.compat.v1.global_variables_initializer()])
        res_g, res_g_res, res_m_new, res_m_new_res = self.evaluate(
            [g, g_res, m_new, m_new_res]
        )
        # Residual connections
        self.assertAllClose(res_g_res, res_g + [1.0, 1.0, 1.0])
        # States are left untouched
        self.assertAllClose(res_m_new, res_m_new_res)
Exemplo n.º 2
0
    def testDeviceWrapper(self):
        wrapper_type = cell_wrappers.DeviceWrapper
        x = tf.zeros([1, 3])
        m = tf.zeros([1, 3])
        cell = legacy_cells.GRUCell(3)
        wrapped_cell = wrapper_type(cell, "/cpu:0")
        self.assertDictEqual({"cell": cell}, wrapped_cell._trackable_children())
        wrapped_cell.get_config()  # Should not throw an error

        outputs, _ = wrapped_cell(x, m)
        self.assertIn("cpu:0", outputs.device.lower())
Exemplo n.º 3
0
    def test_timeseries_classification_sequential_tf_rnn(self):
        np.random.seed(1337)
        (x_train, y_train), _ = test_utils.get_test_data(
            train_samples=100,
            test_samples=0,
            input_shape=(4, 10),
            num_classes=2,
        )
        y_train = utils.to_categorical(y_train)

        with base_layer.keras_style_scope():
            model = keras.models.Sequential()
            model.add(
                keras.layers.RNN(
                    legacy_cells.LSTMCell(5),
                    return_sequences=True,
                    input_shape=x_train.shape[1:],
                )
            )
            model.add(
                keras.layers.RNN(
                    legacy_cells.GRUCell(
                        y_train.shape[-1],
                        activation="softmax",
                        dtype=tf.float32,
                    )
                )
            )
            model.compile(
                loss="categorical_crossentropy",
                optimizer=keras.optimizers.optimizer_v2.adam.Adam(0.005),
                metrics=["acc"],
                run_eagerly=test_utils.should_run_eagerly(),
            )

        history = model.fit(
            x_train,
            y_train,
            epochs=15,
            batch_size=10,
            validation_data=(x_train, y_train),
            verbose=2,
        )
        self.assertGreater(history.history["val_acc"][-1], 0.7)
        _, val_acc = model.evaluate(x_train, y_train)
        self.assertAlmostEqual(history.history["val_acc"][-1], val_acc)
        predictions = model.predict(x_train)
        self.assertEqual(predictions.shape, (x_train.shape[0], 2))
Exemplo n.º 4
0
    def testResidualWrapper(self):
        wrapper_type = cell_wrappers.ResidualWrapper
        x = tf.convert_to_tensor(np.array([[1., 1., 1.]]), dtype="float32")
        m = tf.convert_to_tensor(np.array([[0.1, 0.1, 0.1]]), dtype="float32")
        base_cell = legacy_cells.GRUCell(
            3,
            kernel_initializer=tf.compat.v1.constant_initializer(0.5),
            bias_initializer=tf.compat.v1.constant_initializer(0.5))
        g, m_new = base_cell(x, m)
        wrapper_object = wrapper_type(base_cell)
        self.assertDictEqual({"cell": base_cell},
                             wrapper_object._trackable_children())
        wrapper_object.get_config()  # Should not throw an error

        g_res, m_new_res = wrapper_object(x, m)
        self.evaluate([tf.compat.v1.global_variables_initializer()])
        res = self.evaluate([g, g_res, m_new, m_new_res])
        # Residual connections
        self.assertAllClose(res[1], res[0] + [1., 1., 1.])
        # States are left untouched
        self.assertAllClose(res[2], res[3])