def test_brute_force_topk_sampler_success(self): self.assertProtoEquals( """ similarity_type: COSINE """, cs_config_builder.brute_force_topk_sampler('COSINE')) self.assertProtoEquals( """ similarity_type: COSINE """, cs_config_builder.brute_force_topk_sampler(cs_config_pb2.COSINE)) self.assertProtoEquals( """ similarity_type: DOT_PRODUCT """, cs_config_builder.brute_force_topk_sampler('DOT_PRODUCT')) self.assertProtoEquals( """ similarity_type: DOT_PRODUCT """, cs_config_builder.brute_force_topk_sampler(cs_config_pb2.DOT_PRODUCT))
def test_brute_force_topk_sampler_failed(self): with self.assertRaises(ValueError): cs_config_builder.brute_force_topk_sampler(cs_config_pb2.UNKNOWN) with self.assertRaises(ValueError): cs_config_builder.brute_force_topk_sampler('Unknown type string') with self.assertRaises(ValueError): cs_config_builder.brute_force_topk_sampler( cs_config_pb2.SampleContext()) with self.assertRaises(ValueError): cs_config_builder.brute_force_topk_sampler(999)
def test_brute_force_topk(self): cs_config = cs_config_builder.build_candidate_sampler_config( cs_config_builder.brute_force_topk_sampler('DOT_PRODUCT')) de_config = test_util.default_de_config(2, cs_config=cs_config) # Add a few embeddings into knowledge bank. de_ops.dynamic_embedding_update(['key1', 'key2', 'key3'], tf.constant([[2.0, 4.0], [4.0, 8.0], [8.0, 16.0]]), de_config, 'emb', service_address=self._kbs_address) keys, logits = cs_ops.top_k([[1.0, 2.0], [-1.0, -2.0]], 3, de_config, 'emb', service_address=self._kbs_address) self.assertAllEqual( keys.numpy(), [[b'key3', b'key2', b'key1'], [b'key1', b'key2', b'key3']]) self.assertAllClose(logits.numpy(), [[40, 20, 10], [-10, -20, -40]])
def test_build_candidate_sampler_config_success(self): self.assertProtoEquals( """ extension { [type.googleapis.com/carls.candidate_sampling.BruteForceTopkSamplerConfig] { similarity_type: COSINE } } """, cs_config_builder.build_candidate_sampler_config( cs_config_builder.brute_force_topk_sampler('COSINE'))) self.assertProtoEquals( """ extension { [type.googleapis.com/carls.candidate_sampling.LogUniformSamplerConfig] { unique: true } } """, cs_config_builder.build_candidate_sampler_config( cs_config_builder.log_uniform_sampler(True)))
def test_build_candidate_sampler_config_success(self): self.assertProtoEquals( """ extension { [type.googleapis.com/carls.candidate_sampling.BruteForceTopkSamplerConfig] { similarity_type: COSINE } } """, cs_config_builder.build_candidate_sampler_config( cs_config_builder.brute_force_topk_sampler('COSINE'))) self.assertProtoEquals( """ extension { [type.googleapis.com/carls.candidate_sampling.NegativeSamplerConfig] { unique: true sampler: UNIFORM } } """, cs_config_builder.build_candidate_sampler_config( cs_config_builder.negative_sampler(True, 'UNIFORM')))