Ejemplo n.º 1
0
    def _trim_results(self, results, limit, exclusive_start_key, scanned_index=None):
        if exclusive_start_key is not None:
            hash_key = DynamoType(exclusive_start_key.get(self.hash_key_attr))
            range_key = exclusive_start_key.get(self.range_key_attr)
            if range_key is not None:
                range_key = DynamoType(range_key)
            for i in range(len(results)):
                if (
                    results[i].hash_key == hash_key
                    and results[i].range_key == range_key
                ):
                    results = results[i + 1 :]
                    break

        last_evaluated_key = None
        if limit and len(results) > limit:
            results = results[:limit]
            last_evaluated_key = {self.hash_key_attr: results[-1].hash_key}
            if results[-1].range_key is not None:
                last_evaluated_key[self.range_key_attr] = results[-1].range_key

            if scanned_index:
                all_indexes = self.all_indexes()
                indexes_by_name = dict((i["IndexName"], i) for i in all_indexes)
                idx = indexes_by_name[scanned_index]
                idx_col_list = [i["AttributeName"] for i in idx["KeySchema"]]
                for col in idx_col_list:
                    last_evaluated_key[col] = results[-1].attrs[col]

        return results, last_evaluated_key
Ejemplo n.º 2
0
    def query(self,
              table_name,
              hash_key_dict,
              range_comparison,
              range_value_dicts,
              limit,
              exclusive_start_key,
              scan_index_forward,
              projection_expression,
              index_name=None,
              expr_names=None,
              expr_values=None,
              filter_expression=None,
              **filter_kwargs):
        table = self.tables.get(table_name)
        if not table:
            return None, None

        hash_key = DynamoType(hash_key_dict)
        range_values = [
            DynamoType(range_value) for range_value in range_value_dicts
        ]

        filter_expression = get_filter_expression(filter_expression,
                                                  expr_names, expr_values)

        return table.query(hash_key, range_comparison, range_values, limit,
                           exclusive_start_key, scan_index_forward,
                           projection_expression, index_name,
                           filter_expression, **filter_kwargs)
Ejemplo n.º 3
0
    def put_item(
        self,
        item_attrs,
        expected=None,
        condition_expression=None,
        expression_attribute_names=None,
        expression_attribute_values=None,
        overwrite=False,
    ):
        if self.hash_key_attr not in item_attrs.keys():
            raise KeyError(
                "One or more parameter values were invalid: Missing the key " +
                self.hash_key_attr + " in the item")
        hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
        if self.has_range_key:
            if self.range_key_attr not in item_attrs.keys():
                raise KeyError(
                    "One or more parameter values were invalid: Missing the key "
                    + self.range_key_attr + " in the item")
            range_value = DynamoType(item_attrs.get(self.range_key_attr))
        else:
            range_value = None

        if expected is None:
            expected = {}
            lookup_range_value = range_value
        else:
            expected_range_value = expected.get(self.range_key_attr,
                                                {}).get("Value")
            if expected_range_value is None:
                lookup_range_value = range_value
            else:
                lookup_range_value = DynamoType(expected_range_value)
        current = self.get_item(hash_value, lookup_range_value)
        item = Item(hash_value, self.hash_key_type, range_value,
                    self.range_key_type, item_attrs)

        if not overwrite:
            if not get_expected(expected).expr(current):
                raise ConditionalCheckFailed
            condition_op = get_filter_expression(
                condition_expression,
                expression_attribute_names,
                expression_attribute_values,
            )
            if not condition_op.expr(current):
                raise ConditionalCheckFailed

        if range_value:
            self.items[hash_value][range_value] = item
        else:
            self.items[hash_value] = item

        if self.stream_shard is not None:
            self.stream_shard.add(current, item)

        return item
Ejemplo n.º 4
0
 def get_keys_value(self, table, keys):
     if table.hash_key_attr not in keys or (
             table.has_range_key and table.range_key_attr not in keys):
         raise ValueError(
             "Table has a range key, but no range key was passed into get_item"
         )
     hash_key = DynamoType(keys[table.hash_key_attr])
     range_key = (DynamoType(keys[table.range_key_attr])
                  if table.has_range_key else None)
     return hash_key, range_key
Ejemplo n.º 5
0
    def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
        self.hash_key = hash_key
        self.hash_key_type = hash_key_type
        self.range_key = range_key
        self.range_key_type = range_key_type

        self.attrs = LimitedSizeDict()
        for key, value in attrs.items():
            self.attrs[key] = DynamoType(value)
Ejemplo n.º 6
0
    def _trim_results(self,
                      results,
                      limit,
                      exclusive_start_key,
                      scanned_index=None):
        if exclusive_start_key is not None:
            hash_key = DynamoType(exclusive_start_key.get(self.hash_key_attr))
            range_key = exclusive_start_key.get(self.range_key_attr)
            if range_key is not None:
                range_key = DynamoType(range_key)
            for i in range(len(results)):
                if (results[i].hash_key == hash_key
                        and results[i].range_key == range_key):
                    results = results[i + 1:]
                    break

        last_evaluated_key = None
        size_limit = 1000000  # DynamoDB has a 1MB size limit
        item_size = sum(res.size() for res in results)
        if item_size > size_limit:
            item_size = idx = 0
            while item_size + results[idx].size() < size_limit:
                item_size += results[idx].size()
                idx += 1
            limit = min(limit, idx) if limit else idx
        if limit and len(results) > limit:
            results = results[:limit]
            last_evaluated_key = {self.hash_key_attr: results[-1].hash_key}
            if results[-1].range_key is not None:
                last_evaluated_key[self.range_key_attr] = results[-1].range_key

            if scanned_index:
                all_indexes = self.all_indexes()
                indexes_by_name = dict(
                    (i["IndexName"], i) for i in all_indexes)
                idx = indexes_by_name[scanned_index]
                idx_col_list = [i["AttributeName"] for i in idx["KeySchema"]]
                for col in idx_col_list:
                    last_evaluated_key[col] = results[-1].attrs[col]

        return results, last_evaluated_key
Ejemplo n.º 7
0
 def _get_appended_list(self, value, expression_attribute_values):
     if type(value) != DynamoType:
         list_append_re = re.match("list_append\\((.+),(.+)\\)", value)
         if list_append_re:
             new_value = expression_attribute_values[list_append_re.group(2).strip()]
             old_list_key = list_append_re.group(1)
             # old_key could be a function itself (if_not_exists)
             if old_list_key.startswith("if_not_exists"):
                 old_list = self._get_default(old_list_key)
                 if not isinstance(old_list, DynamoType):
                     old_list = DynamoType(expression_attribute_values[old_list])
             else:
                 old_list = self.attrs[old_list_key.split(".")[0]]
                 if "." in old_list_key:
                     # Value is nested inside a map - find the appropriate child attr
                     old_list = old_list.child_attr(
                         ".".join(old_list_key.split(".")[1:])
                     )
             if not old_list.is_list():
                 raise ParamValidationError
             old_list.value.extend([DynamoType(v) for v in new_value["L"]])
             value = old_list
     return value
Ejemplo n.º 8
0
    def scan(
        self,
        table_name,
        filters,
        limit,
        exclusive_start_key,
        filter_expression,
        expr_names,
        expr_values,
        index_name,
        projection_expression,
    ):
        table = self.tables.get(table_name)
        if not table:
            return None, None, None

        scan_filters = {}
        for key, (comparison_operator, comparison_values) in filters.items():
            dynamo_types = [DynamoType(value) for value in comparison_values]
            scan_filters[key] = (comparison_operator, dynamo_types)

        filter_expression = get_filter_expression(
            filter_expression, expr_names, expr_values
        )

        projection_expression = ",".join(
            [
                expr_names.get(attr, attr)
                for attr in projection_expression.replace(" ", "").split(",")
            ]
        )

        return table.scan(
            scan_filters,
            limit,
            exclusive_start_key,
            filter_expression,
            index_name,
            projection_expression,
        )
Ejemplo n.º 9
0
    def query(self,
              hash_key,
              range_comparison,
              range_objs,
              limit,
              exclusive_start_key,
              scan_index_forward,
              projection_expression,
              index_name=None,
              filter_expression=None,
              **filter_kwargs):
        results = []

        if index_name:
            all_indexes = self.all_indexes()
            indexes_by_name = dict((i["IndexName"], i) for i in all_indexes)
            if index_name not in indexes_by_name:
                raise ValueError(
                    "Invalid index: %s for table: %s. Available indexes are: %s"
                    %
                    (index_name, self.name, ", ".join(indexes_by_name.keys())))

            index = indexes_by_name[index_name]
            try:
                index_hash_key = [
                    key for key in index["KeySchema"]
                    if key["KeyType"] == "HASH"
                ][0]
            except IndexError:
                raise ValueError("Missing Hash Key. KeySchema: %s" %
                                 index["KeySchema"])

            try:
                index_range_key = [
                    key for key in index["KeySchema"]
                    if key["KeyType"] == "RANGE"
                ][0]
            except IndexError:
                index_range_key = None

            possible_results = []
            for item in self.all_items():
                if not isinstance(item, Item):
                    continue
                item_hash_key = item.attrs.get(index_hash_key["AttributeName"])
                if index_range_key is None:
                    if item_hash_key and item_hash_key == hash_key:
                        possible_results.append(item)
                else:
                    item_range_key = item.attrs.get(
                        index_range_key["AttributeName"])
                    if item_hash_key and item_hash_key == hash_key and item_range_key:
                        possible_results.append(item)
        else:
            possible_results = [
                item for item in list(self.all_items())
                if isinstance(item, Item) and item.hash_key == hash_key
            ]
        if range_comparison:
            if index_name and not index_range_key:
                raise ValueError(
                    "Range Key comparison but no range key found for index: %s"
                    % index_name)

            elif index_name:
                for result in possible_results:
                    if result.attrs.get(
                            index_range_key["AttributeName"]).compare(
                                range_comparison, range_objs):
                        results.append(result)
            else:
                for result in possible_results:
                    if result.range_key.compare(range_comparison, range_objs):
                        results.append(result)

        if filter_kwargs:
            for result in possible_results:
                for field, value in filter_kwargs.items():
                    dynamo_types = [
                        DynamoType(ele) for ele in value["AttributeValueList"]
                    ]
                    if result.attrs.get(field).compare(
                            value["ComparisonOperator"], dynamo_types):
                        results.append(result)

        if not range_comparison and not filter_kwargs:
            # If we're not filtering on range key or on an index return all
            # values
            results = possible_results

        if index_name:

            if index_range_key:

                # Convert to float if necessary to ensure proper ordering
                def conv(x):
                    return float(x.value) if x.type == "N" else x.value

                results.sort(key=lambda item: conv(item.attrs[index_range_key[
                    "AttributeName"]]) if item.attrs.get(index_range_key[
                        "AttributeName"]) else None)
        else:
            results.sort(key=lambda item: item.range_key)

        if scan_index_forward is False:
            results.reverse()

        scanned_count = len(list(self.all_items()))

        if filter_expression is not None:
            results = [
                item for item in results if filter_expression.expr(item)
            ]

        results = copy.deepcopy(results)
        if projection_expression:
            for result in results:
                result.filter(projection_expression)

        results, last_evaluated_key = self._trim_results(
            results, limit, exclusive_start_key)
        return results, scanned_count, last_evaluated_key
Ejemplo n.º 10
0
    def update_item(
        self,
        table_name,
        key,
        update_expression,
        expression_attribute_names,
        expression_attribute_values,
        attribute_updates=None,
        expected=None,
        condition_expression=None,
    ):
        table = self.get_table(table_name)

        # Support spaces between operators in an update expression
        # E.g. `a = b + c` -> `a=b+c`
        if update_expression:
            # Parse expression to get validation errors
            update_expression_ast = UpdateExpressionParser.make(
                update_expression)
            update_expression = re.sub(r"\s*([=\+-])\s*", "\\1",
                                       update_expression)

        if all([table.hash_key_attr in key, table.range_key_attr in key]):
            # Covers cases where table has hash and range keys, ``key`` param
            # will be a dict
            hash_value = DynamoType(key[table.hash_key_attr])
            range_value = DynamoType(key[table.range_key_attr])
        elif table.hash_key_attr in key:
            # Covers tables that have a range key where ``key`` param is a dict
            hash_value = DynamoType(key[table.hash_key_attr])
            range_value = None
        else:
            # Covers other cases
            hash_value = DynamoType(key)
            range_value = None

        item = table.get_item(hash_value, range_value)
        orig_item = copy.deepcopy(item)

        if not expected:
            expected = {}

        if not get_expected(expected).expr(item):
            raise ConditionalCheckFailed
        condition_op = get_filter_expression(
            condition_expression,
            expression_attribute_names,
            expression_attribute_values,
        )
        if not condition_op.expr(item):
            raise ConditionalCheckFailed

        # Update does not fail on new items, so create one
        if item is None:
            data = {table.hash_key_attr: {hash_value.type: hash_value.value}}
            if range_value:
                data.update({
                    table.range_key_attr: {
                        range_value.type: range_value.value
                    }
                })

            table.put_item(data)
            item = table.get_item(hash_value, range_value)

        if update_expression:
            validated_ast = UpdateExpressionValidator(
                update_expression_ast,
                expression_attribute_names=expression_attribute_names,
                expression_attribute_values=expression_attribute_values,
                item=item,
            ).validate()
            try:
                UpdateExpressionExecutor(validated_ast, item,
                                         expression_attribute_names).execute()
            except ItemSizeTooLarge:
                raise ItemSizeToUpdateTooLarge()
        else:
            item.update_with_attribute_updates(attribute_updates)
        if table.stream_shard is not None:
            table.stream_shard.add(orig_item, item)
        return item
Ejemplo n.º 11
0
 def update_with_attribute_updates(self, attribute_updates):
     for attribute_name, update_action in attribute_updates.items():
         action = update_action["Action"]
         if action == "DELETE" and "Value" not in update_action:
             if attribute_name in self.attrs:
                 del self.attrs[attribute_name]
             continue
         new_value = list(update_action["Value"].values())[0]
         if action == "PUT":
             # TODO deal with other types
             if isinstance(new_value, list):
                 self.attrs[attribute_name] = DynamoType({"L": new_value})
             elif isinstance(new_value, set):
                 self.attrs[attribute_name] = DynamoType({"SS": new_value})
             elif isinstance(new_value, dict):
                 self.attrs[attribute_name] = DynamoType({"M": new_value})
             elif set(update_action["Value"].keys()) == set(["N"]):
                 self.attrs[attribute_name] = DynamoType({"N": new_value})
             elif set(update_action["Value"].keys()) == set(["NULL"]):
                 if attribute_name in self.attrs:
                     del self.attrs[attribute_name]
             else:
                 self.attrs[attribute_name] = DynamoType({"S": new_value})
         elif action == "ADD":
             if set(update_action["Value"].keys()) == set(["N"]):
                 existing = self.attrs.get(attribute_name,
                                           DynamoType({"N": "0"}))
                 self.attrs[attribute_name] = DynamoType({
                     "N":
                     str(
                         decimal.Decimal(existing.value) +
                         decimal.Decimal(new_value))
                 })
             elif set(update_action["Value"].keys()) == set(["SS"]):
                 existing = self.attrs.get(attribute_name,
                                           DynamoType({"SS": {}}))
                 new_set = set(existing.value).union(set(new_value))
                 self.attrs[attribute_name] = DynamoType(
                     {"SS": list(new_set)})
             else:
                 # TODO: implement other data types
                 raise NotImplementedError(
                     "ADD not supported for %s" %
                     ", ".join(update_action["Value"].keys()))
         elif action == "DELETE":
             if set(update_action["Value"].keys()) == set(["SS"]):
                 existing = self.attrs.get(attribute_name,
                                           DynamoType({"SS": {}}))
                 new_set = set(existing.value).difference(set(new_value))
                 self.attrs[attribute_name] = DynamoType(
                     {"SS": list(new_set)})
             else:
                 raise NotImplementedError(
                     "ADD not supported for %s" %
                     ", ".join(update_action["Value"].keys()))
         else:
             raise NotImplementedError(
                 "%s action not support for update_with_attribute_updates" %
                 action)
Ejemplo n.º 12
0
    def update(
        self, update_expression, expression_attribute_names, expression_attribute_values
    ):
        # Update subexpressions are identifiable by the operator keyword, so split on that and
        # get rid of the empty leading string.
        parts = [
            p
            for p in re.split(
                r"\b(SET|REMOVE|ADD|DELETE)\b", update_expression, flags=re.I
            )
            if p
        ]
        # make sure that we correctly found only operator/value pairs
        assert (
            len(parts) % 2 == 0
        ), "Mismatched operators and values in update expression: '{}'".format(
            update_expression
        )
        for action, valstr in zip(parts[:-1:2], parts[1::2]):
            action = action.upper()

            # "Should" retain arguments in side (...)
            values = re.split(r",(?![^(]*\))", valstr)
            for value in values:
                # A Real value
                value = value.lstrip(":").rstrip(",").strip()
                for k, v in expression_attribute_names.items():
                    value = re.sub(r"{0}\b".format(k), v, value)

                if action == "REMOVE":
                    key = value
                    attr, list_index = attribute_is_list(key.split(".")[0])
                    if "." not in key:
                        if list_index:
                            new_list = DynamoType(self.attrs[attr])
                            new_list.delete(None, list_index)
                            self.attrs[attr] = new_list
                        else:
                            self.attrs.pop(value, None)
                    else:
                        # Handle nested dict updates
                        self.attrs[attr].delete(".".join(key.split(".")[1:]))
                elif action == "SET":
                    key, value = value.split("=", 1)
                    key = key.strip()
                    value = value.strip()

                    # check whether key is a list
                    attr, list_index = attribute_is_list(key.split(".")[0])
                    # If value not exists, changes value to a default if needed, else its the same as it was
                    value = self._get_default(value)
                    # If operation == list_append, get the original value and append it
                    value = self._get_appended_list(value, expression_attribute_values)

                    if type(value) != DynamoType:
                        if value in expression_attribute_values:
                            dyn_value = DynamoType(expression_attribute_values[value])
                        else:
                            dyn_value = DynamoType({"S": value})
                    else:
                        dyn_value = value

                    if "." in key and attr not in self.attrs:
                        raise ValueError  # Setting nested attr not allowed if first attr does not exist yet
                    elif attr not in self.attrs:
                        self.attrs[attr] = dyn_value  # set new top-level attribute
                    else:
                        self.attrs[attr].set(
                            ".".join(key.split(".")[1:]), dyn_value, list_index
                        )  # set value recursively

                elif action == "ADD":
                    key, value = value.split(" ", 1)
                    key = key.strip()
                    value_str = value.strip()
                    if value_str in expression_attribute_values:
                        dyn_value = DynamoType(expression_attribute_values[value])
                    else:
                        raise TypeError

                    # Handle adding numbers - value gets added to existing value,
                    # or added to 0 if it doesn't exist yet
                    if dyn_value.is_number():
                        existing = self.attrs.get(key, DynamoType({"N": "0"}))
                        if not existing.same_type(dyn_value):
                            raise TypeError()
                        self.attrs[key] = DynamoType(
                            {
                                "N": str(
                                    decimal.Decimal(existing.value)
                                    + decimal.Decimal(dyn_value.value)
                                )
                            }
                        )

                    # Handle adding sets - value is added to the set, or set is
                    # created with only this value if it doesn't exist yet
                    # New value must be of same set type as previous value
                    elif dyn_value.is_set():
                        key_head = key.split(".")[0]
                        key_tail = ".".join(key.split(".")[1:])
                        if key_head not in self.attrs:
                            self.attrs[key_head] = DynamoType({dyn_value.type: {}})
                        existing = self.attrs.get(key_head)
                        existing = existing.get(key_tail)
                        if existing.value and not existing.same_type(dyn_value):
                            raise TypeError()
                        new_set = set(existing.value or []).union(dyn_value.value)
                        existing.set(
                            key=None,
                            new_value=DynamoType({dyn_value.type: list(new_set)}),
                        )
                    else:  # Number and Sets are the only supported types for ADD
                        raise TypeError

                elif action == "DELETE":
                    key, value = value.split(" ", 1)
                    key = key.strip()
                    value_str = value.strip()
                    if value_str in expression_attribute_values:
                        dyn_value = DynamoType(expression_attribute_values[value])
                    else:
                        raise TypeError

                    if not dyn_value.is_set():
                        raise TypeError
                    key_head = key.split(".")[0]
                    key_tail = ".".join(key.split(".")[1:])
                    existing = self.attrs.get(key_head)
                    existing = existing.get(key_tail)
                    if existing:
                        if not existing.same_type(dyn_value):
                            raise TypeError
                        new_set = set(existing.value).difference(dyn_value.value)
                        existing.set(
                            key=None,
                            new_value=DynamoType({existing.type: list(new_set)}),
                        )
                else:
                    raise NotImplementedError(
                        "{} update action not yet supported".format(action)
                    )
Ejemplo n.º 13
0
    def update_item(
        self,
        table_name,
        key,
        update_expression,
        attribute_updates,
        expression_attribute_names,
        expression_attribute_values,
        expected=None,
        condition_expression=None,
    ):
        table = self.get_table(table_name)

        if all([table.hash_key_attr in key, table.range_key_attr in key]):
            # Covers cases where table has hash and range keys, ``key`` param
            # will be a dict
            hash_value = DynamoType(key[table.hash_key_attr])
            range_value = DynamoType(key[table.range_key_attr])
        elif table.hash_key_attr in key:
            # Covers tables that have a range key where ``key`` param is a dict
            hash_value = DynamoType(key[table.hash_key_attr])
            range_value = None
        else:
            # Covers other cases
            hash_value = DynamoType(key)
            range_value = None

        item = table.get_item(hash_value, range_value)
        orig_item = copy.deepcopy(item)

        if not expected:
            expected = {}

        if not get_expected(expected).expr(item):
            raise ValueError("The conditional request failed")
        condition_op = get_filter_expression(
            condition_expression,
            expression_attribute_names,
            expression_attribute_values,
        )
        if not condition_op.expr(item):
            raise ValueError("The conditional request failed")

        # Update does not fail on new items, so create one
        if item is None:
            data = {table.hash_key_attr: {hash_value.type: hash_value.value}}
            if range_value:
                data.update(
                    {table.range_key_attr: {range_value.type: range_value.value}}
                )

            table.put_item(data)
            item = table.get_item(hash_value, range_value)

        if update_expression:
            item.update(
                update_expression,
                expression_attribute_names,
                expression_attribute_values,
            )
        else:
            item.update_with_attribute_updates(attribute_updates)
        if table.stream_shard is not None:
            table.stream_shard.add(orig_item, item)
        return item