class ParticipantVarsFromREST(AnyModel): participant_label = Column(st.String(255)) room_name = Column(st.String(255)) _json_data = Column(st.Text) @property def vars(self): return json.loads(self._json_data) @vars.setter def vars(self, value): self._json_data = json.dumps(value)
class ChatMessage(AnyModel): class Meta: index_together = ['channel', 'timestamp'] # the name "channel" here is unrelated to Django channels channel = Column(st.String(255)) participant_id = Column(st.Integer, ForeignKey('otree_participant.id')) participant = relationship("Participant") nickname = Column(st.String(255)) # call it 'body' instead of 'message' or 'content' because those terms # are already used by channels body = Column(st.Text) timestamp = Column(st.Float, default=time.time)
class TaskQueueMessage(AnyModel): method = Column(st.String(50)) kwargs_json = Column(st.Text) epoch_time = Column(st.Integer) def kwargs(self) -> dict: return json.loads(self.kwargs_json)
def test_success(self): column = Column("name", sqltypes.String(50)) name, django_field = fields.to_django_field(TestTable, column) assert isinstance(django_field, models.CharField) assert django_field.max_length == 50 assert name == "name"
def from_sqlalchemy_table(cls, table: sqlalchemy.Table) -> "MockTableSchema": data_types = {} for column in table.columns: if isinstance(column.type, sqltypes.Enum): data_types[column.name] = sqltypes.String(255) else: data_types[column.name] = column.type return cls(data_types)
def visit_column_comment(element, compiler, **kw): ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}" comment = compiler.sql_compiler.render_literal_value( (element.comment if element.comment is not None else ""), sqltypes.String(), ) return ddl.format( table_name=element.table_name, column_name=element.column_name, comment=comment, )
def from_big_query_schema_fields( cls, bq_schema: List[bigquery.SchemaField]) -> "MockTableSchema": data_types = {} for field in bq_schema: field_type = bigquery.enums.SqlTypeNames(field.field_type) if field_type is bigquery.enums.SqlTypeNames.STRING: data_type = sqltypes.String(255) elif field_type is bigquery.enums.SqlTypeNames.INTEGER: data_type = sqltypes.Integer elif field_type is bigquery.enums.SqlTypeNames.FLOAT: data_type = sqltypes.Float elif field_type is bigquery.enums.SqlTypeNames.DATE: data_type = sqltypes.Date elif field_type is bigquery.enums.SqlTypeNames.BOOLEAN: data_type = sqltypes.Boolean else: raise ValueError( f"Unhandled big query field type '{field_type}' for attribute '{field.name}'" ) data_types[field.name] = data_type return cls(data_types)
class RoomToSession(AnyModel, MixinSessionFK): room_name = Column(st.String(255), unique=True)
def convert_sqla_type_for_dialect( coltype: TypeEngine, dialect: Dialect, strip_collation: bool = True, convert_mssql_timestamp: bool = True, expand_for_scrubbing: bool = False) -> TypeEngine: """ Converts an SQLAlchemy column type from one SQL dialect to another. Args: coltype: SQLAlchemy column type in the source dialect dialect: destination :class:`Dialect` strip_collation: remove any ``COLLATION`` information? convert_mssql_timestamp: since you cannot write to a SQL Server ``TIMESTAMP`` field, setting this option to ``True`` (the default) converts such types to something equivalent but writable. expand_for_scrubbing: The purpose of expand_for_scrubbing is that, for example, a ``VARCHAR(200)`` field containing one or more instances of ``Jones``, where ``Jones`` is to be replaced with ``[XXXXXX]``, will get longer (by an unpredictable amount). So, better to expand to unlimited length. Returns: an SQLAlchemy column type instance, in the destination dialect """ assert coltype is not None # noinspection PyUnresolvedReferences to_mysql = dialect.name == SqlaDialectName.MYSQL # noinspection PyUnresolvedReferences to_mssql = dialect.name == SqlaDialectName.MSSQL typeclass = type(coltype) # ------------------------------------------------------------------------- # Text # ------------------------------------------------------------------------- if isinstance(coltype, sqltypes.Enum): return sqltypes.String(length=coltype.length) if isinstance(coltype, sqltypes.UnicodeText): # Unbounded Unicode text. # Includes derived classes such as mssql.base.NTEXT. return sqltypes.UnicodeText() if isinstance(coltype, sqltypes.Text): # Unbounded text, more generally. (UnicodeText inherits from Text.) # Includes sqltypes.TEXT. return sqltypes.Text() # Everything inheriting from String has a length property, but can be None. # There are types that can be unlimited in SQL Server, e.g. VARCHAR(MAX) # and NVARCHAR(MAX), that MySQL needs a length for. (Failure to convert # gives e.g.: 'NVARCHAR requires a length on dialect mysql'.) if isinstance(coltype, sqltypes.Unicode): # Includes NVARCHAR(MAX) in SQL -> NVARCHAR() in SQLAlchemy. if (coltype.length is None and to_mysql) or expand_for_scrubbing: return sqltypes.UnicodeText() # The most general case; will pick up any other string types. if isinstance(coltype, sqltypes.String): # Includes VARCHAR(MAX) in SQL -> VARCHAR() in SQLAlchemy if (coltype.length is None and to_mysql) or expand_for_scrubbing: return sqltypes.Text() if strip_collation: return remove_collation(coltype) return coltype # ------------------------------------------------------------------------- # Binary # ------------------------------------------------------------------------- # ------------------------------------------------------------------------- # BIT # ------------------------------------------------------------------------- if typeclass == mssql.base.BIT and to_mysql: # MySQL BIT objects have a length attribute. return mysql.base.BIT() # ------------------------------------------------------------------------- # TIMESTAMP # ------------------------------------------------------------------------- is_mssql_timestamp = isinstance(coltype, MSSQL_TIMESTAMP) if is_mssql_timestamp and to_mssql and convert_mssql_timestamp: # You cannot write explicitly to a TIMESTAMP field in SQL Server; it's # used for autogenerated values only. # - http://stackoverflow.com/questions/10262426/sql-server-cannot-insert-an-explicit-value-into-a-timestamp-column # noqa # - https://social.msdn.microsoft.com/Forums/sqlserver/en-US/5167204b-ef32-4662-8e01-00c9f0f362c2/how-to-tranfer-a-column-with-timestamp-datatype?forum=transactsql # noqa # ... suggesting BINARY(8) to store the value. # MySQL is more helpful: # - http://stackoverflow.com/questions/409286/should-i-use-field-datetime-or-timestamp # noqa return mssql.base.BINARY(8) # ------------------------------------------------------------------------- # Some other type # ------------------------------------------------------------------------- return coltype
class Company(Base): __tablename__ = 'company' id = Column(t.Integer, primary_key=True) full_name = Column(t.String, nullable=False) abbv = Column(t.String(4), nullable=False) # rich = Column(t.Boolean, default=False) stock_price = Column(t.Float) price_diff = Column(t.Float, default=0) months = Column(t.Integer, default=0) # age event_increase = Column(t.Integer, default=0) event_months_remaining = Column(t.Integer, default=0) increase_chance = Column(t.Integer, default=50) # percentage max_increase = Column(t.Float, default=0.3) max_decrease = Column(t.Float, default=0.35) bankrupt = Column(t.Boolean, default=False) shares = relationship("Shares", backref="company", passive_deletes=True) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def iterate(self): if self.bankrupt: return if self.event_increase is None: self.event_increase = 0 if random.randrange(1, 100) < self.increase_chance + self.event_increase: amount = random.uniform(1, 1 + self.max_increase) else: amount = random.uniform(1 - self.max_decrease, 1) if self.event_months_remaining is not None: self.event_months_remaining -= 1 if self.event_months_remaining <= 0: self.event_increase = 0 initial_price = self.stock_price self.stock_price = self.stock_price * amount self.price_diff = initial_price - self.stock_price self.months += 1 if self.stock_price <= 0.5: self.bankrupt = True return return self.price_diff @classmethod def create(cls, starting_price, name=None, **kwargs): if name is None: name = ["dflt", "default"] return cls( stock_price=starting_price, abbv=name[0], full_name=name[1], **kwargs, ) @classmethod def find_by_abbreviation(cls, abbreviation: str, session): return session.query(cls).filter_by(abbv=abbreviation).first() @hybrid_property def stocks_bought(self): # noinspection PyTypeChecker return sum(share.amount for share in self.shares) @stocks_bought.expression def stocks_bought(self): return select([ func.sum(Shares.amount) ]).where(Shares.company_id == self.id).label("stocks_bought") @property def announcement_description(self): return f"{self.abbv.upper()}[{self.stock_price:,.1f}{-(self.price_diff/(self.stock_price+self.price_diff)*100):+.1f}%]" def __str__(self): years = int(self.months / 12) months = self.months % 12 return f"Name: '{self.abbv}' aka '{self.full_name}' | stock_price: {self.stock_price:,.2f} | " \ f"price change: {-(self.price_diff/(self.stock_price+self.price_diff)*100):+.1f}% | " \ f"lifespan: {years} {'years' if not years == 1 else 'year'} " \ f"and {months} {'months' if not months == 1 else 'month'} | Stocks Bought: {self.stocks_bought}" def __repr__(self): return self._repr( **self._getattrs("id", "abbv", "stocks_bought", "stock_price"), )
def test_select_star(mocker: MockFixture, app_context: AppContext) -> None: """ Test the ``select_star`` method. The method removes pseudo-columns from structures inside arrays. While these pseudo-columns show up as "columns" for metadata reasons, we can't select them in the query, as opposed to fields from non-array structures. """ from superset.db_engine_specs.bigquery import BigQueryEngineSpec cols = [ { "name": "trailer", "type": sqltypes.ARRAY(sqltypes.JSON()), "nullable": True, "comment": None, "default": None, "precision": None, "scale": None, "max_length": None, }, { "name": "trailer.key", "type": sqltypes.String(), "nullable": True, "comment": None, "default": None, "precision": None, "scale": None, "max_length": None, }, { "name": "trailer.value", "type": sqltypes.String(), "nullable": True, "comment": None, "default": None, "precision": None, "scale": None, "max_length": None, }, { "name": "trailer.email", "type": sqltypes.String(), "nullable": True, "comment": None, "default": None, "precision": None, "scale": None, "max_length": None, }, ] # mock the database so we can compile the query database = mocker.MagicMock() database.compile_sqla_query = lambda query: str( query.compile(dialect=BigQueryDialect())) engine = mocker.MagicMock() engine.dialect = BigQueryDialect() sql = BigQueryEngineSpec.select_star( database=database, table_name="my_table", engine=engine, schema=None, limit=100, show_cols=True, indent=True, latest_partition=False, cols=cols, ) assert (sql == """SELECT `trailer` AS `trailer` FROM `my_table` LIMIT :param_1""")
class Participant(otree.database.SSPPGModel, MixinVars): __tablename__ = 'otree_participant' session_id = Column(st.Integer, ForeignKey('otree_session.id')) session = relationship("Session", back_populates="pp_set") label = Column(st.String(50), nullable=True,) id_in_session = Column(st.Integer, nullable=True) payoff = Column(CurrencyType, default=0) time_started = Column(st.DateTime, nullable=True) mturk_assignment_id = Column(st.String(50), nullable=True) mturk_worker_id = Column(st.String(50), nullable=True) _index_in_pages = Column(st.Integer, default=0, index=True) def _numeric_label(self): """the human-readable version.""" return 'P{}'.format(self.id_in_session) _monitor_note = Column(st.String(300), nullable=True) code = Column( st.String(16), default=random_chars_8, # set non-nullable, until we make our CharField non-nullable nullable=False, # unique implies DB index unique=True, ) # useful when we don't want to load the whole session just to get the code _session_code = Column(st.String(16)) visited = Column(st.Boolean, default=False, index=True,) # stores when the page was first visited _last_page_timestamp = Column(st.Integer, nullable=True) _last_request_timestamp = Column(st.Integer, nullable=True) is_on_wait_page = Column(st.Boolean, default=False) # these are both for the admin # In the changelist, simply call these "page" and "app" _current_page_name = Column(st.String(200), nullable=True) _current_app_name = Column(st.String(200), nullable=True) # only to be displayed in the admin participants changelist _round_number = Column(st.Integer, nullable=True) _current_form_page_url = Column(st.String(500)) _max_page_index = Column(st.Integer,) _is_bot = Column(st.Boolean, default=False) # can't start with an underscore because used in template # can't end with underscore because it's a django field (fields.E001) is_browser_bot = Column(st.Boolean, default=False) _timeout_expiration_time = otree.database.FloatField() _timeout_page_index = Column(st.Integer,) _gbat_is_waiting = Column(st.Boolean, default=False) _gbat_page_index = Column(st.Integer,) _gbat_grouped = Column(st.Boolean,) def _current_page(self): # don't put 'pages' because that causes wrapping which takes more space # since it's longer than the header return f'{self._index_in_pages}/{self._max_page_index}' # because variables used in templates can't start with an underscore def current_page_(self): return self._current_page() def get_players(self): """Used to calculate payoffs""" lst = [] app_sequence = self.session.config['app_sequence'] for app in app_sequence: models_module = otree.common.get_models_module(app) players = models_module.Player.objects_filter(participant=self).order_by( 'round_number' ) lst.extend(list(players)) return lst def _url_i_should_be_on(self): if not self.visited: return self._start_url() if self._index_in_pages <= self._max_page_index: return url_i_should_be_on( self.code, self._session_code, self._index_in_pages ) return '/OutOfRangeNotification/' + self.code def _start_url(self): return otree.common.participant_start_url(self.code) def payoff_in_real_world_currency(self): return self.payoff.to_real_world_currency(self.session) def payoff_plus_participation_fee(self): return self.session._get_payoff_plus_participation_fee(self.payoff) def _get_current_player(self): lookup = get_page_lookup(self._session_code, self._index_in_pages) models_module = otree.common.get_models_module(lookup.app_name) PlayerClass = getattr(models_module, 'Player') return PlayerClass.objects_get( participant=self, round_number=lookup.round_number ) def initialize(self, participant_label): """in a separate function so that we can call it individually, e.g. from advance_last_place_participants""" pp = self if pp._index_in_pages == 0: pp._index_in_pages = 1 pp.visited = True # participant.label might already have been set pp.label = pp.label or participant_label # default to Central European Time pp.time_started = datetime.datetime.utcnow() + datetime.timedelta(hours=1) pp._last_page_timestamp = time() player = pp._get_current_player() player.start()
def StringField(**kwargs): return wrap_column( st.String(length=kwargs.get('max_length', 10000)), **kwargs, )
class Session(otree.database.SSPPGModel, MixinVars): __tablename__ = 'otree_session' config: dict = Column(otree.database._PickleField, default=dict) pp_set = relationship("Participant", back_populates="session", lazy='dynamic') # label of this session instance label = Column(st.String, nullable=True) code = Column( st.String(16), default=random_chars_8, nullable=False, unique=True, ) mturk_HITId = Column(st.String(300), nullable=True) mturk_HITGroupId = Column(st.String(300), nullable=True) is_mturk = Column(st.Boolean, default=False) def mturk_num_workers(self): assert self.is_mturk return int(self.num_participants / settings.MTURK_NUM_PARTICIPANTS_MULTIPLE) mturk_use_sandbox = Column(st.Boolean, default=True) # use Float instead of DateTime because DateTime # is a pain to work with (e.g. naive vs aware datetime objects) # and there is no need here for DateTime mturk_expiration = Column(st.Float, nullable=True) mturk_qual_id = Column(st.String(50), default='') archived = Column( st.Boolean, default=False, index=True, ) comment = Column(st.Text) _anonymous_code = Column( st.String(20), default=random_chars_10, nullable=False, index=True, ) is_demo = Column(st.Boolean, default=False) _admin_report_app_names = Column(st.Text, default='') _admin_report_num_rounds = Column(st.String(255), default='') num_participants = Column(st.Integer) def __unicode__(self): return self.code @property def participation_fee(self): '''This method is deprecated from public API, but still useful internally (like data export)''' return self.config['participation_fee'] @property def real_world_currency_per_point(self): '''This method is deprecated from public API, but still useful internally (like data export)''' return self.config['real_world_currency_per_point'] @property def use_browser_bots(self): return self.config.get('use_browser_bots', False) def mock_exogenous_data(self): ''' It's for any exogenous data: - participant labels (which are not passed in through REST API) - participant vars - session vars (if we enable that) ''' if self.config.get('mock_exogenous_data'): import shared_out as user_utils user_utils.mock_exogenous_data(self) def get_subsessions(self): lst = [] app_sequence = self.config['app_sequence'] for app in app_sequence: models_module = otree.common.get_models_module(app) subsessions = models_module.Subsession.objects_filter( session=self).order_by('round_number') lst.extend(list(subsessions)) return lst def get_participants(self): return list(self.pp_set.order_by('id_in_session')) def mturk_worker_url(self): # different HITs # get the same preview page, because they are lumped into the same # "hit group". This is not documented, but it seems HITs are lumped # if a certain subset of properties are the same: # https://forums.aws.amazon.com/message.jspa?messageID=597622#597622 # this seems like the correct design; the only case where this will # not work is if the HIT was deleted from the server, but in that case, # the HIT itself should be canceled. # 2018-06-04: # the format seems to have changed to this: # https://worker.mturk.com/projects/{group_id}/tasks?ref=w_pl_prvw # but the old format still works. # it seems I can't replace groupId by hitID, which i would like to do # because it's more precise. subdomain = "workersandbox" if self.mturk_use_sandbox else 'www' return "https://{}.mturk.com/mturk/preview?groupId={}".format( subdomain, self.mturk_HITGroupId) def mturk_is_expired(self): # self.mturk_expiration is offset-aware, so therefore we must compare # it against an offset-aware value. return self.mturk_expiration and self.mturk_expiration < time.time() def mturk_is_active(self): return self.mturk_HITId and not self.mturk_is_expired() def advance_last_place_participants(self): """the problem with using the test client to make get/post requests is (1) this request already has the global asyncio.lock (2) there are apparently some issues with async/await and event loops. """ from otree.lookup import get_page_lookup from otree.api import WaitPage, Page participants = self.get_participants() # in case some participants haven't started unvisited_participants = False for p in participants: if p._index_in_pages == 0: p.initialize(None) if unvisited_participants: # that's it -- just visit the start URL, advancing by 1 return last_place_page_index = min([p._index_in_pages for p in participants]) last_place_participants = [ p for p in participants if p._index_in_pages == last_place_page_index ] for p in last_place_participants: page_index = p._index_in_pages if page_index >= p._max_page_index: return page = get_page_lookup( self.code, page_index).page_class.instantiate_without_request() page.set_attributes(p) if isinstance(page, Page): from starlette.datastructures import FormData page._is_frozen = False page._form_data = FormData({ otree.constants.admin_secret_code: ADMIN_SECRET_CODE, otree.constants.timeout_happened: '1', }) # TODO: should we also call .get() so that _update_monitor_table will also get run? resp = page.post() if resp.status_code >= 400: msg = ( f'Submitting page {p._current_form_page_url} failed, ' f'returned HTTP status code {resp.status_code}. ' 'Check the logs') raise AssertionError(msg) else: # it's possible that the slowest user is on a wait page, # especially if their browser is closed. # because they were waiting for another user who then # advanced past the wait page, but they were never # advanced themselves. resp = page.inner_dispatch(request=None) # do the auto-advancing here, # rather than in increment_index_in_pages, # because it's only needed here. otree.channels.utils.sync_group_send(group=auto_advance_group( p.code), data={'auto_advanced': True}) def get_room(self): from otree.room import ROOM_DICT try: room_name = RoomToSession.objects_get(session=self).room_name return ROOM_DICT[room_name] except NoResultFound: return None def _get_payoff_plus_participation_fee(self, payoff): '''For a participant who has the given payoff, return their payoff_plus_participation_fee Useful to define it here, for data export ''' return self.config[ 'participation_fee'] + payoff.to_real_world_currency(self) def _set_admin_report_app_names(self): admin_report_app_names = [] num_rounds_list = [] for app_name in self.config['app_sequence']: models_module = otree.common.get_models_module(app_name) app_label = get_app_label_from_name(app_name) try: get_template_name_if_exists([ f'{app_label}/admin_report.html', f'{app_label}/AdminReport.html' ]) except TemplateLoadError: pass else: admin_report_app_names.append(app_name) num_rounds_list.append(models_module.Constants.num_rounds) self._admin_report_app_names = ';'.join(admin_report_app_names) self._admin_report_num_rounds = ';'.join( str(n) for n in num_rounds_list) def _admin_report_apps(self): return self._admin_report_app_names.split(';') def _admin_report_num_rounds_list(self): return [int(num) for num in self._admin_report_num_rounds.split(';')] def has_admin_report(self): return bool(self._admin_report_app_names)
class CorrectionsOutputViewTest(BaseViewTest): """Tests the Corrections output view.""" INPUT_SCHEMA = MockTableSchema({ **METRIC_CALCULATOR_SCHEMA.data_types, "compare_date_partition": sqltypes.Date(), "compare_value": sqltypes.Numeric(), "state_code": sqltypes.String(255), }) def test_recent_population(self) -> None: """Tests the basic use case of calculating population""" # Arrange self.create_mock_bq_table( dataset_id="justice_counts", table_id="source_materialized", mock_schema=MockTableSchema.from_sqlalchemy_table( schema.Source.__table__), mock_data=pd.DataFrame([[1, "XX"], [2, "YY"], [3, "ZZ"]], columns=["id", "name"]), ) self.create_mock_bq_table( dataset_id="justice_counts", table_id="report_materialized", mock_schema=MockTableSchema.from_sqlalchemy_table( schema.Report.__table__), mock_data=pd.DataFrame( [ [ 1, 1, "_", "All", "2021-01-01", "xx.gov", "MANUALLY_ENTERED", "John", ], [ 2, 2, "_", "All", "2021-01-02", "yy.gov", "MANUALLY_ENTERED", "Jane", ], [ 3, 3, "_", "All", "2021-01-02", "zz.gov", "MANUALLY_ENTERED", "Jude", ], ], columns=[ "id", "source_id", "type", "instance", "publish_date", "url", "acquisition_method", "acquired_by", ], ), ) self.create_mock_bq_table( dataset_id="justice_counts", table_id="metric_calculator", mock_schema=self.INPUT_SCHEMA, mock_data=pd.DataFrame( [ row( 1, "2021-01-01", "2020-11-30", (FakeState("US_XX"), ), ["A", "B", "A"], 3000, measurement_type="INSTANT", ) + (None, None, "US_XX"), row( 1, "2021-01-01", "2020-12-31", (FakeState("US_XX"), ), ["B", "B", "C"], 4000, measurement_type="INSTANT", ) + (None, None, "US_XX"), row( 2, "2021-01-01", "2020-11-30", (FakeState("US_YY"), ), ["A", "B", "A"], 1000, measurement_type="INSTANT", ) + (None, None, "US_YY"), row( 2, "2021-01-01", "2020-12-31", (FakeState("US_YY"), ), ["B", "B", "C"], 1020, measurement_type="INSTANT", ) + (None, None, "US_YY"), row( 3, "2021-01-01", "2020-11-30", (FakeState("US_ZZ"), ), ["A", "B", "A"], 400, measurement_type="INSTANT", ) + (None, None, "US_ZZ"), row( 3, "2021-01-01", "2020-12-31", (FakeState("US_ZZ"), ), ["C", "C", "B"], 500, measurement_type="INSTANT", ) + (None, None, "US_ZZ"), ], columns=self.INPUT_SCHEMA.data_types.keys(), ), ) # Act dimensions = ["state_code", "metric", "year", "month"] prison_population_metric = metric_calculator.CalculatedMetric( system=schema.System.CORRECTIONS, metric=schema.MetricType.POPULATION, filtered_dimensions=[manual_upload.PopulationType.PRISON], aggregated_dimensions={ "state_code": metric_calculator.Aggregation(dimension=manual_upload.State, comprehensive=False) }, output_name="POP", ) results = self.query_view_for_builder( corrections_metrics.CorrectionsOutputViewBuilder( dataset_id="fake-dataset", metric_to_calculate=prison_population_metric, input_view=SimpleBigQueryViewBuilder( dataset_id="justice_counts", view_id="metric_calculator", description="metric_calculator view", view_query_template="", ), ), data_types={ "year": int, "month": int, "value": int }, dimensions=dimensions, ) # Assert expected = pd.DataFrame( [ [ "US_XX", "POP", 2020, 11, datetime.date.fromisoformat("2020-11-30"), "XX", "xx.gov", "_", datetime.date.fromisoformat("2021-01-01"), "INSTANT", ["A", "B"], 3000, ] + [None] * 4, [ "US_XX", "POP", 2020, 12, datetime.date.fromisoformat("2020-12-31"), "XX", "xx.gov", "_", datetime.date.fromisoformat("2021-01-01"), "INSTANT", ["B", "C"], 4000, ] + [None] * 4, [ "US_YY", "POP", 2020, 11, datetime.date.fromisoformat("2020-11-30"), "YY", "yy.gov", "_", datetime.date.fromisoformat("2021-01-02"), "INSTANT", ["A", "B"], 1000, ] + [None] * 4, [ "US_YY", "POP", 2020, 12, datetime.date.fromisoformat("2020-12-31"), "YY", "yy.gov", "_", datetime.date.fromisoformat("2021-01-02"), "INSTANT", ["B", "C"], 1020, ] + [None] * 4, [ "US_ZZ", "POP", 2020, 11, datetime.date.fromisoformat("2020-11-30"), "ZZ", "zz.gov", "_", datetime.date.fromisoformat("2021-01-02"), "INSTANT", ["A", "B"], 400, ] + [None] * 4, [ "US_ZZ", "POP", 2020, 12, datetime.date.fromisoformat("2020-12-31"), "ZZ", "zz.gov", "_", datetime.date.fromisoformat("2021-01-02"), "INSTANT", ["B", "C"], 500, ] + [None] * 4, ], columns=[ "state_code", "metric", "year", "month", "date_reported", "source_name", "source_url", "report_name", "date_published", "measurement_type", "raw_source_categories", "value", "compared_to_year", "compared_to_month", "value_change", "percentage_change", ], ) expected = expected.set_index(dimensions) assert_frame_equal(expected, results) def test_comparisons(self) -> None: """Tests that percentage change is correct, or null when the prior value was zero""" # Arrange self.create_mock_bq_table( dataset_id="justice_counts", table_id="source_materialized", mock_schema=MockTableSchema.from_sqlalchemy_table( schema.Source.__table__), mock_data=pd.DataFrame([[1, "XX"]], columns=["id", "name"]), ) self.create_mock_bq_table( dataset_id="justice_counts", table_id="report_materialized", mock_schema=MockTableSchema.from_sqlalchemy_table( schema.Report.__table__), mock_data=pd.DataFrame( [[ 1, 1, "_", "All", "2021-01-01", "xx.gov", "MANUALLY_ENTERED", "John", ]], columns=[ "id", "source_id", "type", "instance", "publish_date", "url", "acquisition_method", "acquired_by", ], ), ) self.create_mock_bq_table( dataset_id="justice_counts", table_id="metric_calculator", mock_schema=self.INPUT_SCHEMA, mock_data=pd.DataFrame( [ row(1, "2021-01-01", "2022-01-01", (FakeState("US_XX"), ), [], 3) + (datetime.date.fromisoformat("2021-02-01"), 0, "US_XX"), row(1, "2021-01-01", "2021-01-01", (FakeState("US_XX"), ), [], 0) + (datetime.date.fromisoformat("2020-02-01"), 2, "US_XX"), row(1, "2021-01-01", "2020-01-01", (FakeState("US_XX"), ), [], 2) + (None, None, "US_XX"), ], columns=self.INPUT_SCHEMA.data_types.keys(), ), ) # Act dimensions = ["state_code", "metric", "year", "month"] parole_population = metric_calculator.CalculatedMetric( system=schema.System.CORRECTIONS, metric=schema.MetricType.ADMISSIONS, filtered_dimensions=[], aggregated_dimensions={ "state_code": metric_calculator.Aggregation(dimension=manual_upload.State, comprehensive=False) }, output_name="ADMISSIONS", ) results = self.query_view_for_builder( corrections_metrics.CorrectionsOutputViewBuilder( dataset_id="fake-dataset", metric_to_calculate=parole_population, input_view=SimpleBigQueryViewBuilder( dataset_id="justice_counts", view_id="metric_calculator", description="metric_calculator view", view_query_template="", ), ), data_types={ "year": int, "month": int, "value": int }, dimensions=dimensions, ) # Assert expected = pd.DataFrame( [ [ "US_XX", "ADMISSIONS", 2020, 1, datetime.date.fromisoformat("2020-01-31"), "XX", "xx.gov", "_", datetime.date.fromisoformat("2021-01-01"), "INSTANT", [], 2, None, None, None, None, ], [ "US_XX", "ADMISSIONS", 2021, 1, datetime.date.fromisoformat("2021-01-31"), "XX", "xx.gov", "_", datetime.date.fromisoformat("2021-01-01"), "INSTANT", [], 0, 2020, 1, -2, -1.00, ], # Percentage change is None as prior value was 0 [ "US_XX", "ADMISSIONS", 2022, 1, datetime.date.fromisoformat("2022-01-31"), "XX", "xx.gov", "_", datetime.date.fromisoformat("2021-01-01"), "INSTANT", [], 3, 2021, 1, 3, None, ], ], columns=[ "state_code", "metric", "year", "month", "date_reported", "source_name", "source_url", "report_name", "date_published", "measurement_type", "raw_source_categories", "value", "compared_to_year", "compared_to_month", "value_change", "percentage_change", ], ) expected = expected.set_index(dimensions) assert_frame_equal(expected, results)