def test_sep_ignored_in_hashed_out(self): sp_inp_1 = self._sparse_tensor( [['batch1-FC1-F1', 'batch1-FC1-F2', 'batch1-FC1-F3']]) sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1']]) sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']]) inds, vals, shapes = gen_sparse_ops.sparse_cross_hashed( indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], shapes=[ sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape ], dense_inputs=[], strong_hash=True, num_buckets=1000, salt=[137, 173]) output = sparse_tensor.SparseTensor(inds, vals, shapes) inds_2, vals_2, shapes_2 = gen_sparse_ops.sparse_cross_hashed( indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], shapes=[ sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape ], dense_inputs=[], strong_hash=True, num_buckets=1000, salt=[137, 173]) output_2 = sparse_tensor.SparseTensor(inds_2, vals_2, shapes_2) with self.cached_session(): out = self.evaluate(output) out_2 = self.evaluate(output_2) self.assertAllEqual(out.indices, out_2.indices) self.assertAllEqual(out.values, out_2.values)
def test_hashed_zero_bucket_no_hash_key(self): sp_inp_1 = self._sparse_tensor([['batch1-FC1-F1']]) sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1']]) sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1']]) inds, vals, shapes = gen_sparse_ops.sparse_cross_hashed( indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], shapes=[ sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape ], dense_inputs=[], num_buckets=0, salt=[1, 1], strong_hash=False) # Check actual hashed output to prevent unintentional hashing changes. expected_out = self._sparse_tensor([[9186962005966787372]]) out = sparse_tensor.SparseTensor(inds, vals, shapes) with self.cached_session(): self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) # salt is not being used when `strong_hash` is False. inds_2, vals_2, shapes_2 = gen_sparse_ops.sparse_cross_hashed( indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], shapes=[ sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape ], dense_inputs=[], num_buckets=0, salt=[137, 173], strong_hash=False) out_2 = sparse_tensor.SparseTensor(inds_2, vals_2, shapes_2) with self.cached_session(): self._assert_sparse_tensor_equals(expected_out, self.evaluate(out_2))
def _process_input_list(self, inputs): # TODO(momernick): support ragged_cross_hashed with corrected fingerprint # and siphash. if any(isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs): raise ValueError('Hashing with ragged input is not supported yet.') sparse_inputs = [ inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor) ] dense_inputs = [ inp for inp in inputs if not isinstance(inp, sparse_tensor.SparseTensor) ] all_dense = True if not sparse_inputs else False indices = [sp_inp.indices for sp_inp in sparse_inputs] values = [sp_inp.values for sp_inp in sparse_inputs] shapes = [sp_inp.dense_shape for sp_inp in sparse_inputs] indices_out, values_out, shapes_out = gen_sparse_ops.sparse_cross_hashed( indices=indices, values=values, shapes=shapes, dense_inputs=dense_inputs, num_buckets=self.num_bins, strong_hash=self.strong_hash, salt=self.salt) sparse_out = sparse_tensor.SparseTensor(indices_out, values_out, shapes_out) if all_dense: return sparse_ops.sparse_tensor_to_dense(sparse_out) return sparse_out
def test_hashed_3x1x2(self): """Tests 3x1x2 permutation with hashed output.""" sp_inp_1 = self._sparse_tensor( [['batch1-FC1-F1', 'batch1-FC1-F2', 'batch1-FC1-F3']]) sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1']]) sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']]) inds, vals, shapes = gen_sparse_ops.sparse_cross_hashed( indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], shapes=[ sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape ], dense_inputs=[], num_buckets=1000, salt=[137, 173], strong_hash=False) output = sparse_tensor.SparseTensor(inds, vals, shapes) with self.cached_session(): out = self.evaluate(output) self.assertEqual(6, len(out.values)) self.assertAllEqual([[0, i] for i in range(6)], out.indices) self.assertTrue(all(x < 1000 and x >= 0 for x in out.values)) all_values_are_different = len(out.values) == len(set(out.values)) self.assertTrue(all_values_are_different)
def test_hashed_output(self): sp_inp_1 = self._sparse_tensor([['batch1-FC1-F1']]) sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1']]) sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1']]) inds, vals, shapes = gen_sparse_ops.sparse_cross_hashed( indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], shapes=[ sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape ], dense_inputs=[], num_buckets=100, salt=[137, 173], strong_hash=False) # Check actual hashed output to prevent unintentional hashing changes. expected_out = self._sparse_tensor([[79]]) out = sparse_tensor.SparseTensor(inds, vals, shapes) with self.cached_session(): self._assert_sparse_tensor_equals(expected_out, self.evaluate(out))
def test_hashed_has_no_collision(self): """Tests that fingerprint concatenation has no collisions.""" # Although the last 10 bits of 359 and 1024+359 are identical. # As a result, all the crosses shouldn't collide. t1 = constant_op.constant([[359], [359 + 1024]], dtype=dtypes.int64) t2 = constant_op.constant( [list(range(10)), list(range(10))], dtype=dtypes.int64) inds, vals, shapes = gen_sparse_ops.sparse_cross_hashed( indices=[], values=[], shapes=[], dense_inputs=[t2, t1], num_buckets=1024, salt=[137, 173], strong_hash=False) cross = sparse_tensor.SparseTensor(inds, vals, shapes) cross_dense = sparse_ops.sparse_tensor_to_dense(cross) with session.Session(): values = self.evaluate(cross_dense) self.assertTrue(numpy.not_equal(values[0], values[1]).all())