def eval_loop(model, dataloader, device=torch.device("cpu")): tto = q.ticktock("testing") tto.tick("testing") tt = q.ticktock("-") totaltestbats = len(dataloader) model.eval() epoch_reset(model) outs = [] with torch.no_grad(): for i, batch in enumerate(dataloader): batch = (batch, ) if not q.issequence(batch) else batch batch = q.recmap( batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) batch_reset(model) modelouts = model(*batch) tt.live("eval - [{}/{}]".format(i + 1, totaltestbats)) if not q.issequence(modelouts): modelouts = (modelouts, ) if len(outs) == 0: outs = [[] for e in modelouts] for out_e, mout_e in zip(outs, modelouts): out_e.append(mout_e) ttmsg = "eval done" tt.stoplive() tt.tock(ttmsg) tto.tock("tested") ret = [torch.cat(out_e, 0) for out_e in outs] return ret
def train_batch_distill(batch=None, model=None, optim=None, losses=None, device=torch.device("cpu"), batch_number=-1, max_batches=0, current_epoch=0, max_epochs=0, on_start=tuple(), on_before_optim_step=tuple(), on_after_optim_step=tuple(), on_end=tuple(), run=False, mbase=None, goldgetter=None): """ Runs a single batch of SGD on provided batch and settings. :param _batch: batch to run on :param model: torch.nn.Module of the model :param optim: torch optimizer :param losses: list of losswrappers :param device: device :param batch_number: which batch :param max_batches: total number of batches :param current_epoch: current epoch :param max_epochs: total number of epochs :param on_start: collection of functions to call when starting training batch :param on_before_optim_step: collection of functions for before optimization step is taken (gradclip) :param on_after_optim_step: collection of functions for after optimization step is taken :param on_end: collection of functions to call when batch is done :param mbase: base model where to distill from. takes inputs and produces output distributions to match by student model. if goldgetter is specified, this is not used. :param goldgetter: takes the gold and produces a softgold :return: """ # if run is False: # kwargs = locals().copy() # return partial(train_batch, **kwargs) [e() for e in on_start] optim.zero_grad() model.train() batch = (batch, ) if not q.issequence(batch) else batch batch = q.recmap( batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) batch_in = batch[:-1] gold = batch[-1] # run batch_in through teacher model to get teacher output distributions if goldgetter is not None: softgold = goldgetter(gold) elif mbase is not None: mbase.eval() q.batch_reset(mbase) with torch.no_grad(): softgold = mbase(*batch_in) else: raise q.SumTingWongException( "goldgetter and mbase can not both be None") q.batch_reset(model) modelouts = model(*batch_in) trainlosses = [] for loss_obj in losses: loss_val = loss_obj(modelouts, (softgold, gold)) loss_val = [loss_val] if not q.issequence(loss_val) else loss_val trainlosses.extend(loss_val) cost = trainlosses[0] cost.backward() [e() for e in on_before_optim_step] optim.step() [e() for e in on_after_optim_step] ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format( current_epoch + 1, max_epochs, batch_number + 1, max_batches, q.pp_epoch_losses(*losses), ) [e() for e in on_end] return ttmsg
def test_epoch(model=None, dataloader=None, losses=None, device=torch.device("cpu"), current_epoch=0, max_epochs=0, print_every_batch=False, on_start=tuple(), on_start_batch=tuple(), on_end_batch=tuple(), on_end=tuple()): """ Performs a test epoch. If run=True, runs, otherwise returns partially filled function. :param model: :param dataloader: :param losses: :param device: :param current_epoch: :param max_epochs: :param on_start: :param on_start_batch: :param on_end_batch: :param on_end: :return: """ tt = q.ticktock("-") model.eval() q.epoch_reset(model) [e() for e in on_start] with torch.no_grad(): for loss_obj in losses: loss_obj.push_epoch_to_history() loss_obj.reset_agg() loss_obj.loss.to(device) for i, _batch in enumerate(dataloader): [e() for e in on_start_batch] _batch = (_batch, ) if not q.issequence(_batch) else _batch _batch = q.recmap( _batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) batch = _batch numex = batch[0].size(0) if q.no_gold(losses): batch_in = batch gold = None else: batch_in = batch[:-1] gold = batch[-1] q.batch_reset(model) modelouts = model(*batch_in) testlosses = [] for loss_obj in losses: loss_val = loss_obj(modelouts, gold, _numex=numex) loss_val = [loss_val ] if not q.issequence(loss_val) else loss_val testlosses.extend(loss_val) ttmsg = "test - Epoch {}/{} - [{}/{}]: {}".format( current_epoch + 1, max_epochs, i + 1, len(dataloader), q.pp_epoch_losses(*losses)) if print_every_batch: tt.msg(ttmsg) else: tt.live(ttmsg) [e() for e in on_end_batch] tt.stoplive() [e() for e in on_end] ttmsg = q.pp_epoch_losses(*losses) return ttmsg
def train_batch(batch=None, model=None, optim=None, losses=None, device=torch.device("cpu"), batch_number=-1, max_batches=0, current_epoch=0, max_epochs=0, on_start=tuple(), on_before_optim_step=tuple(), on_after_optim_step=tuple(), on_end=tuple()): """ Runs a single batch of SGD on provided batch and settings. :param batch: batch to run on :param model: torch.nn.Module of the model :param optim: torch optimizer :param losses: list of losswrappers :param device: device :param batch_number: which batch :param max_batches: total number of batches :param current_epoch: current epoch :param max_epochs: total number of epochs :param on_start: collection of functions to call when starting training batch :param on_before_optim_step: collection of functions for before optimization step is taken (gradclip) :param on_after_optim_step: collection of functions for after optimization step is taken :param on_end: collection of functions to call when batch is done :return: """ [e() for e in on_start] optim.zero_grad() model.train() batch = (batch, ) if not q.issequence(batch) else batch batch = q.recmap( batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) numex = batch[0].size(0) if q.no_gold(losses): batch_in = batch gold = None else: batch_in = batch[:-1] gold = batch[-1] q.batch_reset(model) modelouts = model(*batch_in) trainlosses = [] for loss_obj in losses: loss_val = loss_obj(modelouts, gold, _numex=numex) loss_val = [loss_val] if not q.issequence(loss_val) else loss_val trainlosses.extend(loss_val) cost = trainlosses[0] # penalties penalties = 0 for loss_obj, trainloss in zip(losses, trainlosses): if isinstance(loss_obj.loss, q.loss.PenaltyGetter): penalties += trainloss cost = cost + penalties if torch.isnan(cost).any(): print("Cost is NaN!") embed() cost.backward() [e() for e in on_before_optim_step] optim.step() [e() for e in on_after_optim_step] ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format( current_epoch + 1, max_epochs, batch_number + 1, max_batches, q.pp_epoch_losses(*losses), ) [e() for e in on_end] return ttmsg
def to(self, device): for state in self.states: state.to(device) self.batched_states = q.recmap(self.batched_states, lambda x: x.to(device)) return self
def to(self, device): self.nn_states = q.recmap(self.nn_states, lambda x: x.to(device)) for k, v in self.__dict__.items(): if isinstance(v, torch.Tensor): setattr(self, k, v.to(device)) return self