def test_integer_context_input_throws_error(self): seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) with self.assertRaisesRegexp(TypeError, 'context_input must have dtype float32'): sfc.concatenate_context_input(context_input, seq_input)
def test_context_input_throws_error(self, context_input_arg): context_input = ops.convert_to_tensor(context_input_arg) seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) context_input = math_ops.cast(context_input, dtype=dtypes.float32) with self.assertRaisesRegex(ValueError, 'context_input must have rank 2'): sfc.concatenate_context_input(context_input, seq_input)
def test_concatenate_context_input(self): seq_input = ops.convert_to_tensor(np.arange(12).reshape(2, 3, 2)) context_input = ops.convert_to_tensor(np.arange(10).reshape(2, 5)) seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) context_input = math_ops.cast(context_input, dtype=dtypes.float32) input_layer = sfc.concatenate_context_input(context_input, seq_input) expected = np.array([ [[0, 1, 0, 1, 2, 3, 4], [2, 3, 0, 1, 2, 3, 4], [4, 5, 0, 1, 2, 3, 4]], [[6, 7, 5, 6, 7, 8, 9], [8, 9, 5, 6, 7, 8, 9], [10, 11, 5, 6, 7, 8, 9]] ], dtype=np.float32) output = self.evaluate(input_layer) self.assertAllEqual(expected, output)
def test_sequence_example_into_input_layer(self): examples = [_make_sequence_example().SerializeToString()] * 100 ctx_cols, seq_cols = self._build_feature_columns() def _parse_example(example): ctx, seq = parsing_ops.parse_single_sequence_example( example, context_features=fc.make_parse_example_spec_v2(ctx_cols), sequence_features=fc.make_parse_example_spec_v2(seq_cols)) ctx.update(seq) return ctx ds = dataset_ops.Dataset.from_tensor_slices(examples) ds = ds.map(_parse_example) ds = ds.batch(20) # Test on a single batch features = dataset_ops.make_one_shot_iterator(ds).get_next() # Tile the context features across the sequence features sequence_input_layer = ksfc.SequenceFeatures(seq_cols) seq_layer, _ = sequence_input_layer(features) input_layer = dense_features.DenseFeatures(ctx_cols) ctx_layer = input_layer(features) input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer) rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10)) output = rnn_layer(input_layer) with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) features_r = sess.run(features) self.assertAllEqual(features_r['int_list'].dense_shape, [20, 3, 6]) output_r = sess.run(output) self.assertAllEqual(output_r.shape, [20, 10])
def test_sequence_example_into_input_layer(self): examples = [_make_sequence_example().SerializeToString()] * 100 ctx_cols, seq_cols = self._build_feature_columns() def _parse_example(example): ctx, seq = parsing_ops.parse_single_sequence_example( example, context_features=fc.make_parse_example_spec_v2(ctx_cols), sequence_features=fc.make_parse_example_spec_v2(seq_cols)) ctx.update(seq) return ctx ds = dataset_ops.Dataset.from_tensor_slices(examples) ds = ds.map(_parse_example) ds = ds.batch(20) # Test on a single batch features = ds.make_one_shot_iterator().get_next() # Tile the context features across the sequence features sequence_input_layer = sfc.SequenceFeatures(seq_cols) seq_layer, _ = sequence_input_layer(features) input_layer = fc.DenseFeatures(ctx_cols) ctx_layer = input_layer(features) input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer) rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10)) output = rnn_layer(input_layer) with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) features_r = sess.run(features) self.assertAllEqual(features_r['int_list'].dense_shape, [20, 3, 6]) output_r = sess.run(output) self.assertAllEqual(output_r.shape, [20, 10])
def call(self, tensors: Tuple[tf.Tensor, tf.Tensor], **kwargs): ctx, seq = tensors return sfc.concatenate_context_input(tf.cast(ctx, tf.float32), tf.cast(seq, tf.float32))