def testMutableHashTableIsLocal(self): with ops.device( estimator._get_replica_device_setter(run_config.RunConfig())): default_val = constant_op.constant([-1, -1], dtypes.int64) table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, default_val) input_string = constant_op.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) self.assertDeviceEqual('', table._table_ref.device) self.assertDeviceEqual('', output.device)
def testMutableHashTableIsOnPs(self): tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}} with test.mock.patch.dict('os.environ', {'TF_CONFIG': json.dumps(tf_config)}): config = run_config.RunConfig() with ops.device(estimator._get_replica_device_setter(config)): default_val = constant_op.constant([-1, -1], dtypes.int64) table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, default_val) input_string = constant_op.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) self.assertDeviceEqual('/job:ps/task:0', table._table_ref.device) self.assertDeviceEqual('/job:ps/task:0', output.device)
def __init__(self, container, examples, variables, options): # pylint: disable=unused-argument """Create a new sdca optimizer.""" # TODO(andreasst): get rid of obsolete container parameter if not examples or not variables or not options: raise ValueError('All arguments must be specified.') supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss') if options['loss_type'] not in supported_losses: raise ValueError('Unsupported loss_type: ', options['loss_type']) self._assertSpecified([ 'example_labels', 'example_weights', 'example_ids', 'sparse_features', 'dense_features' ], examples) self._assertList(['sparse_features', 'dense_features'], examples) self._assertSpecified( ['sparse_features_weights', 'dense_features_weights'], variables) self._assertList(['sparse_features_weights', 'dense_features_weights'], variables) self._assertSpecified([ 'loss_type', 'symmetric_l2_regularization', 'symmetric_l1_regularization' ], options) for name in [ 'symmetric_l1_regularization', 'symmetric_l2_regularization' ]: value = options[name] if value < 0.0: raise ValueError('%s should be non-negative. Found (%f)' % (name, value)) self._examples = examples self._variables = variables self._options = options self._create_slots() self._hashtable = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, [0.0, 0.0, 0.0, 0.0])
def __init__(self, key_dtype, value_dtype, default_value, num_shards=1, name='ShardedMutableHashTable'): with ops.name_scope(name, 'sharded_mutable_hash_table') as scope: super(_ShardedMutableHashTable, self).__init__(key_dtype, value_dtype, scope) table_shards = [] for i in range(num_shards): table_shards.append(lookup_ops.MutableHashTable( key_dtype=key_dtype, value_dtype=value_dtype, default_value=default_value, name='%s-%d-of-%d' % (name, i + 1, num_shards))) self._table_shards = table_shards # TODO(andreasst): add a value_shape() method to LookupInterface # pylint: disable=protected-access self._value_shape = self._table_shards[0]._value_shape
def __init__(self, key_dtype, value_dtype, default_value, num_shards=1, name=None): with ops.op_scope([], name, 'sharded_mutable_hash_table') as scope: super(_ShardedMutableHashTable, self).__init__(key_dtype, value_dtype, scope) table_shards = [] for _ in range(num_shards): # TODO(andreasst): add placement hints once bug 30002625 is fixed. table_shards.append(lookup_ops.MutableHashTable( key_dtype=key_dtype, value_dtype=value_dtype, default_value=default_value, name=name)) self._table_shards = table_shards # TODO(andreasst): add a value_shape() method to LookupInterface # pylint: disable=protected-access self._value_shape = self._table_shards[0]._value_shape