예제 #1
0
class ParticipantVarsFromREST(AnyModel):

    participant_label = Column(st.String(255))
    room_name = Column(st.String(255))
    _json_data = Column(st.Text)

    @property
    def vars(self):
        return json.loads(self._json_data)

    @vars.setter
    def vars(self, value):
        self._json_data = json.dumps(value)
예제 #2
0
class ChatMessage(AnyModel):
    class Meta:
        index_together = ['channel', 'timestamp']

    # the name "channel" here is unrelated to Django channels
    channel = Column(st.String(255))
    participant_id = Column(st.Integer, ForeignKey('otree_participant.id'))
    participant = relationship("Participant")
    nickname = Column(st.String(255))

    # call it 'body' instead of 'message' or 'content' because those terms
    # are already used by channels
    body = Column(st.Text)
    timestamp = Column(st.Float, default=time.time)
예제 #3
0
class TaskQueueMessage(AnyModel):

    method = Column(st.String(50))
    kwargs_json = Column(st.Text)
    epoch_time = Column(st.Integer)

    def kwargs(self) -> dict:
        return json.loads(self.kwargs_json)
예제 #4
0
    def test_success(self):
        column = Column("name", sqltypes.String(50))

        name, django_field = fields.to_django_field(TestTable, column)

        assert isinstance(django_field, models.CharField)
        assert django_field.max_length == 50
        assert name == "name"
예제 #5
0
 def from_sqlalchemy_table(cls,
                           table: sqlalchemy.Table) -> "MockTableSchema":
     data_types = {}
     for column in table.columns:
         if isinstance(column.type, sqltypes.Enum):
             data_types[column.name] = sqltypes.String(255)
         else:
             data_types[column.name] = column.type
     return cls(data_types)
예제 #6
0
def visit_column_comment(element, compiler, **kw):
    ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"

    comment = compiler.sql_compiler.render_literal_value(
        (element.comment if element.comment is not None else ""),
        sqltypes.String(),
    )

    return ddl.format(
        table_name=element.table_name,
        column_name=element.column_name,
        comment=comment,
    )
예제 #7
0
 def from_big_query_schema_fields(
         cls, bq_schema: List[bigquery.SchemaField]) -> "MockTableSchema":
     data_types = {}
     for field in bq_schema:
         field_type = bigquery.enums.SqlTypeNames(field.field_type)
         if field_type is bigquery.enums.SqlTypeNames.STRING:
             data_type = sqltypes.String(255)
         elif field_type is bigquery.enums.SqlTypeNames.INTEGER:
             data_type = sqltypes.Integer
         elif field_type is bigquery.enums.SqlTypeNames.FLOAT:
             data_type = sqltypes.Float
         elif field_type is bigquery.enums.SqlTypeNames.DATE:
             data_type = sqltypes.Date
         elif field_type is bigquery.enums.SqlTypeNames.BOOLEAN:
             data_type = sqltypes.Boolean
         else:
             raise ValueError(
                 f"Unhandled big query field type '{field_type}' for attribute '{field.name}'"
             )
         data_types[field.name] = data_type
     return cls(data_types)
예제 #8
0
class RoomToSession(AnyModel, MixinSessionFK):

    room_name = Column(st.String(255), unique=True)
예제 #9
0
def convert_sqla_type_for_dialect(
        coltype: TypeEngine,
        dialect: Dialect,
        strip_collation: bool = True,
        convert_mssql_timestamp: bool = True,
        expand_for_scrubbing: bool = False) -> TypeEngine:
    """
    Converts an SQLAlchemy column type from one SQL dialect to another.

    Args:
        coltype: SQLAlchemy column type in the source dialect

        dialect: destination :class:`Dialect`

        strip_collation: remove any ``COLLATION`` information?

        convert_mssql_timestamp:
            since you cannot write to a SQL Server ``TIMESTAMP`` field, setting
            this option to ``True`` (the default) converts such types to
            something equivalent but writable.

        expand_for_scrubbing:
            The purpose of expand_for_scrubbing is that, for example, a
            ``VARCHAR(200)`` field containing one or more instances of
            ``Jones``, where ``Jones`` is to be replaced with ``[XXXXXX]``,
            will get longer (by an unpredictable amount). So, better to expand
            to unlimited length.

    Returns:
        an SQLAlchemy column type instance, in the destination dialect

    """
    assert coltype is not None

    # noinspection PyUnresolvedReferences
    to_mysql = dialect.name == SqlaDialectName.MYSQL
    # noinspection PyUnresolvedReferences
    to_mssql = dialect.name == SqlaDialectName.MSSQL
    typeclass = type(coltype)

    # -------------------------------------------------------------------------
    # Text
    # -------------------------------------------------------------------------
    if isinstance(coltype, sqltypes.Enum):
        return sqltypes.String(length=coltype.length)
    if isinstance(coltype, sqltypes.UnicodeText):
        # Unbounded Unicode text.
        # Includes derived classes such as mssql.base.NTEXT.
        return sqltypes.UnicodeText()
    if isinstance(coltype, sqltypes.Text):
        # Unbounded text, more generally. (UnicodeText inherits from Text.)
        # Includes sqltypes.TEXT.
        return sqltypes.Text()
    # Everything inheriting from String has a length property, but can be None.
    # There are types that can be unlimited in SQL Server, e.g. VARCHAR(MAX)
    # and NVARCHAR(MAX), that MySQL needs a length for. (Failure to convert
    # gives e.g.: 'NVARCHAR requires a length on dialect mysql'.)
    if isinstance(coltype, sqltypes.Unicode):
        # Includes NVARCHAR(MAX) in SQL -> NVARCHAR() in SQLAlchemy.
        if (coltype.length is None and to_mysql) or expand_for_scrubbing:
            return sqltypes.UnicodeText()
    # The most general case; will pick up any other string types.
    if isinstance(coltype, sqltypes.String):
        # Includes VARCHAR(MAX) in SQL -> VARCHAR() in SQLAlchemy
        if (coltype.length is None and to_mysql) or expand_for_scrubbing:
            return sqltypes.Text()
        if strip_collation:
            return remove_collation(coltype)
        return coltype

    # -------------------------------------------------------------------------
    # Binary
    # -------------------------------------------------------------------------

    # -------------------------------------------------------------------------
    # BIT
    # -------------------------------------------------------------------------
    if typeclass == mssql.base.BIT and to_mysql:
        # MySQL BIT objects have a length attribute.
        return mysql.base.BIT()

    # -------------------------------------------------------------------------
    # TIMESTAMP
    # -------------------------------------------------------------------------
    is_mssql_timestamp = isinstance(coltype, MSSQL_TIMESTAMP)
    if is_mssql_timestamp and to_mssql and convert_mssql_timestamp:
        # You cannot write explicitly to a TIMESTAMP field in SQL Server; it's
        # used for autogenerated values only.
        # - http://stackoverflow.com/questions/10262426/sql-server-cannot-insert-an-explicit-value-into-a-timestamp-column  # noqa
        # - https://social.msdn.microsoft.com/Forums/sqlserver/en-US/5167204b-ef32-4662-8e01-00c9f0f362c2/how-to-tranfer-a-column-with-timestamp-datatype?forum=transactsql  # noqa
        #   ... suggesting BINARY(8) to store the value.
        # MySQL is more helpful:
        # - http://stackoverflow.com/questions/409286/should-i-use-field-datetime-or-timestamp  # noqa
        return mssql.base.BINARY(8)

    # -------------------------------------------------------------------------
    # Some other type
    # -------------------------------------------------------------------------
    return coltype
예제 #10
0
class Company(Base):
    __tablename__ = 'company'

    id = Column(t.Integer, primary_key=True)
    full_name = Column(t.String, nullable=False)
    abbv = Column(t.String(4), nullable=False)

    # rich = Column(t.Boolean, default=False)
    stock_price = Column(t.Float)
    price_diff = Column(t.Float, default=0)
    months = Column(t.Integer, default=0)  # age

    event_increase = Column(t.Integer, default=0)
    event_months_remaining = Column(t.Integer, default=0)

    increase_chance = Column(t.Integer, default=50)  # percentage
    max_increase = Column(t.Float, default=0.3)
    max_decrease = Column(t.Float, default=0.35)

    bankrupt = Column(t.Boolean, default=False)

    shares = relationship("Shares", backref="company", passive_deletes=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def iterate(self):
        if self.bankrupt:
            return

        if self.event_increase is None:
            self.event_increase = 0

        if random.randrange(1,
                            100) < self.increase_chance + self.event_increase:
            amount = random.uniform(1, 1 + self.max_increase)
        else:
            amount = random.uniform(1 - self.max_decrease, 1)

        if self.event_months_remaining is not None:
            self.event_months_remaining -= 1
            if self.event_months_remaining <= 0:
                self.event_increase = 0

        initial_price = self.stock_price
        self.stock_price = self.stock_price * amount
        self.price_diff = initial_price - self.stock_price

        self.months += 1

        if self.stock_price <= 0.5:
            self.bankrupt = True
            return

        return self.price_diff

    @classmethod
    def create(cls, starting_price, name=None, **kwargs):
        if name is None:
            name = ["dflt", "default"]
        return cls(
            stock_price=starting_price,
            abbv=name[0],
            full_name=name[1],
            **kwargs,
        )

    @classmethod
    def find_by_abbreviation(cls, abbreviation: str, session):
        return session.query(cls).filter_by(abbv=abbreviation).first()

    @hybrid_property
    def stocks_bought(self):
        # noinspection PyTypeChecker
        return sum(share.amount for share in self.shares)

    @stocks_bought.expression
    def stocks_bought(self):
        return select([
            func.sum(Shares.amount)
        ]).where(Shares.company_id == self.id).label("stocks_bought")

    @property
    def announcement_description(self):
        return f"{self.abbv.upper()}[{self.stock_price:,.1f}{-(self.price_diff/(self.stock_price+self.price_diff)*100):+.1f}%]"

    def __str__(self):
        years = int(self.months / 12)
        months = self.months % 12
        return f"Name: '{self.abbv}' aka '{self.full_name}' | stock_price: {self.stock_price:,.2f} | " \
               f"price change: {-(self.price_diff/(self.stock_price+self.price_diff)*100):+.1f}% | " \
               f"lifespan: {years} {'years' if not years == 1 else 'year'} " \
               f"and {months} {'months' if not months == 1 else 'month'} | Stocks Bought: {self.stocks_bought}"

    def __repr__(self):
        return self._repr(
            **self._getattrs("id", "abbv", "stocks_bought", "stock_price"), )
예제 #11
0
def test_select_star(mocker: MockFixture, app_context: AppContext) -> None:
    """
    Test the ``select_star`` method.

    The method removes pseudo-columns from structures inside arrays. While these
    pseudo-columns show up as "columns" for metadata reasons, we can't select them
    in the query, as opposed to fields from non-array structures.
    """
    from superset.db_engine_specs.bigquery import BigQueryEngineSpec

    cols = [
        {
            "name": "trailer",
            "type": sqltypes.ARRAY(sqltypes.JSON()),
            "nullable": True,
            "comment": None,
            "default": None,
            "precision": None,
            "scale": None,
            "max_length": None,
        },
        {
            "name": "trailer.key",
            "type": sqltypes.String(),
            "nullable": True,
            "comment": None,
            "default": None,
            "precision": None,
            "scale": None,
            "max_length": None,
        },
        {
            "name": "trailer.value",
            "type": sqltypes.String(),
            "nullable": True,
            "comment": None,
            "default": None,
            "precision": None,
            "scale": None,
            "max_length": None,
        },
        {
            "name": "trailer.email",
            "type": sqltypes.String(),
            "nullable": True,
            "comment": None,
            "default": None,
            "precision": None,
            "scale": None,
            "max_length": None,
        },
    ]

    # mock the database so we can compile the query
    database = mocker.MagicMock()
    database.compile_sqla_query = lambda query: str(
        query.compile(dialect=BigQueryDialect()))

    engine = mocker.MagicMock()
    engine.dialect = BigQueryDialect()

    sql = BigQueryEngineSpec.select_star(
        database=database,
        table_name="my_table",
        engine=engine,
        schema=None,
        limit=100,
        show_cols=True,
        indent=True,
        latest_partition=False,
        cols=cols,
    )
    assert (sql == """SELECT `trailer` AS `trailer`
FROM `my_table`
LIMIT :param_1""")
예제 #12
0
class Participant(otree.database.SSPPGModel, MixinVars):
    __tablename__ = 'otree_participant'

    session_id = Column(st.Integer, ForeignKey('otree_session.id'))
    session = relationship("Session", back_populates="pp_set")

    label = Column(st.String(50), nullable=True,)

    id_in_session = Column(st.Integer, nullable=True)

    payoff = Column(CurrencyType, default=0)

    time_started = Column(st.DateTime, nullable=True)
    mturk_assignment_id = Column(st.String(50), nullable=True)
    mturk_worker_id = Column(st.String(50), nullable=True)

    _index_in_pages = Column(st.Integer, default=0, index=True)

    def _numeric_label(self):
        """the human-readable version."""
        return 'P{}'.format(self.id_in_session)

    _monitor_note = Column(st.String(300), nullable=True)

    code = Column(
        st.String(16),
        default=random_chars_8,
        # set non-nullable, until we make our CharField non-nullable
        nullable=False,
        # unique implies DB index
        unique=True,
    )

    # useful when we don't want to load the whole session just to get the code
    _session_code = Column(st.String(16))

    visited = Column(st.Boolean, default=False, index=True,)

    # stores when the page was first visited
    _last_page_timestamp = Column(st.Integer, nullable=True)

    _last_request_timestamp = Column(st.Integer, nullable=True)

    is_on_wait_page = Column(st.Boolean, default=False)

    # these are both for the admin
    # In the changelist, simply call these "page" and "app"
    _current_page_name = Column(st.String(200), nullable=True)
    _current_app_name = Column(st.String(200), nullable=True)

    # only to be displayed in the admin participants changelist
    _round_number = Column(st.Integer, nullable=True)

    _current_form_page_url = Column(st.String(500))

    _max_page_index = Column(st.Integer,)

    _is_bot = Column(st.Boolean, default=False)
    # can't start with an underscore because used in template
    # can't end with underscore because it's a django field (fields.E001)
    is_browser_bot = Column(st.Boolean, default=False)

    _timeout_expiration_time = otree.database.FloatField()
    _timeout_page_index = Column(st.Integer,)

    _gbat_is_waiting = Column(st.Boolean, default=False)
    _gbat_page_index = Column(st.Integer,)
    _gbat_grouped = Column(st.Boolean,)

    def _current_page(self):
        # don't put 'pages' because that causes wrapping which takes more space
        # since it's longer than the header
        return f'{self._index_in_pages}/{self._max_page_index}'

    # because variables used in templates can't start with an underscore
    def current_page_(self):
        return self._current_page()

    def get_players(self):
        """Used to calculate payoffs"""
        lst = []
        app_sequence = self.session.config['app_sequence']
        for app in app_sequence:
            models_module = otree.common.get_models_module(app)
            players = models_module.Player.objects_filter(participant=self).order_by(
                'round_number'
            )
            lst.extend(list(players))
        return lst

    def _url_i_should_be_on(self):
        if not self.visited:
            return self._start_url()
        if self._index_in_pages <= self._max_page_index:
            return url_i_should_be_on(
                self.code, self._session_code, self._index_in_pages
            )
        return '/OutOfRangeNotification/' + self.code

    def _start_url(self):
        return otree.common.participant_start_url(self.code)

    def payoff_in_real_world_currency(self):
        return self.payoff.to_real_world_currency(self.session)

    def payoff_plus_participation_fee(self):
        return self.session._get_payoff_plus_participation_fee(self.payoff)

    def _get_current_player(self):
        lookup = get_page_lookup(self._session_code, self._index_in_pages)
        models_module = otree.common.get_models_module(lookup.app_name)
        PlayerClass = getattr(models_module, 'Player')
        return PlayerClass.objects_get(
            participant=self, round_number=lookup.round_number
        )

    def initialize(self, participant_label):
        """in a separate function so that we can call it individually,
        e.g. from advance_last_place_participants"""
        pp = self
        if pp._index_in_pages == 0:
            pp._index_in_pages = 1
            pp.visited = True

            # participant.label might already have been set
            pp.label = pp.label or participant_label

            # default to Central European Time
            pp.time_started = datetime.datetime.utcnow() + datetime.timedelta(hours=1)
            pp._last_page_timestamp = time()
            player = pp._get_current_player()
            player.start()
예제 #13
0
def StringField(**kwargs):
    return wrap_column(
        st.String(length=kwargs.get('max_length', 10000)),
        **kwargs,
    )
예제 #14
0
class Session(otree.database.SSPPGModel, MixinVars):
    __tablename__ = 'otree_session'

    config: dict = Column(otree.database._PickleField, default=dict)

    pp_set = relationship("Participant",
                          back_populates="session",
                          lazy='dynamic')
    # label of this session instance
    label = Column(st.String, nullable=True)

    code = Column(
        st.String(16),
        default=random_chars_8,
        nullable=False,
        unique=True,
    )

    mturk_HITId = Column(st.String(300), nullable=True)
    mturk_HITGroupId = Column(st.String(300), nullable=True)

    is_mturk = Column(st.Boolean, default=False)

    def mturk_num_workers(self):
        assert self.is_mturk
        return int(self.num_participants /
                   settings.MTURK_NUM_PARTICIPANTS_MULTIPLE)

    mturk_use_sandbox = Column(st.Boolean, default=True)

    # use Float instead of DateTime because DateTime
    # is a pain to work with (e.g. naive vs aware datetime objects)
    # and there is no need here for DateTime
    mturk_expiration = Column(st.Float, nullable=True)
    mturk_qual_id = Column(st.String(50), default='')

    archived = Column(
        st.Boolean,
        default=False,
        index=True,
    )

    comment = Column(st.Text)

    _anonymous_code = Column(
        st.String(20),
        default=random_chars_10,
        nullable=False,
        index=True,
    )

    is_demo = Column(st.Boolean, default=False)

    _admin_report_app_names = Column(st.Text, default='')
    _admin_report_num_rounds = Column(st.String(255), default='')

    num_participants = Column(st.Integer)

    def __unicode__(self):
        return self.code

    @property
    def participation_fee(self):
        '''This method is deprecated from public API,
        but still useful internally (like data export)'''
        return self.config['participation_fee']

    @property
    def real_world_currency_per_point(self):
        '''This method is deprecated from public API,
        but still useful internally (like data export)'''
        return self.config['real_world_currency_per_point']

    @property
    def use_browser_bots(self):
        return self.config.get('use_browser_bots', False)

    def mock_exogenous_data(self):
        '''
        It's for any exogenous data:
        - participant labels (which are not passed in through REST API)
        - participant vars
        - session vars (if we enable that)
        '''
        if self.config.get('mock_exogenous_data'):
            import shared_out as user_utils

            user_utils.mock_exogenous_data(self)

    def get_subsessions(self):
        lst = []
        app_sequence = self.config['app_sequence']
        for app in app_sequence:
            models_module = otree.common.get_models_module(app)
            subsessions = models_module.Subsession.objects_filter(
                session=self).order_by('round_number')
            lst.extend(list(subsessions))
        return lst

    def get_participants(self):
        return list(self.pp_set.order_by('id_in_session'))

    def mturk_worker_url(self):
        # different HITs
        # get the same preview page, because they are lumped into the same
        # "hit group". This is not documented, but it seems HITs are lumped
        # if a certain subset of properties are the same:
        # https://forums.aws.amazon.com/message.jspa?messageID=597622#597622
        # this seems like the correct design; the only case where this will
        # not work is if the HIT was deleted from the server, but in that case,
        # the HIT itself should be canceled.

        # 2018-06-04:
        # the format seems to have changed to this:
        # https://worker.mturk.com/projects/{group_id}/tasks?ref=w_pl_prvw
        # but the old format still works.
        # it seems I can't replace groupId by hitID, which i would like to do
        # because it's more precise.
        subdomain = "workersandbox" if self.mturk_use_sandbox else 'www'
        return "https://{}.mturk.com/mturk/preview?groupId={}".format(
            subdomain, self.mturk_HITGroupId)

    def mturk_is_expired(self):
        # self.mturk_expiration is offset-aware, so therefore we must compare
        # it against an offset-aware value.
        return self.mturk_expiration and self.mturk_expiration < time.time()

    def mturk_is_active(self):

        return self.mturk_HITId and not self.mturk_is_expired()

    def advance_last_place_participants(self):
        """the problem with using the test client to make get/post requests is
        (1) this request already has the global asyncio.lock
        (2) there are apparently some issues with async/await and event loops.
        """
        from otree.lookup import get_page_lookup
        from otree.api import WaitPage, Page

        participants = self.get_participants()

        # in case some participants haven't started
        unvisited_participants = False
        for p in participants:
            if p._index_in_pages == 0:
                p.initialize(None)

        if unvisited_participants:
            # that's it -- just visit the start URL, advancing by 1
            return

        last_place_page_index = min([p._index_in_pages for p in participants])
        last_place_participants = [
            p for p in participants
            if p._index_in_pages == last_place_page_index
        ]
        for p in last_place_participants:
            page_index = p._index_in_pages
            if page_index >= p._max_page_index:
                return
            page = get_page_lookup(
                self.code,
                page_index).page_class.instantiate_without_request()
            page.set_attributes(p)

            if isinstance(page, Page):

                from starlette.datastructures import FormData

                page._is_frozen = False
                page._form_data = FormData({
                    otree.constants.admin_secret_code:
                    ADMIN_SECRET_CODE,
                    otree.constants.timeout_happened:
                    '1',
                })
                # TODO: should we also call .get() so that _update_monitor_table will also get run?
                resp = page.post()
                if resp.status_code >= 400:
                    msg = (
                        f'Submitting page {p._current_form_page_url} failed, '
                        f'returned HTTP status code {resp.status_code}. '
                        'Check the logs')
                    raise AssertionError(msg)
            else:
                # it's possible that the slowest user is on a wait page,
                # especially if their browser is closed.
                # because they were waiting for another user who then
                # advanced past the wait page, but they were never
                # advanced themselves.
                resp = page.inner_dispatch(request=None)

            # do the auto-advancing here,
            # rather than in increment_index_in_pages,
            # because it's only needed here.
            otree.channels.utils.sync_group_send(group=auto_advance_group(
                p.code),
                                                 data={'auto_advanced': True})

    def get_room(self):
        from otree.room import ROOM_DICT

        try:
            room_name = RoomToSession.objects_get(session=self).room_name
            return ROOM_DICT[room_name]
        except NoResultFound:
            return None

    def _get_payoff_plus_participation_fee(self, payoff):
        '''For a participant who has the given payoff,
        return their payoff_plus_participation_fee
        Useful to define it here, for data export
        '''

        return self.config[
            'participation_fee'] + payoff.to_real_world_currency(self)

    def _set_admin_report_app_names(self):

        admin_report_app_names = []
        num_rounds_list = []
        for app_name in self.config['app_sequence']:
            models_module = otree.common.get_models_module(app_name)
            app_label = get_app_label_from_name(app_name)
            try:
                get_template_name_if_exists([
                    f'{app_label}/admin_report.html',
                    f'{app_label}/AdminReport.html'
                ])
            except TemplateLoadError:
                pass
            else:
                admin_report_app_names.append(app_name)
                num_rounds_list.append(models_module.Constants.num_rounds)

        self._admin_report_app_names = ';'.join(admin_report_app_names)
        self._admin_report_num_rounds = ';'.join(
            str(n) for n in num_rounds_list)

    def _admin_report_apps(self):
        return self._admin_report_app_names.split(';')

    def _admin_report_num_rounds_list(self):
        return [int(num) for num in self._admin_report_num_rounds.split(';')]

    def has_admin_report(self):
        return bool(self._admin_report_app_names)
class CorrectionsOutputViewTest(BaseViewTest):
    """Tests the Corrections output view."""

    INPUT_SCHEMA = MockTableSchema({
        **METRIC_CALCULATOR_SCHEMA.data_types,
        "compare_date_partition":
        sqltypes.Date(),
        "compare_value":
        sqltypes.Numeric(),
        "state_code":
        sqltypes.String(255),
    })

    def test_recent_population(self) -> None:
        """Tests the basic use case of calculating population"""
        # Arrange
        self.create_mock_bq_table(
            dataset_id="justice_counts",
            table_id="source_materialized",
            mock_schema=MockTableSchema.from_sqlalchemy_table(
                schema.Source.__table__),
            mock_data=pd.DataFrame([[1, "XX"], [2, "YY"], [3, "ZZ"]],
                                   columns=["id", "name"]),
        )
        self.create_mock_bq_table(
            dataset_id="justice_counts",
            table_id="report_materialized",
            mock_schema=MockTableSchema.from_sqlalchemy_table(
                schema.Report.__table__),
            mock_data=pd.DataFrame(
                [
                    [
                        1,
                        1,
                        "_",
                        "All",
                        "2021-01-01",
                        "xx.gov",
                        "MANUALLY_ENTERED",
                        "John",
                    ],
                    [
                        2,
                        2,
                        "_",
                        "All",
                        "2021-01-02",
                        "yy.gov",
                        "MANUALLY_ENTERED",
                        "Jane",
                    ],
                    [
                        3,
                        3,
                        "_",
                        "All",
                        "2021-01-02",
                        "zz.gov",
                        "MANUALLY_ENTERED",
                        "Jude",
                    ],
                ],
                columns=[
                    "id",
                    "source_id",
                    "type",
                    "instance",
                    "publish_date",
                    "url",
                    "acquisition_method",
                    "acquired_by",
                ],
            ),
        )
        self.create_mock_bq_table(
            dataset_id="justice_counts",
            table_id="metric_calculator",
            mock_schema=self.INPUT_SCHEMA,
            mock_data=pd.DataFrame(
                [
                    row(
                        1,
                        "2021-01-01",
                        "2020-11-30",
                        (FakeState("US_XX"), ),
                        ["A", "B", "A"],
                        3000,
                        measurement_type="INSTANT",
                    ) + (None, None, "US_XX"),
                    row(
                        1,
                        "2021-01-01",
                        "2020-12-31",
                        (FakeState("US_XX"), ),
                        ["B", "B", "C"],
                        4000,
                        measurement_type="INSTANT",
                    ) + (None, None, "US_XX"),
                    row(
                        2,
                        "2021-01-01",
                        "2020-11-30",
                        (FakeState("US_YY"), ),
                        ["A", "B", "A"],
                        1000,
                        measurement_type="INSTANT",
                    ) + (None, None, "US_YY"),
                    row(
                        2,
                        "2021-01-01",
                        "2020-12-31",
                        (FakeState("US_YY"), ),
                        ["B", "B", "C"],
                        1020,
                        measurement_type="INSTANT",
                    ) + (None, None, "US_YY"),
                    row(
                        3,
                        "2021-01-01",
                        "2020-11-30",
                        (FakeState("US_ZZ"), ),
                        ["A", "B", "A"],
                        400,
                        measurement_type="INSTANT",
                    ) + (None, None, "US_ZZ"),
                    row(
                        3,
                        "2021-01-01",
                        "2020-12-31",
                        (FakeState("US_ZZ"), ),
                        ["C", "C", "B"],
                        500,
                        measurement_type="INSTANT",
                    ) + (None, None, "US_ZZ"),
                ],
                columns=self.INPUT_SCHEMA.data_types.keys(),
            ),
        )

        # Act
        dimensions = ["state_code", "metric", "year", "month"]
        prison_population_metric = metric_calculator.CalculatedMetric(
            system=schema.System.CORRECTIONS,
            metric=schema.MetricType.POPULATION,
            filtered_dimensions=[manual_upload.PopulationType.PRISON],
            aggregated_dimensions={
                "state_code":
                metric_calculator.Aggregation(dimension=manual_upload.State,
                                              comprehensive=False)
            },
            output_name="POP",
        )
        results = self.query_view_for_builder(
            corrections_metrics.CorrectionsOutputViewBuilder(
                dataset_id="fake-dataset",
                metric_to_calculate=prison_population_metric,
                input_view=SimpleBigQueryViewBuilder(
                    dataset_id="justice_counts",
                    view_id="metric_calculator",
                    description="metric_calculator view",
                    view_query_template="",
                ),
            ),
            data_types={
                "year": int,
                "month": int,
                "value": int
            },
            dimensions=dimensions,
        )

        # Assert
        expected = pd.DataFrame(
            [
                [
                    "US_XX",
                    "POP",
                    2020,
                    11,
                    datetime.date.fromisoformat("2020-11-30"),
                    "XX",
                    "xx.gov",
                    "_",
                    datetime.date.fromisoformat("2021-01-01"),
                    "INSTANT",
                    ["A", "B"],
                    3000,
                ] + [None] * 4,
                [
                    "US_XX",
                    "POP",
                    2020,
                    12,
                    datetime.date.fromisoformat("2020-12-31"),
                    "XX",
                    "xx.gov",
                    "_",
                    datetime.date.fromisoformat("2021-01-01"),
                    "INSTANT",
                    ["B", "C"],
                    4000,
                ] + [None] * 4,
                [
                    "US_YY",
                    "POP",
                    2020,
                    11,
                    datetime.date.fromisoformat("2020-11-30"),
                    "YY",
                    "yy.gov",
                    "_",
                    datetime.date.fromisoformat("2021-01-02"),
                    "INSTANT",
                    ["A", "B"],
                    1000,
                ] + [None] * 4,
                [
                    "US_YY",
                    "POP",
                    2020,
                    12,
                    datetime.date.fromisoformat("2020-12-31"),
                    "YY",
                    "yy.gov",
                    "_",
                    datetime.date.fromisoformat("2021-01-02"),
                    "INSTANT",
                    ["B", "C"],
                    1020,
                ] + [None] * 4,
                [
                    "US_ZZ",
                    "POP",
                    2020,
                    11,
                    datetime.date.fromisoformat("2020-11-30"),
                    "ZZ",
                    "zz.gov",
                    "_",
                    datetime.date.fromisoformat("2021-01-02"),
                    "INSTANT",
                    ["A", "B"],
                    400,
                ] + [None] * 4,
                [
                    "US_ZZ",
                    "POP",
                    2020,
                    12,
                    datetime.date.fromisoformat("2020-12-31"),
                    "ZZ",
                    "zz.gov",
                    "_",
                    datetime.date.fromisoformat("2021-01-02"),
                    "INSTANT",
                    ["B", "C"],
                    500,
                ] + [None] * 4,
            ],
            columns=[
                "state_code",
                "metric",
                "year",
                "month",
                "date_reported",
                "source_name",
                "source_url",
                "report_name",
                "date_published",
                "measurement_type",
                "raw_source_categories",
                "value",
                "compared_to_year",
                "compared_to_month",
                "value_change",
                "percentage_change",
            ],
        )
        expected = expected.set_index(dimensions)
        assert_frame_equal(expected, results)

    def test_comparisons(self) -> None:
        """Tests that percentage change is correct, or null when the prior value was zero"""
        # Arrange
        self.create_mock_bq_table(
            dataset_id="justice_counts",
            table_id="source_materialized",
            mock_schema=MockTableSchema.from_sqlalchemy_table(
                schema.Source.__table__),
            mock_data=pd.DataFrame([[1, "XX"]], columns=["id", "name"]),
        )
        self.create_mock_bq_table(
            dataset_id="justice_counts",
            table_id="report_materialized",
            mock_schema=MockTableSchema.from_sqlalchemy_table(
                schema.Report.__table__),
            mock_data=pd.DataFrame(
                [[
                    1,
                    1,
                    "_",
                    "All",
                    "2021-01-01",
                    "xx.gov",
                    "MANUALLY_ENTERED",
                    "John",
                ]],
                columns=[
                    "id",
                    "source_id",
                    "type",
                    "instance",
                    "publish_date",
                    "url",
                    "acquisition_method",
                    "acquired_by",
                ],
            ),
        )
        self.create_mock_bq_table(
            dataset_id="justice_counts",
            table_id="metric_calculator",
            mock_schema=self.INPUT_SCHEMA,
            mock_data=pd.DataFrame(
                [
                    row(1, "2021-01-01", "2022-01-01",
                        (FakeState("US_XX"), ), [], 3) +
                    (datetime.date.fromisoformat("2021-02-01"), 0, "US_XX"),
                    row(1, "2021-01-01", "2021-01-01",
                        (FakeState("US_XX"), ), [], 0) +
                    (datetime.date.fromisoformat("2020-02-01"), 2, "US_XX"),
                    row(1, "2021-01-01", "2020-01-01",
                        (FakeState("US_XX"), ), [], 2) + (None, None, "US_XX"),
                ],
                columns=self.INPUT_SCHEMA.data_types.keys(),
            ),
        )

        # Act
        dimensions = ["state_code", "metric", "year", "month"]
        parole_population = metric_calculator.CalculatedMetric(
            system=schema.System.CORRECTIONS,
            metric=schema.MetricType.ADMISSIONS,
            filtered_dimensions=[],
            aggregated_dimensions={
                "state_code":
                metric_calculator.Aggregation(dimension=manual_upload.State,
                                              comprehensive=False)
            },
            output_name="ADMISSIONS",
        )
        results = self.query_view_for_builder(
            corrections_metrics.CorrectionsOutputViewBuilder(
                dataset_id="fake-dataset",
                metric_to_calculate=parole_population,
                input_view=SimpleBigQueryViewBuilder(
                    dataset_id="justice_counts",
                    view_id="metric_calculator",
                    description="metric_calculator view",
                    view_query_template="",
                ),
            ),
            data_types={
                "year": int,
                "month": int,
                "value": int
            },
            dimensions=dimensions,
        )

        # Assert
        expected = pd.DataFrame(
            [
                [
                    "US_XX",
                    "ADMISSIONS",
                    2020,
                    1,
                    datetime.date.fromisoformat("2020-01-31"),
                    "XX",
                    "xx.gov",
                    "_",
                    datetime.date.fromisoformat("2021-01-01"),
                    "INSTANT",
                    [],
                    2,
                    None,
                    None,
                    None,
                    None,
                ],
                [
                    "US_XX",
                    "ADMISSIONS",
                    2021,
                    1,
                    datetime.date.fromisoformat("2021-01-31"),
                    "XX",
                    "xx.gov",
                    "_",
                    datetime.date.fromisoformat("2021-01-01"),
                    "INSTANT",
                    [],
                    0,
                    2020,
                    1,
                    -2,
                    -1.00,
                ],
                # Percentage change is None as prior value was 0
                [
                    "US_XX",
                    "ADMISSIONS",
                    2022,
                    1,
                    datetime.date.fromisoformat("2022-01-31"),
                    "XX",
                    "xx.gov",
                    "_",
                    datetime.date.fromisoformat("2021-01-01"),
                    "INSTANT",
                    [],
                    3,
                    2021,
                    1,
                    3,
                    None,
                ],
            ],
            columns=[
                "state_code",
                "metric",
                "year",
                "month",
                "date_reported",
                "source_name",
                "source_url",
                "report_name",
                "date_published",
                "measurement_type",
                "raw_source_categories",
                "value",
                "compared_to_year",
                "compared_to_month",
                "value_change",
                "percentage_change",
            ],
        )
        expected = expected.set_index(dimensions)
        assert_frame_equal(expected, results)