示例#1
0
    def _update_model(self):
        assert self._lock.locked()
        grad_var = []

        # (grad, var) pairs excluding keras Embedding layer and
        # ElasticDL Embedding layer
        for k in self._gradient_sum:
            self._gradient_sum[k] = self._gradient_sum[k] / self._grad_to_wait
            grad_var.append((self._gradient_sum[k], self._model[k]))

        # (grad, var) pair of Keras Embedding layer
        for k in self._gradient_sum_indexed:
            grad_var.append((self._gradient_sum_indexed[k], self._model[k]))

        # (grad, var) pair of ElasticDL Embedding layer
        edl_embedding_offset = len(grad_var)
        unique_ids_list = []
        if self._edl_embedding_gradients:
            for layer_name, grads in self._edl_embedding_gradients.items():
                unique_ids, idx = tf.unique(grads.indices)
                unique_ids_list.append(unique_ids)
                grads_idx_transformed = tf.IndexedSlices(grads.values, idx)
                keys = [
                    Embedding.get_key([layer_name, i])
                    for i in unique_ids.numpy()
                ]
                embeddings, unknown_keys = EmbeddingService.lookup_embedding(
                    embedding_service_endpoint=(
                        self._embedding_service_endpoint),
                    keys=keys,
                )
                if unknown_keys:
                    raise RuntimeError(
                        "Master reviced %d unknown embedding keys: %s ..." %
                        (len(unknown_keys), str(unknown_keys[0])))
                if not embeddings:
                    continue
                embeddings = np.concatenate(embeddings,
                                            axis=0).reshape(len(keys), -1)
                embedding_var = tf.Variable(embeddings)
                grad_var.append((grads_idx_transformed, embedding_var))

        # TODO: support optimizer with slots such as Adam, FTRL
        self._opt.apply_gradients(grad_var)

        # report updated embedding table to EmbeddingService
        self._update_edl_embedding_table(
            zip(
                self._edl_embedding_gradients.keys(),
                unique_ids_list,
                [v for g, v in grad_var[edl_embedding_offset:]],
            ))
        self._update_model_version()
        self._gradient_sum.clear()
        self._gradient_sum_indexed.clear()
        self._edl_embedding_gradients.clear()
        self._grad_n = 0
示例#2
0
    def _report_to_kv_store(self):
        """Report updated embedding vectors and slots to kv store."""
        keys = []
        values = []
        for layer, ids in self._unique_ids_all_layers.items():
            value = self._get_embedding_variable(layer).numpy()
            for id, v in zip(ids, value):
                keys.append(Embedding.get_key([layer, id]))
                values.append(v)

            for slot in self._allowed_slot_names:
                value = self._get_slot_variable(layer, slot).numpy()
                for id, v in zip(ids, value):
                    keys.append(Embedding.get_key([layer, slot, id]))
                    values.append(v)

        EmbeddingService.update_embedding(keys, values,
                                          self._kv_store_endpoint)
示例#3
0
 def test_embedding_service(self):
     with tempfile.TemporaryDirectory() as temp_dir:
         embedding_endpoint = start_redis_instances(temp_dir)
         # start
         embedding_service = EmbeddingService(embedding_endpoint)
         embedding_endpoint = embedding_service._create_redis_cluster()
         # wait for cluster up-running
         time.sleep(1)
         self.assertFalse(embedding_endpoint is None)
         # connection
         redis_cluster = embedding_service._get_embedding_cluster()
         self.assertFalse(redis_cluster is None)
         # set value to a key
         self.assertTrue(redis_cluster.set("test_key", "OK", nx=True))
         # set value to a key existed
         self.assertTrue(
             redis_cluster.set("test_key", "OK", nx=True) is None)
         self.assertEqual(b"OK", redis_cluster.get("test_key"))
         # close
         self.assertTrue(embedding_service.stop_embedding_service())
示例#4
0
    def _update_edl_embedding_table(self, name_var_list):
        """
            Put updated embedding vectors' ids and values together
            and use EmbeddingService.update_embedding() to update
            embedding table in the distributed storage
        """
        keys = []
        embeddings = []
        for layer_name, unique_ids, embedding_var in name_var_list:
            keys.extend([
                Embedding.get_key([layer_name, i]) for i in unique_ids.numpy()
            ])
            embeddings.extend([i for i in embedding_var.numpy()])

        if embeddings:
            EmbeddingService.update_embedding(
                keys=keys,
                embedding_vectors=embeddings,
                embedding_service_endpoint=self._embedding_service_endpoint,
            )
示例#5
0
 def lookup_embedding(self,
                      ids,
                      layer_name,
                      initializer="uniform",
                      embedding_table_dim=128):
     keys = [Embedding.get_key([layer_name, id]) for id in ids]
     ES_lookup_embedding = EmbeddingService.lookup_embedding
     embedding_vectors, unknown_keys_index = ES_lookup_embedding(
         keys=keys,
         embedding_service_endpoint=self._embedding_service_endpoint,
     )
     if unknown_keys_index:
         # Initialize unknown_keys' embedding vectors and write into Redis.
         unknown_keys = [keys[index] for index in unknown_keys_index]
         initializer = tf.keras.initializers.get(initializer)
         embedding_vector_init = [
             initializer(shape=[1, embedding_table_dim]).numpy()
             for _ in unknown_keys
         ]
         embedding_vector_init = np.concatenate(embedding_vector_init,
                                                axis=0)
         EmbeddingService.update_embedding(
             keys=unknown_keys,
             embedding_vectors=embedding_vector_init,
             embedding_service_endpoint=self._embedding_service_endpoint,
             set_if_not_exist=True,
         )
         # Lookup unknown_keys' embedding vectors
         embedding_vectors_new, unknown_keys_idx_new = ES_lookup_embedding(
             keys=unknown_keys,
             embedding_service_endpoint=self._embedding_service_endpoint,
         )
         if unknown_keys_idx_new:
             raise Exception("Update embedding vector: %s failed." % str(
                 [unknown_keys[index] for index in unknown_keys_idx_new]))
         for key_index, vector in zip(unknown_keys_index,
                                      embedding_vectors_new):
             embedding_vectors[key_index] = vector
     embedding_vectors = np.concatenate(embedding_vectors, axis=0)
     return embedding_vectors.reshape((len(keys), embedding_table_dim))
示例#6
0
    def _lookup_embeddings_and_slots(self, grads_and_vars):
        """Look up embedding vectors and slot values form kv store.

        This function looks up embedding vectors and slot values.
        It initializes unknown slot if exist.

        Arguments:
            grads_and_vars: A list of (gradient, layer name) pairs.

        Returns:
            A tuple of (`embedding_values`, `slot_values`). `embedding_values`
            is a python dictionary of {layer name: `embedding_vectors`} where
            `embedding_vectors` is a 2D `numpy.ndarray`. `slot_values` is a
            python dictionary of {layer name: {slot name: `slot_values`}}
            where `slot_values` is a 2D `numpy.ndarray`.

        Raises:
            RuntimeError: If any unknown embedding key exists.
        """

        arr = self._generate_lookup_keys(grads_and_vars)
        embed_keys, slot_keys, embed_key_index, slot_key_index = arr

        keys = embed_keys + slot_keys
        embed_keys_num = len(embed_keys)
        values, unknown_keys = EmbeddingService.lookup_embedding(
            keys=keys, embedding_service_endpoint=self._kv_store_endpoint)

        if unknown_keys:
            # raise Error if an unknown embedding key exists
            if unknown_keys[0] < embed_keys_num:
                raise RuntimeError("Failed to get key %s from kv store." %
                                   embed_keys[unknown_keys[0]])

            # initialize unknown slots
            for idx in unknown_keys:
                key = keys[idx]
                layer_name = _get_embedding_layer_name_from_key(key)
                slot_name = _get_slot_name_from_key(key)
                values[idx] = self._initialize_unknown_slot(
                    layer_name, slot_name)

        embed_values = _parse_lookup_values(values[:embed_keys_num],
                                            embed_key_index)
        slot_values = _parse_lookup_values(values[embed_keys_num:],
                                           slot_key_index)
        return embed_values, slot_values
示例#7
0
    def test_lookup_and_update_embedding(self):
        with tempfile.TemporaryDirectory() as temp_dir:
            embedding_endpoint = start_redis_instances(temp_dir)
            # start
            embedding_service = EmbeddingService(embedding_endpoint)
            embedding_endpoint = embedding_service._create_redis_cluster()
            # wait for cluster up-running
            time.sleep(1)
            origin_data = np.random.rand(100, 10).astype(np.float32)
            keys = ["test_%d" % i for i in range(origin_data.shape[0])]

            EmbeddingService.update_embedding(keys, origin_data,
                                              embedding_endpoint)
            lookup_data, unknown_keys_idx = EmbeddingService.lookup_embedding(
                keys, embedding_endpoint, parse_type=np.float32)
            self.assertTrue(len(unknown_keys_idx) == 0)
            output_length = len(keys)
            lookup_data = np.concatenate(lookup_data, axis=0)
            lookup_data = lookup_data.reshape((output_length, -1))
            self.assertTrue(np.equal(origin_data, lookup_data).all())

            # Test set_if_not_exist
            origin_data_2 = np.random.rand(100, 10).astype(np.float32)
            self.assertFalse(np.equal(origin_data, origin_data_2).all())
            EmbeddingService.update_embedding(keys,
                                              origin_data_2,
                                              embedding_endpoint,
                                              set_if_not_exist=True)
            lookup_data, unknown_keys_idx = EmbeddingService.lookup_embedding(
                keys, embedding_endpoint, parse_type=np.float32)
            lookup_data = np.concatenate(lookup_data, axis=0)
            lookup_data = lookup_data.reshape((output_length, -1))
            self.assertTrue(np.equal(origin_data, lookup_data).all())
            self.assertFalse(np.equal(origin_data_2, lookup_data).all())

            # Test non-exist keys
            keys_do_not_exist = ["test_no_exist_%d" % i for i in range(10)]
            lookup_data, unknown_keys_idx = EmbeddingService.lookup_embedding(
                keys_do_not_exist, embedding_endpoint, parse_type=np.float32)
            self.assertTrue(len(unknown_keys_idx) == 10)
            self.assertTrue(len(lookup_data) == 10)
            # Close
            self.assertTrue(embedding_service.stop_embedding_service())
示例#8
0
def main():
    args = parse_args()
    logger = get_logger("master", level=args.log_level.upper())

    # Master addr
    master_ip = os.getenv("MY_POD_IP", "localhost")
    master_addr = "%s:%d" % (master_ip, args.port)

    # Start tensorboard service if required
    if args.tensorboard_log_dir:
        logger.info(
            "Starting tensorboard service with log directory %s",
            args.tensorboard_log_dir,
        )
        tb_service = TensorboardService(args.tensorboard_log_dir, master_ip)
        tb_service.start()
    else:
        tb_service = None

    # Start task queue
    logger.debug(
        "Starting task queue with training data directory %s, "
        "evaluation data directory %s, "
        "and prediction data directory %s",
        args.training_data_dir,
        args.evaluation_data_dir,
        args.prediction_data_dir,
    )
    task_d = _make_task_dispatcher(
        args.training_data_dir,
        args.evaluation_data_dir,
        args.prediction_data_dir,
        args.records_per_task,
        args.num_epochs,
    )
    model_module = load_module(
        get_module_file_path(args.model_zoo, args.model_def)).__dict__
    model_inst = load_model_from_module(args.model_def, model_module,
                                        args.model_params)
    optimizer = model_module[args.optimizer]()

    if all((
            args.training_data_dir,
            args.evaluation_data_dir,
            args.evaluation_throttle_secs or args.evaluation_steps,
    )):
        job_type = JobType.TRAINING_WITH_EVALUATION
    elif all((
            args.evaluation_data_dir,
            not args.training_data_dir,
            not args.prediction_data_dir,
    )):
        job_type = JobType.EVALUATION_ONLY
    elif all((
            args.prediction_data_dir,
            not args.evaluation_data_dir,
            not args.training_data_dir,
    )):
        job_type = JobType.PREDICTION_ONLY
    else:
        job_type = JobType.TRAINING_ONLY

    # Initialize checkpoint service
    if args.checkpoint_steps or job_type == JobType.TRAINING_WITH_EVALUATION:
        logger.info("Starting checkpoint service")
        checkpoint_service = CheckpointService(
            args.checkpoint_dir,
            args.checkpoint_steps,
            args.keep_checkpoint_max,
            job_type == JobType.TRAINING_WITH_EVALUATION,
        )
    else:
        checkpoint_service = None

    # Initialize evaluation service
    evaluation_service = None
    if (job_type == JobType.TRAINING_WITH_EVALUATION
            or job_type == JobType.EVALUATION_ONLY):
        logger.info(
            "Starting evaluation service with throttle seconds %d "
            " and evaluation steps %d",
            args.evaluation_throttle_secs,
            args.evaluation_steps,
        )
        evaluation_service = EvaluationService(
            checkpoint_service,
            tb_service,
            task_d,
            args.evaluation_start_delay_secs,
            args.evaluation_throttle_secs,
            args.evaluation_steps,
            job_type == JobType.EVALUATION_ONLY,
        )
        evaluation_service.start()
        task_d.set_evaluation_service(evaluation_service)

    embedding_service_endpoint = None
    # Search for embedding layers in the model,
    # if found, initialize embedding service
    layers = find_layer(model_inst, Embedding)
    if layers:
        embedding_service = EmbeddingService()
        embedding_service_endpoint = embedding_service.start_embedding_service(
            job_name=args.job_name,
            image_name=args.worker_image,
            namespace=args.namespace,
            resource_request=args.master_resource_request,
            resource_limit=args.master_resource_limit,
            pod_priority=args.worker_pod_priority,
            volume=args.volume,
            image_pull_policy=args.image_pull_policy,
            restart_policy=args.restart_policy,
            cluster_spec=args.cluster_spec,
        )
        logger.info("Embedding service start succeeded. The endpoint is %s." %
                    str(embedding_service_endpoint))

    # The master service
    logger.info("Starting master service")
    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=64),
        options=[
            ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
            (
                "grpc.max_receive_message_length",
                GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
            ),
        ],
    )
    master_servicer = MasterServicer(
        args.grads_to_wait,
        args.minibatch_size,
        optimizer,
        task_d,
        init_var=model_inst.trainable_variables if model_inst.built else [],
        checkpoint_filename_for_init=args.checkpoint_filename_for_init,
        checkpoint_service=checkpoint_service,
        evaluation_service=evaluation_service,
        embedding_service_endpoint=embedding_service_endpoint,
    )
    elasticdl_pb2_grpc.add_MasterServicer_to_server(master_servicer, server)
    server.add_insecure_port("[::]:{}".format(args.port))
    server.start()
    logger.info("Server started at port: %d", args.port)

    worker_manager = None
    if args.num_workers:
        assert args.worker_image, "Worker image cannot be empty"

        worker_command = ["python"]
        worker_args = [
            "-m",
            "elasticdl.python.worker.main",
            "--model_zoo",
            args.model_zoo,
            "--master_addr",
            master_addr,
            "--log_level",
            args.log_level,
            "--dataset_fn",
            args.dataset_fn,
            "--loss",
            args.loss,
            "--optimizer",
            args.optimizer,
            "--eval_metrics_fn",
            args.eval_metrics_fn,
            "--model_def",
            args.model_def,
            "--job_type",
            job_type,
            "--minibatch_size",
            str(args.minibatch_size),
            "--embedding_service_endpoint",
            str(embedding_service_endpoint),
        ]

        logger.info(">>> master pod envs argument is %s" % args.envs)
        env_dict = parse_envs(args.envs)
        env = []
        for key in env_dict:
            env.append(V1EnvVar(name=key, value=env_dict[key]))

        worker_manager = WorkerManager(
            task_d,
            job_name=args.job_name,
            image_name=args.worker_image,
            command=worker_command,
            args=worker_args,
            namespace=args.namespace,
            num_workers=args.num_workers,
            worker_resource_request=args.worker_resource_request,
            worker_resource_limit=args.worker_resource_limit,
            pod_priority=args.worker_pod_priority,
            volume=args.volume,
            image_pull_policy=args.image_pull_policy,
            restart_policy=args.restart_policy,
            cluster_spec=args.cluster_spec,
            envs=env,
        )
        worker_manager.update_status(WorkerManagerStatus.PENDING)
        logger.info("Launching %d workers", args.num_workers)
        worker_manager.start_workers()
        worker_manager.update_status(WorkerManagerStatus.RUNNING)

        if tb_service:
            worker_manager.start_tensorboard_service()

    try:
        while True:
            if task_d.finished():
                if worker_manager:
                    worker_manager.update_status(WorkerManagerStatus.FINISHED)
                if args.output:
                    master_servicer.save_latest_checkpoint(args.output)
                break
            time.sleep(30)
    except KeyboardInterrupt:
        logger.warning("Server stopping")

    if evaluation_service:
        logger.info("Stopping evaluation service")
        evaluation_service.stop()

    logger.info("Stopping RPC server")
    server.stop(0)

    # Keep TensorBoard running when all the tasks are finished
    if tb_service:
        logger.info(
            "All tasks finished. Keeping TensorBoard service running...")
        while True:
            if tb_service.is_active():
                time.sleep(10)
            else:
                logger.warning("Unable to keep TensorBoard running. "
                               "It has already terminated")
                break
    logger.info("Master stopped")