Example #1
0
 def _test_create_net_max_q_parametric_action(self, normalize):
     extractor = TrainingFeatureExtractor(
         state_normalization_parameters=self.
         get_state_normalization_parameters(),
         action_normalization_parameters=self.
         get_action_normalization_parameters(),
         include_possible_actions=True,
         normalize=normalize,
         max_num_actions=2,
     )
     expected_input_record = schema.Struct(
         ("state_features", map_schema()),
         ("next_state_features", map_schema()),
         ("action", map_schema()),
         ("next_action", map_schema()),
         ("not_terminal", schema.Scalar()),
         ("possible_actions", schema.List(map_schema())),
         ("possible_actions_mask", schema.List(schema.Scalar())),
         ("possible_next_actions", schema.List(map_schema())),
         ("possible_next_actions_mask", schema.List(schema.Scalar())),
     )
     expected_output_record = schema.Struct(
         ("state_features", schema.Scalar()),
         ("next_state_features", schema.Scalar()),
         ("action", schema.Scalar()),
         ("next_action", schema.Scalar()),
         ("not_terminal", schema.Scalar()),
         ("possible_actions", schema.Scalar()),
         ("possible_actions_mask", schema.Scalar()),
         ("possible_next_actions", schema.Scalar()),
         ("possible_next_actions_mask", schema.Scalar()),
     )
     self.check_create_net_spec(extractor, expected_input_record,
                                expected_output_record)
Example #2
0
 def testStructIndexing(self):
     s = schema.Struct(('field1', schema.Scalar(dtype=np.int32)),
                       ('field2', schema.List(schema.Scalar(dtype=str))))
     self.assertEquals(s['field2'], s.field2)
     self.assertEquals(s['field2'], schema.List(schema.Scalar(dtype=str)))
     self.assertEquals(
         s['field2', 'field1'],
         schema.Struct(
             ('field2', schema.List(schema.Scalar(dtype=str))),
             ('field1', schema.Scalar(dtype=np.int32)),
         ))
Example #3
0
 def testInitShouldSetFieldOffsets(self):
     f = schema.Field([
         schema.Scalar(dtype=np.int32),
         schema.Struct(
             ('field1', schema.Scalar(dtype=np.int32)),
             ('field2', schema.List(schema.Scalar(dtype=str))),
         ),
         schema.Scalar(dtype=np.int32),
         schema.Struct(('field3', schema.Scalar(dtype=np.int32)),
                       ('field4', schema.List(schema.Scalar(dtype=str)))),
         schema.Scalar(dtype=np.int32),
     ])
     self.assertListEqual(f._field_offsets, [0, 1, 4, 5, 8, 9])
Example #4
0
    def testPreservesMetadata(self):
        s = schema.Struct(
            ('a', schema.Scalar(np.float32)),
            ('b',
             schema.Scalar(np.int32,
                           metadata=schema.Metadata(categorical_limit=5))),
            ('c',
             schema.List(
                 schema.Scalar(
                     np.int32,
                     metadata=schema.Metadata(categorical_limit=6)))))
        # attach metadata to lengths field
        s.c.lengths.set_metadata(schema.Metadata(categorical_limit=7))

        self.assertEqual(None, s.a.metadata)
        self.assertEqual(5, s.b.metadata.categorical_limit)
        self.assertEqual(6, s.c.value.metadata.categorical_limit)
        self.assertEqual(7, s.c.lengths.metadata.categorical_limit)
        sc = s.clone()
        self.assertEqual(None, sc.a.metadata)
        self.assertEqual(5, sc.b.metadata.categorical_limit)
        self.assertEqual(6, sc.c.value.metadata.categorical_limit)
        self.assertEqual(7, sc.c.lengths.metadata.categorical_limit)
        sv = schema.from_blob_list(s, [
            np.array([3.4]),
            np.array([2]),
            np.array([3]),
            np.array([1, 2, 3])
        ])
        self.assertEqual(None, sv.a.metadata)
        self.assertEqual(5, sv.b.metadata.categorical_limit)
        self.assertEqual(6, sv.c.value.metadata.categorical_limit)
        self.assertEqual(7, sv.c.lengths.metadata.categorical_limit)
Example #5
0
 def testPicklable(self):
     s = schema.Struct(('field1', schema.Scalar(dtype=np.int32)),
                       ('field2', schema.List(schema.Scalar(dtype=str))))
     s2 = pickle.loads(pickle.dumps(s))
     for r in (s, s2):
         self.assertTrue(isinstance(r.field1, schema.Scalar))
         self.assertTrue(isinstance(r.field2, schema.List))
         self.assertTrue(getattr(r, 'non_existent', None) is None)
Example #6
0
 def testListInStructIndexing(self):
     a = schema.List(schema.Scalar(dtype=str))
     s = schema.Struct(('field1', schema.Scalar(dtype=np.int32)),
                       ('field2', a))
     self.assertEquals(s['field2:lengths'], a.lengths)
     self.assertEquals(s['field2:values'], a.items)
     with self.assertRaises(KeyError):
         s['fields2:items:non_existent']
     with self.assertRaises(KeyError):
         s['fields2:non_existent']
Example #7
0
 def __init__(self, model, input_record, name='bpr_loss', **kwargs):
     super(BPRLoss, self).__init__(model, name, input_record, **kwargs)
     assert schema.is_schema_subset(
         schema.Struct(
             ('pos_prediction', schema.Scalar()),
             ('neg_prediction', schema.List(np.float32)),
         ), input_record)
     self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
     self.output_schema = schema.Scalar(
         np.float32, self.get_next_blob_reference('output'))
 def test_create_net_max_q_discrete_action(self):
     extractor = TrainingFeatureExtractor(
         state_normalization_parameters=self.
         get_state_normalization_parameters(),
         max_q_learning=True,
     )
     expected_input_record = schema.Struct(
         ("state_features", map_schema()),
         ("next_state_features", map_schema()),
         ("action", schema.Scalar()),
         ("possible_next_actions", schema.List(schema.Scalar())),
     )
     expected_output_record = schema.Struct(
         ("state", schema.Scalar()),
         ("next_state", schema.Scalar()),
         ("action", schema.Scalar()),
         ("possible_next_actions", schema.List(schema.Scalar())),
     )
     self.check_create_net_spec(extractor, expected_input_record,
                                expected_output_record)
Example #9
0
 def testMarginRankLoss(self):
     input_record = self.new_record(
         schema.Struct(
             ('pos_prediction', schema.Scalar((np.float32, (1, )))),
             ('neg_prediction', schema.List(np.float32)),
         ))
     pos_items = np.array([0.1, 0.2, 0.3], dtype=np.float32)
     neg_lengths = np.array([1, 2, 3], dtype=np.int32)
     neg_items = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=np.float32)
     schema.FeedRecord(input_record, [pos_items, neg_lengths, neg_items])
     loss = self.model.MarginRankLoss(input_record)
     self.run_train_net_forward_only()
     self.assertEqual(schema.Scalar((np.float32, tuple())), loss)
Example #10
0
    def __init__(self,
                 model,
                 input_record,
                 seed=0,
                 modulo=None,
                 use_hashing=True,
                 name='sparse_feature_hash',
                 **kwargs):
        super(SparseFeatureHash, self).__init__(model, name, input_record,
                                                **kwargs)

        self.seed = seed
        self.use_hashing = use_hashing
        if schema.equal_schemas(input_record, IdList):
            self.modulo = modulo or self.extract_hash_size(
                input_record.items.metadata)
            metadata = schema.Metadata(
                categorical_limit=self.modulo,
                feature_specs=input_record.items.metadata.feature_specs,
            )
            hashed_indices = schema.Scalar(
                np.int64, self.get_next_blob_reference("hashed_idx"))
            hashed_indices.set_metadata(metadata)
            self.output_schema = schema.List(
                values=hashed_indices,
                lengths_blob=input_record.lengths,
            )
        elif schema.equal_schemas(input_record, IdScoreList):
            self.modulo = modulo or self.extract_hash_size(
                input_record.keys.metadata)
            metadata = schema.Metadata(
                categorical_limit=self.modulo,
                feature_specs=input_record.keys.metadata.feature_specs,
            )
            hashed_indices = schema.Scalar(
                np.int64, self.get_next_blob_reference("hashed_idx"))
            hashed_indices.set_metadata(metadata)
            self.output_schema = schema.Map(
                keys=hashed_indices,
                values=input_record.values,
                lengths_blob=input_record.lengths,
            )
        else:
            assert False, "Input type must be one of (IdList, IdScoreList)"

        assert self.modulo >= 1, 'Unexpected modulo: {}'.format(self.modulo)

        # operators in this layer do not have CUDA implementation yet.
        # In addition, since the sparse feature keys that we are hashing are
        # typically on CPU originally, it makes sense to have this layer on CPU.
        self.tags.update([Tags.CPU_ONLY])
Example #11
0
    def __init__(self,
                 model,
                 input_record,
                 seed,
                 name='sparse_feature_hash',
                 **kwargs):
        super(SparseFeatureHash, self).__init__(model, name, input_record,
                                                **kwargs)

        self.seed = seed
        self.lengths_blob = schema.Scalar(
            np.int32,
            model.net.NextScopedBlob(name + "_lengths"),
        )

        if schema.equal_schemas(input_record, IdList):
            self.modulo = self.extract_hash_size(input_record.items.metadata)
            metadata = schema.Metadata(
                categorical_limit=self.modulo,
                feature_specs=input_record.items.metadata.feature_specs,
            )
            hashed_indices = schema.Scalar(
                np.int64, model.net.NextScopedBlob(name + "_hashed_idx"))
            hashed_indices.set_metadata(metadata)
            self.output_schema = schema.List(
                values=hashed_indices,
                lengths_blob=self.lengths_blob,
            )
        elif schema.equal_schemas(input_record, IdScoreList):
            self.values_blob = schema.Scalar(
                np.float32,
                model.net.NextScopedBlob(name + "_values"),
            )
            self.modulo = self.extract_hash_size(input_record.keys.metadata)
            metadata = schema.Metadata(
                categorical_limit=self.modulo,
                feature_specs=input_record.keys.metadata.feature_specs,
            )
            hashed_indices = schema.Scalar(
                np.int64, model.net.NextScopedBlob(name + "_hashed_idx"))
            hashed_indices.set_metadata(metadata)
            self.output_schema = schema.Map(
                keys=hashed_indices,
                values=self.values_blob,
                lengths_blob=self.lengths_blob,
            )
        else:
            assert False, "Input type must be one of (IdList, IdScoreList)"
Example #12
0
 def testFromColumnList(self):
     st = schema.Struct(('a', schema.Scalar()),
                        ('b', schema.List(schema.Scalar())),
                        ('c', schema.Map(schema.Scalar(), schema.Scalar())))
     columns = st.field_names()
     # test that recovery works for arbitrary order
     for _ in range(10):
         some_blobs = [core.BlobReference('blob:' + x) for x in columns]
         rec = schema.from_column_list(columns, col_blobs=some_blobs)
         self.assertTrue(rec.has_blobs())
         self.assertEqual(sorted(st.field_names()),
                          sorted(rec.field_names()))
         self.assertEqual(
             [str(blob) for blob in rec.field_blobs()],
             [str('blob:' + name) for name in rec.field_names()])
         random.shuffle(columns)
Example #13
0
 def __init__(self, model, input_record, name='margin_rank_loss',
              margin=0.1, **kwargs):
     super(MarginRankLoss, self).__init__(model, name, input_record, **kwargs)
     assert margin >= 0, ('For hinge loss, margin should be no less than 0')
     self._margin = margin
     assert schema.is_schema_subset(
         schema.Struct(
             ('pos_prediction', schema.Scalar()),
             ('neg_prediction', schema.List(np.float32)),
         ),
         input_record
     )
     self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
     self.output_schema = schema.Scalar(
         np.float32,
         self.get_next_blob_reference('output'))
Example #14
0
    def __init__(self,
                 model,
                 input_record,
                 seed=0,
                 modulo=None,
                 use_hashing=True,
                 name='sparse_feature_hash',
                 **kwargs):
        super(SparseFeatureHash, self).__init__(model, name, input_record,
                                                **kwargs)

        self.seed = seed
        self.use_hashing = use_hashing
        if schema.equal_schemas(input_record, IdList):
            self.modulo = modulo or self.extract_hash_size(
                input_record.items.metadata)
            metadata = schema.Metadata(
                categorical_limit=self.modulo,
                feature_specs=input_record.items.metadata.feature_specs,
            )
            hashed_indices = schema.Scalar(
                np.int64, self.get_next_blob_reference("hashed_idx"))
            hashed_indices.set_metadata(metadata)
            self.output_schema = schema.List(
                values=hashed_indices,
                lengths_blob=input_record.lengths,
            )
        elif schema.equal_schemas(input_record, IdScoreList):
            self.modulo = modulo or self.extract_hash_size(
                input_record.keys.metadata)
            metadata = schema.Metadata(
                categorical_limit=self.modulo,
                feature_specs=input_record.keys.metadata.feature_specs,
            )
            hashed_indices = schema.Scalar(
                np.int64, self.get_next_blob_reference("hashed_idx"))
            hashed_indices.set_metadata(metadata)
            self.output_schema = schema.Map(
                keys=hashed_indices,
                values=input_record.values,
                lengths_blob=input_record.lengths,
            )
        else:
            assert False, "Input type must be one of (IdList, IdScoreList)"

        assert self.modulo >= 1, 'Unexpected modulo: {}'.format(self.modulo)
Example #15
0
    def __init__(self, model, input_record, name='merged'):
        super(MergeIdLists, self).__init__(model, name, input_record)
        assert all(schema.equal_schemas(x, IdList) for x in input_record), \
            "Inputs to MergeIdLists should all be IdLists."

        assert all(record.items.metadata is not None
                   for record in self.input_record), \
            "Features without metadata are not supported"

        merge_dim = max(
            get_categorical_limit(record) for record in self.input_record)
        assert merge_dim is not None, "Unbounded features are not supported"

        self.output_schema = schema.NewRecord(
            model.net,
            schema.List(
                schema.Scalar(
                    np.int64,
                    blob=model.net.NextBlob(name),
                    metadata=schema.Metadata(categorical_limit=merge_dim))))
Example #16
0
    def testMergeIdListsLayer(self, num_inputs, batch_size):
        inputs = []
        for _ in range(num_inputs):
            lengths = np.random.randint(5, size=batch_size).astype(np.int32)
            size = lengths.sum()
            values = np.random.randint(1, 10, size=size).astype(np.int64)
            inputs.append(lengths)
            inputs.append(values)
        input_schema = schema.Tuple(*[
            schema.List(
                schema.Scalar(dtype=np.int64,
                              metadata=schema.Metadata(categorical_limit=20)))
            for _ in range(num_inputs)
        ])

        input_record = schema.NewRecord(self.model.net, input_schema)
        schema.FeedRecord(input_record, inputs)
        output_schema = self.model.MergeIdLists(input_record)
        assert schema.equal_schemas(output_schema,
                                    IdList,
                                    check_field_names=False)
Example #17
0
    def testSparseLookup(self):
        record = schema.NewRecord(self.model.net, schema.Struct(
            ('sparse', schema.Struct(
                ('sparse_feature_0', schema.List(
                    schema.Scalar(np.int64,
                                  metadata=schema.Metadata(categorical_limit=1000)))),
            )),
        ))
        embedding_dim = 64
        embedding_after_pooling = self.model.SparseLookup(
            record.sparse.sparse_feature_0, [embedding_dim], 'Sum')
        self.model.output_schema = embedding_after_pooling
        self.assertEqual(
            schema.Scalar((np.float32, (embedding_dim, ))),
            embedding_after_pooling
        )

        train_init_net, train_net = self.get_training_nets()

        init_ops = self.assertNetContainOps(
            train_init_net,
            [
                OpSpec("UniformFill", None, None),
                OpSpec("ConstantFill", None, None),
            ]
        )
        sparse_lookup_op_spec = OpSpec(
            'SparseLengthsSum',
            [
                init_ops[0].output[0],
                record.sparse.sparse_feature_0.items(),
                record.sparse.sparse_feature_0.lengths(),
            ],
            [embedding_after_pooling()]
        )
        self.assertNetContainOps(train_net, [sparse_lookup_op_spec])

        predict_net = self.get_predict_net()
        self.assertNetContainOps(predict_net, [sparse_lookup_op_spec])
Example #18
0
    def testGatherRecord(self):
        indices = np.array([1, 3, 4], dtype=np.int32)
        dense = np.array(range(20), dtype=np.float32).reshape(10, 2)
        lengths = np.array(range(10), dtype=np.int32)
        items = np.array(range(lengths.sum()), dtype=np.int64)
        items_lengths = np.array(range(lengths.sum()), dtype=np.int32)
        items_items = np.array(range(items_lengths.sum()), dtype=np.int64)
        record = self.new_record(
            schema.Struct(
                ('dense', schema.Scalar(np.float32)),
                ('sparse',
                 schema.Struct(
                     ('list', schema.List(np.int64)),
                     ('list_of_list', schema.List(schema.List(np.int64))),
                 )), ('empty_struct', schema.Struct())))
        indices_record = self.new_record(schema.Scalar(np.int32))
        input_record = schema.Struct(
            ('indices', indices_record),
            ('record', record),
        )
        schema.FeedRecord(input_record, [
            indices, dense, lengths, items, lengths, items_lengths, items_items
        ])
        gathered_record = self.model.GatherRecord(input_record)
        self.assertTrue(schema.equal_schemas(gathered_record, record))

        self.run_train_net_forward_only()
        gathered_dense = workspace.FetchBlob(gathered_record.dense())
        np.testing.assert_array_equal(
            np.concatenate([dense[i:i + 1] for i in indices]), gathered_dense)
        gathered_lengths = workspace.FetchBlob(
            gathered_record.sparse.list.lengths())
        np.testing.assert_array_equal(
            np.concatenate([lengths[i:i + 1] for i in indices]),
            gathered_lengths)
        gathered_items = workspace.FetchBlob(
            gathered_record.sparse.list.items())
        offsets = lengths.cumsum() - lengths
        np.testing.assert_array_equal(
            np.concatenate(
                [items[offsets[i]:offsets[i] + lengths[i]] for i in indices]),
            gathered_items)

        gathered_items_lengths = workspace.FetchBlob(
            gathered_record.sparse.list_of_list.items.lengths())
        np.testing.assert_array_equal(
            np.concatenate([
                items_lengths[offsets[i]:offsets[i] + lengths[i]]
                for i in indices
            ]), gathered_items_lengths)

        nested_offsets = []
        nested_lengths = []
        nested_offset = 0
        j = 0
        for l in lengths:
            nested_offsets.append(nested_offset)
            nested_length = 0
            for _i in range(l):
                nested_offset += items_lengths[j]
                nested_length += items_lengths[j]
                j += 1
            nested_lengths.append(nested_length)

        gathered_items_items = workspace.FetchBlob(
            gathered_record.sparse.list_of_list.items.items())
        np.testing.assert_array_equal(
            np.concatenate([
                items_items[nested_offsets[i]:nested_offsets[i] +
                            nested_lengths[i]] for i in indices
            ]), gathered_items_items)
Example #19
0
## @package layers
# Module caffe2.python.layers.layers
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from caffe2.python import core, schema, scope
from caffe2.python.layers.tags import TagContext

from collections import namedtuple
import numpy as np

# Some types to simplify descriptions of things traveling between ops
IdList = schema.List(np.int64)
IdScoreList = schema.Map(np.int64, np.float32)


def get_categorical_limit(record):
    if schema.equal_schemas(record, IdList):
        key = 'items'
    elif schema.equal_schemas(record, IdScoreList, check_field_types=False):
        key = 'keys'
    else:
        raise NotImplementedError()
    assert record[key].metadata is not None, (
        "Blob {} doesn't have metadata".format(str(record[key]())))
    return record[key].metadata.categorical_limit


def set_request_only(field):
Example #20
0
    def create_net(self):
        net = core.Net("feature_extractor")
        init_net = core.Net("feature_extractor_init")
        missing_scalar = self.create_const(init_net, "MISSING_SCALAR",
                                           MISSING_VALUE)

        action_schema = map_schema(
        ) if self.sorted_action_features else schema.Scalar()

        if self.max_q_learning:
            next_action_field = InputColumn.POSSIBLE_NEXT_ACTIONS
            next_action_schema = schema.List(action_schema)
        else:
            next_action_field = InputColumn.NEXT_ACTION
            next_action_schema = action_schema

        input_schema = schema.Struct(
            (InputColumn.STATE_FEATURES, map_schema()),
            (InputColumn.NEXT_STATE_FEATURES, map_schema()),
            (InputColumn.ACTION, action_schema),
            (next_action_field, next_action_schema),
        )

        input_record = net.set_input_record(input_schema)

        state = self.extract_float_features(
            net,
            "state",
            input_record[InputColumn.STATE_FEATURES],
            self.sorted_state_features,
            missing_scalar,
        )
        next_state = self.extract_float_features(
            net,
            "next_state",
            input_record[InputColumn.NEXT_STATE_FEATURES],
            self.sorted_state_features,
            missing_scalar,
        )

        action = input_record.action
        next_action = input_record[next_action_field]
        if self.max_q_learning:
            next_action = next_action["values"]
        if self.sorted_action_features:
            action = self.extract_float_features(net, "action", action,
                                                 self.sorted_action_features,
                                                 missing_scalar)
            next_action = self.extract_float_features(
                net,
                next_action_field,
                next_action,
                self.sorted_action_features,
                missing_scalar,
            )

        next_action_output = (schema.List(
            next_action,
            lengths_blob=input_record.possible_next_actions.lengths)
                              if self.max_q_learning else next_action)

        net.set_output_record(
            schema.Struct(
                ("state", state),
                ("action", action),
                ("next_state", next_state),
                (next_action_field, next_action_output),
            ))

        return FeatureExtractorNet(net, init_net)
Example #21
0
    def create_net(self):
        net = core.Net("feature_extractor")
        init_net = core.Net("feature_extractor_init")
        missing_scalar = self.create_const(init_net, "MISSING_SCALAR",
                                           MISSING_VALUE)

        action_schema = map_schema(
        ) if self.sorted_action_features else schema.Scalar()

        input_schema = schema.Struct(
            (InputColumn.STATE_FEATURES, map_schema()),
            (InputColumn.NEXT_STATE_FEATURES, map_schema()),
            (InputColumn.ACTION, action_schema),
            (InputColumn.NEXT_ACTION, action_schema),
            (InputColumn.NOT_TERMINAL, schema.Scalar()),
            (InputColumn.TIME_DIFF, schema.Scalar()),
        )
        if self.include_possible_actions:
            input_schema += schema.Struct(
                (InputColumn.POSSIBLE_ACTIONS_MASK, schema.List(
                    schema.Scalar())),
                (InputColumn.POSSIBLE_NEXT_ACTIONS_MASK,
                 schema.List(schema.Scalar())),
            )
            if self.sorted_action_features is not None:
                input_schema += schema.Struct(
                    (InputColumn.POSSIBLE_ACTIONS, schema.List(map_schema())),
                    (InputColumn.POSSIBLE_NEXT_ACTIONS,
                     schema.List(map_schema())),
                )

        input_record = net.set_input_record(input_schema)

        state = self.extract_float_features(
            net,
            "state",
            input_record[InputColumn.STATE_FEATURES],
            self.sorted_state_features,
            missing_scalar,
        )
        next_state = self.extract_float_features(
            net,
            "next_state",
            input_record[InputColumn.NEXT_STATE_FEATURES],
            self.sorted_state_features,
            missing_scalar,
        )

        if self.sorted_action_features:
            action = self.extract_float_features(
                net,
                InputColumn.ACTION,
                input_record[InputColumn.ACTION],
                self.sorted_action_features,
                missing_scalar,
            )
            next_action = self.extract_float_features(
                net,
                InputColumn.NEXT_ACTION,
                input_record[InputColumn.NEXT_ACTION],
                self.sorted_action_features,
                missing_scalar,
            )
            if self.include_possible_actions:
                possible_action_features = self.extract_float_features(
                    net,
                    InputColumn.POSSIBLE_ACTIONS,
                    input_record[InputColumn.POSSIBLE_ACTIONS]["values"],
                    self.sorted_action_features,
                    missing_scalar,
                )
                possible_next_action_features = self.extract_float_features(
                    net,
                    InputColumn.POSSIBLE_NEXT_ACTIONS,
                    input_record[InputColumn.POSSIBLE_NEXT_ACTIONS]["values"],
                    self.sorted_action_features,
                    missing_scalar,
                )
        else:
            action = input_record[InputColumn.ACTION]
            next_action = input_record[InputColumn.NEXT_ACTION]

        if self.normalize:
            C2.set_net_and_init_net(net, init_net)
            state, _ = PreprocessorNet().normalize_dense_matrix(
                state,
                self.sorted_state_features,
                self.state_normalization_parameters,
                blobname_prefix="state",
                split_expensive_feature_groups=True,
            )
            next_state, _ = PreprocessorNet().normalize_dense_matrix(
                next_state,
                self.sorted_state_features,
                self.state_normalization_parameters,
                blobname_prefix="next_state",
                split_expensive_feature_groups=True,
            )
            if self.sorted_action_features is not None:
                action, _ = PreprocessorNet().normalize_dense_matrix(
                    action,
                    self.sorted_action_features,
                    self.action_normalization_parameters,
                    blobname_prefix="action",
                    split_expensive_feature_groups=True,
                )
                next_action, _ = PreprocessorNet().normalize_dense_matrix(
                    next_action,
                    self.sorted_action_features,
                    self.action_normalization_parameters,
                    blobname_prefix="next_action",
                    split_expensive_feature_groups=True,
                )
                if self.include_possible_actions:
                    possible_action_features, _ = PreprocessorNet(
                    ).normalize_dense_matrix(
                        possible_action_features,
                        self.sorted_action_features,
                        self.action_normalization_parameters,
                        blobname_prefix="possible_action",
                        split_expensive_feature_groups=True,
                    )
                    possible_next_action_features, _ = PreprocessorNet(
                    ).normalize_dense_matrix(
                        possible_next_action_features,
                        self.sorted_action_features,
                        self.action_normalization_parameters,
                        blobname_prefix="possible_next_action",
                        split_expensive_feature_groups=True,
                    )
            C2.set_net_and_init_net(None, None)

        output_schema = schema.Struct(
            (InputColumn.STATE_FEATURES, state),
            (InputColumn.NEXT_STATE_FEATURES, next_state),
            (InputColumn.ACTION, action),
            (InputColumn.NEXT_ACTION, next_action),
            (InputColumn.NOT_TERMINAL, input_record[InputColumn.NOT_TERMINAL]),
            (InputColumn.TIME_DIFF, input_record[InputColumn.TIME_DIFF]),
        )

        if self.include_possible_actions:
            # Drop the "lengths" blob from possible_actions_mask since we know
            # it's just a list of [max_num_actions, max_num_actions, ...]
            output_schema += schema.Struct(
                (
                    InputColumn.POSSIBLE_ACTIONS_MASK,
                    input_record[InputColumn.POSSIBLE_ACTIONS_MASK]["values"],
                ),
                (
                    InputColumn.POSSIBLE_NEXT_ACTIONS_MASK,
                    input_record[InputColumn.POSSIBLE_NEXT_ACTIONS_MASK]
                    ["values"],
                ),
            )
            if self.sorted_action_features is not None:
                output_schema += schema.Struct(
                    (InputColumn.POSSIBLE_ACTIONS, possible_action_features),
                    (InputColumn.POSSIBLE_NEXT_ACTIONS,
                     possible_next_action_features),
                )

        net.set_output_record(output_schema)
        return FeatureExtractorNet(net, init_net)