示例#1
0
    def testShardedMutableHashTableVectors(self):
        for num_shards in [1, 3, 10]:
            with self.test_session():
                default_val = [-0.1, 0.2]
                empty_key = [0, 1]
                keys = constant_op.constant([[11, 12], [13, 14], [15, 16]],
                                            dtypes.int64)
                values = constant_op.constant(
                    [[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], dtypes.float32)
                table = ShardedMutableDenseHashTable(dtypes.int64,
                                                     dtypes.float32,
                                                     default_val,
                                                     empty_key,
                                                     num_shards=num_shards)
                self.assertAllEqual(0, table.size().eval())

                table.insert(keys, values).run()
                self.assertAllEqual(3, table.size().eval())

                input_string = constant_op.constant(
                    [[11, 12], [13, 14], [11, 14]], dtypes.int64)
                output = table.lookup(input_string)
                self.assertAllEqual([3, 2], output.get_shape())
                self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]],
                                    output.eval())
示例#2
0
    def testExportSharded(self):
        with self.test_session():
            empty_key = -2
            default_val = -1
            num_shards = 2
            keys = constant_op.constant([10, 11, 12], dtypes.int64)
            values = constant_op.constant([2, 3, 4], dtypes.int64)
            table = ShardedMutableDenseHashTable(dtypes.int64,
                                                 dtypes.int64,
                                                 default_val,
                                                 empty_key,
                                                 num_shards=num_shards)
            self.assertAllEqual(0, table.size().eval())

            table.insert(keys, values).run()
            self.assertAllEqual(3, table.size().eval())

            keys_list, values_list = table.export_sharded()
            self.assertAllEqual(num_shards, len(keys_list))
            self.assertAllEqual(num_shards, len(values_list))

            # Exported keys include empty key buckets set to the empty_key
            self.assertAllEqual(set([-2, 10, 12]),
                                set(keys_list[0].eval().flatten()))
            self.assertAllEqual(set([-2, 11]),
                                set(keys_list[1].eval().flatten()))
            # Exported values include empty value buckets set to 0
            self.assertAllEqual(set([0, 2, 4]),
                                set(values_list[0].eval().flatten()))
            self.assertAllEqual(set([0, 3]),
                                set(values_list[1].eval().flatten()))
示例#3
0
    def __init__(self, examples, variables, options):
        """Create a new sdca optimizer."""

        if not examples or not variables or not options:
            raise ValueError(
                'examples, variables and options must all be specified.')

        supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss',
                            'smooth_hinge_loss', 'poisson_loss')
        if options['loss_type'] not in supported_losses:
            raise ValueError('Unsupported loss_type: ', options['loss_type'])

        self._assertSpecified([
            'example_labels', 'example_weights', 'example_ids',
            'sparse_features', 'dense_features'
        ], examples)
        self._assertList(['sparse_features', 'dense_features'], examples)

        self._assertSpecified(
            ['sparse_features_weights', 'dense_features_weights'], variables)
        self._assertList(['sparse_features_weights', 'dense_features_weights'],
                         variables)

        self._assertSpecified([
            'loss_type', 'symmetric_l2_regularization',
            'symmetric_l1_regularization'
        ], options)

        for name in [
                'symmetric_l1_regularization', 'symmetric_l2_regularization'
        ]:
            value = options[name]
            if value < 0.0:
                raise ValueError('%s should be non-negative. Found (%f)' %
                                 (name, value))

        self._examples = examples
        self._variables = variables
        self._options = options
        self._create_slots()
        self._hashtable = ShardedMutableDenseHashTable(
            key_dtype=dtypes.int64,
            value_dtype=dtypes.float32,
            num_shards=self._num_table_shards(),
            default_value=[0.0, 0.0, 0.0, 0.0],
            # SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe
            # empty_key (that will never collide with actual payloads).
            empty_key=[0, 0],
            deleted_key=[1, 1])

        summary.scalar('approximate_duality_gap',
                       self.approximate_duality_gap())
        summary.scalar('examples_seen', self._hashtable.size())
示例#4
0
    def testShardedMutableHashTable(self):
        for num_shards in [1, 3, 10]:
            with self.test_session():
                default_val = -1
                empty_key = 0
                keys = constant_op.constant([11, 12, 13], dtypes.int64)
                values = constant_op.constant([0, 1, 2], dtypes.int64)
                table = ShardedMutableDenseHashTable(dtypes.int64,
                                                     dtypes.int64,
                                                     default_val,
                                                     empty_key,
                                                     num_shards=num_shards)
                self.assertAllEqual(0, table.size().eval())

                table.insert(keys, values).run()
                self.assertAllEqual(3, table.size().eval())

                input_string = constant_op.constant([11, 12, 14], dtypes.int64)
                output = table.lookup(input_string)
                self.assertAllEqual([3], output.get_shape())
                self.assertAllEqual([0, 1, -1], output.eval())