def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len( tl) model.train() cur_b, _ls = 1, {} if save_loss else None src_grp, mt_grp, tgt_grp = td["src"], td["mt"], td["tgt"] for i_d in tqdm(tl): seq_batch = torch.from_numpy(src_grp[i_d][:]).long() seq_mt = torch.from_numpy(mt_grp[i_d][:]).long() seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_mt = seq_mt.to(mv_device) seq_o = seq_o.to(mv_device) oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() with autocast(enabled=_use_amp): output = model(seq_batch, seq_mt, oi) loss = lossf(output, ot) if multi_gpu: loss = loss.sum() loss_add = loss.data.item() if scaler is None: loss.backward() else: scaler.scale(loss).backward() wd_add = ot.ne(pad_id).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: _ls[(i_d, t_d)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add if _done_tokens >= tokens_optm: if multi_gpu: model.collect_gradients() optm_step(optm, scaler) optm.zero_grad(set_to_none=True) if multi_gpu: model.update_replicas() _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and ( _cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) _cur_rstep -= 1 if _cur_rstep <= 0: break lrsch.step() if nreport is not None: part_loss += loss_add part_wd += wd_add if cur_b % nreport == 0: if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info( "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() else: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) part_loss = 0.0 part_wd = 0 if save_checkp_epoch and (_cur_rstep is None) and ( save_every is not None) and (cur_b % save_every == 0) and ( chkpf is not None) and (cur_b < ndata): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) cur_b += 1 if part_wd != 0.0: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False): sum_loss = 0.0 sum_wd = 0 part_loss = 0.0 part_wd = 0 _done_tokens = done_tokens model.train() cur_b = 1 ndata = len(tl) _cur_checkid = cur_checkid _cur_rstep = remain_steps _ls = {} if save_loss else None src_grp, tgt_grp = td["src"], td["tgt"] for i_d in tqdm(tl): seq_batch = torch.from_numpy(src_grp[i_d][:]).long() seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() output = model(seq_batch, oi) loss = lossf(output, ot) if multi_gpu: loss = loss.sum() loss_add = loss.data.item() # scale the sum of losses down according to the number of tokens adviced by: https://mp.weixin.qq.com/s/qAHZ4L5qK3rongCIIq5hQw, I think not reasonable. #loss /= wd_add if use_amp: with amp.scale_loss(loss, optm) as scaled_loss: scaled_loss.backward() else: loss.backward() wd_add = ot.ne(0).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: _ls[(i_d, t_d)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add if _done_tokens >= tokens_optm: if multi_gpu: model.collect_gradients() optm.step() optm.zero_grad() model.update_replicas() else: optm.step() optm.zero_grad() _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and ( _cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) _cur_rstep -= 1 if _cur_rstep <= 0: break lrsch.step() if nreport is not None: part_loss += loss_add part_wd += wd_add if cur_b % nreport == 0: if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu) logger.info( "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() else: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) part_loss = 0.0 part_wd = 0 if save_checkp_epoch and (_cur_rstep is None) and ( save_every is not None) and (cur_b % save_every == 0) and ( chkpf is not None) and (cur_b < ndata): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof #save_model(model, _chkpf, isinstance(model, nn.DataParallel), logger) save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) cur_b += 1 if part_wd != 0.0: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len( tl) model.train() cur_b, _ls = 1, {} if save_loss else None global grad_mon, update_angle, enc_layer, log_dyn_p, log_dynb, wkdir _log_f_dynbatch = open(wkdir + "dynbatch.log", "ab") _log_f_dynbatch.write("ES\n".encode("utf-8")) src_grp, tgt_grp = td["src"], td["tgt"] for i_d in tqdm(tl): seq_batch = torch.from_numpy(src_grp[i_d][:]).long() seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() with autocast(enabled=_use_amp): output = model(seq_batch, oi) loss = lossf(output, ot) if multi_gpu: loss = loss.sum() loss_add = loss.data.item() if scaler is None: loss.backward() else: scaler.scale(loss).backward() wd_add = ot.ne(pad_id).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: _ls[(i_d, t_d)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add if grad_mon.prev_grad is None: if log_dynb: dyn_sel_ind = grad_mon.sel_ind dyn_sel_layer, dyn_sel_enc = dyn_sel_ind % enc_layer, dyn_sel_ind < enc_layer _log_f_dynbatch.write( ("%s%d %d\n" % ("E" if dyn_sel_enc else "D", dyn_sel_layer, wd_add)).encode("utf-8")) _perform_dyn_optm_step, _cos_sim_l = grad_mon.update( model.module if multi_gpu else model) _cos_sim = None if _cos_sim_l is None else _cos_sim_l[0] if log_dynb and (_cos_sim_l is not None): _log_f_dynbatch.write( ("%d %s\n" % (wd_add, " ".join(["%.2f" % (_cu, ) for _cu in _cos_sim_l]))).encode("utf-8")) if _perform_dyn_optm_step or (_done_tokens >= tokens_optm): if not _perform_dyn_optm_step: grad_mon.reset() _do_optm_step = True if _cos_sim is None else ( _cos_sim <= update_angle) if _do_optm_step: if log_dynb: _log_f_dynbatch.write( ("%d\n" % (_done_tokens, )).encode("utf-8")) if multi_gpu: model.collect_gradients() optm_step(optm, scaler) optm.zero_grad(set_to_none=True) if multi_gpu: model.update_replicas() lrsch.step() else: if log_dynb: _log_f_dynbatch.write( ("D %d\n" % (_done_tokens, )).encode("utf-8")) if multi_gpu: model.reset_grad() else: optm.zero_grad(set_to_none=True) log_dynb = random() <= log_dyn_p _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and ( _cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) if _do_optm_step: _cur_rstep -= 1 if _cur_rstep <= 0: break if nreport is not None: part_loss += loss_add part_wd += wd_add if cur_b % nreport == 0: if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info( "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() else: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) logger.info("Dynb: %s" % (" ".join([ "%.2f" % (_tmpu, ) for _tmpu in grad_mon.recorder.get_w() ]), )) part_loss = 0.0 part_wd = 0 if save_checkp_epoch and (_cur_rstep is None) and ( save_every is not None) and (cur_b % save_every == 0) and ( chkpf is not None) and (cur_b < ndata): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) cur_b += 1 if part_wd != 0.0: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) logger.info( "Dynb: %s" % (" ".join(["%.2f" % (_tmpu, ) for _tmpu in grad_mon.recorder.get_w()]), )) _log_f_dynbatch.close() return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False): sum_loss = 0.0 sum_wd = 0 part_loss = 0.0 part_wd = 0 _done_tokens = done_tokens model.train() cur_b = 1 ndata = len(tl) _cur_checkid = cur_checkid _cur_rstep = remain_steps _ls = {} if save_loss else None global grad_mon, update_angle src_grp, tgt_grp = td["src"], td["tgt"] for i_d in tqdm(tl): seq_batch = torch.from_numpy(src_grp[i_d][:]).long() seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() output = model(seq_batch, oi) loss = lossf(output, ot) if multi_gpu: loss = loss.sum() loss_add = loss.data.item() if use_amp: with amp.scale_loss(loss, optm) as scaled_loss: scaled_loss.backward() else: loss.backward() wd_add = ot.ne(0).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: _ls[(i_d, t_d)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add _perform_dyn_optm_step, _cos_sim = grad_mon.update(model.module if multi_gpu else model) if _perform_dyn_optm_step or (_done_tokens >= tokens_optm): if not _perform_dyn_optm_step: grad_mon.reset() _do_optm_step = True if _cos_sim is None else (_cos_sim <= update_angle) if _do_optm_step: if multi_gpu: model.collect_gradients() optm.step() optm.zero_grad() model.update_replicas() else: optm.step() optm.zero_grad() lrsch.step() else: if multi_gpu: #optm.zero_grad() model.reset_grad() else: optm.zero_grad() _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) if _do_optm_step: _cur_rstep -= 1 if _cur_rstep <= 0: break if nreport is not None: part_loss += loss_add part_wd += wd_add if cur_b % nreport == 0: if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() else: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) part_loss = 0.0 part_wd = 0 if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend if chkpof is not None: _chkpof = chkpof[:-3] + _fend _cur_checkid = (_cur_checkid + 1) % num_checkpoint else: _chkpf = chkpf _chkpof = chkpof save_model(model, _chkpf, multi_gpu, logger) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: save_states(statesf, tl[cur_b - 1:]) cur_b += 1 if part_wd != 0.0: logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls