def test():
    from train_config import get_config
    from tools.utils_for_weakly import get_bbox_instance
    args = get_config()
    data_loader = get_cityscapes_dataloader(args, True)
    for i, sample in enumerate(data_loader):
        image = sample['image']
        instance = sample['instance']
        label = sample['label']
        # bbox = sample['bbox']
        bbox_instance = get_bbox_instance(instance[0], label[0])
        bboxes = torch.zeros(1, image.size(-2), image.size(-1))
        for i in bbox_instance:
            box_temp = i['mask'].unsqueeze(dim=0)
            bboxes[box_temp > 0] = i['cls']

        transforms.ToPILImage()(image[0]).show()
        transforms.ToPILImage()(bboxes).show()
        # transforms.ToPILImage()(instance[0]*10).show()
        # transforms.ToPILImage()(label[0]*50).show()
        # transforms.ToPILImage()(bbox_instance[0]*20).show()
        # bbox_temp = bbox[0].argmax(dim=0).float()
        # transforms.ToPILImage()(bbox_temp*50).show()
        print(instance.unique())
        print(label.unique())
        # print(bbox_temp.unique())

        print('Next')
        # print(sample)
        pass
def cat_data_set_test():
    from train_config import get_config
    args = get_config()
    cat_data_set = get_cityscapes_dataset(args, train=True)
    print(len(cat_data_set))
    data_loader = DataLoader(cat_data_set, shuffle=True)
    for sample in data_loader:
        print(sample['image'].shape)
        pass
def _test_freeze_batch_norm():
    from train_config import get_config
    from models import get_model
    from models.sync_batchnorm import convert_model
    args = get_config()
    model = get_model(args)
    model = convert_model(model)
    for m in model._modules:
        print(m)
    model = nn.DataParallel(model)
    freeze_batch_norm(model)
    check_batch_norm_freeze(model)
def test_weakly_dataset():
    from train_config import get_config
    args = get_config()
    args.data_choose_size = 914
    data_loader_1 = get_cityscapes_dataset(args, True)
    args.data_choose_size = -914
    data_loader_2 = get_cityscapes_dataset(args, True)

    # check overlap:
    for img1 in data_loader_1.image_list:
        for img2 in data_loader_2.image_list:
            if img1 == img2:
                raise Exception('Overlap!')
    pass
def test_label_data_set():
    from train_config import get_config
    args = get_config()
    data_set = BalancedCityscapesDataset(crop_size=512,
                                         fake_size=41,
                                         label=6,
                                         root_dir=args.cityscapes_data_path,
                                         type="train",
                                         choose_size=914,
                                         transform=None,
                                         repeat=1)
    print(data_set.__len__())
    for sample in data_set:
        image = sample['image']
        image.show()