def test_load_drawings(): qd = QuickDrawData() qd.load_drawings(["anvil", "ant"]) assert qd.loaded_drawings == ["anvil", "ant"] qd.get_drawing("angel") assert qd.loaded_drawings == ["anvil", "ant", "angel"]
class QuickDrawDataset(data.Dataset): def __init__(self, root, classes, transform): self.classes = classes self.labels = torch.arange(len(classes)) self.transform = transform self.qdd = QuickDrawData(recognized=True, max_drawings=10000, cache_dir=root) self.qdd.load_drawings(classes) def __getitem__(self, idx): c = self.classes[idx%len(self.classes)] label = self.labels[idx%len(self.classes)] img = self.qdd.get_drawing(c).image if self.transform: img = self.transform(img) return img, label def __len__(self): return 10000
def get_drawing(self, recognized=True): dirname = os.path.dirname(__file__) cache_dir = os.path.join(dirname, '.quickdrawcache') qd = QuickDrawData(max_drawings=1, cache_dir=cache_dir) group_name = random.choice(qd.drawing_names) qd.load_drawings([group_name]) drawing_data = qd.get_drawing(group_name) while drawing_data.recognized != recognized: group_name = random.choice(qd.drawing_names) qd.load_drawings([group_name]) drawing_data = qd.get_drawing(group_name) drawing_size = drawing_data.get_image().size drawing = Image.new('RGB', drawing_size, 'white') draw = ImageDraw.Draw(drawing) frame = copy.deepcopy(drawing) frames = [] for stroke in drawing_data.strokes: draw.line(stroke, 'black') frame = copy.deepcopy(drawing) frames.append(frame) return frames, drawing_data