def testMutableHashTableIsLocal(self):
   with ops.device(
       estimator._get_replica_device_setter(run_config.RunConfig())):
     default_val = constant_op.constant([-1, -1], dtypes.int64)
     table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
                                         default_val)
     input_string = constant_op.constant(['brain', 'salad', 'tank'])
     output = table.lookup(input_string)
   self.assertDeviceEqual('', table._table_ref.device)
   self.assertDeviceEqual('', output.device)
  def testMutableHashTableIsOnPs(self):
    tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}}
    with test.mock.patch.dict('os.environ',
                              {'TF_CONFIG': json.dumps(tf_config)}):
      config = run_config.RunConfig()

    with ops.device(estimator._get_replica_device_setter(config)):
      default_val = constant_op.constant([-1, -1], dtypes.int64)
      table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
                                          default_val)
      input_string = constant_op.constant(['brain', 'salad', 'tank'])
      output = table.lookup(input_string)
    self.assertDeviceEqual('/job:ps/task:0', table._table_ref.device)
    self.assertDeviceEqual('/job:ps/task:0', output.device)
示例#3
0
    def __init__(self, container, examples, variables, options):  # pylint: disable=unused-argument
        """Create a new sdca optimizer."""
        # TODO(andreasst): get rid of obsolete container parameter

        if not examples or not variables or not options:
            raise ValueError('All arguments must be specified.')

        supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss')
        if options['loss_type'] not in supported_losses:
            raise ValueError('Unsupported loss_type: ', options['loss_type'])

        self._assertSpecified([
            'example_labels', 'example_weights', 'example_ids',
            'sparse_features', 'dense_features'
        ], examples)
        self._assertList(['sparse_features', 'dense_features'], examples)

        self._assertSpecified(
            ['sparse_features_weights', 'dense_features_weights'], variables)
        self._assertList(['sparse_features_weights', 'dense_features_weights'],
                         variables)

        self._assertSpecified([
            'loss_type', 'symmetric_l2_regularization',
            'symmetric_l1_regularization'
        ], options)

        for name in [
                'symmetric_l1_regularization', 'symmetric_l2_regularization'
        ]:
            value = options[name]
            if value < 0.0:
                raise ValueError('%s should be non-negative. Found (%f)' %
                                 (name, value))

        self._examples = examples
        self._variables = variables
        self._options = options
        self._create_slots()
        self._hashtable = lookup_ops.MutableHashTable(dtypes.string,
                                                      dtypes.float32,
                                                      [0.0, 0.0, 0.0, 0.0])
示例#4
0
 def __init__(self,
              key_dtype,
              value_dtype,
              default_value,
              num_shards=1,
              name='ShardedMutableHashTable'):
   with ops.name_scope(name, 'sharded_mutable_hash_table') as scope:
     super(_ShardedMutableHashTable, self).__init__(key_dtype, value_dtype,
                                                    scope)
     table_shards = []
     for i in range(num_shards):
       table_shards.append(lookup_ops.MutableHashTable(
           key_dtype=key_dtype,
           value_dtype=value_dtype,
           default_value=default_value,
           name='%s-%d-of-%d' % (name, i + 1, num_shards)))
     self._table_shards = table_shards
     # TODO(andreasst): add a value_shape() method to LookupInterface
     # pylint: disable=protected-access
     self._value_shape = self._table_shards[0]._value_shape
示例#5
0
 def __init__(self,
              key_dtype,
              value_dtype,
              default_value,
              num_shards=1,
              name=None):
   with ops.op_scope([], name, 'sharded_mutable_hash_table') as scope:
     super(_ShardedMutableHashTable, self).__init__(key_dtype, value_dtype,
                                                    scope)
     table_shards = []
     for _ in range(num_shards):
       # TODO(andreasst): add placement hints once bug 30002625 is fixed.
       table_shards.append(lookup_ops.MutableHashTable(
           key_dtype=key_dtype,
           value_dtype=value_dtype,
           default_value=default_value,
           name=name))
     self._table_shards = table_shards
     # TODO(andreasst): add a value_shape() method to LookupInterface
     # pylint: disable=protected-access
     self._value_shape = self._table_shards[0]._value_shape