예제 #1
0
def fetch_all(user_email: str, consistent: bool = False) -> List[Subscription]:
    """Fetch all subscriptions for a user.

    Args:
        user_email: User's email address that uniquely identifies them.
        consistent: Whether the read should be strongly consistent.

    Returns:
        The user's subscriptions.

    Raises:
        `db.DatabaseError` if there was an error connecting to the Database.

    """
    # TODO (abiro) add group subs as well
    pk = db.PartitionKey(ent.User, user_email)
    sk = db.PrefixSortKey(ent.Sub)
    subs = get_table().query_prefix(pk, sk,
                                    consistent=consistent,
                                    attributes=['SK', 'IsActive'])
    res = []
    for s in subs:
        r: Subscription = {
            'ProjectDomain': s['SK'],
            'IsActive': s['IsActive']
        }
        res.append(r)
    return res
예제 #2
0
def is_valid(user_email: str, project_domain: str) -> bool:
    """Verify whether a user has an active subscription to a project.

    Args:
        user_email: The user's email address.
        project_domain: The project's domain name.

    Returns:
        True if the user has an activate subscription to the project.

    """
    pk = db.PartitionKey(ent.User, user_email)
    sk_user = db.SortKey(ent.Sub, project_domain)
    sk_group = db.SortKey(ent.GroupSub, project_domain)

    pk_cond = cond.Key('PK').eq(str(pk))
    sk_cond = cond.Key('SK').eq(str(sk_user)) | cond.Key('SK').eq(str(sk_group))  # noqa 501
    key_cond = pk_cond & sk_cond

    query_arg = db.QueryArg(key_cond, attributes=['IsActive'])
    subs = get_table().query(query_arg)
    for s in subs:
        if s['IsActive']:
            return True
    else:
        return False
예제 #3
0
def fetch(group_name: str, project_domain: str, consistent: bool = False) \
        -> Optional[GroupSubAttributes]:
    """Fetch subscription attributes for a group.

    Args:
        group_name: The user's email address.
        project_domain: The project's domain name.
        consistent: Whether the read should be strongly consistent.


    Returns:
        The subscription attributes if the subscription exists.

    Raises:
        `db.DatabaseError` if there was an error connecting to the Database.

    """
    pk = db.PartitionKey(ent.Group, group_name)
    sk = db.SortKey(ent.GroupSub, project_domain)
    res = get_table().get(pk,
                          sk,
                          consistent=consistent,
                          attributes=['IsActive'])
    if res is not None:
        return cast(GroupSubAttributes, res)
    else:
        return None
예제 #4
0
def delete(user_email: str, session_id: bytes) -> None:
    """Delete the session from the database.

    Args:
        user_email: User's email in Cognito.
        session_id: The session id.

    Raises:
        `db.DatabaseError` if there was an error connecting to the database.

    """
    sess_hash = _hex_hash(session_id)
    pk = db.PartitionKey(ent.Session, sess_hash)
    sk_session = db.SortKey(ent.Session, sess_hash)
    sk_user = db.SortKey(ent.User, user_email)

    # We create a session entity in the database while creating a
    # SESSION-USER relation to make sure that there can be no duplicate
    # session ids in the database. Can not use conditions for this purpose,
    # as those require knowing the primary (composite) key which we don't
    # without querying the database.
    get_table().transact_write_items([
        db.DeleteArg(pk, sk_session),
        db.DeleteArg(pk, sk_user)
    ])
예제 #5
0
def _get_user_create_op(user_email: str, project_domain: str) -> db.InsertArg:
    pk = db.PartitionKey(ent.User, user_email)
    sk = db.SortKey(ent.Sub, project_domain)
    attr: SubAttributes = {
        'IsActive': True
    }
    return db.InsertArg(pk, sk, attr)
예제 #6
0
def create(user_email: str, session_id: bytes) -> None:
    """Store the session id in the database.

    Args:
        user_email: User's email in Cognito.
        session_id: New session id.

    Raises:
        `db.ConditionalCheckFailedError` if the session id already exists.
        `db.DatabasError` if there was an error connecting to the database.

    """
    attributes: SessionAttributes = {
        'ExpiresAt': _get_session_ttl()
    }

    sess_hash = _hex_hash(session_id)
    pk = db.PartitionKey(ent.Session, sess_hash)
    sk_session = db.SortKey(ent.Session, sess_hash)
    sk_user = db.SortKey(ent.User, user_email)

    # We create a session entity in the database while creating a
    # SESSION-USER relation to make sure that there can be no duplicate
    # session ids in the database. Can not use conditions for this purpose,
    # as those require knowing the primary (composite) key which we don't
    # without querying the database.
    get_table().transact_write_items([
        db.InsertArg(pk, sk_session, attributes=attributes),
        db.InsertArg(pk, sk_user, attributes=attributes)
    ])
예제 #7
0
def _get_trial_end_op(user_email: str, project_domain: str, trial_days: int) \
        -> db.InsertArg:
    trial_end = get_trial_end_date(trial_days)
    pk = db.PartitionKey(ent.TrialEnd, trial_end)
    # Concatenation of values ensures uniqueness of item.
    sk = db.SortKey(ent.TrialEnd, f'{user_email}|{project_domain}')
    return db.InsertArg(pk, sk)
예제 #8
0
def fetch(user_email: str, project_domain: str, consistent: bool = False) \
        -> Optional[SubAttributes]:
    """Fetch subscription attributes for a user.

    Args:
        user_email: User's email address that uniquely identifies them.
        project_domain: The project's domain name that uniquely identifies it.
        consistent: Whether the read should be strongly consistent.

    Returns:
        The subscription attributes if the subscription exists.

    Raises:
        `db.DatabaseError` if there was an error connecting to the Database.

    """
    pk = db.PartitionKey(ent.User, user_email)
    sk = db.SortKey(ent.Sub, project_domain)
    res = get_table().get(pk, sk,
                          consistent=consistent,
                          attributes=['IsActive'])
    if res is not None:
        return cast(SubAttributes, res)
    else:
        return None
예제 #9
0
    def setUp(self):
        super().setUp()

        self._client = MagicMock()
        self._mocks['_client'].return_value = self._client
        self._pk = db.PartitionKey(User, '*****@*****.**')
        self._sk = db.SortKey(Subscription, 'docs.example.com')
        self._sk_prefix = db.PrefixSortKey(Subscription)
예제 #10
0
 def setUp(self):
     super().setUp()
     self._pk_2 = db.PartitionKey(User, '*****@*****.**')
     self._sk_2 = db.SortKey(Subscription, 'docs.bar.com')
     self._keys = [
         db.PrimaryKey(self._pk, self._sk),
         db.PrimaryKey(self._pk_2, self._sk_2)
     ]
     self._table_name = 'my-table'
     self._table = Table(self._table_name)
예제 #11
0
def exists(project_domain: str) -> bool:
    """Check whether a project exists in the database.

    Args:
        project_domain: The project's domain name.

    Returns:
        True if the project exists.

    """
    pk = db.PartitionKey(ent.Project, project_domain)
    sk = db.SortKey(ent.Project, project_domain)
    res = get_table().get(pk, sk)
    return bool(res)
예제 #12
0
def is_owner(group_name: str, user_email: str) -> bool:
    """Check whether a user is an owner of a group.

    Args:
        group_name: The group's name that uniquely identifies it.
        user_email: The user's email address.

    Returns:
        True if the user is the owner of the group.

    """
    pk = db.PartitionKey(ent.Group, group_name)
    sk = db.SortKey(ent.Group, group_name)
    res = get_table().get(pk, sk, attributes=['OwnerEmail'])
    return res is not None and res['OwnerEmail'] == user_email
예제 #13
0
def fetch(project_domain: str) -> Optional[ProjectAttributes]:
    """Fetch project attributes based on domain name.

    Args:
        project_domain: The project's domain name.

    Returns:
        The project attributes if the project exists.

    """
    pk = db.PartitionKey(ent.Project, project_domain)
    sk = db.SortKey(ent.Project, project_domain)
    res = get_table().get(pk, sk, attributes=['TrialDays'])
    if res is not None:
        return cast(ProjectAttributes, res)
    else:
        return None
예제 #14
0
def delete(user_email: str, project_domain: str) -> None:
    """Delete a subscription for a user.

    The subscription item is not removed from the database, but it's `IsActive`
    attribute is set to false.

    Args:
        user_email: User's email address that uniquely identifies them.
        project_domain: The project's domain name that uniquely identifies it.

    Raises:
        `db.DatabaseError` if there was an error connecting to the Database.

    """
    # TODO (abiro) Stripe logic
    pk = db.PartitionKey(ent.User, user_email)
    sk = db.SortKey(ent.Sub, project_domain)
    attr: SubAttributes = {
        'IsActive': False
    }
    get_table().update_attributes(pk, sk, attr)
예제 #15
0
def fetch_user_email(session_id: bytes) -> Optional[str]:
    """Fetch the user email based on the session from the database.

    Args:
        session_id: The session id.

    Returns:
        The user's email.

    Raises:
        `db.DatabaseError` if there was an error connecting to the database.

    """
    sess_hash = _hex_hash(session_id)
    pk = db.PartitionKey(ent.Session, sess_hash)
    sk = db.PrefixSortKey(ent.User)
    res = get_table().query_prefix(pk, sk, attributes=['SK'])
    if res:
        return cast(str, res[0]['SK'])
    else:
        return None
예제 #16
0
def recreate(user_email: str, project_domain: str) -> None:
    """Recreate a subscription to a project for a user that has lapsed.

    There is no trial in this case.

    Args:
        user_email: User's email address that uniquely identifies them.
        project_domain: The project's domain name that uniquely identifies it.

    Raises:
        `db.DatabaseError` if there was an error connecting to the Database.

    """
    # TODO (abiro) Stripe logic
    # TODO (abiro) Update instead of Put
    pk = db.PartitionKey(ent.User, user_email)
    sk = db.SortKey(ent.Sub, project_domain)
    attr: SubAttributes = {
        'IsActive': True
    }
    get_table().update_attributes(pk, sk, attr)
예제 #17
0
                )
            else:
                scan = table.scan(ProjectionExpression='PK,SK')

            for item in scan['Items']:
                batch.delete_item(Key={'PK': item['PK'], 'SK': item['SK']})


logging.info('Starting integration tests')

# We clear the DB instead of recreating it to save time.
_clear_db(TABLE_NAME)
table = db.Table(TABLE_NAME)

# Users
pk_alice = db.PartitionKey(User, '*****@*****.**')
sk_alice = db.SortKey(User, '*****@*****.**')

# Products
pk_book = db.PartitionKey(Product, 'book')

# Orders
pk_order1 = db.PartitionKey(Order, '2020-02-21|order-1')
sk_order1 = db.SortKey(Order, '2020-02-21|order-1')
sk_order2 = db.SortKey(Order, '2020-02-21|order-2')

logging.info('Testing insert')
table.insert(pk_alice, sk_alice)

logging.info('Testing update_attributes')
table.update_attributes(pk_alice, sk_alice, {'MyJson': {'A': 1}})
예제 #18
0
 def setUp(self):
     self._pk = db.PartitionKey(User, '*****@*****.**')
     self._sk = db.SortKey(Subscription, 'docs.example.com')
예제 #19
0
 def setUp(self):
     self._pk = db.PartitionKey(User, '*****@*****.**')
     self._sk = db.SortKey(Subscription, 'mitpress.mit.edu')
     self._table_name = 'my-table'
     self._primary_index = db.PrimaryGlobalIndex()
예제 #20
0
def _get_group_create_op(group_name: str, project_domain: str) -> db.InsertArg:
    pk_gr = db.PartitionKey(ent.Group, group_name)
    sk_gr = db.SortKey(ent.GroupSub, project_domain)
    attr: GroupSubAttributes = {'IsActive': True}
    return db.InsertArg(pk_gr, sk_gr, attributes=attr)