def testShardedMutableHashTableVectors(self): for num_shards in [1, 3, 10]: with self.test_session(): default_val = [-0.1, 0.2] empty_key = [0, 1] keys = constant_op.constant([[11, 12], [13, 14], [15, 16]], dtypes.int64) values = constant_op.constant( [[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], dtypes.float32) table = ShardedMutableDenseHashTable(dtypes.int64, dtypes.float32, default_val, empty_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) input_string = constant_op.constant( [[11, 12], [13, 14], [11, 14]], dtypes.int64) output = table.lookup(input_string) self.assertAllEqual([3, 2], output.get_shape()) self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]], output.eval())
def testExportSharded(self): with self.test_session(): empty_key = -2 default_val = -1 num_shards = 2 keys = constant_op.constant([10, 11, 12], dtypes.int64) values = constant_op.constant([2, 3, 4], dtypes.int64) table = ShardedMutableDenseHashTable(dtypes.int64, dtypes.int64, default_val, empty_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) keys_list, values_list = table.export_sharded() self.assertAllEqual(num_shards, len(keys_list)) self.assertAllEqual(num_shards, len(values_list)) # Exported keys include empty key buckets set to the empty_key self.assertAllEqual(set([-2, 10, 12]), set(keys_list[0].eval().flatten())) self.assertAllEqual(set([-2, 11]), set(keys_list[1].eval().flatten())) # Exported values include empty value buckets set to 0 self.assertAllEqual(set([0, 2, 4]), set(values_list[0].eval().flatten())) self.assertAllEqual(set([0, 3]), set(values_list[1].eval().flatten()))
def __init__(self, examples, variables, options): """Create a new sdca optimizer.""" if not examples or not variables or not options: raise ValueError( 'examples, variables and options must all be specified.') supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss', 'smooth_hinge_loss', 'poisson_loss') if options['loss_type'] not in supported_losses: raise ValueError('Unsupported loss_type: ', options['loss_type']) self._assertSpecified([ 'example_labels', 'example_weights', 'example_ids', 'sparse_features', 'dense_features' ], examples) self._assertList(['sparse_features', 'dense_features'], examples) self._assertSpecified( ['sparse_features_weights', 'dense_features_weights'], variables) self._assertList(['sparse_features_weights', 'dense_features_weights'], variables) self._assertSpecified([ 'loss_type', 'symmetric_l2_regularization', 'symmetric_l1_regularization' ], options) for name in [ 'symmetric_l1_regularization', 'symmetric_l2_regularization' ]: value = options[name] if value < 0.0: raise ValueError('%s should be non-negative. Found (%f)' % (name, value)) self._examples = examples self._variables = variables self._options = options self._create_slots() self._hashtable = ShardedMutableDenseHashTable( key_dtype=dtypes.int64, value_dtype=dtypes.float32, num_shards=self._num_table_shards(), default_value=[0.0, 0.0, 0.0, 0.0], # SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe # empty_key (that will never collide with actual payloads). empty_key=[0, 0], deleted_key=[1, 1]) summary.scalar('approximate_duality_gap', self.approximate_duality_gap()) summary.scalar('examples_seen', self._hashtable.size())
def testShardedMutableHashTable(self): for num_shards in [1, 3, 10]: with self.test_session(): default_val = -1 empty_key = 0 keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = ShardedMutableDenseHashTable(dtypes.int64, dtypes.int64, default_val, empty_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) input_string = constant_op.constant([11, 12, 14], dtypes.int64) output = table.lookup(input_string) self.assertAllEqual([3], output.get_shape()) self.assertAllEqual([0, 1, -1], output.eval())