Ejemplo n.º 1
0
def train(td,
          tl,
          ed,
          nd,
          optm,
          lrsch,
          model,
          lossf,
          mv_device,
          logger,
          done_tokens,
          multi_gpu,
          tokens_optm=32768,
          nreport=None,
          save_every=None,
          chkpf=None,
          chkpof=None,
          statesf=None,
          num_checkpoint=1,
          cur_checkid=0,
          report_eva=True,
          remain_steps=None,
          save_loss=False,
          save_checkp_epoch=False,
          scaler=None):

    sum_loss = part_loss = 0.0
    sum_wd = part_wd = 0
    _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len(
        tl)
    model.train()
    cur_b, _ls = 1, {} if save_loss else None
    src_grp, mt_grp, tgt_grp = td["src"], td["mt"], td["tgt"]
    for i_d in tqdm(tl):
        seq_batch = torch.from_numpy(src_grp[i_d][:]).long()
        seq_mt = torch.from_numpy(mt_grp[i_d][:]).long()
        seq_o = torch.from_numpy(tgt_grp[i_d][:]).long()
        lo = seq_o.size(1) - 1
        if mv_device:
            seq_batch = seq_batch.to(mv_device)
            seq_mt = seq_mt.to(mv_device)
            seq_o = seq_o.to(mv_device)

        oi = seq_o.narrow(1, 0, lo)
        ot = seq_o.narrow(1, 1, lo).contiguous()
        with autocast(enabled=_use_amp):
            output = model(seq_batch, seq_mt, oi)
            loss = lossf(output, ot)
            if multi_gpu:
                loss = loss.sum()
        loss_add = loss.data.item()

        if scaler is None:
            loss.backward()
        else:
            scaler.scale(loss).backward()

        wd_add = ot.ne(pad_id).int().sum().item()
        loss = output = oi = ot = seq_batch = seq_o = None
        sum_loss += loss_add
        if save_loss:
            _ls[(i_d, t_d)] = loss_add / wd_add
        sum_wd += wd_add
        _done_tokens += wd_add

        if _done_tokens >= tokens_optm:
            if multi_gpu:
                model.collect_gradients()
            optm_step(optm, scaler)
            optm.zero_grad(set_to_none=True)
            if multi_gpu:
                model.update_replicas()
            _done_tokens = 0
            if _cur_rstep is not None:
                if save_checkp_epoch and (save_every is not None) and (
                        _cur_rstep % save_every
                        == 0) and (chkpf is not None) and (_cur_rstep > 0):
                    if num_checkpoint > 1:
                        _fend = "_%d.h5" % (_cur_checkid)
                        _chkpf = chkpf[:-3] + _fend
                        if chkpof is not None:
                            _chkpof = chkpof[:-3] + _fend
                        _cur_checkid = (_cur_checkid + 1) % num_checkpoint
                    else:
                        _chkpf = chkpf
                        _chkpof = chkpof
                    save_model(model, _chkpf, multi_gpu, logger)
                    if chkpof is not None:
                        h5save(optm.state_dict(), _chkpof)
                    if statesf is not None:
                        save_states(statesf, tl[cur_b - 1:])
                _cur_rstep -= 1
                if _cur_rstep <= 0:
                    break
            lrsch.step()

        if nreport is not None:
            part_loss += loss_add
            part_wd += wd_add
            if cur_b % nreport == 0:
                if report_eva:
                    _leva, _eeva = eva(ed, nd, model, lossf, mv_device,
                                       multi_gpu, _use_amp)
                    logger.info(
                        "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f"
                        % (part_wd, part_loss / part_wd, _leva, _eeva))
                    free_cache(mv_device)
                    model.train()
                else:
                    logger.info("Average loss over %d tokens: %.3f" %
                                (part_wd, part_loss / part_wd))
                part_loss = 0.0
                part_wd = 0

        if save_checkp_epoch and (_cur_rstep is None) and (
                save_every is not None) and (cur_b % save_every == 0) and (
                    chkpf is not None) and (cur_b < ndata):
            if num_checkpoint > 1:
                _fend = "_%d.h5" % (_cur_checkid)
                _chkpf = chkpf[:-3] + _fend
                if chkpof is not None:
                    _chkpof = chkpof[:-3] + _fend
                _cur_checkid = (_cur_checkid + 1) % num_checkpoint
            else:
                _chkpf = chkpf
                _chkpof = chkpof
            save_model(model, _chkpf, multi_gpu, logger)
            if chkpof is not None:
                h5save(optm.state_dict(), _chkpof)
            if statesf is not None:
                save_states(statesf, tl[cur_b - 1:])
        cur_b += 1
    if part_wd != 0.0:
        logger.info("Average loss over %d tokens: %.3f" %
                    (part_wd, part_loss / part_wd))
    return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
Ejemplo n.º 2
0
def train(td,
          tl,
          ed,
          nd,
          optm,
          lrsch,
          model,
          lossf,
          mv_device,
          logger,
          done_tokens,
          multi_gpu,
          tokens_optm=32768,
          nreport=None,
          save_every=None,
          chkpf=None,
          chkpof=None,
          statesf=None,
          num_checkpoint=1,
          cur_checkid=0,
          report_eva=True,
          remain_steps=None,
          save_loss=False,
          save_checkp_epoch=False,
          use_amp=False):

    sum_loss = 0.0
    sum_wd = 0
    part_loss = 0.0
    part_wd = 0
    _done_tokens = done_tokens
    model.train()
    cur_b = 1
    ndata = len(tl)
    _cur_checkid = cur_checkid
    _cur_rstep = remain_steps
    _ls = {} if save_loss else None
    src_grp, tgt_grp = td["src"], td["tgt"]
    for i_d in tqdm(tl):
        seq_batch = torch.from_numpy(src_grp[i_d][:]).long()
        seq_o = torch.from_numpy(tgt_grp[i_d][:]).long()
        lo = seq_o.size(1) - 1
        if mv_device:
            seq_batch = seq_batch.to(mv_device)
            seq_o = seq_o.to(mv_device)

        oi = seq_o.narrow(1, 0, lo)
        ot = seq_o.narrow(1, 1, lo).contiguous()
        output = model(seq_batch, oi)
        loss = lossf(output, ot)
        if multi_gpu:
            loss = loss.sum()
        loss_add = loss.data.item()

        # scale the sum of losses down according to the number of tokens adviced by: https://mp.weixin.qq.com/s/qAHZ4L5qK3rongCIIq5hQw, I think not reasonable.
        #loss /= wd_add
        if use_amp:
            with amp.scale_loss(loss, optm) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        wd_add = ot.ne(0).int().sum().item()
        loss = output = oi = ot = seq_batch = seq_o = None
        sum_loss += loss_add
        if save_loss:
            _ls[(i_d, t_d)] = loss_add / wd_add
        sum_wd += wd_add
        _done_tokens += wd_add

        if _done_tokens >= tokens_optm:
            if multi_gpu:
                model.collect_gradients()
                optm.step()
                optm.zero_grad()
                model.update_replicas()
            else:
                optm.step()
                optm.zero_grad()
            _done_tokens = 0
            if _cur_rstep is not None:
                if save_checkp_epoch and (save_every is not None) and (
                        _cur_rstep % save_every
                        == 0) and (chkpf is not None) and (_cur_rstep > 0):
                    if num_checkpoint > 1:
                        _fend = "_%d.h5" % (_cur_checkid)
                        _chkpf = chkpf[:-3] + _fend
                        if chkpof is not None:
                            _chkpof = chkpof[:-3] + _fend
                        _cur_checkid = (_cur_checkid + 1) % num_checkpoint
                    else:
                        _chkpf = chkpf
                        _chkpof = chkpof
                    save_model(model, _chkpf, multi_gpu, logger)
                    if chkpof is not None:
                        h5save(optm.state_dict(), _chkpof)
                    if statesf is not None:
                        save_states(statesf, tl[cur_b - 1:])
                _cur_rstep -= 1
                if _cur_rstep <= 0:
                    break
            lrsch.step()

        if nreport is not None:
            part_loss += loss_add
            part_wd += wd_add
            if cur_b % nreport == 0:
                if report_eva:
                    _leva, _eeva = eva(ed, nd, model, lossf, mv_device,
                                       multi_gpu)
                    logger.info(
                        "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f"
                        % (part_wd, part_loss / part_wd, _leva, _eeva))
                    free_cache(mv_device)
                    model.train()
                else:
                    logger.info("Average loss over %d tokens: %.3f" %
                                (part_wd, part_loss / part_wd))
                part_loss = 0.0
                part_wd = 0

        if save_checkp_epoch and (_cur_rstep is None) and (
                save_every is not None) and (cur_b % save_every == 0) and (
                    chkpf is not None) and (cur_b < ndata):
            if num_checkpoint > 1:
                _fend = "_%d.h5" % (_cur_checkid)
                _chkpf = chkpf[:-3] + _fend
                if chkpof is not None:
                    _chkpof = chkpof[:-3] + _fend
                _cur_checkid = (_cur_checkid + 1) % num_checkpoint
            else:
                _chkpf = chkpf
                _chkpof = chkpof
            #save_model(model, _chkpf, isinstance(model, nn.DataParallel), logger)
            save_model(model, _chkpf, multi_gpu, logger)
            if chkpof is not None:
                h5save(optm.state_dict(), _chkpof)
            if statesf is not None:
                save_states(statesf, tl[cur_b - 1:])
        cur_b += 1
    if part_wd != 0.0:
        logger.info("Average loss over %d tokens: %.3f" %
                    (part_wd, part_loss / part_wd))
    return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
Ejemplo n.º 3
0
def train(td,
          tl,
          ed,
          nd,
          optm,
          lrsch,
          model,
          lossf,
          mv_device,
          logger,
          done_tokens,
          multi_gpu,
          tokens_optm=32768,
          nreport=None,
          save_every=None,
          chkpf=None,
          chkpof=None,
          statesf=None,
          num_checkpoint=1,
          cur_checkid=0,
          report_eva=True,
          remain_steps=None,
          save_loss=False,
          save_checkp_epoch=False,
          scaler=None):

    sum_loss = part_loss = 0.0
    sum_wd = part_wd = 0
    _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len(
        tl)
    model.train()
    cur_b, _ls = 1, {} if save_loss else None

    global grad_mon, update_angle, enc_layer, log_dyn_p, log_dynb, wkdir
    _log_f_dynbatch = open(wkdir + "dynbatch.log", "ab")
    _log_f_dynbatch.write("ES\n".encode("utf-8"))

    src_grp, tgt_grp = td["src"], td["tgt"]
    for i_d in tqdm(tl):
        seq_batch = torch.from_numpy(src_grp[i_d][:]).long()
        seq_o = torch.from_numpy(tgt_grp[i_d][:]).long()
        lo = seq_o.size(1) - 1
        if mv_device:
            seq_batch = seq_batch.to(mv_device)
            seq_o = seq_o.to(mv_device)

        oi = seq_o.narrow(1, 0, lo)
        ot = seq_o.narrow(1, 1, lo).contiguous()
        with autocast(enabled=_use_amp):
            output = model(seq_batch, oi)
            loss = lossf(output, ot)
            if multi_gpu:
                loss = loss.sum()
        loss_add = loss.data.item()

        if scaler is None:
            loss.backward()
        else:
            scaler.scale(loss).backward()

        wd_add = ot.ne(pad_id).int().sum().item()
        loss = output = oi = ot = seq_batch = seq_o = None
        sum_loss += loss_add
        if save_loss:
            _ls[(i_d, t_d)] = loss_add / wd_add
        sum_wd += wd_add
        _done_tokens += wd_add

        if grad_mon.prev_grad is None:
            if log_dynb:
                dyn_sel_ind = grad_mon.sel_ind
                dyn_sel_layer, dyn_sel_enc = dyn_sel_ind % enc_layer, dyn_sel_ind < enc_layer
                _log_f_dynbatch.write(
                    ("%s%d %d\n" % ("E" if dyn_sel_enc else "D", dyn_sel_layer,
                                    wd_add)).encode("utf-8"))
        _perform_dyn_optm_step, _cos_sim_l = grad_mon.update(
            model.module if multi_gpu else model)
        _cos_sim = None if _cos_sim_l is None else _cos_sim_l[0]
        if log_dynb and (_cos_sim_l is not None):
            _log_f_dynbatch.write(
                ("%d %s\n" %
                 (wd_add, " ".join(["%.2f" % (_cu, )
                                    for _cu in _cos_sim_l]))).encode("utf-8"))

        if _perform_dyn_optm_step or (_done_tokens >= tokens_optm):
            if not _perform_dyn_optm_step:
                grad_mon.reset()
            _do_optm_step = True if _cos_sim is None else (
                _cos_sim <= update_angle)
            if _do_optm_step:
                if log_dynb:
                    _log_f_dynbatch.write(
                        ("%d\n" % (_done_tokens, )).encode("utf-8"))
                if multi_gpu:
                    model.collect_gradients()
                optm_step(optm, scaler)
                optm.zero_grad(set_to_none=True)
                if multi_gpu:
                    model.update_replicas()
                lrsch.step()
            else:
                if log_dynb:
                    _log_f_dynbatch.write(
                        ("D %d\n" % (_done_tokens, )).encode("utf-8"))
                if multi_gpu:
                    model.reset_grad()
                else:
                    optm.zero_grad(set_to_none=True)
            log_dynb = random() <= log_dyn_p
            _done_tokens = 0
            if _cur_rstep is not None:
                if save_checkp_epoch and (save_every is not None) and (
                        _cur_rstep % save_every
                        == 0) and (chkpf is not None) and (_cur_rstep > 0):
                    if num_checkpoint > 1:
                        _fend = "_%d.h5" % (_cur_checkid)
                        _chkpf = chkpf[:-3] + _fend
                        if chkpof is not None:
                            _chkpof = chkpof[:-3] + _fend
                        _cur_checkid = (_cur_checkid + 1) % num_checkpoint
                    else:
                        _chkpf = chkpf
                        _chkpof = chkpof
                    save_model(model, _chkpf, multi_gpu, logger)
                    if chkpof is not None:
                        h5save(optm.state_dict(), _chkpof)
                    if statesf is not None:
                        save_states(statesf, tl[cur_b - 1:])
                if _do_optm_step:
                    _cur_rstep -= 1
                    if _cur_rstep <= 0:
                        break

        if nreport is not None:
            part_loss += loss_add
            part_wd += wd_add
            if cur_b % nreport == 0:
                if report_eva:
                    _leva, _eeva = eva(ed, nd, model, lossf, mv_device,
                                       multi_gpu, _use_amp)
                    logger.info(
                        "Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f"
                        % (part_wd, part_loss / part_wd, _leva, _eeva))
                    free_cache(mv_device)
                    model.train()
                else:
                    logger.info("Average loss over %d tokens: %.3f" %
                                (part_wd, part_loss / part_wd))
                logger.info("Dynb: %s" % (" ".join([
                    "%.2f" % (_tmpu, ) for _tmpu in grad_mon.recorder.get_w()
                ]), ))
                part_loss = 0.0
                part_wd = 0

        if save_checkp_epoch and (_cur_rstep is None) and (
                save_every is not None) and (cur_b % save_every == 0) and (
                    chkpf is not None) and (cur_b < ndata):
            if num_checkpoint > 1:
                _fend = "_%d.h5" % (_cur_checkid)
                _chkpf = chkpf[:-3] + _fend
                if chkpof is not None:
                    _chkpof = chkpof[:-3] + _fend
                _cur_checkid = (_cur_checkid + 1) % num_checkpoint
            else:
                _chkpf = chkpf
                _chkpof = chkpof
            save_model(model, _chkpf, multi_gpu, logger)
            if chkpof is not None:
                h5save(optm.state_dict(), _chkpof)
            if statesf is not None:
                save_states(statesf, tl[cur_b - 1:])
        cur_b += 1
    if part_wd != 0.0:
        logger.info("Average loss over %d tokens: %.3f" %
                    (part_wd, part_loss / part_wd))
    logger.info(
        "Dynb: %s" %
        (" ".join(["%.2f" % (_tmpu, )
                   for _tmpu in grad_mon.recorder.get_w()]), ))

    _log_f_dynbatch.close()

    return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
Ejemplo n.º 4
0
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False):

	sum_loss = 0.0
	sum_wd = 0
	part_loss = 0.0
	part_wd = 0
	_done_tokens = done_tokens
	model.train()
	cur_b = 1
	ndata = len(tl)
	_cur_checkid = cur_checkid
	_cur_rstep = remain_steps
	_ls = {} if save_loss else None

	global grad_mon, update_angle

	src_grp, tgt_grp = td["src"], td["tgt"]
	for i_d in tqdm(tl):
		seq_batch = torch.from_numpy(src_grp[i_d][:]).long()
		seq_o = torch.from_numpy(tgt_grp[i_d][:]).long()
		lo = seq_o.size(1) - 1
		if mv_device:
			seq_batch = seq_batch.to(mv_device)
			seq_o = seq_o.to(mv_device)

		oi = seq_o.narrow(1, 0, lo)
		ot = seq_o.narrow(1, 1, lo).contiguous()
		output = model(seq_batch, oi)
		loss = lossf(output, ot)
		if multi_gpu:
			loss = loss.sum()
		loss_add = loss.data.item()

		if use_amp:
			with amp.scale_loss(loss, optm) as scaled_loss:
				scaled_loss.backward()
		else:
			loss.backward()

		wd_add = ot.ne(0).int().sum().item()
		loss = output = oi = ot = seq_batch = seq_o = None
		sum_loss += loss_add
		if save_loss:
			_ls[(i_d, t_d)] = loss_add / wd_add
		sum_wd += wd_add
		_done_tokens += wd_add

		_perform_dyn_optm_step, _cos_sim = grad_mon.update(model.module if multi_gpu else model)

		if _perform_dyn_optm_step or (_done_tokens >= tokens_optm):
			if not _perform_dyn_optm_step:
				grad_mon.reset()
			_do_optm_step = True if _cos_sim is None else (_cos_sim <= update_angle)
			if _do_optm_step:
				if multi_gpu:
					model.collect_gradients()
					optm.step()
					optm.zero_grad()
					model.update_replicas()
				else:
					optm.step()
					optm.zero_grad()
				lrsch.step()
			else:
				if multi_gpu:
					#optm.zero_grad()
					model.reset_grad()
				else:
					optm.zero_grad()
			_done_tokens = 0
			if _cur_rstep is not None:
				if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0):
					if num_checkpoint > 1:
						_fend = "_%d.h5" % (_cur_checkid)
						_chkpf = chkpf[:-3] + _fend
						if chkpof is not None:
							_chkpof = chkpof[:-3] + _fend
						_cur_checkid = (_cur_checkid + 1) % num_checkpoint
					else:
						_chkpf = chkpf
						_chkpof = chkpof
					save_model(model, _chkpf, multi_gpu, logger)
					if chkpof is not None:
						h5save(optm.state_dict(), _chkpof)
					if statesf is not None:
						save_states(statesf, tl[cur_b - 1:])
				if _do_optm_step:
					_cur_rstep -= 1
					if _cur_rstep <= 0:
						break

		if nreport is not None:
			part_loss += loss_add
			part_wd += wd_add
			if cur_b % nreport == 0:
				if report_eva:
					_leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu)
					logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva))
					free_cache(mv_device)
					model.train()
				else:
					logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd))
				part_loss = 0.0
				part_wd = 0

		if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata):
			if num_checkpoint > 1:
				_fend = "_%d.h5" % (_cur_checkid)
				_chkpf = chkpf[:-3] + _fend
				if chkpof is not None:
					_chkpof = chkpof[:-3] + _fend
				_cur_checkid = (_cur_checkid + 1) % num_checkpoint
			else:
				_chkpf = chkpf
				_chkpof = chkpof
			save_model(model, _chkpf, multi_gpu, logger)
			if chkpof is not None:
				h5save(optm.state_dict(), _chkpof)
			if statesf is not None:
				save_states(statesf, tl[cur_b - 1:])
		cur_b += 1
	if part_wd != 0.0:
		logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd))

	return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls