Ejemplo n.º 1
0
def eval_loop(model, dataloader, device=torch.device("cpu")):
    tto = q.ticktock("testing")
    tto.tick("testing")
    tt = q.ticktock("-")
    totaltestbats = len(dataloader)
    model.eval()
    epoch_reset(model)
    outs = []
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            batch = (batch, ) if not q.issequence(batch) else batch
            batch = q.recmap(
                batch, lambda x: x.to(device)
                if isinstance(x, torch.Tensor) else x)

            batch_reset(model)
            modelouts = model(*batch)

            tt.live("eval - [{}/{}]".format(i + 1, totaltestbats))
            if not q.issequence(modelouts):
                modelouts = (modelouts, )
            if len(outs) == 0:
                outs = [[] for e in modelouts]
            for out_e, mout_e in zip(outs, modelouts):
                out_e.append(mout_e)
    ttmsg = "eval done"
    tt.stoplive()
    tt.tock(ttmsg)
    tto.tock("tested")
    ret = [torch.cat(out_e, 0) for out_e in outs]
    return ret
Ejemplo n.º 2
0
def train_batch_distill(batch=None,
                        model=None,
                        optim=None,
                        losses=None,
                        device=torch.device("cpu"),
                        batch_number=-1,
                        max_batches=0,
                        current_epoch=0,
                        max_epochs=0,
                        on_start=tuple(),
                        on_before_optim_step=tuple(),
                        on_after_optim_step=tuple(),
                        on_end=tuple(),
                        run=False,
                        mbase=None,
                        goldgetter=None):
    """
    Runs a single batch of SGD on provided batch and settings.
    :param _batch:  batch to run on
    :param model:   torch.nn.Module of the model
    :param optim:       torch optimizer
    :param losses:      list of losswrappers
    :param device:      device
    :param batch_number:    which batch
    :param max_batches:     total number of batches
    :param current_epoch:   current epoch
    :param max_epochs:      total number of epochs
    :param on_start:        collection of functions to call when starting training batch
    :param on_before_optim_step:    collection of functions for before optimization step is taken (gradclip)
    :param on_after_optim_step:     collection of functions for after optimization step is taken
    :param on_end:              collection of functions to call when batch is done
    :param mbase:           base model where to distill from. takes inputs and produces output distributions to match by student model. if goldgetter is specified, this is not used.
    :param goldgetter:      takes the gold and produces a softgold
    :return:
    """
    # if run is False:
    #     kwargs = locals().copy()
    #     return partial(train_batch, **kwargs)

    [e() for e in on_start]
    optim.zero_grad()
    model.train()

    batch = (batch, ) if not q.issequence(batch) else batch
    batch = q.recmap(
        batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x)

    batch_in = batch[:-1]
    gold = batch[-1]

    # run batch_in through teacher model to get teacher output distributions
    if goldgetter is not None:
        softgold = goldgetter(gold)
    elif mbase is not None:
        mbase.eval()
        q.batch_reset(mbase)
        with torch.no_grad():
            softgold = mbase(*batch_in)
    else:
        raise q.SumTingWongException(
            "goldgetter and mbase can not both be None")

    q.batch_reset(model)
    modelouts = model(*batch_in)

    trainlosses = []
    for loss_obj in losses:
        loss_val = loss_obj(modelouts, (softgold, gold))
        loss_val = [loss_val] if not q.issequence(loss_val) else loss_val
        trainlosses.extend(loss_val)

    cost = trainlosses[0]
    cost.backward()

    [e() for e in on_before_optim_step]
    optim.step()
    [e() for e in on_after_optim_step]

    ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format(
        current_epoch + 1,
        max_epochs,
        batch_number + 1,
        max_batches,
        q.pp_epoch_losses(*losses),
    )

    [e() for e in on_end]
    return ttmsg
Ejemplo n.º 3
0
def test_epoch(model=None,
               dataloader=None,
               losses=None,
               device=torch.device("cpu"),
               current_epoch=0,
               max_epochs=0,
               print_every_batch=False,
               on_start=tuple(),
               on_start_batch=tuple(),
               on_end_batch=tuple(),
               on_end=tuple()):
    """
    Performs a test epoch. If run=True, runs, otherwise returns partially filled function.
    :param model:
    :param dataloader:
    :param losses:
    :param device:
    :param current_epoch:
    :param max_epochs:
    :param on_start:
    :param on_start_batch:
    :param on_end_batch:
    :param on_end:
    :return:
    """
    tt = q.ticktock("-")
    model.eval()
    q.epoch_reset(model)
    [e() for e in on_start]
    with torch.no_grad():
        for loss_obj in losses:
            loss_obj.push_epoch_to_history()
            loss_obj.reset_agg()
            loss_obj.loss.to(device)
        for i, _batch in enumerate(dataloader):
            [e() for e in on_start_batch]

            _batch = (_batch, ) if not q.issequence(_batch) else _batch
            _batch = q.recmap(
                _batch, lambda x: x.to(device)
                if isinstance(x, torch.Tensor) else x)
            batch = _batch
            numex = batch[0].size(0)

            if q.no_gold(losses):
                batch_in = batch
                gold = None
            else:
                batch_in = batch[:-1]
                gold = batch[-1]

            q.batch_reset(model)
            modelouts = model(*batch_in)

            testlosses = []
            for loss_obj in losses:
                loss_val = loss_obj(modelouts, gold, _numex=numex)
                loss_val = [loss_val
                            ] if not q.issequence(loss_val) else loss_val
                testlosses.extend(loss_val)

            ttmsg = "test - Epoch {}/{} - [{}/{}]: {}".format(
                current_epoch + 1, max_epochs, i + 1, len(dataloader),
                q.pp_epoch_losses(*losses))
            if print_every_batch:
                tt.msg(ttmsg)
            else:
                tt.live(ttmsg)
            [e() for e in on_end_batch]
    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    return ttmsg
Ejemplo n.º 4
0
def train_batch(batch=None,
                model=None,
                optim=None,
                losses=None,
                device=torch.device("cpu"),
                batch_number=-1,
                max_batches=0,
                current_epoch=0,
                max_epochs=0,
                on_start=tuple(),
                on_before_optim_step=tuple(),
                on_after_optim_step=tuple(),
                on_end=tuple()):
    """
    Runs a single batch of SGD on provided batch and settings.
    :param batch:  batch to run on
    :param model:   torch.nn.Module of the model
    :param optim:       torch optimizer
    :param losses:      list of losswrappers
    :param device:      device
    :param batch_number:    which batch
    :param max_batches:     total number of batches
    :param current_epoch:   current epoch
    :param max_epochs:      total number of epochs
    :param on_start:        collection of functions to call when starting training batch
    :param on_before_optim_step:    collection of functions for before optimization step is taken (gradclip)
    :param on_after_optim_step:     collection of functions for after optimization step is taken
    :param on_end:              collection of functions to call when batch is done
    :return:
    """
    [e() for e in on_start]
    optim.zero_grad()
    model.train()

    batch = (batch, ) if not q.issequence(batch) else batch
    batch = q.recmap(
        batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x)
    numex = batch[0].size(0)

    if q.no_gold(losses):
        batch_in = batch
        gold = None
    else:
        batch_in = batch[:-1]
        gold = batch[-1]

    q.batch_reset(model)
    modelouts = model(*batch_in)

    trainlosses = []
    for loss_obj in losses:
        loss_val = loss_obj(modelouts, gold, _numex=numex)
        loss_val = [loss_val] if not q.issequence(loss_val) else loss_val
        trainlosses.extend(loss_val)

    cost = trainlosses[0]
    # penalties
    penalties = 0
    for loss_obj, trainloss in zip(losses, trainlosses):
        if isinstance(loss_obj.loss, q.loss.PenaltyGetter):
            penalties += trainloss
    cost = cost + penalties

    if torch.isnan(cost).any():
        print("Cost is NaN!")
        embed()

    cost.backward()

    [e() for e in on_before_optim_step]
    optim.step()
    [e() for e in on_after_optim_step]

    ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format(
        current_epoch + 1,
        max_epochs,
        batch_number + 1,
        max_batches,
        q.pp_epoch_losses(*losses),
    )

    [e() for e in on_end]
    return ttmsg
Ejemplo n.º 5
0
 def to(self, device):
     for state in self.states:
         state.to(device)
     self.batched_states = q.recmap(self.batched_states,
                                    lambda x: x.to(device))
     return self
Ejemplo n.º 6
0
 def to(self, device):
     self.nn_states = q.recmap(self.nn_states, lambda x: x.to(device))
     for k, v in self.__dict__.items():
         if isinstance(v, torch.Tensor):
             setattr(self, k, v.to(device))
     return self