Пример #1
0
def validate(epoch):
    srcnn.eval()
    classifier.eval()
    val_loss = 0
    acc = 0

    with torch.no_grad():
        for batch in tqdm(val_dataloader):
            image, label, corruption = batch

            image = image.cuda()
            label = label.cuda()

            # with corruption detection
            model_out, residual = srcnn(image)
            pred = classifier(model_out)

            loss = nn.CrossEntropyLoss()(pred, label)

            hit = count_match(pred, label)

            val_loss += loss.item()
            acc += hit

    wandb.log({
        "val_loss": val_loss / len(val_dataloader),
        "acc": acc / len(val_dataloader.dataset)
    })
    return acc / len(val_dataloader.dataset)
Пример #2
0
def train(epoch):
    epoch_loss = 0
    mse_total = 0.
    clf_total = 0
    srcnn.train()
    hit = 0
    for iteration, batch in enumerate(tng_dataloader, 1):
        start = time.time()
        input, target, label, corruption_idx = batch
        input = input.cuda()
        target = target.cuda()
        label = label.cuda()

        optimizer.zero_grad()
        model_out, residual = srcnn(input)
        clf_pred = classifier(model_out)

        clf_loss = criterion(clf_pred, label)

        hit += count_match(clf_pred, label)
        mse_loss = nn.MSELoss()(model_out, target)

        loss = clf_loss

        clf_total += clf_loss.item()
        mse_total += mse_loss.item()
        epoch_loss += clf_loss.item() + mse_loss.item()
        loss.backward()
        optimizer.step()

        print("===> Epoch[{}]({}/{}): clf_loss : {:.4f}, iter: {:.3f}".format(
            epoch, iteration, len(tng_dataloader), clf_loss.item(),
            time.time() - start))
        # wandb.log({"iteration_loss": clf_loss.item()})

    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(
        epoch, epoch_loss / len(tng_dataloader)))

    plot_images_to_wandb([input[0], model_out[0], residual[0], target[0]],
                         "Comparison",
                         step=epoch)
    wandb.log(
        {
            "train_acc": hit / len(tng_dataloader.dataset),
            "clf_loss": clf_total / len(tng_dataloader),
            "mse_loss": mse_total / len(tng_dataloader),
            "total_loss": epoch_loss / len(tng_dataloader),
        },
        step=epoch)
    lr_scheduler.step()