コード例 #1
0
def download_quickdraw_414k(output_folder="./", remove_sourcefile=True):

    print(torchsketch())
    time.sleep(2)

    print_banner("Downloading started......")

    try:
        gdown.download(
            "https://drive.google.com/uc?id=1q933KpmJGkfStgIbwMfgfls1_1ZJVFyd",
            os.path.join(output_folder, "picture_files.tar.gz"),
            quiet=False)
        gdown.download(
            "https://drive.google.com/uc?id=1Vrf1ouhtWYJp4XKa6jestLGY3aVlxfLC",
            os.path.join(output_folder, "coordinate_files.tar.gz"),
            quiet=False)

    except Exception as e:
        print(e)
        return

    print("\n")
    print_banner(
        "quickdraw_414k is downloaded {}!".format('\033[32;1m' +
                                                  'succesfully' + '\033[0m'))

    assert md5sum(os.path.join(
        output_folder, "picture_files.tar.gz")) == MD5SUM_TABLE[
            "picture_files.tar.gz"], 'picture_files.tar.gz md5 checksum error'
    assert md5sum(
        os.path.join(output_folder, "coordinate_files.tar.gz")
    ) == MD5SUM_TABLE[
        "coordinate_files.tar.gz"], 'coordinate_files.tar.gz md5 checksum error'

    print_banner("Md5 checksum passed {}!".format('\033[32;1m' +
                                                  'succesfully' + '\033[0m'))

    print_banner("Extracting started......")

    try:
        extract_files(file_name=os.path.join(output_folder,
                                             "picture_files.tar.gz"),
                      output_folder=output_folder,
                      remove_sourcefile=remove_sourcefile)
        extract_files(file_name=os.path.join(output_folder,
                                             "coordinate_files.tar.gz"),
                      output_folder=output_folder,
                      remove_sourcefile=remove_sourcefile)

    except Exception as e:
        print(e)
        return

    print_banner(
        "quickdraw_414k is extracted {}!".format('\033[32;1m' + 'succesfully' +
                                                 '\033[0m'))

    print(torchsketch())
コード例 #2
0
def randomly_remove_strokes_4_svg(svg_url,
                                  output_folder,
                                  remove_ratio=0.3,
                                  verbose=True):

    if remove_ratio <= 0.0 or remove_ratio >= 1.0:
        print(
            "The input parameter of remove_ratio should be defined in (0.0, 1.0)."
        )
        print("Please choose a proper remove_ratio value, and retry this API.")
        print(torchsketch())
        return 0

    doc = minidom.parse(svg_url)
    paths = doc.getElementsByTagName("path")

    assert type(paths) == xml.dom.minicompat.NodeList
    stroke_count = len(paths)

    if verbose == True:
        print("{} has {} storkes.".format(svg_url, stroke_count))

    reset_dir(output_folder)

    svg_filename = svg_url.split("/")[-1][:-4]

    seed = int(time.time())
    np.random.seed(seed)

    index_list = np.arange(stroke_count)

    np.random.shuffle(index_list)

    random_list = index_list[:int(stroke_count * remove_ratio)]

    try:
        for i in random_list:
            doc.getElementsByTagName("path")[int(i)].setAttribute(
                "visibility", "hidden")

        output = open(
            output_folder + "/" + svg_filename + "_" +
            "randomly_partially_removed" + ".svg", "w")
        doc.writexml(output)
        output.close()

    except Exception as e:
        print(e)

    if verbose == True:
        print("Details are as follows.")
        print("{} strokes have been removed RANDOMLY.".format(
            int(stroke_count * remove_ratio)))
        print("The produced sketch is stored in {}".format(output_folder))
        print(torchsketch())

    return stroke_count
コード例 #3
0
def download_tu_berlin(output_folder="./", remove_sourcefile=True):

    print(torchsketch())
    time.sleep(2)

    print_banner("Downloading started......")

    try:
        wget.download(
            "http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/sketches_svg.zip",
            out=os.path.join(output_folder, "sketches_svg.zip"))
        wget.download(
            "http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/sketches_png.zip",
            out=os.path.join(output_folder, "sketches_png.zip"))

    except Exception as e:
        print(e)
        return

    print("\n")
    print_banner(
        "TU-Berlin is downloaded {}!".format('\033[32;1m' + 'succesfully' +
                                             '\033[0m'))

    assert md5sum(os.path.join(
        output_folder, "sketches_svg.zip")) == MD5SUM_TABLE[
            "sketches_svg.zip"], 'sketches_svg.zip md5 checksum error'
    assert md5sum(os.path.join(
        output_folder, "sketches_png.zip")) == MD5SUM_TABLE[
            "sketches_png.zip"], 'sketches_png.zip md5 checksum error'

    print_banner("Md5 checksum passed {}!".format('\033[32;1m' +
                                                  'succesfully' + '\033[0m'))

    print_banner("Extracting started......")

    try:
        extract_files(file_name=os.path.join(output_folder,
                                             "sketches_svg.zip"),
                      output_folder=output_folder,
                      remove_sourcefile=remove_sourcefile)
        extract_files(file_name=os.path.join(output_folder,
                                             "sketches_png.zip"),
                      output_folder=output_folder,
                      remove_sourcefile=remove_sourcefile)

    except Exception as e:
        print(e)
        return

    print_banner(
        "TU-Berlin is extracted {}!".format('\033[32;1m' + 'succesfully' +
                                            '\033[0m'))

    print(torchsketch())
コード例 #4
0
def convert_svgs_2_pngs(svg_url_list, output_folder = None):

    for svg_url in svg_url_list:
        convert_svg_2_png(svg_url, output_folder)


    print(torchsketch())
コード例 #5
0
def convert_colors_4_svgs(svg_url_list, output_folder, colors = ["red", ], verbose = True):

    for svg_url in svg_url_list:
        convert_colors_4_svg(svg_url, output_folder, colors, verbose)



    print(torchsketch())
コード例 #6
0
def download_qmul_chair(output_folder="./", remove_sourcefile=True):

    print(torchsketch())
    time.sleep(2)

    print_banner("Downloading started......")

    try:
        gdown.download(
            "https://drive.google.com/uc?id=1xlBhTJQwtssi8oJHbYbqVer8ZmyEG3wq",
            os.path.join(output_folder, "chairs.zip"),
            quiet=False)

    except Exception as e:
        print(e)
        return

    print("\n")
    print_banner(
        "QMUL chair is downloaded {}!".format('\033[32;1m' + 'succesfully' +
                                              '\033[0m'))

    assert md5sum(
        os.path.join(output_folder, "chairs.zip")
    ) == MD5SUM_TABLE["chairs.zip"], 'chairs.zip md5 checksum error'

    print_banner("Md5 checksum passed {}!".format('\033[32;1m' +
                                                  'succesfully' + '\033[0m'))

    print_banner("Extracting started......")

    try:
        extract_files(file_name=os.path.join(output_folder, "chairs.zip"),
                      output_folder=output_folder,
                      remove_sourcefile=remove_sourcefile)

    except Exception as e:
        print(e)
        return

    print_banner(
        "QMUL chair is extracted {}!".format('\033[32;1m' + 'succesfully' +
                                             '\033[0m'))

    print(torchsketch())
コード例 #7
0
def download_qmul_shoe(output_folder="./", remove_sourcefile=True):

    print(torchsketch())
    time.sleep(2)

    print_banner("Downloading started......")

    try:
        gdown.download(
            "https://drive.google.com/uc?id=1S9lHUzdgR9yIRuIAE0kTw_odWVzb0Q5l",
            os.path.join(output_folder, "shoes.zip"),
            quiet=False)

    except Exception as e:
        print(e)
        return

    print("\n")
    print_banner(
        "QMUL shoe is downloaded {}!".format('\033[32;1m' + 'succesfully' +
                                             '\033[0m'))

    assert md5sum(
        os.path.join(output_folder, "shoes.zip")
    ) == MD5SUM_TABLE["shoes.zip"], 'shoes.zip md5 checksum error'

    print_banner("Md5 checksum passed {}!".format('\033[32;1m' +
                                                  'succesfully' + '\033[0m'))

    print_banner("Extracting started......")

    try:
        extract_files(file_name=os.path.join(output_folder, "shoes.zip"),
                      output_folder=output_folder,
                      remove_sourcefile=remove_sourcefile)

    except Exception as e:
        print(e)
        return

    print_banner(
        "QMUL shoe is extracted {}!".format('\033[32;1m' + 'succesfully' +
                                            '\033[0m'))

    print(torchsketch())
コード例 #8
0
def convert_stroke_width_4_svgs(svg_url_list,
                                output_folder,
                                stroke_width=6,
                                verbose=True):

    for svg_url in svg_url_list:
        convert_stroke_width_4_svg(svg_url, output_folder, stroke_width,
                                   verbose)

    print(torchsketch())
コード例 #9
0
def convert_colors_4_svg(svg_url,
                         output_folder,
                         colors=[
                             "red",
                         ],
                         verbose=True):

    doc = minidom.parse(svg_url)
    paths = doc.getElementsByTagName("path")

    assert type(paths) == xml.dom.minicompat.NodeList
    stroke_count = len(paths)

    if verbose == True:
        print("{} has {} storkes.".format(svg_url, len(paths)))

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    produced_count = 0

    svg_filename = svg_url.split("/")[-1][:-4]

    for color in colors:

        if color not in COLOR_TABLE.keys():
            continue

        try:
            for i in range(stroke_count):
                doc.getElementsByTagName("path")[i].setAttribute(
                    "stroke", color)
        except Exception as e:
            print(e)
            continue

        output = open(
            output_folder + "/" + svg_filename + "_" + str(color) + ".svg",
            "w")
        doc.writexml(output)
        output.close()

        produced_count += 1

    if verbose == True:
        print("Details are as follows.")
        print("{} colors have been inputted.".format(len(colors)))
        print(
            "{} colorful sketches have been produced.".format(produced_count))
        print("{} colors can not been recognized.".format(
            len(colors) - produced_count))
        print(torchsketch())

    return produced_count
コード例 #10
0
def count_strokes_4_svg(svg_url, verbose=True):

    doc = minidom.parse(svg_url)
    path_strings = [
        path.getAttribute('d') for path in doc.getElementsByTagName('path')
    ]
    doc.unlink()

    assert type(path_strings) == list

    if verbose == True:
        print("{} has {} storkes.".format(svg_url, len(path_strings)))
        print(torchsketch())

    return len(path_strings)
コード例 #11
0
def randomly_convert_stroke_width_4_svg(svg_url, output_folder, stroke_width = 6, selected_strokes = 3, verbose = True):

    doc = minidom.parse(svg_url)
    paths = doc.getElementsByTagName("path")

    assert type(paths) == xml.dom.minicompat.NodeList
    stroke_count = len(paths)
    
    if verbose == True:
        print("{} has {} storkes.".format(svg_url, len(paths)))

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)


    svg_filename = svg_url.split("/")[-1][:-4]

    seed = int(time.time())
    np.random.seed(seed)

    index_list = np.arange(stroke_count)

    np.random.shuffle(index_list)

    random_list = index_list[ : selected_strokes]

    
    try:
        for i in random_list:
            doc.getElementsByTagName("path")[i].setAttribute("stroke-width", str(stroke_width))
    except Exception as e:
        print(e)


    output = open(output_folder + "/" + svg_filename + "_randomly_selected_" + str(selected_strokes) + "_strokes_stroke_width_" + str(stroke_width) + ".svg", "w")
    doc.writexml(output)
    output.close()


    if verbose == True:
        print("Details are as follows.")
        print("{} strokes have been selected randomly.".format(selected_strokes))
        print("The stroke widths have been modified as {}.".format(stroke_width))
        print(torchsketch())

    return stroke_width
コード例 #12
0
def convert_svg_2_accumulative_svgs(svg_url, output_folder, verbose=True):

    doc = minidom.parse(svg_url)
    paths = doc.getElementsByTagName("path")

    assert type(paths) == xml.dom.minicompat.NodeList
    stroke_count = len(paths)

    if verbose == True:
        print("{} has {} storkes.".format(svg_url, len(paths)))

    reset_dir(output_folder)

    svg_filename = svg_url.split("/")[-1][:-4]

    try:

        output = open(
            output_folder + "/" + svg_filename + "_" +
            str(stroke_count).zfill(6) + ".svg", "w")
        doc.writexml(output)
        output.close()

        for i in range(stroke_count - 1, 0, -1):

            doc.getElementsByTagName("path")[i].setAttribute(
                "visibility", "hidden")

            output = open(
                output_folder + "/" + svg_filename + "_" + str(i).zfill(6) +
                ".svg", "w")
            doc.writexml(output)
            output.close()

    except Exception as e:
        print(e)

    if verbose == True:
        print("Details are as follows.")
        print("{} svgs have been produced, and stored in {}.".format(
            stroke_count, output_folder))
        print(torchsketch())

    return stroke_count
コード例 #13
0
def convert_svg_2_gif(svg_url, output_folder, gif_fps=3, verbose=True):

    convert_svg_2_accumulative_svgs(svg_url=svg_url,
                                    output_folder=output_folder,
                                    verbose=True)

    if verbose == True:
        print("Accumulative svgs have been produced.")

    accumulative_svg_list = os.listdir(output_folder)
    accumulative_svg_list.sort(key=lambda i: int(i[-10:-4]))

    gif_frames = list()

    for accum_svg in accumulative_svg_list:

        accum_svg_url = os.path.join(output_folder, accum_svg)

        converted_pdf_url = accum_svg_url[:-4] + ".pdf"

        cairosvg.svg2pdf(url=accum_svg_url, write_to=converted_pdf_url)

        gif_frames.append(
            convert_from_path(pdf_path=converted_pdf_url, dpi=50)[0])

        os.remove(accum_svg_url)
        os.remove(converted_pdf_url)

    imageio.mimsave(output_folder + "/" + svg_url.split("/")[-1][:-4] + ".gif",
                    gif_frames,
                    fps=gif_fps)
    gif_frame_count = len(gif_frames)

    if verbose == True:
        print("Intermediate files have been well deleted.")
        print("The achieved gif file has {} frames in total.".format(
            gif_frame_count))
        print(torchsketch())

    return gif_frame_count
コード例 #14
0
def download_sketchy(output_folder="./", remove_sourcefile=True):

    print(torchsketch())
    time.sleep(2)

    print_banner("Downloading started......")

    try:
        gdown.download(
            "https://drive.google.com/uc?id=0B7ISyeE8QtDdbUpYWV8tcFJlY2M",
            os.path.join(output_folder, "sketches-06-04.7z"),
            quiet=False)
        gdown.download(
            "https://drive.google.com/uc?id=0B7ISyeE8QtDdTjE1MG9Gcy1kSkE",
            os.path.join(output_folder, "rendered_256x256.7z"),
            quiet=False)
        gdown.download(
            "https://drive.google.com/uc?id=0B7ISyeE8QtDdaFhqeTZiNVBYZjA",
            os.path.join(output_folder, "info-06-04.7z"),
            quiet=False)

    except Exception as e:
        print(e)
        return

    print("\n")
    print_banner("Sketchy is downloaded {}!".format('\033[32;1m' +
                                                    'succesfully' + '\033[0m'))

    assert md5sum(os.path.join(
        output_folder, "sketches-06-04.7z")) == MD5SUM_TABLE[
            "sketches-06-04.7z"], 'sketches-06-04.7z md5 checksum error'
    assert md5sum(os.path.join(
        output_folder, "rendered_256x256.7z")) == MD5SUM_TABLE[
            "rendered_256x256.7z"], 'rendered_256x256.7z md5 checksum error'
    assert md5sum(
        os.path.join(output_folder, "info-06-04.7z")
    ) == MD5SUM_TABLE["info-06-04.7z"], 'info-06-04.7z md5 checksum error'

    print_banner("Md5 checksum passed {}!".format('\033[32;1m' +
                                                  'succesfully' + '\033[0m'))

    print_banner("Extracting started......")

    try:
        extract_files(file_name=os.path.join(output_folder,
                                             "sketches-06-04.7z"),
                      output_folder=output_folder,
                      remove_sourcefile=remove_sourcefile)
        extract_files(file_name=os.path.join(output_folder,
                                             "rendered_256x256.7z"),
                      output_folder=output_folder,
                      remove_sourcefile=remove_sourcefile)
        extract_files(file_name=os.path.join(output_folder, "info-06-04.7z"),
                      output_folder=output_folder,
                      remove_sourcefile=remove_sourcefile)

    except Exception as e:
        print(e)
        return

    print_banner("Sketchy is extracted {}!".format('\033[32;1m' +
                                                   'succesfully' + '\033[0m'))

    print(torchsketch())
コード例 #15
0
def remove_strokes_4_svg(svg_url,
                         output_folder,
                         remove_ratio=0.3,
                         verbose=True):

    if remove_ratio <= 0.0 or remove_ratio >= 1.0:
        print(
            "The input parameter of remove_ratio should be defined in (0.0, 1.0)."
        )
        print("Please choose a proper remove_ratio value, and retry this API.")
        print(torchsketch())
        return 0

    doc = minidom.parse(svg_url)
    paths = doc.getElementsByTagName("path")

    assert type(paths) == xml.dom.minicompat.NodeList
    stroke_count = len(paths)

    stroke_length_array = [len(path.getAttribute('d')) for path in paths]

    assert len(stroke_length_array) == stroke_count

    stroke_length_array.sort()

    threshold_value = stroke_length_array[int(stroke_count * remove_ratio)]

    if verbose == True:
        print("{} has {} storkes.".format(svg_url, stroke_count))

    reset_dir(output_folder)

    svg_filename = svg_url.split("/")[-1][:-4]

    try:

        for i in range(stroke_count):

            if len(doc.getElementsByTagName("path")[i].getAttribute(
                    "d")) <= threshold_value:

                doc.getElementsByTagName("path")[i].setAttribute(
                    "visibility", "hidden")

        output = open(
            output_folder + "/" + svg_filename + "_" + "partially_removed" +
            ".svg", "w")
        doc.writexml(output)
        output.close()

    except Exception as e:
        print(e)

    if verbose == True:
        print("Details are as follows.")
        print("{} strokes have been removed.".format(
            int(stroke_count * remove_ratio)))
        print("The produced sketch is stored in {}".format(output_folder))
        print(torchsketch())

    return stroke_count
コード例 #16
0
def mark_longest_strokes_4_svg(svg_url, output_folder, mark_longest = True, color = "red", verbose = True):


    if color not in color_table.keys():
        print("Please choose a recognizable color.")
        return 0

    doc = minidom.parse(svg_url)
    paths = doc.getElementsByTagName("path")

    assert type(paths) == xml.dom.minicompat.NodeList
    stroke_count = len(paths)

    stroke_length_array = [len(path.getAttribute('d')) for path in paths]

    assert len(stroke_length_array) == stroke_count

    stroke_length_array.sort()

    if mark_longest:
        threshold_value = stroke_length_array[-1]
    else:
    	threshold_value = stroke_length_array[0]
    
    if verbose == True:
        print("{} has {} storkes.".format(svg_url, stroke_count))

    reset_dir(output_folder)


    svg_filename = svg_url.split("/")[-1][:-4]

    marked_count = 0

    try:

        for i in range(stroke_count):

            if len(doc.getElementsByTagName("path")[i].getAttribute("d")) == threshold_value:

                doc.getElementsByTagName("path")[i].setAttribute("stroke", color)
                marked_count += 1



        if mark_longest:
            output = open(output_folder + "/" + svg_filename + "_" + "marked_longest_stroke" + ".svg", "w")
        else:
        	output = open(output_folder + "/" + svg_filename + "_" + "marked_shortest_stroke" + ".svg", "w")
        doc.writexml(output)
        output.close()
        	
    except Exception as e:
        print(e)

    if verbose == True:
        print("Details are as follows.")
        print("{} strokes have been marked.".format(marked_count))
        print("The produced sketch is stored in {}".format(output_folder))
        print(torchsketch())
	
    return marked_count