def train_epoch(model=None, dataloader=None, optim=None, losses=None, device=torch.device("cpu"), tt=q.ticktock("-"), current_epoch=0, max_epochs=0, _train_batch=train_batch, on_start=tuple(), on_end=tuple(), print_every_batch=False): """ Performs an epoch of training on given model, with data from given dataloader, using given optimizer, with loss computed based on given losses. :param model: :param dataloader: :param optim: :param losses: list of loss wrappers :param device: device to put batches on :param tt: :param current_epoch: :param max_epochs: :param _train_batch: train batch function, default is train_batch :param on_start: :param on_end: :return: """ for loss in losses: loss.push_epoch_to_history(epoch=current_epoch - 1) loss.reset_agg() loss.loss.to(device) model.to(device) [e() for e in on_start] q.epoch_reset(model) for i, _batch in enumerate(dataloader): ttmsg = _train_batch(batch=_batch, model=model, optim=optim, losses=losses, device=device, batch_number=i, max_batches=len(dataloader), current_epoch=current_epoch, max_epochs=max_epochs) if print_every_batch: tt.msg(ttmsg) else: tt.live(ttmsg) tt.stoplive() [e() for e in on_end] ttmsg = q.pp_epoch_losses(*losses) return ttmsg
def adv_train_epoch(model=None, dataloader=None, optim=None, losses=None, advmodel=None, advdataloader=None, advoptim=None, advlosses=None, device=torch.device("cpu"), tt=q.ticktock(" -"), current_epoch=0, max_epochs=0, _train_batch=q.train_batch, _adv_train_batch=q.train_batch, on_start=tuple(), on_end=tuple(), print_every_batch=False, advsteps=1): """ Performs an epoch of adversarial training on given model, with data from given dataloader, using given optimizer, with loss computed based on given losses. :param model: :param dataloader: :param optim: :param losses: list of loss wrappers :param device: device to put batches on :param tt: :param current_epoch: :param max_epochs: :param _train_batch: train batch function, default is train_batch :param on_start: :param on_end: :return: """ for loss in losses + advlosses: loss.push_epoch_to_history(epoch=current_epoch - 1) loss.reset_agg() loss.loss.to(device) model.to(device) advmodel.to(device) [e() for e in on_start] q.epoch_reset(model) q.epoch_reset(advmodel) for i, _batch in enumerate(dataloader): adviter = iter(advdataloader) for j in range(advsteps): try: _advbatch = next(adviter) except StopIteration as e: adviter = iter(advdataloader) _advbatch = next(adviter) ttmsg = _adv_train_batch(batch=_advbatch, model=advmodel, optim=advoptim, losses=advlosses, device=device, batch_number=j, max_batches=0, current_epoch=current_epoch, max_epochs=0) ttmsg = f"adv: {ttmsg}" if print_every_batch: tt.msg(ttmsg) else: tt.live(ttmsg) ttmsg = _train_batch(batch=_batch, model=model, optim=optim, losses=losses, device=device, batch_number=i, max_batches=len(dataloader), current_epoch=current_epoch, max_epochs=max_epochs) ttmsg = f"main: {ttmsg}" if print_every_batch: tt.msg(ttmsg) else: tt.live(ttmsg) tt.stoplive() [e() for e in on_end] ttmsg = q.pp_epoch_losses(*losses) advttmsg = q.pp_epoch_losses(*advlosses) ttmsg = f"\n main: {ttmsg}\n adv: {advttmsg}" return ttmsg
def train_epoch_distill(model=None, dataloader=None, optim=None, losses=None, device=torch.device("cpu"), tt=q.ticktock("-"), current_epoch=0, max_epochs=0, _train_batch=train_batch_distill, on_start=tuple(), on_end=tuple(), run=False, mbase=None, goldgetter=None): """ Performs an epoch of training on given model, with data from given dataloader, using given optimizer, with loss computed based on given losses. :param model: :param dataloader: :param optim: :param losses: list of loss wrappers :param device: device to put batches on :param tt: :param current_epoch: :param max_epochs: :param _train_batch: train batch function, default is train_batch :param on_start: :param on_end: :return: """ # if run is False: # kwargs = locals().copy() # return partial(train_epoch, **kwargs) for loss in losses: loss.push_epoch_to_history(epoch=current_epoch - 1) loss.reset_agg() [e() for e in on_start] q.epoch_reset(model) if mbase is not None: q.epoch_reset(mbase) for i, _batch in enumerate(dataloader): ttmsg = _train_batch(batch=_batch, model=model, optim=optim, losses=losses, device=device, batch_number=i, max_batches=len(dataloader), current_epoch=current_epoch, max_epochs=max_epochs, run=True, mbase=mbase, goldgetter=goldgetter) tt.live(ttmsg) tt.stoplive() [e() for e in on_end] ttmsg = q.pp_epoch_losses(*losses) return ttmsg
def test_epoch(model=None, dataloader=None, losses=None, device=torch.device("cpu"), current_epoch=0, max_epochs=0, print_every_batch=False, on_start=tuple(), on_start_batch=tuple(), on_end_batch=tuple(), on_end=tuple()): """ Performs a test epoch. If run=True, runs, otherwise returns partially filled function. :param model: :param dataloader: :param losses: :param device: :param current_epoch: :param max_epochs: :param on_start: :param on_start_batch: :param on_end_batch: :param on_end: :return: """ tt = q.ticktock("-") model.eval() q.epoch_reset(model) [e() for e in on_start] with torch.no_grad(): for loss_obj in losses: loss_obj.push_epoch_to_history() loss_obj.reset_agg() loss_obj.loss.to(device) for i, _batch in enumerate(dataloader): [e() for e in on_start_batch] _batch = (_batch, ) if not q.issequence(_batch) else _batch _batch = q.recmap( _batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) batch = _batch numex = batch[0].size(0) if q.no_gold(losses): batch_in = batch gold = None else: batch_in = batch[:-1] gold = batch[-1] q.batch_reset(model) modelouts = model(*batch_in) testlosses = [] for loss_obj in losses: loss_val = loss_obj(modelouts, gold, _numex=numex) loss_val = [loss_val ] if not q.issequence(loss_val) else loss_val testlosses.extend(loss_val) ttmsg = "test - Epoch {}/{} - [{}/{}]: {}".format( current_epoch + 1, max_epochs, i + 1, len(dataloader), q.pp_epoch_losses(*losses)) if print_every_batch: tt.msg(ttmsg) else: tt.live(ttmsg) [e() for e in on_end_batch] tt.stoplive() [e() for e in on_end] ttmsg = q.pp_epoch_losses(*losses) return ttmsg
def adv_train_epoch(main_model=None, adv_model=None, main_dataloader=None, adv_dataloader=None, main_optim=None, adv_optim=None, main_losses=None, adv_losses=None, adviters=1, device=torch.device("cpu"), tt=q.ticktock(" -"), current_epoch=0, max_epochs=0, _main_train_batch=q.train_batch, _adv_train_batch=q.train_batch, on_start=tuple(), on_end=tuple(), print_every_batch=False): """ Performs an epoch of training on given model, with data from given dataloader, using given optimizer, with loss computed based on given losses. :param model: :param dataloader: :param optim: :param losses: list of loss wrappers :param device: device to put batches on :param tt: :param current_epoch: :param max_epochs: :param _train_batch: train batch function, default is train_batch :param on_start: :param on_end: :return: """ for loss in main_losses + adv_losses: loss.push_epoch_to_history(epoch=current_epoch - 1) loss.reset_agg() loss.to(device) main_model.to(device) adv_model.to(device) [e() for e in on_start] q.epoch_reset(main_model) q.epoch_reset(adv_model) adv_optim.zero_grad() main_optim.zero_grad() adv_dl_iter = iter(adv_dataloader) k = 0 for i, main_batch in enumerate(main_dataloader): # do 'adviters' of adversarial updates j = adviters while j > 0: try: adv_batch = next(adv_dl_iter) adv_optim.zero_grad() ttmsg = _adv_train_batch(batch=adv_batch, model=adv_model, optim=adv_optim, losses=adv_losses, device=device, batch_number=k, max_batches=len(adv_dataloader), current_epoch=current_epoch, max_epochs=max_epochs) ttmsg = "adv: " + ttmsg if print_every_batch: tt.msg(ttmsg) else: tt.live(ttmsg) j -= 1 k += 1 except StopIteration as e: adv_dl_iter = iter(adv_dataloader) k = 0 # do main update main_optim.zero_grad() ttmsg = _main_train_batch(batch=main_batch, model=main_model, optim=main_optim, losses=main_losses, device=device, batch_number=i, max_batches=len(main_dataloader), current_epoch=current_epoch, max_epochs=max_epochs) ttmsg = "main: " + ttmsg if print_every_batch: tt.msg(ttmsg) else: tt.live(ttmsg) j -= 1 tt.stoplive() [e() for e in on_end] ttmsg = "main: " + q.pp_epoch_losses( *main_losses) + " -- adv: " + q.pp_epoch_losses(*adv_losses) return ttmsg