def test():
    image_transform = T.Compose([
        T.Resize(256),
        T.ToTensor(),
    ], )
    mask_transform = T.Compose([
        T.Resize(256, interpolation=Image.NEAREST),
        T.ToTensor(),
    ], )
    ds = SyntheticDataset(label_transform=mask_transform,
                          image_transform=image_transform)

    dataloader = DataLoader(ds, batch_size=10)

    batch = next(iter(dataloader))
    image, label, inst, grade = (
        batch["image"],
        batch["label"],
        batch["inst"],
        batch["grade"],
    )
    label = label.unsqueeze(0)

    print(f"Label shape: {label.shape}")

    label_map = get_label_semantics(label)

    print(f"Label map shape: {label_map.shape}")
    nc = 7
    for i in range(nc):
        save_image(label_map[:, [i], :, :].float(), f"test_{i}.png")
    save_image(image, f"test_real.png")
    label[label == 2] = 0.5
    save_image(label, "label.png")
def predict_from_label(model: nn.Module, df: pd.DataFrame) -> np.ndarray:
    """Predicts labels dataset using the specified model."""
    img_size = 512
    transform = T.Compose([
        T.Resize(img_size, interpolation=InterpolationMode.NEAREST),
        T.ToTensor()
    ])

    predictions = np.empty(len(df), dtype=int)
    for i, row in tqdm(df.iterrows(), total=len(df)):
        label = Image.open(row["Label"])
        label = transform(label).unsqueeze(0)
        label = get_label_semantics(label)
        pred = model(label)
        pred = torch.argmax(pred)
        predictions[i] = pred.item()

    return predictions
def train_step(
    epoch: int,
    model: nn.Module,
    optimizer: optim.Optimizer,
    criterion: nn.CrossEntropyLoss,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    log_interval: int,
    val_interval: int,
    logger: ResNetLogger,
):
    n_batches = len(train_loader)

    for batch_idx, batch in enumerate(train_loader):
        model.train()

        iteration = epoch * n_batches + batch_idx

        images, grades = batch["label"], batch["grade"]
        images = get_label_semantics(images)

        images = images.to(device)
        grades = grades.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, grades)
        loss.backward()
        optimizer.step()

        if iteration % log_interval == 0:
            preds = torch.argmax(outputs, dim=1)
            train_loss = loss.item()
            train_acc = torch.sum(torch.eq(preds, grades)).item() / len(grades)

            metrics = ResNetTrainMetrics(iteration, epoch, train_loss,
                                         train_acc)
            logger.log_train(metrics)

        if iteration % val_interval == 0:
            validate(iteration, model, criterion, val_loader, device, logger)
def validate(
    iteration: int,
    model: nn.Module,
    criterion: nn.CrossEntropyLoss,
    val_loader: DataLoader,
    device: torch.device,
    logger: ResNetLogger,
):
    model.eval()
    n_val_samples = 0
    val_loss = 0.0
    val_corrects = 0.0

    for batch in val_loader:
        images, grades = batch["label"], batch["grade"]
        images = get_label_semantics(images)

        images = images.to(device)
        grades = grades.to(device)

        with torch.no_grad():
            outputs = model(images)

        loss = criterion(outputs, grades)

        preds = torch.argmax(outputs, dim=1)

        val_loss += loss.item()
        val_corrects += torch.sum(torch.eq(preds, grades)).item()
        n_val_samples += len(preds)

    val_loss /= n_val_samples
    val_acc = val_corrects / n_val_samples

    metrics = ResNetValidateMetrics(iteration, val_loss, val_acc)
    logger.log_val(metrics)
def main():
    # TODO(sonjoonho): Add argument parsing for options.

    out_dir = "results/"
    img_size = 512
    batch_size = 64

    opt = get_args()
    name = opt.name

    output_path = Path(opt.out_dir) / name
    output_path.mkdir(parents=True, exist_ok=True)

    device = get_device()

    model_path = (Path(out_dir) / "resnet_labels" / name / "checkpoints" /
                  "model_latest.pth")
    model = load_label_model(model_path)
    model = model.to(device)

    transform = T.Compose([
        T.Resize(img_size, interpolation=InterpolationMode.NEAREST),
        T.ToTensor(),
    ])

    if opt.dataset == "real":
        test_dataset = CombinedDataset(
            label_transform=transform,
            return_image=False,
            return_inst=False,
            return_transformed=False,
            mode=CombinedDataset.VALIDATION,
        )
        test_dataset.df = test_dataset.df[test_dataset.df["Source"] == "FGADR"]
    elif opt.dataset == "copypaste":
        test_dataset = CopyPasteDataset(
            label_transform=transform,
            return_transformed=False,
        )
    else:
        test_dataset = SyntheticDataset(
            opt.dataset,
            label_transform=transform,
            return_image=False,
            return_inst=False,
            return_transformed=False,
        )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        num_workers=8,
        pin_memory=True,
        shuffle=False,
    )

    n_val_samples = len(test_dataset)
    predictions = np.empty(n_val_samples, dtype=int)
    actual = np.empty(n_val_samples, dtype=int)

    print(f"Validation samples: {n_val_samples}")

    for i, batch in enumerate(tqdm(test_loader)):
        images, grades = batch["label"], batch["grade"]
        images = get_label_semantics(images)

        images = images.to(device)
        grades = grades.to(device)

        with torch.no_grad():
            outputs = model(images)
        preds = torch.argmax(outputs, dim=1)

        predictions[i * batch_size:i * batch_size +
                    images.shape[0]] = preds.cpu().numpy()
        actual[i * batch_size:i * batch_size +
               images.shape[0]] = grades.cpu().numpy()

    metrics = {
        "accuracy": accuracy_score(actual, predictions),
        "precision": precision_score(actual, predictions, average="macro"),
        "recall": recall_score(actual, predictions, average="macro"),
        "f1": f1_score(actual, predictions, average="macro"),
        "kappa": quadratic_kappa(actual, predictions),
        "tta": opt.tta,
        "tta_runs": opt.tta_runs,
    }
    print("Accuracy: ", metrics["accuracy"])
    print("Precision: ", metrics["precision"])
    print("Recall: ", metrics["recall"])
    print("F1: ", metrics["f1"])
    print("Cohen's", metrics["kappa"])

    time = timestamp()

    # Save options.
    with open(output_path / f"metrics-{time}.json", "w") as f:
        json.dump(vars(opt), f, indent=4)
def main():
    opt = get_args()

    out_path = Path(opt.out_dir) / "copypaste"
    label_path = out_path / "label"
    label_path.mkdir(parents=True, exist_ok=True)
    inst_path = out_path / "inst"
    inst_path.mkdir(parents=True, exist_ok=True)

    if opt.seed > -1:
        set_seed(opt.seed)

    nc = 8 + 1

    transform = T.Compose([
        T.Resize(opt.img_size, interpolation=T.InterpolationMode.NEAREST),
        T.ToTensor(),
    ])

    samples_per_grade = opt.n_samples // 5

    ticker = 0
    for dr_grade in range(5):
        dataset = CombinedDataset(
            return_image=False,
            return_transformed=False,
            label_transform=transform,
        )
        dataset.df = dataset.df[dataset.df["Grade"] == dr_grade]
        dataloader = DataLoader(dataset,
                                batch_size=nc,
                                shuffle=True,
                                drop_last=True)
        infinite_loader = infinite(dataloader)

        for _ in tqdm(range(samples_per_grade)):
            batch = next(infinite_loader)
            source_label = batch["label"]

            assert len(source_label) == nc

            source_label = get_label_semantics(source_label)

            _, _, height, width = source_label.shape
            retina = make_circle(height, width)

            od = source_label[[1], Labels.OD.value, :, :]
            ma = source_label[[2], Labels.MA.value, :, :]
            he = source_label[[3], Labels.HE.value, :, :]
            ex = source_label[[4], Labels.EX.value, :, :]
            se = source_label[[5], Labels.SE.value, :, :]
            nv = source_label[[6], Labels.NV.value, :, :]
            irma = source_label[[7], Labels.IRMA.value, :, :]

            retina = retina - od - ma - he - ex - se - nv - irma

            # Create background by subtracting everything else.
            bg = torch.ones_like(
                retina) - od - ma - he - ex - se - nv - irma - retina
            bg = torch.clamp(bg, 0, 1)

            combined = torch.cat([retina, od, ma, he, ex, se, nv, irma, bg],
                                 dim=0)
            new_label = torch.argmax(combined, dim=0).float().numpy()

            if opt.colour:
                new_label = colour_labels_numpy(new_label)
            else:
                # Set background to 255.
                new_label[new_label == (nc - 1)] = 255.0

            # TODO(sonjoonho): Generate grade more intelligently. This could be done just by
            # sampling from existing images with the desired grade.
            # End-point is inclusive.
            filename = f"copypaste_{dr_grade}_{ticker:05}.png"
            cv2.imwrite(str(label_path / filename), new_label)

            inst = od.squeeze().numpy()
            inst = np.ones_like(inst) - inst
            inst *= 255.0

            cv2.imwrite(str(inst_path / filename), inst)

            ticker += 1