Esempio n. 1
0
def get_params_shard_from_pb(model_pb, shard_index, shard_num):
    """Get parameters including variables values and embedding table
    from a model protobuf.

    Args:
        model_pb: A Model protobuf instance.
        shard_index: Model shard index.
        shard_num: The total number of model shards.

    Return:
        non_embedding_vars: A Python dict in which the key is a variable
            name and the value is a `tf.Variable` object.
        embedding_table_values: A Python dict in which the key is an embedding
            table name and the value is a tuple with 2 elements. The value[0]
            is indices and value[1] is the corresponding embedding vector.
    """
    non_embedding_vars = {}
    embedding_table_values = {}

    for tensor_pb in model_pb.param:
        tensor = Tensor.from_tensor_pb(tensor_pb)
        if tensor.indices is not None:
            embedding_table_values.setdefault(tensor.name, ([], []))
            for embedding_id, vector in zip(tensor.indices, tensor.values):
                if int_to_id(embedding_id, shard_num) == shard_index:
                    embedding_table_values[tensor.name][0].append(embedding_id)
                    embedding_table_values[tensor.name][1].append(vector)
        else:
            if string_to_id(tensor.name, shard_num) == shard_index:
                non_embedding_vars[tensor.name] = tf.Variable(
                    initial_value=tensor.values, trainable=True)
    return non_embedding_vars, embedding_table_values
Esempio n. 2
0
def _get_params_shard_from_pb(model_pb, shard_index, shard_num):
    """Get parameters including variables values and embedding table
    from a model protobuf.
    Args:
        model_pb: A Model protobuf instance.
        shard_index: Model shard index.
        shard_num: The total number of model shards.
    Return:
        non_embedding_vars: A Python dict in which the key is a variable
            name and the value is a `tf.Variable` object.
        embedding_table_values: A Python dict in which the key is an embedding
            table name and the value is a tuple with 2 elements. The value[0]
            is indices and value[1] is the corresponding embedding vector.
    """
    non_embedding_vars = {}
    embedding_table_values = {}

    for name, pb in model_pb.dense_parameters.items():
        if string_to_id(name, shard_num) == shard_index:
            non_embedding_vars[name] = tf.Variable(
                initial_value=pb_to_ndarray(pb), trainable=True)
    for name, pb in model_pb.embedding_tables.items():
        embedding_table_values.setdefault(name, ([], []))
        t = pb_to_indexed_slices(pb)
        for embedding_id, vector in zip(t.indices, t.values):
            if int_to_id(embedding_id, shard_num) == shard_index:
                embedding_table_values[name][0].append(embedding_id)
                embedding_table_values[name][1].append(vector)
    return non_embedding_vars, embedding_table_values
Esempio n. 3
0
 def init_ps_var_partition(self):
     ps_vars = {}
     for v in self._non_embed_vars.values():
         if v.name not in self._var_to_ps:
             self._var_to_ps[v.name] = string_to_id(v.name, self._ps_num)
         ps_id = self._var_to_ps[v.name]
         if ps_id not in ps_vars:
             ps_vars[ps_id] = [v]
         else:
             ps_vars[ps_id].append(v)
     self._ps_vars = ps_vars
Esempio n. 4
0
 def partition_dense_parameters(self, param_names):
     """
     Partition dense parameters to PS
     ps_id = string_to_id(param_name)
     """
     for name in param_names:
         if name not in self.parameter_to_ps:
             self.parameter_to_ps[name] = string_to_id(name, self.ps_num)
             ps_id = self.parameter_to_ps[name]
             if ps_id not in self.ps_to_parameter:
                 self.ps_to_parameter[ps_id] = [name]
             else:
                 self.ps_to_parameter[ps_id].append(name)
Esempio n. 5
0
    def test_compare_onebatch_train(self):
        model_def = "mnist_functional_api.mnist_functional_api.custom_model"
        self._create_pserver(model_def, 2)
        images, labels = get_random_batch(self._batch_size)
        # TODO(yunjian.lmh): test optimizer wrapper
        arguments = [
            "--worker_id",
            0,
            "--job_type",
            elasticdl_pb2.TRAINING,
            "--minibatch_size",
            self._batch_size,
            "--model_zoo",
            self._model_zoo_path,
            "--model_def",
            model_def,
            "--distribution_strategy",
            DistributionStrategy.PARAMETER_SERVER,
        ]
        args = parse_worker_args(arguments)

        tf.keras.backend.clear_session()
        tf.random.set_seed(22)

        worker = Worker(args, ps_channels=self._channels)
        worker._run_model_call_before_training(images)
        worker.get_model()
        w_loss, w_grads = worker.training_process_eagerly(images, labels)
        worker.report_gradient(w_grads)

        tf.keras.backend.clear_session()
        tf.random.set_seed(22)

        (
            model,
            dataset_fn,
            loss_fn,
            opt_fn,
            eval_metrics_fn,
            prediction_outputs_processor,
            create_data_reader_fn,
            callback_list,
        ) = get_model_spec(
            model_zoo=self._model_zoo_path,
            model_def=model_def,
            dataset_fn="dataset_fn",
            model_params=None,
            loss="loss",
            optimizer="optimizer",
            eval_metrics_fn="eval_metrics_fn",
            prediction_outputs_processor="PredictionOutputsProcessor",
            custom_data_reader="custom_data_reader",
            callbacks="callbacks",
        )

        with tf.GradientTape() as tape:
            output = model.call(images, training=True)
            labels = tf.reshape(labels, [-1])
            loss = loss_fn(labels, output)
        grads = tape.gradient(loss, model.trainable_variables)
        opt_fn().apply_gradients(zip(grads, model.trainable_variables))

        for v in model.trainable_variables:
            ps_id = string_to_id(v.name, len(self._channels))
            ps_v = self._pservers[ps_id].parameters.get_non_embedding_param(
                v.name)
            np.testing.assert_array_equal(ps_v.numpy(), v.numpy())