Ejemplo n.º 1
0
    def test_generate_trimap(self):
        image = cv.imread('fg/1-1252426161dfXY.jpg')
        alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0)
        trimap = generate_trimap(alpha)
        self.assertEqual(trimap.shape, (615, 410))

        # ensure np.where works as expected.
        count = 0
        h, w = trimap.shape[:2]
        for i in range(h):
            for j in range(w):
                if trimap[i, j] == unknown_code:
                    count += 1
        x_indices, y_indices = np.where(trimap == unknown_code)
        num_unknowns = len(x_indices)
        self.assertEqual(count, num_unknowns)

        # ensure an unknown pixel is chosen
        ix = random.choice(range(num_unknowns))
        center_x = x_indices[ix]
        center_y = y_indices[ix]

        self.assertEqual(trimap[center_x, center_y], unknown_code)

        x, y = random_choice(trimap)
        # print(x, y)
        image = safe_crop(image, x, y)
        trimap = safe_crop(trimap, x, y)
        alpha = safe_crop(alpha, x, y)
        cv.imwrite('temp/test_generate_trimap_image.png', image)
        cv.imwrite('temp/test_generate_trimap_trimap.png', trimap)
        cv.imwrite('temp/test_generate_trimap_alpha.png', alpha)
    def __getitem__(self, i):
        fcount = self.fgs[i]
        bcount = np.random.randint(num_bgs)
        img, alpha, fg, bg = process(fcount, bcount)

        # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)

        trimap = gen_trimap(alpha)

        x, y = random_choice(trimap, crop_size)
        img = safe_crop(img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)

        trimap = gen_trimap(alpha)

        # Flip array left to right randomly (prob=1:1)
        if np.random.random_sample() > 0.5:
            img = np.fliplr(img)
            trimap = np.fliplr(trimap)
            alpha = np.fliplr(alpha)

        x = torch.zeros((4, im_size, im_size), dtype=torch.float)
        img = transforms.ToPILImage()(img)
        img = self.transformer(img)
        x[0:3, :, :] = img
        x[3, :, :] = torch.from_numpy(trimap.copy() / 255.)

        y = np.empty((2, im_size, im_size), dtype=np.float32)
        y[0, :, :] = alpha / 255.
        mask = np.equal(trimap, 128).astype(np.float32)
        y[1, :, :] = mask
        
        return x, y
Ejemplo n.º 3
0
    def __getitem__(self, i):
        name = self.names[i]
        fcount = int(name.split('.')[0].split('_')[0])
        bcount = int(name.split('.')[0].split('_')[1])
        im_name = fg_files[fcount]
        bg_name = bg_files[bcount]
        img, alpha, _, _ = process(im_name, bg_name)
        # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]

        # trimap = gen_trimap(alpha)
        x, y, crop_size = random_choice(img, different_sizes)
        img = safe_crop(img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)

        # trimap = gen_trimap(alpha)

        # Flip array left to right randomly (prob=1:1)
        if np.random.random_sample() > 0.5:
            img = np.fliplr(img)
            # trimap = np.fliplr(trimap)
            alpha = np.fliplr(alpha)

        # x = torch.zeros((4, im_size, im_size), dtype=torch.float)
        img = img[..., ::-1]  # RGB
        img = transforms.ToPILImage()(img)
        img = self.transformer(img)
        x = img

        # y = np.empty((2, im_size, im_size), dtype=np.float32)
        y = alpha / 255.
        # mask = np.equal(trimap, 128).astype(np.float32)
        # y[1, :, :] = mask

        return x, y
    def __getitem__(self, i):
        fcount = self.fgs[i]

        if i % args.batch_size == 0:
            self.current_index = fcount
            alpha = get_raw("a", fcount)
            alpha = np.reshape(alpha, (alpha.shape[0], alpha.shape[1]))
            fg = get_raw("fg", fcount)
            self.current_fg = fg
            self.current_alpha = alpha
            self.is_resize = True if np.random.rand() < 0.25 else False
        else:
            fg = self.current_fg
            alpha = self.current_alpha

        bcount = np.random.randint(num_bgs)
        img, _, _, bg = process(fcount, bcount)

        if self.is_resize:
            interpolation = maybe_random_interp(cv.INTER_NEAREST)
            img = cv.resize(img, (640, 640), interpolation=interpolation)
            # fg = cv.resize(fg, (640, 640), interpolation=interpolation)
            alpha = cv.resize(alpha, (640, 640), interpolation=interpolation)
            # bg = cv.resize(bg, (640, 640), interpolation=interpolation)

        # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)

        trimap = gen_trimap(alpha)

        x, y = random_choice(trimap, crop_size)
        img = safe_crop(img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)

        trimap = gen_trimap(alpha)

        # Flip array left to right randomly (prob=1:1)
        if np.random.random_sample() > 0.5:
            img = np.fliplr(img)
            trimap = np.fliplr(trimap)
            alpha = np.fliplr(alpha)

        x = torch.zeros((4, im_size, im_size), dtype=torch.float)
        img = transforms.ToPILImage()(img)
        img = self.transformer(img)
        x[0:3, :, :] = img
        x[3, :, :] = torch.from_numpy(trimap.copy() / 255.)

        y = np.empty((2, im_size, im_size), dtype=np.float32)
        y[0, :, :] = alpha / 255.
        mask = np.equal(trimap, 128).astype(np.float32)
        y[1, :, :] = mask

        return x, y
Ejemplo n.º 5
0
    def __getitem__(self, i):
        name = self.names[i]
        fcount = int(name.split('.')[0].split('_')[0])
        bcount = int(name.split('.')[0].split('_')[1])
        img, alpha, fg, bg = process(fcount, bcount)

        # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)

        # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)

        trimap = gen_trimap(alpha)
        x, y = random_choice(trimap, crop_size)
        img = safe_crop(img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)
        fg = safe_crop(fg, x, y, crop_size)
        bg = safe_crop(bg, x, y, crop_size)

        trimap = gen_trimap(alpha)

        # Flip array left to right randomly (prob=1:1)
        if np.random.random_sample() > 0.5:
            img = np.fliplr(img)
            trimap = np.fliplr(trimap)
            alpha = np.fliplr(alpha)
            fg = np.fliplr(fg)
            bg = np.fliplr(bg)
        img = img.copy()
        trimap = trimap.copy()
        alpha = alpha.copy()
        fg = fg.copy()
        bg = bg.copy()
        x = torch.zeros((4, im_size, im_size), dtype=torch.float)
        image = img[..., ::-1].copy()  # RGB
        image = transforms.ToPILImage()(image)
        image = self.transformer(image)
        x[0:3, :, :] = image
        x[3, :, :] = torch.from_numpy(trimap.copy() / 255.)

        y = np.empty((2, im_size, im_size), dtype=np.float32)
        y[0, :, :] = alpha / 255.
        mask = np.equal(trimap, 128).astype(np.float32)
        y[1, :, :] = mask

        img = transforms.ToPILImage()(img)
        img = self.transformer(img)
        fg = transforms.ToPILImage()(fg)
        fg = self.transformer(fg)
        bg = transforms.ToPILImage()(bg)
        bg = self.transformer(bg)

        return x, y, img, fg, bg
    def __getitem__(self, i):
        name = self.names[i].split()
        # fcount = int(name.split('.')[0].split('_')[0])
        # bcount = int(name.split('.')[0].split('_')[1])
        # im_name = fg_files[fcount]
        # bg_name = bg_files[bcount]
        im_name = name[0]
        alpha_name = name[1]
        # img, alpha, fg, bg = process(im_name, bg_name)
        img = cv.imread(
            os.path.join(
                '/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/',
                im_name))
        alpha = cv.imread(
            os.path.join(
                '/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/',
                alpha_name), cv.IMREAD_UNCHANGED)
        alpha = alpha[:, :, 3]

        # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)

        trimap = gen_trimap(alpha)
        x, y = random_choice(trimap, crop_size)
        img = safe_crop(img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)
        # label_path = self.labels[i]
        # mask = cv.imread(label_path, cv.IMREAD_UNCHANGED).astype(np.float32)
        # mask = safe_crop(mask, x, y, crop_size)

        trimap = gen_trimap(alpha)

        # Flip array left to right randomly (prob=1:1)
        if np.random.random_sample() > 0.5:
            img = np.fliplr(img)
            trimap = np.fliplr(trimap)
            alpha = np.fliplr(alpha)

        x = torch.zeros((4, im_size, im_size), dtype=torch.float)
        img = img[..., ::-1]  # RGB
        img = transforms.ToPILImage()(img)
        img = self.transformer(img)
        x[0:3, :, :] = img
        x[3, :, :] = torch.from_numpy(trimap.copy() / 255.)

        y = np.empty((2, im_size, im_size), dtype=np.float32)
        y[0, :, :] = alpha / 255.
        mask = np.equal(trimap, 128).astype(np.float32)
        y[1, :, :] = mask

        return x, y
Ejemplo n.º 7
0
    def __getitem__(self, idx):
        i = idx * batch_size

        length = min(batch_size, (len(self.names) - i))
        batch_x = np.empty((length, img_rows, img_cols, 4), dtype=np.float32)
        batch_y = np.empty((length, img_rows, img_cols, 13), dtype=np.float32)

        for i_batch in range(length):
            name = self.names[i]
            fcount = int(name.split('.')[0].split('_')[0])
            bcount = int(name.split('.')[0].split('_')[1])
            im_name = fg_files[fcount]
            bg_name = bg_files[bcount]
            image, alpha, fg, bg = process(im_name, bg_name)

            # crop size 320:640:480 = 1:1:1
            different_sizes = [(320, 320), (480, 480), (640, 640)]
            crop_size = random.choice(different_sizes)

            trimap = generate_trimap(alpha)
            x, y = random_choice(trimap, crop_size)
            image = safe_crop(image, x, y, crop_size)
            alpha = safe_crop(alpha, x, y, crop_size)
            fg = safe_crop(fg, x, y, crop_size)
            bg = safe_crop(bg, x, y, crop_size)

            trimap = generate_trimap(alpha)

            # Flip array left to right randomly (prob=1:1)
            if np.random.random_sample() > 0.5:
                image = np.fliplr(image)
                trimap = np.fliplr(trimap)
                alpha = np.fliplr(alpha)
                fg = np.fliplr(fg)
                bg = np.fliplr(bg)

            batch_x[i_batch, :, :, 0:3] = image / 255.
            batch_x[i_batch, :, :, 3] = trimap / 255.

            mask = np.equal(trimap, 128).astype(np.float32)
            batch_y[i_batch, :, :, 0] = alpha / 255.
            batch_y[i_batch, :, :, 1] = mask
            batch_y[i_batch, :, :, 2:5] = image / 255.
            batch_y[i_batch, :, :, 5:8] = fg / 255.
            batch_y[i_batch, :, :, 8:11] = bg / 255.
            batch_y[i_batch, :, :, 12] = trimap / 255.

            i += 1

        return batch_x, batch_y
Ejemplo n.º 8
0
    def __getitem__(self, idx):
        np.random.shuffle(self.names)
        i = idx * batch_size

        length = min(batch_size, (len(self.names) - i))
        batch_x = np.empty((length, img_rows, img_cols, 4), dtype=np.float32)
        batch_y = np.empty((length, img_rows, img_cols, 2), dtype=np.float32)

        for i_batch in range(length):
            #             name = self.names[i]
            fcount = i + i_batch  #int(name.split('.')[0].split('_')[0])
            bcount = i + i_batch  #int(name.split('.')[0].split('_')[1])

            #             print(fg_files[fcount],fcount,idx,batch_size,i_batch,length)
            im_name = self.names[fcount]  #fg_files[fcount]
            bg_name = self.names[fcount]  #bg_files[bcount]
            image, alpha, fg, bg = process(im_name, bg_name)

            # crop size 320:640:480 = 1:1:1
            different_sizes = [(320, 320), (480, 480), (640, 640)]
            crop_size = random.choice(different_sizes)

            #             trimap = generate_trimap(alpha)
            trimap = generate_trimap_withmask(alpha)
            x, y = random_choice(trimap, crop_size)
            image = safe_crop(image, x, y, crop_size)
            alpha = safe_crop(alpha, x, y, crop_size)

            #             trimap = generate_trimap(alpha)
            trimap = generate_trimap_withmask(alpha)

            # Flip array left to right randomly (prob=1:1)
            if np.random.random_sample() > 0.5:
                image = np.fliplr(image)
                trimap = np.fliplr(trimap)
                alpha = np.fliplr(alpha)

            batch_x[i_batch, :, :, 0:3] = image / 255.
            batch_x[i_batch, :, :, 3] = trimap / 255.

            mask = np.equal(trimap, 128).astype(np.float32)
            batch_y[i_batch, :, :, 0] = alpha / 255.
            batch_y[i_batch, :, :, 1] = mask


#             i += 1
        print('*********batch done!', idx, length)

        return batch_x, batch_y
Ejemplo n.º 9
0
 def test_flip(self):
     image = cv.imread('fg/1-1252426161dfXY.jpg')
     # print(image.shape)
     alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0)
     trimap = generate_trimap(alpha)
     x, y = random_choice(trimap)
     image = safe_crop(image, x, y)
     trimap = safe_crop(trimap, x, y)
     alpha = safe_crop(alpha, x, y)
     image = np.fliplr(image)
     trimap = np.fliplr(trimap)
     alpha = np.fliplr(alpha)
     cv.imwrite('temp/test_flip_image.png', image)
     cv.imwrite('temp/test_flip_trimap.png', trimap)
     cv.imwrite('temp/test_flip_alpha.png', alpha)
Ejemplo n.º 10
0
    def __getitem__(self, i):
        fcount = self.fgs[i]
        bcount = self.bgs[i]
        img, alpha, fg, bg = process(fcount, bcount)

        # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)

        trimap = gen_trimap(alpha)

        if args.data_augumentation:
            img, alpha = self._composite_fg(img, alpha, fg, bg, i)

        x, y = random_choice(trimap, crop_size)
        img = safe_crop(img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)

        trimap = gen_trimap(alpha)

        # Flip array left to right randomly (prob=1:1)
        if np.random.random_sample() > 0.5:
            img = np.fliplr(img)
            trimap = np.fliplr(trimap)
            alpha = np.fliplr(alpha)

        x = torch.zeros((4, im_size, im_size), dtype=torch.float)
        img = img[..., ::-1]  # RGB
        img = transforms.ToPILImage()(img)
        img = self.transformer(img)
        x[0:3, :, :] = img
        x[3, :, :] = torch.from_numpy(trimap.copy() / 255.)

        y = np.empty((2, im_size, im_size), dtype=np.float32)
        y[0, :, :] = alpha / 255.
        mask = np.equal(trimap, 128).astype(np.float32)
        y[1, :, :] = mask

        if (i >= self.__len__() - 1):
            fgs = self.fgs
            bgs = self.bgs
            random.shuffle(fgs)
            random.shuffle(bgs)
            self.fgs = fgs
            self.bgs = bgs

        return x, y
Ejemplo n.º 11
0
    def __getitem__(self, idx):
        i = idx * batch_size

        length = min(batch_size, (len(self.names) - i))
        batch_x = np.empty((length, img_rows, img_cols, 4), dtype=np.float32)
        batch_y = np.empty((length, img_rows, img_cols, 2), dtype=np.float32)

        for i_batch in range(length):
            name = self.names[i]
            fcount = int(name.split('.')[0].split('_')[0])
            bcount = int(name.split('.')[0].split('_')[1])
            im_name = fg_files[fcount]
            bg_name = bg_files[bcount]
            image, alpha, fg, bg = process(im_name, bg_name)

            # crop size 320:640:480 = 1:1:1
            different_sizes = [(320, 320), (480, 480), (640, 640)]
            crop_size = random.choice(different_sizes)

            trimap = generate_trimap(alpha)
            x, y = random_choice(trimap, crop_size)
            image = safe_crop(image, x, y, crop_size)
            alpha = safe_crop(alpha, x, y, crop_size)
            fg = safe_crop(fg, x, y, crop_size)
            bg = safe_crop(bg, x, y, crop_size)

            trimap = generate_trimap(alpha)

            # Flip array left to right randomly (prob=1:1)
            if np.random.random_sample() > 0.5:
                image = np.fliplr(image)
                trimap = np.fliplr(trimap)
                alpha = np.fliplr(alpha)

            batch_x[i_batch, :, :, 0:3] = image / 255.
            batch_x[i_batch, :, :, 3] = trimap / 255.

            mask = np.equal(trimap, 128).astype(np.float32)
            batch_y[i_batch, :, :, 0] = alpha / 255.
            batch_y[i_batch, :, :, 1] = mask

            i += 1

        return batch_x, batch_y
Ejemplo n.º 12
0
 def test_resize(self):
     name = '0_0.png'
     filename = os.path.join('merged', name)
     image = cv.imread(filename)
     bg_h, bg_w = image.shape[:2]
     a = get_alpha(name)
     a_h, a_w = a.shape[:2]
     alpha = np.zeros((bg_h, bg_w), np.float32)
     alpha[0:a_h, 0:a_w] = a
     trimap = generate_trimap(alpha)
     # 剪切尺寸 320:640:480 = 3:1:1
     crop_size = (480, 480)
     x, y = random_choice(trimap, crop_size)
     image = safe_crop(image, x, y, crop_size)
     trimap = safe_crop(trimap, x, y, crop_size)
     alpha = safe_crop(alpha, x, y, crop_size)
     cv.imwrite('temp/test_resize_image.png', image)
     cv.imwrite('temp/test_resize_trimap.png', trimap)
     cv.imwrite('temp/test_resize_alpha.png', alpha)
Ejemplo n.º 13
0
    def __getitem__(self, i):
        name = self.names[i]
        fcount = int(name.split('.')[0].split('_')[0])
        bcount = int(name.split('.')[0].split('_')[1])
        im_name = fg_files[fcount]
        bg_name = bg_files[bcount]
        img, alpha, fg, bg = process(im_name, bg_name)

        # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)

        trimap = generate_trimap(alpha)
        x, y = random_choice(trimap, crop_size)
        img = safe_crop(img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)

        trimap = generate_trimap(alpha)

        # Flip array left to right randomly (prob=1:1)
        if np.random.random_sample() > 0.5:
            img = np.fliplr(img)
            trimap = np.fliplr(trimap)
            alpha = np.fliplr(alpha)

        x = torch.zeros((4, im_size, im_size), dtype=torch.float)
        img = transforms.ToPILImage()(img)
        img = self.transformer(img)
        x[0:3, :, :] = img
        x[3, :, :] = torch.from_numpy(trimap.copy()) / 255.

        y = np.empty((2, im_size, im_size), dtype=np.float32)
        y[0, :, :] = alpha / 255.
        mask = np.equal(trimap, 128).astype(np.float32)
        y[1, :, :] = mask

        return x, y
Ejemplo n.º 14
0
        bg_h, bg_w = bgr_img.shape[:2]
        print('bg_h, bg_w: ' + str((bg_h, bg_w)))

        a = get_alpha_test(image_name)
        a_h, a_w = a.shape[:2]
        print('a_h, a_w: ' + str((a_h, a_w)))

        alpha = np.zeros((bg_h, bg_w), np.float32)
        alpha[0:a_h, 0:a_w] = a
        trimap = generate_trimap(alpha)
        different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)
        x, y = random_choice(trimap, crop_size)
        print('x, y: ' + str((x, y)))

        bgr_img = safe_crop(bgr_img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)
        trimap = safe_crop(trimap, x, y, crop_size)
        cv.imwrite('images/{}_image.png'.format(i), np.array(bgr_img).astype(np.uint8))
        cv.imwrite('images/{}_trimap.png'.format(i), np.array(trimap).astype(np.uint8))
        cv.imwrite('images/{}_alpha.png'.format(i), np.array(alpha).astype(np.uint8))

        x_test = np.empty((1, img_rows, img_cols, 4), dtype=np.float32)
        x_test[0, :, :, 0:3] = bgr_img / 255.
        x_test[0, :, :, 3] = trimap / 255.

        y_true = np.empty((1, img_rows, img_cols, 2), dtype=np.float32)
        y_true[0, :, :, 0] = alpha / 255.
        y_true[0, :, :, 1] = trimap / 255.

        y_pred = final.predict(x_test)
Ejemplo n.º 15
0
        print('Start processing image: {}'.format(filename))
        x_test = np.empty((1, img_rows, img_cols, 4), dtype=np.float32)
        bgr_img = cv.imread(os.path.join(out_test_path, filename))
        bg_h, bg_w = bgr_img.shape[:2]
        print(bg_h, bg_w)
        a = get_alpha_test(image_name)
        a_h, a_w = a.shape[:2]
        print(a_h, a_w)
        alpha = np.zeros((bg_h, bg_w), np.float32)
        alpha[0:a_h, 0:a_w] = a
        trimap = generate_trimap(alpha)
        different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)
        x, y = random_choice(trimap, crop_size)
        print(x, y)
        bgr_img = safe_crop(bgr_img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)
        trimap = safe_crop(trimap, x, y, crop_size)
        cv.imwrite('images/{}_image.png'.format(i), np.array(bgr_img).astype(np.uint8))
        cv.imwrite('images/{}_trimap.png'.format(i), np.array(trimap).astype(np.uint8))
        cv.imwrite('images/{}_alpha.png'.format(i), np.array(alpha).astype(np.uint8))

        x_test = np.empty((1, 320, 320, 4), dtype=np.float32)
        x_test[0, :, :, 0:3] = bgr_img / 255.
        x_test[0, :, :, 3] = trimap / 255.

        y_true = np.empty((1, 320, 320, 2), dtype=np.float32)
        y_true[0, :, :, 0] = alpha / 255.
        y_true[0, :, :, 1] = trimap / 255.

        y_pred = final.predict(x_test)
Ejemplo n.º 16
0
    def __getitem__(self, idx):
        # 每一个epoch,从前往后读取self.ids,依据id,读取self.names
        # idx应为第几个batch,i为该次batch的起始点
        i = idx * batch_size
        # length为当前batch的大小
        length = min(batch_size, (len(self.names) - i))
        batch_x = np.empty((length, img_rows, img_cols, 3), dtype=np.float32)
        batch_y = np.empty((length, img_rows, img_cols, num_classes),
                           dtype=np.uint8)

        for i_batch in range(length):
            ###normal
            img_name = self.names[i]  # xx.jpg
            img_name_prefix, useless = os.path.splitext(img_name)
            mask_name = img_name_prefix + '.png'

            image_path = os.path.join(rgb_image_path, img_name)
            image = cv2.imread(image_path, 1)
            mask_path = os.path.join(mask_img_path, mask_name)
            mask = cv2.imread(mask_path, 0)

            ###temp
            # img_name = self.names[i] # xx.jpg
            # image_path = os.path.join(rgb_image_path, img_name)
            # image = cv2.imread(image_path,1)

            # img_name_prefix = img_name.split('split')[0][0:-1]
            # mask_name = img_name_prefix+'.png'
            # mask_path = os.path.join(mask_img_path, mask_name)
            # mask = cv2.imread(mask_path,0)
            ##mask = (mask!=0)*255

            # 随机缩放image和mask,0.5~2.0
            image, mask = random_rescale_image_and_mask(image, mask)

            # 实时处理alpha,得到trimap:128/0/255
            trimap = generate_random_trimap(mask)

            # 定义随机剪裁尺寸
            crop_size = (512, 512)
            # 获得剪裁的起始点,其目的是为了保证剪裁的图像中包含未知像素
            x, y = random_choice(trimap, crop_size)

            # 剪裁image,到指定剪裁尺寸crop_size,并缩放到(img_rows,img_cols)
            image = safe_crop(image, x, y, crop_size)
            # 剪裁trimap,到指定剪裁尺寸crop_size,并缩放到(img_rows,img_cols)
            trimap = safe_crop(trimap, x, y, crop_size)

            if np.random.random_sample() > 0.5:
                image = np.fliplr(image)
                trimap = np.fliplr(trimap)

            ### save the image/trimap crop patch
            # patch_save_dir = "show_data_loader"
            # image_patch_path = "show_data_loader" + '/' + img_name_prefix + '_image_' + str(i_batch) + '.png'
            # trimap_patch_path = "show_data_loader" + '/' + img_name_prefix + '_trimap_' + str(i_batch) + '.png'
            # cv2.imwrite(image_patch_path,image)
            # cv2.imwrite(trimap_patch_path,trimap)

            batch_x[i_batch] = image / 255.0
            batch_y[i_batch] = make_trimap_for_batch_y(trimap)

            i += 1

        return batch_x, batch_y