Exemplo n.º 1
0
def test(test_loader, model, losstype = "lovasz"):
    running_loss = 0.0
    predicts = []
    truths = []

    model.eval()
    for inputs, masks in test_loader:
        inputs, masks = inputs.to(device), masks.to(device)
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            outputs = outputs[:, :, pad_left:pad_left + fine_size,
                      pad_left:pad_left + fine_size].contiguous()
            if (losstype != "bce_dice"):
                loss = lovasz_hinge(outputs.squeeze(1), masks.squeeze(1))
                predicts.append(F.sigmoid(outputs).detach().cpu().numpy())
            else:
                outputs = F.sigmoid(outputs)
                loss = mixed_dice_bce_loss(outputs, masks)
                predicts.append(outputs.detach().cpu().numpy())
        truths.append(masks.detach().cpu().numpy())
        running_loss += loss.item() * inputs.size(0)

    predicts = np.concatenate(predicts).squeeze()
    truths = np.concatenate(truths).squeeze()
    precision, _, _ = do_kaggle_metric(predicts, truths, 0.5)
    precision = precision.mean()
    epoch_loss = running_loss / val_data.__len__()
    return epoch_loss, precision
Exemplo n.º 2
0
def train(train_loader, model, losstype="lovasz"):
    running_loss = 0.0
    data_size = train_data.__len__()

    model.train()
    # for inputs, masks, labels in progress_bar(train_loader, parent=mb):
    for inputs, masks, labels in train_loader:
        inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(
            device)
        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            logit = model(inputs)
            logit = logit[0]
            if (losstype != "bce_dice"):
                loss = lovasz_hinge(logit.squeeze(1), masks.squeeze(1))
            else:
                logit = F.sigmoid(logit)
                loss = mixed_dice_bce_loss(logit, masks)
            loss.backward()
            optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        # mb.child.comment = 'loss: {}'.format(loss.item())
    epoch_loss = running_loss / data_size
    return epoch_loss