Exemple #1
0
 def test_pgtd_active_dups(self):
     if not self.can_run:
         return
     schema = test_schema + "_act_dups"
     tdf_1 = TicDatFactory(t_one=[[],
                                  ["Field One", "Field Two", "Da Active"]],
                           t_two=[[], ["Field One", "Da Active"]])
     dat = tdf_1.TicDat(t_one=[["a", "b", True], ["a", "c", True],
                               ["a", "b", False], ["a", "d", True]],
                        t_two=[["a", True], ["b", False], ["a", False],
                               ["b", False], ["a", False]])
     self.assertTrue(len(dat.t_one) == 4 and len(dat.t_two) == 5)
     tdf_1.pgsql.write_schema(
         self.engine,
         schema,
         include_ancillary_info=False,
         forced_field_types={(t, f): "boolean" if "Active" in f else "text"
                             for t, (pks, dfs) in tdf_1.schema().items()
                             for f in pks + dfs})
     tdf_1.pgsql.write_data(dat, self.engine, schema)
     self.assertTrue(
         tdf_1._same_data(dat,
                          tdf_1.pgsql.create_tic_dat(self.engine, schema),
                          epsilon=1e-8))
     tdf = TicDatFactory(t_one=[["Field One", "Field Two"], []],
                         t_two=[["Field One"], []])
     self.assertTrue(tdf.pgsql.find_duplicates(self.engine, schema))
     self.assertFalse(
         tdf.pgsql.find_duplicates(self.engine,
                                   schema,
                                   active_fld="da_active"))
Exemple #2
0
 def test_dups(self):
     if not self.can_run:
         return
     tdf = TicDatFactory(one=[["a"], ["b", "c"]],
                         two=[["a", "b"], ["c"]],
                         three=[["a", "b", "c"], []])
     tdf2 = TicDatFactory(
         **{t: [[], ["a", "b", "c"]]
            for t in tdf.all_tables})
     td = tdf2.TicDat(
         **{
             t: [[1, 2, 1], [1, 2, 2], [2, 1, 3], [2, 2, 3], [1, 2, 2],
                 [5, 1, 2]]
             for t in tdf.all_tables
         })
     tdf2.pgsql.write_schema(self.engine, test_schema)
     tdf2.pgsql.write_data(td, self.engine, test_schema)
     dups = tdf.pgsql.find_duplicates(self.engine, test_schema)
     self.assertTrue(dups == {
         'three': {
             (1, 2, 2): 2
         },
         'two': {
             (1, 2): 3
         },
         'one': {
             1: 3,
             2: 2
         }
     })
Exemple #3
0
def copyDataDietWeirdCase2(dat):
    tdf = TicDatFactory(**dietSchemaWeirdCase2())
    tmp = copyDataDietWeirdCase(dat)
    rtn = tdf.TicDat(cateGories = tmp.cateGories, foodS = tmp.foodS)
    for (f,c),r in tmp.nutritionquantities.items():
        rtn.nutrition_quantities[f,c] = r
    return rtn
Exemple #4
0
    def test_time_stamp(self):
        tdf = TicDatFactory(table=[["Blah"], ["Timed Info"]])
        tdf.set_data_type("table", "Timed Info", nullable=True)
        tdf.set_default_value("table", "Timed Info", None)
        dat = tdf.TicDat()
        dat.table[1] = dateutil.parser.parse("2014-05-01 18:47:05.069722")
        dat.table[2] = dateutil.parser.parse("2014-05-02 18:47:05.178768")
        pgtf = tdf.pgsql
        pgtf.write_schema(self.engine,
                          test_schema,
                          forced_field_types={
                              ('table', 'Blah'): "integer",
                              ('table', 'Timed Info'): "timestamp"
                          })
        pgtf.write_data(dat,
                        self.engine,
                        test_schema,
                        dsn=self.postgresql.dsn())
        dat_2 = pgtf.create_tic_dat(self.engine, test_schema)
        self.assertTrue(tdf._same_data(dat, dat_2))
        self.assertTrue(
            all(
                isinstance(row["Timed Info"], datetime.datetime)
                for row in dat_2.table.values()))
        self.assertFalse(
            any(isinstance(k, datetime.datetime) for k in dat_2.table))

        pdf = PanDatFactory.create_from_full_schema(
            tdf.schema(include_ancillary_info=True))

        def same_data(pan_dat, pan_dat_2):
            df1, df2 = pan_dat.table, pan_dat_2.table
            if list(df1["Blah"]) != list(df2["Blah"]):
                return False
            for dt1, dt2 in zip(df1["Timed Info"], df2["Timed Info"]):
                delta = dt1 - dt2
                if abs(delta.total_seconds()) > 1e-6:
                    return False
            return True

        pan_dat = pdf.pgsql.create_pan_dat(self.engine, test_schema)
        pan_dat_2 = pan_dat_maker(tdf.schema(), dat_2)
        self.assertTrue(same_data(pan_dat, pan_dat_2))
        for df in [_.table for _ in [pan_dat, pan_dat_2]]:
            for i in range(len(df)):
                self.assertFalse(
                    isinstance(df.loc[i, "Blah"], datetime.datetime))
                self.assertTrue(
                    isinstance(df.loc[i, "Timed Info"], datetime.datetime))

        pan_dat.table.loc[1, "Timed Info"] = dateutil.parser.parse(
            "2014-05-02 18:48:05.178768")
        self.assertFalse(same_data(pan_dat, pan_dat_2))
        pdf.pgsql.write_data(pan_dat, self.engine, test_schema)
        pan_dat_2 = pdf.pgsql.create_pan_dat(self.engine, test_schema)
        self.assertTrue(same_data(pan_dat, pan_dat_2))

        dat.table[2] = dateutil.parser.parse("2014-05-02 18:48:05.178768")
        self.assertFalse(tdf._same_data(dat, dat_2))
Exemple #5
0
def copyDataDietWeirdCase(dat):
    tdf = TicDatFactory(**dietSchemaWeirdCase())
    rtn = tdf.TicDat()
    for c,r in dat.categories.items():
        rtn.cateGories[c]["miNnutrition"] = r["minNutrition"]
        rtn.cateGories[c]["maXnutrition"] = r["maxNutrition"]
    for f,r in dat.foods.items():
        rtn.foodS[f] = r["cost"]
    for (f,c),r in dat.nutritionQuantities.items():
        rtn.nutritionquantities[f,c] = r["qty"]
    return rtn
Exemple #6
0
    def testNullsAndInf(self):
        tdf = TicDatFactory(table=[["field one"], ["field two"]])
        for f in ["field one", "field two"]:
            tdf.set_data_type("table", f, nullable=True)
        dat = tdf.TicDat(
            table=[[None, 100], [200, 109], [0, 300], [300, None], [400, 0]])
        schema = test_schema + "_bool_defaults"
        tdf.pgsql.write_schema(self.engine,
                               schema,
                               include_ancillary_info=False)
        tdf.pgsql.write_data(dat, self.engine, schema)

        dat_1 = tdf.pgsql.create_tic_dat(self.engine, schema)
        self.assertTrue(tdf._same_data(dat, dat_1))

        tdf = TicDatFactory(table=[["field one"], ["field two"]])
        for f in ["field one", "field two"]:
            tdf.set_data_type("table", f, max=float("inf"), inclusive_max=True)
        tdf.set_infinity_io_flag(None)
        dat_inf = tdf.TicDat(table=[[float("inf"), 100], [200, 109], [0, 300],
                                    [300, float("inf")], [400, 0]])
        dat_1 = tdf.pgsql.create_tic_dat(self.engine, schema)

        self.assertTrue(tdf._same_data(dat_inf, dat_1))
        tdf.pgsql.write_data(dat_inf, self.engine, schema)
        dat_1 = tdf.pgsql.create_tic_dat(self.engine, schema)
        self.assertTrue(tdf._same_data(dat_inf, dat_1))

        tdf = TicDatFactory(table=[["field one"], ["field two"]])
        for f in ["field one", "field two"]:
            tdf.set_data_type("table",
                              f,
                              min=-float("inf"),
                              inclusive_min=True)
        tdf.set_infinity_io_flag(None)
        dat_1 = tdf.pgsql.create_tic_dat(self.engine, schema)
        self.assertFalse(tdf._same_data(dat_inf, dat_1))
        dat_inf = tdf.TicDat(table=[[float("-inf"), 100], [200, 109], [0, 300],
                                    [300, -float("inf")], [400, 0]])
        self.assertTrue(tdf._same_data(dat_inf, dat_1))
Exemple #7
0
 def test_parameters(self):
     schema = test_schema + "_parameters"
     tdf = TicDatFactory(parameters=[["Key"], ["Value"]])
     tdf.add_parameter("Something", 100)
     tdf.add_parameter("Different",
                       'boo',
                       strings_allowed='*',
                       number_allowed=False)
     dat = tdf.TicDat(
         parameters=[["Something", float("inf")], ["Different", "inf"]])
     tdf.pgsql.write_schema(self.engine, schema)
     tdf.pgsql.write_data(dat, self.engine, schema)
     dat_ = tdf.pgsql.create_tic_dat(self.engine, schema)
     self.assertTrue(tdf._same_data(dat, dat_))
Exemple #8
0
 def testCircularFks(self):
     schema = test_schema + "circular_fks"
     tdf = TicDatFactory(table_one=[["A Field"], []],
                         table_two=[["B Field"], []],
                         table_three=[["C Field"], []])
     tdf.add_foreign_key("table_one", "table_two", ["A Field", "B Field"])
     tdf.add_foreign_key("table_two", "table_three", ["B Field", "C Field"])
     tdf.add_foreign_key("table_three", "table_one", ["C Field", "A Field"])
     tdf.pgsql.write_schema(self.engine,
                            schema,
                            include_ancillary_info=False)
     t_ = [["a"], ["b"], ["c"]]
     dat = tdf.TicDat(table_one=t_, table_two=t_, table_three=t_)
     tdf.pgsql.write_data(dat, self.engine, schema)
Exemple #9
0
def netflowSolver(modelType):
    tdf = TicDatFactory(**netflowSchema())
    addNetflowForeignKeys(tdf)
    addNetflowDataTypes(tdf)

    dat = tdf.copy_tic_dat(netflowData())
    assert not tdf.find_data_type_failures(
        dat) and not tdf.find_foreign_key_failures(dat)

    mdl = Model(modelType, "netflow")

    flow = {}
    for h, i, j in dat.cost:
        if (i, j) in dat.arcs:
            flow[h, i, j] = mdl.add_var(name='flow_%s_%s_%s' % (h, i, j))

    flowslice = Slicer(flow)

    for i_, j_ in dat.arcs:
        mdl.add_constraint(mdl.sum(flow[h, i, j]
                                   for h, i, j in flowslice.slice('*', i_, j_))
                           <= dat.arcs[i_, j_]["capacity"],
                           name='cap_%s_%s' % (i_, j_))

    for h_, j_ in set(k for k, v in dat.inflow.items()
                      if abs(v["quantity"]) > 0).union(
                          {(h, i)
                           for h, i, j in flow}, {(h, j)
                                                  for h, i, j in flow}):
        mdl.add_constraint(
            mdl.sum(flow[h, i, j]
                    for h, i, j in flowslice.slice(h_, '*', j_)) +
            dat.inflow.get((h_, j_), {"quantity": 0})["quantity"] == mdl.sum(
                flow[h, i, j] for h, i, j in flowslice.slice(h_, j_, '*')),
            name='node_%s_%s' % (h_, j_))

    mdl.set_objective(
        mdl.sum(flow * dat.cost[h, i, j]["cost"]
                for (h, i, j), flow in flow.items()))
    if mdl.optimize():
        solutionFactory = TicDatFactory(
            flow=[["commodity", "source", "destination"], ["quantity"]])
        if mdl.optimize():
            rtn = solutionFactory.TicDat()
            for (h, i, j), var in flow.items():
                if mdl.get_solution_value(var) > 0:
                    rtn.flow[h, i, j] = mdl.get_solution_value(var)
            return rtn, sum(dat.cost[h, i, j]["cost"] * r["quantity"]
                            for (h, i, j), r in rtn.flow.items())
Exemple #10
0
 def test_wtf(self):
     schema = "wtf"
     tdf = TicDatFactory(
         table_one=[["Cost per Distance", "Cost per Hr. (in-transit)"],
                    ["Stuff"]],
         table_two=[["This", "That"], ["Tho"]])
     tdf.pgsql.write_schema(self.engine, schema)
     data = [["a", "b", 1], ["dd", "ee", 10], ["023", "210", 102.1]]
     tic_dat = tdf.TicDat(table_one=data, table_two=data)
     tdf.pgsql.write_data(tic_dat,
                          self.engine,
                          schema,
                          dsn=self.postgresql.dsn())
     pg_tic_dat = tdf.pgsql.create_tic_dat(self.engine, schema)
     self.assertTrue(tdf._same_data(tic_dat, pg_tic_dat))
Exemple #11
0
 def test_missing_tables(self):
     schema = test_schema + "_missing_tables"
     tdf_1 = TicDatFactory(this=[["Something"], ["Another"]])
     pdf_1 = PanDatFactory(**tdf_1.schema())
     tdf_2 = TicDatFactory(
         **dict(tdf_1.schema(), that=[["What", "Ever"], []]))
     pdf_2 = PanDatFactory(**tdf_2.schema())
     dat = tdf_1.TicDat(this=[["a", 2], ["b", 3], ["c", 5]])
     pan_dat = tdf_1.copy_to_pandas(dat, drop_pk_columns=False)
     tdf_1.pgsql.write_schema(self.engine, schema)
     tdf_1.pgsql.write_data(dat, self.engine, schema)
     pg_dat = tdf_2.pgsql.create_tic_dat(self.engine, schema)
     self.assertTrue(tdf_1._same_data(dat, pg_dat))
     pg_pan_dat = pdf_2.pgsql.create_pan_dat(self.engine, schema)
     self.assertTrue(pdf_1._same_data(pan_dat, pg_pan_dat))
 def _copy_to_tic_dat(self, pan_dat, keep_generics_as_df=True):
     sch = self.schema()
     if not keep_generics_as_df:
         for t in self.generic_tables:
             sch[t] = [[], list(getattr(pan_dat, t).columns)]
     from ticdat import TicDatFactory
     tdf = TicDatFactory(**sch)
     def df(t):
         rtn = getattr(pan_dat, t)
         if self.primary_key_fields.get(t, ()):
             return rtn.set_index(list(self.primary_key_fields[t]), drop=False)
         if t in self.generic_tables and not keep_generics_as_df:
             return list(map(list, rtn.itertuples(index=False)))
         return rtn
     return tdf.TicDat(**{t: df(t) for t in self.all_tables})
Exemple #13
0
def dietSolver(modelType):
    tdf = TicDatFactory(**dietSchema())
    addDietForeignKeys(tdf)
    addDietDataTypes(tdf)

    dat = tdf.copy_tic_dat(dietData())
    assert not tdf.find_data_type_failures(
        dat) and not tdf.find_foreign_key_failures(dat)

    mdl = Model(modelType, "diet")

    nutrition = {}
    for c, n in dat.categories.items():
        nutrition[c] = mdl.add_var(lb=n["minNutrition"],
                                   ub=n["maxNutrition"],
                                   name=c)

    # Create decision variables for the foods to buy
    buy = {}
    for f in dat.foods:
        buy[f] = mdl.add_var(name=f)

    # Nutrition constraints
    for c in dat.categories:
        mdl.add_constraint(mdl.sum(dat.nutritionQuantities[f, c]["qty"] *
                                   buy[f] for f in dat.foods) == nutrition[c],
                           name=c)

    mdl.set_objective(mdl.sum(buy[f] * c["cost"]
                              for f, c in dat.foods.items()))

    if mdl.optimize():
        solutionFactory = TicDatFactory(parameters=[[], ["totalCost"]],
                                        buyFood=[["food"], ["qty"]],
                                        consumeNutrition=[["category"],
                                                          ["qty"]])
        sln = solutionFactory.TicDat()
        for f, x in buy.items():
            if mdl.get_solution_value(x) > 0.0001:
                sln.buyFood[f] = mdl.get_solution_value(x)
        for c, x in nutrition.items():
            sln.consumeNutrition[c] = mdl.get_solution_value(x)
        return sln, sum(dat.foods[f]["cost"] * r["qty"]
                        for f, r in sln.buyFood.items())
Exemple #14
0
 def test_ints_and_strings_and_lists(self):
     if not self.can_run:
         return
     tdf = TicDatFactory(t_one=[[], ["str_field", "int_field"]],
                         t_two=[["str_field", "int_field"], []])
     for t in tdf.all_tables:
         tdf.set_data_type(t,
                           "str_field",
                           strings_allowed=['This', 'That'],
                           number_allowed=False)
         tdf.set_data_type(t, "int_field", must_be_int=True)
     dat = tdf.TicDat(t_one=[["This", 1], ["That", 2], ["This", 111],
                             ["That", 211]],
                      t_two=[["This", 10], ["That", 9]])
     self.assertFalse(tdf.find_data_type_failures(dat))
     self.assertTrue(len(dat.t_one) == 4)
     self.assertTrue(len(dat.t_two) == 2)
     pgtf = tdf.pgsql
     pgtf.write_schema(self.engine, test_schema)
     pgtf.write_data(dat, self.engine, test_schema)
     self.assertFalse(pgtf.find_duplicates(self.engine, test_schema))
     pg_tic_dat = pgtf.create_tic_dat(self.engine, test_schema)
     self.assertTrue(tdf._same_data(dat, pg_tic_dat))
Exemple #15
0
 def test_true_false(self):
     if not self.can_run:
         return
     tdf = TicDatFactory(table=[["pkf"], ["df1", "df2"]])
     tdf.set_data_type("table", "df2", min=-float("inf"))
     dat = tdf.TicDat(table=[["d1", True, 100], ["d2", False, 200],
                             ["d3", False, -float("inf")]])
     self.assertTrue(len(dat.table) == 3)
     self.assertFalse(tdf.find_data_type_failures(dat))
     pgtf = tdf.pgsql
     ex = None
     try:
         pgtf.write_data(None, self.engine, test_schema)
     except utils.TicDatError as te:
         ex = str(te)
     self.assertTrue(ex and "Not a valid TicDat object" in ex)
     pgtf.write_schema(self.engine,
                       test_schema,
                       forced_field_types={("table", "df1"): "bool"})
     pgtf.write_data(dat, self.engine, test_schema)
     self.assertFalse(pgtf.find_duplicates(self.engine, test_schema))
     pg_tic_dat = pgtf.create_tic_dat(self.engine, test_schema)
     self.assertTrue(tdf._same_data(dat, pg_tic_dat))
class Action(metaclass=ABCMeta):
    """Every action should inherit from this class."""

    @staticmethod
    def _ensure_docs(action):
        if not inspect.getdoc(action):
            raise Exception(
                f"Add a docstring to '{type(action).__name__}' class"
            )
        for method_name in action.method_names:
            method_doc = inspect.getdoc(getattr(action, method_name))
            if not method_doc:
                raise Exception(
                    f"Add a docstring to '{method_name}' method"
                )
            elif (
                    method_name == 'execute_action' and
                    inspect.getdoc(getattr(Action, method_name)) == method_doc
            ):
                raise Exception(
                    f"Add a docstring to execute_action method"
                )

    def __new__(action_class, *args, **kwargs):
        action = super().__new__(action_class)
        Action._ensure_docs(action)
        action._data_source_mappings = {'local': {}}
        action.set_local_data_source(
            '',
            (
                Path('../Inputs') if Path('../Inputs').is_dir() else Path('.')
            ).absolute()
        )

        # todo add additional checks like checking for docker.sock
        if len(sys.argv) == 3 and sys.argv[2].endswith('.json'):
            ensure_packages('sqlalchemy')
            from sqlalchemy.engine.url import URL
            from sqlalchemy import create_engine

            scenario_name, config_path = sys.argv[1], sys.argv[2]
            try:
                with open(config_path, 'r') as fp:
                    db_config = json.load(fp)['database']
                    action._enframe_db_url = str(URL(
                        'postgres',
                        username=db_config['dbusername'],
                        password=db_config['dbpassword'],
                        host=db_config['dbserverName'],
                        port=db_config['port'],
                        database=db_config['dbname']
                    ))
            except:
                pass
            else:
                action._is_running_on_enframe = True
                action._enframe_engine = create_engine(
                    action._enframe_db_url
                )
                action._enframe_scenario_name = scenario_name
                action._data_source_mappings['enframe'] = {}
                action.set_enframe_data_source(
                    '',
                    {
                        'db_url': action.enframe_db_url,
                        'db_schema': action.enframe_scenario_name
                    }
                )
                action.set_enframe_data_source('config_schema', {
                    'db_schema': (
                            type(action).__name__.lower()
                            + '_' + action._enframe_scenario_name
                    )
                })
        return action

    @property
    def is_running_on_enframe(self):
        """Indicates whether the action is running on Enframe or locally"""
        return getattr(self, '_is_running_on_enframe', False)

    @is_running_on_enframe.setter
    def is_running_on_enframe(self, value):
        if type(value) is not bool:
            raise ValueError('is_running_on_enframe should be bool')
        setattr(self, '_is_running_on_enframe', value)

    enframe_db_url = property(
        fget=lambda self: getattr(self, '_enframe_db_url', None),
        fset=None, fdel=None,
        doc='URL for Enframe app database of the action'
    )

    enframe_scenario_name = property(
        fget=lambda self: getattr(self, '_enframe_scenario_name', None),
        fset=None, fdel=None,
        doc='Name of the Enframe scenario running the action'
    )

    enframe_connection = property(
        fget=lambda self: getattr(self, '_enframe_engine', None),
        fset=None, fdel=None,
        doc='Connection to Enframe app database'
    )

    @property
    def is_setup_on_enframe(self):
        '''Indicates whether the action UI has been setup on Enframe'''
        try:
            schema_names = chain.from_iterable(
                self.enframe_connection.execute(
                    'SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA;'
                )
            )
            config_schema_name = (
                    type(
                        self).__name__.lower() + '_' + self.enframe_scenario_name
            )
            for schema_name in schema_names:
                if schema_name == config_schema_name:
                    return True
        except:
            pass
        return False

    @staticmethod
    def _is_correct_schema(tic_or_pan_dat, tic_or_pan_dat_schema):
        return (
                (type(tic_or_pan_dat) is
                 getattr(tic_or_pan_dat_schema, 'TicDat', None)
                 ) or
                (type(tic_or_pan_dat) is
                 getattr(tic_or_pan_dat_schema, 'PanDat', None)
                 )
        )

    @property
    def config_schema(self):
        '''Configuration schema for the action'''
        return getattr(self, '_config_schema', None)

    @config_schema.setter
    def config_schema(self, tic_or_pan_dat_schema):
        if (
                type(self.config_schema) not in
                (TicDatFactory, PanDatFactory, type(None))
        ):
            raise TypeError(
                "'config_schema' should be a TicDatFactory/PanDatFactory "
                'schema representing the configuration tables for your action'
            )

        self._config_schema = tic_or_pan_dat_schema
        if not tic_or_pan_dat_schema:
            return

        schema_info = tic_or_pan_dat_schema.schema()
        for table_name, primary_and_data_fields in schema_info.items():
            for field_name in chain.from_iterable(primary_and_data_fields):
                self._config_schema.set_data_type(
                    table_name, field_name, strings_allowed='*',
                    number_allowed=False
                )

    @property
    def config_defaults(self):
        '''Configuration defaults for the action'''
        return getattr(self, '_config_defaults', None)

    @config_defaults.setter
    def config_defaults(self, tic_or_pan_dat):
        if not tic_or_pan_dat:
            self._config_defaults = tic_or_pan_dat
            return

        if not Action._is_correct_schema(tic_or_pan_dat, self.config_schema):
            raise TypeError(
                "'config_defaults' should be a TicDat/PanDat object "
                'containing the default data for your configuration tables.\n'
                'Use self.config_schema.TicDat or self.config_schema.PanDat '
                'depending on whether self.config_schema is a '
                'TicDatFactory/PanDatFactory schema.'
            )
        self._config_defaults = tic_or_pan_dat

    def setup_enframe_ui(self):
        '''Sets up UI for configuration tables on Enframe'''
        if not self.is_running_on_enframe:
            return

        action_name = type(self).__name__
        action_db_name = action_name.lower()
        action_config_display_name = action_name + ' Configuration'

        enframe_ui_tables = {
            'lkp_data_upload_tables': [
                [],
                ['id', 'scenario_template_id', 'order_id', 'tablename',
                 'displayname', 'columnlist', 'displaylist',
                 'columnlistwithtypes', 'visible', 'type', 'unique_key',
                 'visiblecolumns', 'editablecolumns', 'select_query', 'tag',
                 'columnfloat', 'version', 'filter', 'created_at',
                 'updated_at', 'created_by', 'updated_by'
                 ]
            ],
            'lkp_views': [
                ['id'], ['table_id', 'definition']
            ],
            'projects': [
                ['id'],
                ['scenario_template_id', 'order_id', 'name', 'tag_id',
                 'status', 'version', 'created_at', 'updated_at', 'archived_at',
                 'created_by', 'updated_by', 'archived_by'
                 ]
            ],
            'project_tables': [
                ['id'],
                ['pid', 'name', 'file_name', 'table_name', 'status', 'visible',
                 'type', 'columns', 'created_at', 'updated_at', 'created_by',
                 'updated_by'
                 ]
            ]
        }
        self._enframe_ui = TicDatFactory(**enframe_ui_tables)

        self.set_enframe_data_source('_enframe_ui', {'db_schema': 'public'})
        ui_data = self.read_data('_enframe_ui')

        current_project_tuple = next(filter(
            lambda key_and_row: (
                    key_and_row[1]['name'].lower().replace(' ', '_') ==
                    self.enframe_scenario_name
            ),
            ui_data.projects.items()
        ), None)
        if not current_project_tuple:
            raise Exception(
                "Couldn't find project associated with current scenario"
            )
        project_id = current_project_tuple[0]
        project_version = current_project_tuple[1]['version']
        scenario_template_id = current_project_tuple[1]['scenario_template_id']

        config_schema_name = action_db_name + '_' + self.enframe_scenario_name
        self.enframe_connection.execute(
            f'CREATE SCHEMA IF NOT EXISTS {config_schema_name};'
        )
        self.write_data(self.config_defaults, create_tables=True)
        # todo use superticdat for this

        config_schema_info = self.config_schema.schema()
        for config_table_name in self.config_schema.all_tables:
            # table_db_name = action_db_name + '_' + config_table_name
            table_db_name = config_table_name
            table_display_name = table_db_name.replace('_', ' ').title()
            column_display_names = list(
                chain.from_iterable(config_schema_info[table_db_name])
            )
            column_db_names = [
                display_name.lower().replace(' ', '_')
                for display_name in column_display_names
            ]

            column_db_names_str = ','.join(column_db_names)
            column_db_name_and_types_str = ', '.join(
                f'"{column_name}" text'
                for column_name in column_db_names
            )
            column_db_select_str = ', '.join(
                f'"{column_name}" AS "{column_name}"'
                for column_name in column_db_names
            )

            ui_data.project_tables = {
                key: row
                for key, row in ui_data.project_tables.items()
                if row['pid'] != project_id or
                   row['table_name'] != table_db_name
            }
            next_project_tables_id = max(chain(ui_data.project_tables, [0])) + 1
            ui_data.project_tables[next_project_tables_id] = {
                'pid': project_id,
                'name': table_display_name,
                'file_name': None,
                'table_name': table_db_name,
                'status': 'Uploaded Successfully',
                'visible': 'true',
                'type': 'input_view',
                'columns': column_db_name_and_types_str,
                'created_at': 'NOW()',
                'updated_at': 'NOW()',
                'created_by': 'Administrator',
                'updated_by': 'Administrator'
            }

            ui_data.lkp_data_upload_tables = [
                row for row in ui_data.lkp_data_upload_tables
                if row['scenario_template_id'] != scenario_template_id or
                   row['tablename'] != table_db_name
            ]
            ui_data.lkp_data_upload_tables.append({
                'id': max(
                    chain(
                        (row['id'] for row in ui_data.lkp_data_upload_tables),
                        [0]
                    )
                ) + 1,
                'scenario_template_id': scenario_template_id,
                'order_id': max(
                    chain(
                        (
                            row['order_id']
                            for row in ui_data.lkp_data_upload_tables
                            if row['type'] in ('input', 'input_view')
                        ), [0]
                    )
                ) + 1,
                'tablename': table_db_name,
                'displayname': table_display_name,
                'columnlist': column_db_names_str,
                'displaylist': dict(zip(column_db_names, column_display_names)),
                'columnlistwithtypes': column_db_name_and_types_str,
                'visible': 'true',
                'type': 'input_view',
                'unique_key': '',
                'visiblecolumns': column_db_names_str,
                'editablecolumns': column_db_names_str,
                'select_query': column_db_select_str,
                'tag': action_config_display_name,
                'columnfloat': '{}',
                'version': project_version,
                'filter': '',
                'created_at': 'NOW()',
                'updated_at': 'NOW()',
                'created_by': None,
                'updated_by': None
            })

            self.enframe_connection.execute(
                f'ALTER TABLE {config_schema_name}.{table_db_name} '
                'ADD COLUMN IF NOT EXISTS jqgrid_id INTEGER;'
            )
            self.enframe_connection.execute(
                f'UPDATE {config_schema_name}.{table_db_name} '
                f'SET jqgrid_id = jqgrid.id '
                'FROM (SELECT ROW_NUMBER() OVER () AS id FROM '
                f'{config_schema_name}.{table_db_name}) AS jqgrid'
            )
            self.enframe_connection.execute(
                f'CREATE OR REPLACE VIEW {table_db_name} AS '
                f'SELECT * FROM {config_schema_name}.{table_db_name};'
            )

            # todo change scenario_template_id
            table_id = next(filter(
                lambda row: (
                        row['scenario_template_id'] == scenario_template_id
                        and row['tablename'] == table_db_name
                ),
                ui_data.lkp_data_upload_tables
            ))['id']

            # todo better comparison, i.e. reference the id before it is
            # removed from the lkp_data_upload table
            view_definition = (
                f'SELECT * FROM {config_schema_name}.{table_db_name};'
            )
            ui_data.lkp_views = {
                id: row
                for id, row in ui_data.lkp_views.items()
                if row['definition'] != view_definition
            }
            ui_data.lkp_views[max(chain(ui_data.lkp_views, [0])) + 1] = {
                'table_id': table_id,
                'definition': view_definition
            }

        for row in ui_data.lkp_data_upload_tables:
            for dict_field in ('displaylist', 'columnfloat'):
                row[dict_field] = json.dumps(row[dict_field])

        # todo use superticdat for writing subschema
        del enframe_ui_tables['projects']
        self._enframe_write_ui = TicDatFactory(**enframe_ui_tables)
        self.set_enframe_data_source(
            '_enframe_write_ui', {'db_schema': 'public'}
        )
        ui_write_data = self._enframe_write_ui.TicDat(
            lkp_data_upload_tables=ui_data.lkp_data_upload_tables,
            # todo see if this is needed
            # project_tables=ui_data.project_tables,
            lkp_views=ui_data.lkp_views
        )
        self.write_data(ui_write_data)

    @staticmethod
    def _get_schema_and_table_name(schema_or_table_name):
        if (
                not isinstance(schema_or_table_name, str)
                or len(schema_or_table_name.split('.')) not in (1, 2)
        ):
            raise ValueError(
                f'Check {schema_or_table_name}\n'
                "'schema_or_table_name' should be a str of the form "
                '<schema_name> or <schema_name>.<table_name>'
            )
        schema_or_table_name_split = schema_or_table_name.split('.')
        if len(schema_or_table_name_split) == 1:
            schema_name, table_name = schema_or_table_name_split[0], None
        else:
            schema_name, table_name = schema_or_table_name_split
            if not table_name:
                raise ValueError(
                    f'Check {schema_or_table_name}\n'
                    '<table_name> in <schema_name>.<table_name> cannot be an '
                    'empty string'
                )
        return schema_name, table_name

    # todo set_hierarchical_params in case of directory
    # todo check this for file_or_dir
    @staticmethod
    def _get_data_source(
            data_source_mappings, schema_or_table_name,
            include_data_source_type=False
    ):
        def get_return_value(data_source_and_type, **kwargs):
            if not include_data_source_type:
                return data_source_and_type[0]
            else:
                return data_source_and_type

        def set_hierarchical_params(
                data_source_and_type, default_data_source_and_type_list
        ):
            def set_default_param(data_source, param):
                if param not in data_source:
                    for (
                            default_data_source, _
                    ) in default_data_source_and_type_list:
                        if param in default_data_source:
                            data_source[param] = default_data_source[param]
                            return

            data_source, data_type = data_source_and_type
            if data_type == 'db':
                set_default_param(data_source, 'db_url')
                set_default_param(data_source, 'db_schema')

        if schema_or_table_name == '' or 'schemas' not in data_source_mappings:
            return get_return_value(data_source_mappings['source'])

        schema_name, table_name = Action._get_schema_and_table_name(
            schema_or_table_name
        )
        if not table_name:
            if schema_name not in data_source_mappings['schemas']:
                data_source_and_type = data_source_mappings['source']
            else:
                data_source_and_type = (
                    data_source_mappings['schemas'][schema_name]['source']
                )
            set_hierarchical_params(
                data_source_and_type, [data_source_mappings['source']]
            )
        else:
            if schema_name not in data_source_mappings['schemas']:
                data_source_and_type = data_source_mappings['source']
            elif (
                    table_name not in
                    data_source_mappings['schemas'][schema_name]['tables']
            ):
                data_source_and_type = (
                    data_source_mappings['schemas'][schema_name]['source']
                )
            else:
                data_source_and_type = (
                    data_source_mappings['schemas']
                    [schema_name]['tables'][table_name]['source']
                )

            # DB table name is the same as table name in TicDat schema
            # by default. Note that this is not so for schema names
            if data_source_and_type[1] == 'db':
                if 'db_table' not in data_source_and_type[0]:
                    data_source_and_type[0]['db_table'] = table_name
            set_hierarchical_params(
                data_source_and_type, [
                    data_source_mappings['schemas'][schema_name]['source'],
                    data_source_mappings['source']
                ]
            )
        return get_return_value(data_source_and_type)

    @staticmethod
    def _get_data_source_type(data_source):
        if isinstance(data_source, str):
            if urlparse(data_source).scheme:
                raise ValueError(
                    f'Check {data_source}\n'
                    "A database 'data_source' should be a dict of the form "
                    '{db_url:..., db_schema:..., db_table:...} with at least '
                    'one of the dict items. All the dict values are strings.\n'
                    "But, if 'db_url' is present then 'db_schema' should also "
                    'be present.'
                )
            return 'file_or_dir'
        elif isinstance(data_source, Path):
            return 'file_or_dir'
        elif isinstance(data_source, dict) and data_source:
            if (
                    data_source.get('db_url', None)
                    and not data_source.get('db_schema', None)
            ):
                raise ValueError(
                    f'Check {data_source}\n'
                    "If 'db_url' is present then 'db_schema' should also be "
                    "present in the 'data_source' dictionary"
                )
            elif (
                    'db_url' in data_source
                    and urlparse(data_source['db_url']).scheme not in (
                            'postgresql', 'postgres'
                    )
            ):
                raise ValueError(
                    f'Check {data_source}\n'
                    'Only PostgreSQL database is supported at the moment'
                )
            return 'db'
        else:
            raise ValueError(
                f'Check {data_source}\n'
                "'data_source' can represent a file/directory or database.\n"
                "A file/directory 'data_source' can be a str or pathlib.Path "
                'object with the file/directory path.\n'
                "A database 'data_source' should be a dict of the form "
                '{db_url:..., db_schema:..., db_table:...} with at least '
                'one of the dict items. All the dict values are strings.\n'
                "But, if 'db_url' is present then 'db_schema' should also "
                'be present.'
            )

    @staticmethod
    def _check_db_data_source(schema_or_table_name, data_source):
        # db_url without db_schema is handled in _get_data_source_type
        if schema_or_table_name == '':
            if not data_source.get('db_schema', None):
                raise ValueError(
                    f'Check {data_source}\n'
                    "'db_schema' is not present in 'data_source' dict"
                )
            if data_source.get('db_table', None):
                raise ValueError(
                    f'Check {data_source}\n'
                    "'db_table' cannot be specified when setting "
                    'data source for all schemas'
                )

        _, table_name = Action._get_schema_and_table_name(schema_or_table_name)
        if not table_name:
            if not data_source.get('db_schema', None):
                raise ValueError(
                    f'Check {data_source}\n'
                    "'db_schema' is not present in 'data_source' dict"
                )
            if data_source.get('db_table', None):
                raise ValueError(
                    f'Check {data_source}\n'
                    "'db_table' cannot be specified when setting "
                    'data source for a schema'
                )
        else:
            if not data_source.get('db_table', None):
                raise ValueError(
                    f'Check {data_source}\n'
                    "'db_table' is not present in 'data_source' dict"
                )

    # todo data sources can also be ticdat/pandat objects
    @staticmethod
    def _set_data_source(
            data_source_mappings, schema_or_table_name, data_source,
            data_source_type=None
    ):
        data_source_type = (
                data_source_type or Action._get_data_source_type(data_source)
        )
        if data_source_type == 'db':
            Action._check_db_data_source(schema_or_table_name, data_source)

        if schema_or_table_name == '':
            data_source_mappings['source'] = (data_source, data_source_type)
        else:
            schema_name, table_name = Action._get_schema_and_table_name(
                schema_or_table_name
            )
            if not table_name:
                data_source_mappings \
                    .setdefault('schemas', {}) \
                    .setdefault(schema_name, {})['source'] = (
                    data_source, data_source_type
                )
            else:
                data_source_mappings \
                    .setdefault('schemas', {}) \
                    .setdefault(schema_name, {}) \
                    .setdefault('tables', {}) \
                    .setdefault(table_name, {})['source'] = (
                    data_source, data_source_type
                )

    def get_enframe_data_source(
            self, schema_or_table_name, include_data_source_type=False
    ):
        '''
        Get the data source being used for a TicDat schema/table when the action
        is running on Enframe
        '''
        return Action._get_data_source(
            self._data_source_mappings['enframe'],
            schema_or_table_name,
            include_data_source_type=include_data_source_type
        )

    def set_enframe_data_source(self, schema_or_table_name, data_source):
        '''
        Set the data source to be used for a TicDat schema/table when the action
        is running on Enframe
        '''
        Action._set_data_source(
            self._data_source_mappings['enframe'],
            schema_or_table_name,
            data_source
        )

    def get_local_data_source(
            self, schema_or_table_name, include_data_source_type=False
    ):
        '''
        Get the data source being used for a TicDat schema/table when the action
        is running locally
        '''
        return Action._get_data_source(
            self._data_source_mappings['local'],
            schema_or_table_name,
            include_data_source_type=include_data_source_type
        )

    def set_local_data_source(self, schema_or_table_name, data_source):
        '''
        Set the data source to be used for a TicDat schema/table when the action
        is running locally
        '''
        Action._set_data_source(
            self._data_source_mappings['local'],
            schema_or_table_name,
            data_source
        )

    @staticmethod
    def _get_data_path_and_type(file_or_dir_path, include_extension=False):
        valid_extensions = [
            'csv', 'json', 'xls', 'xlsx', 'db', 'sql', 'mdb', 'accdb'
        ]
        file_or_dir_path = Path(file_or_dir_path)
        extension = file_or_dir_path.resolve().suffix[1:].lower() or None

        data_path = file_or_dir_path
        if file_or_dir_path.is_file() and extension in valid_extensions:
            data_file_type = extension
            if extension == 'csv':
                data_path = Path(file_or_dir_path).parent
            if extension == 'xlsx':
                data_file_type = 'xls'
            elif extension == 'accdb':
                data_file_type = 'xls'
        elif file_or_dir_path.is_dir():
            # Assumes CSV files are to be read. Reading multiple schemas
            # from different data sources is handled in read_data
            # methods which will then pass the file path rather than
            # directory path if files other than CSV files are to be read
            # NOTE the same doesn't apply to write_data as it will create
            # a file with the schema name if a particular data source isn't
            # specified.
            data_file_type = 'csv'
        else:
            raise TypeError(
                f'Check {file_or_dir_path}\n'
                'TicDat can only read from the following file types '
                f"{valid_extensions}"
            )
        return (
            (data_path, data_file_type) if not include_extension
            else (data_path, data_file_type, extension)
        )

    @staticmethod
    def _read_data_from_file_system(tic_or_pan_dat_schema, file_or_dir_path):
        data_path, data_file_type = Action._get_data_path_and_type(
            file_or_dir_path
        )
        if type(tic_or_pan_dat_schema) is TicDatFactory:
            read_method_name = (
                'create_tic_dat' if data_file_type != 'sql'
                else 'create_tic_dat_from_sql'
            )
        elif type(tic_or_pan_dat_schema) is PanDatFactory:
            read_method_name = 'create_pan_dat'
        return getattr(
            getattr(tic_or_pan_dat_schema, data_file_type),
            read_method_name
        )(data_path)

    @staticmethod
    def _read_data_from_db(
            tic_or_pan_dat_schema, db_engine_or_url, db_schema
    ):
        ensure_packages('sqlalchemy', 'framework_utils')
        from sqlalchemy.engine import Connectable
        from framework_utils.pgtd import PostgresTicFactory, PostgresPanFactory

        if isinstance(db_engine_or_url, Connectable):
            db_engine = db_engine_or_url
        else:
            from sqlalchemy import create_engine
            db_engine = create_engine(db_engine_or_url)

        if type(tic_or_pan_dat_schema) is TicDatFactory:
            read_method = \
                PostgresTicFactory(tic_or_pan_dat_schema).create_tic_dat
        else:
            read_method = \
                PostgresPanFactory(tic_or_pan_dat_schema).create_pan_dat
        return read_method(db_engine, db_schema)

    # todo use superticdat to read subsets of tables
    def read_data(self, *schema_or_table_names):
        '''
        Read data for a TicDat schema/table from its corresponding data source
        '''
        if not all(
                isinstance(schema_or_table_name, str)
                for schema_or_table_name in schema_or_table_names
        ):
            raise ValueError(
                'Every argument should be str of the form <schema_name> '
                'or <schema_name>.<table_name>\n'
                'Here the schema name is the name of the '
                'TicDatFactory/PanDatFactory instance variable bound to the '
                'action object.\n'
                'The table names are those defined in the '
                'TicDatFactory/PanDatFactory object.'
            )

        get_data_source = (
            self.get_enframe_data_source if self.is_running_on_enframe
            else self.get_local_data_source
        )
        data = {}
        for schema_or_table_name in schema_or_table_names:
            schema_name, table_name = Action._get_schema_and_table_name(
                schema_or_table_name
            )
            if table_name:
                raise NotImplementedError(
                    'Reading from individual tables has not been implemented yet'
                )
            tic_or_pan_dat_schema = getattr(self, schema_name, None)
            if not tic_or_pan_dat_schema:
                raise ValueError(
                    f'Check {schema_name}\n'
                    'Cannot find TicDatFactory/PanDatFactory instance variable '
                    'with specified name'
                )

            data_source, data_source_type = get_data_source(
                schema_or_table_name, include_data_source_type=True
            )
            if data_source_type == 'file_or_dir':
                # todo get appropriate file for schema/table
                if Path(data_source).is_dir():
                    pass
                data[schema_or_table_name] = Action._read_data_from_file_system(
                    tic_or_pan_dat_schema, data_source
                )
            elif data_source_type == 'db':
                if (
                        'db_url' not in data_source
                        or 'db_schema' not in data_source
                ):
                    missing_param = (
                        'db_url' if 'db_url' not in data_source
                        else 'db_schema'
                    )
                    raise ValueError(
                        f"The '{missing_param}' in database data source dict "
                        'must be set before reading from a database.\n '
                        'Use set_data_source method to set the db_url '
                        'for any of the following: all schemas, the schema '
                        'to be read or the table to be read.'
                    )
                data[schema_or_table_name] = Action._read_data_from_db(
                    tic_or_pan_dat_schema,
                    (
                        data_source['db_url']
                        if data_source['db_url'] != self.enframe_db_url
                        else self.enframe_connection
                    ),
                    data_source['db_schema']
                )
            setattr(data[schema_or_table_name], '_schema', schema_or_table_name)
        return (
            data if len(schema_or_table_names) > 1
            else next(iter(data.values()))
        )

    def _get_tic_or_pan_dat_schema_name(self, tic_or_pan_dat):
        if hasattr(tic_or_pan_dat, '_schema'):
            return tic_or_pan_dat._schema

        for schema_name in self.schema_names:
            tic_or_pan_dat_schema = getattr(self, schema_name)
            if Action._is_correct_schema(tic_or_pan_dat, tic_or_pan_dat_schema):
                return schema_name
        for attr, value in vars(self).items():
            if (
                    type(value) in (TicDatFactory, PanDatFactory)
                    and Action._is_correct_schema(tic_or_pan_dat, value)
            ):
                return attr

        raise ValueError(
            f'Check {tic_or_pan_dat}\n'
            "The given TicDat/PanDat doesn't correspond to any "
            'TicDatFactory/PanDatFactory schema defined as an '
            'instance variable/public property of this action'
        )

    def check_data(self, *tic_or_pan_dats):
        '''
        Check if TicDat/PanDat objects violate the data constraints
        defined on their corresponding TicDatFactory/PanDatFactory schemas
        '''
        for tic_or_pan_dat in tic_or_pan_dats:
            schema_name = self._get_tic_or_pan_dat_schema_name(tic_or_pan_dat)
            tic_or_pan_dat_schema = getattr(self, schema_name)

            assert tic_or_pan_dat_schema.good_pan_dat_object(tic_or_pan_dat)
            # todo see if xls.find_duplicates and so on can be used for
            # ticdatfactory
            if type(tic_or_pan_dat_schema) is PanDatFactory:
                assert not tic_or_pan_dat_schema.find_duplicates(
                    tic_or_pan_dat
                )
            assert not tic_or_pan_dat_schema.find_foreign_key_failures(
                tic_or_pan_dat
            )
            assert not tic_or_pan_dat_schema.find_data_type_failures(
                tic_or_pan_dat
            )
            assert not tic_or_pan_dat_schema.find_data_row_failures(
                tic_or_pan_dat
            )

    @staticmethod
    def _write_data_to_file_system(
            tic_or_pan_dat_schema, tic_or_pan_dat, file_or_dir_path
    ):
        data_path, data_file_type, extension = Action._get_data_path_and_type(
            file_or_dir_path, include_extension=True
        )
        kwargs = {}
        if type(tic_or_pan_dat_schema) is TicDatFactory:
            write_method_name = 'write_file'
            if data_file_type == 'sql':
                write_method_name = (
                    'write_db_data' if extension == 'db'
                    else 'write_sql_file'
                )
            kwargs = {'allow_overwrite': True}
        elif type(tic_or_pan_dat_schema) is PanDatFactory:
            write_method_name = 'write_file'
        if data_file_type == 'csv':
            write_method_name = 'write_directory'

        return getattr(
            getattr(tic_or_pan_dat_schema, data_file_type),
            write_method_name
        )(tic_or_pan_dat, data_path, **kwargs)

    @staticmethod
    def _create_tables_in_db(
            tic_or_pan_dat_schema, db_engine_or_url, db_schema,
    ):
        ensure_packages('sqlalchemy', 'framework_utils')
        from sqlalchemy.engine import Connectable
        from framework_utils.pgtd import PostgresTicFactory, PostgresPanFactory

        if isinstance(db_engine_or_url, Connectable):
            db_engine = db_engine_or_url
        else:
            from sqlalchemy import create_engine
            db_engine = create_engine(db_engine_or_url)

        postgres_factory = (
            PostgresTicFactory(tic_or_pan_dat_schema)
            if type(tic_or_pan_dat_schema) is TicDatFactory
            else PostgresPanFactory(tic_or_pan_dat_schema)
        )

        original_setattr = type(postgres_factory).__setattr__

        def modified_setattr(self, name, value):
            object.__setattr__(self, name, value)

        original_get_schema_sql = type(postgres_factory)._get_schema_sql

        def modified_get_schema_sql(self, *args, **kwargs):
            return tuple(
                sql_create_statement.replace(
                    'CREATE TABLE', 'CREATE TABLE IF NOT EXISTS'
                )
                for sql_create_statement in original_get_schema_sql(
                    self, *args, **kwargs
                )
            )

        type(postgres_factory).__setattr__ = modified_setattr
        postgres_factory._get_schema_sql = MethodType(
            modified_get_schema_sql, postgres_factory
        )
        postgres_factory.write_schema(db_engine, db_schema)
        postgres_factory._get_schema_sql = MethodType(
            original_get_schema_sql, postgres_factory
        )
        type(postgres_factory).__setattr__ = original_setattr

    # todo move out the common code from read/write db methods
    @staticmethod
    def _write_data_to_db(
            tic_or_pan_dat_schema, tic_or_pan_dat, db_engine_or_url, db_schema,
            **kwargs
    ):
        ensure_packages('sqlalchemy', 'framework_utils')
        from sqlalchemy.engine import Connectable
        from framework_utils.pgtd import PostgresTicFactory, PostgresPanFactory

        if isinstance(db_engine_or_url, Connectable):
            db_engine = db_engine_or_url
        else:
            from sqlalchemy import create_engine
            db_engine = create_engine(db_engine_or_url)

        postgres_factory = (
            PostgresTicFactory(tic_or_pan_dat_schema)
            if type(tic_or_pan_dat_schema) is TicDatFactory
            else PostgresPanFactory(tic_or_pan_dat_schema)
        )

        postgres_factory.write_data(
            tic_or_pan_dat, db_engine, db_schema,
            **kwargs
        )

    # todo check the pandatfactory test
    # todo add functionality to pass in data types for each field
    def write_data(self, *tic_or_pan_dats, **pgtd_write_db_data_kwargs):
        '''
        Write data for a TicDat schema/table to its corresponding data source
        '''
        if not all(
                (
                        'ticdat.ticdatfactory.TicDatFactory.__init__.<locals>.TicDat'
                        in str(type(tic_or_pan_dat))
                        or
                        'ticdat.pandatfactory.PanDatFactory.__init__.<locals>.PanDat'
                        in str(type(tic_or_pan_dat))
                )
                for tic_or_pan_dat in tic_or_pan_dats
        ):
            raise ValueError(
                'Every argument other than keyword arguments should be a '
                'TicDat/PanDat object'
            )

        get_data_source = (
            self.get_enframe_data_source if self.is_running_on_enframe
            else self.get_local_data_source
        )
        for tic_or_pan_dat in tic_or_pan_dats:
            schema_name = self._get_tic_or_pan_dat_schema_name(tic_or_pan_dat)
            data_source, data_source_type = get_data_source(
                schema_name, include_data_source_type=True
            )
            tic_or_pan_dat_schema = getattr(self, schema_name)

            if data_source_type == 'file_or_dir':
                Action._write_data_to_file_system(
                    tic_or_pan_dat_schema, tic_or_pan_dat, data_source
                )
            elif data_source_type == 'db':
                if (
                        'db_url' not in data_source
                        or 'db_schema' not in data_source
                ):
                    missing_param = (
                        'db_url' if 'db_url' not in data_source
                        else 'db_schema'
                    )
                    raise ValueError(
                        f"The '{missing_param}' in database data source dict "
                        'must be set before writing to a database.\n '
                        'Use set_data_source method to set the db_url '
                        'for any of the following: all schemas, the schema '
                        'to be written or the table to be written.'
                    )
                if pgtd_write_db_data_kwargs.pop('create_tables', False):
                    Action._create_tables_in_db(
                        tic_or_pan_dat_schema,
                        (
                            data_source['db_url']
                            if data_source['db_url'] != self.enframe_db_url
                            else self.enframe_connection
                        ),
                        data_source['db_schema']
                    )
                Action._write_data_to_db(
                    tic_or_pan_dat_schema, tic_or_pan_dat,
                    (
                        data_source['db_url']
                        if data_source['db_url'] != self.enframe_db_url
                        else self.enframe_connection
                    ),
                    data_source['db_schema'],
                    **pgtd_write_db_data_kwargs
                )

    @abstractmethod
    def execute_action(self):
        """
        This method will be called by a user of the action in order
        to execute its functionality. This will be overridden
        by an implementation of the action functionality.
        """

    @property
    def schema_names(self):
        """
        A list of all the public schemas defined by the action.
        Note that these are all the schemas defined as public
        property attributes of the action.
        """
        # This is required as inspect.getmembers becomes recursive
        # when inspect.getmembers is used on an action object
        if (
                inspect.stack()[1].function == 'getmembers'
                and inspect.stack()[1].filename.endswith('inspect.py')
        ):
            return type(self).schema_names
        else:
            return [
                name
                for name, desc in inspect.getmembers(
                    type(self), inspect.isdatadescriptor
                )
                if not name.startswith('_') and name != 'schema_names'
                   and type(desc.fget(self)) in (TicDatFactory, PanDatFactory)
            ]

    @property
    def method_names(self):
        """A list of all the public methods defined by the action"""
        # This is required as inspect.getmembers becomes recursive
        # when inspect.getmembers is used on an action object
        if (
                inspect.stack()[1].function == 'getmembers'
                and inspect.stack()[1].filename.endswith('inspect.py')
        ):
            return type(self).method_names
        else:
            return [
                name
                for name, _ in inspect.getmembers(self, inspect.ismethod)
                if not name.startswith('_')
            ]
Exemple #17
0
                          inclusive_min=True,
                          inclusive_max=False)
diet_dat = diet_schema.TicDat(
    **{
        'foods': [['hamburger', 2.49], ['salad', 2.49], ['hot dog', 1.5],
                  ['fries', 1.89], ['macaroni', 2.09], ['chicken', 2.89],
                  ['milk', 0.89], ['ice cream', 1.59], ['pizza', 1.99]],
        'categories': [['protein', 91, float("inf")], [
            'calories', 1800, 2200.0
        ], ['fat', 0, 65.0], ['sodium', 0, 1779.0]],
        'nutrition_quantities':
        [['ice cream', 'protein', 8], ['ice cream', 'fat', 10],
         ['fries', 'sodium', 270], ['fries', 'calories', 380],
         ['hamburger', 'fat', 26], ['macaroni', 'sodium', 930],
         ['hot dog', 'sodium', 1800], ['chicken', 'sodium', 1190],
         ['salad', 'calories', 320], ['ice cream', 'calories', 330],
         ['milk', 'sodium', 125], ['salad', 'sodium', 1230],
         ['pizza', 'sodium', 820], ['pizza', 'protein', 15],
         ['pizza', 'calories', 320], ['hamburger', 'calories', 410],
         ['milk', 'fat', 2.5], ['salad', 'protein', 31],
         ['milk', 'protein', 8], ['macaroni', 'fat', 10], ['salad', 'fat', 12],
         ['hot dog', 'fat', 32], ['chicken', 'fat', 10],
         ['chicken', 'protein', 32], ['fries', 'protein', 4],
         ['pizza', 'fat', 12], ['milk', 'calories', 100],
         ['ice cream', 'sodium', 180], ['chicken', 'calories', 420],
         ['hamburger', 'sodium', 730], ['macaroni', 'calories', 320],
         ['fries', 'fat', 19], ['hot dog', 'calories', 560],
         ['hot dog', 'protein', 20], ['macaroni', 'protein', 12],
         ['hamburger', 'protein', 24]]
    })

Exemple #18
0
def _testFantop(modelType, sqlFile):
    dataFactory = TicDatFactory(parameters=[["Key"], ["Value"]],
                                players=[['Player Name'],
                                         [
                                             'Position',
                                             'Average Draft Position',
                                             'Expected Points', 'Draft Status'
                                         ]],
                                roster_requirements=[['Position'],
                                                     [
                                                         'Min Num Starters',
                                                         'Max Num Starters',
                                                         'Min Num Reserve',
                                                         'Max Num Reserve',
                                                         'Flex Status'
                                                     ]],
                                my_draft_positions=[['Draft Position'], []])

    # add foreign key constraints (optional, but helps with preventing garbage-in, garbage-out)
    dataFactory.add_foreign_key("players", "roster_requirements",
                                ['Position', 'Position'])

    # set data types (optional, but helps with preventing garbage-in, garbage-out)
    dataFactory.set_data_type("parameters",
                              "Key",
                              number_allowed=False,
                              strings_allowed=[
                                  "Starter Weight", "Reserve Weight",
                                  "Maximum Number of Flex Starters"
                              ])
    dataFactory.set_data_type("parameters",
                              "Value",
                              min=0,
                              max=float("inf"),
                              inclusive_min=True,
                              inclusive_max=False)
    dataFactory.set_data_type("players",
                              "Average Draft Position",
                              min=0,
                              max=float("inf"),
                              inclusive_min=False,
                              inclusive_max=False)
    dataFactory.set_data_type("players",
                              "Expected Points",
                              min=-float("inf"),
                              max=float("inf"),
                              inclusive_min=False,
                              inclusive_max=False)
    dataFactory.set_data_type("players",
                              "Draft Status",
                              strings_allowed=[
                                  "Un-drafted", "Drafted By Me",
                                  "Drafted By Someone Else"
                              ])
    for fld in ("Min Num Starters", "Min Num Reserve", "Max Num Reserve"):
        dataFactory.set_data_type("roster_requirements",
                                  fld,
                                  min=0,
                                  max=float("inf"),
                                  inclusive_min=True,
                                  inclusive_max=False,
                                  must_be_int=True)
    dataFactory.set_data_type("roster_requirements",
                              "Max Num Starters",
                              min=0,
                              max=float("inf"),
                              inclusive_min=False,
                              inclusive_max=True,
                              must_be_int=True)
    dataFactory.set_data_type(
        "roster_requirements",
        "Flex Status",
        number_allowed=False,
        strings_allowed=["Flex Eligible", "Flex Ineligible"])
    dataFactory.set_data_type("my_draft_positions",
                              "Draft Position",
                              min=0,
                              max=float("inf"),
                              inclusive_min=False,
                              inclusive_max=False,
                              must_be_int=True)

    solutionFactory = TicDatFactory(my_draft=[[
        'Player Name'
    ], [
        'Draft Position', 'Position', 'Planned Or Actual', 'Starter Or Reserve'
    ]])

    dat = dataFactory.sql.create_tic_dat_from_sql(os.path.join(
        _codeDir(), sqlFile),
                                                  freeze_it=True)

    assert dataFactory.good_tic_dat_object(dat)
    assert not dataFactory.find_foreign_key_failures(dat)
    assert not dataFactory.find_data_type_failures(dat)

    expected_draft_position = {}
    # for our purposes, its fine to assume all those drafted by someone else are drafted
    # prior to any players drafted by me
    for player_name in sorted(
            dat.players,
            key=lambda _p: {
                "Un-drafted": dat.players[_p]["Average Draft Position"],
                "Drafted By Me": -1,
                "Drafted By Someone Else": -2
            }[dat.players[_p]["Draft Status"]]):
        expected_draft_position[player_name] = len(expected_draft_position) + 1
    assert max(expected_draft_position.values()) == len(
        set(expected_draft_position.values())) == len(dat.players)
    assert min(expected_draft_position.values()) == 1

    already_drafted_by_me = {
        player_name
        for player_name, row in dat.players.items()
        if row["Draft Status"] == "Drafted By Me"
    }
    can_be_drafted_by_me = {
        player_name
        for player_name, row in dat.players.items()
        if row["Draft Status"] != "Drafted By Someone Else"
    }

    m = Model(modelType, 'fantop')
    my_starters = {
        player_name: m.add_var(type="binary", name="starter_%s" % player_name)
        for player_name in can_be_drafted_by_me
    }
    my_reserves = {
        player_name: m.add_var(type="binary", name="reserve_%s" % player_name)
        for player_name in can_be_drafted_by_me
    }

    for player_name in can_be_drafted_by_me:
        if player_name in already_drafted_by_me:
            m.add_constraint(my_starters[player_name] +
                             my_reserves[player_name] == 1,
                             name="already_drafted_%s" % player_name)
        else:
            m.add_constraint(
                my_starters[player_name] + my_reserves[player_name] <= 1,
                name="cant_draft_twice_%s" % player_name)

    for i, draft_position in enumerate(sorted(dat.my_draft_positions)):
        m.add_constraint(m.sum(
            my_starters[player_name] + my_reserves[player_name]
            for player_name in can_be_drafted_by_me
            if expected_draft_position[player_name] < draft_position) <= i,
                         name="at_most_%s_can_be_ahead_of_%s" %
                         (i, draft_position))

    my_draft_size = m.sum(my_starters[player_name] + my_reserves[player_name]
                          for player_name in can_be_drafted_by_me)
    m.add_constraint(my_draft_size >= len(already_drafted_by_me) + 1,
                     name="need_to_extend_by_at_least_one")
    m.add_constraint(my_draft_size <= len(dat.my_draft_positions),
                     name="cant_exceed_draft_total")

    for position, row in dat.roster_requirements.items():
        players = {
            player_name
            for player_name in can_be_drafted_by_me
            if dat.players[player_name]["Position"] == position
        }
        starters = m.sum(my_starters[player_name] for player_name in players)
        reserves = m.sum(my_reserves[player_name] for player_name in players)
        m.add_constraint(starters >= row["Min Num Starters"],
                         name="min_starters_%s" % position)
        m.add_constraint(starters <= row["Max Num Starters"],
                         name="max_starters_%s" % position)
        m.add_constraint(reserves >= row["Min Num Reserve"],
                         name="min_reserve_%s" % position)
        m.add_constraint(reserves <= row["Max Num Reserve"],
                         name="max_reserve_%s" % position)

    if "Maximum Number of Flex Starters" in dat.parameters:
        players = {
            player_name
            for player_name in can_be_drafted_by_me
            if dat.roster_requirements[dat.players[player_name]["Position"]]
            ["Flex Status"] == "Flex Eligible"
        }
        m.add_constraint(
            m.sum(my_starters[player_name] for player_name in players) <=
            dat.parameters["Maximum Number of Flex Starters"]["Value"],
            name="max_flex")

    starter_weight = dat.parameters["Starter Weight"][
        "Value"] if "Starter Weight" in dat.parameters else 1
    reserve_weight = dat.parameters["Reserve Weight"][
        "Value"] if "Reserve Weight" in dat.parameters else 1
    m.set_objective(m.sum(dat.players[player_name]["Expected Points"] *
                          (my_starters[player_name] * starter_weight +
                           my_reserves[player_name] * reserve_weight)
                          for player_name in can_be_drafted_by_me),
                    sense="maximize")

    if not m.optimize():
        return

    sln = solutionFactory.TicDat()

    def almostone(x):
        return abs(m.get_solution_value(x) - 1) < 0.0001

    picked = sorted([
        player_name for player_name in can_be_drafted_by_me
        if almostone(my_starters[player_name])
        or almostone(my_reserves[player_name])
    ],
                    key=lambda _p: expected_draft_position[_p])
    assert len(picked) <= len(dat.my_draft_positions)
    if len(picked) < len(dat.my_draft_positions):
        print(
            "Your model is over-constrained, and thus only a partial draft was possible"
        )

    draft_yield = 0
    for player_name, draft_position in zip(picked,
                                           sorted(dat.my_draft_positions)):
        draft_yield += dat.players[player_name]["Expected Points"] * \
                       (starter_weight if almostone(my_starters[player_name]) else reserve_weight)
        assert draft_position <= expected_draft_position[player_name]
        sln.my_draft[player_name]["Draft Position"] = draft_position
        sln.my_draft[player_name]["Position"] = dat.players[player_name][
            "Position"]
        sln.my_draft[player_name][
            "Planned Or Actual"] = "Actual" if player_name in already_drafted_by_me else "Planned"
        sln.my_draft[player_name]["Starter Or Reserve"] = \
            "Starter" if almostone(my_starters[player_name]) else "Reserve"
    return sln, draft_yield
Exemple #19
0
                                "optimizer_only_test_input.xlsx",
                                allow_overwrite=True)

sch1, sch2 = [_.schema() for _ in (tdfSuperBowl, spo.input_schema)]

tdfSuperBowl = TicDatFactory(
    forecast_sales=sch1["data"],
    **{k: v
       for k, v in sch2.items() if k != "forecast_sales"})
superBowlDat = tdfSuperBowl.TicDat(
    forecast_sales=superBowlDat.data,
    products={
        '11 Down': 'Clear',
        'AB Root Beer': 'Dark',
        'Alpine Stream': 'Clear',
        'Bright': 'Clear',
        'Crisp Clear': 'Clear',
        'DC Kola': 'Dark',
        'Koala Kola': 'Dark',
        'Mr. Popper': 'Dark',
        'Popsi Kola': 'Dark'
    },
    max_promotions={
        "Clear": 2,
        "Dark": 2
    },
    parameters={"Maximum Total Investment": 750})
tdfSuperBowl.xls.write_file(superBowlDat,
                            "predict_then_optimize_super_bowl.xlsx",
                            allow_overwrite=True)
Exemple #20
0
    def testDateTime(self):
        schema = test_schema + "_datetime"
        tdf = TicDatFactory(table_with_stuffs=[["field one"], ["field two"]],
                            parameters=[["a"], ["b"]])
        tdf.add_parameter("p1", "Dec 15 1970", datetime=True)
        tdf.add_parameter("p2", None, datetime=True, nullable=True)
        tdf.set_data_type("table_with_stuffs", "field one", datetime=True)
        tdf.set_data_type("table_with_stuffs",
                          "field two",
                          datetime=True,
                          nullable=True)

        dat = tdf.TicDat(table_with_stuffs=[[
            dateutil.parser.parse("July 11 1972"), None
        ], [datetime.datetime.now(),
            dateutil.parser.parse("Sept 11 2011")]],
                         parameters=[["p1", "7/11/1911"], ["p2", None]])
        self.assertFalse(
            tdf.find_data_type_failures(dat)
            or tdf.find_data_row_failures(dat))

        tdf.pgsql.write_schema(self.engine, schema)
        tdf.pgsql.write_data(dat, self.engine, schema)
        dat_1 = tdf.pgsql.create_tic_dat(self.engine, schema)
        self.assertFalse(
            tdf._same_data(dat, dat_1, nans_are_same_for_data_rows=True))
        self.assertTrue(
            all(
                len(getattr(dat, t)) == len(getattr(dat_1, t))
                for t in tdf.all_tables))
        self.assertFalse(
            tdf.find_data_type_failures(dat_1)
            or tdf.find_data_row_failures(dat_1))
        self.assertTrue(
            isinstance(dat_1.parameters["p1"]["b"], datetime.datetime))
        self.assertTrue(
            all(
                isinstance(_, datetime.datetime)
                for _ in dat_1.table_with_stuffs))
        self.assertTrue(
            len([_ for _ in dat_1.table_with_stuffs if pd.isnull(_)]) == 0)
        self.assertTrue(
            all(
                isinstance(_, datetime.datetime) or pd.isnull(_)
                for v in dat_1.table_with_stuffs.values() for _ in v.values()))
        self.assertTrue(
            len([
                _ for v in dat_1.table_with_stuffs.values()
                for _ in v.values() if pd.isnull(_)
            ]) == 1)
        pdf = PanDatFactory.create_from_full_schema(
            tdf.schema(include_ancillary_info=True))
        pan_dat = pdf.pgsql.create_pan_dat(self.engine, schema)
        dat_2 = pdf.copy_to_tic_dat(pan_dat)
        # pandas can be a real PIA sometimes, hacking around some weird downcasting
        for k in list(dat_2.table_with_stuffs):
            dat_2.table_with_stuffs[pd.Timestamp(
                k)] = dat_2.table_with_stuffs.pop(k)
        self.assertTrue(
            tdf._same_data(dat_1, dat_2, nans_are_same_for_data_rows=True))

        pdf.pgsql.write_data(pan_dat, self.engine, schema)
        dat_3 = pdf.copy_to_tic_dat(
            pdf.pgsql.create_pan_dat(self.engine, schema))
        for k in list(dat_3.table_with_stuffs):
            dat_3.table_with_stuffs[pd.Timestamp(
                k)] = dat_3.table_with_stuffs.pop(k)
        self.assertTrue(
            tdf._same_data(dat_1, dat_3, nans_are_same_for_data_rows=True))
Exemple #21
0
def _pan_dat_maker_from_dict(schema, tic_dat_dict):
    tdf = TicDatFactory(**schema)
    tic_dat = tdf.TicDat(**tic_dat_dict)
    return pan_dat_maker(schema, tic_dat)