예제 #1
0
    def get_trainer_send_context(self):
        send_ctx = {}
        distibuted_varnames = get_sparse_tablenames(self.origin_main_program,
                                                    True)
        idx = 0

        if not self.is_geo_mode():
            for merged in self.merged_dense_pairs:
                grad = merged[1]
                ctx = self.build_ctx(grad, self.grad_var_mapping, True, False,
                                     True)
                send_ctx[ctx.var_name()] = ctx

            for merged in self.merged_sparse_pairs:
                param = merged[0]
                grad = merged[1]

                param_name = param.merged_var.name

                is_distributed = True if param_name in distibuted_varnames else False

                ctx = self.build_ctx(grad, self.grad_var_mapping, True, True,
                                     True, is_distributed)
                send_ctx[ctx.var_name()] = ctx
                idx += 1

            if self.is_async_mode():
                name, ctx = self._step_ctx(idx)
                send_ctx[name] = ctx
        else:
            for pairs in self.origin_sparse_pairs:
                param, grad = pairs
                param_name = param.name
                is_distributed = True if param_name in distibuted_varnames else False

                param_ctx = self.build_ctx(param, self.param_var_mapping, False,
                                           True, True, is_distributed)
                grad_ctx = self.build_ctx(grad, self.grad_var_mapping, True,
                                          True, True, is_distributed)

                ctx = CommContext(param_ctx.var_name(),
                                  param_ctx.split_varnames(),
                                  param_ctx.split_endpoints(),
                                  param_ctx.sections(),
                                  grad_ctx.origin_varnames(),
                                  param_ctx.trainer_id(),
                                  param_ctx.aggregate(),
                                  param_ctx.is_sparse(),
                                  param_ctx.is_distributed())

                send_ctx[ctx.var_name()] = ctx
                idx += 1
            name, ctx = self._step_ctx(idx)
            send_ctx[name] = ctx
        return send_ctx
예제 #2
0
    def get_the_one_send_context(self,
                                 split_dense_table=False,
                                 use_origin_program=False,
                                 ep_list=None):
        if ep_list is None:
            ep_list = ["127.0.0.1:6071"]
        send_ctx = {}
        trainer_id = self.get_role_id()
        idx = 0

        merged_dense_pairs = self.origin_merged_dense_pairs if use_origin_program else self.merged_dense_pairs
        merged_sparse_pairs = self.origin_merged_sparse_pairs if use_origin_program else self.merged_sparse_pairs

        idx += self.get_dense_send_context(send_ctx, idx, merged_dense_pairs,
                                           trainer_id, split_dense_table)

        distibuted_varnames = get_sparse_tablenames(self.origin_main_program,
                                                    True)
        for merged in merged_sparse_pairs:
            param, grad = merged
            grad_name = grad.merged_var.name
            param_name = param.merged_var.name
            splited_varname = []

            for i in range(len(ep_list)):
                splited_varname.append("{}.block{}".format(param_name, i))

            is_distributed = True if param_name in distibuted_varnames else False

            var = self.origin_main_program.global_block().vars[
                grad.merged_var.name]

            shape = list(var.shape)
            shape[0] = 0 if is_distributed else shape[0]

            sparse_ctx = CommContext(grad_name, splited_varname, ep_list,
                                     shape, [grad_name], trainer_id, True,
                                     True, is_distributed, idx, False)

            idx += 1
            send_ctx[sparse_ctx.var_name()] = sparse_ctx

        if len(self.tensor_table_dict) > 0 and self.role_maker._is_worker():
            name, ctx = self._step_ctx(idx)
            send_ctx[name] = ctx

        return send_ctx
예제 #3
0
 def _step_ctx(self, idx):
     name = STEP_COUNTER
     trainer_id = self.get_role_id()
     endpoints = self.get_ps_endpoints()
     sections = [1] * len(endpoints)
     names = [name] * len(endpoints)
     ctx = CommContext(name, names, endpoints, sections, [name], trainer_id,
                       True, False, False, idx, True)
     return name, ctx
예제 #4
0
 def get_dense_send_context(self,
                            send_ctx,
                            idx,
                            merged_dense_pairs,
                            trainer_id,
                            split_dense_table=False):
     if len(merged_dense_pairs) < 1:
         return idx
     if not split_dense_table:
         origin_varnames = []
         var_numel = 0
         for merged in merged_dense_pairs:
             grad = merged[1]
             origin_varnames.append(grad.merged_var.name)
             var = self.origin_main_program.global_block().vars[
                 grad.merged_var.name]
             var_numel += reduce(lambda x, y: x * y, var.shape)
         grad_name = "Dense@Grad"
         trainer_id = self.get_role_id()
         aggregate = True
         dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
                                 [var_numel], origin_varnames, trainer_id,
                                 aggregate, False, False, idx, False)
         send_ctx[grad_name] = dense_ctx
         idx += 1
     else:
         for merged in merged_dense_pairs:
             grad = merged[1]
             origin_varname = grad.merged_var.name
             var = self.origin_main_program.global_block(
             ).vars[origin_varname]
             var_numel = reduce(lambda x, y: x * y, var.shape)
             grad_name = origin_varname
             aggregate = True
             dense_ctx = CommContext(grad_name, [grad_name],
                                     ["127.0.0.1:6071"], [var_numel],
                                     [origin_varname], trainer_id,
                                     aggregate, False, False, idx, False)
             send_ctx[grad_name] = dense_ctx
             idx += 1
     return idx
예제 #5
0
    def build_ctx(self,
                  vars,
                  mapping,
                  is_grad,
                  is_sparse,
                  is_send,
                  is_distributed=False):
        def get_grad_var_ep(slices):
            names = []
            eps = []
            sections = []

            for slice in slices:
                if self.is_geo_mode():
                    if is_send:
                        names.append("{}.delta".format(slice.name))
                    else:
                        names.append(slice.name)
                elif is_grad and self.is_sync_mode(
                ) and self.get_trainers() > 1:
                    names.append("{}.trainer_{}".format(
                        slice.name, self.get_role_id()))
                else:
                    names.append(slice.name)

                sections.append(slice.shape[0])

                for ep, pairs in self.param_grad_ep_mapping.items():
                    params, grads = pairs["params"], pairs["grads"]

                    for var in params + grads:
                        if slice.name == var.name:
                            eps.append(ep)
                            break
            return names, eps, sections

        if isinstance(vars, MergedVariable):
            name = vars.merged_var.name
            slices = mapping[name]
            names, eps, sections = get_grad_var_ep(slices)
            origin_varnames = [var.name for var in vars.ordered_vars]
        else:
            name = vars.name
            slices = mapping[name]
            names, eps, sections = get_grad_var_ep(slices)
            origin_varnames = [vars.name]

        trainer_id = self.get_role_id()
        aggregate = True
        ctx = CommContext(name, names, eps, sections, origin_varnames,
                          trainer_id, aggregate, is_sparse, is_distributed)
        return ctx
예제 #6
0
    def get_the_one_trainer_send_context(self, split_dense_table):
        if self.is_geo_mode():
            send_ctx = {}
            trainer_id = self.get_role_id()
            idx = 0

            distibuted_varnames = get_sparse_tablenames(
                self.origin_main_program, True)
            for merged in self.merged_sparse_pairs:
                param, grad = merged
                grad_name = grad.merged_var.name
                param_name = param.merged_var.name
                is_distributed = True if param_name in distibuted_varnames else False

                var = self.origin_main_program.global_block().vars[
                    grad.merged_var.name]
                var_numel = reduce(lambda x, y: x * y, var.shape[1:])

                sparse_ctx = CommContext(grad_name, [grad_name],
                                         ["127.0.0.1:6071"], [var_numel],
                                         [grad_name], trainer_id, True, True,
                                         is_distributed, idx, False)
                idx += 1
                send_ctx[sparse_ctx.var_name()] = sparse_ctx

            if len(send_ctx) == 0:
                raise ValueError(
                    "GeoSGD require sparse parameters in your net.")

            if len(self.tensor_table_dict) > 0 and self.role_maker._is_worker(
            ):
                name, ctx = self._step_ctx(idx)
                send_ctx[name] = ctx

            return send_ctx
        else:
            return self.get_the_one_send_context(split_dense_table)