def pull_embedding_vectors(self, layer_name, embedding_ids): """Pulls and returns embedding vectors ordered by the embedding ids.""" ps_ids = {} ps_ids_index = {} for idx, embedding_id in enumerate(embedding_ids): ps_id = int_to_id(embedding_id, self._ps_num) ps_ids.setdefault(ps_id, []).append(embedding_id) ps_ids_index.setdefault(ps_id, []).append(idx) embeddings = [] index = [] pb_future_and_id_pairs = [] for ps_id, embedding_ids in ps_ids.items(): req = elasticdl_pb2.PullEmbeddingVectorRequest() req.name = layer_name req.ids.extend(embedding_ids) pb_future = self._ps_stubs[ps_id].pull_embedding_vectors.future( req ) pb_future_and_id_pairs.append((pb_future, ps_id)) for pb_future, ps_id in pb_future_and_id_pairs: pb = pb_future.result() embeddings.append(pb_to_ndarray(pb)) index.extend(ps_ids_index[ps_id]) embeddings = np.concatenate(embeddings) # adjust the order of embedding vectors new_embeddings = np.empty_like(embeddings) new_embeddings[index] = embeddings return new_embeddings
def get_params_shard_from_pb(model_pb, shard_index, shard_num): """Get parameters including variables values and embedding table from a model protobuf. Args: model_pb: A Model protobuf instance. shard_index: Model shard index. shard_num: The total number of model shards. Return: non_embedding_vars: A Python dict in which the key is a variable name and the value is a `tf.Variable` object. embedding_table_values: A Python dict in which the key is an embedding table name and the value is a tuple with 2 elements. The value[0] is indices and value[1] is the corresponding embedding vector. """ non_embedding_vars = {} embedding_table_values = {} for tensor_pb in model_pb.param: tensor = Tensor.from_tensor_pb(tensor_pb) if tensor.indices is not None: embedding_table_values.setdefault(tensor.name, ([], [])) for embedding_id, vector in zip(tensor.indices, tensor.values): if int_to_id(embedding_id, shard_num) == shard_index: embedding_table_values[tensor.name][0].append(embedding_id) embedding_table_values[tensor.name][1].append(vector) else: if string_to_id(tensor.name, shard_num) == shard_index: non_embedding_vars[tensor.name] = tf.Variable( initial_value=tensor.values, trainable=True) return non_embedding_vars, embedding_table_values
def test_scatter_embedding_vector(self): vectors = np.array([[1, 2], [3, 4], [5, 6], [1, 7], [3, 9]]) indices = np.array([0, 1, 2, 3, 4]) num = 2 expected_results = {} for i, item_id in enumerate(indices): ps_id = int_to_id(item_id, num) if ps_id not in expected_results: item_list = [item_id] expected_results[ps_id] = [ np.expand_dims(vectors[i, :], axis=0), item_list, ] else: expected_results[ps_id][0] = np.concatenate( ( expected_results[ps_id][0], np.expand_dims(vectors[i, :], axis=0), ), axis=0, ) expected_results[ps_id][1].append(item_id) results = scatter_embedding_vector(vectors, indices, num) for ps_id in range(num): np.testing.assert_array_equal(results[ps_id][0], expected_results[ps_id][0]) self.assertListEqual(results[ps_id][1], expected_results[ps_id][1])
def _get_params_shard_from_pb(model_pb, shard_index, shard_num): """Get parameters including variables values and embedding table from a model protobuf. Args: model_pb: A Model protobuf instance. shard_index: Model shard index. shard_num: The total number of model shards. Return: non_embedding_vars: A Python dict in which the key is a variable name and the value is a `tf.Variable` object. embedding_table_values: A Python dict in which the key is an embedding table name and the value is a tuple with 2 elements. The value[0] is indices and value[1] is the corresponding embedding vector. """ non_embedding_vars = {} embedding_table_values = {} for name, pb in model_pb.dense_parameters.items(): if string_to_id(name, shard_num) == shard_index: non_embedding_vars[name] = tf.Variable( initial_value=pb_to_ndarray(pb), trainable=True) for name, pb in model_pb.embedding_tables.items(): embedding_table_values.setdefault(name, ([], [])) t = pb_to_indexed_slices(pb) for embedding_id, vector in zip(t.indices, t.values): if int_to_id(embedding_id, shard_num) == shard_index: embedding_table_values[name][0].append(embedding_id) embedding_table_values[name][1].append(vector) return non_embedding_vars, embedding_table_values
def test_worker_pull_embedding(self): model_def = "mnist_functional_api.mnist_functional_api.custom_model" self._create_pserver(model_def, 2) arguments = [ "--worker_id", 0, "--job_type", elasticdl_pb2.TRAINING, "--minibatch_size", self._batch_size, "--model_zoo", self._model_zoo_path, "--model_def", model_def, "--distribution_strategy", DistributionStrategy.PARAMETER_SERVER, ] args = parse_worker_args(arguments) worker = Worker(args, ps_channels=self._channels) # Test lookup embedding vectors that do not exist layers = ["test-2", "test-2-slot"] ids = [3, 5, 1, 6, 10, 2, 1, 2, 4, 7, 9] embedding_table_args = [ (layers[0], 8, "uniform", False), (layers[1], 8, 3.3, True), ] # initialize embedding table object for pserver in self._pservers: for layer, table_args in zip(layers, embedding_table_args): pserver.parameters.embedding_params[layer] = EmbeddingTable( *table_args ) result_dict = {} for layer in layers: embedding = worker.pull_embedding_vectors(layer, ids) result_dict[layer] = embedding for layer in layers: expected_result = [] for embedding_id in ids: ps_id = int_to_id(embedding_id, len(self._pservers)) table = self._pservers[ps_id].parameters.embedding_params[ layer ] expected_result.append(table.get([embedding_id])) expected_result = np.concatenate(expected_result) self.assertTrue(np.allclose(expected_result, result_dict[layer]))