def do_nccl_all_reduce_start(self, cmd, *args, **kwargs): ilog.debug(500002, '>>> do_nccl_all_reduce_start') r = self.do_rpc_call_wkrs('fn_nccl_all_reduce_start', cmd=cmd, **kwargs) return r
def do_dt_sampler_sync_start(self, cmd, *args, **kwargs): ilog.debug(500002, '>>> do_dt_sampler_sync_start') r = self.do_rpc_call_wkrs('fn_dt_sampler_sync_start', cmd=cmd, **kwargs) return r
def do_one_mp_step_calc_loss(self, cmd, pre_ret, *args, **kwargs): ilog.debug(ILOG_I_MP_SVR_SCH, '>>> do_one_mp_step_calc_loss') ''' pre_net_map_order = pre_ret.get('net_map_order') assert pre_net_map_order == self.one_mp_step_order_i, 'wrong pre_net_map_order' assert self._is_last_part(), 'not is the last part' r = self.do_rpc_call_wkr( wkr_rank, 'fn_one_mp_step_calc_loss', cmd=cmd, **kwargs) ''' target_shps = self.mp_target_shps # TODO: if not target_shps: pre_net_map_order = pre_ret.get('net_map_order') assert pre_net_map_order == self.one_mp_step_order_i, 'wrong pre_net_map_order' assert self._is_last_part(), 'not is the last part' from_rank = self._one_mp_step_get_worker_rank_from_net_map(0) to_rank = self._one_mp_step_get_worker_rank_from_net_map( pre_net_map_order) r = self.do_rpc_call_wkrs_p2p(from_rank, to_rank, 'fn_one_mp_step_calc_loss', cmd=cmd, target_shps=target_shps, **kwargs) return r
def do_cre_distdt_indices(self, wkr_rank=0, cfg={}, *args, **kwargs): ilog.debug(500002, '>>> do_cre_distdt_indices') wkr = self.worker_url(wkr_rank) cli = self.worker_clis[wkr] r = self.do_rpc_call(cli, 'fn_cre_distdt_indices', cfg=cfg) return r
def do_cre_nccl_nuid(self, wkr_rank=0, cfg={}, *args, **kwargs): ilog.debug(500002, '>>> do_cre_nccl_nuid') wkr = self.worker_url(wkr_rank) cli = self.worker_clis[wkr] r = self.do_rpc_call(cli, 'fn_cre_nccl_nuid', cfg=cfg) return r
def do_one_mp_step_backward(self, cmd, *args, **kwargs): ilog.debug(ILOG_I_MP_SVR_SCH, '>>> do_one_mp_step_backward') wkr_rank = self._one_mp_step_get_worker_rank_from_net_map() r = self.do_rpc_call_wkr(wkr_rank, 'fn_one_mp_step_backward', cmd=cmd, **kwargs) return r
def fn_cre_nccl(self, msg, *args, **kwargs): ilog.debug(ILOG_I_MP_SVR_WKR, '>>> fn_cre_nccl') m = self.msg_unpck(msg) msg = m.get('msg') or {} self.q_put_state(2, msg) rep = self.mk_rpc_ret('fn_cre_nccl', 'ok') return self.msg_pck(rep)
def fn_reset_distdt_indices(self, msg, *args, **kwargs): ilog.debug(500001, '>>> fn_reset_distdt_indices') m = self.msg_unpck(msg) msg = m.get('msg') or {} self.q_put_state(52, msg) rep = self.mk_rpc_ret('fn_reset_distdt_indices', 'ok') return self.msg_pck(rep)
def fn_next_epoch_start(self, msg, *args, **kwargs): ilog.debug(500001, '>>> fn_next_epoch_start') m = self.msg_unpck(msg) msg = m.get('msg') or {} self.q_put_state(13, msg) rep = self.mk_rpc_ret('fn_next_epoch_start', 'ok') return self.msg_pck(rep)
def fn_init_nccl_comm(self, msg, *args, **kwargs): ilog.debug(500001, '>>> fn_init_nccl_comm') m = self.msg_unpck(msg) msg = m.get('msg') or {} self.q_put_state(4, msg) rep = self.mk_rpc_ret('fn_init_nccl_comm', 'ok') return self.msg_pck(rep)
def fn_one_mp_step_backward(self, msg, *args, **kwargs): ilog.debug(ILOG_I_MP_SVR_WKR, '>>> fn_one_mp_step_backward') m = self.msg_unpck(msg) msg = m.get('msg') or {} self.q_put_state(Q_STATE_HEAD_MP_WKR + 225, msg) rep = self.mk_rpc_ret('fn_one_mp_step_backward', 'ok') return self.msg_pck(rep)
def fn_stop_train(self, msg, *args, **kwargs): ilog.debug(500001, '>>> fn_stop_train') m = self.msg_unpck(msg) msg = m.get('msg') or {} self.q_put_state(20, msg) #self.trainer_stop() rep = self.mk_rpc_ret('fn_stop_train', 'ok') return self.msg_pck(rep)
def fn_start_epoch(self, msg, *args, **kwargs): ilog.debug(500001, '>>> fn_start_epoch') m = self.msg_unpck(msg) msg = m.get('msg') or {} self.q_put_state(5, msg) #nuid = self.trnr.start_epoch() rep = self.mk_rpc_ret('fn_start_epoch', 'ok') return self.msg_pck(rep)
def fn_nccl_all_reduce_start(self, msg, *args, **kwargs): ilog.debug(500001, '>>> fn_nccl_all_reduce_start') m = self.msg_unpck(msg) msg = m.get('msg') or {} self.q_put_state(12, msg) #r = self.trnr.do_params_nccl_all_reduce() #r = self.trnr.check_msg(msg) rep = self.mk_rpc_ret('fn_nccl_all_reduce_start', 'ok') return self.msg_pck(rep)
def fn_one_step_start(self, msg, *args, **kwargs): ilog.debug(500001, '>>> fn_one_step_start') m = self.msg_unpck(msg) msg = m.get('msg') or {} self.q_put_state(11, msg) #r = self.trnr.one_step() #r = self.trnr.check_msg(msg) rep = self.mk_rpc_ret('fn_one_step_start', 'ok') return self.msg_pck(rep)
def fn_nccl_all_reduce(self, msg, *args, **kwargs): # -x- ilog.debug(500001, '>>> fn_nccl_all_reduce') m = self.msg_unpck(msg) msg = m.get('msg') or {} #self.q_put_state(7, msg) #r = self.trnr.one_step() r = self.trnr.check_msg(msg) rep = self.mk_rpc_ret('fn_nccl_all_reduce', 'ok', r=r) return self.msg_pck(rep)
def fn_cre_nccl_nuid(self, msg, *args, **kwargs): ilog.debug(500001, '>>> fn_cre_nccl_nuid') m = self.msg_unpck(msg) msg = m.get('msg') or {} #self.q_put_state(3, msg) nuid = self.trnr.cre_nccl_nuid(msg) nuid = msgpck.pkl_dumps(nuid) rep = self.mk_rpc_ret('fn_cre_nccl_nuid', 'ok', nuid=nuid) return self.msg_pck(rep)
def fn_next_epoch(self, msg, *args, **kwargs): ilog.debug(500001, '>>> fn_next_epoch') m = self.msg_unpck(msg) msg = m.get('msg') or {} #self.q_put_state(8, msg) #r = self.trnr.one_step() r = self.trnr.check_msg(msg) rep = self.mk_rpc_ret('fn_next_epoch', 'ok', r=r) return self.msg_pck(rep)
def fn_one_mp_step_backwarded(self, msg, *args, **kwargs): ilog.debug(ILOG_I_MP_SVR_SCH, '>>> fn_one_mp_step_backwarded') m = self.msg_unpck(msg) role = m.get('role') rank = m.get('rank') msg = m.get('msg') or {} # TODO: do check self.q_put_state(Q_STATE_HEAD_MP_SCH + 226, msg, rank, role) rep = self.mk_rpc_ret('fn_one_mp_step_backwarded', 'ok') return self.msg_pck(rep)
def fn_cre_distdt_indices(self, msg, *args, **kwargs): ilog.debug(500001, '>>> fn_cre_distdt_indices') m = self.msg_unpck(msg) msg = m.get('msg') or {} dt_indices = self.trnr.cre_distdt_indices(msg) #dt_indices = msgpck.pkl_dumps(dt_indices) rep = self.mk_rpc_ret('fn_cre_distdt_indices', 'ok', dt_indices=dt_indices) return self.msg_pck(rep)
def fn_next_epoch_end(self, msg, *args, **kwargs): ilog.debug(500002, '>>> fn_next_epoch_end') m = self.msg_unpck(msg) role = m.get('role') rank = m.get('rank') msg = m.get('msg') or {} # TODO: do check self.q_put_state(13, msg, rank, role) rep = self.mk_rpc_ret('fn_next_epoch_end', 'ok') return self.msg_pck(rep)
def fn_worker_reg(self, msg, *args, **kwargs): ilog.debug(ILOG_I_MP_SVR_SCH, '>>> fn_worker_reg') m = self.msg_unpck(msg) role = m.get('role') rank = m.get('rank') msg = m.get('msg') or {} # TODO: do check self.q_put_state(41, msg, rank, role) rep = self.mk_rpc_ret('fn_worker_reg', 'ok') return self.msg_pck(rep)
def do_one_mp_step_trans_grads(self, cmd, pre_ret, *args, **kwargs): ilog.debug(ILOG_I_MP_SVR_SCH, '>>> do_one_mp_step_trans_grads') grads_shps = pre_ret.get('grads_shps') # TODO: if not grads_shps: pre_net_map_order = pre_ret.get('net_map_order') assert pre_net_map_order == self.one_mp_step_order_i, 'wrong pre_net_map_order' from_rank = self._one_mp_step_get_worker_rank_from_net_map( pre_net_map_order) to_rank = self._one_mp_step_get_worker_rank_from_net_map( pre_net_map_order - 1) r = self.do_rpc_call_wkrs_p2p(from_rank, to_rank, 'fn_one_mp_step_trans_grads', cmd=cmd, grads_shps=grads_shps, **kwargs) return r
def check_state(self, *args, **kwargs): s = self.q_get_state(timeout=1) svr_state, sa, skw = s ilog.debug(500003, '=== svr_state', svr_state) if svr_state == 0: r = self.svr.check_train_stop() if r: self.q_put_state(20, 'train_stop') elif svr_state == 40: r = self.svr.check_workers_reg() # TODO: if all workers registered #self.q_put_state(1, 'init_train') self.g_sleep(1.0) elif svr_state == 41: if sa: msg, rank, role = sa wkr_url = msg.get('url') r = self.svr.check_workers_reg(rank, role, msg) if r: # all workers have already registered self.init_train() else: self.wait_for_workers_reg() else: pass elif svr_state == 1: r = self.svr.do_init_train() self.q_put_state(2, 'cre_nccl') elif svr_state == 2: r = self.svr.do_cre_nccl() self.q_put_state(3, 'cre_nccl_nuid') elif svr_state == 3: r = self.svr.do_cre_nccl_nuid() ret = r.get('ret') # TODO: check ret msg = r.get('msg', {}) nuid = msg.get('nuid') r = self.svr.do_init_nccl_comm(nuid) #self.q_put_state(5, 'start_epoch') #self.q_put_state(6, 'one_step') self.q_put_state(30, 'weights_sync') ##self.q_put_state(51, 'distdt_indices') elif svr_state == 51: # do every epoch r = self.svr.do_cre_distdt_indices() # NOTE: only call rank-0 ret = r.get('ret') # TODO: check ret msg = r.get('msg', {}) dt_indices = msg.get('dt_indices') r = self.svr.do_reset_distdt_indices(dt_indices) #self.q_put_state(5, 'start_epoch') self.q_put_state(6, 'one_step') ##self.q_put_state(30, 'weights_sync') elif svr_state == 4: pass elif svr_state == 5: pass elif svr_state == 30: cmd = sa[0] r = self.svr.do_weights_sync_start(cmd, typ=self.nccl_allreduce_typ) elif svr_state == 31: if sa: msg, rank, role = sa ret = msg.get('ret') s, t = self.svr.trainer_check_ret(ret, rank, role) self.q_put_state(s, t) else: pass elif svr_state == 32: cmd = sa[0] r = self.svr.do_dt_sampler_sync_start(cmd, typ=self.nccl_allreduce_typ) elif svr_state == 33: if sa: msg, rank, role = sa ret = msg.get('ret') s, t = self.svr.trainer_check_ret(ret, rank, role) self.q_put_state(s, t) self.g_sleep(0.001) # NOTE: wait (for dt_sampler_sync) <3.3> else: pass elif svr_state == 6: # NOTE: wait the zrpc call to return (for dt_sampler_sync) #self.g_sleep(0.001) # to <3.3> cmd = sa[0] r = self.svr.do_one_step_start(cmd, typ=self.nccl_allreduce_typ) elif svr_state == 7: cmd = sa[0] r = self.svr.do_nccl_all_reduce_start(cmd, typ=self.nccl_allreduce_typ) elif svr_state == 8: cmd = sa[0] #r = self.svr.do_next_epoch(cmd, typ=self.nccl_allreduce_typ) ##self.q_put_state(5, 'start_epoch') ##self.q_put_state(6, 'one_step') # ==> start a new epoch #self.q_put_state(32, 'dt_sampler_sync') # ==> resync data sampler r = self.svr.do_next_epoch_start(cmd, typ=self.nccl_allreduce_typ) elif svr_state == 11 or svr_state == 12 or svr_state == 13: if sa: msg, rank, role = sa ret = msg.get('ret') s, t = self.svr.trainer_check_ret(ret, rank, role) self.q_put_state(s, t) else: pass elif svr_state == 20: # NOTE: wait the zrpc call to return self.g_sleep(0.01) r = self.svr.do_stop_train() self.stop() # service loop stop else: r = self.svr.check_train_stop() if r: self.q_put_state(20, 'train_stop')
def do_stop_train(self, *args, **kwargs): ilog.debug(500002, '>>> do_stop_train') r = self.do_rpc_call_wkrs('fn_stop_train') return r
def do_cre_nccl(self, cfg={}, *args, **kwargs): ilog.debug(500002, '>>> do_cre_nccl') r = self.do_rpc_call_wkrs('fn_cre_nccl', cfg=cfg) return r
def do_next_epoch(self, cmd, *args, **kwargs): # not used now ilog.debug(500002, '>>> do_next_epoch') r = self.do_rpc_call_wkrs('fn_next_epoch', cmd=cmd, **kwargs) return r
def do_next_epoch_start(self, cmd, *args, **kwargs): ilog.debug(500002, '>>> do_next_epoch_start') r = self.do_rpc_call_wkrs('fn_next_epoch_start', cmd=cmd, **kwargs) return r
def do_reset_distdt_indices(self, dt_indices, *args, **kwargs): ilog.debug(500002, '>>> do_reset_distdt_indices') r = self.do_rpc_call_wkrs('fn_reset_distdt_indices', dt_indices=dt_indices) return r
def do_init_nccl_comm(self, nuid, *args, **kwargs): ilog.debug(500002, '>>> do_init_nccl_comm') r = self.do_rpc_call_wkrs('fn_init_nccl_comm', nuid=nuid) return r