示例#1
0
class DynamodbClientIntegrationTestCase(unittest.TestCase):
    TEST_CONFIG = {
        'row_mapper': {
            'lambda_name': 'S',
            'invocation_id': 'S',
            'en_time': 'N',
            'hash_col': 'S',
            'range_col': 'N',
            'other_col': 'S',
            'new_col': 'S',
            'some_col': 'S',
            'some_counter': 'N',
            'some_bool': 'BOOL',
            'some_map': 'M',
        },
        'required_fields': ['lambda_name'],
        'table_name': 'autotest_dynamo_db',
        'hash_key': 'hash_col'
    }

    @classmethod
    def setUpClass(cls):
        clean_dynamo_table()

    def setUp(self):
        self.HASH_COL = 'hash_col'
        self.HASH_KEY = (self.HASH_COL, 'S')

        self.RANGE_COL = 'range_col'
        self.RANGE_COL_TYPE = 'N'
        self.RANGE_KEY = (self.RANGE_COL, self.RANGE_COL)

        self.KEYS = (self.HASH_COL, self.RANGE_COL)
        self.table_name = 'autotest_dynamo_db'
        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)

    def tearDown(self):
        clean_dynamo_table(self.table_name, self.KEYS)

    def test_put(self):
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: '123',
            'some_bool': True,
            'some_map': {
                'a': 1,
                'b': 'b1',
                'c': {
                    'test': True
                }
            }
        }

        client = boto3.client('dynamodb')

        client.delete_item(TableName=self.table_name,
                           Key={
                               self.HASH_COL: {
                                   'S': str(row[self.HASH_COL])
                               },
                               self.RANGE_COL: {
                                   self.RANGE_COL_TYPE:
                                   str(row[self.RANGE_COL])
                               },
                           })

        self.dynamo_client.put(row, self.table_name)

        result = client.scan(
            TableName=self.table_name,
            FilterExpression="hash_col = :hash_col AND range_col = :range_col",
            ExpressionAttributeValues={
                ':hash_col': {
                    'S': row[self.HASH_COL]
                },
                ':range_col': {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            })

        items = result['Items']

        self.assertEqual(1, len(items))

        expected = [{
            'hash_col': {
                'S': 'cat'
            },
            'range_col': {
                'N': '123'
            },
            'some_bool': {
                'BOOL': True
            },
            'some_map': {
                'M': {
                    'a': {
                        'N': '1'
                    },
                    'b': {
                        'S': 'b1'
                    },
                    'c': {
                        'M': {
                            'test': {
                                'BOOL': True
                            }
                        }
                    }
                }
            }
        }]
        self.assertEqual(expected, items)

    def test_put__create(self):
        row = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}

        self.dynamo_client.put(row, self.table_name)

        with self.assertRaises(self.dynamo_client.dynamo_client.exceptions.
                               ConditionalCheckFailedException):
            self.dynamo_client.put(row,
                                   self.table_name,
                                   overwrite_existing=False)

    def test_update__updates(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: '123',
            'some_col': 'no',
            'other_col': 'foo'
        }
        attributes_to_update = {'some_col': 'yes', 'new_col': 'yup'}

        self.dynamo_client.put(row, self.table_name)

        client = boto3.client('dynamodb')

        # First check that the row we are trying to update is PUT correctly.
        initial_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        initial_row = self.dynamo_client.dynamo_to_dict(initial_row)

        self.assertIsNotNone(initial_row)
        self.assertEqual(initial_row['some_col'], 'no')
        self.assertEqual(initial_row['other_col'], 'foo')

        self.dynamo_client.update(keys,
                                  attributes_to_update,
                                  table_name=self.table_name)

        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_col'],
                         'yes'), "Updated field not really updated"
        self.assertEqual(updated_row['new_col'],
                         'yup'), "New field was not created"
        self.assertEqual(
            updated_row['other_col'],
            'foo'), "This field should be preserved, update() damaged it"

    def test_update__increment(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: '123',
            'some_col': 'no',
            'some_counter': 10
        }
        attributes_to_increment = {'some_counter': '1'}

        self.dynamo_client.put(row, self.table_name)

        self.dynamo_client.update(
            keys, {},
            attributes_to_increment=attributes_to_increment,
            table_name=self.table_name)

        client = boto3.client('dynamodb')

        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_counter'], 11)

    def test_update__increment_2(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: '123',
            'some_col': 'no',
            'some_counter': 10
        }
        attributes_to_increment = {'some_counter': 5}

        self.dynamo_client.put(row, self.table_name)

        self.dynamo_client.update(
            keys, {},
            attributes_to_increment=attributes_to_increment,
            table_name=self.table_name)

        client = boto3.client('dynamodb')

        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_counter'], 15)

    def test_update__increment_no_default(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}
        row = {self.HASH_COL: 'cat', self.RANGE_COL: '123', 'some_col': 'no'}
        attributes_to_increment = {'some_counter': '3'}

        self.dynamo_client.put(row, self.table_name)

        self.dynamo_client.update(
            keys, {},
            attributes_to_increment=attributes_to_increment,
            table_name=self.table_name)

        client = boto3.client('dynamodb')

        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_counter'], 3)

    def test_update__condition_expression(self):
        keys = {self.HASH_COL: 'slime', self.RANGE_COL: '41'}
        row = {self.HASH_COL: 'slime', self.RANGE_COL: '41', 'some_col': 'no'}

        self.dynamo_client.put(row, self.table_name)

        # Should fail because conditional expression does not match
        self.assertRaises(self.dynamo_client.dynamo_client.exceptions.
                          ConditionalCheckFailedException,
                          self.dynamo_client.update,
                          keys, {},
                          attributes_to_increment={'some_counter': '3'},
                          condition_expression='some_col = yes',
                          table_name=self.table_name)

        # Should pass
        self.dynamo_client.update(
            keys, {},
            attributes_to_increment={'some_counter': '3'},
            condition_expression='some_col = no',
            table_name=self.table_name)

        client = boto3.client('dynamodb')
        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)
        self.assertEqual(updated_row['some_counter'], 3)

    def test_patch(self):
        keys = {self.HASH_COL: 'slime', self.RANGE_COL: '41'}
        row = {self.HASH_COL: 'slime', self.RANGE_COL: '41', 'some_col': 'no'}

        # Should fail because row doesn't exist
        self.assertRaises(self.dynamo_client.dynamo_client.exceptions.
                          ConditionalCheckFailedException,
                          self.dynamo_client.patch,
                          keys,
                          attributes_to_update={'some_col': 'yes'},
                          table_name=self.table_name)

        # Create the row
        self.dynamo_client.put(row, self.table_name)
        # Should pass because the row exists now
        self.dynamo_client.patch(keys,
                                 attributes_to_update={'some_col': 'yes'},
                                 table_name=self.table_name)

        client = boto3.client('dynamodb')
        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)
        self.assertEqual(updated_row['some_col'], 'yes')

    def test_get_by_query__primary_index(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 123,
            'some_col': 'test',
            'some_bool': True
        }
        self.dynamo_client.put(row, self.table_name)

        result = self.dynamo_client.get_by_query(keys=keys)

        self.assertEqual(len(result), 1)
        result = result[0]
        for key in row:
            self.assertEqual(row[key], result[key])
        for key in result:
            self.assertEqual(row[key], result[key])

    def test_get_by_query__primary_index__gets_multiple(self):
        row = {self.HASH_COL: 'cat', self.RANGE_COL: 123, 'some_col': 'test'}
        self.dynamo_client.put(row, self.table_name)

        row2 = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 1234,
            'some_col': 'test2'
        }
        self.dynamo_client.put(row2, self.table_name)

        result = self.dynamo_client.get_by_query(keys={self.HASH_COL: 'cat'})

        self.assertEqual(len(result), 2)

        result1 = [
            x for x in result if x[self.RANGE_COL] == row[self.RANGE_COL]
        ][0]
        result2 = [
            x for x in result if x[self.RANGE_COL] == row2[self.RANGE_COL]
        ][0]

        for key in row:
            self.assertEqual(row[key], result1[key])
        for key in result1:
            self.assertEqual(row[key], result1[key])
        for key in row2:
            self.assertEqual(row2[key], result2[key])
        for key in result2:
            self.assertEqual(row2[key], result2[key])

    def test_get_by_query__secondary_index(self):
        keys = {self.HASH_COL: 'cat', 'other_col': 'abc123'}
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 123,
            'other_col': 'abc123'
        }
        self.dynamo_client.put(row, self.table_name)

        result = self.dynamo_client.get_by_query(keys=keys,
                                                 index_name='autotest_index')

        self.assertEqual(len(result), 1)
        result = result[0]
        for key in row:
            self.assertEqual(row[key], result[key])
        for key in result:
            self.assertEqual(row[key], result[key])

    def test_get_by_query__comparison(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '300'}
        row1 = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 123,
            'other_col': 'abc123'
        }
        row2 = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 456,
            'other_col': 'abc123'
        }
        self.dynamo_client.put(row1, self.table_name)
        self.dynamo_client.put(row2, self.table_name)

        result = self.dynamo_client.get_by_query(
            keys=keys, comparisons={self.RANGE_COL: '<='})

        self.assertEqual(len(result), 1)

        result = result[0]
        self.assertEqual(result, row1)

    def test_get_by_query__comparison_between(self):
        # Put sample data
        x = [
            self.dynamo_client.put({
                self.HASH_COL: 'cat',
                self.RANGE_COL: x
            }, self.table_name) for x in range(10)
        ]

        keys = {
            self.HASH_COL: 'cat',
            'st_between_range_col': '3',
            'en_between_range_col': '6'
        }
        result = self.dynamo_client.get_by_query(
            keys=keys, comparisons={self.RANGE_COL: 'between'})
        # print(result)
        self.assertTrue(all(x[self.RANGE_COL] in range(3, 7) for x in result))

        result = self.dynamo_client.get_by_query(keys=keys)
        # print(result)
        self.assertTrue(all(x[self.RANGE_COL] in range(3, 7) for x in result)), "Failed if unspecified comparison. " \
                                                                                "Should be automatic for :st_between_..."

    def test_get_by_query__filter_expression(self):
        """
        This _integration_ test runs multiple checks with same sample data for several comparators.
        Have a look at the manual if required:
        https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html
        """

        # Put sample data
        [
            self.dynamo_client.put({
                self.HASH_COL: 'cat',
                self.RANGE_COL: x
            }, self.table_name) for x in range(3)
        ]
        [
            self.dynamo_client.put(
                {
                    self.HASH_COL: 'cat',
                    self.RANGE_COL: x,
                    'mark': 1
                }, self.table_name) for x in range(3, 6)
        ]
        self.dynamo_client.put(
            {
                self.HASH_COL: 'cat',
                self.RANGE_COL: 6,
                'mark': 0
            }, self.table_name)
        self.dynamo_client.put(
            {
                self.HASH_COL: 'cat',
                self.RANGE_COL: 7,
                'mark': 'a'
            }, self.table_name)

        # Condition by range_col will return five rows out of six: 0 - 4
        # Filter expression neggs the first three rows because they don't have `mark = 1`.
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: 4}
        result = self.dynamo_client.get_by_query(
            keys=keys,
            comparisons={self.RANGE_COL: '<='},
            fetch_all_fields=True,
            filter_expression='mark = 1')
        # print(result)

        self.assertEqual(len(result), 2)
        self.assertEqual(result[0], {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 3,
            'mark': 1
        })
        self.assertEqual(result[1], {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 4,
            'mark': 1
        })

        # In the same test we check also some comparator _functions_.
        result = self.dynamo_client.get_by_query(
            keys=keys,
            comparisons={self.RANGE_COL: '<='},
            fetch_all_fields=True,
            filter_expression='attribute_exists mark')
        # print(result)
        self.assertEqual(len(result), 2)
        self.assertEqual([x[self.RANGE_COL] for x in result], list(range(3,
                                                                         5)))

        self.assertEqual(result[0], {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 3,
            'mark': 1
        })
        self.assertEqual(result[1], {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 4,
            'mark': 1
        })

        result = self.dynamo_client.get_by_query(
            keys=keys,
            comparisons={self.RANGE_COL: '<='},
            fetch_all_fields=True,
            filter_expression='attribute_not_exists mark')
        # print(result)
        self.assertEqual(len(result), 3)
        self.assertEqual([x[self.RANGE_COL] for x in result], list(range(3)))

    def test_get_by_query__comparison_begins_with(self):
        self.table_name = 'autotest_config_component'  # This table has a string range key
        self.HASH_KEY = ('env', 'S')
        self.RANGE_KEY = ('config_name', 'S')
        self.KEYS = ('env', 'config_name')
        config = {
            'row_mapper': {
                'env': 'S',
                'config_name': 'S',
                'config_value': 'S'
            },
            'required_fields': ['env', 'config_name', 'config_value'],
            'table_name': 'autotest_config_component',
            'hash_key': self.HASH_COL
        }

        self.dynamo_client = DynamoDbClient(config=config)

        row1 = {
            'env': 'cat',
            'config_name': 'testzing',
            'config_value': 'abc123'
        }
        row2 = {
            'env': 'cat',
            'config_name': 'dont_get_this',
            'config_value': 'abc123'
        }
        row3 = {
            'env': 'cat',
            'config_name': 'testzer',
            'config_value': 'abc124'
        }
        self.dynamo_client.put(row1, self.table_name)
        self.dynamo_client.put(row2, self.table_name)
        self.dynamo_client.put(row3, self.table_name)

        keys = {'env': 'cat', 'config_name': 'testz'}
        result = self.dynamo_client.get_by_query(
            keys=keys,
            table_name=self.table_name,
            comparisons={'config_name': 'begins_with'})

        self.assertEqual(len(result), 2)

        self.assertTrue(row1 in result)
        self.assertTrue(row3 in result)

    def test_get_by_query__max_items(self):
        # This function can also be used for some benchmarking, just change to bigger amounts manually.
        INITIAL_TASKS = 5  # Change to 500 to run benchmarking, and uncomment raise at the end of the test.

        for x in range(1000, 1000 + INITIAL_TASKS):
            row = {self.HASH_COL: f"key", self.RANGE_COL: x}
            self.dynamo_client.put(row, self.table_name)
            if INITIAL_TASKS > 10:
                time.sleep(
                    0.1
                )  # Sleep a little to fit the Write Capacity (10 WCU) of autotest table.

        n = 3
        st = time.perf_counter()
        result = self.dynamo_client.get_by_query({self.HASH_COL: 'key'},
                                                 table_name=self.table_name,
                                                 max_items=n)
        bm = time.perf_counter() - st
        logging.info(f"Benchmark (n={n}): {bm}")

        self.assertEqual(len(result), n)
        self.assertLess(bm, 0.1)

        # Check unspecified limit.
        result = self.dynamo_client.get_by_query({self.HASH_COL: 'key'},
                                                 table_name=self.table_name)
        self.assertEqual(len(result), INITIAL_TASKS)

    def test_get_by_query__return_count(self):
        rows = [{
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 121,
            'some_col': 'test1'
        }, {
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 122,
            'some_col': 'test2'
        }, {
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 123,
            'some_col': 'test3'
        }]

        for x in rows:
            self.dynamo_client.put(x, table_name=self.table_name)

        result = self.dynamo_client.get_by_query({self.HASH_COL: 'cat1'},
                                                 table_name=self.table_name,
                                                 return_count=True)

        self.assertEqual(result, 3)

    def test_get_by_query__reverse(self):
        rows = [{
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 121,
            'some_col': 'test1'
        }, {
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 122,
            'some_col': 'test2'
        }, {
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 123,
            'some_col': 'test3'
        }]

        for x in rows:
            self.dynamo_client.put(x, table_name=self.table_name)

        result = self.dynamo_client.get_by_query({self.HASH_COL: 'cat1'},
                                                 table_name=self.table_name,
                                                 desc=True)

        self.assertEqual(result[0], rows[-1])

    def test_get_by_scan__all(self):
        rows = [{
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 121,
            'some_col': 'test1'
        }, {
            self.HASH_COL: 'cat2',
            self.RANGE_COL: 122,
            'some_col': 'test2'
        }, {
            self.HASH_COL: 'cat3',
            self.RANGE_COL: 123,
            'some_col': 'test3'
        }]
        for x in rows:
            self.dynamo_client.put(x, self.table_name)

        result = self.dynamo_client.get_by_scan()

        self.assertEqual(len(result), 3)

        for r in rows:
            assert r in result, f"row not in result from dynamo scan: {r}"

    def test_get_by_scan__with_filter(self):
        rows = [
            {
                self.HASH_COL: 'cat1',
                self.RANGE_COL: 121,
                'some_col': 'test1'
            },
            {
                self.HASH_COL: 'cat1',
                self.RANGE_COL: 122,
                'some_col': 'test2'
            },
            {
                self.HASH_COL: 'cat2',
                self.RANGE_COL: 122,
                'some_col': 'test2'
            },
        ]
        for x in rows:
            self.dynamo_client.put(x, self.table_name)

        filter = {'some_col': 'test2'}

        result = self.dynamo_client.get_by_scan(attrs=filter)

        self.assertEqual(len(result), 2)

        for r in rows[1:]:
            assert r in result, f"row not in result from dynamo scan: {r}"

    def test_batch_get_items(self):
        rows = [
            {
                self.HASH_COL: 'cat1',
                self.RANGE_COL: 121,
                'some_col': 'test1'
            },
            {
                self.HASH_COL: 'cat1',
                self.RANGE_COL: 122,
                'some_col': 'test2'
            },
            {
                self.HASH_COL: 'cat2',
                self.RANGE_COL: 122,
                'some_col': 'test2'
            },
        ]
        for x in rows:
            self.dynamo_client.put(x, self.table_name)

        keys_list_query = [
            {
                self.HASH_COL: 'cat1',
                self.RANGE_COL: 121
            },
            {
                self.HASH_COL: 'doesnt_exist',
                self.RANGE_COL: 40
            },
            {
                self.HASH_COL: 'cat2',
                self.RANGE_COL: 122
            },
        ]

        result = self.dynamo_client.batch_get_items_one_table(keys_list_query)

        self.assertEqual(len(result), 2)

        self.assertIn(rows[0], result)
        self.assertIn(rows[2], result)

    def test_delete(self):
        self.dynamo_client.put({self.HASH_COL: 'cat1', self.RANGE_COL: 123})
        self.dynamo_client.put({self.HASH_COL: 'cat2', self.RANGE_COL: 234})

        self.dynamo_client.delete(keys={
            self.HASH_COL: 'cat1',
            self.RANGE_COL: '123'
        })

        items = self.dynamo_client.get_by_scan()

        self.assertEqual(len(items), 1)
        self.assertEqual(items[0], {
            self.HASH_COL: 'cat2',
            self.RANGE_COL: 234
        })

    def test_get_table_keys(self):
        result1 = self.dynamo_client.get_table_keys()
        self.assertEqual(result1, ('hash_col', 'range_col'))

        result2 = self.dynamo_client.get_table_keys(table_name=self.table_name)
        self.assertEqual(result2, ('hash_col', 'range_col'))

    def test_get_table_indexes(self):
        indexes = self.dynamo_client.get_table_indexes()
        expected = {
            'autotest_index': {
                'projection_type': 'ALL',
                'hash_key': 'hash_col',
                'range_key': 'other_col',
                'provisioned_throughput': {
                    'write_capacity': 1,
                    'read_capacity': 1
                }
            }
        }
        self.assertDictEqual(expected, indexes)

    def test_batch_get_items_one_table(self):
        # If you want to stress test batch_get_items_one_table, use bigger numbers
        num_of_items = 5
        query_from = 2
        query_till = 4
        expected_items = query_till - query_from

        # Write items
        operations = []
        query_keys = []
        for i in range(num_of_items):
            item = {self.HASH_COL: f'cat{i%2}', self.RANGE_COL: i}
            operations.append(
                {'Put': self.dynamo_client.build_put_query(item)})
            query_keys.append(item)
        for operations_chunk in chunks(operations, 10):
            self.dynamo_client.dynamo_client.transact_write_items(
                TransactItems=operations_chunk)
            time.sleep(1)  # cause the table has 10 write/sec capacity

        # Batch get items
        results = self.dynamo_client.batch_get_items_one_table(
            keys_list=query_keys[query_from:query_till])
        self.assertEqual(expected_items, len(results))
示例#2
0
class dynamodb_client_UnitTestCase(unittest.TestCase):
    TEST_CONFIG = {
        'row_mapper': {
            'lambda_name': 'S',
            'invocation_id': 'S',
            'en_time': 'N',
            'hash_col': 'S',
            'range_col': 'N',
            'other_col': 'S',
            'new_col': 'S',
            'some_col': 'S',
            'some_counter': 'N',
            'some_bool': 'BOOL',
            'some_bool2': 'BOOL',
            'some_map': 'M',
            'some_list': 'L'
        },
        'required_fields': ['lambda_name'],
        'table_name': 'autotest_dynamo_db',
        'hash_key': 'hash_col',
    }

    def setUp(self):
        self.HASH_KEY = ('hash_col', 'S')
        self.RANGE_KEY = ('range_col', 'N')
        self.KEYS = ('hash_col', 'range_col')
        self.table_name = 'autotest_dynamo_db'

        self.patcher = patch("boto3.client")
        self.paginator_mock = MagicMock()
        self.dynamo_mock = MagicMock()
        self.dynamo_mock.get_paginator.return_value = self.paginator_mock

        self.boto3_client_patch = self.patcher.start()
        self.boto3_client_patch.return_value = self.dynamo_mock

        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)

    def tearDown(self):
        self.patcher.stop()

    def test_create__raises__if_no_hash_col_configured(self):
        bad_config = deepcopy(self.TEST_CONFIG)
        del bad_config['hash_key']

        dynamo_client = DynamoDbClient(config=bad_config)

        row = {self.HASH_KEY: 'cat', self.RANGE_KEY: '123'}
        self.assertRaises(AssertionError, dynamo_client.create, row,
                          self.table_name)

    def test_create__calls_boto_client(self):
        self.dynamo_mock.put_item.assert_not_called()

        self.dynamo_client.put({
            self.HASH_KEY: 'cat',
            self.RANGE_KEY: '123'
        }, self.table_name)
        self.dynamo_mock.put_item.assert_called_once()

    def test_dict_to_dynamo_strict(self):
        dict_row = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456,
            'some_bool': True,
            'some_bool2': 'True',
            'some_map': {
                'a': 1,
                'b': 'b1',
                'c': {
                    'test': True
                }
            },
            'some_list': ['x', 'y']
        }
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row)
        expected = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            },
            'some_bool': {
                'BOOL': True
            },
            'some_bool2': {
                'BOOL': True
            },
            'some_map': {
                'M': {
                    'a': {
                        'N': '1'
                    },
                    'b': {
                        'S': 'b1'
                    },
                    'c': {
                        'M': {
                            'test': {
                                'BOOL': True
                            }
                        }
                    }
                }
            },
            'some_list': {
                'L': [{
                    'S': 'x'
                }, {
                    'S': 'y'
                }]
            }
        }
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dict_to_dynamo_not_strict(self):
        dict_row = {
            'name': 'cat',
            'age': 3,
            'other_bool': False,
            'other_bool2': 'False',
            'other_map': {
                'a': 1,
                'b': 'b1',
                'c': {
                    'test': True
                }
            },
            'some_list': ['x', 'y']
        }
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row, strict=False)
        expected = {
            'name': {
                'S': 'cat'
            },
            'age': {
                'N': '3'
            },
            'other_bool': {
                'BOOL': False
            },
            'other_map': {
                'M': {
                    'a': {
                        'N': '1'
                    },
                    'b': {
                        'S': 'b1'
                    },
                    'c': {
                        'M': {
                            'test': {
                                'BOOL': True
                            }
                        }
                    }
                }
            },
            'some_list': {
                'L': [{
                    'S': 'x'
                }, {
                    'S': 'y'
                }]
            }
        }
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dict_to_dynamo__not_strict__map_type(self):
        dict_row = {
            'accept_mimetypes': {
                'image/webp': 1,
                'image/apng': 1,
                'image/*': 1,
                '*/*': 0.8
            },
        }
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row, strict=False)
        expected = {}
        logging.info(f"dynamo_row: {dynamo_row}")
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dict_to_dynamo_prefix(self):
        dict_row = {'hash_col': 'cat', 'range_col': '123', 'some_col': 'no'}
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row,
                                                       add_prefix="#")
        expected = {
            '#hash_col': {
                'S': 'cat'
            },
            '#range_col': {
                'N': '123'
            },
            '#some_col': {
                'S': 'no'
            }
        }
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dynamo_to_dict(self):
        dynamo_row = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            },
            'extra_key': {
                'N': '42'
            },
            'some_bool': {
                'BOOL': False
            },
            'some_map': {
                'M': {
                    'a': {
                        'N': '1'
                    },
                    'b': {
                        'S': 'b1'
                    },
                    'c': {
                        'M': {
                            'test': {
                                'BOOL': True
                            }
                        }
                    }
                }
            },
            'some_list': {
                'L': [{
                    'S': 'x'
                }, {
                    'S': 'y'
                }]
            }
        }
        dict_row = self.dynamo_client.dynamo_to_dict(dynamo_row)
        expected = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456,
            'some_bool': False,
            'some_map': {
                'a': 1,
                'b': 'b1',
                'c': {
                    'test': True
                }
            },
            'some_list': ['x', 'y']
        }
        self.assertDictEqual(expected, dict_row)
        for k, v in dict_row.items():
            self.assertNotIsInstance(v, Decimal)
        for k, v in dict_row['some_map'].items():
            self.assertNotIsInstance(v, Decimal)

    def test_dynamo_to_dict_no_strict_row_mapper(self):
        dynamo_row = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            },
            'extra_key_n': {
                'N': '42'
            },
            'extra_key_s': {
                'S': 'wowie'
            },
            'other_bool': {
                'BOOL': True
            }
        }
        dict_row = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                     fetch_all_fields=True)
        expected = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456,
            'extra_key_n': 42,
            'extra_key_s': 'wowie',
            'other_bool': True
        }
        self.assertDictEqual(dict_row, expected)
        for k, v in dict_row.items():
            self.assertNotIsInstance(v, Decimal)

    def test_dynamo_to_dict__dont_json_loads(self):
        config = self.TEST_CONFIG.copy()
        config['dont_json_loads_results'] = True

        self.dynamo_client = DynamoDbClient(config=config)

        dynamo_row = {
            'hash_col': {
                'S': 'aaa'
            },
            'range_col': {
                'N': '123'
            },
            'other_col': {
                'S': '{"how many": 300}'
            },
            'duck_quack': {
                'S': '{"quack": "duck"}'
            }
        }
        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=True)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': '{"how many": 300}',
            'duck_quack': '{"quack": "duck"}'
        }
        self.assertDictEqual(res, expected)

        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=False)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': '{"how many": 300}'
        }
        self.assertDictEqual(res, expected)

    def test_dynamo_to_dict__do_json_loads(self):
        config = self.TEST_CONFIG.copy()
        config['dont_json_loads_results'] = False

        self.dynamo_client = DynamoDbClient(config=config)

        dynamo_row = {
            'hash_col': {
                'S': 'aaa'
            },
            'range_col': {
                'N': '123'
            },
            'other_col': {
                'S': '{"how many": 300}'
            },
            'duck_quack': {
                'S': '{"quack": "duck"}'
            }
        }
        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=True)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': {
                "how many": 300
            },
            'duck_quack': {
                "quack": "duck"
            }
        }
        self.assertDictEqual(res, expected)

        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=False)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': {
                "how many": 300
            }
        }
        self.assertDictEqual(res, expected)

    def test_dynamo_to_dict__mapping_doesnt_match__raises(self):
        # If the value type in the DB doesn't match the expected type in row_mapper - raise ValueError

        dynamo_row = {
            'hash_col': {
                'S': 'aaa'
            },
            'range_col': {
                'N': '123'
            },
            'other_col': {
                'N': '111'
            }  # In the row_mapper, other_col is of type 'S'
        }

        with self.assertRaises(ValueError) as e:
            dict_row = self.dynamo_client.dynamo_to_dict(dynamo_row)

        self.assertEqual(
            "'other_col' is expected to be of type 'S' in row_mapper, but real value is of type 'N'",
            str(e.exception))

    def test_get_by_query__validates_comparison(self):
        self.assertRaises(AssertionError,
                          self.dynamo_client.get_by_query,
                          keys={'k': '1'},
                          comparisons={'k': 'unsupported'})

    def test_get_by_query__between(self):
        keys = {
            'hash_col': 'cat',
            'st_between_range_col': '3',
            'en_between_range_col': '6'
        }

        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)

        self.dynamo_client.get_by_query(keys=keys)
        # print(f"Call_args for paginate: {self.paginator_mock.paginate.call_args}")

        args, kwargs = self.paginator_mock.paginate.call_args
        # print(kwargs)

        self.assertEqual(len(kwargs['ExpressionAttributeValues']), 3)
        self.assertIn(
            'range_col between :st_between_range_col and :en_between_range_col',
            kwargs['KeyConditionExpression'])

    def test_get_by_query__return_count(self):

        # Make sure dynamo paginator is mocked.
        self.paginator_mock.paginate.return_value = [{
            'Count': 24,
            'LastEvaluatedKey': 'bzz'
        }, {
            'Count': 12
        }]
        self.dynamo_client.dynamo_client.get_paginator.return_value = self.paginator_mock

        # Call the manager
        result = self.dynamo_client.get_by_query(keys={'a': 'b'},
                                                 return_count=True)

        # Validate result
        self.assertEqual(
            result, 36,
            f"Result from 2 pages should be 24 + 12, but we received: {result}"
        )

        # Make sure the paginator was called
        self.dynamo_client.dynamo_client.get_paginator.assert_called()

    def test_get_by_query__expr_attr(self):
        keys = {
            'st_between_range_col': '3',
            'en_between_range_col': '6',
            'session': 'ses1'
        }
        expr_attrs_names = ['range_col', 'session']

        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)
        self.dynamo_client.get_by_query(keys=keys,
                                        expr_attrs_names=expr_attrs_names)

        args, kwargs = self.paginator_mock.paginate.call_args
        self.assertIn('#range_col', kwargs['ExpressionAttributeNames'])
        self.assertIn('#session', kwargs['ExpressionAttributeNames'])
        self.assertIn(
            '#range_col between :st_between_range_col and :en_between_range_col AND #session = :session',
            kwargs['KeyConditionExpression'])

    def test__parse_filter_expression(self):
        TESTS = {
            'key = 42': ("key = :filter_key", {
                ":filter_key": {
                    'N': '42'
                }
            }),
            '   key    = 42  ': ("key = :filter_key", {
                ":filter_key": {
                    'N': '42'
                }
            }),
            'cat = meaw': ("cat = :filter_cat", {
                ":filter_cat": {
                    'S': 'meaw'
                }
            }),
            'magic between 41 and 42':
            ("magic between :st_between_magic and :en_between_magic", {
                ":st_between_magic": {
                    'N': '41'
                },
                ":en_between_magic": {
                    'N': '42'
                }
            }),
            'attribute_not_exists boo': ("attribute_not_exists (boo)", {})
        }

        for data, expected in TESTS.items():
            self.assertEqual(self.dynamo_client._parse_filter_expression(data),
                             expected)

    def test__parse_filter_expression__raises(self):

        TESTS = [
            {
                'k': 1
            },
            [1, 2],
            None,  # Invalid input types
            'key == 42',
            'foo ~ 1',
            'foo3 <=> 0',
            'key between 42',  # Invalid operators
            'key between 23, 25',
            'key between [23, 25]',
            'key 23 between 21',  # Invalid between formats.
        ]

        for data in TESTS:
            self.assertRaises((AssertionError, ValueError),
                              self.dynamo_client._parse_filter_expression,
                              data)

    def test_create__calls_put(self):
        row = {'hash_col': 'cat', 'range_key': 'test', 'another_col': 'wow'}
        self.dynamo_client.put = MagicMock(return_value=None)

        self.dynamo_client.create(row)

        self.dynamo_client.put.assert_called_once_with(
            row, None, overwrite_existing=False)

    def test_batch_get_items_one_table__strict(self):
        # Strict - returns only fields that are in the row mapper
        db_items = [{
            'hash_col': {
                'S': 'b'
            },
            'range_col': {
                'N': '10'
            },
            'unknown_col': {
                'S': 'not_strict'
            }
        }]
        db_result = {'Responses': {'autotest_dynamo_db': db_items}}

        self.dynamo_client.dynamo_client.batch_get_item = Mock(
            return_value=db_result)

        result = self.dynamo_client.batch_get_items_one_table(
            keys_list=[{
                'hash_col': 'b'
            }], fetch_all_fields=False)

        self.assertEqual(result, [{'hash_col': 'b', 'range_col': 10}])

    def test_batch_get_items_one_table__not_strict(self):
        # Not strict - returns all fields
        db_items = [{
            'hash_col': {
                'S': 'b'
            },
            'range_col': {
                'N': '10'
            },
            'unknown_col': {
                'S': 'not_strict'
            }
        }]
        db_result = {'Responses': {'autotest_dynamo_db': db_items}}

        self.dynamo_client.dynamo_client.batch_get_item = Mock(
            return_value=db_result)

        result = self.dynamo_client.batch_get_items_one_table(
            keys_list=[{
                'hash_col': 'b'
            }], fetch_all_fields=True)

        self.assertEqual(result, [{
            'hash_col': 'b',
            'range_col': 10,
            'unknown_col': 'not_strict'
        }])

    def test_get_by_query__max_items_and_count__raises(self):
        with self.assertRaises(Exception) as e:
            self.dynamo_client.get_by_query({'hash_col': 'key'},
                                            table_name=self.table_name,
                                            max_items=3,
                                            return_count=True)
        expected_msg = "DynamoDbCLient.get_by_query does not support `max_items` and `return_count` together"
        self.assertEqual(e.exception.args[0], expected_msg)

    def test_patch__transfers_attrs_to_remove(self):

        keys = {'hash_col': 'a'}
        attributes_to_update = {'some_col': 'b'}
        attributes_to_increment = {'some_counter': 3}
        table_name = 'the_table'
        attributes_to_remove = ['remove_me']

        # using kwargs
        self.dynamo_client.update = Mock()

        self.dynamo_client.patch(
            keys=keys,
            attributes_to_update=attributes_to_update,
            attributes_to_increment=attributes_to_increment,
            table_name=table_name,
            attributes_to_remove=attributes_to_remove)

        self.dynamo_client.update.assert_called_once_with(
            keys=keys,
            attributes_to_update=attributes_to_update,
            attributes_to_increment=attributes_to_increment,
            table_name=table_name,
            attributes_to_remove=attributes_to_remove,
            condition_expression='attribute_exists hash_col')

        # not kwargs
        self.dynamo_client.update = Mock()

        self.dynamo_client.patch(keys, attributes_to_update,
                                 attributes_to_increment, table_name,
                                 attributes_to_remove)

        self.dynamo_client.update.assert_called_once_with(
            keys=keys,
            attributes_to_update=attributes_to_update,
            attributes_to_increment=attributes_to_increment,
            table_name=table_name,
            attributes_to_remove=attributes_to_remove,
            condition_expression='attribute_exists hash_col')

    def test_sleep_db__get_capacity_called(self):
        self.dynamo_client.dynamo_client = MagicMock()

        self.dynamo_client.sleep_db(last_action_time=datetime.datetime.now(),
                                    action='write',
                                    table_name='autotest_new')
        self.dynamo_client.dynamo_client.describe_table.assert_called_once()

    def test_sleep_db__wrong_action(self):
        self.assertRaises(KeyError,
                          self.dynamo_client.sleep_db,
                          last_action_time=datetime.datetime.now(),
                          action='call')

    @patch.object(time, 'sleep')
    def test_sleep_db__fell_asleep(self, mock_sleep):
        """ Test for table if BillingMode is PROVISIONED """
        self.dynamo_client.get_capacity = MagicMock(return_value={
            'read': 10,
            'write': 5
        })
        # Check that went to sleep
        time_between_ms = 100
        last_action_time = datetime.datetime.now() - datetime.timedelta(
            milliseconds=time_between_ms)
        self.dynamo_client.sleep_db(last_action_time=last_action_time,
                                    action='write')
        self.assertEqual(mock_sleep.call_count, 1)
        args, kwargs = mock_sleep.call_args

        # Should sleep around 1 / capacity second minus "time_between_ms" minus code execution time
        self.assertGreater(
            args[0], 1 / self.dynamo_client.get_capacity()['write'] -
            time_between_ms - 0.02)
        self.assertLess(args[0],
                        1 / self.dynamo_client.get_capacity()['write'])

    @patch.object(time, 'sleep')
    def test_sleep_db__fell_asleep(self, mock_sleep):
        """ Test for table if BillingMode is PAY_PER_REQUEST """

        self.dynamo_client.get_capacity = MagicMock(return_value={
            'read': 0,
            'write': 0
        })
        self.dynamo_client.sleep_db(last_action_time=datetime.datetime.now(),
                                    action='write')
        # Check that didn't go to sleep
        time_between_ms = 100
        last_action_time = datetime.datetime.now() - datetime.timedelta(
            milliseconds=time_between_ms)
        self.dynamo_client.sleep_db(last_action_time=last_action_time,
                                    action='write')
        self.assertEqual(mock_sleep.call_count, 0)

    @patch.object(time, 'sleep')
    def test_sleep_db__(self, mock_sleep):
        self.dynamo_client.get_capacity = MagicMock(return_value={
            'read': 10,
            'write': 5
        })

        # Shouldn't go to sleep
        last_action_time = datetime.datetime.now() - datetime.timedelta(
            milliseconds=900)
        self.dynamo_client.sleep_db(last_action_time=last_action_time,
                                    action='write')
        # Sleep function should not be called
        self.assertEqual(mock_sleep.call_count, 0)

    @patch.object(time, 'sleep')
    def test_sleep_db__returns_none_for_on_demand(self, mock_sleep):
        self.dynamo_client.dynamo_client = MagicMock()
        self.dynamo_client.dynamo_client.describe_table.return_value = {
            'TableName': 'autotest_OnDemand'
        }

        # Check that went to sleep
        time_between_ms = 10
        last_action_time = datetime.datetime.now() - datetime.timedelta(
            milliseconds=time_between_ms)
        self.dynamo_client.sleep_db(last_action_time=last_action_time,
                                    action='write',
                                    table_name='autotest_OnDemand')

        self.assertEqual(mock_sleep.call_count, 0,
                         "Should not have called time.sleep")

    def test_on_demand_provisioned_throughput__get_capacity(self):
        self.dynamo_client.dynamo_client = MagicMock()
        self.dynamo_client.dynamo_client.describe_table.return_value = {
            'TableName': 'autotest_OnDemand'
        }

        result = self.dynamo_client.get_capacity(
            table_name='autotest_OnDemand')
        self.assertIsNone(result)

    def test_on_demand_provisioned_throughput__get_table_indexes(self):
        self.dynamo_client.dynamo_client = MagicMock()
        self.dynamo_client.dynamo_client.describe_table.return_value = {
            'Table': {
                'TableName':
                'autotest_OnDemandTable',
                'LocalSecondaryIndexes': [],
                'GlobalSecondaryIndexes': [{
                    'IndexName':
                    'IndexA',
                    'KeySchema': [
                        {
                            'AttributeName': 'SomeAttr',
                            'KeyType': 'HASH',
                        },
                    ],
                    'Projection': {
                        'ProjectionType': 'ALL',
                    }
                }]
            }
        }

        result = self.dynamo_client.get_table_indexes(
            table_name='autotest_OnDemandTable')
        self.assertIsNone(result['IndexA'].get('ProvisionedThroughput'))

    def test_get_table_indexes__ppr(self):
        """ Check return value of get_table_indexes function in case table BillingMode is PAY_PER_REQUEST """

        self.dynamo_client._describe_table = Mock(
            return_value=PPR_DESCRIBE_TABLE)
        expected_indexes = {
            'session': {
                'projection_type': 'ALL',
                'hash_key': 'session',
                'provisioned_throughput': {
                    'write_capacity': 0,
                    'read_capacity': 0
                }
            },
            'session_id': {
                'projection_type': 'ALL',
                'hash_key': 'session_id',
                'provisioned_throughput': {
                    'write_capacity': 0,
                    'read_capacity': 0
                }
            },
        }
        self.assertEqual(expected_indexes,
                         self.dynamo_client.get_table_indexes('actions'))

    def test_get_table_indexes__pt(self):
        """ Check return value of get_table_indexes function in case table BillingMode is PROVISIONED """

        self.dynamo_client._describe_table = Mock(
            return_value=PT_DESCRIBE_TABLE)
        expected_indexes = {
            'name': {
                'projection_type': 'ALL',
                'hash_key': 'name',
                'provisioned_throughput': {
                    'write_capacity': 10,
                    'read_capacity': 100
                }
            },
            'city': {
                'projection_type': 'ALL',
                'hash_key': 'city',
                'provisioned_throughput': {
                    'write_capacity': 10,
                    'read_capacity': 100
                }
            },
        }
        self.assertEqual(expected_indexes,
                         self.dynamo_client.get_table_indexes('partners'))