예제 #1
0
  def test_timeseries_classification_sequential_tf_rnn(self):
    np.random.seed(1337)
    (x_train, y_train), _ = testing_utils.get_test_data(
        train_samples=100,
        test_samples=0,
        input_shape=(4, 10),
        num_classes=2)
    y_train = np_utils.to_categorical(y_train)

    with base_layer.keras_style_scope():
      model = keras.models.Sequential()
      model.add(keras.layers.RNN(rnn_cell.LSTMCell(5), return_sequences=True,
                                 input_shape=x_train.shape[1:]))
      model.add(keras.layers.RNN(rnn_cell.GRUCell(y_train.shape[-1],
                                                  activation='softmax',
                                                  dtype=dtypes.float32)))
      model.compile(
          loss='categorical_crossentropy',
          optimizer=keras.optimizer_v2.adam.Adam(0.005),
          metrics=['acc'],
          run_eagerly=testing_utils.should_run_eagerly())

    history = model.fit(x_train, y_train, epochs=15, batch_size=10,
                        validation_data=(x_train, y_train),
                        verbose=2)
    self.assertGreater(history.history['val_acc'][-1], 0.7)
    _, val_acc = model.evaluate(x_train, y_train)
    self.assertAlmostEqual(history.history['val_acc'][-1], val_acc)
    predictions = model.predict(x_train)
    self.assertEqual(predictions.shape, (x_train.shape[0], 2))
예제 #2
0
  def testWrapperV2Caller(self, wrapper):
    """Tests that wrapper V2 is using the LayerRNNCell's caller."""

    with legacy_base_layer.keras_style_scope():
      base_cell = rnn_cell_impl.MultiRNNCell(
          [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)])
    rnn_cell = wrapper(base_cell)
    inputs = ops.convert_to_tensor_v2_with_dispatch([[1]], dtype=dtypes.float32)
    state = ops.convert_to_tensor_v2_with_dispatch([[1]], dtype=dtypes.float32)
    _ = rnn_cell(inputs, [state, state])
    weights = base_cell._cells[0].weights
    self.assertLen(weights, expected_len=2)
    self.assertTrue(all("_wrapper" in v.name for v in weights))
예제 #3
0
  def testKerasStyleAddWeight(self):
    keras_layer = keras_base_layer.Layer(name='keras_layer')
    with ops.name_scope('foo', skip_on_eager=False):
      keras_variable = keras_layer.add_variable(
          'my_var', [2, 2], initializer=init_ops.zeros_initializer())
    self.assertEqual(keras_variable.name, 'foo/my_var:0')

    with ops.name_scope('baz', skip_on_eager=False):
      old_style_layer = base_layers.Layer(name='my_layer')
      # Test basic variable creation.
      variable = old_style_layer.add_variable(
          'my_var', [2, 2], initializer=init_ops.zeros_initializer())
    self.assertEqual(variable.name, 'my_layer/my_var:0')

    with base_layers.keras_style_scope():
      layer = base_layers.Layer(name='my_layer')
    # Test basic variable creation.
    with ops.name_scope('bar', skip_on_eager=False):
      variable = layer.add_variable(
          'my_var', [2, 2], initializer=init_ops.zeros_initializer())
    self.assertEqual(variable.name, 'bar/my_var:0')
예제 #4
0
  def testKerasStyleAddWeight(self):
    keras_layer = keras_base_layer.Layer(name='keras_layer')
    with backend.name_scope('foo'):
      keras_variable = keras_layer.add_variable(
          'my_var', [2, 2], initializer=init_ops.zeros_initializer())
    self.assertEqual(keras_variable.name, 'foo/my_var:0')

    with backend.name_scope('baz'):
      old_style_layer = base_layers.Layer(name='my_layer')
      # Test basic variable creation.
      variable = old_style_layer.add_variable(
          'my_var', [2, 2], initializer=init_ops.zeros_initializer())
    self.assertEqual(variable.name, 'my_layer/my_var:0')

    with base_layers.keras_style_scope():
      layer = base_layers.Layer(name='my_layer')
    # Assert that the layer was not instrumented as a Keras layer
    self.assertFalse(layer._instrumented_keras_api)
    # Test basic variable creation.
    with backend.name_scope('bar'):
      variable = layer.add_variable(
          'my_var', [2, 2], initializer=init_ops.zeros_initializer())
    self.assertEqual(variable.name, 'bar/my_var:0')