def test_restore_knowledge_bank(self): de_ops.dynamic_embedding_update(['first'], tf.constant([4.0, 5.0]), self._config, 'emb', service_address=self._kbs_address) saved_paths = io_ops.save_knowledge_bank(FLAGS.test_tmpdir, self._kbs_address) self.assertLen(saved_paths, 1) # Now updates the embedding value. de_ops.dynamic_embedding_update(['first'], tf.constant([10.0, 20.0]), self._config, 'emb', service_address=self._kbs_address) # Checks it is updated. embedding = de_ops.dynamic_embedding_lookup( ['first'], self._config, 'emb', service_address=self._kbs_address) self.assertAllClose(embedding.numpy(), [[10.0, 20.0]]) # Now restore the knowledge bank. io_ops.restore_knowledge_bank(self._config, 'emb', saved_paths[0].numpy()[0], service_address=self._kbs_address) # Checks it is restored. embedding = de_ops.dynamic_embedding_lookup( ['first'], self._config, 'emb', service_address=self._kbs_address) self.assertAllClose(embedding.numpy(), [[4.0, 5.0]])
def test_compute_sampled_logits_grad(self): cs_config = cs_config_builder.build_candidate_sampler_config( cs_config_builder.negative_sampler(unique=True, algorithm='UNIFORM')) de_config = test_util.default_de_config(3, cs_config=cs_config) # Add a few embeddings into knowledge bank. de_ops.dynamic_embedding_update(['key1', 'key2', 'key3'], tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), de_config, 'emb', service_address=self._kbs_address) # A simple one layer NN model. # Input data: x = [[1, 2], [3, 4]]. # Weights from input to logit output layer: W = [[1, 2, 3], [4, 5, 6]]. # Input activation at output layer i = x*W = [[9, 12, 15], [19, 26, 33]]. # Logits output therefore becomes E*i, where E are the embeddings of output # keys, i.e., E = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]. # Then the logits output becomes [[78, 186, 294], [170, 404, 638]] # # If we define the loss to be L = tf.reduced_sum(Logits), then # dL/dE = sum_by_key(i) = [[28, 38, 48], [28, 38, 48], [28, 38, 48]]. # So the expected new embeddings become # E - 0.1 * dL/dE = [[-1.8, -1.8, -1.8], [1.2, 1.2, 1.2], [4.2, 4.2, 4.2]]. weights = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32) inputs = tf.constant([[1.0, 2.0], [3.0, 4.0]]) with tf.GradientTape() as tape: logits, _, _, _, _ = cs_ops.compute_sampled_logits( [['key1', ''], ['key2', 'key3']], tf.matmul(inputs, weights), 3, de_config, 'emb', service_address=self._kbs_address) loss = tf.reduce_sum(logits) # Applies the gradient descent. grads = tape.gradient(loss, weights) # The gradients updated by the knowledge bank. updated_embedding = de_ops.dynamic_embedding_lookup( ['key1', 'key2', 'key3'], de_config, 'emb', service_address=self._kbs_address) self.assertAllClose( updated_embedding, [[-1.8, -1.8, -1.8], [1.2, 1.2, 1.2], [4.2, 4.2, 4.2]]) # The gradients w.r.t. the weight W is calculated as # dL/dw = dL/di * di/dW = sum_by_dim(E) * x = # [12, 15, 18] * [[4, 4, 4], [6, 6, 6]] = [[48, 60, 72], [72, 90, 108]] self.assertAllClose(grads, [[48, 60, 72], [72, 90, 108]])
def testUpdate_2DInput(self, skip_gradient): init = self._config.knowledge_bank_config.initializer init.default_embedding.value.append(1) init.default_embedding.value.append(2) embedding = de_ops.dynamic_embedding_lookup( [['first', 'second'], ['third', '']], self._config, 'emb', service_address=self._kbs_address, skip_gradient_update=skip_gradient) self.assertAllClose(embedding.numpy(), [[[1, 2], [1, 2]], [[1, 2], [0, 0]]]) # The values for an empty key should be ignored. update_res = de_ops.dynamic_embedding_update( [['first', 'second'], ['third', '']], tf.constant([[[2.0, 4.0], [4.0, 8.0]], [[8.0, 16.0], [16.0, 32.0]]]), self._config, 'emb', service_address=self._kbs_address, ) self.assertAllClose(update_res.numpy(), [[[2, 4], [4, 8]], [[8, 16], [0, 0]]]) embedding = de_ops.dynamic_embedding_lookup( [['first', 'second'], ['third', '']], self._config, 'emb', service_address=self._kbs_address, skip_gradient_update=skip_gradient) self.assertAllClose(embedding.numpy(), [[[2, 4], [4, 8]], [[8, 16], [0, 0]]]) # Allows keys' shape to be [N1, N2, 1] and values shape to be [N1, N2, D]. update_res = de_ops.dynamic_embedding_update( [[['first'], ['second']], [['third'], ['']]], tf.constant([[[3.0, 5.0], [5.0, 9.0]], [[9.0, 17.0], [17.0, 33.0]]]), self._config, 'emb', service_address=self._kbs_address, ) self.assertAllClose(update_res.numpy(), [[[3, 5], [5, 9]], [[9, 17], [0, 0]]]) embedding = de_ops.dynamic_embedding_lookup( [['first', 'second'], ['third', '']], self._config, 'emb', service_address=self._kbs_address, skip_gradient_update=skip_gradient) self.assertAllClose(embedding.numpy(), [[[3, 5], [5, 9]], [[9, 17], [0, 0]]])
def testUpdate_1DInput(self, use_kbs_address, skip_gradient): init = self._config.knowledge_bank_config.initializer init.default_embedding.value.append(1) init.default_embedding.value.append(2) embedding = de_ops.dynamic_embedding_lookup( ['first'], self._config, 'emb', skip_gradient_update=skip_gradient) self.assertAllClose(embedding.numpy(), [[1, 2]]) if use_kbs_address: kbs_address = 'localhost:%d' % self._service_server.port() else: kbs_address = '' update_res = de_ops.dynamic_embedding_update(['first'], tf.constant([[2.0, 4.0]]), self._config, 'emb', kbs_address) self.assertAllClose(update_res.numpy(), [[2, 4]]) embedding = de_ops.dynamic_embedding_lookup( ['first'], self._config, 'emb', kbs_address, skip_gradient_update=skip_gradient) self.assertAllClose(embedding.numpy(), [[2, 4]])
def testUpdate_2DInput(self, skip_gradient): init = self._config.knowledge_bank_config.initializer init.default_embedding.value.append(1) init.default_embedding.value.append(2) embedding = de_ops.dynamic_embedding_lookup( [['first', 'second'], ['third', '']], self._config, 'emb', skip_gradient_update=skip_gradient) self.assertAllClose(embedding.numpy(), [[[1, 2], [1, 2]], [[1, 2], [0, 0]]]) # The values for an empty key should be ignored. update_res = de_ops.dynamic_embedding_update( [['first', 'second'], ['third', '']], tf.constant([[[2.0, 4.0], [4.0, 8.0]], [[8.0, 16.0], [16.0, 32.0]]]), self._config, 'emb') self.assertAllClose(update_res.numpy(), [[[2, 4], [4, 8]], [[8, 16], [0, 0]]]) embedding = de_ops.dynamic_embedding_lookup( [['first', 'second'], ['third', '']], self._config, 'emb', skip_gradient_update=skip_gradient) self.assertAllClose(embedding.numpy(), [[[2, 4], [4, 8]], [[8, 16], [0, 0]]])
def testUpdate_1DInput(self, use_kbs_address, skip_gradient): init = self._config.knowledge_bank_config.initializer init.default_embedding.value.append(1) init.default_embedding.value.append(2) embedding = de_ops.dynamic_embedding_lookup( ['first'], self._config, 'emb', service_address=self._kbs_address, skip_gradient_update=skip_gradient) self.assertAllClose(embedding.numpy(), [[1, 2]]) update_res = de_ops.dynamic_embedding_update( ['first'], tf.constant([[2.0, 4.0]]), self._config, 'emb', service_address=self._kbs_address, ) self.assertAllClose(update_res.numpy(), [[2, 4]]) embedding = de_ops.dynamic_embedding_lookup( ['first'], self._config, 'emb', service_address=self._kbs_address, skip_gradient_update=skip_gradient) self.assertAllClose(embedding.numpy(), [[2, 4]]) # Allows keys' shape to be [N, 1] and values shape to be [N, D]. update_res = de_ops.dynamic_embedding_update( [['first']], tf.constant([[4.0, 5.0]]), self._config, 'emb', service_address=self._kbs_address) self.assertAllClose(update_res.numpy(), [[4, 5]]) embedding = de_ops.dynamic_embedding_lookup( ['first'], self._config, 'emb', service_address=self._kbs_address, skip_gradient_update=skip_gradient) self.assertAllClose(embedding.numpy(), [[4, 5]])
def test_brute_force_topk(self): cs_config = cs_config_builder.build_candidate_sampler_config( cs_config_builder.brute_force_topk_sampler('DOT_PRODUCT')) de_config = test_util.default_de_config(2, cs_config=cs_config) # Add a few embeddings into knowledge bank. de_ops.dynamic_embedding_update(['key1', 'key2', 'key3'], tf.constant([[2.0, 4.0], [4.0, 8.0], [8.0, 16.0]]), de_config, 'emb', service_address=self._kbs_address) keys, logits = cs_ops.top_k([[1.0, 2.0], [-1.0, -2.0]], 3, de_config, 'emb', service_address=self._kbs_address) self.assertAllEqual( keys.numpy(), [[b'key3', b'key2', b'key1'], [b'key1', b'key2', b'key3']]) self.assertAllClose(logits.numpy(), [[40, 20, 10], [-10, -20, -40]])
def test_save_knowledge_bank(self): # Adds an embedding with values [4, 5]. pattern1 = (FLAGS.test_tmpdir + '/knowledge_bank_data_[0-9]+_[0-9]+_[0-9]+' + '/emb1/embedding_store_meta_data.pbtxt') de_ops.dynamic_embedding_update(['first'], tf.constant([4.0, 5.0]), self._config, 'emb1', service_address=self._kbs_address) saved_paths = io_ops.save_knowledge_bank(FLAGS.test_tmpdir, self._kbs_address) self.assertLen(saved_paths, 1) self.assertRegex(saved_paths[0].numpy()[0].decode(), pattern1) # Add another embedding data. pattern2 = (FLAGS.test_tmpdir + '/knowledge_bank_data_[0-9]+_[0-9]+_[0-9]+' + '/emb2/embedding_store_meta_data.pbtxt') de_ops.dynamic_embedding_update(['first'], tf.constant([5.0, 6.0]), self._config, 'emb2', service_address=self._kbs_address) saved_paths = io_ops.save_knowledge_bank(FLAGS.test_tmpdir, self._kbs_address) self.assertLen(saved_paths, 2) self.assertRegex(saved_paths[0].numpy()[0].decode(), pattern1) self.assertRegex(saved_paths[1].numpy()[0].decode(), pattern2) # Only save selected embedding. new_saved_paths = io_ops.save_knowledge_bank(FLAGS.test_tmpdir, self._kbs_address, var_names=['emb2']) self.assertLen(new_saved_paths, 1) self.assertRegex(new_saved_paths[0].numpy()[0].decode(), pattern2) self.assertNotEqual(new_saved_paths[0].numpy()[0], saved_paths[0].numpy()[0])
def update(self, neighbor_ids, neighbor_state): """Updates the neighbor cache with the new state of neighbor examples. Args: neighbor_ids: a string Tensor of shape [batch_size] representing the ids of a neighborhood. neighbor_state: a Tensor of shape [batch_size, ...] representing newly computed neighbor state(e.g. embeddings, logits) that should be stored in the neighbor cache. Returns: A `Tensor` of shape [batch_size, config.embedding_dimension]. """ return de_ops.dynamic_embedding_update(neighbor_ids, neighbor_state, self._config, self._key_feature_name, self._service_address, self._timeout_ms)
def test_compute_sampled_logits(self): cs_config = cs_config_builder.build_candidate_sampler_config( cs_config_builder.negative_sampler(unique=True, algorithm='UNIFORM')) de_config = test_util.default_de_config(3, cs_config=cs_config) # Add a few embeddings into knowledge bank. de_ops.dynamic_embedding_update(['key1', 'key2', 'key3'], tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), de_config, 'emb', service_address=self._kbs_address) # Sample logits. logits, labels, keys, mask, weights = cs_ops.compute_sampled_logits( [['key1', ''], ['key2', 'key3']], tf.constant([[2.0, 4.0, 1], [-2.0, -4.0, 1]]), 3, de_config, 'emb', service_address=self._kbs_address) # Expected results: # - Example one returns one positive key {'key2'} and two negative keys # {'key2', 'key3'}. # - Example two returns two positive keys {'key2', 'key3'} and one # positive key {'key1'}. expected_weights = { b'key1': [1, 2, 3], b'key2': [4, 5, 6], b'key3': [7, 8, 9] } expected_labels = [{ b'key1': 1, b'key2': 0, b'key3': 0 }, { b'key1': 0, b'key2': 1, b'key3': 1 }] # Logit for example one: # - 'key1': [2, 4, 1] * [1, 2, 3] = 13 # - 'key2': [2, 4, 1] * [4, 5, 6] = 34 # - 'key3': [2, 4, 1] * [7, 8, 9] = 55 # Logit for example two: # - 'key1': [-2, -4, 1] * [1, 2, 3] = -7 # - 'key2': [-2, -4, 1] * [4, 5, 6] = -22 # - 'key3': [-2, -4, 1] * [7, 8, 9] = -37 expected_logits = [{ b'key1': 13, b'key2': 34, b'key3': 55 }, { b'key1': -7, b'key2': -22, b'key3': -37 }] # Check keys and weights. for b in range(2): self.assertEqual(1, mask.numpy()[b]) for key in {b'key1', b'key2', b'key3'}: self.assertIn(key, keys.numpy()[b]) for i in range(3): key = keys.numpy()[b][i] self.assertAllClose(expected_weights[key], weights.numpy()[b][i]) self.assertAllClose(expected_labels[b][key], labels.numpy()[b][i]) self.assertAllClose(expected_logits[b][key], logits.numpy()[b][i])