Пример #1
0
    def _worker(model,
                fname,
                sub_module=False,
                logger=None,
                para_lock=None,
                log_success=None):

        success = True
        _msave = model.module if sub_module else model
        try:
            if para_lock is None:
                h5save([t.data for t in _msave.parameters()],
                       fname,
                       h5args=h5args)
            else:
                with para_lock:
                    h5save([t.data for t in _msave.parameters()],
                           fname,
                           h5args=h5args)
        except Exception as e:
            if logger is None:
                print(e)
            else:
                logger.info(str(e))
            success = False
        if success and (logger is not None) and (log_success is not None):
            logger.info(log_success)
Пример #2
0
def handle(srcf, rsf, h5args=h5zipargs):

    if srcf == rsf:
        h5save(h5load(srcf, restore_list=False), rsf, h5args=h5args)
    else:
        with h5File(srcf, "r") as sfg, h5File(rsf, 'w') as rfg:
            handle_group(sfg, rfg, h5args=h5args)
Пример #3
0
def handle(srcfl, rsf):

    rsm = h5load(srcfl[0])

    src_type = [para.dtype for para in rsm]
    map_type = [
        secure_type_map[para.dtype] if para.dtype in secure_type_map else None
        for para in rsm
    ]
    sec_rsm = [
        para if typ is None else para.to(typ)
        for para, typ in zip(rsm, map_type)
    ]

    nmodel = 1
    for modelf in srcfl[1:]:
        for basep, mpload, typ in zip(sec_rsm, h5load(modelf), map_type):
            basep.add_(mpload if typ is None else mpload.to(typ))
        nmodel += 1
    nmodel = float(nmodel)
    for basep in sec_rsm:
        basep.div_(nmodel)

    rsm = [
        para if mtyp is None else para.to(styp)
        for para, mtyp, styp in zip(sec_rsm, map_type, src_type)
    ]

    h5save(rsm, rsf, h5args=h5zipargs)
Пример #4
0
def handle(srcf, rsf, h5args=h5zipargs):

	if srcf == rsf:
		h5save(h5load(srcf, restore_list=False), rsf, h5args=h5args)
	else:
		sfg, rfg = h5py.File(srcf, "r"), h5py.File(rsf, 'w')
		handle_group(sfg, rfg, h5args=h5args)
		sfg.close()
		rfg.close()
Пример #5
0
def handle(vcbf, embf, rsf):

    vcb, nwd = ldvocab(vcbf)
    emb = load_emb_txt(vcb, embf)
    unkemb = emb.get("<unk>", torch.zeros(emb[list(emb.keys())[0]].size(0)))
    vcb = reverse_dict(vcb)
    rs = []
    for i in range(nwd):
        rs.append(emb.get(vcb[i], unkemb))
    h5save(torch.stack(rs, 0), rsf)
Пример #6
0
def save_model(model, fname, sub_module=False, logger=None):

	_msave = model.module if sub_module else model
	try:
		h5save([t.data for t in _msave.parameters()], fname)
	except Exception as e:
		if logger is None:
			print(e)
		else:
			logger.info(str(e))
Пример #7
0
def save_model(model, fname, sub_module=False, print_func=print, mtyp=None, h5args=h5modelwargs):

	_msave = model.module if sub_module else model
	try:
		h5save([t.data for t in _msave.parameters()], fname, h5args=h5args)
		if mtyp is not None:
			save_model_cleaner(fname, mtyp)
	except Exception as e:
		if print_func is not None:
			print_func(str(e))
Пример #8
0
	def _worker(model, fname, sub_module=False, print_func=print, mtyp=None, para_lock=None, log_success=None):

		success = True
		_msave = model.module if sub_module else model
		try:
			if para_lock is None:
				h5save([t.data for t in _msave.parameters()], fname, h5args=h5args)
				if mtyp is not None:
					save_model_cleaner(fname, mtyp)
			else:
				with para_lock:
					h5save([t.data for t in _msave.parameters()], fname, h5args=h5args)
					if mtyp is not None:
						save_model_cleaner(fname, mtyp)
		except Exception as e:
			if print_func is not None:
				print_func(str(e))
			success = False
		if success and (print_func is not None) and (log_success is not None):
			print_func(str(log_success))
Пример #9
0
def handle(srcfl, rsf):

    type_map = {
        torch.float16: torch.float64,
        torch.float32: torch.float64,
        torch.uint8: torch.int64,
        torch.int8: torch.int64,
        torch.int16: torch.int64,
        torch.int32: torch.int64
    }
    type_map[mask_tensor_type] = torch.int64

    rsm = h5load(srcfl[0])

    src_type = [para.dtype for para in rsm]
    map_type = [
        type_map[para.dtype] if para.dtype in type_map else None
        for para in rsm
    ]
    sec_rsm = [
        para if typ is None else para.to(typ)
        for para, typ in zip(rsm, map_type)
    ]

    nmodel = 1
    for modelf in srcfl[1:]:
        for basep, mpload, typ in zip(sec_rsm, h5load(modelf), map_type):
            basep.add_(mpload if typ is None else mpload.to(typ))
        nmodel += 1
    nmodel = float(nmodel)
    for basep in sec_rsm:
        basep.div_(nmodel)

    rsm = [
        para if mtyp is None else para.to(styp)
        for para, mtyp, styp in zip(sec_rsm, map_type, src_type)
    ]

    h5save(rsm, rsf, h5args=h5zipargs)
Пример #10
0
def train(td,
          tl,
          ed,
          nd,
          optm,
          lrsch,
          model,
          lossf,
          mv_device,
          logger,
          done_tokens,
          multi_gpu,
          tokens_optm=32768,
          nreport=None,
          save_every=None,
          chkpf=None,
          chkpof=None,
          statesf=None,
          num_checkpoint=1,
          cur_checkid=0,
          report_eva=True,
          remain_steps=None,
          save_loss=False,
          save_checkp_epoch=False,
          scaler=None):

    sum_loss = part_loss = 0.0
    sum_wd = part_wd = 0
    _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len(
        tl)
    model.train()
    cur_b, _ls = 1, {} if save_loss else None
    src_grp, mt_grp, tgt_grp = td["src"], td["mt"], td["tgt"]
    for i_d in tqdm(tl):
        seq_batch = torch.from_numpy(src_grp[i_d][:]).long()
        seq_mt = torch.from_numpy(mt_grp[i_d][:]).long()
        seq_o = torch.from_numpy(tgt_grp[i_d][:]).long()
        lo = seq_o.size(1) - 1
        if mv_device:
            seq_batch = seq_batch.to(mv_device)
            seq_mt = seq_mt.to(mv_device)
            seq_o = seq_o.to(mv_device)

        oi = seq_o.narrow(1, 0, lo)
        ot = seq_o.narrow(1, 1, lo).contiguous()
        with autocast(enabled=_use_amp):
            output = model(seq_batch, seq_mt, oi)
            loss = lossf(output, ot)
            if multi_gpu:
                loss = loss.sum()
        loss_add = loss.data.item()

        if scaler is None:
            loss.backward()
        else:
            scaler.scale(loss).backward()

        wd_add = ot.ne(pad_id).int().sum().item()
        loss = output = oi = ot = seq_batch = seq_o = None
        sum_loss += loss_add
        if save_loss:
            _ls[(i_d, t_d)] = loss_add / wd_add
        sum_wd += wd_add
        _done_tokens += wd_add

        if _done_tokens >= tokens_optm:
            if multi_gpu:
                model.collect_gradients()
            optm_step(optm, scaler)
            optm.zero_grad(set_to_none=True)
            if multi_gpu:
                model.update_replicas()
            _done_tokens = 0
            if _cur_rstep is not None:
                if save_checkp_epoch and (save_every is not None) and (
                        _cur_rstep % save_every
                        == 0) and (chkpf is not None) and (_cur_rstep > 0):
                    if num_checkpoint > 1:
                        _fend = "_%d.h5" % (_cur_checkid)
                        _chkpf = chkpf[:-3] + _fend
                        if chkpof is not None:
                            _chkpof = chkpof[:-3] + _fend
                        _cur_checkid = (_cur_checkid + 1) % num_checkpoint
                    else:
                        _chkpf = chkpf
                        _chkpof = chkpof
                    save_model(model, _chkpf, multi_gpu, logger)
                    if chkpof is not None:
                        h5save(optm.state_dict(), _chkpof)
                    if statesf is not None:
                        save_states(statesf, tl[cur_b - 1:])
                _cur_rstep -= 1
                if _cur_rstep <= 0:
                    break
            lrsch.step()

        if nreport is not None:
            part_loss += loss_add
            part_wd += wd_add
            if cur_b % nreport == 0:
                if report_eva:
                    _leva, _eeva = eva(ed, nd, model, lossf, mv_device,
                                       multi_gpu, _use_amp)
                    logger.info(
                        "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f"
                        % (part_wd, part_loss / part_wd, _leva, _eeva))
                    free_cache(mv_device)
                    model.train()
                else:
                    logger.info("Average loss over %d tokens: %.3f" %
                                (part_wd, part_loss / part_wd))
                part_loss = 0.0
                part_wd = 0

        if save_checkp_epoch and (_cur_rstep is None) and (
                save_every is not None) and (cur_b % save_every == 0) and (
                    chkpf is not None) and (cur_b < ndata):
            if num_checkpoint > 1:
                _fend = "_%d.h5" % (_cur_checkid)
                _chkpf = chkpf[:-3] + _fend
                if chkpof is not None:
                    _chkpof = chkpof[:-3] + _fend
                _cur_checkid = (_cur_checkid + 1) % num_checkpoint
            else:
                _chkpf = chkpf
                _chkpof = chkpof
            save_model(model, _chkpf, multi_gpu, logger)
            if chkpof is not None:
                h5save(optm.state_dict(), _chkpof)
            if statesf is not None:
                save_states(statesf, tl[cur_b - 1:])
        cur_b += 1
    if part_wd != 0.0:
        logger.info("Average loss over %d tokens: %.3f" %
                    (part_wd, part_loss / part_wd))
    return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
Пример #11
0
        tminerr, done_tokens, cur_checkid, remain_steps, _ = train(
            td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel,
            lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm,
            batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint,
            cur_checkid, report_eva, remain_steps, False, False, scaler)
        vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu,
                           use_amp)
        logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" %
                    (tminerr, vloss, vprec))
        save_model(
            mymodel,
            wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec),
            multi_gpu, logger)
        if save_optm_state:
            h5save(
                optimizer.state_dict(), wkdir +
                "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec))
        logger.info("New best model saved")

if cnfg.dss_ws is not None and cnfg.dss_ws > 0.0 and cnfg.dss_ws < 1.0:
    dss_ws = int(cnfg.dss_ws * ntrain)
    _Dws = {}
    _prev_Dws = {}
    _crit_inc = {}
    if cnfg.dss_rm is not None and cnfg.dss_rm > 0.0 and cnfg.dss_rm < 1.0:
        dss_rm = int(cnfg.dss_rm * ntrain * (1.0 - cnfg.dss_ws))
    else:
        dss_rm = 0
else:
    dss_ws = 0
    dss_rm = 0
Пример #12
0
def train(td,
          tl,
          ed,
          nd,
          optm,
          lrsch,
          model,
          lossf,
          mv_device,
          logger,
          done_tokens,
          multi_gpu,
          tokens_optm=32768,
          nreport=None,
          save_every=None,
          chkpf=None,
          chkpof=None,
          statesf=None,
          num_checkpoint=1,
          cur_checkid=0,
          report_eva=True,
          remain_steps=None,
          save_loss=False,
          save_checkp_epoch=False,
          scaler=None):

    sum_loss = part_loss = 0.0
    sum_wd = part_wd = 0
    _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len(
        tl)
    model.train()
    cur_b, _ls = 1, {} if save_loss else None

    global grad_mon, update_angle, enc_layer, log_dyn_p, log_dynb, wkdir
    _log_f_dynbatch = open(wkdir + "dynbatch.log", "ab")
    _log_f_dynbatch.write("ES\n".encode("utf-8"))

    src_grp, tgt_grp = td["src"], td["tgt"]
    for i_d in tqdm(tl):
        seq_batch = torch.from_numpy(src_grp[i_d][:]).long()
        seq_o = torch.from_numpy(tgt_grp[i_d][:]).long()
        lo = seq_o.size(1) - 1
        if mv_device:
            seq_batch = seq_batch.to(mv_device)
            seq_o = seq_o.to(mv_device)

        oi = seq_o.narrow(1, 0, lo)
        ot = seq_o.narrow(1, 1, lo).contiguous()
        with autocast(enabled=_use_amp):
            output = model(seq_batch, oi)
            loss = lossf(output, ot)
            if multi_gpu:
                loss = loss.sum()
        loss_add = loss.data.item()

        if scaler is None:
            loss.backward()
        else:
            scaler.scale(loss).backward()

        wd_add = ot.ne(pad_id).int().sum().item()
        loss = output = oi = ot = seq_batch = seq_o = None
        sum_loss += loss_add
        if save_loss:
            _ls[(i_d, t_d)] = loss_add / wd_add
        sum_wd += wd_add
        _done_tokens += wd_add

        if grad_mon.prev_grad is None:
            if log_dynb:
                dyn_sel_ind = grad_mon.sel_ind
                dyn_sel_layer, dyn_sel_enc = dyn_sel_ind % enc_layer, dyn_sel_ind < enc_layer
                _log_f_dynbatch.write(
                    ("%s%d %d\n" % ("E" if dyn_sel_enc else "D", dyn_sel_layer,
                                    wd_add)).encode("utf-8"))
        _perform_dyn_optm_step, _cos_sim_l = grad_mon.update(
            model.module if multi_gpu else model)
        _cos_sim = None if _cos_sim_l is None else _cos_sim_l[0]
        if log_dynb and (_cos_sim_l is not None):
            _log_f_dynbatch.write(
                ("%d %s\n" %
                 (wd_add, " ".join(["%.2f" % (_cu, )
                                    for _cu in _cos_sim_l]))).encode("utf-8"))

        if _perform_dyn_optm_step or (_done_tokens >= tokens_optm):
            if not _perform_dyn_optm_step:
                grad_mon.reset()
            _do_optm_step = True if _cos_sim is None else (
                _cos_sim <= update_angle)
            if _do_optm_step:
                if log_dynb:
                    _log_f_dynbatch.write(
                        ("%d\n" % (_done_tokens, )).encode("utf-8"))
                if multi_gpu:
                    model.collect_gradients()
                optm_step(optm, scaler)
                optm.zero_grad(set_to_none=True)
                if multi_gpu:
                    model.update_replicas()
                lrsch.step()
            else:
                if log_dynb:
                    _log_f_dynbatch.write(
                        ("D %d\n" % (_done_tokens, )).encode("utf-8"))
                if multi_gpu:
                    model.reset_grad()
                else:
                    optm.zero_grad(set_to_none=True)
            log_dynb = random() <= log_dyn_p
            _done_tokens = 0
            if _cur_rstep is not None:
                if save_checkp_epoch and (save_every is not None) and (
                        _cur_rstep % save_every
                        == 0) and (chkpf is not None) and (_cur_rstep > 0):
                    if num_checkpoint > 1:
                        _fend = "_%d.h5" % (_cur_checkid)
                        _chkpf = chkpf[:-3] + _fend
                        if chkpof is not None:
                            _chkpof = chkpof[:-3] + _fend
                        _cur_checkid = (_cur_checkid + 1) % num_checkpoint
                    else:
                        _chkpf = chkpf
                        _chkpof = chkpof
                    save_model(model, _chkpf, multi_gpu, logger)
                    if chkpof is not None:
                        h5save(optm.state_dict(), _chkpof)
                    if statesf is not None:
                        save_states(statesf, tl[cur_b - 1:])
                if _do_optm_step:
                    _cur_rstep -= 1
                    if _cur_rstep <= 0:
                        break

        if nreport is not None:
            part_loss += loss_add
            part_wd += wd_add
            if cur_b % nreport == 0:
                if report_eva:
                    _leva, _eeva = eva(ed, nd, model, lossf, mv_device,
                                       multi_gpu, _use_amp)
                    logger.info(
                        "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f"
                        % (part_wd, part_loss / part_wd, _leva, _eeva))
                    free_cache(mv_device)
                    model.train()
                else:
                    logger.info("Average loss over %d tokens: %.3f" %
                                (part_wd, part_loss / part_wd))
                logger.info("Dynb: %s" % (" ".join([
                    "%.2f" % (_tmpu, ) for _tmpu in grad_mon.recorder.get_w()
                ]), ))
                part_loss = 0.0
                part_wd = 0

        if save_checkp_epoch and (_cur_rstep is None) and (
                save_every is not None) and (cur_b % save_every == 0) and (
                    chkpf is not None) and (cur_b < ndata):
            if num_checkpoint > 1:
                _fend = "_%d.h5" % (_cur_checkid)
                _chkpf = chkpf[:-3] + _fend
                if chkpof is not None:
                    _chkpof = chkpof[:-3] + _fend
                _cur_checkid = (_cur_checkid + 1) % num_checkpoint
            else:
                _chkpf = chkpf
                _chkpof = chkpof
            save_model(model, _chkpf, multi_gpu, logger)
            if chkpof is not None:
                h5save(optm.state_dict(), _chkpof)
            if statesf is not None:
                save_states(statesf, tl[cur_b - 1:])
        cur_b += 1
    if part_wd != 0.0:
        logger.info("Average loss over %d tokens: %.3f" %
                    (part_wd, part_loss / part_wd))
    logger.info(
        "Dynb: %s" %
        (" ".join(["%.2f" % (_tmpu, )
                   for _tmpu in grad_mon.recorder.get_w()]), ))

    _log_f_dynbatch.close()

    return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
Пример #13
0
def train(td,
          tl,
          ed,
          nd,
          optm,
          lrsch,
          model,
          lossf,
          mv_device,
          logger,
          done_tokens,
          multi_gpu,
          tokens_optm=32768,
          nreport=None,
          save_every=None,
          chkpf=None,
          chkpof=None,
          statesf=None,
          num_checkpoint=1,
          cur_checkid=0,
          report_eva=True,
          remain_steps=None,
          save_loss=False,
          save_checkp_epoch=False,
          use_amp=False):

    sum_loss = 0.0
    sum_wd = 0
    part_loss = 0.0
    part_wd = 0
    _done_tokens = done_tokens
    model.train()
    cur_b = 1
    ndata = len(tl)
    _cur_checkid = cur_checkid
    _cur_rstep = remain_steps
    _ls = {} if save_loss else None
    src_grp, tgt_grp = td["src"], td["tgt"]
    for i_d in tqdm(tl):
        seq_batch = torch.from_numpy(src_grp[i_d][:]).long()
        seq_o = torch.from_numpy(tgt_grp[i_d][:]).long()
        lo = seq_o.size(1) - 1
        if mv_device:
            seq_batch = seq_batch.to(mv_device)
            seq_o = seq_o.to(mv_device)

        oi = seq_o.narrow(1, 0, lo)
        ot = seq_o.narrow(1, 1, lo).contiguous()
        output = model(seq_batch, oi)
        loss = lossf(output, ot)
        if multi_gpu:
            loss = loss.sum()
        loss_add = loss.data.item()

        # scale the sum of losses down according to the number of tokens adviced by: https://mp.weixin.qq.com/s/qAHZ4L5qK3rongCIIq5hQw, I think not reasonable.
        #loss /= wd_add
        if use_amp:
            with amp.scale_loss(loss, optm) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        wd_add = ot.ne(0).int().sum().item()
        loss = output = oi = ot = seq_batch = seq_o = None
        sum_loss += loss_add
        if save_loss:
            _ls[(i_d, t_d)] = loss_add / wd_add
        sum_wd += wd_add
        _done_tokens += wd_add

        if _done_tokens >= tokens_optm:
            if multi_gpu:
                model.collect_gradients()
                optm.step()
                optm.zero_grad()
                model.update_replicas()
            else:
                optm.step()
                optm.zero_grad()
            _done_tokens = 0
            if _cur_rstep is not None:
                if save_checkp_epoch and (save_every is not None) and (
                        _cur_rstep % save_every
                        == 0) and (chkpf is not None) and (_cur_rstep > 0):
                    if num_checkpoint > 1:
                        _fend = "_%d.h5" % (_cur_checkid)
                        _chkpf = chkpf[:-3] + _fend
                        if chkpof is not None:
                            _chkpof = chkpof[:-3] + _fend
                        _cur_checkid = (_cur_checkid + 1) % num_checkpoint
                    else:
                        _chkpf = chkpf
                        _chkpof = chkpof
                    save_model(model, _chkpf, multi_gpu, logger)
                    if chkpof is not None:
                        h5save(optm.state_dict(), _chkpof)
                    if statesf is not None:
                        save_states(statesf, tl[cur_b - 1:])
                _cur_rstep -= 1
                if _cur_rstep <= 0:
                    break
            lrsch.step()

        if nreport is not None:
            part_loss += loss_add
            part_wd += wd_add
            if cur_b % nreport == 0:
                if report_eva:
                    _leva, _eeva = eva(ed, nd, model, lossf, mv_device,
                                       multi_gpu)
                    logger.info(
                        "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f"
                        % (part_wd, part_loss / part_wd, _leva, _eeva))
                    free_cache(mv_device)
                    model.train()
                else:
                    logger.info("Average loss over %d tokens: %.3f" %
                                (part_wd, part_loss / part_wd))
                part_loss = 0.0
                part_wd = 0

        if save_checkp_epoch and (_cur_rstep is None) and (
                save_every is not None) and (cur_b % save_every == 0) and (
                    chkpf is not None) and (cur_b < ndata):
            if num_checkpoint > 1:
                _fend = "_%d.h5" % (_cur_checkid)
                _chkpf = chkpf[:-3] + _fend
                if chkpof is not None:
                    _chkpof = chkpof[:-3] + _fend
                _cur_checkid = (_cur_checkid + 1) % num_checkpoint
            else:
                _chkpf = chkpf
                _chkpof = chkpof
            #save_model(model, _chkpf, isinstance(model, nn.DataParallel), logger)
            save_model(model, _chkpf, multi_gpu, logger)
            if chkpof is not None:
                h5save(optm.state_dict(), _chkpof)
            if statesf is not None:
                save_states(statesf, tl[cur_b - 1:])
        cur_b += 1
    if part_wd != 0.0:
        logger.info("Average loss over %d tokens: %.3f" %
                    (part_wd, part_loss / part_wd))
    return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
Пример #14
0
def torch_to_h5(srcf, rsf, h5args=h5zipargs):
	h5save(torch.load(srcf, map_location='cpu'), rsf, h5args=h5args)
Пример #15
0
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False):

	sum_loss = 0.0
	sum_wd = 0
	part_loss = 0.0
	part_wd = 0
	_done_tokens = done_tokens
	model.train()
	cur_b = 1
	ndata = len(tl)
	_cur_checkid = cur_checkid
	_cur_rstep = remain_steps
	_ls = {} if save_loss else None

	global grad_mon, update_angle

	src_grp, tgt_grp = td["src"], td["tgt"]
	for i_d in tqdm(tl):
		seq_batch = torch.from_numpy(src_grp[i_d][:]).long()
		seq_o = torch.from_numpy(tgt_grp[i_d][:]).long()
		lo = seq_o.size(1) - 1
		if mv_device:
			seq_batch = seq_batch.to(mv_device)
			seq_o = seq_o.to(mv_device)

		oi = seq_o.narrow(1, 0, lo)
		ot = seq_o.narrow(1, 1, lo).contiguous()
		output = model(seq_batch, oi)
		loss = lossf(output, ot)
		if multi_gpu:
			loss = loss.sum()
		loss_add = loss.data.item()

		if use_amp:
			with amp.scale_loss(loss, optm) as scaled_loss:
				scaled_loss.backward()
		else:
			loss.backward()

		wd_add = ot.ne(0).int().sum().item()
		loss = output = oi = ot = seq_batch = seq_o = None
		sum_loss += loss_add
		if save_loss:
			_ls[(i_d, t_d)] = loss_add / wd_add
		sum_wd += wd_add
		_done_tokens += wd_add

		_perform_dyn_optm_step, _cos_sim = grad_mon.update(model.module if multi_gpu else model)

		if _perform_dyn_optm_step or (_done_tokens >= tokens_optm):
			if not _perform_dyn_optm_step:
				grad_mon.reset()
			_do_optm_step = True if _cos_sim is None else (_cos_sim <= update_angle)
			if _do_optm_step:
				if multi_gpu:
					model.collect_gradients()
					optm.step()
					optm.zero_grad()
					model.update_replicas()
				else:
					optm.step()
					optm.zero_grad()
				lrsch.step()
			else:
				if multi_gpu:
					#optm.zero_grad()
					model.reset_grad()
				else:
					optm.zero_grad()
			_done_tokens = 0
			if _cur_rstep is not None:
				if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0):
					if num_checkpoint > 1:
						_fend = "_%d.h5" % (_cur_checkid)
						_chkpf = chkpf[:-3] + _fend
						if chkpof is not None:
							_chkpof = chkpof[:-3] + _fend
						_cur_checkid = (_cur_checkid + 1) % num_checkpoint
					else:
						_chkpf = chkpf
						_chkpof = chkpof
					save_model(model, _chkpf, multi_gpu, logger)
					if chkpof is not None:
						h5save(optm.state_dict(), _chkpof)
					if statesf is not None:
						save_states(statesf, tl[cur_b - 1:])
				if _do_optm_step:
					_cur_rstep -= 1
					if _cur_rstep <= 0:
						break

		if nreport is not None:
			part_loss += loss_add
			part_wd += wd_add
			if cur_b % nreport == 0:
				if report_eva:
					_leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu)
					logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva))
					free_cache(mv_device)
					model.train()
				else:
					logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd))
				part_loss = 0.0
				part_wd = 0

		if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata):
			if num_checkpoint > 1:
				_fend = "_%d.h5" % (_cur_checkid)
				_chkpf = chkpf[:-3] + _fend
				if chkpof is not None:
					_chkpof = chkpof[:-3] + _fend
				_cur_checkid = (_cur_checkid + 1) % num_checkpoint
			else:
				_chkpf = chkpf
				_chkpof = chkpof
			save_model(model, _chkpf, multi_gpu, logger)
			if chkpof is not None:
				h5save(optm.state_dict(), _chkpof)
			if statesf is not None:
				save_states(statesf, tl[cur_b - 1:])
		cur_b += 1
	if part_wd != 0.0:
		logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd))

	return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls