def get_params_shard_from_pb(model_pb, shard_index, shard_num): """Get parameters including variables values and embedding table from a model protobuf. Args: model_pb: A Model protobuf instance. shard_index: Model shard index. shard_num: The total number of model shards. Return: non_embedding_vars: A Python dict in which the key is a variable name and the value is a `tf.Variable` object. embedding_table_values: A Python dict in which the key is an embedding table name and the value is a tuple with 2 elements. The value[0] is indices and value[1] is the corresponding embedding vector. """ non_embedding_vars = {} embedding_table_values = {} for tensor_pb in model_pb.param: tensor = Tensor.from_tensor_pb(tensor_pb) if tensor.indices is not None: embedding_table_values.setdefault(tensor.name, ([], [])) for embedding_id, vector in zip(tensor.indices, tensor.values): if int_to_id(embedding_id, shard_num) == shard_index: embedding_table_values[tensor.name][0].append(embedding_id) embedding_table_values[tensor.name][1].append(vector) else: if string_to_id(tensor.name, shard_num) == shard_index: non_embedding_vars[tensor.name] = tf.Variable( initial_value=tensor.values, trainable=True) return non_embedding_vars, embedding_table_values
def _get_params_shard_from_pb(model_pb, shard_index, shard_num): """Get parameters including variables values and embedding table from a model protobuf. Args: model_pb: A Model protobuf instance. shard_index: Model shard index. shard_num: The total number of model shards. Return: non_embedding_vars: A Python dict in which the key is a variable name and the value is a `tf.Variable` object. embedding_table_values: A Python dict in which the key is an embedding table name and the value is a tuple with 2 elements. The value[0] is indices and value[1] is the corresponding embedding vector. """ non_embedding_vars = {} embedding_table_values = {} for name, pb in model_pb.dense_parameters.items(): if string_to_id(name, shard_num) == shard_index: non_embedding_vars[name] = tf.Variable( initial_value=pb_to_ndarray(pb), trainable=True) for name, pb in model_pb.embedding_tables.items(): embedding_table_values.setdefault(name, ([], [])) t = pb_to_indexed_slices(pb) for embedding_id, vector in zip(t.indices, t.values): if int_to_id(embedding_id, shard_num) == shard_index: embedding_table_values[name][0].append(embedding_id) embedding_table_values[name][1].append(vector) return non_embedding_vars, embedding_table_values
def init_ps_var_partition(self): ps_vars = {} for v in self._non_embed_vars.values(): if v.name not in self._var_to_ps: self._var_to_ps[v.name] = string_to_id(v.name, self._ps_num) ps_id = self._var_to_ps[v.name] if ps_id not in ps_vars: ps_vars[ps_id] = [v] else: ps_vars[ps_id].append(v) self._ps_vars = ps_vars
def partition_dense_parameters(self, param_names): """ Partition dense parameters to PS ps_id = string_to_id(param_name) """ for name in param_names: if name not in self.parameter_to_ps: self.parameter_to_ps[name] = string_to_id(name, self.ps_num) ps_id = self.parameter_to_ps[name] if ps_id not in self.ps_to_parameter: self.ps_to_parameter[ps_id] = [name] else: self.ps_to_parameter[ps_id].append(name)
def test_compare_onebatch_train(self): model_def = "mnist_functional_api.mnist_functional_api.custom_model" self._create_pserver(model_def, 2) images, labels = get_random_batch(self._batch_size) # TODO(yunjian.lmh): test optimizer wrapper arguments = [ "--worker_id", 0, "--job_type", elasticdl_pb2.TRAINING, "--minibatch_size", self._batch_size, "--model_zoo", self._model_zoo_path, "--model_def", model_def, "--distribution_strategy", DistributionStrategy.PARAMETER_SERVER, ] args = parse_worker_args(arguments) tf.keras.backend.clear_session() tf.random.set_seed(22) worker = Worker(args, ps_channels=self._channels) worker._run_model_call_before_training(images) worker.get_model() w_loss, w_grads = worker.training_process_eagerly(images, labels) worker.report_gradient(w_grads) tf.keras.backend.clear_session() tf.random.set_seed(22) ( model, dataset_fn, loss_fn, opt_fn, eval_metrics_fn, prediction_outputs_processor, create_data_reader_fn, callback_list, ) = get_model_spec( model_zoo=self._model_zoo_path, model_def=model_def, dataset_fn="dataset_fn", model_params=None, loss="loss", optimizer="optimizer", eval_metrics_fn="eval_metrics_fn", prediction_outputs_processor="PredictionOutputsProcessor", custom_data_reader="custom_data_reader", callbacks="callbacks", ) with tf.GradientTape() as tape: output = model.call(images, training=True) labels = tf.reshape(labels, [-1]) loss = loss_fn(labels, output) grads = tape.gradient(loss, model.trainable_variables) opt_fn().apply_gradients(zip(grads, model.trainable_variables)) for v in model.trainable_variables: ps_id = string_to_id(v.name, len(self._channels)) ps_v = self._pservers[ps_id].parameters.get_non_embedding_param( v.name) np.testing.assert_array_equal(ps_v.numpy(), v.numpy())