Esempio n. 1
0
def train_epoch(logger, loader, model, optimizer, scheduler):
    model.train()
    time_start = time.time()
    for batch in loader:
        optimizer.zero_grad()
        batch.to(torch.device(cfg.device))
        pred, true = model(batch)
        loss_ret = compute_loss(pred, true, batch)
        if len(loss_ret) == 2:
            total_loss, pred_score = loss_ret
            loss_main = torch.tensor(0)
            loss_reg = torch.tensor(0)
        else:
            total_loss, pred_score, loss_main, loss_reg = loss_ret
        total_loss.backward()
        optimizer.step()
        logger.update_stats(true=true.detach().cpu(),
                            pred=pred_score.detach().cpu(),
                            loss=total_loss.item(),
                            lr=scheduler.get_last_lr()[0],
                            time_used=time.time() - time_start,
                            params=cfg.params,
                            loss_main=loss_main.item(),
                            loss_reg=loss_reg.item())
        time_start = time.time()
    scheduler.step()
Esempio n. 2
0
def eval_epoch(logger, loader, model):
    model.eval()
    time_start = time.time()
    for batch in loader:
        batch.to(torch.device(cfg.device))
        pred, true = model(batch)
        loss, pred_score = compute_loss(pred, true)
        logger.update_stats(true=true.detach().cpu(),
                            pred=pred_score.detach().cpu(),
                            loss=loss.item(),
                            lr=0,
                            time_used=time.time() - time_start,
                            params=cfg.params)
        time_start = time.time()
Esempio n. 3
0
def train_epoch(logger, loader, model, optimizer, scheduler):
    model.train()
    time_start = time.time()
    for batch in loader:
        optimizer.zero_grad()
        batch.to(torch.device(cfg.device))
        pred, true = model(batch)
        loss, pred_score = compute_loss(pred, true)
        loss.backward()
        optimizer.step()
        logger.update_stats(true=true.detach().cpu(),
                            pred=pred_score.detach().cpu(),
                            loss=loss.item(),
                            lr=scheduler.get_last_lr()[0],
                            time_used=time.time() - time_start,
                            params=cfg.params)
        time_start = time.time()
    scheduler.step()
Esempio n. 4
0
def eval_epoch(logger, loader, model):
    model.eval()
    time_start = time.time()
    for batch in loader:
        batch.to(torch.device(cfg.device))
        pred, true = model(batch)
        loss_ret = compute_loss(pred, true, batch)
        if len(loss_ret) == 2:  # todo duplicate code: unpacking of loss values
            total_loss, pred_score = loss_ret
            loss_main = None
            loss_reg = None
        else:
            total_loss, pred_score, loss_main, loss_reg = loss_ret
        logger.update_stats(true=true.detach().cpu(),
                            pred=pred_score.detach().cpu(),
                            loss=total_loss.item(),
                            lr=0,
                            time_used=time.time() - time_start,
                            params=cfg.params,
                            loss_main=loss_main.item(),
                            loss_reg=loss_reg.item())
        time_start = time.time()