def test_get_specific_drawing():
    qdg = QuickDrawDataGroup("anvil")

    # get the first anvil drawing and test the values
    d = qdg.get_drawing(0)
    assert d.name == "anvil"
    assert d.key_id == 5355190515400704
    assert d.recognized == True
    assert d.countrycode == "PL"
    assert d.timestamp == 1488368345

    # 1 stroke, 2 x,y coords, 33 points
    assert len(d.image_data) == 1
    assert len(d.image_data[0]) == 2
    assert len(d.image_data[0][0]) == 33
    assert len(d.image_data[0][1]) == 33

    assert d.no_of_strokes == 1
    assert len(d.strokes) == 1
    assert len(d.strokes[0]) == 33
    assert len(d.strokes[0][0]) == 2

    assert isinstance(d.image, Image)
    assert isinstance(
        d.get_image(stroke_color=(10, 10, 10),
                    stroke_width=4,
                    bg_color=(200, 200, 200)), Image)
Ejemplo n.º 2
0
class QuickDrawDataGroupDataset(Dataset):
    def __init__(self,
                 name: str,
                 max_drawings: int = 1000,
                 recognized: Optional[bool] = None,
                 transform: Callable[[QuickDrawing], torch.Tensor] = None):

        self.ds = QuickDrawDataGroup(name,
                                     max_drawings=max_drawings,
                                     recognized=recognized)
        if transform is None:
            self.transform = lambda x: x
        else:
            self.transform = transform

    def __getitem__(self, index):
        return self.transform(self.ds.get_drawing(index))

    def __len__(self):
        return self.ds.drawing_count
def test_get_random_drawing():
    qdg = QuickDrawDataGroup("anvil")

    d = qdg.get_drawing(0)
    assert d.name == "anvil"
    assert isinstance(d.key_id, int)
    assert isinstance(d.recognized, bool)
    assert isinstance(d.timestamp, int)
    assert isinstance(d.countrycode, str)

    assert isinstance(d.image_data, list)
    assert len(d.image_data) == d.no_of_strokes

    assert isinstance(d.strokes, list)
    assert len(d.strokes) == d.no_of_strokes
    for stroke in d.strokes:
        for point in stroke:
            assert len(point) == 2

    assert isinstance(d.image, Image)
    assert isinstance(
        d.get_image(stroke_color=(10, 10, 10),
                    stroke_width=4,
                    bg_color=(200, 200, 200)), Image)
Ejemplo n.º 4
0
    for drawing in qd.drawing_names:  #list of names of drawings
        name = qd.get_drawing(qd.drawing_names[drawing])  #gets drawing name
        for i in len(name.image_data):  #loops through coords of all drawings
            for j in range(1):
                drawing_1 = np.array(name.image_data[i][0])
                drawing_2 = np.array(name.image_data[i][1])
        return (xcoords, ycoords)


""" for stroke in anvil.strokes:
    for x, y in stroke:
        print("x={} y={}".format(x, y)) """

#VERSION2 (for anvil drawings)

draw = anvils.get_drawing(index=2)

#first = anvils.get_drawing(index = 0)
#second = anvils.get_drawing(index = 1)
#drawing_1 = first.strokes
#print("dist = ", directed_hausdorff(u,v))

#print("out = ", list(zip(*x)))
""" compress = [item for sublist in first.strokes for item in sublist]
comp = [item for sublist in second.strokes for item in sublist]
a = np.array(list(zip(*compress)))
b = np.array(list(zip(*comp))) """
""" list(zip(*sum(b,[])))
list(zip(*chain.from_iterable(b))) """

#t = a[:,:25]
def load_data_by_label(label_name, num_samples):
    # print(f"Loading data for label: {label_name}")
    qdg = QuickDrawDataGroup(name=label_name, recognized=True, max_drawings=num_samples)
    data = torch.stack([transforms.ToTensor()(qdg.get_drawing(i).image.convert("L")) for i in range(num_samples)])
    return data
from quickdraw import QuickDrawDataGroup

qdg = QuickDrawDataGroup("anvil", cache_dir="C:\\path\\to\\cache")
print(qdg.drawing_count)
print(qdg.get_drawing())
Ejemplo n.º 7
0
from quickdraw import QuickDrawDataGroup
import json
from trained_qd_model import *

qd = QuickDrawData()

#####

# load trained model, get arbitrary image, pass image into trained model, get weights, save weights to a file

# load model
# model = load_model('trained_quickdraw.model')
# model.summary()

arms = QuickDrawDataGroup("arm")
arm = arms.get_drawing()
# extract key id from arm
# print(arm)
#x = arms.search_drawings(key_id= int(5778229946220544))
# print the output?
# x[0]
arm.image.save('arm2.png')
im = Image.open('arm2.png')
im = im.resize((92, 92), Image.ANTIALIAS)
im.save('armcopy.png')
# print('width: %d - height: %d' % im.size)

img_width = 28
img_height = 28

# store the label codes in a dictionary
from quickdraw import QuickDrawDataGroup

anvils = QuickDrawDataGroup("anvil")
print(anvils.drawing_count)
print(anvils.get_drawing())
Ejemplo n.º 9
0
qd = QuickDrawData(recognized=None, max_drawings=1000, refresh_data=False, jit_loading=True, print_messages=False, cache_dir='./.quickdrawcache')

print(f'Total categories count {len(qd.drawing_names)}')
currentIteration = 0


def rgb2gray(rgb):
    return (np.dot(rgb[...,:3], [0.298, 0.586, 0.143])/ 255).clip(0, 1)


all_drawings = []
for name in qd.drawing_names:

	group =  QuickDrawDataGroup(name)
	for i in range(0, args.count):
		data_point = group.get_drawing()
		img = data_point.image.resize((32,32))
		all_drawings.append(rgb2gray(np.array(img)))
		#img.save(f'{args.output}/{name}_{data_point.key_id}.png')
	if args.verbose:
		currentIteration += 1
		os.system("cls")
		print(f'Images loaded [{name}]: {currentIteration}/{len(qd.drawing_names)}')

all_drawings = np.array(all_drawings)



np.save(f'{args.output}/image_set_d{len(qd.drawing_names)}_c{args.count}', all_drawings)