Exemplo n.º 1
0
    def query(self,
              table_name,
              hash_key_dict,
              range_comparison,
              range_value_dicts,
              limit,
              exclusive_start_key,
              scan_index_forward,
              projection_expression,
              index_name=None,
              expr_names=None,
              expr_values=None,
              filter_expression=None,
              **filter_kwargs):
        table = self.tables.get(table_name)
        if not table:
            return None, None

        hash_key = DynamoType(hash_key_dict)
        range_values = [
            DynamoType(range_value) for range_value in range_value_dicts
        ]

        filter_expression = get_filter_expression(filter_expression,
                                                  expr_names, expr_values)

        return table.query(hash_key, range_comparison, range_values, limit,
                           exclusive_start_key, scan_index_forward,
                           projection_expression, index_name,
                           filter_expression, **filter_kwargs)
Exemplo n.º 2
0
    def put_item(
        self,
        item_attrs,
        expected=None,
        condition_expression=None,
        expression_attribute_names=None,
        expression_attribute_values=None,
        overwrite=False,
    ):
        if self.hash_key_attr not in item_attrs.keys():
            raise KeyError(
                "One or more parameter values were invalid: Missing the key " +
                self.hash_key_attr + " in the item")
        hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
        if self.has_range_key:
            if self.range_key_attr not in item_attrs.keys():
                raise KeyError(
                    "One or more parameter values were invalid: Missing the key "
                    + self.range_key_attr + " in the item")
            range_value = DynamoType(item_attrs.get(self.range_key_attr))
        else:
            range_value = None

        if expected is None:
            expected = {}
            lookup_range_value = range_value
        else:
            expected_range_value = expected.get(self.range_key_attr,
                                                {}).get("Value")
            if expected_range_value is None:
                lookup_range_value = range_value
            else:
                lookup_range_value = DynamoType(expected_range_value)
        current = self.get_item(hash_value, lookup_range_value)
        item = Item(hash_value, self.hash_key_type, range_value,
                    self.range_key_type, item_attrs)

        if not overwrite:
            if not get_expected(expected).expr(current):
                raise ConditionalCheckFailed
            condition_op = get_filter_expression(
                condition_expression,
                expression_attribute_names,
                expression_attribute_values,
            )
            if not condition_op.expr(current):
                raise ConditionalCheckFailed

        if range_value:
            self.items[hash_value][range_value] = item
        else:
            self.items[hash_value] = item

        if self.stream_shard is not None:
            self.stream_shard.add(current, item)

        return item
Exemplo n.º 3
0
    def scan(
        self,
        table_name,
        filters,
        limit,
        exclusive_start_key,
        filter_expression,
        expr_names,
        expr_values,
        index_name,
        projection_expression,
    ):
        table = self.tables.get(table_name)
        if not table:
            return None, None, None

        scan_filters = {}
        for key, (comparison_operator, comparison_values) in filters.items():
            dynamo_types = [DynamoType(value) for value in comparison_values]
            scan_filters[key] = (comparison_operator, dynamo_types)

        filter_expression = get_filter_expression(
            filter_expression, expr_names, expr_values
        )

        projection_expression = ",".join(
            [
                expr_names.get(attr, attr)
                for attr in projection_expression.replace(" ", "").split(",")
            ]
        )

        return table.scan(
            scan_filters,
            limit,
            exclusive_start_key,
            filter_expression,
            index_name,
            projection_expression,
        )
Exemplo n.º 4
0
    def delete_item(
        self,
        table_name,
        key,
        expression_attribute_names=None,
        expression_attribute_values=None,
        condition_expression=None,
    ):
        table = self.get_table(table_name)
        if not table:
            return None

        hash_value, range_value = self.get_keys_value(table, key)
        item = table.get_item(hash_value, range_value)

        condition_op = get_filter_expression(
            condition_expression,
            expression_attribute_names,
            expression_attribute_values,
        )
        if not condition_op.expr(item):
            raise ConditionalCheckFailed

        return table.delete_item(hash_value, range_value)
Exemplo n.º 5
0
    def transact_write_items(self, transact_items):
        # Create a backup in case any of the transactions fail
        original_table_state = copy.deepcopy(self.tables)
        errors = []
        for item in transact_items:
            try:
                if "ConditionCheck" in item:
                    item = item["ConditionCheck"]
                    key = item["Key"]
                    table_name = item["TableName"]
                    condition_expression = item.get("ConditionExpression",
                                                    None)
                    expression_attribute_names = item.get(
                        "ExpressionAttributeNames", None)
                    expression_attribute_values = item.get(
                        "ExpressionAttributeValues", None)
                    current = self.get_item(table_name, key)

                    condition_op = get_filter_expression(
                        condition_expression,
                        expression_attribute_names,
                        expression_attribute_values,
                    )
                    if not condition_op.expr(current):
                        raise ConditionalCheckFailed()
                elif "Put" in item:
                    item = item["Put"]
                    attrs = item["Item"]
                    table_name = item["TableName"]
                    condition_expression = item.get("ConditionExpression",
                                                    None)
                    expression_attribute_names = item.get(
                        "ExpressionAttributeNames", None)
                    expression_attribute_values = item.get(
                        "ExpressionAttributeValues", None)
                    self.put_item(
                        table_name,
                        attrs,
                        condition_expression=condition_expression,
                        expression_attribute_names=expression_attribute_names,
                        expression_attribute_values=expression_attribute_values,
                    )
                elif "Delete" in item:
                    item = item["Delete"]
                    key = item["Key"]
                    table_name = item["TableName"]
                    condition_expression = item.get("ConditionExpression",
                                                    None)
                    expression_attribute_names = item.get(
                        "ExpressionAttributeNames", None)
                    expression_attribute_values = item.get(
                        "ExpressionAttributeValues", None)
                    self.delete_item(
                        table_name,
                        key,
                        condition_expression=condition_expression,
                        expression_attribute_names=expression_attribute_names,
                        expression_attribute_values=expression_attribute_values,
                    )
                elif "Update" in item:
                    item = item["Update"]
                    key = item["Key"]
                    table_name = item["TableName"]
                    update_expression = item["UpdateExpression"]
                    condition_expression = item.get("ConditionExpression",
                                                    None)
                    expression_attribute_names = item.get(
                        "ExpressionAttributeNames", None)
                    expression_attribute_values = item.get(
                        "ExpressionAttributeValues", None)
                    self.update_item(
                        table_name,
                        key,
                        update_expression=update_expression,
                        condition_expression=condition_expression,
                        expression_attribute_names=expression_attribute_names,
                        expression_attribute_values=expression_attribute_values,
                    )
                else:
                    raise ValueError
                errors.append(None)
            except Exception as e:  # noqa: E722 Do not use bare except
                errors.append(type(e).__name__)
        if any(errors):
            # Rollback to the original state, and reraise the errors
            self.tables = original_table_state
            raise TransactionCanceledException(errors)
Exemplo n.º 6
0
    def update_item(
        self,
        table_name,
        key,
        update_expression,
        expression_attribute_names,
        expression_attribute_values,
        attribute_updates=None,
        expected=None,
        condition_expression=None,
    ):
        table = self.get_table(table_name)

        # Support spaces between operators in an update expression
        # E.g. `a = b + c` -> `a=b+c`
        if update_expression:
            # Parse expression to get validation errors
            update_expression_ast = UpdateExpressionParser.make(
                update_expression)
            update_expression = re.sub(r"\s*([=\+-])\s*", "\\1",
                                       update_expression)

        if all([table.hash_key_attr in key, table.range_key_attr in key]):
            # Covers cases where table has hash and range keys, ``key`` param
            # will be a dict
            hash_value = DynamoType(key[table.hash_key_attr])
            range_value = DynamoType(key[table.range_key_attr])
        elif table.hash_key_attr in key:
            # Covers tables that have a range key where ``key`` param is a dict
            hash_value = DynamoType(key[table.hash_key_attr])
            range_value = None
        else:
            # Covers other cases
            hash_value = DynamoType(key)
            range_value = None

        item = table.get_item(hash_value, range_value)
        orig_item = copy.deepcopy(item)

        if not expected:
            expected = {}

        if not get_expected(expected).expr(item):
            raise ConditionalCheckFailed
        condition_op = get_filter_expression(
            condition_expression,
            expression_attribute_names,
            expression_attribute_values,
        )
        if not condition_op.expr(item):
            raise ConditionalCheckFailed

        # Update does not fail on new items, so create one
        if item is None:
            data = {table.hash_key_attr: {hash_value.type: hash_value.value}}
            if range_value:
                data.update({
                    table.range_key_attr: {
                        range_value.type: range_value.value
                    }
                })

            table.put_item(data)
            item = table.get_item(hash_value, range_value)

        if update_expression:
            validated_ast = UpdateExpressionValidator(
                update_expression_ast,
                expression_attribute_names=expression_attribute_names,
                expression_attribute_values=expression_attribute_values,
                item=item,
            ).validate()
            try:
                UpdateExpressionExecutor(validated_ast, item,
                                         expression_attribute_names).execute()
            except ItemSizeTooLarge:
                raise ItemSizeToUpdateTooLarge()
        else:
            item.update_with_attribute_updates(attribute_updates)
        if table.stream_shard is not None:
            table.stream_shard.add(orig_item, item)
        return item
Exemplo n.º 7
0
    def update_item(
        self,
        table_name,
        key,
        update_expression,
        attribute_updates,
        expression_attribute_names,
        expression_attribute_values,
        expected=None,
        condition_expression=None,
    ):
        table = self.get_table(table_name)

        if all([table.hash_key_attr in key, table.range_key_attr in key]):
            # Covers cases where table has hash and range keys, ``key`` param
            # will be a dict
            hash_value = DynamoType(key[table.hash_key_attr])
            range_value = DynamoType(key[table.range_key_attr])
        elif table.hash_key_attr in key:
            # Covers tables that have a range key where ``key`` param is a dict
            hash_value = DynamoType(key[table.hash_key_attr])
            range_value = None
        else:
            # Covers other cases
            hash_value = DynamoType(key)
            range_value = None

        item = table.get_item(hash_value, range_value)
        orig_item = copy.deepcopy(item)

        if not expected:
            expected = {}

        if not get_expected(expected).expr(item):
            raise ValueError("The conditional request failed")
        condition_op = get_filter_expression(
            condition_expression,
            expression_attribute_names,
            expression_attribute_values,
        )
        if not condition_op.expr(item):
            raise ValueError("The conditional request failed")

        # Update does not fail on new items, so create one
        if item is None:
            data = {table.hash_key_attr: {hash_value.type: hash_value.value}}
            if range_value:
                data.update(
                    {table.range_key_attr: {range_value.type: range_value.value}}
                )

            table.put_item(data)
            item = table.get_item(hash_value, range_value)

        if update_expression:
            item.update(
                update_expression,
                expression_attribute_names,
                expression_attribute_values,
            )
        else:
            item.update_with_attribute_updates(attribute_updates)
        if table.stream_shard is not None:
            table.stream_shard.add(orig_item, item)
        return item