コード例 #1
0
    def testResidualWrapperWithSlice(self):
        wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper
        x = ops.convert_to_tensor_v2_with_dispatch(np.array(
            [[1., 1., 1., 1., 1.]]),
                                                   dtype="float32")
        m = ops.convert_to_tensor_v2_with_dispatch(np.array([[0.1, 0.1, 0.1]]),
                                                   dtype="float32")
        base_cell = rnn_cell_impl.GRUCell(
            3,
            kernel_initializer=init_ops.constant_initializer(0.5),
            bias_initializer=init_ops.constant_initializer(0.5))
        g, m_new = base_cell(x, m)

        def residual_with_slice_fn(inp, out):
            inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
            return inp_sliced + out

        g_res, m_new_res = wrapper_type(base_cell, residual_with_slice_fn)(x,
                                                                           m)
        self.evaluate([variables_lib.global_variables_initializer()])
        res_g, res_g_res, res_m_new, res_m_new_res = self.evaluate(
            [g, g_res, m_new, m_new_res])
        # Residual connections
        self.assertAllClose(res_g_res, res_g + [1., 1., 1.])
        # States are left untouched
        self.assertAllClose(res_m_new, res_m_new_res)
コード例 #2
0
    def testDeviceWrapper(self):
        wrapper_type = rnn_cell_wrapper_v2.DeviceWrapper
        x = array_ops.zeros([1, 3])
        m = array_ops.zeros([1, 3])
        cell = rnn_cell_impl.GRUCell(3)
        wrapped_cell = wrapper_type(cell, "/cpu:0")
        (name, dep), = wrapped_cell._checkpoint_dependencies
        wrapped_cell.get_config()  # Should not throw an error
        self.assertIs(dep, cell)
        self.assertEqual("cell", name)

        outputs, _ = wrapped_cell(x, m)
        self.assertIn("cpu:0", outputs.device.lower())
コード例 #3
0
    def testDeviceWrapper(self):
        wrapper_type = rnn_cell_wrapper_v2.DeviceWrapper
        x = array_ops.zeros([1, 3])
        m = array_ops.zeros([1, 3])
        cell = rnn_cell_impl.GRUCell(3)
        wrapped_cell = wrapper_type(cell, "/cpu:0")
        children = wrapped_cell._trackable_children()
        wrapped_cell.get_config()  # Should not throw an error
        self.assertIn("cell", children)
        self.assertIs(children["cell"], cell)

        outputs, _ = wrapped_cell(x, m)
        self.assertIn("cpu:0", outputs.device.lower())
コード例 #4
0
    def test_timeseries_classification_sequential_tf_rnn(self):
        np.random.seed(1337)
        (x_train, y_train), _ = testing_utils.get_test_data(train_samples=100,
                                                            test_samples=0,
                                                            input_shape=(4,
                                                                         10),
                                                            num_classes=2)
        y_train = np_utils.to_categorical(y_train)

        with base_layer.keras_style_scope():
            model = keras.models.Sequential()
            model.add(
                keras.layers.RNN(rnn_cell.LSTMCell(5),
                                 return_sequences=True,
                                 input_shape=x_train.shape[1:]))
            model.add(
                keras.layers.RNN(
                    rnn_cell.GRUCell(y_train.shape[-1],
                                     activation='softmax',
                                     dtype=dtypes.float32)))
            model.compile(loss='categorical_crossentropy',
                          optimizer=keras.optimizer_v2.adam.Adam(0.005),
                          metrics=['acc'],
                          run_eagerly=testing_utils.should_run_eagerly())

        history = model.fit(x_train,
                            y_train,
                            epochs=15,
                            batch_size=10,
                            validation_data=(x_train, y_train),
                            verbose=2)
        self.assertGreater(history.history['val_acc'][-1], 0.7)
        _, val_acc = model.evaluate(x_train, y_train)
        self.assertAlmostEqual(history.history['val_acc'][-1], val_acc)
        predictions = model.predict(x_train)
        self.assertEqual(predictions.shape, (x_train.shape[0], 2))
コード例 #5
0
    def testResidualWrapper(self):
        wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper
        x = ops.convert_to_tensor_v2_with_dispatch(np.array([[1., 1., 1.]]),
                                                   dtype="float32")
        m = ops.convert_to_tensor_v2_with_dispatch(np.array([[0.1, 0.1, 0.1]]),
                                                   dtype="float32")
        base_cell = rnn_cell_impl.GRUCell(
            3,
            kernel_initializer=init_ops.constant_initializer(0.5),
            bias_initializer=init_ops.constant_initializer(0.5))
        g, m_new = base_cell(x, m)
        wrapper_object = wrapper_type(base_cell)
        (name, dep), = wrapper_object._checkpoint_dependencies
        wrapper_object.get_config()  # Should not throw an error
        self.assertIs(dep, base_cell)
        self.assertEqual("cell", name)

        g_res, m_new_res = wrapper_object(x, m)
        self.evaluate([variables_lib.global_variables_initializer()])
        res = self.evaluate([g, g_res, m_new, m_new_res])
        # Residual connections
        self.assertAllClose(res[1], res[0] + [1., 1., 1.])
        # States are left untouched
        self.assertAllClose(res[2], res[3])