示例#1
0
def train(
    rdd: RDD,
    torch_obj: str,
    server: Server,
    iters: int = 10,
    partition_shuffles: int = 1,
    verbose: int = 1,
    early_stop_patience: int = -1,
    mini_batch: int = -1,
    validation_pct: float = 0.0
) -> Dict:
    try:
        torch_obj, _ = load_base_torch(torch_obj)

        master_url = str(server.master_url)

        for i in range(partition_shuffles):
            rdd.mapPartitions(
                lambda x: handle_model(
                    x,
                    torch_obj=torch_obj,
                    master_url=master_url,
                    iters=iters,
                    verbose=verbose,
                    early_stop_patience=early_stop_patience,
                    mini_batch=mini_batch,
                    validation_pct=validation_pct
                )
            ).foreach(lambda x: x)

            if partition_shuffles - i > 1:
                num_partitions = rdd.getNumPartitions()
                rdd = rdd.repartition(num_partitions)

        state_dict = get_state_dict(master_url)
        server.stop_server()

        return state_dict

    except Exception as e:
        server.stop_server()
        raise e
示例#2
0
def train_distributed(
    rdd: RDD,
    torch_obj: str,
    iters: int = 10,
    partition_shuffles: int = 1,
    verbose: int = 1,
    mini_batch: int = -1,
    validation_pct: float = 0.0,
    world_size: int = 2,
    device: str = 'cpu',
    early_stop_patience: int = -1
) -> Dict:
    """
    Entry point to train the model in distributed fashion.

    :param rdd: The rdd of data to run on the model.
    :param torch_obj: The torch object as a string that includes the model and param shapes.
    :param master_url: The main url for the driver.
    :param iters: Number of iterations for training.
    :param partition_shuffles: Number of partition shuffles (Need to implement)
    :param verbose: Verbosity of logs
    :param mini_batch: Mini batch for each iteration of training.
    :param validation_pct: How many items to validate
    :param world_size: number of partitions.
    :param device: pytorch device

    :return: The train dict.
    """
    master_url = retrieve_url()

    torch_loaded, params = load_base_torch(torch_obj)

    # Start the driver process.
    p = Process(
        target=handle_model,
        args=(-1, None, params, master_url, iters, world_size, early_stop_patience)
    )
    p.start()

    try:
        state_dict = None
        for i in range(partition_shuffles):

            # Run model with barrier execution mode.
            state_dict = mapPartitionsWithIndex(
                rdd, lambda i, x: handle_model(
                    i,
                    x,
                    torch_obj=torch_loaded,
                    master_url=master_url,
                    iters=iters,
                    verbose=verbose,
                    mini_batch=mini_batch,
                    validation_pct=validation_pct,
                    world_size=world_size,
                    device=device,
                    early_stop_patience=int(early_stop_patience+0)
                )
            ).collect()

            if partition_shuffles - i > 1:
                num_partitions = rdd.getNumPartitions()
                rdd = rdd.repartition(num_partitions)

        return state_dict[0]

    finally:
        p.terminate()
        p.join()