Exemple #1
0
def train_epoch(model=None,
                dataloader=None,
                optim=None,
                losses=None,
                device=torch.device("cpu"),
                tt=q.ticktock("-"),
                current_epoch=0,
                max_epochs=0,
                _train_batch=train_batch,
                on_start=tuple(),
                on_end=tuple(),
                print_every_batch=False):
    """
    Performs an epoch of training on given model, with data from given dataloader, using given optimizer,
    with loss computed based on given losses.
    :param model:
    :param dataloader:
    :param optim:
    :param losses:  list of loss wrappers
    :param device:  device to put batches on
    :param tt:
    :param current_epoch:
    :param max_epochs:
    :param _train_batch:    train batch function, default is train_batch
    :param on_start:
    :param on_end:
    :return:
    """
    for loss in losses:
        loss.push_epoch_to_history(epoch=current_epoch - 1)
        loss.reset_agg()
        loss.loss.to(device)

    model.to(device)

    [e() for e in on_start]

    q.epoch_reset(model)

    for i, _batch in enumerate(dataloader):
        ttmsg = _train_batch(batch=_batch,
                             model=model,
                             optim=optim,
                             losses=losses,
                             device=device,
                             batch_number=i,
                             max_batches=len(dataloader),
                             current_epoch=current_epoch,
                             max_epochs=max_epochs)
        if print_every_batch:
            tt.msg(ttmsg)
        else:
            tt.live(ttmsg)

    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    return ttmsg
def adv_train_epoch(model=None,
                    dataloader=None,
                    optim=None,
                    losses=None,
                    advmodel=None,
                    advdataloader=None,
                    advoptim=None,
                    advlosses=None,
                    device=torch.device("cpu"),
                    tt=q.ticktock(" -"),
                    current_epoch=0,
                    max_epochs=0,
                    _train_batch=q.train_batch,
                    _adv_train_batch=q.train_batch,
                    on_start=tuple(),
                    on_end=tuple(),
                    print_every_batch=False,
                    advsteps=1):
    """
    Performs an epoch of adversarial training on given model, with data from given dataloader, using given optimizer,
    with loss computed based on given losses.
    :param model:
    :param dataloader:
    :param optim:
    :param losses:  list of loss wrappers
    :param device:  device to put batches on
    :param tt:
    :param current_epoch:
    :param max_epochs:
    :param _train_batch:    train batch function, default is train_batch
    :param on_start:
    :param on_end:
    :return:
    """
    for loss in losses + advlosses:
        loss.push_epoch_to_history(epoch=current_epoch - 1)
        loss.reset_agg()
        loss.loss.to(device)

    model.to(device)
    advmodel.to(device)

    [e() for e in on_start]

    q.epoch_reset(model)
    q.epoch_reset(advmodel)

    for i, _batch in enumerate(dataloader):
        adviter = iter(advdataloader)
        for j in range(advsteps):
            try:
                _advbatch = next(adviter)
            except StopIteration as e:
                adviter = iter(advdataloader)
                _advbatch = next(adviter)
            ttmsg = _adv_train_batch(batch=_advbatch,
                                     model=advmodel,
                                     optim=advoptim,
                                     losses=advlosses,
                                     device=device,
                                     batch_number=j,
                                     max_batches=0,
                                     current_epoch=current_epoch,
                                     max_epochs=0)
            ttmsg = f"adv:  {ttmsg}"
            if print_every_batch:
                tt.msg(ttmsg)
            else:
                tt.live(ttmsg)
        ttmsg = _train_batch(batch=_batch,
                             model=model,
                             optim=optim,
                             losses=losses,
                             device=device,
                             batch_number=i,
                             max_batches=len(dataloader),
                             current_epoch=current_epoch,
                             max_epochs=max_epochs)
        ttmsg = f"main: {ttmsg}"
        if print_every_batch:
            tt.msg(ttmsg)
        else:
            tt.live(ttmsg)

    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    advttmsg = q.pp_epoch_losses(*advlosses)
    ttmsg = f"\n main: {ttmsg}\n adv:  {advttmsg}"
    return ttmsg
def train_epoch_distill(model=None,
                        dataloader=None,
                        optim=None,
                        losses=None,
                        device=torch.device("cpu"),
                        tt=q.ticktock("-"),
                        current_epoch=0,
                        max_epochs=0,
                        _train_batch=train_batch_distill,
                        on_start=tuple(),
                        on_end=tuple(),
                        run=False,
                        mbase=None,
                        goldgetter=None):
    """
    Performs an epoch of training on given model, with data from given dataloader, using given optimizer,
    with loss computed based on given losses.
    :param model:
    :param dataloader:
    :param optim:
    :param losses:  list of loss wrappers
    :param device:  device to put batches on
    :param tt:
    :param current_epoch:
    :param max_epochs:
    :param _train_batch:    train batch function, default is train_batch
    :param on_start:
    :param on_end:
    :return:
    """
    # if run is False:
    #     kwargs = locals().copy()
    #     return partial(train_epoch, **kwargs)

    for loss in losses:
        loss.push_epoch_to_history(epoch=current_epoch - 1)
        loss.reset_agg()

    [e() for e in on_start]

    q.epoch_reset(model)
    if mbase is not None:
        q.epoch_reset(mbase)

    for i, _batch in enumerate(dataloader):
        ttmsg = _train_batch(batch=_batch,
                             model=model,
                             optim=optim,
                             losses=losses,
                             device=device,
                             batch_number=i,
                             max_batches=len(dataloader),
                             current_epoch=current_epoch,
                             max_epochs=max_epochs,
                             run=True,
                             mbase=mbase,
                             goldgetter=goldgetter)
        tt.live(ttmsg)

    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    return ttmsg
Exemple #4
0
def test_epoch(model=None,
               dataloader=None,
               losses=None,
               device=torch.device("cpu"),
               current_epoch=0,
               max_epochs=0,
               print_every_batch=False,
               on_start=tuple(),
               on_start_batch=tuple(),
               on_end_batch=tuple(),
               on_end=tuple()):
    """
    Performs a test epoch. If run=True, runs, otherwise returns partially filled function.
    :param model:
    :param dataloader:
    :param losses:
    :param device:
    :param current_epoch:
    :param max_epochs:
    :param on_start:
    :param on_start_batch:
    :param on_end_batch:
    :param on_end:
    :return:
    """
    tt = q.ticktock("-")
    model.eval()
    q.epoch_reset(model)
    [e() for e in on_start]
    with torch.no_grad():
        for loss_obj in losses:
            loss_obj.push_epoch_to_history()
            loss_obj.reset_agg()
            loss_obj.loss.to(device)
        for i, _batch in enumerate(dataloader):
            [e() for e in on_start_batch]

            _batch = (_batch, ) if not q.issequence(_batch) else _batch
            _batch = q.recmap(
                _batch, lambda x: x.to(device)
                if isinstance(x, torch.Tensor) else x)
            batch = _batch
            numex = batch[0].size(0)

            if q.no_gold(losses):
                batch_in = batch
                gold = None
            else:
                batch_in = batch[:-1]
                gold = batch[-1]

            q.batch_reset(model)
            modelouts = model(*batch_in)

            testlosses = []
            for loss_obj in losses:
                loss_val = loss_obj(modelouts, gold, _numex=numex)
                loss_val = [loss_val
                            ] if not q.issequence(loss_val) else loss_val
                testlosses.extend(loss_val)

            ttmsg = "test - Epoch {}/{} - [{}/{}]: {}".format(
                current_epoch + 1, max_epochs, i + 1, len(dataloader),
                q.pp_epoch_losses(*losses))
            if print_every_batch:
                tt.msg(ttmsg)
            else:
                tt.live(ttmsg)
            [e() for e in on_end_batch]
    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    return ttmsg
Exemple #5
0
def adv_train_epoch(main_model=None,
                    adv_model=None,
                    main_dataloader=None,
                    adv_dataloader=None,
                    main_optim=None,
                    adv_optim=None,
                    main_losses=None,
                    adv_losses=None,
                    adviters=1,
                    device=torch.device("cpu"),
                    tt=q.ticktock(" -"),
                    current_epoch=0,
                    max_epochs=0,
                    _main_train_batch=q.train_batch,
                    _adv_train_batch=q.train_batch,
                    on_start=tuple(),
                    on_end=tuple(),
                    print_every_batch=False):
    """
    Performs an epoch of training on given model, with data from given dataloader, using given optimizer,
    with loss computed based on given losses.
    :param model:
    :param dataloader:
    :param optim:
    :param losses:  list of loss wrappers
    :param device:  device to put batches on
    :param tt:
    :param current_epoch:
    :param max_epochs:
    :param _train_batch:    train batch function, default is train_batch
    :param on_start:
    :param on_end:
    :return:
    """
    for loss in main_losses + adv_losses:
        loss.push_epoch_to_history(epoch=current_epoch - 1)
        loss.reset_agg()
        loss.to(device)

    main_model.to(device)
    adv_model.to(device)

    [e() for e in on_start]

    q.epoch_reset(main_model)
    q.epoch_reset(adv_model)
    adv_optim.zero_grad()
    main_optim.zero_grad()

    adv_dl_iter = iter(adv_dataloader)
    k = 0
    for i, main_batch in enumerate(main_dataloader):

        # do 'adviters' of adversarial updates
        j = adviters
        while j > 0:
            try:
                adv_batch = next(adv_dl_iter)
                adv_optim.zero_grad()
                ttmsg = _adv_train_batch(batch=adv_batch,
                                         model=adv_model,
                                         optim=adv_optim,
                                         losses=adv_losses,
                                         device=device,
                                         batch_number=k,
                                         max_batches=len(adv_dataloader),
                                         current_epoch=current_epoch,
                                         max_epochs=max_epochs)
                ttmsg = "adv: " + ttmsg
                if print_every_batch:
                    tt.msg(ttmsg)
                else:
                    tt.live(ttmsg)
                j -= 1
                k += 1
            except StopIteration as e:
                adv_dl_iter = iter(adv_dataloader)
                k = 0

        # do main update
        main_optim.zero_grad()
        ttmsg = _main_train_batch(batch=main_batch,
                                  model=main_model,
                                  optim=main_optim,
                                  losses=main_losses,
                                  device=device,
                                  batch_number=i,
                                  max_batches=len(main_dataloader),
                                  current_epoch=current_epoch,
                                  max_epochs=max_epochs)
        ttmsg = "main: " + ttmsg
        if print_every_batch:
            tt.msg(ttmsg)
        else:
            tt.live(ttmsg)
        j -= 1

    tt.stoplive()
    [e() for e in on_end]
    ttmsg = "main: " + q.pp_epoch_losses(
        *main_losses) + " -- adv: " + q.pp_epoch_losses(*adv_losses)
    return ttmsg