Beispiel #1
0
def test_ScaleAndShift_test_a():
    preprocessing_queue = [preprocessing.ScaleAndShift()]
    s = (
        '[[{"x":232,"y":423,"time":1407885913983},'
        '{"x":267,"y":262,"time":1407885914315},'
        '{"x":325,"y":416,"time":1407885914650}],'
        '[{"x":252,"y":355,"time":1407885915675},'
        '{"x":305,"y":351,"time":1407885916361}]]'
    )
    a = HandwrittenData(s)
    a.preprocessing(preprocessing_queue)
    s = a.get_pointlist()
    expectation = [
        [
            {"y": 1.0, "x": 0.0, "time": 0},
            {"y": 0.0, "x": 0.2174, "time": 332},
            {"y": 0.9565, "x": 0.5776, "time": 667},
        ],
        [
            {"y": 0.5776, "x": 0.1242, "time": 1692},
            {"y": 0.5528, "x": 0.4534, "time": 2378},
        ],
    ]
    assert testhelper.compare_pointlists(
        s, expectation
    ), f"Got: {s}; expected {expectation}"
Beispiel #2
0
def test_ScaleAndShift_test_a_center():
    preprocessing_queue = [preprocessing.ScaleAndShift(center=True)]
    s = (
        '[[{"y": 1.0, "x": -0.3655913978494625, "time": 0}, '
        '{"y": 0.0, "x": -0.1482000935016364, "time": 332}, '
        '{"y": 0.9565, "x": 0.21204835370333253, "time": 667}], '
        '[{"y": 0.5776, "x": -0.24136779536499045, "time": 1692}, '
        '{"y": 0.5528, "x": 0.08782475121886046, "time": 2378}]]'
    )
    a = HandwrittenData(s)
    a.preprocessing(preprocessing_queue)
    s = a.get_pointlist()
    expectation = [
        [
            {"y": 1.0, "x": -0.2888198757763975, "time": 0},
            {"y": 0.0, "x": -0.07142857142857142, "time": 332},
            {"y": 0.9565, "x": 0.2888198757763975, "time": 667},
        ],
        [
            {"y": 0.5776, "x": -0.16459627329192547, "time": 1692},
            {"y": 0.5528, "x": 0.16459627329192544, "time": 2378},
        ],
    ]
    assert testhelper.compare_pointlists(
        s, expectation
    ), f"Got: {s}; expected {expectation}"
 def __init__(self,
              raw_data_json,
              formula_id,
              wild_point_count,
              missing_line,
              has_hook,
              has_too_long_line,
              is_image,
              other_problem,
              has_interrupted_line,
              raw_data_id,
              latex):
     HandwrittenData.__init__(self, raw_data_json, formula_id)
     self.wild_point_count = wild_point_count
     self.missing_line = missing_line
     self.has_hook = has_hook
     self.has_too_long_line = has_too_long_line
     self.is_image = is_image
     self.other_problem = other_problem
     self.has_interrupted_line = has_interrupted_line
     self.unaccept = False
     self.ok = False
     self.raw_data_id = raw_data_id
     self.latex = latex
     self.istrash = False
Beispiel #4
0
def test_ScaleAndShift_test_simple_5():
    preprocessing_queue = [preprocessing.ScaleAndShift()]
    s = '[[{"x":42, "y":12, "time": 10}]]'
    a = HandwrittenData(s)
    a.preprocessing(preprocessing_queue)
    s = a.get_pointlist()
    expectation = [[{"x": 0, "y": 0, "time": 0}]]
    assert s == expectation, f"Got: {s}; expected {expectation}"
Beispiel #5
0
def ScaleAndShift_test_simple_4():
    preprocessing_queue = [preprocessing.ScaleAndShift()]
    s = '[[{"x":0, "y":0, "time": 10}]]'
    a = HandwrittenData(s)
    a.preprocessing(preprocessing_queue)
    s = a.get_pointlist()
    expectation = [[{"x": 0, "y": 0, "time": 0}]]
    assert s == expectation, "Got: %s; expected %s" % (s, expectation)
Beispiel #6
0
def get_all_symbols_as_handwriting():
    handwritings = []
    for symbol_file in get_all_symbols():
        with open(symbol_file) as f:
            data = f.read()
        handwritings.append(HandwrittenData(data))
    return handwritings
Beispiel #7
0
 def __init__(self, raw_data_json, formula_id, wild_point_count,
              missing_line, has_hook, has_too_long_line, is_image,
              other_problem, has_interrupted_line, raw_data_id, latex):
     HandwrittenData.__init__(self, raw_data_json, formula_id)
     self.wild_point_count = wild_point_count
     self.missing_line = missing_line
     self.has_hook = has_hook
     self.has_too_long_line = has_too_long_line
     self.is_image = is_image
     self.other_problem = other_problem
     self.has_interrupted_line = has_interrupted_line
     self.unaccept = False
     self.ok = False
     self.raw_data_id = raw_data_id
     self.latex = latex
     self.istrash = False
Beispiel #8
0
def get_recordings(mysql, symbol_id):
    """
    Parameters
    ----------
    mysql : dict
        Connection information
    symbol_id : int
        ID of a symbol on write-math.com

    Returns
    -------
    list :
        A list of HandwrittenData objects
    """
    connection = pymysql.connect(host=mysql['host'],
                                 user=mysql['user'],
                                 passwd=mysql['passwd'],
                                 db=mysql['db'],
                                 cursorclass=pymysql.cursors.DictCursor)
    cursor = connection.cursor()

    # Get the data
    recordings = []
    sql = ("SELECT `id`, `data`, `is_in_testset`, `wild_point_count`, "
           "`missing_line`, `user_id` "
           "FROM `wm_raw_draw_data` "
           "WHERE `accepted_formula_id` = %s" % str(symbol_id))
    cursor.execute(sql)
    raw_datasets = cursor.fetchall()
    for raw_data in raw_datasets:
        try:
            handwriting = HandwrittenData(raw_data['data'],
                                          symbol_id,
                                          raw_data['id'],
                                          "no formula in latex",
                                          raw_data['wild_point_count'],
                                          raw_data['missing_line'],
                                          raw_data['user_id'])
            recordings.append(handwriting)
        except Exception as e:
            logging.info("Raw data id: %s", raw_data['id'])
            logging.info(e)
    return recordings
Beispiel #9
0
def main(destination=os.path.join(utils.get_project_root(), "raw-datasets"),
         dataset='all',
         renderings=False):
    """Main part of the backup script."""
    time_prefix = time.strftime("%Y-%m-%d-%H-%M")
    filename = ("%s-handwriting_datasets-%s-raw.pickle" %
                (time_prefix, dataset.replace('/', '-')))
    destination_path = os.path.join(destination, filename)
    logging.info("Data will be written to '%s'", destination_path)

    cfg = utils.get_database_configuration()
    mysql = cfg['mysql_online']
    connection = pymysql.connect(host=mysql['host'],
                                 user=mysql['user'],
                                 passwd=mysql['passwd'],
                                 db=mysql['db'],
                                 cursorclass=pymysql.cursors.DictCursor)
    cursor = connection.cursor()

    formulas = get_formulas(cursor, dataset)
    logging.info('Received %i formulas.', len(formulas))
    handwriting_datasets = []
    formula_id2latex = {}

    # Go through each formula and download every raw_data instance
    for formula in formulas:
        formula_id2latex[formula['id']] = formula['formula_in_latex']
        sql = ((
            "SELECT `wm_raw_draw_data`.`id`, `data`, `is_in_testset`, "
            "`wild_point_count`, `missing_line`, `user_id`, "
            "`display_name` "
            "FROM `wm_raw_draw_data` "
            "JOIN `wm_users` ON "
            "(`wm_users`.`id` = `wm_raw_draw_data`.`user_id`) "
            "WHERE `accepted_formula_id` = %s "
            # "AND `display_name` LIKE 'MfrDB::%%'"
        ) % str(formula['id']))
        cursor.execute(sql)
        raw_datasets = cursor.fetchall()
        logging.info("%s (%i)", formula['formula_in_latex'], len(raw_datasets))
        for raw_data in raw_datasets:
            try:
                handwriting = HandwrittenData(
                    raw_data['data'],
                    formula['id'],
                    raw_data['id'],
                    formula['formula_in_latex'],
                    raw_data['wild_point_count'],
                    raw_data['missing_line'],
                    raw_data['user_id'],
                    user_name=raw_data['display_name'])
                handwriting_datasets.append({
                    'handwriting':
                    handwriting,
                    'id':
                    raw_data['id'],
                    'formula_id':
                    formula['id'],
                    'formula_in_latex':
                    formula['formula_in_latex'],
                    'is_in_testset':
                    raw_data['is_in_testset']
                })
            except Exception as e:
                logging.info("Raw data id: %s", raw_data['id'])
                logging.info(e)
    pickle.dump(
        {
            'handwriting_datasets': handwriting_datasets,
            'formula_id2latex': formula_id2latex
        }, open(destination_path, "wb"), 2)

    if renderings:
        logging.info("Start downloading SVG renderings...")
        svgfolder = tempfile.mkdtemp()
        sql = """SELECT t1.formula_id, t1.svg from wm_renderings t1
                 LEFT JOIN wm_renderings t2 ON t1.formula_id = t2.formula_id
                 AND t1.creation_time < t2.creation_time
                 WHERE t2.id is null"""
        cursor.execute(sql)
        formulas = cursor.fetchall()
        logging.info("Create svg...")
        for formula in formulas:
            filename = os.path.join(svgfolder,
                                    "%s.svg" % str(formula['formula_id']))
            with open(filename, 'wb') as temp_file:
                temp_file.write(formula['svg'])
        logging.info("Tar at %s", os.path.abspath("renderings.tar"))

        tar = tarfile.open("renderings.tar.bz2", "w:bz2")
        for fn in os.listdir(svgfolder):
            filename = os.path.join(svgfolder, fn)
            if os.path.isfile(filename):
                print(filename)
                tar.add(filename, arcname=os.path.basename(filename))
        tar.close()
def main(dataset='all'):
    """
    Parameters
    ----------
    dataset : string
        Either 'all' or a path to a yaml symbol file.
    """
    cfg = utils.get_database_configuration()
    mysql = cfg['mysql_online']
    connection = pymysql.connect(host=mysql['host'],
                                 user=mysql['user'],
                                 passwd=mysql['passwd'],
                                 db=mysql['db'],
                                 cursorclass=pymysql.cursors.DictCursor)
    cursor = connection.cursor()

    # TODO: no formulas, only single-symbol ones.
    formulas = get_formulas(cursor, dataset)
    prob = {}

    # Go through each formula and download every raw_data instance
    for formula in formulas:
        stroke_counts = []
        recordings = []
        sql = (("SELECT `wm_raw_draw_data`.`id`, `data`, `is_in_testset`, "
                "`wild_point_count`, `missing_line`, `user_id`, "
                "`display_name` "
                "FROM `wm_raw_draw_data` "
                "JOIN `wm_users` ON "
                "(`wm_users`.`id` = `wm_raw_draw_data`.`user_id`) "
                "WHERE `accepted_formula_id` = %s "
                "AND wild_point_count=0 "
                "AND has_correction=0 "
                # "AND `display_name` LIKE 'MfrDB::%%'"
                ) %
               str(formula['id']))
        cursor.execute(sql)
        raw_datasets = cursor.fetchall()
        logging.info("%s (%i)", formula['formula_in_latex'], len(raw_datasets))
        for raw_data in raw_datasets:
            try:
                handwriting = HandwrittenData(raw_data['data'],
                                              formula['id'],
                                              raw_data['id'],
                                              formula['formula_in_latex'],
                                              raw_data['wild_point_count'],
                                              raw_data['missing_line'],
                                              raw_data['user_id'])
                stroke_counts.append(len(handwriting.get_pointlist()))
                recordings.append(handwriting)
            except Exception as e:
                logging.info("Raw data id: %s", raw_data['id'])
                logging.info(e)
        if len(stroke_counts) > 0:
            logging.info("\t[%i - %i]", min(stroke_counts), max(stroke_counts))
            median = numpy.median(stroke_counts)
            logging.info("\tMedian: %0.2f\tMean: %0.2f\tstd: %0.2f",
                         median,
                         numpy.mean(stroke_counts),
                         numpy.std(stroke_counts))

            # Make prob
            s = sorted(Counter(stroke_counts).items(),
                       key=lambda n: n[1],
                       reverse=True)
            key = formula['formula_in_latex']
            prob[key] = {}
            for stroke_nr, count in s:
                prob[key][stroke_nr] = count

            # Outliers
            modes = get_modes(stroke_counts)
            logging.info("\tModes: %s", modes)
            exceptions = []
            for rec in recordings:
                if len(rec.get_pointlist()) not in modes:
                    url = (("http://www.martin-thoma.de/"
                            "write-math/view/?raw_data_id=%i - "
                            "%i strokes") % (rec.raw_data_id,
                                             len(rec.get_pointlist())))
                    dist = get_dist(len(rec.get_pointlist()), modes)
                    exceptions.append((url, len(rec.get_pointlist()), dist))
            print_exceptions(exceptions, max_print=10)
        else:
            logging.debug("No recordings for symbol "
                          "'http://www.martin-thoma.de/"
                          "write-math/symbol/?id=%s'.",
                          formula['id'])
    write_prob(prob, "prob_stroke_count_by_symbol.yml")
Beispiel #11
0
def width_test():
    with open(testhelper.get_symbol(97705)) as f:
        data = f.read()
    assert HandwrittenData(data).get_width() == 186, \
        "Got %i" % HandwrittenData(data).get_width()
Beispiel #12
0
def load_symbol_test():
    for symbol_file in testhelper.get_all_symbols():
        with open(symbol_file) as f:
            data = f.read()
        a = HandwrittenData(data)
        assert isinstance(a, HandwrittenData)
Beispiel #13
0
def get_symbol_as_handwriting(raw_data_id):
    symbol_file = get_symbol(raw_data_id)
    with open(symbol_file) as f:
        data = f.read()
    a = HandwrittenData(data)
    return a
def main(dataset='all'):
    """
    Parameters
    ----------
    dataset : string
        Either 'all' or a path to a yaml symbol file.
    """
    cfg = utils.get_database_configuration()
    mysql = cfg['mysql_online']
    connection = pymysql.connect(host=mysql['host'],
                                 user=mysql['user'],
                                 passwd=mysql['passwd'],
                                 db=mysql['db'],
                                 cursorclass=pymysql.cursors.DictCursor)
    cursor = connection.cursor()

    # TODO: no formulas, only single-symbol ones.
    formulas = get_formulas(cursor, dataset)
    prob = {}

    # Go through each formula and download every raw_data instance
    for formula in formulas:
        stroke_counts = []
        recordings = []
        sql = ((
            "SELECT `wm_raw_draw_data`.`id`, `data`, `is_in_testset`, "
            "`wild_point_count`, `missing_line`, `user_id`, "
            "`display_name` "
            "FROM `wm_raw_draw_data` "
            "JOIN `wm_users` ON "
            "(`wm_users`.`id` = `wm_raw_draw_data`.`user_id`) "
            "WHERE `accepted_formula_id` = %s "
            "AND wild_point_count=0 "
            "AND has_correction=0 "
            # "AND `display_name` LIKE 'MfrDB::%%'"
        ) % str(formula['id']))
        cursor.execute(sql)
        raw_datasets = cursor.fetchall()
        logging.info("%s (%i)", formula['formula_in_latex'], len(raw_datasets))
        for raw_data in raw_datasets:
            try:
                handwriting = HandwrittenData(raw_data['data'], formula['id'],
                                              raw_data['id'],
                                              formula['formula_in_latex'],
                                              raw_data['wild_point_count'],
                                              raw_data['missing_line'],
                                              raw_data['user_id'])
                stroke_counts.append(len(handwriting.get_pointlist()))
                recordings.append(handwriting)
            except Exception as e:
                logging.info("Raw data id: %s", raw_data['id'])
                logging.info(e)
        if len(stroke_counts) > 0:
            logging.info("\t[%i - %i]", min(stroke_counts), max(stroke_counts))
            median = numpy.median(stroke_counts)
            logging.info("\tMedian: %0.2f\tMean: %0.2f\tstd: %0.2f", median,
                         numpy.mean(stroke_counts), numpy.std(stroke_counts))

            # Make prob
            s = sorted(Counter(stroke_counts).items(),
                       key=lambda n: n[1],
                       reverse=True)
            key = formula['formula_in_latex']
            prob[key] = {}
            for stroke_nr, count in s:
                prob[key][stroke_nr] = count

            # Outliers
            modes = get_modes(stroke_counts)
            logging.info("\tModes: %s", modes)
            exceptions = []
            for rec in recordings:
                if len(rec.get_pointlist()) not in modes:
                    url = (("http://www.martin-thoma.de/"
                            "write-math/view/?raw_data_id=%i - "
                            "%i strokes") %
                           (rec.raw_data_id, len(rec.get_pointlist())))
                    dist = get_dist(len(rec.get_pointlist()), modes)
                    exceptions.append((url, len(rec.get_pointlist()), dist))
            print_exceptions(exceptions, max_print=10)
        else:
            logging.debug(
                "No recordings for symbol "
                "'http://www.martin-thoma.de/"
                "write-math/symbol/?id=%s'.", formula['id'])
    write_prob(prob, "prob_stroke_count_by_symbol.yml")