def train(train_loader):
    pbar = ProgressBar(n_batch=len(train_loader))
    train_loss = AverageMeter()
    train_acc = AverageMeter()
    count = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output, loss = model(data, y=target, loss_fn=nn.CrossEntropyLoss())
        pred = output.argmax(
            dim=1, keepdim=True)  # get the index of the max log-probability
        correct = pred.eq(target.view_as(pred)).sum().item()
        loss.backward()
        optimizer.step()
        count += data.size(0)
        train_acc.update(correct, n=1)
        pbar.batch_step(batch_idx=batch_idx,
                        info={
                            'loss': loss.item(),
                            'acc': correct / data.size(0)
                        },
                        bar_type='Training')
        train_loss.update(loss.item(), n=1)
    print(' ')
    return {'loss': train_loss.avg, 'acc': train_acc.sum / count}
Exemplo n.º 2
0
def train(train_loader):
    pbar = ProgressBar(n_batch=len(train_loader))
    train_loss = AverageMeter()
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        pbar.batch_step(batch_idx=batch_idx,
                        info={'loss': loss.item()},
                        bar_type='Training')
        train_loss.update(loss.item(), n=1)
    return {'loss': train_loss.avg}
def test(test_loader):
    pbar = ProgressBar(n_batch=len(test_loader))
    valid_loss = AverageMeter()
    valid_acc = AverageMeter()
    model.eval()
    count = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output, loss = model(data, y=target, loss_fn=nn.CrossEntropyLoss())
            pred = output.argmax(
                dim=1,
                keepdim=True)  # get the index of the max log-probability
            correct = pred.eq(target.view_as(pred)).sum().item()
            valid_loss.update(loss, n=data.size(0))
            valid_acc.update(correct, n=1)
            count += data.size(0)
            pbar.batch_step(batch_idx=batch_idx, info={}, bar_type='Testing')
    return {'valid_loss': valid_loss.avg, 'valid_acc': valid_acc.sum / count}