示例#1
0
def test_load_drawings():
    qd = QuickDrawData()
    qd.load_drawings(["anvil", "ant"])
    assert qd.loaded_drawings == ["anvil", "ant"]

    qd.get_drawing("angel")
    assert qd.loaded_drawings == ["anvil", "ant", "angel"]
示例#2
0
class QuickDrawDataset(data.Dataset):
    def __init__(self, root, classes, transform):
        self.classes = classes
        self.labels = torch.arange(len(classes))
        self.transform = transform
        self.qdd = QuickDrawData(recognized=True, max_drawings=10000, cache_dir=root)
        self.qdd.load_drawings(classes)

    def __getitem__(self, idx):
        c = self.classes[idx%len(self.classes)]
        label = self.labels[idx%len(self.classes)]
        img = self.qdd.get_drawing(c).image
        if self.transform:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return 10000
示例#3
0
 def get_drawing(self, recognized=True):
     dirname = os.path.dirname(__file__)
     cache_dir = os.path.join(dirname, '.quickdrawcache')
     qd = QuickDrawData(max_drawings=1, cache_dir=cache_dir)
     group_name = random.choice(qd.drawing_names)
     qd.load_drawings([group_name])
     drawing_data = qd.get_drawing(group_name)
     while drawing_data.recognized != recognized:
         group_name = random.choice(qd.drawing_names)
         qd.load_drawings([group_name])
         drawing_data = qd.get_drawing(group_name)
     drawing_size = drawing_data.get_image().size
     drawing = Image.new('RGB', drawing_size, 'white')
     draw = ImageDraw.Draw(drawing)
     frame = copy.deepcopy(drawing)
     frames = []
     for stroke in drawing_data.strokes:
         draw.line(stroke, 'black')
         frame = copy.deepcopy(drawing)
         frames.append(frame)
     return frames, drawing_data