def test_admm_size():
    for name in list(flags.FLAGS):
        delattr(flags.FLAGS, name)
    root_dir = '../dataset/ACDC-2D-All'
    train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform((200, 200)), augment=augment)
    val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((200, 200)), augment=None)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    AdmmSize.setup_arch_flags()
    hparams = flags.FLAGS.flag_values_dict()
    arch_hparam = extract_from_big_dict(hparams, AdmmSize.arch_hparam_keys)
    torchnet = get_arch('enet', arch_hparam)
    weight = torch.Tensor([0.1, 1])
    criterion = get_loss_fn('cross_entropy', weight=weight)

    test_admm = AdmmSize(torchnet, hparams)

    val_score = test_admm.evaluate(val_loader)
    print(val_score)

    for i, (img, gt, wgt, _) in enumerate(train_loader):
        if gt.sum() == 0 or wgt.sum() == 0:
            continue
        test_admm.reset(img)
        for j in range(3):
            test_admm.update((img, gt, wgt), criterion)
        if i >= 3:
            break
def test_admm_gc_size():
    for name in list(flags.FLAGS):
        delattr(flags.FLAGS, name)
    root_dir = '../dataset/ACDC-2D-All'
    train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform((128, 128)), augment=augment)
    val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((128, 128)), augment=None)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    AdmmGCSize.setup_arch_flags()
    hparams = flags.FLAGS.flag_values_dict()
    torchnet = get_arch('enet', {'num_classes': 2})
    # torchnet.load_state_dict(torch.load('/Users/jizong/workspace/DGA1033/checkpoints/weakly/enet_fdice_0.8906.pth',
    # map_location=lambda storage, loc: storage))

    weight = torch.Tensor([0, 1])
    criterion = get_loss_fn('cross_entropy', weight=weight)

    test_admm = AdmmGCSize(torchnet, hparams)

    val_score = test_admm.evaluate(val_loader)
    print(val_score)

    for i, (img, gt, wgt, _) in enumerate(train_loader):
        if gt.sum() == 0 or wgt.sum() == 0:
            continue
        test_admm.reset(img)
        for j in range(2):
            test_admm.update((img, gt, wgt), criterion)
        if i >= 4:
            break
def test_prostate_dataloader():
    root_dir = '../dataset/PROSTATE'
    train_dataset = MedicalImageDataset(root_dir,
                                        'train',
                                        transform=segment_transform(
                                            (128, 128)),
                                        augment=augment)
    val_dataset = MedicalImageDataset(root_dir,
                                      'val',
                                      transform=segment_transform((128, 128)),
                                      augment=None)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    # for i, (Img, GT, wgt, _) in enumerate(train_loader):
    #     ToPILImage()(Img[0]).show()
    #     if i == 5:
    #         train_loader.dataset.set_mode('eval')
    #     ToPILImage()(Img[0]).show()
    #     if i == 10:
    #         break
    #
    # for i, (img, gt, wgt, _) in enumerate(val_loader):
    #     ToPILImage()(img[0]).show()
    #     if i == 5:
    #         val_loader.dataset.set_mode('eval')
    #     ToPILImage()(img[0]).show()
    #     if i == 10:
    #         break
    assert train_dataset.__len__() == train_dataset.imgs.__len__()
def test_dataloader():
    root_dir = '../dataset/ACDC-2D-All'
    train_dataset = MedicalImageDataset(root_dir,
                                        'train',
                                        transform=segment_transform(
                                            (500, 500)),
                                        augment=augment)
    val_dataset = MedicalImageDataset(root_dir,
                                      'val',
                                      transform=segment_transform((500, 500)),
                                      augment=None)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for i, (Img, GT, wgt, _) in enumerate(train_loader):
        if GT.sum() <= 0 or wgt.sum() <= 0:
            continue
        show_img_mask_weakmask(Img.numpy(), GT.numpy(), wgt.numpy())

        # # ToPILImage()(Img[0]).show()
        # if i == 5:
        #     train_loader.dataset.set_mode('eval')
        # ToPILImage()(Img[0]).show()
        # if i == 10:
        #     break

    # for i, (img, gt, wgt, _) in enumerate(val_loader):
    #     ToPILImage()(img[0]).show()
    #     if i == 5:
    #         val_loader.dataset.set_mode('eval')
    #     ToPILImage()(img[0]).show()
    #     if i == 10:
    #         break
    assert train_dataset.__len__() == train_dataset.imgs.__len__()
def test_visualization():
    tensorbord_dir = 'runs_test'
    root_dir = '../dataset/ACDC-2D-All'
    train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform((200, 200)), augment=augment)
    val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((200, 200)), augment=None)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    torchnet = get_arch('enet', {'num_classes': 2})
    writer = Writter_tf(tensorbord_dir, torchnet, 40)

    for i in range(2):
        writer.add_images(train_loader, i)

    ## clean up
    shutil.rmtree(tensorbord_dir)
def test_iter():
    dataroot: str = '../dataset/ACDC-all'
    train_set = MedicalImageDataset(
        dataroot,
        'train',
        subfolders=['img', 'gt'],
        transform=segment_transform((256, 256)),
        augment=PILaugment,
        equalize=None,
        pin_memory=False,
        metainfo="[classSizeCalulator, {'C':4,'foldernames':['gt','gt2']}]")
    train_loader_1 = DataLoader(train_set, batch_size=10, num_workers=8)
    train_loader_2 = DataLoader(train_set, batch_size=10, num_workers=8)

    n_batches_1 = len(train_loader_1)
    n_batches_2 = len(train_loader_2)
    assert n_batches_1 == n_batches_2

    train_loader_1_ = iterator_(train_loader_1)
    train_loader_2_ = iterator_(train_loader_2)

    output_list1 = []
    output_list2 = []

    for i in range(n_batches_2 + 1):
        data1 = train_loader_1_.__next__()
        data2 = train_loader_2_.__next__()
        output_list1.extend(data1[2])
        output_list2.extend(data2[2])

    assert set(output_list1) == set(
        [Path(x).stem for x in train_loader_1.dataset.filenames['img']])
    assert set(output_list2) == set(
        [Path(x).stem for x in train_loader_2.dataset.filenames['img']])
def test_dataset():
    dataroot: str = '../dataset/ACDC-all'
    train_set = MedicalImageDataset(
        dataroot,
        'train',
        subfolders=['img', 'gt'],
        transform=segment_transform((256, 256)),
        augment=PILaugment,
        equalize=None,
        pin_memory=False,
        metainfo="[classSizeCalulator, {'C':4,'foldernames':['gt','gt2']}]")
    train_loader = DataLoader(train_set, batch_size=10, num_workers=8)

    n_batches = len(train_loader)

    for i, (imgs, metainfo, filenames) in enumerate(train_loader):
        print(imgs)
        print(metainfo)
        print(filenames)
        time.sleep(2)