示例#1
0
class EvalEngine:
    def __init__(self, env, dataset):
        self.env = env
        self.ds = dataset

    def load_data(self,
                  name,
                  fpath,
                  tid_col,
                  attr_col,
                  val_col,
                  na_values=None):
        tic = time.clock()
        try:
            raw_data = pd.read_csv(fpath,
                                   na_values=na_values,
                                   dtype=str,
                                   encoding='utf-8')
            # We drop any ground truth values that are NULLs since we follow
            # the closed-world assumption (if it's not there it's wrong).
            # TODO: revisit this once we allow users to specify which
            # attributes may be NULL.
            raw_data.dropna(subset=[val_col], inplace=True)
            raw_data.rename(
                {
                    tid_col: '_tid_',
                    attr_col: '_attribute_',
                    val_col: '_value_'
                },
                axis='columns',
                inplace=True)
            raw_data = raw_data[['_tid_', '_attribute_', '_value_']]
            raw_data['_tid_'] = raw_data['_tid_'].astype(int)

            # Normalize string to whitespaces.
            categorical_attrs = self.ds.categorical_attrs
            if categorical_attrs:
                cat_cells = raw_data['_attribute_'].isin(categorical_attrs)
                raw_data.loc[cat_cells, '_value_'] = \
                    raw_data.loc[cat_cells, '_value_'].astype(str).str.strip().str.lower()

            self.clean_data = Table(name, Source.DF, df=raw_data)
            self.clean_data.store_to_db(self.ds.engine.engine)
            self.clean_data.create_db_index(self.ds.engine, ['_tid_'])
            self.clean_data.create_db_index(self.ds.engine, ['_attribute_'])
            status = 'DONE Loading {fname}'.format(
                fname=os.path.basename(fpath))
        except Exception:
            logging.error('load_data for table %s', name)
            raise
        toc = time.clock()
        load_time = toc - tic
        return status, load_time

    def eval_report(self, attr=None):
        """
        Returns an EvalReport named tuple containing the experiment results.
        :param attr: if attr is not None, compute results for attr:
                        if attr is numerical, then only report rmse
                        if attr is categorical, then report precision, recall etc.
                     if attr is None, compute results for all attrs
        """
        tic = time.clock()
        eval_report_dict = {}
        # attr is not None and is numerical
        # or attr is None(query on all attrs) and no categorical
        if attr is None or attr in self.ds.numerical_attrs:
            eval_report_dict['rmse'] = self.compute_rmse(attr) or 0.

        if attr is None or attr in self.ds.categorical_attrs:
            # if attr in self.ds.categorical_attrs or attr is None
            # attr is not None and is categorical
            # attr is None and should query on all
            correct_repairs = self.compute_correct_repairs(attr)
            total_repairs = self.compute_total_repairs(attr)
            detected_errors = self.compute_detected_errors(attr)
            total_errors = self.compute_total_errors(attr)
            total_repairs_grdt_correct, total_repairs_grdt_incorrect, \
                total_repairs_grdt = self.compute_total_repairs_grdt(attr)

            eval_report_dict['detected_errors'] = detected_errors
            eval_report_dict['total_errors'] = total_errors
            eval_report_dict['correct_repairs'] = correct_repairs
            eval_report_dict['total_repairs'] = total_repairs
            eval_report_dict['total_repairs_grdt'] = total_repairs_grdt
            eval_report_dict[
                'total_repairs_grdt_correct'] = total_repairs_grdt_correct
            eval_report_dict[
                'total_repairs_grdt_incorrect'] = total_repairs_grdt_incorrect

            eval_report_dict['precision'] = self.compute_precision(
                correct_repairs, total_repairs_grdt)
            eval_report_dict['recall'] = self.compute_recall(
                correct_repairs, total_errors)
            eval_report_dict['repair_recall'] = self.compute_repairing_recall(
                correct_repairs, detected_errors)
            eval_report_dict['f1'] = self.compute_f1(correct_repairs,
                                                     total_errors,
                                                     total_repairs_grdt)
            eval_report_dict['repair_f1'] = self.compute_repairing_f1(
                correct_repairs, detected_errors, total_repairs_grdt)

        report = EvalReport(**eval_report_dict)
        report_str = "Precision = %.2f, Recall = %.2f, Repairing Recall = %.2f, " \
                     "F1 = %.2f, Repairing F1 = %.2f, Detected Errors = %d, " \
                     "Total Errors = %d, Correct Repairs = %d, Total Repairs = %d, " \
                     "Total Repairs on correct cells (Grdth present) = %d, " \
                     "Total Repairs on incorrect cells (Grdth present) = %d, " \
                     "RMSE = %.2f" % (report.precision, report.recall, report.repair_recall, report.f1, report.repair_f1,
                                     report.detected_errors, report.total_errors, report.correct_repairs,
                                     report.total_repairs,
                                     report.total_repairs_grdt_correct, report.total_repairs_grdt_incorrect,
                                     report.rmse)

        if attr:
            report_str = "# Attribute:{};{}".format(attr, report_str)

        toc = time.clock()
        report_time = toc - tic
        return report_str, report_time, report

    """
    All these compute_xxx methods are aimed at categorical attributes
    All numerical attrs should not be involved in computing precision, recall, etc.
    """

    def get_categorical_clause(self, attr):
        """
        Get where condition on which attr should be queried
        :param attr: if attr is None, generate condition on all categorical attrs
        :return: str type attr_clause showing: 'AND t1.attribute = attr' to only do query on target attr
                if attr is None, then 'AND (t1.attribute = attr1 OR t1.attribute = attr2 ...) for all categorical attrs
        """
        query_attrs = [attr] if attr else self.ds.categorical_attrs
        query_attrs_str = ["\'{}\'".format(attr) for attr in query_attrs]
        categorical_where = 't1.attribute IN (%s)' % ','.join(query_attrs_str)

        return categorical_where

    def compute_total_repairs(self, attr=None):
        """
        compute_total_repairs memoizes the number of repairs:
        the # of cells that were inferred and where the inferred value
        is not equal to the initial value.
        :param attr: if attr is not None, it must be categorical
        """
        assert attr is None or attr in self.ds.categorical_attrs

        if not self.ds.categorical_attrs:
            return 0.
        # do not query on numerical attrs (when attr is None, indicating we want query on all categorical attrs)
        # if there are no numerical attrs, then no condition should be added, just query on all attributes
        attr_clause = self.get_categorical_clause(
            attr) if self.ds.numerical_attrs else "TRUE"
        query = "SELECT count(*) FROM " \
                "  (SELECT _vid_ " \
                '     FROM "{}" as t1, "{}" as t2 ' \
                "    WHERE t1._tid_ = t2._tid_ " \
                "      AND t1.attribute = t2.attribute " \
                "      AND t1.init_value != t2.rv_value " \
                "      AND {}) AS t".format(AuxTables.cell_domain.name,
                                            AuxTables.inf_values_dom.name,
                                            attr_clause)
        res = self.ds.engine.execute_query(query)
        return res[0][0]

    def compute_total_repairs_grdt(self, attr=None):
        """
        compute_total_repairs_grdt memoizes the number of repairs for cells
        that are specified in the clean/ground truth data. Otherwise repairs
        are defined the same as compute_total_repairs.

        We also distinguish between repairs on correct cells and repairs on
        incorrect cells (correct cells are cells where init == ground truth).
        :param attr: if attr is not None, it must be categorical
        """

        assert attr is None or attr in self.ds.categorical_attrs
        if not self.ds.categorical_attrs:
            return 0.
        # do not query on numerical attrs (when attr is None, indicating we want query on all categorical attrs)
        # if there are no numerical attrs, then no condition should be added, just query on all attributes
        attr_clause = self.get_categorical_clause(
            attr) if self.ds.numerical_attrs else "TRUE"

        query = """
            SELECT
                (t1.init_value = ANY(string_to_array(regexp_replace(t3._value_,\'[{{\"\"}}]\',\'\',\'gi\'),\'|\'))) AS is_correct,
                count(*)
            FROM   "{}" as t1, "{}" as t2, "{}" as t3
            WHERE  t1._tid_ = t2._tid_
              AND  t1.attribute = t2.attribute
              AND  t1.init_value != t2.rv_value
              AND  t1._tid_ = t3._tid_
              AND  t1.attribute = t3._attribute_
              AND  {}
            GROUP BY is_correct
              """.format(AuxTables.cell_domain.name,
                         AuxTables.inf_values_dom.name, self.clean_data.name,
                         attr_clause)

        res = self.ds.engine.execute_query(query)

        # Memoize the number of repairs on correct cells and incorrect cells.
        # Since we do a GROUP BY we need to check which row of the result
        # corresponds to the correct/incorrect counts.
        total_repairs_grdt_correct, total_repairs_grdt_incorrect = 0, 0
        if not res:
            return 0, 0, 0

        if res[0][0]:
            correct_idx, incorrect_idx = 0, 1
        else:
            correct_idx, incorrect_idx = 1, 0
        if correct_idx < len(res):
            total_repairs_grdt_correct = float(res[correct_idx][1])
        if incorrect_idx < len(res):
            total_repairs_grdt_incorrect = float(res[incorrect_idx][1])
        total_repairs_grdt = total_repairs_grdt_correct + total_repairs_grdt_incorrect

        return total_repairs_grdt_correct, total_repairs_grdt_incorrect, total_repairs_grdt

    def compute_total_errors(self, attr=None):
        """
        compute_total_errors memoizes the number of cells that have a
        wrong initial value: requires ground truth data.
        :param attr: if attr is not None, it must be categorical
        """
        assert attr is None or attr in self.ds.categorical_attrs
        if not self.ds.categorical_attrs:
            return 0.

        queries = []
        total_errors = 0.0

        query_attrs = [attr] if attr else self.ds.categorical_attrs
        for attr in query_attrs:
            query = errors_template.substitute(
                init_table=self.ds.raw_data.name,
                grdt_table=self.clean_data.name,
                attr=attr)
            queries.append(query)
        results = self.ds.engine.execute_queries(queries)

        for i in range(len(results)):
            res = results[i]
            total_errors += float(res[0][0])

        return total_errors

    def compute_detected_errors(self, attr=None):
        """
        compute_detected_errors memoizes the number of error cells that
        were detected in error detection: requires ground truth.

        This value is always equal or less than total errors (see
        compute_total_errors).
        :param attr: if attr is not None, it must be categorical
        """
        assert attr is None or attr in self.ds.categorical_attrs
        if not self.ds.categorical_attrs:
            return 0.
        # do not query on numerical attrs (when attr is None, indicating we want query on all categorical attrs)
        # if there are no numerical attrs, then no condition should be added, just query on all attributes
        attr_clause = self.get_categorical_clause(
            attr) if self.ds.numerical_attrs else "TRUE"

        query = "SELECT count(*) FROM " \
                "  (SELECT _vid_ " \
                '   FROM "{}" as t1, "{}" as t2, "{}" as t3 ' \
                "   WHERE t1._tid_ = t2._tid_ " \
                "     AND t1._cid_ = t3._cid_ " \
                "     AND t1.attribute = t2._attribute_ " \
                "     AND NOT t1.init_value = ANY(string_to_array(regexp_replace(t2._value_,\'[{{\"\"}}]\',\'\',\'gi\'),\'|\')) " \
                "     AND {}) AS t".format(AuxTables.cell_domain.name,
                                           self.clean_data.name,
                                           AuxTables.dk_cells.name,
                                           attr_clause)
        res = self.ds.engine.execute_query(query)
        return float(res[0][0])

    def compute_correct_repairs(self, attr=None):
        """
        compute_correct_repairs memoizes the number of error cells
        that were correctly inferred.

        This value is always equal or less than total errors (see
        compute_total_errors).
        :param attr: if attr is not None, it must be categorical
        """
        assert attr is None or attr in self.ds.categorical_attrs
        if not self.ds.categorical_attrs:
            return 0.

        queries = []
        correct_repairs = 0.0

        query_attrs = [attr] if attr else self.ds.categorical_attrs
        for attr in query_attrs:
            query = correct_repairs_template.substitute(
                init_table=self.ds.raw_data.name,
                grdt_table=self.clean_data.name,
                attr=attr,
                inf_dom=AuxTables.inf_values_dom.name)
            queries.append(query)
        results = self.ds.engine.execute_queries(queries)

        for i in range(len(results)):
            res = results[i]
            correct_repairs += float(res[0][0])

        return correct_repairs

    def compute_rmse(self, attr=None):
        """
        Should check all the dk_cells in numerical attributes
        compute RMS error for all dk_cells in numerical attributes
        :return:
        """
        assert attr is None or attr in self.ds.numerical_attrs
        if not self.ds.numerical_attrs:
            return 0.

        query_attrs = [attr] if attr else self.ds.numerical_attrs
        query_attrs_str = ["\'{}\'".format(attr) for attr in query_attrs]
        query_attrs_sql = '(%s)' % ','.join(query_attrs_str)
        query = rmse_template.substitute(grdt_table=self.clean_data.name,
                                         inf_dom=AuxTables.inf_values_dom.name,
                                         attrs_list=query_attrs_sql)
        res = self.ds.engine.execute_query(query)

        return res[0][0]

    @staticmethod
    def compute_recall(correct_repairs, total_errors):
        """
        Computes the recall (# of correct repairs / # of total errors).
        """
        if total_errors == 0:
            return 0
        return correct_repairs / total_errors

    @staticmethod
    def compute_repairing_recall(correct_repairs, detected_errors):
        """
        Computes the _repairing_ recall (# of correct repairs / # of total
        _detected_ errors).
        """
        if detected_errors == 0:
            return 0
        return correct_repairs / detected_errors

    @staticmethod
    def compute_precision(correct_repairs, total_repairs_grdt):
        """
        Computes precision (# correct repairs / # of total repairs w/ ground truth)
        """
        if total_repairs_grdt == 0:
            return 0
        return correct_repairs / total_repairs_grdt

    @staticmethod
    def compute_f1(correct_repairs, total_errors, total_repairs_grdt):
        prec = EvalEngine.compute_precision(correct_repairs,
                                            total_repairs_grdt)
        rec = EvalEngine.compute_recall(correct_repairs, total_errors)
        if prec + rec == 0:
            f1 = 0
        else:
            f1 = 2 * (prec * rec) / (prec + rec)
        return f1

    @staticmethod
    def compute_repairing_f1(correct_repairs, detected_errors,
                             total_repairs_grdt):
        prec = EvalEngine.compute_precision(correct_repairs,
                                            total_repairs_grdt)
        rec = EvalEngine.compute_repairing_recall(correct_repairs,
                                                  detected_errors)
        if prec == 0 and rec == 0:
            f1 = 0
        else:
            f1 = 2 * (prec * rec) / (prec + rec)

        return f1

    def log_weak_label_stats(self):
        query = """
        select
            (t3._tid_ is NULL) as clean,
            (t1.fixed) as status,
            (t1.init_value =  ANY(string_to_array(regexp_replace(t2._value_,\'[{{\"\"}}]\',\'\',\'gi\'),\'|\'))) as init_eq_grdth,
            (t1.weak_label = ANY(string_to_array(regexp_replace(t2._value_,\'[{{\"\"}}]\',\'\',\'gi\'),\'|\'))) as wl_eq_grdth,
            (t1.weak_label = t4.rv_value) as wl_eq_infer,
            (t4.rv_value = ANY(string_to_array(regexp_replace(t2._value_,\'[{{\"\"}}]\',\'\',\'gi\'),\'|\'))) as infer_eq_grdth,
            count(*) as count
        from
            "{cell_domain}" as t1,
            "{clean_data}" as t2
            left join "{dk_cells}" as t3 on t2._tid_ = t3._tid_ and t2._attribute_ = t3.attribute
            left join "{inf_values_dom}" as t4 on t2._tid_ = t4._tid_ and t2._attribute_ = t4.attribute where t1._tid_ = t2._tid_ and t1.attribute = t2._attribute_
        group by
            clean,
            status,
            init_eq_grdth,
            wl_eq_grdth,
            wl_eq_infer,
            infer_eq_grdth
        """.format(cell_domain=AuxTables.cell_domain.name,
                   clean_data=self.clean_data.name,
                   dk_cells=AuxTables.dk_cells.name,
                   inf_values_dom=AuxTables.inf_values_dom.name)

        res = self.ds.engine.execute_query(query)

        df_stats = pd.DataFrame(res,
                                columns=[
                                    "is_clean", "cell_status", "init = grdth",
                                    "wlabel = grdth", "wlabel = infer",
                                    "infer = grdth", "count"
                                ])
        df_stats = df_stats.sort_values(list(
            df_stats.columns)).reset_index(drop=True)
        pd.set_option('display.max_columns', None)
        pd.set_option('display.max_rows', len(df_stats))
        pd.set_option('display.max_colwidth', -1)
        logging.debug(
            "weak label statistics: (cell_status: 0 - none, 1 - wlabelled, 2 - single value)\n%s",
            df_stats)
        pd.reset_option('display.max_columns')
        pd.reset_option('display.max_rows')
        pd.reset_option('display.max_colwidth')
示例#2
0
class EvalEngine:
    def __init__(self, env, dataset):
        self.env = env
        self.ds = dataset

    def load_data(self,
                  name,
                  f_path,
                  f_name,
                  get_tid,
                  get_attr,
                  get_val,
                  na_values=None):
        tic = time.clock()
        try:
            raw_data = pd.read_csv(os.path.join(f_path, f_name),
                                   na_values=na_values)
            raw_data.fillna('_nan_', inplace=True)
            raw_data['_tid_'] = raw_data.apply(get_tid, axis=1)
            raw_data['_attribute_'] = raw_data.apply(get_attr, axis=1)
            raw_data['_value_'] = raw_data.apply(get_val, axis=1)
            raw_data = raw_data[['_tid_', '_attribute_', '_value_']]
            # Normalize string to lower-case and strip whitespaces.
            raw_data['_attribute_'] = raw_data['_attribute_'].apply(
                lambda x: x.lower())
            raw_data['_value_'] = raw_data['_value_'].apply(
                lambda x: x.strip())
            self.clean_data = Table(name, Source.DF, raw_data)
            self.clean_data.store_to_db(self.ds.engine.engine)
            self.clean_data.create_db_index(self.ds.engine, ['_tid_'])
            self.clean_data.create_db_index(self.ds.engine, ['_attribute_'])
            status = 'DONE Loading ' + f_name
        except Exception:
            logging.error('load_data for table %s', name)
            raise
        toc = time.clock()
        load_time = toc - tic
        return status, load_time

    def evaluate_repairs(self):
        self.compute_total_repairs()
        self.compute_total_repairs_grdt()
        self.compute_total_errors()
        self.compute_detected_errors()
        self.compute_correct_repairs()
        prec = self.compute_precision()
        rec = self.compute_recall()
        rep_recall = self.compute_repairing_recall()
        f1 = self.compute_f1()
        rep_f1 = self.compute_repairing_f1()
        return prec, rec, rep_recall, f1, rep_f1

    def eval_report(self):
        tic = time.clock()
        prec, rec, rep_recall, f1, rep_f1 = self.evaluate_repairs()
        report = "Precision = %.2f, Recall = %.2f, Repairing Recall = %.2f, F1 = %.2f, Repairing F1 = %.2f, Detected Errors = %d, Total Errors = %d, Correct Repairs = %d, Total Repairs = %d, Total Repairs (Grdth present) = %d" % (
            prec, rec, rep_recall, f1, rep_f1, self.detected_errors,
            self.total_errors, self.correct_repairs, self.total_repairs,
            self.total_repairs_grdt)
        toc = time.clock()
        report_time = toc - tic
        return report, report_time

    def compute_total_repairs(self):
        query = "SELECT count(*) FROM " \
                "(SELECT _vid_ " \
                 "FROM %s as t1, %s as t2 " \
                 "WHERE t1._tid_ = t2._tid_ " \
                   "AND t1.attribute = t2.attribute " \
                   "AND t1.init_value != t2.rv_value) AS t"\
                %(AuxTables.cell_domain.name, AuxTables.inf_values_dom.name)
        res = self.ds.engine.execute_query(query)
        self.total_repairs = float(res[0][0])

    def compute_total_repairs_grdt(self):
        query = "SELECT count(*) FROM " \
                "(SELECT _vid_ " \
                 "FROM %s as t1, %s as t2, %s as t3 " \
                 "WHERE t1._tid_ = t2._tid_ " \
                   "AND t1.attribute = t2.attribute " \
                   "AND t1.init_value != t2.rv_value " \
                   "AND t1._tid_ = t3._tid_ " \
                   "AND t1.attribute = t3._attribute_) AS t"\
                %(AuxTables.cell_domain.name, AuxTables.inf_values_dom.name, self.clean_data.name)
        res = self.ds.engine.execute_query(query)
        self.total_repairs_grdt = float(res[0][0])

    def compute_total_errors(self):
        queries = []
        total_errors = 0.0
        for attr in self.ds.get_attributes():
            query = errors_template.substitute(
                init_table=self.ds.raw_data.name,
                grdt_table=self.clean_data.name,
                attr=attr)
            queries.append(query)
        results = self.ds.engine.execute_queries(queries)
        for res in results:
            total_errors += float(res[0][0])
        self.total_errors = total_errors

    def compute_total_errors_grdt(self):
        queries = []
        total_errors = 0.0
        for attr in self.ds.get_attributes():
            query = errors_template.substitute(
                init_table=self.ds.raw_data.name,
                grdt_table=self.clean_data.name,
                attr=attr)
            queries.append(query)
        results = self.ds.engine.execute_queries(queries)
        for res in results:
            total_errors += float(res[0][0])
        self.total_errors = total_errors

    def compute_detected_errors(self):
        query = "SELECT count(*) FROM " \
                "(SELECT _vid_ " \
                "FROM %s as t1, %s as t2, %s as t3 " \
                "WHERE t1._tid_ = t2._tid_ AND t1._cid_ = t3._cid_ " \
                "AND t1.attribute = t2._attribute_ " \
                "AND t1.init_value != t2._value_) AS t" \
                % (AuxTables.cell_domain.name, self.clean_data.name, AuxTables.dk_cells.name)
        res = self.ds.engine.execute_query(query)
        self.detected_errors = float(res[0][0])

    def compute_correct_repairs(self):
        queries = []
        correct_repairs = 0.0
        for attr in self.ds.get_attributes():
            query = correct_repairs_template.substitute(
                init_table=self.ds.raw_data.name,
                grdt_table=self.clean_data.name,
                attr=attr,
                inf_dom=AuxTables.inf_values_dom.name)
            queries.append(query)
        results = self.ds.engine.execute_queries(queries)
        for res in results:
            correct_repairs += float(res[0][0])
        self.correct_repairs = correct_repairs

    def compute_recall(self):
        return self.correct_repairs / self.total_errors

    def compute_repairing_recall(self):
        return self.correct_repairs / self.detected_errors

    def compute_precision(self):
        return self.correct_repairs / self.total_repairs_grdt

    def compute_f1(self):
        prec = self.compute_precision()
        rec = self.compute_recall()
        f1 = 2 * (prec * rec) / (prec + rec)
        return f1

    def compute_repairing_f1(self):
        prec = self.compute_precision()
        rec = self.compute_repairing_recall()
        f1 = 2 * (prec * rec) / (prec + rec)
        return f1
示例#3
0
class EvalEngine:
    def __init__(self, env, dataset):
        self.env = env
        self.ds = dataset

    def load_data(self,
                  name,
                  fpath,
                  tid_col,
                  attr_col,
                  val_col,
                  na_values=None):
        tic = time.clock()
        try:
            raw_data = pd.read_csv(fpath,
                                   na_values=na_values,
                                   encoding='utf-8')
            # We drop any ground truth values that are NULLs since we follow
            # the closed-world assumption (if it's not there it's wrong).
            # TODO: revisit this once we allow users to specify which
            # attributes may be NULL.
            raw_data.dropna(subset=[val_col], inplace=True)
            raw_data.fillna(NULL_REPR, inplace=True)
            raw_data.rename(
                {
                    tid_col: '_tid_',
                    attr_col: '_attribute_',
                    val_col: '_value_'
                },
                axis='columns',
                inplace=True)
            raw_data = raw_data[['_tid_', '_attribute_', '_value_']]
            # Normalize string to whitespaces.
            raw_data['_value_'] = raw_data['_value_'].str.strip().str.lower()
            self.clean_data = Table(name, Source.DF, df=raw_data)
            self.clean_data.store_to_db(self.ds.engine.engine)
            self.clean_data.create_db_index(self.ds.engine, ['_tid_'])
            self.clean_data.create_db_index(self.ds.engine, ['_attribute_'])
            status = 'DONE Loading {fname}'.format(
                fname=os.path.basename(fpath))
        except Exception:
            logging.error('load_data for table %s', name)
            raise
        toc = time.clock()
        load_time = toc - tic
        return status, load_time

    def evaluate_repairs(self):
        self.compute_total_repairs()
        self.compute_total_repairs_grdt()
        self.compute_total_errors()
        self.compute_detected_errors()
        self.compute_correct_repairs()
        prec = self.compute_precision()
        rec = self.compute_recall()
        rep_recall = self.compute_repairing_recall()
        f1 = self.compute_f1()
        rep_f1 = self.compute_repairing_f1()

        if self.env['verbose']:
            self.log_weak_label_stats()

        return prec, rec, rep_recall, f1, rep_f1

    def eval_report(self):
        """
        Returns an EvalReport named tuple containing the experiment results.
        """
        tic = time.clock()
        try:
            prec, rec, rep_recall, f1, rep_f1 = self.evaluate_repairs()
            report = "Precision = %.2f, Recall = %.2f, Repairing Recall = %.2f, F1 = %.2f, Repairing F1 = %.2f, Detected Errors = %d, Total Errors = %d, Correct Repairs = %d, Total Repairs = %d, Total Repairs on correct cells (Grdth present) = %d, Total Repairs on incorrect cells (Grdth present) = %d" % (
                prec, rec, rep_recall, f1, rep_f1, self.detected_errors,
                self.total_errors, self.correct_repairs, self.total_repairs,
                self.total_repairs_grdt_correct,
                self.total_repairs_grdt_incorrect)
            eval_report = EvalReport(prec, rec, rep_recall, f1, rep_f1,
                                     self.detected_errors, self.total_errors,
                                     self.correct_repairs, self.total_repairs,
                                     self.total_repairs_grdt,
                                     self.total_repairs_grdt_correct,
                                     self.total_repairs_grdt_incorrect)
        except Exception as e:
            logging.error("ERROR generating evaluation report %s" % e)
            raise

        toc = time.clock()
        report_time = toc - tic
        return report, report_time, eval_report

    def compute_total_repairs(self):
        """
        compute_total_repairs memoizes the number of repairs:
        the # of cells that were inferred and where the inferred value
        is not equal to the initial value.
        """

        query = "SELECT count(*) FROM " \
                "  (SELECT _vid_ " \
                "     FROM {} as t1, {} as t2 " \
                "    WHERE t1._tid_ = t2._tid_ " \
                "      AND t1.attribute = t2.attribute " \
                "      AND t1.init_value != t2.rv_value) AS t".format(AuxTables.cell_domain.name,
                                                                      AuxTables.inf_values_dom.name)
        res = self.ds.engine.execute_query(query)
        self.total_repairs = float(res[0][0])

    def compute_total_repairs_grdt(self):
        """
        compute_total_repairs_grdt memoizes the number of repairs for cells
        that are specified in the clean/ground truth data. Otherwise repairs
        are defined the same as compute_total_repairs.

        We also distinguish between repairs on correct cells and repairs on
        incorrect cells (correct cells are cells where init == ground truth).
        """
        query = """
        SELECT
            (t1.init_value = t3._value_) AS is_correct,
            count(*)
        FROM   {} as t1, {} as t2, {} as t3
        WHERE  t1._tid_ = t2._tid_
          AND  t1.attribute = t2.attribute
          AND  t1.init_value != t2.rv_value
          AND  t1._tid_ = t3._tid_
          AND  t1.attribute = t3._attribute_
        GROUP BY is_correct
          """.format(AuxTables.cell_domain.name, AuxTables.inf_values_dom.name,
                     self.clean_data.name)
        res = self.ds.engine.execute_query(query)

        # Memoize the number of repairs on correct cells and incorrect cells.
        # Since we do a GROUP BY we need to check which row of the result
        # corresponds to the correct/incorrect counts.
        self.total_repairs_grdt_correct, self.total_repairs_grdt_incorrect = 0, 0
        self.total_repairs_grdt = 0
        if not res:
            return

        if res[0][0]:
            correct_idx, incorrect_idx = 0, 1
        else:
            correct_idx, incorrect_idx = 1, 0
        if correct_idx < len(res):
            self.total_repairs_grdt_correct = float(res[correct_idx][1])
        if incorrect_idx < len(res):
            self.total_repairs_grdt_incorrect = float(res[incorrect_idx][1])
        self.total_repairs_grdt = self.total_repairs_grdt_correct + self.total_repairs_grdt_incorrect

    def compute_total_errors(self):
        """
        compute_total_errors memoizes the number of cells that have a
        wrong initial value: requires ground truth data.
        """
        queries = []
        total_errors = 0.0
        for attr in self.ds.get_attributes():
            query = errors_template.substitute(
                init_table=self.ds.raw_data.name,
                grdt_table=self.clean_data.name,
                attr=attr)
            queries.append(query)
        results = self.ds.engine.execute_queries(queries)
        for res in results:
            total_errors += float(res[0][0])
        self.total_errors = total_errors

    def compute_detected_errors(self):
        """
        compute_detected_errors memoizes the number of error cells that
        were detected in error detection: requires ground truth.

        This value is always equal or less than total errors (see
        compute_total_errors).
        """
        query = "SELECT count(*) FROM " \
                "  (SELECT _vid_ " \
                "   FROM   %s as t1, %s as t2, %s as t3 " \
                "   WHERE  t1._tid_ = t2._tid_ AND t1._cid_ = t3._cid_ " \
                "     AND  t1.attribute = t2._attribute_ " \
                "     AND  t1.init_value != t2._value_) AS t" \
                % (AuxTables.cell_domain.name, self.clean_data.name, AuxTables.dk_cells.name)
        res = self.ds.engine.execute_query(query)
        self.detected_errors = float(res[0][0])

    def compute_correct_repairs(self):
        """
        compute_correct_repairs memoizes the number of error cells
        that were correctly inferred.

        This value is always equal or less than total errors (see
        compute_total_errors).
        """
        queries = []
        correct_repairs = 0.0
        for attr in self.ds.get_attributes():
            query = correct_repairs_template.substitute(
                init_table=self.ds.raw_data.name,
                grdt_table=self.clean_data.name,
                attr=attr,
                inf_dom=AuxTables.inf_values_dom.name)
            queries.append(query)
        results = self.ds.engine.execute_queries(queries)
        for res in results:
            correct_repairs += float(res[0][0])
        self.correct_repairs = correct_repairs

    def compute_recall(self):
        """
        Computes the recall (# of correct repairs / # of total errors).
        """
        if self.total_errors == 0:
            return 0
        return self.correct_repairs / self.total_errors

    def compute_repairing_recall(self):
        """
        Computes the _repairing_ recall (# of correct repairs / # of total
        _detected_ errors).
        """
        if self.detected_errors == 0:
            return 0
        return self.correct_repairs / self.detected_errors

    def compute_precision(self):
        """
        Computes precision (# correct repairs / # of total repairs w/ ground truth)
        """
        if self.total_repairs_grdt == 0:
            return 0
        return self.correct_repairs / self.total_repairs_grdt

    def compute_f1(self):
        prec = self.compute_precision()
        rec = self.compute_recall()
        if prec + rec == 0:
            return 0
        f1 = 2 * (prec * rec) / (prec + rec)
        return f1

    def compute_repairing_f1(self):
        prec = self.compute_precision()
        rec = self.compute_repairing_recall()
        if prec + rec == 0:
            return 0
        f1 = 2 * (prec * rec) / (prec + rec)
        return f1

    def log_weak_label_stats(self):
        query = """
        select
            (t3._tid_ is NULL) as clean,
            (t1.fixed) as status,
            (t4._tid_ is NOT NULL) as inferred,
            (t1.init_value = t2._value_) as init_eq_grdth,
            (t1.init_value = t4.rv_value) as init_eq_infer,
            (t1.weak_label = t1.init_value) as wl_eq_init,
            (t1.weak_label = t2._value_) as wl_eq_grdth,
            (t1.weak_label = t4.rv_value) as wl_eq_infer,
            (t2._value_ = t4.rv_value) as infer_eq_grdth,
            count(*) as count
        from
            {cell_domain} as t1,
            {clean_data} as t2
            left join {dk_cells} as t3 on t2._tid_ = t3._tid_ and t2._attribute_ = t3.attribute
            left join {inf_values_dom} as t4 on t2._tid_ = t4._tid_ and t2._attribute_ = t4.attribute where t1._tid_ = t2._tid_ and t1.attribute = t2._attribute_
        group by
            clean,
            status,
            inferred,
            init_eq_grdth,
            init_eq_infer,
            wl_eq_init,
            wl_eq_grdth,
            wl_eq_infer,
            infer_eq_grdth
        """.format(cell_domain=AuxTables.cell_domain.name,
                   clean_data=self.clean_data.name,
                   dk_cells=AuxTables.dk_cells.name,
                   inf_values_dom=AuxTables.inf_values_dom.name)

        res = self.ds.engine.execute_query(query)

        df_stats = pd.DataFrame(res,
                                columns=[
                                    "is_clean", "cell_status", "is_inferred",
                                    "init = grdth", "init = inferred",
                                    "w. label = init", "w. label = grdth",
                                    "w. label = inferred", "infer = grdth",
                                    "count"
                                ])
        df_stats = df_stats.sort_values(list(
            df_stats.columns)).reset_index(drop=True)
        logging.debug("weak label statistics:")
        pd.set_option('display.max_columns', None)
        pd.set_option('display.max_rows', len(df_stats))
        pd.set_option('display.max_colwidth', -1)
        logging.debug("%s", df_stats)
        pd.reset_option('display.max_columns')
        pd.reset_option('display.max_rows')
        pd.reset_option('display.max_colwidth')