def __getitem__(self, item):

        if self.mode == 'Train':
            sketch_path = self.Train_Sketch[item]

            vector_x = self.Coordinate[sketch_path]
            sketch_img, sketch_points = rasterize_Sketch(vector_x)
            sketch_img = Image.fromarray(sketch_img).convert('RGB')

            n_flip = random.random()
            if n_flip > 0.5:
                sketch_img = F.hflip(sketch_img)
                sketch_points[:, 0] = -sketch_points[:, 0] + 256.
            sketch_img = self.train_transform(sketch_img)

            sample = {
                'sketch_img': sketch_img,
                'sketch_path': sketch_path,
                'sketch_points': sketch_points,
                'sketch_label': self.name2num[sketch_path.split('/')[0]]
            }

        elif self.mode == 'Test':

            sketch_path = self.Test_Sketch[item]
            vector_x = self.Coordinate[sketch_path]

            sketch_img, sketch_points = rasterize_Sketch(vector_x)
            sketch_img = self.test_transform(
                Image.fromarray(sketch_img).convert('RGB'))

            sample = {
                'sketch_img': sketch_img,
                'sketch_path': sketch_path,
                'sketch_points': sketch_points,
                'sketch_label': self.name2num[sketch_path.split('/')[0]]
            }

        return sample
示例#2
0
    def __getitem__(self, item):

        if self.mode == 'Train':

            path = self.Train_Sketch[item]
            anchor_sketch_vector = self.Coordinate[path]
            anchor_sketch = Image.fromarray(
                rasterize_Sketch(anchor_sketch_vector)).convert('RGB')

            class_name = path.split('/')[0]
            sketch_positive_path = random.choice(self.seen_dict[class_name])
            sketch_positive_vector = self.Coordinate[sketch_positive_path]
            sketch_positive = Image.fromarray(
                rasterize_Sketch(sketch_positive_vector)).convert('RGB')

            possible_negative_class = random.choice(
                list(set(self.Seen_Class) - set(class_name)))
            sketch_negative_path = random.choice(
                self.seen_dict[possible_negative_class])
            sketch_negative_vector = self.Coordinate[sketch_negative_path]
            sketch_negative = Image.fromarray(
                rasterize_Sketch(sketch_negative_vector)).convert('RGB')

        else:

            path = self.Test_Sketch[item]
            anchor_sketch_vector = self.Coordinate[path]
            anchor_sketch = Image.fromarray(
                rasterize_Sketch(anchor_sketch_vector)).convert('RGB')

            class_name = path.split('/')[0]
            sketch_positive_path = random.choice(self.seen_dict[class_name])
            sketch_positive_vector = self.Coordinate[sketch_positive_path]
            sketch_positive = Image.fromarray(
                rasterize_Sketch(sketch_positive_vector)).convert('RGB')

            possible_negative_class = random.choice(
                list(set(self.Seen_Class) - set(class_name)))
            sketch_negative_path = random.choice(
                self.seen_dict[possible_negative_class])
            sketch_negative_vector = self.Coordinate[sketch_negative_path]
            sketch_negative = Image.fromarray(
                rasterize_Sketch(sketch_negative_vector)).convert('RGB')

        if self.hp.data_encoding_type == '3point':
            stroke_wise_split_anchor_list = np.split(
                anchor_sketch_vector,
                np.where(anchor_sketch_vector[:, 2])[0] + 1,
                axis=0)[:-1]
            stroke_wise_split_positive_list = np.split(
                sketch_positive_vector,
                np.where(sketch_positive_vector[:, 2])[0] + 1,
                axis=0)[:-1]
            stroke_wise_split_negative_list = np.split(
                sketch_negative_vector,
                np.where(sketch_negative_vector[:, 2])[0] + 1,
                axis=0)[:-1]

        elif self.hp.data_encoding_type == '5point':
            anchor_sketch_vector = self.to_delXY(anchor_sketch_vector)
            stroke_wise_split_anchor_list = np.split(
                anchor_sketch_vector,
                np.where(anchor_sketch_vector[:, 3])[0] + 1,
                axis=0)

            sketch_positive_vector = self.to_delXY(sketch_positive_vector)
            stroke_wise_split_positive_list = np.split(
                sketch_positive_vector,
                np.where(sketch_positive_vector[:, 3])[0] + 1,
                axis=0)

            sketch_negative_vector = self.to_delXY(sketch_negative_vector)
            stroke_wise_split_negative_list = np.split(
                sketch_negative_vector,
                np.where(sketch_negative_vector[:, 3])[0] + 1,
                axis=0)

        else:
            raise ValueError(
                'invalid option for --data_encoding_type. Valid options: 3point/5point'
            )

        stroke_wise_split_anchor = [
            torch.from_numpy(x) for x in stroke_wise_split_anchor_list
        ]
        every_stroke_len_anchor = [
            len(stroke) for stroke in stroke_wise_split_anchor
        ]
        num_stroke_per_anchor = len(every_stroke_len_anchor)
        assert sum(every_stroke_len_anchor) == anchor_sketch_vector.shape[0]

        stroke_wise_split_positive = [
            torch.from_numpy(x) for x in stroke_wise_split_positive_list
        ]
        every_stroke_len_positive = [
            len(stroke) for stroke in stroke_wise_split_positive
        ]
        num_stroke_per_positive = len(every_stroke_len_positive)
        assert sum(
            every_stroke_len_positive) == sketch_positive_vector.shape[0]

        stroke_wise_split_negative = [
            torch.from_numpy(x) for x in stroke_wise_split_negative_list
        ]
        every_stroke_len_negative = [
            len(stroke) for stroke in stroke_wise_split_negative
        ]
        num_stroke_per_negative = len(every_stroke_len_negative)
        assert sum(
            every_stroke_len_negative) == sketch_negative_vector.shape[0]

        sample = {
            'path': path,
            'label': self.name2num[class_name],
            'anchor_sketch_image': self.train_transform(anchor_sketch),
            'anchor_sketch_vector': anchor_sketch_vector,
            'num_stroke_per_anchor': num_stroke_per_anchor,
            'every_stroke_len_anchor': every_stroke_len_anchor,
            'stroke_wise_split_anchor': stroke_wise_split_anchor,
            'sketch_positive': self.train_transform(sketch_positive),
            'sketch_positive_vector': sketch_positive_vector,
            'num_stroke_per_positive': num_stroke_per_positive,
            'every_stroke_len_positive': every_stroke_len_positive,
            'stroke_wise_split_positive': stroke_wise_split_positive,
            'sketch_negative': self.train_transform(sketch_negative),
            'sketch_negative_vector': sketch_negative_vector,
            'num_stroke_per_negative': num_stroke_per_negative,
            'every_stroke_len_negative': every_stroke_len_negative,
            'stroke_wise_split_negative': stroke_wise_split_negative
        }

        return sample
            print(i_rdp, num, key)
            sketch_points = Coordinate[key]
            sketch_points_orig = sketch_points

            sketch_points = sketch_points.astype(np.float)
            # sketch_points[:, :2] = sketch_points[:, :2] / np.array([800, 800])
            # sketch_points[:, :2] = sketch_points[:, :2] * 256
            sketch_points = np.round(sketch_points)

            all_strokes = np.split(sketch_points,
                                   np.where(sketch_points[:, 2])[0] + 1,
                                   axis=0)[:-1]

            max_points_old.append(sketch_points_orig.shape)

            sketch_img_orig = rasterize_Sketch(sketch_points_orig)
            sketch_img_orig = Image.fromarray(sketch_img_orig).convert('RGB')
            # sketch_img_orig.show()
            sketch_img_orig.save(str(i_rdp) + '/' + str(num) + '.jpg')

            sketch_points_sampled_new = []
            for stroke in all_strokes:
                stroke_new = rdp(stroke[:, :2], epsilon=i_rdp, algo="iter")
                stroke_new = np.hstack(
                    (stroke_new, np.zeros((stroke_new.shape[0], 1))))
                stroke_new[-1, -1] = 1.
                # print(stroke_new.shape, stroke.shape)
                sketch_points_sampled_new.append(stroke_new)
            sketch_points_new = np.vstack(sketch_points_sampled_new)

            max_points_new.append(sketch_points_new.shape[0])
示例#4
0
def show(step, batch, model, save_path):
    if not os.path.isdir(save_path + '/' + time_id):
        os.makedirs(save_path + '/' + time_id)

    output_anc, num_stroke_anc = model.Network(batch,
                                               type='anchor')  # b,N1,512
    output_pos, num_stroke_pos = model.Network(batch,
                                               type='positive')  # b,N2,512
    mask_anc, mask_pos = map(make_mask, [num_stroke_anc, num_stroke_pos])
    corr_xpos = model.neighbour(output_anc, output_pos, mask_anc, mask_pos)
    '''
    I need the n1 x n2 matrix
    For every stroke in N1, the stroke having highest correlation will be painted.
    '''
    i_sample = 0  # range = 0 to batch-size
    A = corr_xpos[i_sample]  # Taking 1 sample out of the batch
    anc_max = torch.argmax(A, dim=1)  # 1 x N1

    anc_vec = batch['anchor_sketch_vector']
    anc_img_orig = Image.fromarray(
        255 - rasterize_Sketch(anc_vec[i_sample].numpy())).convert('RGB')
    anc_stroke_num = batch['num_stroke_per_anchor'][i_sample]
    start_index = sum(batch['num_stroke_per_anchor'][:i_sample])
    anc_sample = batch['stroke_wise_split_anchor'][start_index:start_index +
                                                   anc_stroke_num].numpy()

    pos_vec = batch['sketch_positive_vector']
    pos_img_orig = Image.fromarray(
        255 - rasterize_Sketch(pos_vec[i_sample].numpy())).convert('RGB')
    pos_stroke_num = batch['num_stroke_per_positive'][i_sample]
    start_index = sum(batch['num_stroke_per_positive'][:i_sample])
    pos_sample = batch['stroke_wise_split_positive'][start_index:start_index +
                                                     pos_stroke_num].numpy()

    im = Image.new('RGB',
                   (266 * anc_stroke_num, 5 + 261 * 2))  # width x height
    for i_stroke in range(anc_stroke_num):
        anc_img = anc_img_orig.copy()
        draw_anc = ImageDraw.Draw(anc_img)
        anc_stroke_points = np.where(
            255 - rasterize_Sketch(anc_sample[i_stroke]) == 0)
        for i_point in range(len(anc_stroke_points[0])):
            draw_anc.point(
                (anc_stroke_points[1][i_point], anc_stroke_points[0][i_point]),
                fill='red')

        pos_img = pos_img_orig.copy()
        draw_pos = ImageDraw.Draw(pos_img)
        pos_stroke_points = np.where(
            255 - rasterize_Sketch(pos_sample[anc_max[i_stroke].item()]) == 0)
        for i_point in range(len(pos_stroke_points[0])):
            draw_pos.point(
                (pos_stroke_points[1][i_point], pos_stroke_points[0][i_point]),
                fill='blue')

        im.paste(anc_img, (5 + i_stroke * 266, 5))  #width x height
        im.paste(pos_img, (5 + i_stroke * 266, 266))

    im.save(f'{save_path}/{time_id}/Step_{step}.png')

    return
    def __getitem__(self, item):
        sample = {}

        if self.mode == "Train":
            path = self.Train_List[item]

            if self.data == "coordinate":
                vector_x = self.Coordinate[path]
                sketch_img, sketch_points = rasterize_Sketch(vector_x)
                sketch_image = Image.fromarray(sketch_img).convert("RGB")
                sketch_classname = path.split("/")[0]

                positive_path = self.train_images_classes[
                    f"{sketch_classname}"][randint(
                        0,
                        len(self.train_images_classes[f"{sketch_classname}"]) -
                        1)]
                positive_classname = sketch_classname

                possible_list = list(self.train_images_classes.keys())
                possible_list.remove(sketch_classname)

                negative_classname = possible_list[randint(
                    0,
                    len(possible_list) - 1)]
                negative_path = self.train_images_classes[
                    f"{negative_classname}"][randint(
                        0,
                        len(self.train_images_classes[f"{negative_classname}"])
                        - 1)]
                # print(path,positive_path,negative_path)
                positive_img = Image.open(positive_path).convert("RGB")
                negative_img = Image.open(negative_path).convert("RGB")

                n_flip = random.random()
                if n_flip > 0.5:
                    sketch_image = F.hflip(sketch_image)
                    positive_img = F.hflip(positive_img)
                    negative_img = F.hflip(negative_img)

                sketch_image = self.train_transform(sketch_image)
                positive_img = self.train_transform(positive_img)
                negative_img = self.train_transform(negative_img)

                sample = {
                    "sketch_img": sketch_image,
                    "sketch_label": self.name2num[sketch_classname],
                    "sketch_points": sketch_points,
                    "positive_img": positive_img,
                    "positive_label": self.name2num[positive_classname],
                    "negative_img": negative_img,
                    "negative_label": self.name2num[negative_classname],
                }

        elif self.mode == "Test":
            path = self.Test_List[item]
            if self.data == "coordinate":
                vector_x = self.Coordinate[path]
                sketch_img, sketch_points = rasterize_Sketch(vector_x)
                sketch_image = Image.fromarray(sketch_img).convert("RGB")
                sketch_classname = path.split("/")[0]

                positive_path = self.test_images_classes[
                    f"{sketch_classname}"][randint(
                        0,
                        len(self.test_images_classes[f"{sketch_classname}"]) -
                        1)]
                positive_classname = sketch_classname
                positive_img = Image.open(positive_path).convert("RGB")

                sketch_image = self.test_transform(sketch_image)
                positive_image = self.test_transform(positive_img)

                # sample = {
                #     "sketch_img": sketch_image,
                #     "sketch_label": self.name2num[sketch_classname],
                #     "positive_img": positive_image,
                #     "positive_label": self.name2num[positive_classname],
                #     "test_data":self.test_images_values,
                #     "name2num":self.name2num
                # }
                sample = {
                    "sketch_img": sketch_image,
                    "sketch_label": self.name2num[sketch_classname],
                    "sketch_points": sketch_points,
                    "positive_img": positive_image,
                }

        elif self.mode == "Test_photo":
            path = self.test_photo_all[item]
            positive_img = Image.open(path).convert("RGB")
            positive_image = self.test_transform(positive_img)
            positive_classname = path.split('/')[-2]
            positive_label = self.name2num[positive_classname]

            sample = {
                "positive_img": positive_image,
                "positive_label": positive_label,
                "path": path
            }

        return sample