def get_trainer_send_context(self): send_ctx = {} distibuted_varnames = get_sparse_tablenames(self.origin_main_program, True) idx = 0 if not self.is_geo_mode(): for merged in self.merged_dense_pairs: grad = merged[1] ctx = self.build_ctx(grad, self.grad_var_mapping, True, False, True) send_ctx[ctx.var_name()] = ctx for merged in self.merged_sparse_pairs: param = merged[0] grad = merged[1] param_name = param.merged_var.name is_distributed = True if param_name in distibuted_varnames else False ctx = self.build_ctx(grad, self.grad_var_mapping, True, True, True, is_distributed) send_ctx[ctx.var_name()] = ctx idx += 1 if self.is_async_mode(): name, ctx = self._step_ctx(idx) send_ctx[name] = ctx else: for pairs in self.origin_sparse_pairs: param, grad = pairs param_name = param.name is_distributed = True if param_name in distibuted_varnames else False param_ctx = self.build_ctx(param, self.param_var_mapping, False, True, True, is_distributed) grad_ctx = self.build_ctx(grad, self.grad_var_mapping, True, True, True, is_distributed) ctx = CommContext(param_ctx.var_name(), param_ctx.split_varnames(), param_ctx.split_endpoints(), param_ctx.sections(), grad_ctx.origin_varnames(), param_ctx.trainer_id(), param_ctx.aggregate(), param_ctx.is_sparse(), param_ctx.is_distributed()) send_ctx[ctx.var_name()] = ctx idx += 1 name, ctx = self._step_ctx(idx) send_ctx[name] = ctx return send_ctx
def get_the_one_send_context(self, split_dense_table=False, use_origin_program=False, ep_list=None): if ep_list is None: ep_list = ["127.0.0.1:6071"] send_ctx = {} trainer_id = self.get_role_id() idx = 0 merged_dense_pairs = self.origin_merged_dense_pairs if use_origin_program else self.merged_dense_pairs merged_sparse_pairs = self.origin_merged_sparse_pairs if use_origin_program else self.merged_sparse_pairs idx += self.get_dense_send_context(send_ctx, idx, merged_dense_pairs, trainer_id, split_dense_table) distibuted_varnames = get_sparse_tablenames(self.origin_main_program, True) for merged in merged_sparse_pairs: param, grad = merged grad_name = grad.merged_var.name param_name = param.merged_var.name splited_varname = [] for i in range(len(ep_list)): splited_varname.append("{}.block{}".format(param_name, i)) is_distributed = True if param_name in distibuted_varnames else False var = self.origin_main_program.global_block().vars[ grad.merged_var.name] shape = list(var.shape) shape[0] = 0 if is_distributed else shape[0] sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape, [grad_name], trainer_id, True, True, is_distributed, idx, False) idx += 1 send_ctx[sparse_ctx.var_name()] = sparse_ctx if len(self.tensor_table_dict) > 0 and self.role_maker._is_worker(): name, ctx = self._step_ctx(idx) send_ctx[name] = ctx return send_ctx
def _step_ctx(self, idx): name = STEP_COUNTER trainer_id = self.get_role_id() endpoints = self.get_ps_endpoints() sections = [1] * len(endpoints) names = [name] * len(endpoints) ctx = CommContext(name, names, endpoints, sections, [name], trainer_id, True, False, False, idx, True) return name, ctx
def get_dense_send_context(self, send_ctx, idx, merged_dense_pairs, trainer_id, split_dense_table=False): if len(merged_dense_pairs) < 1: return idx if not split_dense_table: origin_varnames = [] var_numel = 0 for merged in merged_dense_pairs: grad = merged[1] origin_varnames.append(grad.merged_var.name) var = self.origin_main_program.global_block().vars[ grad.merged_var.name] var_numel += reduce(lambda x, y: x * y, var.shape) grad_name = "Dense@Grad" trainer_id = self.get_role_id() aggregate = True dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], origin_varnames, trainer_id, aggregate, False, False, idx, False) send_ctx[grad_name] = dense_ctx idx += 1 else: for merged in merged_dense_pairs: grad = merged[1] origin_varname = grad.merged_var.name var = self.origin_main_program.global_block( ).vars[origin_varname] var_numel = reduce(lambda x, y: x * y, var.shape) grad_name = origin_varname aggregate = True dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], [origin_varname], trainer_id, aggregate, False, False, idx, False) send_ctx[grad_name] = dense_ctx idx += 1 return idx
def build_ctx(self, vars, mapping, is_grad, is_sparse, is_send, is_distributed=False): def get_grad_var_ep(slices): names = [] eps = [] sections = [] for slice in slices: if self.is_geo_mode(): if is_send: names.append("{}.delta".format(slice.name)) else: names.append(slice.name) elif is_grad and self.is_sync_mode( ) and self.get_trainers() > 1: names.append("{}.trainer_{}".format( slice.name, self.get_role_id())) else: names.append(slice.name) sections.append(slice.shape[0]) for ep, pairs in self.param_grad_ep_mapping.items(): params, grads = pairs["params"], pairs["grads"] for var in params + grads: if slice.name == var.name: eps.append(ep) break return names, eps, sections if isinstance(vars, MergedVariable): name = vars.merged_var.name slices = mapping[name] names, eps, sections = get_grad_var_ep(slices) origin_varnames = [var.name for var in vars.ordered_vars] else: name = vars.name slices = mapping[name] names, eps, sections = get_grad_var_ep(slices) origin_varnames = [vars.name] trainer_id = self.get_role_id() aggregate = True ctx = CommContext(name, names, eps, sections, origin_varnames, trainer_id, aggregate, is_sparse, is_distributed) return ctx
def get_the_one_trainer_send_context(self, split_dense_table): if self.is_geo_mode(): send_ctx = {} trainer_id = self.get_role_id() idx = 0 distibuted_varnames = get_sparse_tablenames( self.origin_main_program, True) for merged in self.merged_sparse_pairs: param, grad = merged grad_name = grad.merged_var.name param_name = param.merged_var.name is_distributed = True if param_name in distibuted_varnames else False var = self.origin_main_program.global_block().vars[ grad.merged_var.name] var_numel = reduce(lambda x, y: x * y, var.shape[1:]) sparse_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], [grad_name], trainer_id, True, True, is_distributed, idx, False) idx += 1 send_ctx[sparse_ctx.var_name()] = sparse_ctx if len(send_ctx) == 0: raise ValueError( "GeoSGD require sparse parameters in your net.") if len(self.tensor_table_dict) > 0 and self.role_maker._is_worker( ): name, ctx = self._step_ctx(idx) send_ctx[name] = ctx return send_ctx else: return self.get_the_one_send_context(split_dense_table)