コード例 #1
0
    def test_restore_knowledge_bank(self):
        de_ops.dynamic_embedding_update(['first'],
                                        tf.constant([4.0, 5.0]),
                                        self._config,
                                        'emb',
                                        service_address=self._kbs_address)
        saved_paths = io_ops.save_knowledge_bank(FLAGS.test_tmpdir,
                                                 self._kbs_address)
        self.assertLen(saved_paths, 1)

        # Now updates the embedding value.
        de_ops.dynamic_embedding_update(['first'],
                                        tf.constant([10.0, 20.0]),
                                        self._config,
                                        'emb',
                                        service_address=self._kbs_address)

        # Checks it is updated.
        embedding = de_ops.dynamic_embedding_lookup(
            ['first'], self._config, 'emb', service_address=self._kbs_address)
        self.assertAllClose(embedding.numpy(), [[10.0, 20.0]])

        # Now restore the knowledge bank.
        io_ops.restore_knowledge_bank(self._config,
                                      'emb',
                                      saved_paths[0].numpy()[0],
                                      service_address=self._kbs_address)

        # Checks it is restored.
        embedding = de_ops.dynamic_embedding_lookup(
            ['first'], self._config, 'emb', service_address=self._kbs_address)
        self.assertAllClose(embedding.numpy(), [[4.0, 5.0]])
コード例 #2
0
    def test_compute_sampled_logits_grad(self):
        cs_config = cs_config_builder.build_candidate_sampler_config(
            cs_config_builder.negative_sampler(unique=True,
                                               algorithm='UNIFORM'))
        de_config = test_util.default_de_config(3, cs_config=cs_config)

        # Add a few embeddings into knowledge bank.
        de_ops.dynamic_embedding_update(['key1', 'key2', 'key3'],
                                        tf.constant([[1.0, 2.0, 3.0],
                                                     [4.0, 5.0, 6.0],
                                                     [7.0, 8.0, 9.0]]),
                                        de_config,
                                        'emb',
                                        service_address=self._kbs_address)

        # A simple one layer NN model.
        # Input data: x = [[1, 2], [3, 4]].
        # Weights from input to logit output layer: W = [[1, 2, 3], [4, 5, 6]].
        # Input activation at output layer i = x*W = [[9, 12, 15], [19, 26, 33]].
        # Logits output therefore becomes E*i, where E are the embeddings of output
        #  keys, i.e., E = [[1, 2, 3], [4, 5, 6], [7, 8, 9]].
        # Then the logits output becomes [[78, 186, 294], [170, 404, 638]]
        #
        # If we define the loss to be L = tf.reduced_sum(Logits), then
        # dL/dE = sum_by_key(i) = [[28, 38, 48], [28, 38, 48], [28, 38, 48]].
        # So the expected new embeddings become
        # E - 0.1 * dL/dE = [[-1.8, -1.8, -1.8], [1.2, 1.2, 1.2], [4.2, 4.2, 4.2]].
        weights = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
        inputs = tf.constant([[1.0, 2.0], [3.0, 4.0]])

        with tf.GradientTape() as tape:
            logits, _, _, _, _ = cs_ops.compute_sampled_logits(
                [['key1', ''], ['key2', 'key3']],
                tf.matmul(inputs, weights),
                3,
                de_config,
                'emb',
                service_address=self._kbs_address)
            loss = tf.reduce_sum(logits)

        # Applies the gradient descent.
        grads = tape.gradient(loss, weights)

        # The gradients updated by the knowledge bank.
        updated_embedding = de_ops.dynamic_embedding_lookup(
            ['key1', 'key2', 'key3'],
            de_config,
            'emb',
            service_address=self._kbs_address)
        self.assertAllClose(
            updated_embedding,
            [[-1.8, -1.8, -1.8], [1.2, 1.2, 1.2], [4.2, 4.2, 4.2]])

        # The gradients w.r.t. the weight W is calculated as
        # dL/dw = dL/di * di/dW = sum_by_dim(E) * x =
        # [12, 15, 18] * [[4, 4, 4], [6, 6, 6]] = [[48, 60, 72], [72, 90, 108]]
        self.assertAllClose(grads, [[48, 60, 72], [72, 90, 108]])
コード例 #3
0
    def testUpdate_2DInput(self, skip_gradient):
        init = self._config.knowledge_bank_config.initializer
        init.default_embedding.value.append(1)
        init.default_embedding.value.append(2)
        embedding = de_ops.dynamic_embedding_lookup(
            [['first', 'second'], ['third', '']],
            self._config,
            'emb',
            service_address=self._kbs_address,
            skip_gradient_update=skip_gradient)
        self.assertAllClose(embedding.numpy(),
                            [[[1, 2], [1, 2]], [[1, 2], [0, 0]]])

        # The values for an empty key should be ignored.
        update_res = de_ops.dynamic_embedding_update(
            [['first', 'second'], ['third', '']],
            tf.constant([[[2.0, 4.0], [4.0, 8.0]], [[8.0, 16.0], [16.0,
                                                                  32.0]]]),
            self._config,
            'emb',
            service_address=self._kbs_address,
        )
        self.assertAllClose(update_res.numpy(),
                            [[[2, 4], [4, 8]], [[8, 16], [0, 0]]])

        embedding = de_ops.dynamic_embedding_lookup(
            [['first', 'second'], ['third', '']],
            self._config,
            'emb',
            service_address=self._kbs_address,
            skip_gradient_update=skip_gradient)
        self.assertAllClose(embedding.numpy(),
                            [[[2, 4], [4, 8]], [[8, 16], [0, 0]]])

        # Allows keys' shape to be [N1, N2, 1] and values shape to be [N1, N2, D].
        update_res = de_ops.dynamic_embedding_update(
            [[['first'], ['second']], [['third'], ['']]],
            tf.constant([[[3.0, 5.0], [5.0, 9.0]], [[9.0, 17.0], [17.0,
                                                                  33.0]]]),
            self._config,
            'emb',
            service_address=self._kbs_address,
        )
        self.assertAllClose(update_res.numpy(),
                            [[[3, 5], [5, 9]], [[9, 17], [0, 0]]])
        embedding = de_ops.dynamic_embedding_lookup(
            [['first', 'second'], ['third', '']],
            self._config,
            'emb',
            service_address=self._kbs_address,
            skip_gradient_update=skip_gradient)
        self.assertAllClose(embedding.numpy(),
                            [[[3, 5], [5, 9]], [[9, 17], [0, 0]]])
コード例 #4
0
    def testUpdate_1DInput(self, use_kbs_address, skip_gradient):
        init = self._config.knowledge_bank_config.initializer
        init.default_embedding.value.append(1)
        init.default_embedding.value.append(2)
        embedding = de_ops.dynamic_embedding_lookup(
            ['first'], self._config, 'emb', skip_gradient_update=skip_gradient)
        self.assertAllClose(embedding.numpy(), [[1, 2]])

        if use_kbs_address:
            kbs_address = 'localhost:%d' % self._service_server.port()
        else:
            kbs_address = ''

        update_res = de_ops.dynamic_embedding_update(['first'],
                                                     tf.constant([[2.0, 4.0]]),
                                                     self._config, 'emb',
                                                     kbs_address)
        self.assertAllClose(update_res.numpy(), [[2, 4]])

        embedding = de_ops.dynamic_embedding_lookup(
            ['first'],
            self._config,
            'emb',
            kbs_address,
            skip_gradient_update=skip_gradient)
        self.assertAllClose(embedding.numpy(), [[2, 4]])
コード例 #5
0
    def testUpdate_2DInput(self, skip_gradient):
        init = self._config.knowledge_bank_config.initializer
        init.default_embedding.value.append(1)
        init.default_embedding.value.append(2)
        embedding = de_ops.dynamic_embedding_lookup(
            [['first', 'second'], ['third', '']],
            self._config,
            'emb',
            skip_gradient_update=skip_gradient)
        self.assertAllClose(embedding.numpy(),
                            [[[1, 2], [1, 2]], [[1, 2], [0, 0]]])

        # The values for an empty key should be ignored.
        update_res = de_ops.dynamic_embedding_update(
            [['first', 'second'], ['third', '']],
            tf.constant([[[2.0, 4.0], [4.0, 8.0]],
                         [[8.0, 16.0], [16.0, 32.0]]]), self._config, 'emb')
        self.assertAllClose(update_res.numpy(),
                            [[[2, 4], [4, 8]], [[8, 16], [0, 0]]])

        embedding = de_ops.dynamic_embedding_lookup(
            [['first', 'second'], ['third', '']],
            self._config,
            'emb',
            skip_gradient_update=skip_gradient)
        self.assertAllClose(embedding.numpy(),
                            [[[2, 4], [4, 8]], [[8, 16], [0, 0]]])
コード例 #6
0
    def testUpdate_1DInput(self, use_kbs_address, skip_gradient):
        init = self._config.knowledge_bank_config.initializer
        init.default_embedding.value.append(1)
        init.default_embedding.value.append(2)
        embedding = de_ops.dynamic_embedding_lookup(
            ['first'],
            self._config,
            'emb',
            service_address=self._kbs_address,
            skip_gradient_update=skip_gradient)
        self.assertAllClose(embedding.numpy(), [[1, 2]])

        update_res = de_ops.dynamic_embedding_update(
            ['first'],
            tf.constant([[2.0, 4.0]]),
            self._config,
            'emb',
            service_address=self._kbs_address,
        )
        self.assertAllClose(update_res.numpy(), [[2, 4]])

        embedding = de_ops.dynamic_embedding_lookup(
            ['first'],
            self._config,
            'emb',
            service_address=self._kbs_address,
            skip_gradient_update=skip_gradient)
        self.assertAllClose(embedding.numpy(), [[2, 4]])

        # Allows keys' shape to be [N, 1] and values shape to be [N, D].
        update_res = de_ops.dynamic_embedding_update(
            [['first']],
            tf.constant([[4.0, 5.0]]),
            self._config,
            'emb',
            service_address=self._kbs_address)
        self.assertAllClose(update_res.numpy(), [[4, 5]])
        embedding = de_ops.dynamic_embedding_lookup(
            ['first'],
            self._config,
            'emb',
            service_address=self._kbs_address,
            skip_gradient_update=skip_gradient)
        self.assertAllClose(embedding.numpy(), [[4, 5]])
コード例 #7
0
    def test_brute_force_topk(self):
        cs_config = cs_config_builder.build_candidate_sampler_config(
            cs_config_builder.brute_force_topk_sampler('DOT_PRODUCT'))
        de_config = test_util.default_de_config(2, cs_config=cs_config)
        # Add a few embeddings into knowledge bank.
        de_ops.dynamic_embedding_update(['key1', 'key2', 'key3'],
                                        tf.constant([[2.0, 4.0], [4.0, 8.0],
                                                     [8.0, 16.0]]),
                                        de_config,
                                        'emb',
                                        service_address=self._kbs_address)

        keys, logits = cs_ops.top_k([[1.0, 2.0], [-1.0, -2.0]],
                                    3,
                                    de_config,
                                    'emb',
                                    service_address=self._kbs_address)
        self.assertAllEqual(
            keys.numpy(),
            [[b'key3', b'key2', b'key1'], [b'key1', b'key2', b'key3']])
        self.assertAllClose(logits.numpy(), [[40, 20, 10], [-10, -20, -40]])
コード例 #8
0
    def test_save_knowledge_bank(self):
        # Adds an embedding with values [4, 5].
        pattern1 = (FLAGS.test_tmpdir +
                    '/knowledge_bank_data_[0-9]+_[0-9]+_[0-9]+' +
                    '/emb1/embedding_store_meta_data.pbtxt')
        de_ops.dynamic_embedding_update(['first'],
                                        tf.constant([4.0, 5.0]),
                                        self._config,
                                        'emb1',
                                        service_address=self._kbs_address)
        saved_paths = io_ops.save_knowledge_bank(FLAGS.test_tmpdir,
                                                 self._kbs_address)
        self.assertLen(saved_paths, 1)
        self.assertRegex(saved_paths[0].numpy()[0].decode(), pattern1)

        # Add another embedding data.
        pattern2 = (FLAGS.test_tmpdir +
                    '/knowledge_bank_data_[0-9]+_[0-9]+_[0-9]+' +
                    '/emb2/embedding_store_meta_data.pbtxt')
        de_ops.dynamic_embedding_update(['first'],
                                        tf.constant([5.0, 6.0]),
                                        self._config,
                                        'emb2',
                                        service_address=self._kbs_address)
        saved_paths = io_ops.save_knowledge_bank(FLAGS.test_tmpdir,
                                                 self._kbs_address)
        self.assertLen(saved_paths, 2)
        self.assertRegex(saved_paths[0].numpy()[0].decode(), pattern1)
        self.assertRegex(saved_paths[1].numpy()[0].decode(), pattern2)

        # Only save selected embedding.
        new_saved_paths = io_ops.save_knowledge_bank(FLAGS.test_tmpdir,
                                                     self._kbs_address,
                                                     var_names=['emb2'])
        self.assertLen(new_saved_paths, 1)
        self.assertRegex(new_saved_paths[0].numpy()[0].decode(), pattern2)
        self.assertNotEqual(new_saved_paths[0].numpy()[0],
                            saved_paths[0].numpy()[0])
    def update(self, neighbor_ids, neighbor_state):
        """Updates the neighbor cache with the new state of neighbor examples.

    Args:
      neighbor_ids: a string Tensor of shape [batch_size] representing the ids
        of a neighborhood.
      neighbor_state: a Tensor of shape [batch_size, ...] representing newly
        computed neighbor state(e.g. embeddings, logits) that should be stored
        in the neighbor cache.

    Returns:
      A `Tensor` of shape [batch_size, config.embedding_dimension].
    """
        return de_ops.dynamic_embedding_update(neighbor_ids, neighbor_state,
                                               self._config,
                                               self._key_feature_name,
                                               self._service_address,
                                               self._timeout_ms)
コード例 #10
0
    def test_compute_sampled_logits(self):
        cs_config = cs_config_builder.build_candidate_sampler_config(
            cs_config_builder.negative_sampler(unique=True,
                                               algorithm='UNIFORM'))
        de_config = test_util.default_de_config(3, cs_config=cs_config)

        # Add a few embeddings into knowledge bank.
        de_ops.dynamic_embedding_update(['key1', 'key2', 'key3'],
                                        tf.constant([[1.0, 2.0, 3.0],
                                                     [4.0, 5.0, 6.0],
                                                     [7.0, 8.0, 9.0]]),
                                        de_config,
                                        'emb',
                                        service_address=self._kbs_address)

        # Sample logits.
        logits, labels, keys, mask, weights = cs_ops.compute_sampled_logits(
            [['key1', ''], ['key2', 'key3']],
            tf.constant([[2.0, 4.0, 1], [-2.0, -4.0, 1]]),
            3,
            de_config,
            'emb',
            service_address=self._kbs_address)

        # Expected results:
        # - Example one returns one positive key {'key2'} and two negative keys
        #   {'key2', 'key3'}.
        # - Example two returns two positive keys {'key2', 'key3'} and one
        #   positive key {'key1'}.
        expected_weights = {
            b'key1': [1, 2, 3],
            b'key2': [4, 5, 6],
            b'key3': [7, 8, 9]
        }
        expected_labels = [{
            b'key1': 1,
            b'key2': 0,
            b'key3': 0
        }, {
            b'key1': 0,
            b'key2': 1,
            b'key3': 1
        }]
        # Logit for example one:
        # - 'key1': [2, 4, 1] * [1, 2, 3] = 13
        # - 'key2': [2, 4, 1] * [4, 5, 6] = 34
        # - 'key3': [2, 4, 1] * [7, 8, 9] = 55
        # Logit for example two:
        # - 'key1': [-2, -4, 1] * [1, 2, 3] = -7
        # - 'key2': [-2, -4, 1] * [4, 5, 6] = -22
        # - 'key3': [-2, -4, 1] * [7, 8, 9] = -37
        expected_logits = [{
            b'key1': 13,
            b'key2': 34,
            b'key3': 55
        }, {
            b'key1': -7,
            b'key2': -22,
            b'key3': -37
        }]
        # Check keys and weights.
        for b in range(2):
            self.assertEqual(1, mask.numpy()[b])

            for key in {b'key1', b'key2', b'key3'}:
                self.assertIn(key, keys.numpy()[b])
            for i in range(3):
                key = keys.numpy()[b][i]
                self.assertAllClose(expected_weights[key],
                                    weights.numpy()[b][i])
                self.assertAllClose(expected_labels[b][key],
                                    labels.numpy()[b][i])
                self.assertAllClose(expected_logits[b][key],
                                    logits.numpy()[b][i])