def testResidualWrapperSerialization(self): wrapper_cls = cell_wrappers.ResidualWrapper cell = layers.LSTMCell(10) wrapper = wrapper_cls(cell) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertDictEqual(config, reconstructed_wrapper.get_config()) self.assertIsInstance(reconstructed_wrapper, wrapper_cls) wrapper = wrapper_cls(cell, residual_fn=lambda i, o: i + i + o) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) # Assert the reconstructed function will perform the math correctly. self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 4) def residual_fn(inputs, outputs): return inputs * 3 + outputs wrapper = wrapper_cls(cell, residual_fn=residual_fn) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) # Assert the reconstructed function will perform the math correctly. self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 5)
def testDeviceWrapperSerialization(self): wrapper_cls = cell_wrappers.DeviceWrapper cell = layers.LSTMCell(10) wrapper = wrapper_cls(cell, "/cpu:0") config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertDictEqual(config, reconstructed_wrapper.get_config()) self.assertIsInstance(reconstructed_wrapper, wrapper_cls)
def testDropoutWrapperWithKerasLSTMCell(self): wrapper_cls = cell_wrappers.DropoutWrapper cell = layers.LSTMCell(10) with self.assertRaisesRegex(ValueError, "does not work with "): wrapper_cls(cell) cell = layers.LSTMCellV2(10) with self.assertRaisesRegex(ValueError, "does not work with "): wrapper_cls(cell)