Example #1
0
def train_epoch(model=None,
                dataloader=None,
                optim=None,
                losses=None,
                device=torch.device("cpu"),
                tt=q.ticktock("-"),
                current_epoch=0,
                max_epochs=0,
                _train_batch=train_batch,
                on_start=tuple(),
                on_end=tuple(),
                print_every_batch=False):
    """
    Performs an epoch of training on given model, with data from given dataloader, using given optimizer,
    with loss computed based on given losses.
    :param model:
    :param dataloader:
    :param optim:
    :param losses:  list of loss wrappers
    :param device:  device to put batches on
    :param tt:
    :param current_epoch:
    :param max_epochs:
    :param _train_batch:    train batch function, default is train_batch
    :param on_start:
    :param on_end:
    :return:
    """
    for loss in losses:
        loss.push_epoch_to_history(epoch=current_epoch - 1)
        loss.reset_agg()
        loss.loss.to(device)

    model.to(device)

    [e() for e in on_start]

    q.epoch_reset(model)

    for i, _batch in enumerate(dataloader):
        ttmsg = _train_batch(batch=_batch,
                             model=model,
                             optim=optim,
                             losses=losses,
                             device=device,
                             batch_number=i,
                             max_batches=len(dataloader),
                             current_epoch=current_epoch,
                             max_epochs=max_epochs)
        if print_every_batch:
            tt.msg(ttmsg)
        else:
            tt.live(ttmsg)

    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    return ttmsg
def adv_train_epoch(model=None,
                    dataloader=None,
                    optim=None,
                    losses=None,
                    advmodel=None,
                    advdataloader=None,
                    advoptim=None,
                    advlosses=None,
                    device=torch.device("cpu"),
                    tt=q.ticktock(" -"),
                    current_epoch=0,
                    max_epochs=0,
                    _train_batch=q.train_batch,
                    _adv_train_batch=q.train_batch,
                    on_start=tuple(),
                    on_end=tuple(),
                    print_every_batch=False,
                    advsteps=1):
    """
    Performs an epoch of adversarial training on given model, with data from given dataloader, using given optimizer,
    with loss computed based on given losses.
    :param model:
    :param dataloader:
    :param optim:
    :param losses:  list of loss wrappers
    :param device:  device to put batches on
    :param tt:
    :param current_epoch:
    :param max_epochs:
    :param _train_batch:    train batch function, default is train_batch
    :param on_start:
    :param on_end:
    :return:
    """
    for loss in losses + advlosses:
        loss.push_epoch_to_history(epoch=current_epoch - 1)
        loss.reset_agg()
        loss.loss.to(device)

    model.to(device)
    advmodel.to(device)

    [e() for e in on_start]

    q.epoch_reset(model)
    q.epoch_reset(advmodel)

    for i, _batch in enumerate(dataloader):
        adviter = iter(advdataloader)
        for j in range(advsteps):
            try:
                _advbatch = next(adviter)
            except StopIteration as e:
                adviter = iter(advdataloader)
                _advbatch = next(adviter)
            ttmsg = _adv_train_batch(batch=_advbatch,
                                     model=advmodel,
                                     optim=advoptim,
                                     losses=advlosses,
                                     device=device,
                                     batch_number=j,
                                     max_batches=0,
                                     current_epoch=current_epoch,
                                     max_epochs=0)
            ttmsg = f"adv:  {ttmsg}"
            if print_every_batch:
                tt.msg(ttmsg)
            else:
                tt.live(ttmsg)
        ttmsg = _train_batch(batch=_batch,
                             model=model,
                             optim=optim,
                             losses=losses,
                             device=device,
                             batch_number=i,
                             max_batches=len(dataloader),
                             current_epoch=current_epoch,
                             max_epochs=max_epochs)
        ttmsg = f"main: {ttmsg}"
        if print_every_batch:
            tt.msg(ttmsg)
        else:
            tt.live(ttmsg)

    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    advttmsg = q.pp_epoch_losses(*advlosses)
    ttmsg = f"\n main: {ttmsg}\n adv:  {advttmsg}"
    return ttmsg
Example #3
0
def train_epoch_distill(model=None,
                        dataloader=None,
                        optim=None,
                        losses=None,
                        device=torch.device("cpu"),
                        tt=q.ticktock("-"),
                        current_epoch=0,
                        max_epochs=0,
                        _train_batch=train_batch_distill,
                        on_start=tuple(),
                        on_end=tuple(),
                        run=False,
                        mbase=None,
                        goldgetter=None):
    """
    Performs an epoch of training on given model, with data from given dataloader, using given optimizer,
    with loss computed based on given losses.
    :param model:
    :param dataloader:
    :param optim:
    :param losses:  list of loss wrappers
    :param device:  device to put batches on
    :param tt:
    :param current_epoch:
    :param max_epochs:
    :param _train_batch:    train batch function, default is train_batch
    :param on_start:
    :param on_end:
    :return:
    """
    # if run is False:
    #     kwargs = locals().copy()
    #     return partial(train_epoch, **kwargs)

    for loss in losses:
        loss.push_epoch_to_history(epoch=current_epoch - 1)
        loss.reset_agg()

    [e() for e in on_start]

    q.epoch_reset(model)
    if mbase is not None:
        q.epoch_reset(mbase)

    for i, _batch in enumerate(dataloader):
        ttmsg = _train_batch(batch=_batch,
                             model=model,
                             optim=optim,
                             losses=losses,
                             device=device,
                             batch_number=i,
                             max_batches=len(dataloader),
                             current_epoch=current_epoch,
                             max_epochs=max_epochs,
                             run=True,
                             mbase=mbase,
                             goldgetter=goldgetter)
        tt.live(ttmsg)

    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    return ttmsg
Example #4
0
def train_batch_distill(batch=None,
                        model=None,
                        optim=None,
                        losses=None,
                        device=torch.device("cpu"),
                        batch_number=-1,
                        max_batches=0,
                        current_epoch=0,
                        max_epochs=0,
                        on_start=tuple(),
                        on_before_optim_step=tuple(),
                        on_after_optim_step=tuple(),
                        on_end=tuple(),
                        run=False,
                        mbase=None,
                        goldgetter=None):
    """
    Runs a single batch of SGD on provided batch and settings.
    :param _batch:  batch to run on
    :param model:   torch.nn.Module of the model
    :param optim:       torch optimizer
    :param losses:      list of losswrappers
    :param device:      device
    :param batch_number:    which batch
    :param max_batches:     total number of batches
    :param current_epoch:   current epoch
    :param max_epochs:      total number of epochs
    :param on_start:        collection of functions to call when starting training batch
    :param on_before_optim_step:    collection of functions for before optimization step is taken (gradclip)
    :param on_after_optim_step:     collection of functions for after optimization step is taken
    :param on_end:              collection of functions to call when batch is done
    :param mbase:           base model where to distill from. takes inputs and produces output distributions to match by student model. if goldgetter is specified, this is not used.
    :param goldgetter:      takes the gold and produces a softgold
    :return:
    """
    # if run is False:
    #     kwargs = locals().copy()
    #     return partial(train_batch, **kwargs)

    [e() for e in on_start]
    optim.zero_grad()
    model.train()

    batch = (batch, ) if not q.issequence(batch) else batch
    batch = q.recmap(
        batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x)

    batch_in = batch[:-1]
    gold = batch[-1]

    # run batch_in through teacher model to get teacher output distributions
    if goldgetter is not None:
        softgold = goldgetter(gold)
    elif mbase is not None:
        mbase.eval()
        q.batch_reset(mbase)
        with torch.no_grad():
            softgold = mbase(*batch_in)
    else:
        raise q.SumTingWongException(
            "goldgetter and mbase can not both be None")

    q.batch_reset(model)
    modelouts = model(*batch_in)

    trainlosses = []
    for loss_obj in losses:
        loss_val = loss_obj(modelouts, (softgold, gold))
        loss_val = [loss_val] if not q.issequence(loss_val) else loss_val
        trainlosses.extend(loss_val)

    cost = trainlosses[0]
    cost.backward()

    [e() for e in on_before_optim_step]
    optim.step()
    [e() for e in on_after_optim_step]

    ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format(
        current_epoch + 1,
        max_epochs,
        batch_number + 1,
        max_batches,
        q.pp_epoch_losses(*losses),
    )

    [e() for e in on_end]
    return ttmsg
Example #5
0
def test_epoch(model=None,
               dataloader=None,
               losses=None,
               device=torch.device("cpu"),
               current_epoch=0,
               max_epochs=0,
               print_every_batch=False,
               on_start=tuple(),
               on_start_batch=tuple(),
               on_end_batch=tuple(),
               on_end=tuple()):
    """
    Performs a test epoch. If run=True, runs, otherwise returns partially filled function.
    :param model:
    :param dataloader:
    :param losses:
    :param device:
    :param current_epoch:
    :param max_epochs:
    :param on_start:
    :param on_start_batch:
    :param on_end_batch:
    :param on_end:
    :return:
    """
    tt = q.ticktock("-")
    model.eval()
    q.epoch_reset(model)
    [e() for e in on_start]
    with torch.no_grad():
        for loss_obj in losses:
            loss_obj.push_epoch_to_history()
            loss_obj.reset_agg()
            loss_obj.loss.to(device)
        for i, _batch in enumerate(dataloader):
            [e() for e in on_start_batch]

            _batch = (_batch, ) if not q.issequence(_batch) else _batch
            _batch = q.recmap(
                _batch, lambda x: x.to(device)
                if isinstance(x, torch.Tensor) else x)
            batch = _batch
            numex = batch[0].size(0)

            if q.no_gold(losses):
                batch_in = batch
                gold = None
            else:
                batch_in = batch[:-1]
                gold = batch[-1]

            q.batch_reset(model)
            modelouts = model(*batch_in)

            testlosses = []
            for loss_obj in losses:
                loss_val = loss_obj(modelouts, gold, _numex=numex)
                loss_val = [loss_val
                            ] if not q.issequence(loss_val) else loss_val
                testlosses.extend(loss_val)

            ttmsg = "test - Epoch {}/{} - [{}/{}]: {}".format(
                current_epoch + 1, max_epochs, i + 1, len(dataloader),
                q.pp_epoch_losses(*losses))
            if print_every_batch:
                tt.msg(ttmsg)
            else:
                tt.live(ttmsg)
            [e() for e in on_end_batch]
    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    return ttmsg
Example #6
0
def train_batch(batch=None,
                model=None,
                optim=None,
                losses=None,
                device=torch.device("cpu"),
                batch_number=-1,
                max_batches=0,
                current_epoch=0,
                max_epochs=0,
                on_start=tuple(),
                on_before_optim_step=tuple(),
                on_after_optim_step=tuple(),
                on_end=tuple()):
    """
    Runs a single batch of SGD on provided batch and settings.
    :param batch:  batch to run on
    :param model:   torch.nn.Module of the model
    :param optim:       torch optimizer
    :param losses:      list of losswrappers
    :param device:      device
    :param batch_number:    which batch
    :param max_batches:     total number of batches
    :param current_epoch:   current epoch
    :param max_epochs:      total number of epochs
    :param on_start:        collection of functions to call when starting training batch
    :param on_before_optim_step:    collection of functions for before optimization step is taken (gradclip)
    :param on_after_optim_step:     collection of functions for after optimization step is taken
    :param on_end:              collection of functions to call when batch is done
    :return:
    """
    [e() for e in on_start]
    optim.zero_grad()
    model.train()

    batch = (batch, ) if not q.issequence(batch) else batch
    batch = q.recmap(
        batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x)
    numex = batch[0].size(0)

    if q.no_gold(losses):
        batch_in = batch
        gold = None
    else:
        batch_in = batch[:-1]
        gold = batch[-1]

    q.batch_reset(model)
    modelouts = model(*batch_in)

    trainlosses = []
    for loss_obj in losses:
        loss_val = loss_obj(modelouts, gold, _numex=numex)
        loss_val = [loss_val] if not q.issequence(loss_val) else loss_val
        trainlosses.extend(loss_val)

    cost = trainlosses[0]
    # penalties
    penalties = 0
    for loss_obj, trainloss in zip(losses, trainlosses):
        if isinstance(loss_obj.loss, q.loss.PenaltyGetter):
            penalties += trainloss
    cost = cost + penalties

    if torch.isnan(cost).any():
        print("Cost is NaN!")
        embed()

    cost.backward()

    [e() for e in on_before_optim_step]
    optim.step()
    [e() for e in on_after_optim_step]

    ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format(
        current_epoch + 1,
        max_epochs,
        batch_number + 1,
        max_batches,
        q.pp_epoch_losses(*losses),
    )

    [e() for e in on_end]
    return ttmsg
Example #7
0
def adv_train_epoch(main_model=None,
                    adv_model=None,
                    main_dataloader=None,
                    adv_dataloader=None,
                    main_optim=None,
                    adv_optim=None,
                    main_losses=None,
                    adv_losses=None,
                    adviters=1,
                    device=torch.device("cpu"),
                    tt=q.ticktock(" -"),
                    current_epoch=0,
                    max_epochs=0,
                    _main_train_batch=q.train_batch,
                    _adv_train_batch=q.train_batch,
                    on_start=tuple(),
                    on_end=tuple(),
                    print_every_batch=False):
    """
    Performs an epoch of training on given model, with data from given dataloader, using given optimizer,
    with loss computed based on given losses.
    :param model:
    :param dataloader:
    :param optim:
    :param losses:  list of loss wrappers
    :param device:  device to put batches on
    :param tt:
    :param current_epoch:
    :param max_epochs:
    :param _train_batch:    train batch function, default is train_batch
    :param on_start:
    :param on_end:
    :return:
    """
    for loss in main_losses + adv_losses:
        loss.push_epoch_to_history(epoch=current_epoch - 1)
        loss.reset_agg()
        loss.to(device)

    main_model.to(device)
    adv_model.to(device)

    [e() for e in on_start]

    q.epoch_reset(main_model)
    q.epoch_reset(adv_model)
    adv_optim.zero_grad()
    main_optim.zero_grad()

    adv_dl_iter = iter(adv_dataloader)
    k = 0
    for i, main_batch in enumerate(main_dataloader):

        # do 'adviters' of adversarial updates
        j = adviters
        while j > 0:
            try:
                adv_batch = next(adv_dl_iter)
                adv_optim.zero_grad()
                ttmsg = _adv_train_batch(batch=adv_batch,
                                         model=adv_model,
                                         optim=adv_optim,
                                         losses=adv_losses,
                                         device=device,
                                         batch_number=k,
                                         max_batches=len(adv_dataloader),
                                         current_epoch=current_epoch,
                                         max_epochs=max_epochs)
                ttmsg = "adv: " + ttmsg
                if print_every_batch:
                    tt.msg(ttmsg)
                else:
                    tt.live(ttmsg)
                j -= 1
                k += 1
            except StopIteration as e:
                adv_dl_iter = iter(adv_dataloader)
                k = 0

        # do main update
        main_optim.zero_grad()
        ttmsg = _main_train_batch(batch=main_batch,
                                  model=main_model,
                                  optim=main_optim,
                                  losses=main_losses,
                                  device=device,
                                  batch_number=i,
                                  max_batches=len(main_dataloader),
                                  current_epoch=current_epoch,
                                  max_epochs=max_epochs)
        ttmsg = "main: " + ttmsg
        if print_every_batch:
            tt.msg(ttmsg)
        else:
            tt.live(ttmsg)
        j -= 1

    tt.stoplive()
    [e() for e in on_end]
    ttmsg = "main: " + q.pp_epoch_losses(
        *main_losses) + " -- adv: " + q.pp_epoch_losses(*adv_losses)
    return ttmsg