Пример #1
0
def main():
    args = get_arguments()

    num_classes = dataset_settings[args.dataset]['num_classes']
    input_size = dataset_settings[args.dataset]['input_size']
    label = dataset_settings[args.dataset]['label']

    model = network(num_classes=num_classes, pretrained=None).cuda()
    model = nn.DataParallel(model)
    state_dict = torch.load(args.restore_weight)
    model.load_state_dict(state_dict)
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
    ])
    dataset = SCHPDataset(root=args.input, input_size=input_size, transform=transform)
    dataloader = DataLoader(dataset)

    if not os.path.exists(args.output):
        os.makedirs(args.output)

    palette = get_palette(num_classes)

    with torch.no_grad():
        for idx, batch in enumerate(dataloader):

            image, meta = batch
            img_name = meta['name'][0]
            c = meta['center'].numpy()[0]
            s = meta['scale'].numpy()[0]
            w = meta['width'].numpy()[0]
            h = meta['height'].numpy()[0]

            output = model(image.cuda())
            upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
            upsample_output = upsample(output)
            upsample_output = upsample_output.squeeze()
            upsample_output = upsample_output.permute(1, 2, 0) #CHW -> HWC

            logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=input_size)
            parsing_result = np.argmax(logits_result, axis=2)

            parsing_result_path = os.path.join(args.output, img_name[:-4]+'.png')
            # plt.imshow(parsing_result)
            # plt.show()

            output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
            output_img.putpalette(palette)
            output_img.save(parsing_result_path)

            if args.logits:
                #logits_result_path = os.path.join(args.output, img_name[:-4] + '.npy')
                np_result_path = os.path.join(args.output, img_name[:-4] + '.npy')

                #np.save(logits_result_path, logits_result)
                np.save(np_result_path, parsing_result)
    return
def main():
    args = get_arguments()

    num_classes = dataset_settings[args.dataset]['num_classes']
    input_size = dataset_settings[args.dataset]['input_size']
    label = dataset_settings[args.dataset]['label']

    model = network(num_classes=num_classes, pretrained=None).cuda()
    model = nn.DataParallel(model)
    state_dict = torch.load(args.restore_weight)
    model.load_state_dict(state_dict)
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.406, 0.456, 0.485],
                             std=[0.225, 0.224, 0.229])
    ])
    INPUT_DIR = os.path.join(args.root, args.input)
    dataset = SCHPDataset(root=INPUT_DIR,
                          input_size=input_size,
                          transform=transform)
    dataloader = DataLoader(dataset)

    #palette = get_palette(num_classes)
    #mask_palette = get_palette_mask(num_classes)
    palette = get_palette(20)
    mask_palette = get_palette_mask(20)
    full_mask_palette = get_palette_fullP_mask(20)

    with torch.no_grad():
        for idx, batch in enumerate(dataloader):

            image, meta = batch
            img_name = meta['name'][0]
            c = meta['center'].numpy()[0]
            s = meta['scale'].numpy()[0]
            w = meta['width'].numpy()[0]
            h = meta['height'].numpy()[0]

            output = model(image.cuda())
            upsample = torch.nn.Upsample(size=input_size,
                                         mode='bilinear',
                                         align_corners=True)
            upsample_output = upsample(output)
            upsample_output = upsample_output.squeeze()
            upsample_output = upsample_output.permute(1, 2, 0)  #CHW -> HWC

            logits_result = transform_logits(
                upsample_output.data.cpu().numpy(),
                c,
                s,
                w,
                h,
                input_size=input_size)
            parsing_result = np.argmax(logits_result, axis=2)

            OUTDIR_GRAY = os.path.join(args.root, args.output)
            OUTDIR_VIS = os.path.join(args.root, args.output_vis)
            #OUTDIR_COLOR=args.root + '/train_color/'
            #OUTDIR_WB=args.root + '/train_img_WB/'
            #OUTDIR_EDGE=args.root + '/train_edge/'
            os.makedirs(OUTDIR_GRAY, exist_ok=True)
            os.makedirs(OUTDIR_VIS, exist_ok=True)
            #os.makedirs(OUTDIR_COLOR, exist_ok=True)
            #os.makedirs(OUTDIR_EDGE, exist_ok=True)
            # os.makedirs(OUTDIR_WB, exist_ok=True)
            prefix = img_name.split('-')[0]
            parsing_result_path = os.path.join(OUTDIR_VIS,
                                               img_name[:-4] + '.png')
            parsing_result_path_gray = os.path.join(OUTDIR_GRAY,
                                                    img_name[:-4] + '.png')
            #cloth_mask_path = os.path.join(OUTDIR_EDGE,img_name[:-4]+'.png')
            #white_mask_path = os.path.join(OUTDIR_WB,img_name[:-4]+'.jpg')
            #cloth_path = os.path.join(OUTDIR_COLOR,img_name[:-4]+'.jpg')
            output_img = Image.fromarray(
                np.asarray(parsing_result, dtype=np.uint8))
            #original mask all class
            original_mask = output_img
            original_mask = map_pixel(original_mask, img_name)
            original_mask.save(parsing_result_path_gray)
            original_mask.putpalette(palette)
            original_mask.save(parsing_result_path)

            # if not 'front' in img_name: continue
            # #new mask with only upper cloth
            # new_mask = output_img
            # new_mask.putpalette(mask_palette)
            # RGB_mask = new_mask.convert('RGB')
            # L_mask = new_mask.convert('L')

            # new_mask_full = output_img
            # new_mask_full.putpalette(full_mask_palette)
            # RGB_mask_full = new_mask_full.convert('RGB')
            # L_mask_full = new_mask_full.convert('L')

            # original_image_path = os.path.join(INPUT_DIR, img_name)
            # original_image = Image.open(original_image_path)
            ## #original_save_path = os.path.join(OUTDIR ,img_name[:-4]+'.png')
            ## #original_image.save(original_save_path)
            # masked_image = ImageChops.multiply(original_image, RGB_mask)
            # masked_image_full = ImageChops.multiply(original_image, RGB_mask_full)
            ## #reverse_mask = ImageOps.invert(RGB_mask)
            ## #white = Image.new("RGB", original_image.size, "white")
            ## #background = ImageChops.multiply(white, reverse_mask)
            ## #masked_image.paste(white, (0,0), reverse_mask)

            # bg_image_full = Image.new("RGBA", masked_image_full.size,(255,255,255,255))
            # white_masked_image = bg_image_full.paste(masked_image_full.convert('RGBA'),(0,0),L_mask_full)
            # bg_image_full = bg_image_full.convert("RGB")
            # bg_image_full.save(white_mask_path)

            # if not '-p-' in img_name: continue
            # bg_image = Image.new("RGBA", masked_image.size,(255,255,255,255))
            # white_masked_image = bg_image.paste(masked_image.convert('RGBA'),(0,0),L_mask)
            # bg_image = bg_image.convert("RGB")
            # L_mask.save(cloth_mask_path)
            # bg_image.save(cloth_path)

            if args.logits:
                logits_result_path = os.path.join(args.root,
                                                  img_name[:-4] + '.npy')
                np.save(logits_result_path, logits_result)
    return
Пример #3
0
def main():
    args = get_arguments()

    num_classes = dataset_settings[args.dataset]['num_classes']
    input_size = dataset_settings[args.dataset]['input_size']
    label = dataset_settings[args.dataset]['label']

    model = network(num_classes=num_classes, pretrained=None).cuda()
    model = nn.DataParallel(model)
    state_dict = torch.load(args.restore_weight)
    model.load_state_dict(state_dict)
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.406, 0.456, 0.485],
                             std=[0.225, 0.224, 0.229])
    ])
    dataset = SCHPDataset(root=args.input,
                          input_size=input_size,
                          transform=transform)
    dataloader = DataLoader(dataset)

    palette = get_palette(num_classes)

    print(f"Found {len(dataloader)} files")
    with torch.no_grad():
        for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):

            image, meta = batch
            img_name = meta['name'][0]
            img_path = meta['path'][0]
            c = meta['center'].numpy()[0]
            s = meta['scale'].numpy()[0]
            w = meta['width'].numpy()[0]
            h = meta['height'].numpy()[0]

            output = model(image.cuda())
            upsample = torch.nn.Upsample(size=input_size,
                                         mode='bilinear',
                                         align_corners=True)
            upsample_output = upsample(output)
            upsample_output = upsample_output.squeeze()
            upsample_output = upsample_output.permute(1, 2, 0)  #CHW -> HWC

            logits_result = transform_logits(
                upsample_output.data.cpu().numpy(),
                c,
                s,
                w,
                h,
                input_size=input_size)

            parsing_result = np.argmax(logits_result, axis=2)

            img_subpath = img_path.replace(args.input, "").lstrip("/")
            parsing_result_path = os.path.join(
                args.output, img_subpath[:-4] + args.postfix_filename + '.png')
            os.makedirs(os.path.dirname(parsing_result_path), exist_ok=True)

            output_img = Image.fromarray(
                np.asarray(parsing_result, dtype=np.uint8))
            output_img.putpalette(palette)  # colors the labels it seems
            output_img.save(parsing_result_path)
            if args.logits:
                fname = img_name[:-4]
                logits_result_path = os.path.join(
                    args.output,
                    img_subpath[:-4] + args.postfix_filename + '.npy')
                if args.argmax_logits:
                    logits_result_path += "c"  # c for compressed
                    result = parsing_result
                else:
                    result = logits_result
                np.save(logits_result_path, result)
    return
def main():
    args = get_arguments()

    num_classes = dataset_settings[args.dataset]['num_classes']
    input_size = dataset_settings[args.dataset]['input_size']
    label = dataset_settings[args.dataset]['label']

    model = network(num_classes=num_classes, pretrained=None).cuda()
    model = nn.DataParallel(model)
    state_dict = torch.load(args.restore_weight)
    model.load_state_dict(state_dict)
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.406, 0.456, 0.485],
                             std=[0.225, 0.224, 0.229])
    ])
    subdirs = [x[1] for x in os.walk(args.input)]
    for subdir in subdirs[0]:
        DIR = args.input + '/' + subdir
        dataset = SCHPDataset(root=DIR,
                              input_size=input_size,
                              transform=transform)
        dataloader = DataLoader(dataset)

        PARSING_DIR = args.input + '/test_label'
        os.makedirs(PARSING_DIR, exist_ok=True)
        CLOTHING_DIR = args.input + '/test_color'
        os.makedirs(CLOTHING_DIR, exist_ok=True)
        CLOTHMASK_DIR = args.input + '/test_edge'
        os.makedirs(CLOTHMASK_DIR, exist_ok=True)

        palette = get_palette(num_classes)
        mask_palette = get_palette_mask(num_classes)

        with torch.no_grad():
            for idx, batch in enumerate(dataloader):

                image, meta = batch
                img_name = meta['name'][0]
                c = meta['center'].numpy()[0]
                s = meta['scale'].numpy()[0]
                w = meta['width'].numpy()[0]
                h = meta['height'].numpy()[0]

                output = model(image.cuda())
                upsample = torch.nn.Upsample(size=input_size,
                                             mode='bilinear',
                                             align_corners=True)
                upsample_output = upsample(output)
                upsample_output = upsample_output.squeeze()
                upsample_output = upsample_output.permute(1, 2, 0)  #CHW -> HWC

                logits_result = transform_logits(
                    upsample_output.data.cpu().numpy(),
                    c,
                    s,
                    w,
                    h,
                    input_size=input_size)
                parsing_result = np.argmax(logits_result, axis=2)

                parsing_result_path = os.path.join(PARSING_DIR,
                                                   img_name[:-4] + '.png')
                cloth_mask_path = os.path.join(CLOTHMASK_DIR,
                                               img_name[:-4] + '.jpg')
                cloth_path = os.path.join(CLOTHING_DIR, img_name[:-4] + '.png')
                output_img = Image.fromarray(
                    np.asarray(parsing_result, dtype=np.uint8))
                #original mask all class
                original_mask = output_img
                original_mask.save(parsing_result_path)
                # original_mask.putpalette(palette)
                # original_mask.save(parsing_result_path)
                if not '-p-' in img_name: continue
                #new mask with only upper cloth
                new_mask = output_img
                new_mask.putpalette(mask_palette)
                RGB_mask = new_mask.convert('RGB')
                L_mask = new_mask.convert('L')
                original_image_path = os.path.join(DIR, img_name)
                print('original_image_path is ', original_image_path)
                original_image = Image.open(original_image_path)
                #original_save_path = os.path.join(OUTDIR ,img_name[:-4]+'.png')
                #original_image.save(original_save_path)
                masked_image = ImageChops.multiply(original_image, RGB_mask)
                bg_image = Image.new("RGBA", masked_image.size,
                                     (255, 255, 255, 255))
                white_masked_image = bg_image.paste(
                    masked_image.convert('RGBA'), (0, 0), L_mask)
                #reverse_mask = ImageOps.invert(RGB_mask)
                #white = Image.new("RGB", original_image.size, "white")
                #background = ImageChops.multiply(white, reverse_mask)
                #masked_image.paste(white, (0,0), reverse_mask)
                L_mask.save(cloth_mask_path)
                bg_image.save(cloth_path)

                if args.logits:
                    logits_result_path = os.path.join(args.output,
                                                      img_name[:-4] + '.npy')
                    np.save(logits_result_path, logits_result)
    return