コード例 #1
0
    def __init__(
        self,
        params,
        lr=required,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
        conf=None,
        model=None,
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(
                "Nesterov momentum requires a momentum and zero dampening")
        super(DCD_PSGD, self).__init__(params, defaults)

        # store the whole training arguments.
        self.conf = conf

        # define the aggregator.
        self.rank = conf.graph.rank
        self.neighbors_info = conf.graph.get_neighborhood()
        self.aggregator = comm.get_aggregators(
            cur_rank=self.rank,
            world=conf.graph.ranks,
            neighbors_info=self.neighbors_info,
            aggregator_type="decentralized",
        )
        self.world_aggregator = comm.get_aggregators(
            cur_rank=self.rank,
            world=conf.graph.ranks,
            neighbors_info=dict(
                (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks),
            aggregator_type="centralized",
        )

        # define param names and init model_hat.
        self.param_names = list(
            enumerate([group["name"] for group in self.param_groups]))
        self.init_neighbor_hat_params()

        # related to sparsification/quantization.
        self.compressor = DCDCompressor(
            aggregator=self.aggregator,
            comm_op=conf.comm_op,
            comm_device=conf.comm_device,
            compress_ratio=conf.compress_ratio,
            quantize_level=conf.quantize_level,
            is_biased=conf.is_biased,
            backend=conf.backend,
            use_ipc=conf.use_ipc,
        )
コード例 #2
0
ファイル: sgd.py プロジェクト: gessfred/LocalSGD-Code
    def __init__(
        self,
        params,
        lr=required,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
        conf=None,
        model=None,
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(
                "Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)

        # store the whole training arguments.
        self.conf = conf
        self.rank = conf.graph.rank
        self.neighbors_info = conf.graph.get_neighborhood()

        # define the aggregator.
        self.decentralized_aggregator = comm.get_aggregators(
            cur_rank=self.rank,
            world=conf.graph.ranks,
            neighbors_info=self.neighbors_info,
            aggregator_type="decentralized",
        )
        self.world_aggregator = comm.get_aggregators(
            conf,
            cur_rank=self.rank,
            world=conf.graph.ranks,
            neighbors_info=dict(
                (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks),
            aggregator_type="centralized",
        )

        # define reducer.
        self.backend = conf.backend

        # define sorted param names.
        self.param_names = list(
            enumerate([group["name"] for group in self.param_groups]))
コード例 #3
0
ファイル: sign_sgd.py プロジェクト: gessfred/LocalSGD-Code
    def __init__(
        self,
        params,
        lr=required,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
        conf=None,
        model=None,
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(
                "Nesterov momentum requires a momentum and zero dampening")
        super(SignSGD, self).__init__(params, defaults)

        # store the whole training arguments.
        self.conf = conf
        self.rank = conf.graph.rank
        self.neighbors_info = conf.graph.get_neighborhood()
        self.local_step = conf.local_step
        self.turn_on_local_step_from_epoch = conf.turn_on_local_step_from

        # define the aggregator.
        self.world_aggregator = comm.get_aggregators(
            conf,
            cur_rank=self.rank,
            world=conf.graph.ranks,
            neighbors_info=dict(
                (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks),
            aggregator_type="centralized",
        )

        # define sorted param names.
        self.param_names = list(
            enumerate([group["name"] for group in self.param_groups]))

        # initialize the concensus
        self.compressor = ExactSignCompressor(
            rank=self.rank,
            world_size=len(conf.graph.ranks),
            majority_vote=conf.majority_vote,
            aggregator=self.world_aggregator,
            comm_op=conf.comm_op,
            comm_device=self.conf.comm_device,
            use_ipc=conf.use_ipc,
        )
コード例 #4
0
    def __init__(
        self,
        params,
        lr=required,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
        conf=None,
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(
                "Nesterov momentum requires a momentum and zero dampening")
        super(DGC, self).__init__(params, defaults)

        # store the whole training arguments.
        self.conf = conf
        self.n_nodes = conf.graph.n_nodes
        self.rank = conf.graph.rank

        # define the aggregator.
        self.param_names = list(
            enumerate([group["name"] for group in self.param_groups]))
        self.world_aggregator = get_aggregators(
            cur_rank=self.rank,
            world=conf.graph.ranks,
            neighbors_info=dict(
                (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks),
            aggregator_type="centralized",
        )

        # related to sparsification/quantization.
        self.comm_op = conf.comm_op
        self.comm_device = conf.comm_device
        self.is_compress_op = "compress" in self.comm_op
        self.compress_ratio = conf.compress_ratio
        self.compress_warmup_values = conf.compress_warmup_values
        self.compress_warmup_epochs = conf.compress_warmup_epochs
        self.quantize_level = conf.quantize_level
        self.is_biased = conf.is_biased

        self.clip_grad = conf.clip_grad
        self.clip_grad_val = conf.clip_grad_val
        self.mask_momentum = conf.mask_momentum

        self.init_memory()
        self.init_compression()

        # define compressors.
        if self.is_compress_op:
            self.compressor_fn = SparsificationCompressor()
        else:
            self.compressor_fn = QuantizationCompressor()

        # define reducer.
        self.backend = conf.backend
コード例 #5
0
    def _sync_thread_func(
        gossiped_dist_model_flag,
        updated_local_model_flag,
        sync_queue,
        param_names,
        timer,
        neighbors_info,
    ):
        # some utility function.
        def _init_neighbor_hat_params(conf, param_groups, param_names):
            params, params_shapes = comm.get_data(param_groups,
                                                  param_names,
                                                  is_get_grad=False)
            flatten_params = TensorBuffer(params)
            flatten_params.buffer = torch.zeros_like(flatten_params.buffer)

            # init the neighbor_params.
            return (
                {
                    conf.graph.rank: deepcopy(flatten_params),
                    "memory": deepcopy(flatten_params),
                },
                params_shapes,
            )

        # initialization.
        conf, param_groups, n_bits = sync_queue.get()
        neighbor_hat_params, params_shapes = _init_neighbor_hat_params(
            conf, param_groups, param_names)

        # init the distributed world again.
        try:
            dist.init_process_group("mpi")
        except RuntimeError as e:
            print(f"error: {e}")

        # define aggregator and compressor.
        # conf.graph._make_process_group()
        aggregator = comm.get_aggregators(
            cur_rank=conf.graph.rank,
            world=conf.graph.ranks,
            neighbors_info=neighbors_info,
            aggregator_type="decentralized",
            graph=conf.graph,
        )
        compressor = CHOCOCompressor(
            aggregator=aggregator,
            comm_op=conf.comm_op,
            comm_device=conf.comm_device,
            compress_ratio=conf.compress_ratio,
            quantize_level=conf.quantize_level,
            is_biased=conf.is_biased,
            backend=conf.backend,
            use_ipc=conf.use_ipc,
        )

        # formal infinity loop.
        while True:
            if updated_local_model_flag.is_set():
                updated_local_model_flag.clear()

                # recover current params and hat_params
                params, flatten_params, flatten_hat_params = utils.recover_params(
                    param_groups=param_groups,
                    param_names=param_names,
                    rank=conf.graph.rank,
                    neighbor_hat_params=neighbor_hat_params,
                    get_hat_params=True,
                )
                # get updated flatten params.
                utils.update_params_from_neighbor(
                    neighbor_hat_params=neighbor_hat_params,
                    flatten_params=flatten_params,
                    consensus_stepsize=conf.consensus_stepsize,
                    self_rank=conf.graph.rank,
                )
                # update the local model using neighborhood info.
                flatten_params.unpack(params)
                gossiped_dist_model_flag.set()

                # start compress/sync.
                sync_buffer = {
                    "original_shapes": params_shapes,
                    "flatten_params": flatten_params,
                    "flatten_hat_params": flatten_hat_params,
                }
                compressor.pipeline(
                    sync_buffer=sync_buffer,
                    neighbor_hat_params=neighbor_hat_params,
                    neighbors_info=neighbors_info,
                )
                n_bits.data[0] = sync_buffer["n_bits"]
コード例 #6
0
    def __init__(
        self,
        params,
        lr=required,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
        conf=None,
        model=None,
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(
                "Nesterov momentum requires a momentum and zero dampening")
        super(ParallelCHOCO, self).__init__(params, defaults)

        # store the whole training arguments.
        self.conf = conf

        # define the aggregator.
        self.rank = conf.graph.rank
        self.neighbors_info = conf.graph.get_neighborhood()
        self.world_aggregator = comm.get_aggregators(
            cur_rank=self.rank,
            world=conf.graph.ranks,
            neighbors_info=dict(
                (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks),
            aggregator_type="centralized",
        )

        # define param names and init model_hat.
        self.param_names = list(
            enumerate([group["name"] for group in self.param_groups]))

        # efficient sync (try to hide the communication cost).
        self.sync_queue = mp.Queue()
        self.gossiped_dist_model_flag = mp.Event()
        self.updated_local_model_flag = mp.Event()
        self.sync_thread = mp.Process(
            target=ParallelCHOCO._sync_thread_func,
            args=(
                self.gossiped_dist_model_flag,
                self.updated_local_model_flag,
                self.sync_queue,
                self.param_names,
                self.conf.timer,
                self.neighbors_info,
            ),
        )

        self.sync_thread.daemon = True
        self.sync_thread.name = "Sync-Thread"
        self.sync_thread.start()
        self.gossiped_dist_model_flag.clear()
        self.updated_local_model_flag.clear()
        self.n_bits = torch.FloatTensor([0])

        # put something to the shared memory.
        self.sync_queue.put((self.conf, self.param_groups, self.n_bits))