def run_partition_server(
    config: ConfigSchema,
    rank: Rank = RANK_ZERO,
    subprocess_init: Optional[Callable[[], None]] = None,
) -> None:
    tag_logs_with_process_name(f"PartS-{rank}")
    if config.num_partition_servers <= 0:
        raise RuntimeError("Config doesn't require explicit partition servers")
    if not 0 <= rank < config.num_partition_servers:
        raise RuntimeError("Invalid rank for partition server")
    if not td.is_available():
        raise RuntimeError("The installed PyTorch version doesn't provide "
                           "distributed training capabilities.")
    ranks = ProcessRanks.from_num_invocations(config.num_machines,
                                              config.num_partition_servers)
    if subprocess_init is not None:
        subprocess_init()
    init_process_group(
        rank=ranks.partition_servers[rank],
        world_size=ranks.world_size,
        init_method=config.distributed_init_method,
        groups=[ranks.trainers],
    )
    ps = ParameterServer(num_clients=len(ranks.trainers))
    ps.start()
def run_partition_server(config, rank=0):
    if config.num_partition_servers <= 0:
        raise RuntimeError("Config doesn't require explicit partition servers")
    if not 0 <= rank < config.num_partition_servers:
        raise RuntimeError("Invalid rank for partition server")
    if not td.is_available():
        raise RuntimeError("The installed PyTorch version doesn't provide "
                           "distributed training capabilities.")
    ranks = ProcessRanks.from_num_invocations(
        config.num_machines, config.num_partition_servers)
    init_process_group(
        rank=ranks.partition_servers[rank],
        world_size=ranks.world_size,
        init_method=config.distributed_init_method,
        groups=[ranks.trainers],
    )
    ps = ParameterServer(num_clients=len(ranks.trainers))
    ps.start()
def run_partition_server(
    config: ConfigSchema,
    rank: Rank = SINGLE_TRAINER,
    subprocess_init: Optional[Callable[[], None]] = None,
) -> None:
    tag_logs_with_process_name(f"PartS-{rank}")
    if config.num_partition_servers <= 0:
        raise RuntimeError("Config doesn't require explicit partition servers")
    if not 0 <= rank < config.num_partition_servers:
        raise RuntimeError("Invalid rank for partition server")
    if not td.is_available():
        raise RuntimeError("The installed PyTorch version doesn't provide "
                           "distributed training capabilities.")
    ranks = ProcessRanks.from_num_invocations(config.num_machines,
                                              config.num_partition_servers)

    num_ps_groups = config.num_groups_for_partition_server
    groups: List[List[int]] = [ranks.trainers]  # barrier group
    groups += [ranks.trainers + ranks.partition_servers
               ] * num_ps_groups  # ps groups
    group_idxs_for_partition_servers = range(1, len(groups))

    if subprocess_init is not None:
        subprocess_init()
    groups = init_process_group(
        rank=ranks.partition_servers[rank],
        world_size=ranks.world_size,
        init_method=config.distributed_init_method,
        groups=groups,
    )
    ps = ParameterServer(
        num_clients=len(ranks.trainers),
        group_idxs=group_idxs_for_partition_servers,
        log_stats=True,
    )
    ps.start(groups)
    logger.info("ps.start done")