def test_brute_force_topk_sampler_success(self):
     self.assertProtoEquals(
         """
   similarity_type: COSINE
 """, cs_config_builder.brute_force_topk_sampler('COSINE'))
     self.assertProtoEquals(
         """
   similarity_type: COSINE
 """, cs_config_builder.brute_force_topk_sampler(cs_config_pb2.COSINE))
     self.assertProtoEquals(
         """
   similarity_type: DOT_PRODUCT
 """, cs_config_builder.brute_force_topk_sampler('DOT_PRODUCT'))
     self.assertProtoEquals(
         """
   similarity_type: DOT_PRODUCT
 """, cs_config_builder.brute_force_topk_sampler(cs_config_pb2.DOT_PRODUCT))
 def test_brute_force_topk_sampler_failed(self):
     with self.assertRaises(ValueError):
         cs_config_builder.brute_force_topk_sampler(cs_config_pb2.UNKNOWN)
     with self.assertRaises(ValueError):
         cs_config_builder.brute_force_topk_sampler('Unknown type string')
     with self.assertRaises(ValueError):
         cs_config_builder.brute_force_topk_sampler(
             cs_config_pb2.SampleContext())
     with self.assertRaises(ValueError):
         cs_config_builder.brute_force_topk_sampler(999)
示例#3
0
    def test_brute_force_topk(self):
        cs_config = cs_config_builder.build_candidate_sampler_config(
            cs_config_builder.brute_force_topk_sampler('DOT_PRODUCT'))
        de_config = test_util.default_de_config(2, cs_config=cs_config)
        # Add a few embeddings into knowledge bank.
        de_ops.dynamic_embedding_update(['key1', 'key2', 'key3'],
                                        tf.constant([[2.0, 4.0], [4.0, 8.0],
                                                     [8.0, 16.0]]),
                                        de_config,
                                        'emb',
                                        service_address=self._kbs_address)

        keys, logits = cs_ops.top_k([[1.0, 2.0], [-1.0, -2.0]],
                                    3,
                                    de_config,
                                    'emb',
                                    service_address=self._kbs_address)
        self.assertAllEqual(
            keys.numpy(),
            [[b'key3', b'key2', b'key1'], [b'key1', b'key2', b'key3']])
        self.assertAllClose(logits.numpy(), [[40, 20, 10], [-10, -20, -40]])
    def test_build_candidate_sampler_config_success(self):
        self.assertProtoEquals(
            """
        extension {
          [type.googleapis.com/carls.candidate_sampling.BruteForceTopkSamplerConfig] {
            similarity_type: COSINE
          }
        }
    """,
            cs_config_builder.build_candidate_sampler_config(
                cs_config_builder.brute_force_topk_sampler('COSINE')))

        self.assertProtoEquals(
            """
        extension {
          [type.googleapis.com/carls.candidate_sampling.LogUniformSamplerConfig] {
            unique: true
          }
        }
    """,
            cs_config_builder.build_candidate_sampler_config(
                cs_config_builder.log_uniform_sampler(True)))
    def test_build_candidate_sampler_config_success(self):
        self.assertProtoEquals(
            """
        extension {
          [type.googleapis.com/carls.candidate_sampling.BruteForceTopkSamplerConfig] {
            similarity_type: COSINE
          }
        }
    """,
            cs_config_builder.build_candidate_sampler_config(
                cs_config_builder.brute_force_topk_sampler('COSINE')))

        self.assertProtoEquals(
            """
        extension {
          [type.googleapis.com/carls.candidate_sampling.NegativeSamplerConfig] {
            unique: true
            sampler: UNIFORM
          }
        }
    """,
            cs_config_builder.build_candidate_sampler_config(
                cs_config_builder.negative_sampler(True, 'UNIFORM')))