コード例 #1
0
    def test_from_config(self, units, sparse_combiner, trainable, name):
        cols = [
            fc.numeric_column('a'),
            fc.categorical_column_with_vocabulary_list('b',
                                                       vocabulary_list=('1',
                                                                        '2',
                                                                        '3')),
            fc.categorical_column_with_hash_bucket(key='c', hash_bucket_size=3)
        ]
        orig_layer = fc._LinearModelLayer(cols,
                                          units=units,
                                          sparse_combiner=sparse_combiner,
                                          trainable=trainable,
                                          name=name)
        config = orig_layer.get_config()

        new_layer = fc._LinearModelLayer.from_config(config)

        self.assertEqual(new_layer.name, orig_layer.name)
        self.assertEqual(new_layer._units, units)
        self.assertEqual(new_layer._sparse_combiner, sparse_combiner)
        self.assertEqual(new_layer.trainable, trainable)
        self.assertLen(new_layer._feature_columns, 3)
        self.assertEqual(new_layer._feature_columns[0].name, 'a')
        self.assertEqual(new_layer._feature_columns[1].vocabulary_list,
                         ('1', '2', '3'))
        self.assertEqual(new_layer._feature_columns[2].num_buckets, 3)
コード例 #2
0
  def test_get_config(self, units, sparse_combiner, trainable, name):
    cols = [fc.numeric_column('a'),
            fc.categorical_column_with_identity(key='b', num_buckets=3)]
    layer = fc._LinearModelLayer(
        cols, units=units, sparse_combiner=sparse_combiner,
        trainable=trainable, name=name)
    config = layer.get_config()

    self.assertEqual(config['name'], layer.name)
    self.assertEqual(config['trainable'], trainable)
    self.assertEqual(config['units'], units)
    self.assertEqual(config['sparse_combiner'], sparse_combiner)
    self.assertLen(config['feature_columns'], 2)
    self.assertEqual(
        config['feature_columns'][0]['class_name'], 'NumericColumn')
    self.assertEqual(
        config['feature_columns'][1]['class_name'], 'IdentityCategoricalColumn')