def test_get_specific_drawing():
    qdg = QuickDrawDataGroup("anvil")

    # get the first anvil drawing and test the values
    d = qdg.get_drawing(0)
    assert d.name == "anvil"
    assert d.key_id == 5355190515400704
    assert d.recognized == True
    assert d.countrycode == "PL"
    assert d.timestamp == 1488368345

    # 1 stroke, 2 x,y coords, 33 points
    assert len(d.image_data) == 1
    assert len(d.image_data[0]) == 2
    assert len(d.image_data[0][0]) == 33
    assert len(d.image_data[0][1]) == 33

    assert d.no_of_strokes == 1
    assert len(d.strokes) == 1
    assert len(d.strokes[0]) == 33
    assert len(d.strokes[0][0]) == 2

    assert isinstance(d.image, Image)
    assert isinstance(
        d.get_image(stroke_color=(10, 10, 10),
                    stroke_width=4,
                    bg_color=(200, 200, 200)), Image)
    def __init__(self,
                 name: str,
                 max_drawings: int = 1000,
                 recognized: Optional[bool] = None,
                 transform: Callable[[QuickDrawing], torch.Tensor] = None):

        self.ds = QuickDrawDataGroup(name,
                                     max_drawings=max_drawings,
                                     recognized=recognized)
        if transform is None:
            self.transform = lambda x: x
        else:
            self.transform = transform
def test_get_random_drawing():
    qdg = QuickDrawDataGroup("anvil")

    d = qdg.get_drawing(0)
    assert d.name == "anvil"
    assert isinstance(d.key_id, int)
    assert isinstance(d.recognized, bool)
    assert isinstance(d.timestamp, int)
    assert isinstance(d.countrycode, str)

    assert isinstance(d.image_data, list)
    assert len(d.image_data) == d.no_of_strokes

    assert isinstance(d.strokes, list)
    assert len(d.strokes) == d.no_of_strokes
    for stroke in d.strokes:
        for point in stroke:
            assert len(point) == 2

    assert isinstance(d.image, Image)
    assert isinstance(
        d.get_image(stroke_color=(10, 10, 10),
                    stroke_width=4,
                    bg_color=(200, 200, 200)), Image)
def test_unrecognized_data():
    qdg = QuickDrawDataGroup("anvil", recognized=False)
    assert qdg.drawing_count == 1000

    rec = 0
    unrec = 0

    for drawing in qdg.drawings:
        if drawing.recognized:
            rec += 1
        else:
            unrec += 1

    assert rec == 0
    assert unrec == qdg.drawing_count
def test_search_drawings():
    qdg = QuickDrawDataGroup("anvil")
    # test a search with no criteria returns 1000 results
    r = qdg.search_drawings()
    assert len(r) == 1000

    # test a recognized search
    r = qdg.search_drawings(recognized=True)
    for d in r:
        assert d.recognized

    r = qdg.search_drawings(recognized=False)
    for d in r:
        assert not d.recognized

    # test a country search
    r = qdg.search_drawings(countrycode="US")
    for d in r:
        assert d.countrycode == "US"

    # pull first drawing
    key_id = r[0].key_id
    timestamp = r[0].timestamp

    # test key_id search
    r = qdg.search_drawings(key_id=key_id)
    for d in r:
        assert d.key_id == key_id

    # test timestamp search
    r = qdg.search_drawings(timestamp=timestamp)
    for d in r:
        assert d.timestamp == timestamp

    # test a compound search of recognized and country code
    r = qdg.search_drawings(recognized=True, countrycode="US")
    for d in r:
        assert d.recognized
        assert d.countrycode == "US"
    def __init__(self,
                 recognized: Optional[bool] = None,
                 transform: Callable[[QuickDrawing], torch.Tensor] = None):

        self.qd = QuickDrawData()
        self.qd_class_names = self.qd.drawing_names

        # dictionary of QuickDrawDataGroups based on all possible names, loads 1000 examples from each class, but can
        # be changed by specifying max_drawings
        self.qd_DataGroups = {
            name: QuickDrawDataGroup(name, recognized=recognized)
            for name in self.qd_class_names
        }

        if transform is None:
            self.transform = lambda x: x
        else:
            self.transform = transform
class QuickDrawDataGroupDataset(Dataset):
    def __init__(self,
                 name: str,
                 max_drawings: int = 1000,
                 recognized: Optional[bool] = None,
                 transform: Callable[[QuickDrawing], torch.Tensor] = None):

        self.ds = QuickDrawDataGroup(name,
                                     max_drawings=max_drawings,
                                     recognized=recognized)
        if transform is None:
            self.transform = lambda x: x
        else:
            self.transform = transform

    def __getitem__(self, index):
        return self.transform(self.ds.get_drawing(index))

    def __len__(self):
        return self.ds.drawing_count
Esempio n. 8
0
def main(cache_dir, save_img_dir, img_format, num_workers, categories,
         drawing_recognized):
    work = Queue()
    # Start worker threads
    for _ in range(num_workers):
        w = Thread(target=worker_thread,
                   args=(work, save_img_dir, img_format),
                   daemon=True)
        w.start()

    # Start feeding in work
    for name in categories:
        group = QuickDrawDataGroup(name=name,
                                   recognized=drawing_recognized,
                                   cache_dir=cache_dir,
                                   max_drawings=float('inf'))
        for i, drawing in enumerate(group.drawings):
            # Wait until some work gets done before continuing
            while work.qsize() > 10000:
                sleep(0.1)
            work.put((name, i, drawing))

    # Wait for work to finish
    work.join()
def test_get_data_group():
    qdg = QuickDrawDataGroup("anvil")
    assert qdg.drawing_count == 1000

    qdg = QuickDrawDataGroup("anvil", max_drawings=2000)
    assert qdg.drawing_count == 2000
Esempio n. 10
0
from quickdraw import QuickDrawData
from quickdraw import QuickDrawDataGroup
from scipy.spatial.distance import directed_hausdorff
import numpy as np
from itertools import chain

qd = QuickDrawData()
anvil = qd.get_drawing("anvil")
anvils = QuickDrawDataGroup("anvil")
#anvil = anvils.get_drawing()

#iterate through drawings
#get one drawing group
#multiple drawings within that group (for one name, ie airplane)
#get coordinates of first drawing and second drawing
#hausdorff distance to get min distance across all drawings in that group
#save distances comparing one drawing per group to the next drawing in that group
#find the minimum distance

#VERSION1


def get_drawing(qd):
    for drawing in qd.drawing_names:  #list of names of drawings
        name = qd.get_drawing(qd.drawing_names[drawing])  #gets drawing name
        for i in len(name.image_data):  #loops through coords of all drawings
            for j in range(1):
                drawing_1 = np.array(name.image_data[i][0])
                drawing_2 = np.array(name.image_data[i][1])
        return (xcoords, ycoords)
def load_data_by_label(label_name, num_samples):
    # print(f"Loading data for label: {label_name}")
    qdg = QuickDrawDataGroup(name=label_name, recognized=True, max_drawings=num_samples)
    data = torch.stack([transforms.ToTensor()(qdg.get_drawing(i).image.convert("L")) for i in range(num_samples)])
    return data
Esempio n. 12
0
from config import *

import os
from quickdraw import QuickDrawDataGroup

with open(os.path.join(DATASET_ROOT, 'categories.txt')) as f:
    for name in f.readlines():
        name = name.strip()
        data = QuickDrawDataGroup(name,
                                  max_drawings=SAMPLES_IN_CLASS,
                                  recognized=True)

        i = 0
        for sample in data.drawings:
            if i < train_samples_num:
                if not os.path.isdir(
                        os.path.join(DATASET_ROOT, 'train/' + name)):
                    os.mkdir(os.path.join(DATASET_ROOT, 'train/' + name))
                sample.image.save(DATASET_ROOT + '/train/' + name + '/' +
                                  str(sample.key_id) + '.png')
                i += 1
            else:
                if not os.path.isdir(os.path.join(DATASET_ROOT,
                                                  'test/' + name)):
                    os.mkdir(os.path.join(DATASET_ROOT, 'test/' + name))
                sample.image.save(DATASET_ROOT + '/test/' + name + '/' +
                                  str(sample.key_id) + '.png')
def test_drawings():
    qdg = QuickDrawDataGroup("anvil")
    count = 0
    for drawing in qdg.drawings:
        count += 1
    assert count == 1000
from quickdraw import QuickDrawDataGroup

anvils = QuickDrawDataGroup("anvil")
print(anvils.drawing_count)
print(anvils.get_drawing())
from quickdraw import QuickDrawDataGroup
anvils = QuickDrawDataGroup("anvil")
results = anvils.search_drawings(countrycode="PL")
Esempio n. 16
0
from quickdraw import QuickDrawDataGroup
from PIL import Image, ImageDraw, ImageOps
import cv2
import numpy as np
from color_it import ColorIt

qdg = QuickDrawDataGroup("circle")
results = qdg.search_drawings(key_id=int(4865367048454144))
anvil = results[0]

#anvil_image = Image.new("RGB", (255,255), color=(255,255,255))
#anvil_drawing = ImageDraw.Draw(anvil_image)

# for stroke in anvil.strokes:
#     # anvil_drawing.line(stroke, fill=(0,0,0), width=2)

#     for coordinate in range(len(stroke)-1):
#         x1 = stroke[coordinate][0]
#         y1 = stroke[coordinate][1]
#         x2 = stroke[coordinate+1][0]
#         y2 = stroke[coordinate+1][1]
#         anvil_drawing.line((x1,y1,x2,y2), fill=(0,0,0), width=4)

#old_size = anvil_image.size  # old_size[0] is in (width, height) format

#clearnew_im = ImageOps.expand(anvil_image,border = 10, fill = (0,0,0))

#new_im.show()

c = ColorIt(anvil)
c.fill(3)
from quickdraw import QuickDrawDataGroup

qdg = QuickDrawDataGroup("anvil", cache_dir="C:\\path\\to\\cache")
print(qdg.drawing_count)
print(qdg.get_drawing())
Esempio n. 18
0
from quickdraw import QuickDrawDataGroup

anvils = QuickDrawDataGroup("anvil")
for anvil in anvils.drawings:
    print(anvil)
Esempio n. 19
0
from quickdraw import QuickDrawData
from quickdraw import QuickDrawDataGroup
import json
from trained_qd_model import *

qd = QuickDrawData()

#####

# load trained model, get arbitrary image, pass image into trained model, get weights, save weights to a file

# load model
# model = load_model('trained_quickdraw.model')
# model.summary()

arms = QuickDrawDataGroup("arm")
arm = arms.get_drawing()
# extract key id from arm
# print(arm)
#x = arms.search_drawings(key_id= int(5778229946220544))
# print the output?
# x[0]
arm.image.save('arm2.png')
im = Image.open('arm2.png')
im = im.resize((92, 92), Image.ANTIALIAS)
im.save('armcopy.png')
# print('width: %d - height: %d' % im.size)

img_width = 28
img_height = 28
from quickdraw import QuickDrawDataGroup
from quickdraw import QuickDrawData
import sys
import os

qd = QuickDrawData()

##for drawing_num, drawing_name in enumerate(qd.drawing_names):
##    if drawing_num < 86:
##        continue
##    print("Saving " + drawing_name + " images")
##    qdg = QuickDrawDataGroup(drawing_name, max_drawings=20)
##    directory_name = 'images/' + drawing_name
##    if not os.path.isdir(os.path.join(os.getcwd(), directory_name)):
##        os.mkdir(directory_name)
##    for drawing_count, drawing in enumerate(qdg.drawings, 1):
##        drawing.image.save(directory_name + "/" + drawing_name + "{:0>4d}.png".format(drawing_count))

drawing_name = "face"
print("Saving " + drawing_name + " images")
qdg = QuickDrawDataGroup(drawing_name, max_drawings=300)
directory_name = 'images/' + drawing_name
if not os.path.isdir(os.path.join(os.getcwd(), directory_name)):
    os.mkdir(directory_name)
for drawing_count, drawing in enumerate(qdg.drawings, 1):
    drawing.image.save(directory_name + "/" + drawing_name +
                       "{:0>4d}.png".format(drawing_count))
Esempio n. 21
0
print(args)
qd = QuickDrawData(recognized=None, max_drawings=1000, refresh_data=False, jit_loading=True, print_messages=False, cache_dir='./.quickdrawcache')

print(f'Total categories count {len(qd.drawing_names)}')
currentIteration = 0


def rgb2gray(rgb):
    return (np.dot(rgb[...,:3], [0.298, 0.586, 0.143])/ 255).clip(0, 1)


all_drawings = []
for name in qd.drawing_names:

	group =  QuickDrawDataGroup(name)
	for i in range(0, args.count):
		data_point = group.get_drawing()
		img = data_point.image.resize((32,32))
		all_drawings.append(rgb2gray(np.array(img)))
		#img.save(f'{args.output}/{name}_{data_point.key_id}.png')
	if args.verbose:
		currentIteration += 1
		os.system("cls")
		print(f'Images loaded [{name}]: {currentIteration}/{len(qd.drawing_names)}')

all_drawings = np.array(all_drawings)



np.save(f'{args.output}/image_set_d{len(qd.drawing_names)}_c{args.count}', all_drawings)