Esempio n. 1
0
def main(config):
    #print("initializing the dataloader")

    model = networks.UNet(
        in_channels=1,
        out_channels=1,
        depth=4,
        conv_num=2,
        wf=6,
        padding=True,
        batch_norm=True,
        up_mode="upsample",
        with_tanh=False,
        sync_bn=True,
        antialiasing=True,
    )

    ## load model
    checkpoint_path = "./checkpoints/detection/FT_Epoch_latest.pt"
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint["model_state"])
    #print("model weights loaded")
    if config.GPU > 0:
        model.to(config.GPU)
    model.eval()

    ## dataloader and transformation
    #print("directory of testing image: " + config.test_path)
    imagelist = os.listdir(config.test_path)
    imagelist.sort()
    total_iter = 0

    P_matrix = {}
    save_url = os.path.join(config.output_dir)
    mkdir_if_not(save_url)

    input_dir = os.path.join(save_url, "input")
    output_dir = os.path.join(save_url, "mask")
    # blend_output_dir=os.path.join(save_url, 'blend_output')
    mkdir_if_not(input_dir)
    mkdir_if_not(output_dir)
    # mkdir_if_not(blend_output_dir)

    idx = 0

    for image_name in imagelist:

        idx += 1

        #print("processing", image_name)

        results = []
        scratch_file = os.path.join(config.test_path, image_name)
        if not os.path.isfile(scratch_file):
            #print("Skipping non-file %s" % image_name)
            continue
        scratch_image = Image.open(scratch_file).convert("RGB")

        w, h = scratch_image.size

        transformed_image_PIL = data_transforms(scratch_image, config.input_size)

        scratch_image = transformed_image_PIL.convert("L")
        scratch_image = tv.transforms.ToTensor()(scratch_image)

        scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image)

        scratch_image = torch.unsqueeze(scratch_image, 0)
        if config.GPU > 0:
            scratch_image = scratch_image.to(config.GPU)

        P = torch.sigmoid(model(scratch_image))

        P = P.data.cpu()

        tv.utils.save_image(
            (P >= 0.4).float(),
            os.path.join(output_dir, image_name[:-4] + ".png",),
            nrow=1,
            padding=0,
            normalize=True,
        )
        transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png"))
Esempio n. 2
0
def main(config, input_images, image_names):
    # print("initializing the dataloader")

    model = networks.UNet(
        in_channels=1,
        out_channels=1,
        depth=4,
        conv_num=2,
        wf=6,
        padding=True,
        batch_norm=True,
        up_mode="upsample",
        with_tanh=False,
        sync_bn=True,
        antialiasing=True,
    )

    ## load model
    checkpoint_path = "./checkpoints/detection/FT_Epoch_latest.pt"
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint["model_state"])
    # print("model weights loaded")

    model.to(config.GPU)
    model.eval()

    ## dataloader and transformation
    # print("directory of testing image: " + config.test_path)
    # imagelist = os.listdir(config.test_path)
    # imagelist.sort()
    total_iter = 0

    P_matrix = {}
    save_url = os.path.join(config.output_dir)
    # mkdir_if_not(save_url)

    input_dir = os.path.join(save_url, "input")
    output_dir = os.path.join(save_url, "mask")
    # blend_output_dir=os.path.join(save_url, 'blend_output')
    # mkdir_if_not(input_dir)
    # mkdir_if_not(output_dir)
    # mkdir_if_not(blend_output_dir)

    idx = 0
    input_dirs = []
    mask_dirs = []

    # for image_name in imagelist:
    # for scratch_image in input_images:
    for i in range(len(input_images)):
        scratch_image = input_images[i]
        image_name = image_names[i]
        idx += 1

        # print("processing ", image_name)

        results = []
        # scratch_file = os.path.join(config.test_path, image_name)
        # if not os.path.isfile(scratch_file):
        # print("Skipping non-file %s" % image_name)
        # continue
        # scratch_image = Image.open(scratch_file).convert("RGB")

        w, h = scratch_image.size

        transformed_image_PIL = data_transforms(scratch_image,
                                                config.input_size)

        scratch_image = transformed_image_PIL.convert("L")
        scratch_image = tv.transforms.ToTensor()(scratch_image)

        scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image)

        scratch_image = torch.unsqueeze(scratch_image, 0)
        scratch_image = scratch_image.to(config.GPU)

        P = torch.sigmoid(model(scratch_image))

        P = P.data.cpu()

        mask_dirs.append(
            save_image(
                (P >= 0.4).float(),
                os.path.join(
                    output_dir,
                    image_name[:-4] + ".png",
                ),
                # os.path.join(output_dir, str(idx) + ".png",),
                nrow=1,
                padding=0,
                normalize=True,
            ))
        # transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png"))
        input_dirs.append(transformed_image_PIL)
        # transformed_image_PIL.save(os.path.join(input_dir, str(idx) + ".png"))

        # single_mask=np.array((P>=0.4).float())[0,0,:,:]
        # RGB_mask=np.stack([single_mask,single_mask,single_mask],axis=2)
        # blend_output=blend_mask(transformed_image_PIL,RGB_mask)
        # blend_output.save(os.path.join(blend_output_dir,image_name[:-4]+'.png'))

    return input_dirs, mask_dirs