def _worker(model, fname, sub_module=False, logger=None, para_lock=None, log_success=None): success = True _msave = model.module if sub_module else model try: if para_lock is None: h5save([t.data for t in _msave.parameters()], fname, h5args=h5args) else: with para_lock: h5save([t.data for t in _msave.parameters()], fname, h5args=h5args) except Exception as e: if logger is None: print(e) else: logger.info(str(e)) success = False if success and (logger is not None) and (log_success is not None): logger.info(log_success)
def handle(srcf, rsf, h5args=h5zipargs): if srcf == rsf: h5save(h5load(srcf, restore_list=False), rsf, h5args=h5args) else: with h5File(srcf, "r") as sfg, h5File(rsf, 'w') as rfg: handle_group(sfg, rfg, h5args=h5args)
def handle(srcfl, rsf): rsm = h5load(srcfl[0]) src_type = [para.dtype for para in rsm] map_type = [ secure_type_map[para.dtype] if para.dtype in secure_type_map else None for para in rsm ] sec_rsm = [ para if typ is None else para.to(typ) for para, typ in zip(rsm, map_type) ] nmodel = 1 for modelf in srcfl[1:]: for basep, mpload, typ in zip(sec_rsm, h5load(modelf), map_type): basep.add_(mpload if typ is None else mpload.to(typ)) nmodel += 1 nmodel = float(nmodel) for basep in sec_rsm: basep.div_(nmodel) rsm = [ para if mtyp is None else para.to(styp) for para, mtyp, styp in zip(sec_rsm, map_type, src_type) ] h5save(rsm, rsf, h5args=h5zipargs)
def handle(srcf, rsf, h5args=h5zipargs): if srcf == rsf: h5save(h5load(srcf, restore_list=False), rsf, h5args=h5args) else: sfg, rfg = h5py.File(srcf, "r"), h5py.File(rsf, 'w') handle_group(sfg, rfg, h5args=h5args) sfg.close() rfg.close()
def handle(vcbf, embf, rsf): vcb, nwd = ldvocab(vcbf) emb = load_emb_txt(vcb, embf) unkemb = emb.get("<unk>", torch.zeros(emb[list(emb.keys())[0]].size(0))) vcb = reverse_dict(vcb) rs = [] for i in range(nwd): rs.append(emb.get(vcb[i], unkemb)) h5save(torch.stack(rs, 0), rsf)
def save_model(model, fname, sub_module=False, logger=None): _msave = model.module if sub_module else model try: h5save([t.data for t in _msave.parameters()], fname) except Exception as e: if logger is None: print(e) else: logger.info(str(e))
def save_model(model, fname, sub_module=False, print_func=print, mtyp=None, h5args=h5modelwargs): _msave = model.module if sub_module else model try: h5save([t.data for t in _msave.parameters()], fname, h5args=h5args) if mtyp is not None: save_model_cleaner(fname, mtyp) except Exception as e: if print_func is not None: print_func(str(e))
def _worker(model, fname, sub_module=False, print_func=print, mtyp=None, para_lock=None, log_success=None): success = True _msave = model.module if sub_module else model try: if para_lock is None: h5save([t.data for t in _msave.parameters()], fname, h5args=h5args) if mtyp is not None: save_model_cleaner(fname, mtyp) else: with para_lock: h5save([t.data for t in _msave.parameters()], fname, h5args=h5args) if mtyp is not None: save_model_cleaner(fname, mtyp) except Exception as e: if print_func is not None: print_func(str(e)) success = False if success and (print_func is not None) and (log_success is not None): print_func(str(log_success))
def handle(srcfl, rsf): type_map = { torch.float16: torch.float64, torch.float32: torch.float64, torch.uint8: torch.int64, torch.int8: torch.int64, torch.int16: torch.int64, torch.int32: torch.int64 } type_map[mask_tensor_type] = torch.int64 rsm = h5load(srcfl[0]) src_type = [para.dtype for para in rsm] map_type = [ type_map[para.dtype] if para.dtype in type_map else None for para in rsm ] sec_rsm = [ para if typ is None else para.to(typ) for para, typ in zip(rsm, map_type) ] nmodel = 1 for modelf in srcfl[1:]: for basep, mpload, typ in zip(sec_rsm, h5load(modelf), map_type): basep.add_(mpload if typ is None else mpload.to(typ)) nmodel += 1 nmodel = float(nmodel) for basep in sec_rsm: basep.div_(nmodel) rsm = [ para if mtyp is None else para.to(styp) for para, mtyp, styp in zip(sec_rsm, map_type, src_type) ] h5save(rsm, rsf, h5args=h5zipargs)
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len( tl) model.train() cur_b, _ls = 1, {} if save_loss else None src_grp, mt_grp, tgt_grp = td["src"], td["mt"], td["tgt"] for i_d in tqdm(tl): seq_batch = torch.from_numpy(src_grp[i_d][:]).long() seq_mt = torch.from_numpy(mt_grp[i_d][:]).long() seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_mt = seq_mt.to(mv_device) seq_o = seq_o.to(mv_device) oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() with autocast(enabled=_use_amp): output = model(seq_batch, seq_mt, oi) loss = lossf(output, ot) if multi_gpu: loss = loss.sum() loss_add = loss.data.item() if scaler is None: loss.backward() else: scaler.scale(loss).backward() wd_add = ot.ne(pad_id).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: _ls[(i_d, t_d)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add if _done_tokens >= tokens_optm: if multi_gpu: model.collect_gradients() optm_step(optm, scaler) optm.zero_grad(set_to_none=True) if multi_gpu: model.update_replicas() _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and ( _cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) _cur_rstep -= 1 if _cur_rstep <= 0: break lrsch.step() if nreport is not None: part_loss += loss_add part_wd += wd_add if cur_b % nreport == 0: if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info( "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() else: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) part_loss = 0.0 part_wd = 0 if save_checkp_epoch and (_cur_rstep is None) and ( save_every is not None) and (cur_b % save_every == 0) and ( chkpf is not None) and (cur_b < ndata): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) cur_b += 1 if part_wd != 0.0: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
tminerr, done_tokens, cur_checkid, remain_steps, _ = train( td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec)) save_model( mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, logger) if save_optm_state: h5save( optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec)) logger.info("New best model saved") if cnfg.dss_ws is not None and cnfg.dss_ws > 0.0 and cnfg.dss_ws < 1.0: dss_ws = int(cnfg.dss_ws * ntrain) _Dws = {} _prev_Dws = {} _crit_inc = {} if cnfg.dss_rm is not None and cnfg.dss_rm > 0.0 and cnfg.dss_rm < 1.0: dss_rm = int(cnfg.dss_rm * ntrain * (1.0 - cnfg.dss_ws)) else: dss_rm = 0 else: dss_ws = 0 dss_rm = 0
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len( tl) model.train() cur_b, _ls = 1, {} if save_loss else None global grad_mon, update_angle, enc_layer, log_dyn_p, log_dynb, wkdir _log_f_dynbatch = open(wkdir + "dynbatch.log", "ab") _log_f_dynbatch.write("ES\n".encode("utf-8")) src_grp, tgt_grp = td["src"], td["tgt"] for i_d in tqdm(tl): seq_batch = torch.from_numpy(src_grp[i_d][:]).long() seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() with autocast(enabled=_use_amp): output = model(seq_batch, oi) loss = lossf(output, ot) if multi_gpu: loss = loss.sum() loss_add = loss.data.item() if scaler is None: loss.backward() else: scaler.scale(loss).backward() wd_add = ot.ne(pad_id).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: _ls[(i_d, t_d)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add if grad_mon.prev_grad is None: if log_dynb: dyn_sel_ind = grad_mon.sel_ind dyn_sel_layer, dyn_sel_enc = dyn_sel_ind % enc_layer, dyn_sel_ind < enc_layer _log_f_dynbatch.write( ("%s%d %d\n" % ("E" if dyn_sel_enc else "D", dyn_sel_layer, wd_add)).encode("utf-8")) _perform_dyn_optm_step, _cos_sim_l = grad_mon.update( model.module if multi_gpu else model) _cos_sim = None if _cos_sim_l is None else _cos_sim_l[0] if log_dynb and (_cos_sim_l is not None): _log_f_dynbatch.write( ("%d %s\n" % (wd_add, " ".join(["%.2f" % (_cu, ) for _cu in _cos_sim_l]))).encode("utf-8")) if _perform_dyn_optm_step or (_done_tokens >= tokens_optm): if not _perform_dyn_optm_step: grad_mon.reset() _do_optm_step = True if _cos_sim is None else ( _cos_sim <= update_angle) if _do_optm_step: if log_dynb: _log_f_dynbatch.write( ("%d\n" % (_done_tokens, )).encode("utf-8")) if multi_gpu: model.collect_gradients() optm_step(optm, scaler) optm.zero_grad(set_to_none=True) if multi_gpu: model.update_replicas() lrsch.step() else: if log_dynb: _log_f_dynbatch.write( ("D %d\n" % (_done_tokens, )).encode("utf-8")) if multi_gpu: model.reset_grad() else: optm.zero_grad(set_to_none=True) log_dynb = random() <= log_dyn_p _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and ( _cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) if _do_optm_step: _cur_rstep -= 1 if _cur_rstep <= 0: break if nreport is not None: part_loss += loss_add part_wd += wd_add if cur_b % nreport == 0: if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info( "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() else: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) logger.info("Dynb: %s" % (" ".join([ "%.2f" % (_tmpu, ) for _tmpu in grad_mon.recorder.get_w() ]), )) part_loss = 0.0 part_wd = 0 if save_checkp_epoch and (_cur_rstep is None) and ( save_every is not None) and (cur_b % save_every == 0) and ( chkpf is not None) and (cur_b < ndata): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) cur_b += 1 if part_wd != 0.0: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) logger.info( "Dynb: %s" % (" ".join(["%.2f" % (_tmpu, ) for _tmpu in grad_mon.recorder.get_w()]), )) _log_f_dynbatch.close() return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False): sum_loss = 0.0 sum_wd = 0 part_loss = 0.0 part_wd = 0 _done_tokens = done_tokens model.train() cur_b = 1 ndata = len(tl) _cur_checkid = cur_checkid _cur_rstep = remain_steps _ls = {} if save_loss else None src_grp, tgt_grp = td["src"], td["tgt"] for i_d in tqdm(tl): seq_batch = torch.from_numpy(src_grp[i_d][:]).long() seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() output = model(seq_batch, oi) loss = lossf(output, ot) if multi_gpu: loss = loss.sum() loss_add = loss.data.item() # scale the sum of losses down according to the number of tokens adviced by: https://mp.weixin.qq.com/s/qAHZ4L5qK3rongCIIq5hQw, I think not reasonable. #loss /= wd_add if use_amp: with amp.scale_loss(loss, optm) as scaled_loss: scaled_loss.backward() else: loss.backward() wd_add = ot.ne(0).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: _ls[(i_d, t_d)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add if _done_tokens >= tokens_optm: if multi_gpu: model.collect_gradients() optm.step() optm.zero_grad() model.update_replicas() else: optm.step() optm.zero_grad() _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and ( _cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) _cur_rstep -= 1 if _cur_rstep <= 0: break lrsch.step() if nreport is not None: part_loss += loss_add part_wd += wd_add if cur_b % nreport == 0: if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu) logger.info( "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() else: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) part_loss = 0.0 part_wd = 0 if save_checkp_epoch and (_cur_rstep is None) and ( save_every is not None) and (cur_b % save_every == 0) and ( chkpf is not None) and (cur_b < ndata): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof #save_model(model, _chkpf, isinstance(model, nn.DataParallel), logger) save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) cur_b += 1 if part_wd != 0.0: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
def torch_to_h5(srcf, rsf, h5args=h5zipargs): h5save(torch.load(srcf, map_location='cpu'), rsf, h5args=h5args)
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False): sum_loss = 0.0 sum_wd = 0 part_loss = 0.0 part_wd = 0 _done_tokens = done_tokens model.train() cur_b = 1 ndata = len(tl) _cur_checkid = cur_checkid _cur_rstep = remain_steps _ls = {} if save_loss else None global grad_mon, update_angle src_grp, tgt_grp = td["src"], td["tgt"] for i_d in tqdm(tl): seq_batch = torch.from_numpy(src_grp[i_d][:]).long() seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() output = model(seq_batch, oi) loss = lossf(output, ot) if multi_gpu: loss = loss.sum() loss_add = loss.data.item() if use_amp: with amp.scale_loss(loss, optm) as scaled_loss: scaled_loss.backward() else: loss.backward() wd_add = ot.ne(0).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: _ls[(i_d, t_d)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add _perform_dyn_optm_step, _cos_sim = grad_mon.update(model.module if multi_gpu else model) if _perform_dyn_optm_step or (_done_tokens >= tokens_optm): if not _perform_dyn_optm_step: grad_mon.reset() _do_optm_step = True if _cos_sim is None else (_cos_sim <= update_angle) if _do_optm_step: if multi_gpu: model.collect_gradients() optm.step() optm.zero_grad() model.update_replicas() else: optm.step() optm.zero_grad() lrsch.step() else: if multi_gpu: #optm.zero_grad() model.reset_grad() else: optm.zero_grad() _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) if _do_optm_step: _cur_rstep -= 1 if _cur_rstep <= 0: break if nreport is not None: part_loss += loss_add part_wd += wd_add if cur_b % nreport == 0: if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() else: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) part_loss = 0.0 part_wd = 0 if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) cur_b += 1 if part_wd != 0.0: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls