示例#1
0
def ScaleAndShift_test_simple_3():
    preprocessing_queue = [preprocessing.ScaleAndShift()]
    s = '[[{"x":0, "y":10, "time": 0}]]'
    a = HandwrittenData(s)
    a.preprocessing(preprocessing_queue)
    s = a.get_pointlist()
    expectation = [[{"x": 0, "y": 0, "time": 0}]]
    assert s == expectation, "Got: %s; expected %s" % (s, expectation)
示例#2
0
def test_ScaleAndShift_test_simple_4():
    preprocessing_queue = [preprocessing.ScaleAndShift()]
    s = '[[{"x":0, "y":0, "time": 10}]]'
    a = HandwrittenData(s)
    a.preprocessing(preprocessing_queue)
    s = a.get_pointlist()
    expectation = [[{"x": 0, "y": 0, "time": 0}]]
    assert s == expectation, f"Got: {s}; expected {expectation}"
def main(dataset='all'):
    """
    Parameters
    ----------
    dataset : string
        Either 'all' or a path to a yaml symbol file.
    """
    cfg = utils.get_database_configuration()
    mysql = cfg['mysql_online']
    connection = pymysql.connect(host=mysql['host'],
                                 user=mysql['user'],
                                 passwd=mysql['passwd'],
                                 db=mysql['db'],
                                 cursorclass=pymysql.cursors.DictCursor)
    cursor = connection.cursor()

    # TODO: no formulas, only single-symbol ones.
    formulas = get_formulas(cursor, dataset)
    prob = {}

    # Go through each formula and download every raw_data instance
    for formula in formulas:
        stroke_counts = []
        recordings = []
        sql = (("SELECT `wm_raw_draw_data`.`id`, `data`, `is_in_testset`, "
                "`wild_point_count`, `missing_line`, `user_id`, "
                "`display_name` "
                "FROM `wm_raw_draw_data` "
                "JOIN `wm_users` ON "
                "(`wm_users`.`id` = `wm_raw_draw_data`.`user_id`) "
                "WHERE `accepted_formula_id` = %s "
                "AND wild_point_count=0 "
                "AND has_correction=0 "
                # "AND `display_name` LIKE 'MfrDB::%%'"
                ) %
               str(formula['id']))
        cursor.execute(sql)
        raw_datasets = cursor.fetchall()
        logging.info("%s (%i)", formula['formula_in_latex'], len(raw_datasets))
        for raw_data in raw_datasets:
            try:
                handwriting = HandwrittenData(raw_data['data'],
                                              formula['id'],
                                              raw_data['id'],
                                              formula['formula_in_latex'],
                                              raw_data['wild_point_count'],
                                              raw_data['missing_line'],
                                              raw_data['user_id'])
                stroke_counts.append(len(handwriting.get_pointlist()))
                recordings.append(handwriting)
            except Exception as e:
                logging.info("Raw data id: %s", raw_data['id'])
                logging.info(e)
        if len(stroke_counts) > 0:
            logging.info("\t[%i - %i]", min(stroke_counts), max(stroke_counts))
            median = numpy.median(stroke_counts)
            logging.info("\tMedian: %0.2f\tMean: %0.2f\tstd: %0.2f",
                         median,
                         numpy.mean(stroke_counts),
                         numpy.std(stroke_counts))

            # Make prob
            s = sorted(Counter(stroke_counts).items(),
                       key=lambda n: n[1],
                       reverse=True)
            key = formula['formula_in_latex']
            prob[key] = {}
            for stroke_nr, count in s:
                prob[key][stroke_nr] = count

            # Outliers
            modes = get_modes(stroke_counts)
            logging.info("\tModes: %s", modes)
            exceptions = []
            for rec in recordings:
                if len(rec.get_pointlist()) not in modes:
                    url = (("http://www.martin-thoma.de/"
                            "write-math/view/?raw_data_id=%i - "
                            "%i strokes") % (rec.raw_data_id,
                                             len(rec.get_pointlist())))
                    dist = get_dist(len(rec.get_pointlist()), modes)
                    exceptions.append((url, len(rec.get_pointlist()), dist))
            print_exceptions(exceptions, max_print=10)
        else:
            logging.debug("No recordings for symbol "
                          "'http://www.martin-thoma.de/"
                          "write-math/symbol/?id=%s'.",
                          formula['id'])
    write_prob(prob, "prob_stroke_count_by_symbol.yml")
def main(dataset='all'):
    """
    Parameters
    ----------
    dataset : string
        Either 'all' or a path to a yaml symbol file.
    """
    cfg = utils.get_database_configuration()
    mysql = cfg['mysql_online']
    connection = pymysql.connect(host=mysql['host'],
                                 user=mysql['user'],
                                 passwd=mysql['passwd'],
                                 db=mysql['db'],
                                 cursorclass=pymysql.cursors.DictCursor)
    cursor = connection.cursor()

    # TODO: no formulas, only single-symbol ones.
    formulas = get_formulas(cursor, dataset)
    prob = {}

    # Go through each formula and download every raw_data instance
    for formula in formulas:
        stroke_counts = []
        recordings = []
        sql = ((
            "SELECT `wm_raw_draw_data`.`id`, `data`, `is_in_testset`, "
            "`wild_point_count`, `missing_line`, `user_id`, "
            "`display_name` "
            "FROM `wm_raw_draw_data` "
            "JOIN `wm_users` ON "
            "(`wm_users`.`id` = `wm_raw_draw_data`.`user_id`) "
            "WHERE `accepted_formula_id` = %s "
            "AND wild_point_count=0 "
            "AND has_correction=0 "
            # "AND `display_name` LIKE 'MfrDB::%%'"
        ) % str(formula['id']))
        cursor.execute(sql)
        raw_datasets = cursor.fetchall()
        logging.info("%s (%i)", formula['formula_in_latex'], len(raw_datasets))
        for raw_data in raw_datasets:
            try:
                handwriting = HandwrittenData(raw_data['data'], formula['id'],
                                              raw_data['id'],
                                              formula['formula_in_latex'],
                                              raw_data['wild_point_count'],
                                              raw_data['missing_line'],
                                              raw_data['user_id'])
                stroke_counts.append(len(handwriting.get_pointlist()))
                recordings.append(handwriting)
            except Exception as e:
                logging.info("Raw data id: %s", raw_data['id'])
                logging.info(e)
        if len(stroke_counts) > 0:
            logging.info("\t[%i - %i]", min(stroke_counts), max(stroke_counts))
            median = numpy.median(stroke_counts)
            logging.info("\tMedian: %0.2f\tMean: %0.2f\tstd: %0.2f", median,
                         numpy.mean(stroke_counts), numpy.std(stroke_counts))

            # Make prob
            s = sorted(Counter(stroke_counts).items(),
                       key=lambda n: n[1],
                       reverse=True)
            key = formula['formula_in_latex']
            prob[key] = {}
            for stroke_nr, count in s:
                prob[key][stroke_nr] = count

            # Outliers
            modes = get_modes(stroke_counts)
            logging.info("\tModes: %s", modes)
            exceptions = []
            for rec in recordings:
                if len(rec.get_pointlist()) not in modes:
                    url = (("http://www.martin-thoma.de/"
                            "write-math/view/?raw_data_id=%i - "
                            "%i strokes") %
                           (rec.raw_data_id, len(rec.get_pointlist())))
                    dist = get_dist(len(rec.get_pointlist()), modes)
                    exceptions.append((url, len(rec.get_pointlist()), dist))
            print_exceptions(exceptions, max_print=10)
        else:
            logging.debug(
                "No recordings for symbol "
                "'http://www.martin-thoma.de/"
                "write-math/symbol/?id=%s'.", formula['id'])
    write_prob(prob, "prob_stroke_count_by_symbol.yml")