Exemple #1
0
def grad_server_helper(model_creators: List[Callable],
                       optimizer: Any = Adam,
                       learning_rate: float = 1e-3):
    """
    Helper function for creating a tuple of grad servers,
    used by A3C, IMPALE, etc. This function requires all processes
    in the world to enter.

    Warning:
        You should never run this function twice!

    Args:
        model_creators: A list of model creator functions,
            each one corresponds to one gradient reduction server.
        optimizer: Optimizer type, default is Adam.
        learning_rate: Learning rate of the optimizer.

    Returns:
        A tuple of accessors to gradient servers, the tuple has the
        same size as ``model_creators``
    """
    # Note:
    # passing a list of creator functions instead of passing a list of models
    # directly is designed to remove the unnecessary model creation cost on
    # not-the-primary-reducer processes.
    DEFAULT_GROUP_NAME = "server_group"

    # create groups first
    world = get_world()
    server_group = world.create_rpc_group(DEFAULT_GROUP_NAME,
                                          world.get_members())

    # create servers
    primary_reducer = world.get_members()[0]
    servers = [
        PushPullGradServerImpl("grad_server_" + str(i),
                               server_group,
                               primary_reducer=primary_reducer)
        for i in range(len(model_creators))
    ]
    if get_cur_name() == primary_reducer:
        for model_creator, server in zip(model_creators, servers):
            model = model_creator()
            server.manage_model(model,
                                optimizer(model.parameters(),
                                          lr=learning_rate))
            server.start()

    server_group.barrier()
    servers = tuple(
        server_group.get_paired("grad_server_" + str(i)).to_here()
        for i in range(len(model_creators))
    )

    # accessors instead of actual implementation instance
    # will be returned because of __reduce__
    return servers
Exemple #2
0
    def init_from_config(cls, config: Union[Dict[str, Any], Config]):
        world = get_world()
        f_config = deepcopy(config["frame_config"])
        apex_group = world.create_rpc_group(
            group_name=f_config["apex_group_name"],
            members=(
                world.get_members()
                if f_config["apex_members"] == "all"
                else f_config["apex_members"]
            ),
        )

        models = assert_and_get_valid_models(f_config["models"])
        model_args = f_config["model_args"]
        model_kwargs = f_config["model_kwargs"]
        models = [
            m(*arg, **kwarg) for m, arg, kwarg in zip(models, model_args, model_kwargs)
        ]
        # wrap models in DistributedDataParallel when running in learner mode
        max_learner_id = f_config["learner_process_number"]

        learner_group = world.create_collective_group(ranks=list(range(max_learner_id)))

        if world.rank < max_learner_id:
            models = [
                DistributedDataParallel(module=m, process_group=learner_group.group)
                for m in models
            ]

        optimizer = assert_and_get_valid_optimizer(f_config["optimizer"])
        criterion = assert_and_get_valid_criterion(f_config["criterion"])(
            *f_config["criterion_args"], **f_config["criterion_kwargs"]
        )
        criterion.reduction = "none"
        lr_scheduler = f_config["lr_scheduler"] and assert_and_get_valid_lr_scheduler(
            f_config["lr_scheduler"]
        )
        servers = model_server_helper(
            model_num=1,
            group_name=f_config["model_server_group_name"],
            members=f_config["model_server_members"],
        )
        del f_config["optimizer"]
        del f_config["criterion"]
        del f_config["lr_scheduler"]
        frame = cls(
            *models,
            optimizer,
            criterion,
            apex_group,
            servers,
            lr_scheduler=lr_scheduler,
            **f_config
        )
        if world.rank >= max_learner_id:
            frame.update = lambda *_, **__: (None, None)
        return frame
Exemple #3
0
    def training_step(self, batch, _batch_idx):
        world_inited = get_world() is not None
        model_inited = isinstance(self.nn_model, NNModule)

        if world_inited and get_cur_rank() == 0:
            with open(os.environ["TEST_SAVE_PATH"], "wb") as f:
                pickle.dump([model_inited], f)
        if not world_inited:
            raise RuntimeError("World not initialized.")
        return None
Exemple #4
0
    def init_frame(self):
        # Called by overloaded pytorch lightning DDP plugins
        if self.frame is None:
            # initialize framework
            self.frame = init_algorithm_from_config(self.config,
                                                    model_device=self.device)
            self._batch_num = self.config["batch_num"].get(self.frame.role, 1)
            # forward models to the launcher module, so that parameters are handled.
            for name, model in zip(self.frame.get_top_model_names(),
                                   self.frame.top_models):
                self.add_module(name, model)

            # create group for custom synchronization with large timeout
            # otherwise the default barrier in pytorch_lightning after training step
            # will throw an exception.
            world = get_world()
            self.group = world.create_collective_group(world.get_ranks(),
                                                       timeout=86400)
Exemple #5
0
def model_server_helper(
    model_num: int,
    group_name: str = "model_server",
    members: Union[str, List[str]] = "all",
):
    """
    Helper function for creating a tuple of model servers,
    used by APEX, etc. This function requires all processes
    in the world to enter.

    Args:
        model_num: The number of models, corresponds to the number of model
            servers, since each server manages 1 model.
        group_name: Name of the RPC group where gradient servers should be
            registered on, the group name should be unique.
        members: Name of the involved RPC processes, ``"all"`` for all
            processes, only the first process will serve as the server in the
            current implementation.

    Returns:
        A tuple of accessors to model servers, the size of tuple is
        ``model_num``
    """
    # create groups first
    world = get_world()
    members = world.get_members() if members == "all" else members
    server_group = world.create_rpc_group(group_name, members)

    # create servers
    # In current implementation, only one process will initialize the server
    if get_cur_name() == members[0]:
        for i in range(model_num):
            _server = PushPullModelServerImpl("model_server_" + str(i),
                                              server_group)

    server_group.barrier()

    servers = tuple(
        server_group.get_paired("model_server_" + str(i)).to_here()
        for i in range(model_num))

    # accessors instead of actual implementation instance
    # will be returned because of __reduce__
    return servers
Exemple #6
0
    def init_from_config(cls, config: Union[Dict[str, Any], Config]):
        world = get_world()
        f_config = copy.deepcopy(config["frame_config"])
        ars_group = world.create_rpc_group(
            group_name=f_config["ars_group_name"],
            members=(
                world.get_members()
                if f_config["ars_members"] == "all"
                else f_config["ars_members"]
            ),
        )

        models = assert_and_get_valid_models(f_config["models"])
        model_args = f_config["model_args"]
        model_kwargs = f_config["model_kwargs"]
        models = [
            m(*arg, **kwarg) for m, arg, kwarg in zip(models, model_args, model_kwargs)
        ]

        optimizer = assert_and_get_valid_optimizer(f_config["optimizer"])
        lr_scheduler = f_config["lr_scheduler"] and assert_and_get_valid_lr_scheduler(
            f_config["lr_scheduler"]
        )
        servers = model_server_helper(
            model_num=1,
            group_name=f_config["model_server_group_name"],
            members=f_config["model_server_members"],
        )
        del f_config["optimizer"]
        del f_config["lr_scheduler"]
        frame = cls(
            *models,
            optimizer,
            ars_group,
            servers,
            lr_scheduler=lr_scheduler,
            **f_config,
        )
        return frame
Exemple #7
0
def model_server_helper(model_num):
    """
    Helper function for creating a tuple of model servers,
    used by APEX, etc. This function requires all processes
    in the world to enter.

    Warning:
        You should never run this function twice!

    Returns:
        A tuple of accessors to model servers, the size of tuple is
        ``model_num``
    """
    DEFAULT_GROUP_NAME = "server_group"

    # create groups first
    world = get_world()
    server_group = world.create_rpc_group(DEFAULT_GROUP_NAME,
                                          world.get_members())

    # create servers
    # In current implementation, only one process will initialize the server
    if get_cur_name() == world.get_members()[0]:
        for i in range(model_num):
            _server = PushPullModelServerImpl("model_server_" + str(i),
                                              server_group)

    server_group.barrier()

    servers = tuple(
        server_group.get_paired("model_server_" + str(i)).to_here()
        for i in range(model_num)
    )

    # accessors instead of actual implementation instance
    # will be returned because of __reduce__
    return servers
Exemple #8
0
    def init_from_config(
        cls,
        config: Union[Dict[str, Any], Config],
        model_device: Union[str, t.device] = "cpu",
    ):
        world = get_world()
        f_config = deepcopy(config["frame_config"])
        impala_group = world.create_rpc_group(
            group_name=f_config["impala_group_name"],
            members=(
                world.get_members()
                if f_config["impala_members"] == "all"
                else f_config["impala_members"]
            ),
        )

        models = assert_and_get_valid_models(f_config["models"])
        model_args = f_config["model_args"]
        model_kwargs = f_config["model_kwargs"]
        models = [
            m(*arg, **kwarg).to(model_device)
            for m, arg, kwarg in zip(models, model_args, model_kwargs)
        ]
        # wrap models in DistributedDataParallel when running in learner mode
        max_learner_id = f_config["learner_process_number"]

        learner_group = world.create_collective_group(ranks=list(range(max_learner_id)))

        if world.rank < max_learner_id:
            models = [
                DistributedDataParallel(module=m, process_group=learner_group.group)
                for m in models
            ]

        optimizer = assert_and_get_valid_optimizer(f_config["optimizer"])
        criterion = assert_and_get_valid_criterion(f_config["criterion"])(
            *f_config["criterion_args"], **f_config["criterion_kwargs"]
        )
        lr_scheduler = f_config["lr_scheduler"] and assert_and_get_valid_lr_scheduler(
            f_config["lr_scheduler"]
        )
        servers = model_server_helper(
            model_num=1,
            group_name=f_config["model_server_group_name"],
            members=f_config["model_server_members"],
        )
        del f_config["optimizer"]
        del f_config["criterion"]
        del f_config["lr_scheduler"]
        frame = cls(
            *models,
            optimizer,
            criterion,
            impala_group,
            servers,
            lr_scheduler=lr_scheduler,
            **f_config,
        )
        if world.rank >= max_learner_id:
            frame.role = "sampler"
            frame.update = _disable_update
        else:
            frame.role = "learner"
        return frame
Exemple #9
0
def grad_server_helper(
    model_creators: List[Callable],
    group_name: str = "grad_server",
    members: Union[str, List[str]] = "all",
    optimizer: Any = Adam,
    learning_rate: Union[float, List[float]] = 1e-3,
    optimizer_kwargs: List[Dict[str, Any]] = None,
    lr_scheduler: Any = None,
    lr_scheduler_args: List[Tuple] = None,
    lr_scheduler_kwargs: List[Dict[str, Any]] = None,
):
    """
    Helper function for creating a tuple of grad servers,
    used by A3C, IMPALE, etc. This function requires all processes
    in the world to enter.

    Args:
        model_creators: A list of model creator functions,
            each one corresponds to one gradient reduction server.
        group_name: Name of the RPC group where gradient servers should be
            registered on, the group name should be unique.
        members: Name of the involved RPC processes, ``"all"`` for all
            processes, they will be used as secondary reducers, the first
            process will be the primary reducer.
        optimizer: Optimizer class, default is Adam.
        learning_rate: Learning rate of each optimizer. Or a single float value
            for every one.
        optimizer_kwargs: Optimizer keyword arguments for each optimizer of
            each model.
        lr_scheduler: Learning rate scheduler class.
        lr_scheduler_args: Learning rate scheduler arguments for each
            lr_scheduler corresponding to each optimizer.
        lr_scheduler_kwargs: Learning rate scheduler keyword arguments for each
            lr_scheduler corresponding to each optimizer.

    Returns:
        A tuple of accessors to gradient servers, the tuple has the
        same size as ``model_creators``
    """
    # Note:
    # passing a list of creator functions instead of passing a list of models
    # directly is designed to remove the unnecessary model creation cost on
    # not-the-primary-reducer processes.

    # create groups first
    world = get_world()
    members = world.get_members() if members == "all" else members
    server_group = world.create_rpc_group(group_name, members)

    if isinstance(learning_rate, float):
        learning_rate = [learning_rate] * len(model_creators)

    optimizer_kwargs = optimizer_kwargs or [{}] * len(model_creators)
    lr_scheduler_args = lr_scheduler_args or [()] * len(model_creators)
    lr_scheduler_kwargs = lr_scheduler_kwargs or [{}] * len(model_creators)

    # create servers
    primary_reducer = members[0]
    servers = [
        PushPullGradServerImpl("grad_server_" + str(i),
                               server_group,
                               primary_reducer=primary_reducer)
        for i in range(len(model_creators))
    ]
    if get_cur_name() == primary_reducer:
        for (
                model_creator,
                server,
                optim_kwargs,
                lr,
                lr_sch_args,
                lr_sch_kwargs,
        ) in zip(
                model_creators,
                servers,
                optimizer_kwargs,
                learning_rate,
                lr_scheduler_args,
                lr_scheduler_kwargs,
        ):
            model = model_creator()
            if lr_scheduler is None:
                server.manage_model(
                    model, optimizer(model.parameters(), lr=lr,
                                     **optim_kwargs))
            else:
                optimizer = optimizer(model.parameters(),
                                      lr=lr,
                                      **optim_kwargs)
                server.manage_model(
                    model,
                    optimizer,
                    lr_scheduler(optimizer, *lr_sch_args, **lr_sch_kwargs),
                )
            server.start()

    server_group.barrier()
    servers = tuple(
        server_group.get_paired("grad_server_" + str(i)).to_here()
        for i in range(len(model_creators)))

    # accessors instead of actual implementation instance
    # will be returned because of __reduce__
    return servers