示例#1
0
class WorkerAssistant_IntegrationTestCase(unittest.TestCase):
    TEST_CONFIG = TEST_TASK_CLIENT_CONFIG


    @classmethod
    def setUpClass(cls):
        """
        Clean the classic autotest table.
        """
        cls.TEST_CONFIG['init_clients'] = ['DynamoDb']


    def setUp(self):
        """
        We keep copies of main parameters here, because they may differ from test to test and cleanup needs them.
        This is responsibility of the test author to update these values if required from test.
        """
        self.config = self.TEST_CONFIG.copy()

        self.patcher = patch("sosw.app.get_config")
        self.get_config_patch = self.patcher.start()

        self.table_name = self.config['dynamo_db_config']['table_name']
        self.HASH_KEY = ('task_id', 'S')

        self.clean_task_tables()

        self.dynamo_client = DynamoDbClient(config=self.config['dynamo_db_config'])

        self.assistant = WorkerAssistant(custom_config={'test': 1})


    def tearDown(self):
        self.patcher.stop()
        self.clean_task_tables()


    def clean_task_tables(self):
        clean_dynamo_table(self.table_name, (self.HASH_KEY[0],))


    def test_mark_task_as_completed(self):
        _ = self.assistant.get_db_field_name
        task_id = '123'

        initial_task = {_('task_id'): task_id, _('labourer_id'): 'lab', _('greenfield'): 8888, _('attempts'): 2}
        self.dynamo_client.put(initial_task)

        between_times = (
            (datetime.datetime.now() - datetime.timedelta(minutes=1)).timestamp(),
            (datetime.datetime.now() + datetime.timedelta(minutes=1)).timestamp()
        )

        self.assistant.mark_task_as_completed(task_id)

        changed_task = self.dynamo_client.get_by_query({_('task_id'): task_id})[0]

        self.assertTrue(between_times[0] <= changed_task['completed_at'] <= between_times[1],
                        msg=f"NOT {between_times[0]} <= {changed_task['completed_at']} <= {between_times[1]}")
示例#2
0
class DynamodbClientIntegrationTestCase(unittest.TestCase):
    TEST_CONFIG = {
        'row_mapper': {
            'lambda_name': 'S',
            'invocation_id': 'S',
            'en_time': 'N',
            'hash_col': 'S',
            'range_col': 'N',
            'other_col': 'S',
            'new_col': 'S',
            'some_col': 'S',
            'some_counter': 'N',
            'some_bool': 'BOOL',
            'some_map': 'M',
        },
        'required_fields': ['lambda_name'],
        'table_name': 'autotest_dynamo_db',
        'hash_key': 'hash_col'
    }

    @classmethod
    def setUpClass(cls):
        clean_dynamo_table()

    def setUp(self):
        self.HASH_COL = 'hash_col'
        self.HASH_KEY = (self.HASH_COL, 'S')

        self.RANGE_COL = 'range_col'
        self.RANGE_COL_TYPE = 'N'
        self.RANGE_KEY = (self.RANGE_COL, self.RANGE_COL)

        self.KEYS = (self.HASH_COL, self.RANGE_COL)
        self.table_name = 'autotest_dynamo_db'
        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)

    def tearDown(self):
        clean_dynamo_table(self.table_name, self.KEYS)

    def test_put(self):
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: '123',
            'some_bool': True,
            'some_map': {
                'a': 1,
                'b': 'b1',
                'c': {
                    'test': True
                }
            }
        }

        client = boto3.client('dynamodb')

        client.delete_item(TableName=self.table_name,
                           Key={
                               self.HASH_COL: {
                                   'S': str(row[self.HASH_COL])
                               },
                               self.RANGE_COL: {
                                   self.RANGE_COL_TYPE:
                                   str(row[self.RANGE_COL])
                               },
                           })

        self.dynamo_client.put(row, self.table_name)

        result = client.scan(
            TableName=self.table_name,
            FilterExpression="hash_col = :hash_col AND range_col = :range_col",
            ExpressionAttributeValues={
                ':hash_col': {
                    'S': row[self.HASH_COL]
                },
                ':range_col': {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            })

        items = result['Items']

        self.assertEqual(1, len(items))

        expected = [{
            'hash_col': {
                'S': 'cat'
            },
            'range_col': {
                'N': '123'
            },
            'some_bool': {
                'BOOL': True
            },
            'some_map': {
                'M': {
                    'a': {
                        'N': '1'
                    },
                    'b': {
                        'S': 'b1'
                    },
                    'c': {
                        'M': {
                            'test': {
                                'BOOL': True
                            }
                        }
                    }
                }
            }
        }]
        self.assertEqual(expected, items)

    def test_put__create(self):
        row = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}

        self.dynamo_client.put(row, self.table_name)

        with self.assertRaises(self.dynamo_client.dynamo_client.exceptions.
                               ConditionalCheckFailedException):
            self.dynamo_client.put(row,
                                   self.table_name,
                                   overwrite_existing=False)

    def test_update__updates(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: '123',
            'some_col': 'no',
            'other_col': 'foo'
        }
        attributes_to_update = {'some_col': 'yes', 'new_col': 'yup'}

        self.dynamo_client.put(row, self.table_name)

        client = boto3.client('dynamodb')

        # First check that the row we are trying to update is PUT correctly.
        initial_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        initial_row = self.dynamo_client.dynamo_to_dict(initial_row)

        self.assertIsNotNone(initial_row)
        self.assertEqual(initial_row['some_col'], 'no')
        self.assertEqual(initial_row['other_col'], 'foo')

        self.dynamo_client.update(keys,
                                  attributes_to_update,
                                  table_name=self.table_name)

        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_col'],
                         'yes'), "Updated field not really updated"
        self.assertEqual(updated_row['new_col'],
                         'yup'), "New field was not created"
        self.assertEqual(
            updated_row['other_col'],
            'foo'), "This field should be preserved, update() damaged it"

    def test_update__increment(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: '123',
            'some_col': 'no',
            'some_counter': 10
        }
        attributes_to_increment = {'some_counter': '1'}

        self.dynamo_client.put(row, self.table_name)

        self.dynamo_client.update(
            keys, {},
            attributes_to_increment=attributes_to_increment,
            table_name=self.table_name)

        client = boto3.client('dynamodb')

        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_counter'], 11)

    def test_update__increment_2(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: '123',
            'some_col': 'no',
            'some_counter': 10
        }
        attributes_to_increment = {'some_counter': 5}

        self.dynamo_client.put(row, self.table_name)

        self.dynamo_client.update(
            keys, {},
            attributes_to_increment=attributes_to_increment,
            table_name=self.table_name)

        client = boto3.client('dynamodb')

        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_counter'], 15)

    def test_update__increment_no_default(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}
        row = {self.HASH_COL: 'cat', self.RANGE_COL: '123', 'some_col': 'no'}
        attributes_to_increment = {'some_counter': '3'}

        self.dynamo_client.put(row, self.table_name)

        self.dynamo_client.update(
            keys, {},
            attributes_to_increment=attributes_to_increment,
            table_name=self.table_name)

        client = boto3.client('dynamodb')

        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_counter'], 3)

    def test_update__condition_expression(self):
        keys = {self.HASH_COL: 'slime', self.RANGE_COL: '41'}
        row = {self.HASH_COL: 'slime', self.RANGE_COL: '41', 'some_col': 'no'}

        self.dynamo_client.put(row, self.table_name)

        # Should fail because conditional expression does not match
        self.assertRaises(self.dynamo_client.dynamo_client.exceptions.
                          ConditionalCheckFailedException,
                          self.dynamo_client.update,
                          keys, {},
                          attributes_to_increment={'some_counter': '3'},
                          condition_expression='some_col = yes',
                          table_name=self.table_name)

        # Should pass
        self.dynamo_client.update(
            keys, {},
            attributes_to_increment={'some_counter': '3'},
            condition_expression='some_col = no',
            table_name=self.table_name)

        client = boto3.client('dynamodb')
        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)
        self.assertEqual(updated_row['some_counter'], 3)

    def test_patch(self):
        keys = {self.HASH_COL: 'slime', self.RANGE_COL: '41'}
        row = {self.HASH_COL: 'slime', self.RANGE_COL: '41', 'some_col': 'no'}

        # Should fail because row doesn't exist
        self.assertRaises(self.dynamo_client.dynamo_client.exceptions.
                          ConditionalCheckFailedException,
                          self.dynamo_client.patch,
                          keys,
                          attributes_to_update={'some_col': 'yes'},
                          table_name=self.table_name)

        # Create the row
        self.dynamo_client.put(row, self.table_name)
        # Should pass because the row exists now
        self.dynamo_client.patch(keys,
                                 attributes_to_update={'some_col': 'yes'},
                                 table_name=self.table_name)

        client = boto3.client('dynamodb')
        updated_row = client.get_item(
            Key={
                self.HASH_COL: {
                    'S': row[self.HASH_COL]
                },
                self.RANGE_COL: {
                    self.RANGE_COL_TYPE: str(row[self.RANGE_COL])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)
        self.assertEqual(updated_row['some_col'], 'yes')

    def test_get_by_query__primary_index(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '123'}
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 123,
            'some_col': 'test',
            'some_bool': True
        }
        self.dynamo_client.put(row, self.table_name)

        result = self.dynamo_client.get_by_query(keys=keys)

        self.assertEqual(len(result), 1)
        result = result[0]
        for key in row:
            self.assertEqual(row[key], result[key])
        for key in result:
            self.assertEqual(row[key], result[key])

    def test_get_by_query__primary_index__gets_multiple(self):
        row = {self.HASH_COL: 'cat', self.RANGE_COL: 123, 'some_col': 'test'}
        self.dynamo_client.put(row, self.table_name)

        row2 = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 1234,
            'some_col': 'test2'
        }
        self.dynamo_client.put(row2, self.table_name)

        result = self.dynamo_client.get_by_query(keys={self.HASH_COL: 'cat'})

        self.assertEqual(len(result), 2)

        result1 = [
            x for x in result if x[self.RANGE_COL] == row[self.RANGE_COL]
        ][0]
        result2 = [
            x for x in result if x[self.RANGE_COL] == row2[self.RANGE_COL]
        ][0]

        for key in row:
            self.assertEqual(row[key], result1[key])
        for key in result1:
            self.assertEqual(row[key], result1[key])
        for key in row2:
            self.assertEqual(row2[key], result2[key])
        for key in result2:
            self.assertEqual(row2[key], result2[key])

    def test_get_by_query__secondary_index(self):
        keys = {self.HASH_COL: 'cat', 'other_col': 'abc123'}
        row = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 123,
            'other_col': 'abc123'
        }
        self.dynamo_client.put(row, self.table_name)

        result = self.dynamo_client.get_by_query(keys=keys,
                                                 index_name='autotest_index')

        self.assertEqual(len(result), 1)
        result = result[0]
        for key in row:
            self.assertEqual(row[key], result[key])
        for key in result:
            self.assertEqual(row[key], result[key])

    def test_get_by_query__comparison(self):
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: '300'}
        row1 = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 123,
            'other_col': 'abc123'
        }
        row2 = {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 456,
            'other_col': 'abc123'
        }
        self.dynamo_client.put(row1, self.table_name)
        self.dynamo_client.put(row2, self.table_name)

        result = self.dynamo_client.get_by_query(
            keys=keys, comparisons={self.RANGE_COL: '<='})

        self.assertEqual(len(result), 1)

        result = result[0]
        self.assertEqual(result, row1)

    def test_get_by_query__comparison_between(self):
        # Put sample data
        x = [
            self.dynamo_client.put({
                self.HASH_COL: 'cat',
                self.RANGE_COL: x
            }, self.table_name) for x in range(10)
        ]

        keys = {
            self.HASH_COL: 'cat',
            'st_between_range_col': '3',
            'en_between_range_col': '6'
        }
        result = self.dynamo_client.get_by_query(
            keys=keys, comparisons={self.RANGE_COL: 'between'})
        # print(result)
        self.assertTrue(all(x[self.RANGE_COL] in range(3, 7) for x in result))

        result = self.dynamo_client.get_by_query(keys=keys)
        # print(result)
        self.assertTrue(all(x[self.RANGE_COL] in range(3, 7) for x in result)), "Failed if unspecified comparison. " \
                                                                                "Should be automatic for :st_between_..."

    def test_get_by_query__filter_expression(self):
        """
        This _integration_ test runs multiple checks with same sample data for several comparators.
        Have a look at the manual if required:
        https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html
        """

        # Put sample data
        [
            self.dynamo_client.put({
                self.HASH_COL: 'cat',
                self.RANGE_COL: x
            }, self.table_name) for x in range(3)
        ]
        [
            self.dynamo_client.put(
                {
                    self.HASH_COL: 'cat',
                    self.RANGE_COL: x,
                    'mark': 1
                }, self.table_name) for x in range(3, 6)
        ]
        self.dynamo_client.put(
            {
                self.HASH_COL: 'cat',
                self.RANGE_COL: 6,
                'mark': 0
            }, self.table_name)
        self.dynamo_client.put(
            {
                self.HASH_COL: 'cat',
                self.RANGE_COL: 7,
                'mark': 'a'
            }, self.table_name)

        # Condition by range_col will return five rows out of six: 0 - 4
        # Filter expression neggs the first three rows because they don't have `mark = 1`.
        keys = {self.HASH_COL: 'cat', self.RANGE_COL: 4}
        result = self.dynamo_client.get_by_query(
            keys=keys,
            comparisons={self.RANGE_COL: '<='},
            fetch_all_fields=True,
            filter_expression='mark = 1')
        # print(result)

        self.assertEqual(len(result), 2)
        self.assertEqual(result[0], {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 3,
            'mark': 1
        })
        self.assertEqual(result[1], {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 4,
            'mark': 1
        })

        # In the same test we check also some comparator _functions_.
        result = self.dynamo_client.get_by_query(
            keys=keys,
            comparisons={self.RANGE_COL: '<='},
            fetch_all_fields=True,
            filter_expression='attribute_exists mark')
        # print(result)
        self.assertEqual(len(result), 2)
        self.assertEqual([x[self.RANGE_COL] for x in result], list(range(3,
                                                                         5)))

        self.assertEqual(result[0], {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 3,
            'mark': 1
        })
        self.assertEqual(result[1], {
            self.HASH_COL: 'cat',
            self.RANGE_COL: 4,
            'mark': 1
        })

        result = self.dynamo_client.get_by_query(
            keys=keys,
            comparisons={self.RANGE_COL: '<='},
            fetch_all_fields=True,
            filter_expression='attribute_not_exists mark')
        # print(result)
        self.assertEqual(len(result), 3)
        self.assertEqual([x[self.RANGE_COL] for x in result], list(range(3)))

    def test_get_by_query__comparison_begins_with(self):
        self.table_name = 'autotest_config_component'  # This table has a string range key
        self.HASH_KEY = ('env', 'S')
        self.RANGE_KEY = ('config_name', 'S')
        self.KEYS = ('env', 'config_name')
        config = {
            'row_mapper': {
                'env': 'S',
                'config_name': 'S',
                'config_value': 'S'
            },
            'required_fields': ['env', 'config_name', 'config_value'],
            'table_name': 'autotest_config_component',
            'hash_key': self.HASH_COL
        }

        self.dynamo_client = DynamoDbClient(config=config)

        row1 = {
            'env': 'cat',
            'config_name': 'testzing',
            'config_value': 'abc123'
        }
        row2 = {
            'env': 'cat',
            'config_name': 'dont_get_this',
            'config_value': 'abc123'
        }
        row3 = {
            'env': 'cat',
            'config_name': 'testzer',
            'config_value': 'abc124'
        }
        self.dynamo_client.put(row1, self.table_name)
        self.dynamo_client.put(row2, self.table_name)
        self.dynamo_client.put(row3, self.table_name)

        keys = {'env': 'cat', 'config_name': 'testz'}
        result = self.dynamo_client.get_by_query(
            keys=keys,
            table_name=self.table_name,
            comparisons={'config_name': 'begins_with'})

        self.assertEqual(len(result), 2)

        self.assertTrue(row1 in result)
        self.assertTrue(row3 in result)

    def test_get_by_query__max_items(self):
        # This function can also be used for some benchmarking, just change to bigger amounts manually.
        INITIAL_TASKS = 5  # Change to 500 to run benchmarking, and uncomment raise at the end of the test.

        for x in range(1000, 1000 + INITIAL_TASKS):
            row = {self.HASH_COL: f"key", self.RANGE_COL: x}
            self.dynamo_client.put(row, self.table_name)
            if INITIAL_TASKS > 10:
                time.sleep(
                    0.1
                )  # Sleep a little to fit the Write Capacity (10 WCU) of autotest table.

        n = 3
        st = time.perf_counter()
        result = self.dynamo_client.get_by_query({self.HASH_COL: 'key'},
                                                 table_name=self.table_name,
                                                 max_items=n)
        bm = time.perf_counter() - st
        logging.info(f"Benchmark (n={n}): {bm}")

        self.assertEqual(len(result), n)
        self.assertLess(bm, 0.1)

        # Check unspecified limit.
        result = self.dynamo_client.get_by_query({self.HASH_COL: 'key'},
                                                 table_name=self.table_name)
        self.assertEqual(len(result), INITIAL_TASKS)

    def test_get_by_query__return_count(self):
        rows = [{
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 121,
            'some_col': 'test1'
        }, {
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 122,
            'some_col': 'test2'
        }, {
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 123,
            'some_col': 'test3'
        }]

        for x in rows:
            self.dynamo_client.put(x, table_name=self.table_name)

        result = self.dynamo_client.get_by_query({self.HASH_COL: 'cat1'},
                                                 table_name=self.table_name,
                                                 return_count=True)

        self.assertEqual(result, 3)

    def test_get_by_query__reverse(self):
        rows = [{
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 121,
            'some_col': 'test1'
        }, {
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 122,
            'some_col': 'test2'
        }, {
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 123,
            'some_col': 'test3'
        }]

        for x in rows:
            self.dynamo_client.put(x, table_name=self.table_name)

        result = self.dynamo_client.get_by_query({self.HASH_COL: 'cat1'},
                                                 table_name=self.table_name,
                                                 desc=True)

        self.assertEqual(result[0], rows[-1])

    def test_get_by_scan__all(self):
        rows = [{
            self.HASH_COL: 'cat1',
            self.RANGE_COL: 121,
            'some_col': 'test1'
        }, {
            self.HASH_COL: 'cat2',
            self.RANGE_COL: 122,
            'some_col': 'test2'
        }, {
            self.HASH_COL: 'cat3',
            self.RANGE_COL: 123,
            'some_col': 'test3'
        }]
        for x in rows:
            self.dynamo_client.put(x, self.table_name)

        result = self.dynamo_client.get_by_scan()

        self.assertEqual(len(result), 3)

        for r in rows:
            assert r in result, f"row not in result from dynamo scan: {r}"

    def test_get_by_scan__with_filter(self):
        rows = [
            {
                self.HASH_COL: 'cat1',
                self.RANGE_COL: 121,
                'some_col': 'test1'
            },
            {
                self.HASH_COL: 'cat1',
                self.RANGE_COL: 122,
                'some_col': 'test2'
            },
            {
                self.HASH_COL: 'cat2',
                self.RANGE_COL: 122,
                'some_col': 'test2'
            },
        ]
        for x in rows:
            self.dynamo_client.put(x, self.table_name)

        filter = {'some_col': 'test2'}

        result = self.dynamo_client.get_by_scan(attrs=filter)

        self.assertEqual(len(result), 2)

        for r in rows[1:]:
            assert r in result, f"row not in result from dynamo scan: {r}"

    def test_batch_get_items(self):
        rows = [
            {
                self.HASH_COL: 'cat1',
                self.RANGE_COL: 121,
                'some_col': 'test1'
            },
            {
                self.HASH_COL: 'cat1',
                self.RANGE_COL: 122,
                'some_col': 'test2'
            },
            {
                self.HASH_COL: 'cat2',
                self.RANGE_COL: 122,
                'some_col': 'test2'
            },
        ]
        for x in rows:
            self.dynamo_client.put(x, self.table_name)

        keys_list_query = [
            {
                self.HASH_COL: 'cat1',
                self.RANGE_COL: 121
            },
            {
                self.HASH_COL: 'doesnt_exist',
                self.RANGE_COL: 40
            },
            {
                self.HASH_COL: 'cat2',
                self.RANGE_COL: 122
            },
        ]

        result = self.dynamo_client.batch_get_items_one_table(keys_list_query)

        self.assertEqual(len(result), 2)

        self.assertIn(rows[0], result)
        self.assertIn(rows[2], result)

    def test_delete(self):
        self.dynamo_client.put({self.HASH_COL: 'cat1', self.RANGE_COL: 123})
        self.dynamo_client.put({self.HASH_COL: 'cat2', self.RANGE_COL: 234})

        self.dynamo_client.delete(keys={
            self.HASH_COL: 'cat1',
            self.RANGE_COL: '123'
        })

        items = self.dynamo_client.get_by_scan()

        self.assertEqual(len(items), 1)
        self.assertEqual(items[0], {
            self.HASH_COL: 'cat2',
            self.RANGE_COL: 234
        })

    def test_get_table_keys(self):
        result1 = self.dynamo_client.get_table_keys()
        self.assertEqual(result1, ('hash_col', 'range_col'))

        result2 = self.dynamo_client.get_table_keys(table_name=self.table_name)
        self.assertEqual(result2, ('hash_col', 'range_col'))

    def test_get_table_indexes(self):
        indexes = self.dynamo_client.get_table_indexes()
        expected = {
            'autotest_index': {
                'projection_type': 'ALL',
                'hash_key': 'hash_col',
                'range_key': 'other_col',
                'provisioned_throughput': {
                    'write_capacity': 1,
                    'read_capacity': 1
                }
            }
        }
        self.assertDictEqual(expected, indexes)

    def test_batch_get_items_one_table(self):
        # If you want to stress test batch_get_items_one_table, use bigger numbers
        num_of_items = 5
        query_from = 2
        query_till = 4
        expected_items = query_till - query_from

        # Write items
        operations = []
        query_keys = []
        for i in range(num_of_items):
            item = {self.HASH_COL: f'cat{i%2}', self.RANGE_COL: i}
            operations.append(
                {'Put': self.dynamo_client.build_put_query(item)})
            query_keys.append(item)
        for operations_chunk in chunks(operations, 10):
            self.dynamo_client.dynamo_client.transact_write_items(
                TransactItems=operations_chunk)
            time.sleep(1)  # cause the table has 10 write/sec capacity

        # Batch get items
        results = self.dynamo_client.batch_get_items_one_table(
            keys_list=query_keys[query_from:query_till])
        self.assertEqual(expected_items, len(results))
示例#3
0
class TaskManager_IntegrationTestCase(unittest.TestCase):
    TEST_CONFIG = TEST_TASK_CLIENT_CONFIG
    LABOURER = Labourer(
        id='some_function',
        arn='arn:aws:lambda:us-west-2:000000000000:function:some_function')

    @classmethod
    def setUpClass(cls):
        """
        Clean the classic autotest table.
        """
        cls.TEST_CONFIG['init_clients'] = ['DynamoDb']

    def setUp(self):
        """
        We keep copies of main parameters here, because they may differ from test to test and cleanup needs them.
        This is responsibility of the test author to update these values if required from test.
        """
        self.config = self.TEST_CONFIG.copy()

        self.HASH_KEY = ('task_id', 'S')
        self.RANGE_KEY = ('labourer_id', 'S')
        self.table_name = self.config['dynamo_db_config']['table_name']
        self.completed_tasks_table = self.config['sosw_closed_tasks_table']
        self.retry_tasks_table = self.config['sosw_retry_tasks_table']

        self.clean_task_tables()

        self.dynamo_client = DynamoDbClient(
            config=self.config['dynamo_db_config'])
        self.manager = TaskManager(custom_config=self.config)
        self.manager.ecology_client = MagicMock()

        self.labourer = deepcopy(self.LABOURER)

    def tearDown(self):
        self.clean_task_tables()

    def clean_task_tables(self):
        clean_dynamo_table(self.table_name, (self.HASH_KEY[0], ))
        clean_dynamo_table(self.completed_tasks_table, ('task_id', ))
        clean_dynamo_table(self.retry_tasks_table, ('labourer_id', 'task_id'))

    def setup_tasks(self,
                    status='available',
                    mutiple_labourers=False,
                    count_tasks=3):
        """ Some fake adding some scheduled tasks for some workers. """

        _ = self.manager.get_db_field_name
        _cfg = self.manager.config.get

        table = _cfg('dynamo_db_config')['table_name'] if status not in ['closed', 'failed'] \
            else _cfg('sosw_closed_tasks_table')

        MAP = {
            'available': {
                self.RANGE_KEY[0]:
                lambda x: str(worker_id),
                _('greenfield'):
                lambda x: round(1000 + random.randrange(0, 100000, 1000)),
                _('attempts'):
                lambda x: 0,
            },
            'invoked': {
                self.RANGE_KEY[0]:
                lambda x: str(worker_id),
                _('greenfield'):
                lambda x: round(time.time()) + _cfg(
                    'greenfield_invocation_delta'),
                _('attempts'):
                lambda x: 1,
            },
            'expired': {
                self.RANGE_KEY[0]:
                lambda x: str(worker_id),
                _('greenfield'):
                lambda x: round(time.time()) + _cfg(
                    'greenfield_invocation_delta') - random.randint(
                        1000, 10000),
                _('attempts'):
                lambda x: 1,
            },
            'running': {
                self.RANGE_KEY[0]:
                lambda x: str(worker_id),
                _('greenfield'):
                lambda x: round(time.time()) + _cfg(
                    'greenfield_invocation_delta') - random.randint(1, 900),
                _('attempts'):
                lambda x: 1,
            },
            'closed': {
                _('greenfield'):
                lambda x: round(time.time()) + _cfg(
                    'greenfield_invocation_delta') - random.randint(
                        1000, 10000),
                _('labourer_id_task_status'):
                lambda x: f"{self.LABOURER.id}_1",
                _('completed_at'):
                lambda x: x[_('greenfield')] - _cfg(
                    'greenfield_invocation_delta') + random.randint(10, 300),
                _('closed_at'):
                lambda x: x[_('completed_at')] + random.randint(1, 60),
                _('attempts'):
                lambda x: 3,
            },
            'failed': {
                _('greenfield'):
                lambda x: round(time.time()) + _cfg(
                    'greenfield_invocation_delta') - random.randint(
                        1000, 10000),
                _('labourer_id_task_status'):
                lambda x: f"{self.LABOURER.id}_0",
                _('closed_at'):
                lambda x: x[_('greenfield')] + 900 + random.randint(1, 60),
                _('attempts'):
                lambda x: 3,
            },
        }

        # raise ValueError(f"Unsupported `status`: {status}. Should be one of: 'available', 'invoked'.")

        workers = [self.LABOURER.id] if not mutiple_labourers else range(
            42, 45)
        for worker_id in workers:

            for i in range(count_tasks):
                row = {
                    self.HASH_KEY[0]:
                    f"task_id_{worker_id}_{i}_{str(uuid.uuid4())[:8]}",  # Task ID
                }

                for field, getter in MAP[status].items():
                    row[field] = getter(row)

                print(f"Putting {row} to {table}")
                self.dynamo_client.put(row, table_name=table)
                time.sleep(
                    0.1
                )  # Sleep a little to fit the Write Capacity (10 WCU) of autotest table.

    def test_get_next_for_labourer(self):
        self.setup_tasks()
        # time.sleep(5)

        result = self.manager.get_next_for_labourer(self.LABOURER,
                                                    only_ids=True)
        # print(result)

        self.assertEqual(len(result), 1, "Returned more than one task")
        self.assertIn(f'task_id_{self.LABOURER.id}_', result[0])

    def test_get_next_for_labourer__multiple(self):
        self.setup_tasks()

        result = self.manager.get_next_for_labourer(self.LABOURER,
                                                    cnt=5000,
                                                    only_ids=True)
        # print(result)

        self.assertEqual(len(result), 3,
                         "Should be just 3 tasks for this worker in setup")
        self.assertTrue(
            all(f'task_id_{self.LABOURER.id}_' in task for task in result),
            "Returned some tasks of other Workers")

    def test_get_next_for_labourer__not_take_invoked(self):
        self.setup_tasks()
        self.setup_tasks(status='invoked')

        result = self.manager.get_next_for_labourer(self.LABOURER,
                                                    cnt=50,
                                                    only_ids=True)
        # print(result)

        self.assertEqual(
            len(result), 3,
            "Should be just 3 tasks for this worker in setup. The other 3 are invoked."
        )
        self.assertTrue(
            all(f'task_id_{self.LABOURER.id}_' in task for task in result),
            "Returned some tasks of other Workers")

    def test_get_next_for_labourer__full_tasks(self):
        self.setup_tasks()

        result = self.manager.get_next_for_labourer(self.LABOURER, cnt=2)
        # print(result)

        self.assertEqual(len(result), 2, "Should be just 2 tasks as requested")

        for task in result:
            self.assertIn(
                f'task_id_{self.LABOURER.id}_',
                task['task_id']), "Returned some tasks of other Workers"
            self.assertEqual(
                self.LABOURER.id,
                task['labourer_id']), "Returned some tasks of other Workers"

    def register_labourers(self):
        self.manager.get_labourers = MagicMock(return_value=[self.LABOURER])
        return self.manager.register_labourers()

    def test_mark_task_invoked(self):
        greenfield = round(time.time() - random.randint(100, 1000))
        delta = self.manager.config['greenfield_invocation_delta']
        self.register_labourers()

        row = {
            self.HASH_KEY[0]: f"task_id_{self.LABOURER.id}_256",  # Task ID
            self.RANGE_KEY[0]: self.LABOURER.id,  # Worker ID
            'greenfield': greenfield
        }
        self.dynamo_client.put(row)
        # print(f"Saved initial version with greenfield some date not long ago: {row}")

        # Do the actual tested job
        self.manager.mark_task_invoked(self.LABOURER, row)
        time.sleep(1)
        result = self.dynamo_client.get_by_query(
            {self.HASH_KEY[0]: f"task_id_{self.LABOURER.id}_256"},
            strict=False)
        # print(f"The new updated value of task is: {result}")

        # Rounded -2 we check that the greenfield was updated
        self.assertAlmostEqual(round(int(time.time()) + delta, -2),
                               round(result[0]['greenfield'], -2))

    def test_get_invoked_tasks_for_labourer(self):
        self.register_labourers()

        self.setup_tasks(status='running')
        self.setup_tasks(status='expired')
        self.setup_tasks(status='invoked')
        self.assertEqual(
            len(self.manager.get_invoked_tasks_for_labourer(self.LABOURER)), 3)

    def test_get_running_tasks_for_labourer(self):
        self.register_labourers()

        self.setup_tasks(status='available')
        self.setup_tasks(status='running')
        self.setup_tasks(status='expired')
        self.assertEqual(
            len(self.manager.get_running_tasks_for_labourer(self.LABOURER)), 3)

    def test_get_expired_tasks_for_labourer(self):
        self.register_labourers()

        self.setup_tasks(status='running')
        self.setup_tasks(status='expired')
        self.assertEqual(
            len(self.manager.get_expired_tasks_for_labourer(self.LABOURER)), 3)

    # @unittest.skip("Function currently depricated")
    # def test_close_task(self):
    #     _ = self.manager.get_db_field_name
    #     # Create task with id=123
    #     task = {_('task_id'): '123', _('labourer_id'): 'lambda1', _('greenfield'): 8888, _('attempts'): 2,
    #             _('completed_at'): 123123}
    #     self.dynamo_client.put(task)
    #
    #     # Call
    #     self.manager.close_task(task_id='123', labourer_id='lambda1')
    #
    #     # Get from db, check
    #     tasks = self.dynamo_client.get_by_query({_('task_id'): '123'})
    #     self.assertEqual(len(tasks), 1)
    #     task_result = tasks[0]
    #
    #     expected_result = task.copy()
    #
    #     for k in ['task_id', 'labourer_id', 'greenfield', 'attempts']:
    #         assert expected_result[k] == task_result[k]
    #
    #     self.assertTrue(_('closed_at') in task_result, msg=f"{_('closed_at')} not in task_result {task_result}")
    #     self.assertTrue(time.time() - 360 < task_result[_('closed_at')] < time.time())

    def test_archive_task(self):
        _ = self.manager.get_db_field_name
        # Create task with id=123
        task = {
            _('task_id'): '123',
            _('labourer_id'): 'lambda1',
            _('greenfield'): 8888,
            _('attempts'): 2
        }
        self.dynamo_client.put(task)

        # Call
        self.manager.archive_task('123')

        # Check the task isn't in the tasks db, but is in the completed_tasks table
        tasks = self.dynamo_client.get_by_query({_('task_id'): '123'})
        self.assertEqual(len(tasks), 0)

        completed_tasks = self.dynamo_client.get_by_query(
            {_('task_id'): '123'}, table_name=self.completed_tasks_table)
        self.assertEqual(len(completed_tasks), 1)
        completed_task = completed_tasks[0]

        for k in task.keys():
            self.assertEqual(task[k], completed_task[k])
        for k in completed_task.keys():
            if k != _('closed_at'):
                self.assertEqual(task[k], completed_task[k])

        self.assertTrue(
            time.time() - 360 < completed_task[_('closed_at')] < time.time())

    def test_move_task_to_retry_table(self):
        _ = self.manager.get_db_field_name
        labourer_id = 'lambda1'
        task = {
            _('task_id'): '123',
            _('labourer_id'): labourer_id,
            _('greenfield'): 8888,
            _('attempts'): 2
        }
        delay = 300

        self.dynamo_client.put(task)

        self.manager.move_task_to_retry_table(task, delay)

        result_tasks = self.dynamo_client.get_by_query({_('task_id'): '123'})
        self.assertEqual(len(result_tasks), 0)

        result_retry_tasks = self.dynamo_client.get_by_query(
            {_('labourer_id'): labourer_id}, table_name=self.retry_tasks_table)
        self.assertEqual(len(result_retry_tasks), 1)
        result = first_or_none(result_retry_tasks)

        for k in task:
            self.assertEqual(task[k], result[k])
        for k in result:
            if k != _('desired_launch_time'):
                self.assertEqual(result[k], task[k])

        self.assertTrue(
            time.time() + delay -
            60 < result[_('desired_launch_time')] < time.time() + delay + 60)

    def test_get_tasks_to_retry_for_labourer(self):
        _ = self.manager.get_db_field_name

        tasks = RETRY_TASKS.copy()
        # Add tasks to retry table
        for task in tasks:
            self.dynamo_client.put(task, self.config['sosw_retry_tasks_table'])

        # Call
        with patch('time.time') as t:
            t.return_value = 9500
            labourer = self.manager.register_labourers()[0]

        result_tasks = self.manager.get_tasks_to_retry_for_labourer(labourer,
                                                                    limit=20)

        self.assertEqual(len(result_tasks), 2)

        # Check it only gets tasks with timestamp <= now
        self.assertIn(tasks[0], result_tasks)
        self.assertIn(tasks[1], result_tasks)

    def test_retry_tasks(self):
        _ = self.manager.get_db_field_name

        with patch('time.time') as t:
            t.return_value = 9500
            labourer = self.manager.register_labourers()[0]

        self.manager.get_oldest_greenfield_for_labourer = Mock(
            return_value=8888)

        # Add tasks to tasks_table
        regular_tasks = [
            {
                _('labourer_id'): labourer.id,
                _('task_id'): '11',
                _('arn'): 'some_arn',
                _('payload'): {},
                _('greenfield'): 8888
            },
            {
                _('labourer_id'): labourer.id,
                _('task_id'): '22',
                _('arn'): 'some_arn',
                _('payload'): {},
                _('greenfield'): 9999
            },
        ]
        for task in regular_tasks:
            self.dynamo_client.put(task)

        # Add tasks to retry_table
        retry_tasks = RETRY_TASKS.copy()

        for task in retry_tasks:
            self.dynamo_client.put(
                task, table_name=self.config['sosw_retry_tasks_table'])

        retry_table_items = self.dynamo_client.get_by_scan(
            table_name=self.retry_tasks_table)
        self.assertEqual(len(retry_table_items), len(retry_tasks))

        # Use get_tasks_to_retry_for_labourer to get tasks
        tasks = self.manager.get_tasks_to_retry_for_labourer(labourer)

        # Call
        self.manager.retry_tasks(labourer, tasks)

        # Check removed 2 out of 3 tasks from retry queue. One is desired to be launched later.
        retry_table_items = self.dynamo_client.get_by_scan(
            table_name=self.retry_tasks_table)
        self.assertEqual(len(retry_table_items), 1)

        # Check tasks moved to `tasks_table` with lowest greenfields
        tasks_table_items = self.dynamo_client.get_by_scan()
        for x in tasks_table_items:
            print(x)
        self.assertEqual(len(tasks_table_items), 4)

        for reg_task in regular_tasks:
            self.assertIn(reg_task, tasks_table_items)

        for retry_task in retry_tasks:
            try:
                matching = next(x for x in tasks_table_items
                                if x[_('task_id')] == retry_task[_('task_id')])
            except StopIteration:
                print(
                    f"Task not retried {retry_task}. Probably not yet desired."
                )
                continue

            for k in retry_task.keys():
                if k not in [_('greenfield'), _('desired_launch_time')]:
                    self.assertEqual(retry_task[k], matching[k])

            for k in matching.keys():
                if k != _('greenfield'):
                    self.assertEqual(retry_task[k], matching[k])

            print(
                f"New greenfield of a retried task: {matching[_('greenfield')]}"
            )
            self.assertTrue(matching[_('greenfield')] < min(
                [x[_('greenfield')] for x in regular_tasks]))

    @patch.object(boto3, '__version__', return_value='1.9.53')
    def test_retry_tasks__old_boto(self, n):
        self.test_retry_tasks()

    def test_get_oldest_greenfield_for_labourer__get_newest_greenfield_for_labourer(
            self):
        with patch('time.time') as t:
            t.return_value = 9500
            labourer = self.manager.register_labourers()[0]

        min_gf = 20000
        max_gf = 10000
        for i in range(5):  # Ran this with range(1000), it passes :)
            gf = random.randint(10000, 20000)
            if gf < min_gf:
                min_gf = gf
            if gf > max_gf:
                max_gf = gf
            row = {
                'labourer_id': f"{labourer.id}",
                'task_id': f"task-{i}",
                'greenfield': gf
            }
            self.dynamo_client.put(row)
            time.sleep(
                0.1
            )  # Sleep a little to fit the Write Capacity (10 WCU) of autotest table.

        result = self.manager.get_oldest_greenfield_for_labourer(labourer)
        self.assertEqual(min_gf, result)

        newest = self.manager.get_newest_greenfield_for_labourer(labourer)
        self.assertEqual(max_gf, newest)

    def test_get_length_of_queue_for_labourer(self):
        labourer = Labourer(id='some_lambda', arn='some_arn')

        num_of_tasks = 3  # Ran this with 464 tasks and it worked

        for i in range(num_of_tasks):
            row = {
                'labourer_id': f"some_lambda",
                'task_id': f"task-{i}",
                'greenfield': i
            }
            self.dynamo_client.put(row)
            time.sleep(
                0.1
            )  # Sleep a little to fit the Write Capacity (10 WCU) of autotest table.

        queue_len = self.manager.get_length_of_queue_for_labourer(labourer)

        self.assertEqual(queue_len, num_of_tasks)

    def test_get_average_labourer_duration__calculates_average__only_failing_tasks(
            self):
        self.manager.ecology_client.get_max_labourer_duration.return_value = 900
        some_labourer = self.register_labourers()[0]

        self.setup_tasks(status='failed', count_tasks=15)

        self.assertEqual(
            900, self.manager.get_average_labourer_duration(some_labourer))

    def test_get_average_labourer_duration__calculates_average(self):
        self.manager.ecology_client.get_max_labourer_duration.return_value = 900
        some_labourer = self.register_labourers()[0]

        self.setup_tasks(status='closed', count_tasks=15)
        self.setup_tasks(status='failed', count_tasks=15)

        self.assertLessEqual(
            self.manager.get_average_labourer_duration(some_labourer), 900)
        self.assertGreaterEqual(
            self.manager.get_average_labourer_duration(some_labourer), 10)
示例#4
0
class DynamoConfigTestCase(unittest.TestCase):
    TEST_CONFIG = {
        'row_mapper': {
            'env': 'S',
            'config_name': 'S',
            'config_value': 'S'
        },
        'required_fields': ['env', 'config_name', 'config_value'],
        'table_name': 'autotest_config_component'
    }

    def setUp(self):
        config = self.TEST_CONFIG.copy()
        self.dynamo_client = DynamoDbClient(config)
        self.dynamo_config = DynamoConfig(test=True)

    def tearDown(self):
        clean_dynamo_table('autotest_config_component',
                           keys=('env', 'config_name'))

    @unittest.skip("TODO need normal patching")
    def test_get_config__json(self):
        row = {
            'env': 'production',
            'config_name': 'sophie_test',
            'config_value': '{"a": 1}'
        }
        self.dynamo_client.put(row)

        config = self.dynamo_config.get_config('sophie_test', "production")

        self.assertEqual(config, {'a': 1})

    @unittest.skip("TODO need normal patching")
    def test_get_config__str(self):
        def get_by_query(*args, **kwargs):
            return [{
                'env': 'production',
                'config_name': 'sophie_test2',
                'config_value': 'some text'
            }]

        self.dynamo_config.dynamo_client = FakeDynamo
        with patch.object(FakeDynamo, 'get_by_query', new=get_by_query):
            config = self.dynamo_config.get_config('sophie_test2',
                                                   "production")
            self.assertEqual(config, 'some text')

    def test_get_config__doesnt_exist(self):
        config = self.dynamo_config.get_config('sophie_test', "production")
        self.assertEqual(config, {})

    def test_get_credentials_by_prefix(self):
        SAMPLES = [{
            'env': 'dev',
            'config_name': 'testing_zz1',
            'config_value': '{"a": 1}'
        }, {
            'env': 'dev',
            'config_name': 'testing_zz2',
            'config_value': 'zz2_value'
        }, {
            'env': 'dev',
            'config_name': 'dont_get_this',
            'config_value': '{"b": 2}'
        }, {
            'env': 'dev',
            'config_name': 'testingab2',
            'config_value': 'some text'
        }]

        for row in SAMPLES:
            self.dynamo_client.put(row)

        result = self.dynamo_config.get_credentials_by_prefix('testing')

        self.assertEqual(len(result), 2)
        self.assertIn('zz1', result)
        self.assertEqual(result['zz2'], 'zz2_value')

    def test_update_config(self):
        KEY = 'testing_update_method'
        VALUE = 'exists'
        self.dynamo_config.update_config(name=KEY, val=VALUE)
        result = self.dynamo_client.get_by_query(keys={
            'env': 'dev',
            'config_name': KEY
        })
        self.assertEqual(
            {
                'env': 'dev',
                'config_name': KEY,
                'config_value': VALUE
            }, result[0])
示例#5
0
class dynamodb_client_UnitTestCase(unittest.TestCase):
    TEST_CONFIG = {
        'row_mapper': {
            'lambda_name': 'S',
            'invocation_id': 'S',
            'en_time': 'N',
            'hash_col': 'S',
            'range_col': 'N',
            'other_col': 'S',
            'new_col': 'S',
            'some_col': 'S',
            'some_counter': 'N',
            'some_bool': 'BOOL',
            'some_bool2': 'BOOL',
            'some_map': 'M',
            'some_list': 'L'
        },
        'required_fields': ['lambda_name'],
        'table_name': 'autotest_dynamo_db',
        'hash_key': 'hash_col',
    }

    def setUp(self):
        self.HASH_KEY = ('hash_col', 'S')
        self.RANGE_KEY = ('range_col', 'N')
        self.KEYS = ('hash_col', 'range_col')
        self.table_name = 'autotest_dynamo_db'

        self.patcher = patch("boto3.client")
        self.paginator_mock = MagicMock()
        self.dynamo_mock = MagicMock()
        self.dynamo_mock.get_paginator.return_value = self.paginator_mock

        self.boto3_client_patch = self.patcher.start()
        self.boto3_client_patch.return_value = self.dynamo_mock

        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)

    def tearDown(self):
        self.patcher.stop()

    def test_create__raises__if_no_hash_col_configured(self):
        bad_config = deepcopy(self.TEST_CONFIG)
        del bad_config['hash_key']

        dynamo_client = DynamoDbClient(config=bad_config)

        row = {self.HASH_KEY: 'cat', self.RANGE_KEY: '123'}
        self.assertRaises(AssertionError, dynamo_client.create, row,
                          self.table_name)

    def test_create__calls_boto_client(self):
        self.dynamo_mock.put_item.assert_not_called()

        self.dynamo_client.put({
            self.HASH_KEY: 'cat',
            self.RANGE_KEY: '123'
        }, self.table_name)
        self.dynamo_mock.put_item.assert_called_once()

    def test_dict_to_dynamo_strict(self):
        dict_row = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456,
            'some_bool': True,
            'some_bool2': 'True',
            'some_map': {
                'a': 1,
                'b': 'b1',
                'c': {
                    'test': True
                }
            },
            'some_list': ['x', 'y']
        }
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row)
        expected = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            },
            'some_bool': {
                'BOOL': True
            },
            'some_bool2': {
                'BOOL': True
            },
            'some_map': {
                'M': {
                    'a': {
                        'N': '1'
                    },
                    'b': {
                        'S': 'b1'
                    },
                    'c': {
                        'M': {
                            'test': {
                                'BOOL': True
                            }
                        }
                    }
                }
            },
            'some_list': {
                'L': [{
                    'S': 'x'
                }, {
                    'S': 'y'
                }]
            }
        }
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dict_to_dynamo_not_strict(self):
        dict_row = {
            'name': 'cat',
            'age': 3,
            'other_bool': False,
            'other_bool2': 'False',
            'other_map': {
                'a': 1,
                'b': 'b1',
                'c': {
                    'test': True
                }
            },
            'some_list': ['x', 'y']
        }
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row, strict=False)
        expected = {
            'name': {
                'S': 'cat'
            },
            'age': {
                'N': '3'
            },
            'other_bool': {
                'BOOL': False
            },
            'other_map': {
                'M': {
                    'a': {
                        'N': '1'
                    },
                    'b': {
                        'S': 'b1'
                    },
                    'c': {
                        'M': {
                            'test': {
                                'BOOL': True
                            }
                        }
                    }
                }
            },
            'some_list': {
                'L': [{
                    'S': 'x'
                }, {
                    'S': 'y'
                }]
            }
        }
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dict_to_dynamo__not_strict__map_type(self):
        dict_row = {
            'accept_mimetypes': {
                'image/webp': 1,
                'image/apng': 1,
                'image/*': 1,
                '*/*': 0.8
            },
        }
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row, strict=False)
        expected = {}
        logging.info(f"dynamo_row: {dynamo_row}")
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dict_to_dynamo_prefix(self):
        dict_row = {'hash_col': 'cat', 'range_col': '123', 'some_col': 'no'}
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row,
                                                       add_prefix="#")
        expected = {
            '#hash_col': {
                'S': 'cat'
            },
            '#range_col': {
                'N': '123'
            },
            '#some_col': {
                'S': 'no'
            }
        }
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dynamo_to_dict(self):
        dynamo_row = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            },
            'extra_key': {
                'N': '42'
            },
            'some_bool': {
                'BOOL': False
            },
            'some_map': {
                'M': {
                    'a': {
                        'N': '1'
                    },
                    'b': {
                        'S': 'b1'
                    },
                    'c': {
                        'M': {
                            'test': {
                                'BOOL': True
                            }
                        }
                    }
                }
            },
            'some_list': {
                'L': [{
                    'S': 'x'
                }, {
                    'S': 'y'
                }]
            }
        }
        dict_row = self.dynamo_client.dynamo_to_dict(dynamo_row)
        expected = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456,
            'some_bool': False,
            'some_map': {
                'a': 1,
                'b': 'b1',
                'c': {
                    'test': True
                }
            },
            'some_list': ['x', 'y']
        }
        self.assertDictEqual(expected, dict_row)
        for k, v in dict_row.items():
            self.assertNotIsInstance(v, Decimal)
        for k, v in dict_row['some_map'].items():
            self.assertNotIsInstance(v, Decimal)

    def test_dynamo_to_dict_no_strict_row_mapper(self):
        dynamo_row = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            },
            'extra_key_n': {
                'N': '42'
            },
            'extra_key_s': {
                'S': 'wowie'
            },
            'other_bool': {
                'BOOL': True
            }
        }
        dict_row = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                     fetch_all_fields=True)
        expected = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456,
            'extra_key_n': 42,
            'extra_key_s': 'wowie',
            'other_bool': True
        }
        self.assertDictEqual(dict_row, expected)
        for k, v in dict_row.items():
            self.assertNotIsInstance(v, Decimal)

    def test_dynamo_to_dict__dont_json_loads(self):
        config = self.TEST_CONFIG.copy()
        config['dont_json_loads_results'] = True

        self.dynamo_client = DynamoDbClient(config=config)

        dynamo_row = {
            'hash_col': {
                'S': 'aaa'
            },
            'range_col': {
                'N': '123'
            },
            'other_col': {
                'S': '{"how many": 300}'
            },
            'duck_quack': {
                'S': '{"quack": "duck"}'
            }
        }
        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=True)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': '{"how many": 300}',
            'duck_quack': '{"quack": "duck"}'
        }
        self.assertDictEqual(res, expected)

        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=False)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': '{"how many": 300}'
        }
        self.assertDictEqual(res, expected)

    def test_dynamo_to_dict__do_json_loads(self):
        config = self.TEST_CONFIG.copy()
        config['dont_json_loads_results'] = False

        self.dynamo_client = DynamoDbClient(config=config)

        dynamo_row = {
            'hash_col': {
                'S': 'aaa'
            },
            'range_col': {
                'N': '123'
            },
            'other_col': {
                'S': '{"how many": 300}'
            },
            'duck_quack': {
                'S': '{"quack": "duck"}'
            }
        }
        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=True)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': {
                "how many": 300
            },
            'duck_quack': {
                "quack": "duck"
            }
        }
        self.assertDictEqual(res, expected)

        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=False)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': {
                "how many": 300
            }
        }
        self.assertDictEqual(res, expected)

    def test_dynamo_to_dict__mapping_doesnt_match__raises(self):
        # If the value type in the DB doesn't match the expected type in row_mapper - raise ValueError

        dynamo_row = {
            'hash_col': {
                'S': 'aaa'
            },
            'range_col': {
                'N': '123'
            },
            'other_col': {
                'N': '111'
            }  # In the row_mapper, other_col is of type 'S'
        }

        with self.assertRaises(ValueError) as e:
            dict_row = self.dynamo_client.dynamo_to_dict(dynamo_row)

        self.assertEqual(
            "'other_col' is expected to be of type 'S' in row_mapper, but real value is of type 'N'",
            str(e.exception))

    def test_get_by_query__validates_comparison(self):
        self.assertRaises(AssertionError,
                          self.dynamo_client.get_by_query,
                          keys={'k': '1'},
                          comparisons={'k': 'unsupported'})

    def test_get_by_query__between(self):
        keys = {
            'hash_col': 'cat',
            'st_between_range_col': '3',
            'en_between_range_col': '6'
        }

        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)

        self.dynamo_client.get_by_query(keys=keys)
        # print(f"Call_args for paginate: {self.paginator_mock.paginate.call_args}")

        args, kwargs = self.paginator_mock.paginate.call_args
        # print(kwargs)

        self.assertEqual(len(kwargs['ExpressionAttributeValues']), 3)
        self.assertIn(
            'range_col between :st_between_range_col and :en_between_range_col',
            kwargs['KeyConditionExpression'])

    def test_get_by_query__return_count(self):

        # Make sure dynamo paginator is mocked.
        self.paginator_mock.paginate.return_value = [{
            'Count': 24,
            'LastEvaluatedKey': 'bzz'
        }, {
            'Count': 12
        }]
        self.dynamo_client.dynamo_client.get_paginator.return_value = self.paginator_mock

        # Call the manager
        result = self.dynamo_client.get_by_query(keys={'a': 'b'},
                                                 return_count=True)

        # Validate result
        self.assertEqual(
            result, 36,
            f"Result from 2 pages should be 24 + 12, but we received: {result}"
        )

        # Make sure the paginator was called
        self.dynamo_client.dynamo_client.get_paginator.assert_called()

    def test_get_by_query__expr_attr(self):
        keys = {
            'st_between_range_col': '3',
            'en_between_range_col': '6',
            'session': 'ses1'
        }
        expr_attrs_names = ['range_col', 'session']

        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)
        self.dynamo_client.get_by_query(keys=keys,
                                        expr_attrs_names=expr_attrs_names)

        args, kwargs = self.paginator_mock.paginate.call_args
        self.assertIn('#range_col', kwargs['ExpressionAttributeNames'])
        self.assertIn('#session', kwargs['ExpressionAttributeNames'])
        self.assertIn(
            '#range_col between :st_between_range_col and :en_between_range_col AND #session = :session',
            kwargs['KeyConditionExpression'])

    def test__parse_filter_expression(self):
        TESTS = {
            'key = 42': ("key = :filter_key", {
                ":filter_key": {
                    'N': '42'
                }
            }),
            '   key    = 42  ': ("key = :filter_key", {
                ":filter_key": {
                    'N': '42'
                }
            }),
            'cat = meaw': ("cat = :filter_cat", {
                ":filter_cat": {
                    'S': 'meaw'
                }
            }),
            'magic between 41 and 42':
            ("magic between :st_between_magic and :en_between_magic", {
                ":st_between_magic": {
                    'N': '41'
                },
                ":en_between_magic": {
                    'N': '42'
                }
            }),
            'attribute_not_exists boo': ("attribute_not_exists (boo)", {})
        }

        for data, expected in TESTS.items():
            self.assertEqual(self.dynamo_client._parse_filter_expression(data),
                             expected)

    def test__parse_filter_expression__raises(self):

        TESTS = [
            {
                'k': 1
            },
            [1, 2],
            None,  # Invalid input types
            'key == 42',
            'foo ~ 1',
            'foo3 <=> 0',
            'key between 42',  # Invalid operators
            'key between 23, 25',
            'key between [23, 25]',
            'key 23 between 21',  # Invalid between formats.
        ]

        for data in TESTS:
            self.assertRaises((AssertionError, ValueError),
                              self.dynamo_client._parse_filter_expression,
                              data)

    def test_create__calls_put(self):
        row = {'hash_col': 'cat', 'range_key': 'test', 'another_col': 'wow'}
        self.dynamo_client.put = MagicMock(return_value=None)

        self.dynamo_client.create(row)

        self.dynamo_client.put.assert_called_once_with(
            row, None, overwrite_existing=False)

    def test_batch_get_items_one_table__strict(self):
        # Strict - returns only fields that are in the row mapper
        db_items = [{
            'hash_col': {
                'S': 'b'
            },
            'range_col': {
                'N': '10'
            },
            'unknown_col': {
                'S': 'not_strict'
            }
        }]
        db_result = {'Responses': {'autotest_dynamo_db': db_items}}

        self.dynamo_client.dynamo_client.batch_get_item = Mock(
            return_value=db_result)

        result = self.dynamo_client.batch_get_items_one_table(
            keys_list=[{
                'hash_col': 'b'
            }], fetch_all_fields=False)

        self.assertEqual(result, [{'hash_col': 'b', 'range_col': 10}])

    def test_batch_get_items_one_table__not_strict(self):
        # Not strict - returns all fields
        db_items = [{
            'hash_col': {
                'S': 'b'
            },
            'range_col': {
                'N': '10'
            },
            'unknown_col': {
                'S': 'not_strict'
            }
        }]
        db_result = {'Responses': {'autotest_dynamo_db': db_items}}

        self.dynamo_client.dynamo_client.batch_get_item = Mock(
            return_value=db_result)

        result = self.dynamo_client.batch_get_items_one_table(
            keys_list=[{
                'hash_col': 'b'
            }], fetch_all_fields=True)

        self.assertEqual(result, [{
            'hash_col': 'b',
            'range_col': 10,
            'unknown_col': 'not_strict'
        }])

    def test_get_by_query__max_items_and_count__raises(self):
        with self.assertRaises(Exception) as e:
            self.dynamo_client.get_by_query({'hash_col': 'key'},
                                            table_name=self.table_name,
                                            max_items=3,
                                            return_count=True)
        expected_msg = "DynamoDbCLient.get_by_query does not support `max_items` and `return_count` together"
        self.assertEqual(e.exception.args[0], expected_msg)

    def test_patch__transfers_attrs_to_remove(self):

        keys = {'hash_col': 'a'}
        attributes_to_update = {'some_col': 'b'}
        attributes_to_increment = {'some_counter': 3}
        table_name = 'the_table'
        attributes_to_remove = ['remove_me']

        # using kwargs
        self.dynamo_client.update = Mock()

        self.dynamo_client.patch(
            keys=keys,
            attributes_to_update=attributes_to_update,
            attributes_to_increment=attributes_to_increment,
            table_name=table_name,
            attributes_to_remove=attributes_to_remove)

        self.dynamo_client.update.assert_called_once_with(
            keys=keys,
            attributes_to_update=attributes_to_update,
            attributes_to_increment=attributes_to_increment,
            table_name=table_name,
            attributes_to_remove=attributes_to_remove,
            condition_expression='attribute_exists hash_col')

        # not kwargs
        self.dynamo_client.update = Mock()

        self.dynamo_client.patch(keys, attributes_to_update,
                                 attributes_to_increment, table_name,
                                 attributes_to_remove)

        self.dynamo_client.update.assert_called_once_with(
            keys=keys,
            attributes_to_update=attributes_to_update,
            attributes_to_increment=attributes_to_increment,
            table_name=table_name,
            attributes_to_remove=attributes_to_remove,
            condition_expression='attribute_exists hash_col')

    def test_sleep_db__get_capacity_called(self):
        self.dynamo_client.dynamo_client = MagicMock()

        self.dynamo_client.sleep_db(last_action_time=datetime.datetime.now(),
                                    action='write',
                                    table_name='autotest_new')
        self.dynamo_client.dynamo_client.describe_table.assert_called_once()

    def test_sleep_db__wrong_action(self):
        self.assertRaises(KeyError,
                          self.dynamo_client.sleep_db,
                          last_action_time=datetime.datetime.now(),
                          action='call')

    @patch.object(time, 'sleep')
    def test_sleep_db__fell_asleep(self, mock_sleep):
        """ Test for table if BillingMode is PROVISIONED """
        self.dynamo_client.get_capacity = MagicMock(return_value={
            'read': 10,
            'write': 5
        })
        # Check that went to sleep
        time_between_ms = 100
        last_action_time = datetime.datetime.now() - datetime.timedelta(
            milliseconds=time_between_ms)
        self.dynamo_client.sleep_db(last_action_time=last_action_time,
                                    action='write')
        self.assertEqual(mock_sleep.call_count, 1)
        args, kwargs = mock_sleep.call_args

        # Should sleep around 1 / capacity second minus "time_between_ms" minus code execution time
        self.assertGreater(
            args[0], 1 / self.dynamo_client.get_capacity()['write'] -
            time_between_ms - 0.02)
        self.assertLess(args[0],
                        1 / self.dynamo_client.get_capacity()['write'])

    @patch.object(time, 'sleep')
    def test_sleep_db__fell_asleep(self, mock_sleep):
        """ Test for table if BillingMode is PAY_PER_REQUEST """

        self.dynamo_client.get_capacity = MagicMock(return_value={
            'read': 0,
            'write': 0
        })
        self.dynamo_client.sleep_db(last_action_time=datetime.datetime.now(),
                                    action='write')
        # Check that didn't go to sleep
        time_between_ms = 100
        last_action_time = datetime.datetime.now() - datetime.timedelta(
            milliseconds=time_between_ms)
        self.dynamo_client.sleep_db(last_action_time=last_action_time,
                                    action='write')
        self.assertEqual(mock_sleep.call_count, 0)

    @patch.object(time, 'sleep')
    def test_sleep_db__(self, mock_sleep):
        self.dynamo_client.get_capacity = MagicMock(return_value={
            'read': 10,
            'write': 5
        })

        # Shouldn't go to sleep
        last_action_time = datetime.datetime.now() - datetime.timedelta(
            milliseconds=900)
        self.dynamo_client.sleep_db(last_action_time=last_action_time,
                                    action='write')
        # Sleep function should not be called
        self.assertEqual(mock_sleep.call_count, 0)

    @patch.object(time, 'sleep')
    def test_sleep_db__returns_none_for_on_demand(self, mock_sleep):
        self.dynamo_client.dynamo_client = MagicMock()
        self.dynamo_client.dynamo_client.describe_table.return_value = {
            'TableName': 'autotest_OnDemand'
        }

        # Check that went to sleep
        time_between_ms = 10
        last_action_time = datetime.datetime.now() - datetime.timedelta(
            milliseconds=time_between_ms)
        self.dynamo_client.sleep_db(last_action_time=last_action_time,
                                    action='write',
                                    table_name='autotest_OnDemand')

        self.assertEqual(mock_sleep.call_count, 0,
                         "Should not have called time.sleep")

    def test_on_demand_provisioned_throughput__get_capacity(self):
        self.dynamo_client.dynamo_client = MagicMock()
        self.dynamo_client.dynamo_client.describe_table.return_value = {
            'TableName': 'autotest_OnDemand'
        }

        result = self.dynamo_client.get_capacity(
            table_name='autotest_OnDemand')
        self.assertIsNone(result)

    def test_on_demand_provisioned_throughput__get_table_indexes(self):
        self.dynamo_client.dynamo_client = MagicMock()
        self.dynamo_client.dynamo_client.describe_table.return_value = {
            'Table': {
                'TableName':
                'autotest_OnDemandTable',
                'LocalSecondaryIndexes': [],
                'GlobalSecondaryIndexes': [{
                    'IndexName':
                    'IndexA',
                    'KeySchema': [
                        {
                            'AttributeName': 'SomeAttr',
                            'KeyType': 'HASH',
                        },
                    ],
                    'Projection': {
                        'ProjectionType': 'ALL',
                    }
                }]
            }
        }

        result = self.dynamo_client.get_table_indexes(
            table_name='autotest_OnDemandTable')
        self.assertIsNone(result['IndexA'].get('ProvisionedThroughput'))

    def test_get_table_indexes__ppr(self):
        """ Check return value of get_table_indexes function in case table BillingMode is PAY_PER_REQUEST """

        self.dynamo_client._describe_table = Mock(
            return_value=PPR_DESCRIBE_TABLE)
        expected_indexes = {
            'session': {
                'projection_type': 'ALL',
                'hash_key': 'session',
                'provisioned_throughput': {
                    'write_capacity': 0,
                    'read_capacity': 0
                }
            },
            'session_id': {
                'projection_type': 'ALL',
                'hash_key': 'session_id',
                'provisioned_throughput': {
                    'write_capacity': 0,
                    'read_capacity': 0
                }
            },
        }
        self.assertEqual(expected_indexes,
                         self.dynamo_client.get_table_indexes('actions'))

    def test_get_table_indexes__pt(self):
        """ Check return value of get_table_indexes function in case table BillingMode is PROVISIONED """

        self.dynamo_client._describe_table = Mock(
            return_value=PT_DESCRIBE_TABLE)
        expected_indexes = {
            'name': {
                'projection_type': 'ALL',
                'hash_key': 'name',
                'provisioned_throughput': {
                    'write_capacity': 10,
                    'read_capacity': 100
                }
            },
            'city': {
                'projection_type': 'ALL',
                'hash_key': 'city',
                'provisioned_throughput': {
                    'write_capacity': 10,
                    'read_capacity': 100
                }
            },
        }
        self.assertEqual(expected_indexes,
                         self.dynamo_client.get_table_indexes('partners'))
示例#6
0
class dynamodb_client_UnitTestCase(unittest.TestCase):
    TEST_CONFIG = {
        'row_mapper': {
            'lambda_name': 'S',
            'invocation_id': 'S',
            'en_time': 'N',
            'hash_col': 'S',
            'range_col': 'N',
            'other_col': 'S',
            'new_col': 'S',
            'some_col': 'S',
            'some_counter': 'N',
            'some_bool': 'BOOL',
            'some_bool2': 'BOOL',
            'some_map': 'M',
            'some_list': 'L'
        },
        'required_fields': ['lambda_name'],
        'table_name': 'autotest_dynamo_db'
    }

    def setUp(self):
        self.HASH_KEY = ('hash_col', 'S')
        self.RANGE_KEY = ('range_col', 'N')
        self.KEYS = ('hash_col', 'range_col')
        self.table_name = 'autotest_dynamo_db'

        self.patcher = patch("boto3.client")
        self.paginator_mock = MagicMock()
        self.dynamo_mock = MagicMock()
        self.dynamo_mock.get_paginator.return_value = self.paginator_mock

        self.boto3_client_patch = self.patcher.start()
        self.boto3_client_patch.return_value = self.dynamo_mock

        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)

    def tearDown(self):
        self.patcher.stop()

    def test_dict_to_dynamo_strict(self):
        dict_row = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456,
            'some_bool': True,
            'some_bool2': 'True',
            'some_map': {
                'a': 1,
                'b': 'b1',
                'c': {
                    'test': True
                }
            },
            'some_list': ['x', 'y']
        }
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row)
        expected = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            },
            'some_bool': {
                'BOOL': True
            },
            'some_bool2': {
                'BOOL': True
            },
            'some_map': {
                'M': {
                    'a': {
                        'N': '1'
                    },
                    'b': {
                        'S': 'b1'
                    },
                    'c': {
                        'M': {
                            'test': {
                                'BOOL': True
                            }
                        }
                    }
                }
            },
            'some_list': {
                'L': [{
                    'S': 'x'
                }, {
                    'S': 'y'
                }]
            }
        }
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dict_to_dynamo_not_strict(self):
        dict_row = {
            'name': 'cat',
            'age': 3,
            'other_bool': False,
            'other_bool2': 'False',
            'other_map': {
                'a': 1,
                'b': 'b1',
                'c': {
                    'test': True
                }
            },
            'some_list': ['x', 'y']
        }
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row, strict=False)
        expected = {
            'name': {
                'S': 'cat'
            },
            'age': {
                'N': '3'
            },
            'other_bool': {
                'BOOL': False
            },
            'other_map': {
                'M': {
                    'a': {
                        'N': '1'
                    },
                    'b': {
                        'S': 'b1'
                    },
                    'c': {
                        'M': {
                            'test': {
                                'BOOL': True
                            }
                        }
                    }
                }
            },
            'some_list': {
                'L': [{
                    'S': 'x'
                }, {
                    'S': 'y'
                }]
            }
        }
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dict_to_dynamo__not_strict__map_type(self):
        dict_row = {
            'accept_mimetypes': {
                'image/webp': 1,
                'image/apng': 1,
                'image/*': 1,
                '*/*': 0.8
            },
        }
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row, strict=False)
        expected = {}
        logging.info(f"dynamo_row: {dynamo_row}")
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dict_to_dynamo_prefix(self):
        dict_row = {'hash_col': 'cat', 'range_col': '123', 'some_col': 'no'}
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row,
                                                       add_prefix="#")
        expected = {
            '#hash_col': {
                'S': 'cat'
            },
            '#range_col': {
                'N': '123'
            },
            '#some_col': {
                'S': 'no'
            }
        }
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dynamo_to_dict(self):
        dynamo_row = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            },
            'extra_key': {
                'N': '42'
            },
            'some_bool': {
                'BOOL': False
            },
            'some_map': {
                'M': {
                    'a': {
                        'N': '1'
                    },
                    'b': {
                        'S': 'b1'
                    },
                    'c': {
                        'M': {
                            'test': {
                                'BOOL': True
                            }
                        }
                    }
                }
            },
            'some_list': {
                'L': [{
                    'S': 'x'
                }, {
                    'S': 'y'
                }]
            }
        }
        dict_row = self.dynamo_client.dynamo_to_dict(dynamo_row)
        expected = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456,
            'some_bool': False,
            'some_map': {
                'a': 1,
                'b': 'b1',
                'c': {
                    'test': True
                }
            },
            'some_list': ['x', 'y']
        }
        self.assertDictEqual(expected, dict_row)
        for k, v in dict_row.items():
            self.assertNotIsInstance(v, Decimal)
        for k, v in dict_row['some_map'].items():
            self.assertNotIsInstance(v, Decimal)

    def test_dynamo_to_dict_no_strict_row_mapper(self):
        dynamo_row = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            },
            'extra_key_n': {
                'N': '42'
            },
            'extra_key_s': {
                'S': 'wowie'
            },
            'other_bool': {
                'BOOL': True
            }
        }
        dict_row = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                     fetch_all_fields=True)
        expected = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456,
            'extra_key_n': 42,
            'extra_key_s': 'wowie',
            'other_bool': True
        }
        self.assertDictEqual(dict_row, expected)
        for k, v in dict_row.items():
            self.assertNotIsInstance(v, Decimal)

    def test_dynamo_to_dict__dont_json_loads(self):
        config = self.TEST_CONFIG.copy()
        config['dont_json_loads_results'] = True

        self.dynamo_client = DynamoDbClient(config=config)

        dynamo_row = {
            'hash_col': {
                'S': 'aaa'
            },
            'range_col': {
                'N': '123'
            },
            'other_col': {
                'S': '{"how many": 300}'
            },
            'duck_quack': {
                'S': '{"quack": "duck"}'
            }
        }
        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=True)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': '{"how many": 300}',
            'duck_quack': '{"quack": "duck"}'
        }
        self.assertDictEqual(res, expected)

        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=False)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': '{"how many": 300}'
        }
        self.assertDictEqual(res, expected)

    def test_dynamo_to_dict__do_json_loads(self):
        config = self.TEST_CONFIG.copy()
        config['dont_json_loads_results'] = False

        self.dynamo_client = DynamoDbClient(config=config)

        dynamo_row = {
            'hash_col': {
                'S': 'aaa'
            },
            'range_col': {
                'N': '123'
            },
            'other_col': {
                'S': '{"how many": 300}'
            },
            'duck_quack': {
                'S': '{"quack": "duck"}'
            }
        }
        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=True)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': {
                "how many": 300
            },
            'duck_quack': {
                "quack": "duck"
            }
        }
        self.assertDictEqual(res, expected)

        res = self.dynamo_client.dynamo_to_dict(dynamo_row,
                                                fetch_all_fields=False)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': {
                "how many": 300
            }
        }
        self.assertDictEqual(res, expected)

    def test_get_by_query__validates_comparison(self):
        self.assertRaises(AssertionError,
                          self.dynamo_client.get_by_query,
                          keys={'k': '1'},
                          comparisons={'k': 'unsupported'})

    def test_get_by_query__between(self):
        keys = {
            'hash_col': 'cat',
            'st_between_range_col': '3',
            'en_between_range_col': '6'
        }

        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)

        self.dynamo_client.get_by_query(keys=keys)
        # print(f"Call_args for paginate: {self.paginator_mock.paginate.call_args}")

        args, kwargs = self.paginator_mock.paginate.call_args
        # print(kwargs)

        self.assertEqual(len(kwargs['ExpressionAttributeValues']), 3)
        self.assertIn(
            'range_col between :st_between_range_col and :en_between_range_col',
            kwargs['KeyConditionExpression'])

    def test_get_by_query__return_count(self):

        # Make sure dynamo paginator is mocked.
        self.paginator_mock.paginate.return_value = [{
            'Count': 24,
            'LastEvaluatedKey': 'bzz'
        }, {
            'Count': 12
        }]
        self.dynamo_client.dynamo_client.get_paginator.return_value = self.paginator_mock

        # Call the manager
        result = self.dynamo_client.get_by_query(keys={'a': 'b'},
                                                 return_count=True)

        # Validate result
        self.assertEqual(
            result, 36,
            f"Result from 2 pages should be 24 + 12, but we received: {result}"
        )

        # Make sure the paginator was called
        self.dynamo_client.dynamo_client.get_paginator.assert_called()

    def test__parse_filter_expression(self):
        TESTS = {
            'key = 42': ("key = :filter_key", {
                ":filter_key": {
                    'N': '42'
                }
            }),
            '   key    = 42  ': ("key = :filter_key", {
                ":filter_key": {
                    'N': '42'
                }
            }),
            'cat = meaw': ("cat = :filter_cat", {
                ":filter_cat": {
                    'S': 'meaw'
                }
            }),
            'magic between 41 and 42':
            ("magic between :st_between_magic and :en_between_magic", {
                ":st_between_magic": {
                    'N': '41'
                },
                ":en_between_magic": {
                    'N': '42'
                }
            }),
            'attribute_not_exists boo': ("attribute_not_exists (boo)", {})
        }

        for data, expected in TESTS.items():
            self.assertEqual(self.dynamo_client._parse_filter_expression(data),
                             expected)

    def test__parse_filter_expression__raises(self):

        TESTS = [
            {
                'k': 1
            },
            [1, 2],
            None,  # Invalid input types
            'key == 42',
            'foo ~ 1',
            'foo3 <=> 0',
            'key between 42',  # Invalid operators
            'key between 23, 25',
            'key between [23, 25]',
            'key 23 between 21',  # Invalid between formats.
        ]

        for data in TESTS:
            self.assertRaises((AssertionError, ValueError),
                              self.dynamo_client._parse_filter_expression,
                              data)

    def test_create__calls_put(self):
        row = {'hash_col': 'cat', 'range_key': 'test', 'another_col': 'wow'}
        self.dynamo_client.put = MagicMock(return_value=None)

        self.dynamo_client.create(row)

        self.dynamo_client.put.assert_called_once_with(
            row, None, overwrite_existing=False)

    def test_batch_get_items_one_table__strict(self):
        # Strict - returns only fields that are in the row mapper
        db_items = [{
            'hash_col': {
                'S': 'b'
            },
            'range_col': {
                'N': '10'
            },
            'unknown_col': {
                'S': 'not_strict'
            }
        }]
        db_result = {'Responses': {'autotest_dynamo_db': db_items}}

        self.dynamo_client.dynamo_client.batch_get_item = Mock(
            return_value=db_result)

        result = self.dynamo_client.batch_get_items_one_table(
            keys_list=[{
                'hash_col': 'b'
            }], fetch_all_fields=False)

        self.assertEqual(result, [{'hash_col': 'b', 'range_col': 10}])

    def test_batch_get_items_one_table__not_strict(self):
        # Not strict - returns all fields
        db_items = [{
            'hash_col': {
                'S': 'b'
            },
            'range_col': {
                'N': '10'
            },
            'unknown_col': {
                'S': 'not_strict'
            }
        }]
        db_result = {'Responses': {'autotest_dynamo_db': db_items}}

        self.dynamo_client.dynamo_client.batch_get_item = Mock(
            return_value=db_result)

        result = self.dynamo_client.batch_get_items_one_table(
            keys_list=[{
                'hash_col': 'b'
            }], fetch_all_fields=True)

        self.assertEqual(result, [{
            'hash_col': 'b',
            'range_col': 10,
            'unknown_col': 'not_strict'
        }])

    def test_get_by_query__max_items_and_count__raises(self):
        with self.assertRaises(Exception) as e:
            self.dynamo_client.get_by_query({'hash_col': 'key'},
                                            table_name=self.table_name,
                                            max_items=3,
                                            return_count=True)
        expected_msg = "DynamoDbCLient.get_by_query does not support `max_items` and `return_count` together"
        self.assertEqual(e.exception.args[0], expected_msg)
示例#7
0
class dynamodb_client_IntegrationTestCase(unittest.TestCase):
    TEST_CONFIG = {
        'row_mapper': {
            'lambda_name': 'S',
            'invocation_id': 'S',
            'en_time': 'N',
            'hash_col': 'S',
            'range_col': 'N',
            'other_col': 'S',
            'new_col': 'S',
            'some_col': 'S',
            'some_counter': 'N'
        },
        'required_fields': ['lambda_name'],
        'table_name': 'autotest_dynamo_db'
    }

    @classmethod
    def setUpClass(cls):
        clean_dynamo_table()

    def setUp(self):
        self.HASH_KEY = ('hash_col', 'S')
        self.RANGE_KEY = ('range_col', 'N')
        self.KEYS = ('hash_col', 'range_col')
        self.table_name = 'autotest_dynamo_db'
        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)

    def tearDown(self):
        clean_dynamo_table(self.table_name, self.KEYS)

    def test_put(self):
        row = {'hash_col': 'cat', 'range_col': '123'}

        client = boto3.client('dynamodb')

        client.delete_item(TableName=self.table_name,
                           Key={
                               'hash_col': {
                                   'S': str(row['hash_col'])
                               },
                               'range_col': {
                                   'N': str(row['range_col'])
                               },
                           })

        self.dynamo_client.put(row, self.table_name)

        result = client.scan(
            TableName=self.table_name,
            FilterExpression="hash_col = :hash_col AND range_col = :range_col",
            ExpressionAttributeValues={
                ':hash_col': {
                    'S': row['hash_col']
                },
                ':range_col': {
                    'N': str(row['range_col'])
                }
            })

        items = result['Items']

        self.assertTrue(len(items) > 0)

    def test_update__updates(self):
        keys = {'hash_col': 'cat', 'range_col': '123'}
        row = {
            'hash_col': 'cat',
            'range_col': '123',
            'some_col': 'no',
            'other_col': 'foo'
        }
        attributes_to_update = {'some_col': 'yes', 'new_col': 'yup'}

        self.dynamo_client.put(row, self.table_name)

        client = boto3.client('dynamodb')

        # First check that the row we are trying to update is PUT correctly.
        initial_row = client.get_item(
            Key={
                'hash_col': {
                    'S': row['hash_col']
                },
                'range_col': {
                    'N': str(row['range_col'])
                }
            },
            TableName=self.table_name,
        )['Item']

        initial_row = self.dynamo_client.dynamo_to_dict(initial_row)

        self.assertIsNotNone(initial_row)
        self.assertEqual(initial_row['some_col'], 'no')
        self.assertEqual(initial_row['other_col'], 'foo')

        self.dynamo_client.update(keys,
                                  attributes_to_update,
                                  table_name=self.table_name)

        updated_row = client.get_item(
            Key={
                'hash_col': {
                    'S': row['hash_col']
                },
                'range_col': {
                    'N': str(row['range_col'])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_col'],
                         'yes'), "Updated field not really updated"
        self.assertEqual(updated_row['new_col'],
                         'yup'), "New field was not created"
        self.assertEqual(
            updated_row['other_col'],
            'foo'), "This field should be preserved, update() damaged it"

    def test_update__increment(self):
        keys = {'hash_col': 'cat', 'range_col': '123'}
        row = {
            'hash_col': 'cat',
            'range_col': '123',
            'some_col': 'no',
            'some_counter': 10
        }
        attributes_to_increment = {'some_counter': '1'}

        self.dynamo_client.put(row, self.table_name)

        self.dynamo_client.update(
            keys, {},
            attributes_to_increment=attributes_to_increment,
            table_name=self.table_name)

        client = boto3.client('dynamodb')

        updated_row = client.get_item(
            Key={
                'hash_col': {
                    'S': row['hash_col']
                },
                'range_col': {
                    'N': str(row['range_col'])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_counter'], 11)

    def test_update__increment_2(self):
        keys = {'hash_col': 'cat', 'range_col': '123'}
        row = {
            'hash_col': 'cat',
            'range_col': '123',
            'some_col': 'no',
            'some_counter': 10
        }
        attributes_to_increment = {'some_counter': 5}

        self.dynamo_client.put(row, self.table_name)

        self.dynamo_client.update(
            keys, {},
            attributes_to_increment=attributes_to_increment,
            table_name=self.table_name)

        client = boto3.client('dynamodb')

        updated_row = client.get_item(
            Key={
                'hash_col': {
                    'S': row['hash_col']
                },
                'range_col': {
                    'N': str(row['range_col'])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_counter'], 15)

    def test_update__increment_no_default(self):
        keys = {'hash_col': 'cat', 'range_col': '123'}
        row = {'hash_col': 'cat', 'range_col': '123', 'some_col': 'no'}
        attributes_to_increment = {'some_counter': '3'}

        self.dynamo_client.put(row, self.table_name)

        self.dynamo_client.update(
            keys, {},
            attributes_to_increment=attributes_to_increment,
            table_name=self.table_name)

        client = boto3.client('dynamodb')

        updated_row = client.get_item(
            Key={
                'hash_col': {
                    'S': row['hash_col']
                },
                'range_col': {
                    'N': str(row['range_col'])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)

        self.assertIsNotNone(updated_row)
        self.assertEqual(updated_row['some_counter'], 3)

    def test_update__condition_expression(self):
        keys = {'hash_col': 'slime', 'range_col': '41'}
        row = {'hash_col': 'slime', 'range_col': '41', 'some_col': 'no'}

        self.dynamo_client.put(row, self.table_name)

        # Should fail because conditional expression does not match
        self.assertRaises(self.dynamo_client.dynamo_client.exceptions.
                          ConditionalCheckFailedException,
                          self.dynamo_client.update,
                          keys, {},
                          attributes_to_increment={'some_counter': '3'},
                          condition_expression='some_col = yes',
                          table_name=self.table_name)

        # Should pass
        self.dynamo_client.update(
            keys, {},
            attributes_to_increment={'some_counter': '3'},
            condition_expression='some_col = no',
            table_name=self.table_name)

        client = boto3.client('dynamodb')
        updated_row = client.get_item(
            Key={
                'hash_col': {
                    'S': row['hash_col']
                },
                'range_col': {
                    'N': str(row['range_col'])
                }
            },
            TableName=self.table_name,
        )['Item']

        updated_row = self.dynamo_client.dynamo_to_dict(updated_row)
        self.assertEqual(updated_row['some_counter'], 3)

    def test_get_by_query__primary_index(self):
        keys = {'hash_col': 'cat', 'range_col': '123'}
        row = {'hash_col': 'cat', 'range_col': 123, 'some_col': 'test'}
        self.dynamo_client.put(row, self.table_name)

        result = self.dynamo_client.get_by_query(keys=keys)

        self.assertEqual(len(result), 1)
        result = result[0]
        for key in row:
            self.assertEqual(row[key], result[key])
        for key in result:
            self.assertEqual(row[key], result[key])

    def test_get_by_query__primary_index__gets_multiple(self):
        row = {'hash_col': 'cat', 'range_col': 123, 'some_col': 'test'}
        self.dynamo_client.put(row, self.table_name)

        row2 = {'hash_col': 'cat', 'range_col': 1234, 'some_col': 'test2'}
        self.dynamo_client.put(row2, self.table_name)

        result = self.dynamo_client.get_by_query(keys={'hash_col': 'cat'})

        self.assertEqual(len(result), 2)

        result1 = [x for x in result if x['range_col'] == row['range_col']][0]
        result2 = [x for x in result if x['range_col'] == row2['range_col']][0]

        for key in row:
            self.assertEqual(row[key], result1[key])
        for key in result1:
            self.assertEqual(row[key], result1[key])
        for key in row2:
            self.assertEqual(row2[key], result2[key])
        for key in result2:
            self.assertEqual(row2[key], result2[key])

    def test_get_by_query__secondary_index(self):
        keys = {'hash_col': 'cat', 'other_col': 'abc123'}
        row = {'hash_col': 'cat', 'range_col': 123, 'other_col': 'abc123'}
        self.dynamo_client.put(row, self.table_name)

        result = self.dynamo_client.get_by_query(keys=keys,
                                                 index_name='autotest_index')

        self.assertEqual(len(result), 1)
        result = result[0]
        for key in row:
            self.assertEqual(row[key], result[key])
        for key in result:
            self.assertEqual(row[key], result[key])

    def test_get_by_query__comparison(self):
        keys = {'hash_col': 'cat', 'range_col': '300'}
        row1 = {'hash_col': 'cat', 'range_col': 123, 'other_col': 'abc123'}
        row2 = {'hash_col': 'cat', 'range_col': 456, 'other_col': 'abc123'}
        self.dynamo_client.put(row1, self.table_name)
        self.dynamo_client.put(row2, self.table_name)

        result = self.dynamo_client.get_by_query(
            keys=keys, comparisons={'range_col': '<='})

        self.assertEqual(len(result), 1)

        result = result[0]
        self.assertEqual(result, row1)

    def test_get_by_query__comparison_between(self):
        # Put sample data
        x = [
            self.dynamo_client.put({
                'hash_col': 'cat',
                'range_col': x
            }, self.table_name) for x in range(10)
        ]

        keys = {
            'hash_col': 'cat',
            'st_between_range_col': '3',
            'en_between_range_col': '6'
        }
        result = self.dynamo_client.get_by_query(
            keys=keys, comparisons={'range_col': 'between'})
        # print(result)
        self.assertTrue(all(x['range_col'] in range(3, 7) for x in result))

        result = self.dynamo_client.get_by_query(keys=keys)
        # print(result)
        self.assertTrue(all(x['range_col'] in range(3, 7) for x in result)), "Failed if unspecified comparison. " \
                                                                             "Should be automatic for :st_between_..."

    def test_get_by_query__filter_expression(self):
        """
        This _integration_ test runs multiple checks with same sample data for several comparators.
        Have a look at the manual if required:
        https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html
        """

        # Put sample data
        [
            self.dynamo_client.put({
                'hash_col': 'cat',
                'range_col': x
            }, self.table_name) for x in range(3)
        ]
        [
            self.dynamo_client.put(
                {
                    'hash_col': 'cat',
                    'range_col': x,
                    'mark': 1
                }, self.table_name) for x in range(3, 6)
        ]
        self.dynamo_client.put({
            'hash_col': 'cat',
            'range_col': 6,
            'mark': 0
        }, self.table_name)
        self.dynamo_client.put({
            'hash_col': 'cat',
            'range_col': 7,
            'mark': 'a'
        }, self.table_name)

        # Condition by range_col will return five rows out of six: 0 - 4
        # Filter expression neggs the first three rows because they don't have `mark = 1`.
        keys = {'hash_col': 'cat', 'range_col': 4}
        result = self.dynamo_client.get_by_query(
            keys=keys,
            comparisons={'range_col': '<='},
            strict=False,
            filter_expression='mark = 1')
        # print(result)

        self.assertEqual(len(result), 2)
        self.assertEqual(result[0], {
            'hash_col': 'cat',
            'range_col': 3,
            'mark': 1
        })
        self.assertEqual(result[1], {
            'hash_col': 'cat',
            'range_col': 4,
            'mark': 1
        })

        # In the same test we check also some comparator _functions_.
        result = self.dynamo_client.get_by_query(
            keys=keys,
            comparisons={'range_col': '<='},
            strict=False,
            filter_expression='attribute_exists mark')
        # print(result)
        self.assertEqual(len(result), 2)
        self.assertEqual([x['range_col'] for x in result], list(range(3, 5)))

        self.assertEqual(result[0], {
            'hash_col': 'cat',
            'range_col': 3,
            'mark': 1
        })
        self.assertEqual(result[1], {
            'hash_col': 'cat',
            'range_col': 4,
            'mark': 1
        })

        result = self.dynamo_client.get_by_query(
            keys=keys,
            comparisons={'range_col': '<='},
            strict=False,
            filter_expression='attribute_not_exists mark')
        # print(result)
        self.assertEqual(len(result), 3)
        self.assertEqual([x['range_col'] for x in result], list(range(3)))

    def test_get_by_query__comparison_begins_with(self):
        self.table_name = 'autotest_config_component'  # This table has a string range key
        self.HASH_KEY = ('env', 'S')
        self.RANGE_KEY = ('config_name', 'S')
        self.KEYS = ('env', 'config_name')
        config = {
            'row_mapper': {
                'env': 'S',
                'config_name': 'S',
                'config_value': 'S'
            },
            'required_fields': ['env', 'config_name', 'config_value'],
            'table_name': 'autotest_config_component'
        }

        self.dynamo_client = DynamoDbClient(config=config)

        row1 = {
            'env': 'cat',
            'config_name': 'testzing',
            'config_value': 'abc123'
        }
        row2 = {
            'env': 'cat',
            'config_name': 'dont_get_this',
            'config_value': 'abc123'
        }
        row3 = {
            'env': 'cat',
            'config_name': 'testzer',
            'config_value': 'abc124'
        }
        self.dynamo_client.put(row1, self.table_name)
        self.dynamo_client.put(row2, self.table_name)
        self.dynamo_client.put(row3, self.table_name)

        keys = {'env': 'cat', 'config_name': 'testz'}
        result = self.dynamo_client.get_by_query(
            keys=keys,
            table_name=self.table_name,
            comparisons={'config_name': 'begins_with'})

        self.assertEqual(len(result), 2)

        self.assertTrue(row1 in result)
        self.assertTrue(row3 in result)

    def test_get_by_query__max_items(self):
        # This function can also be used for some benchmarking, just change to bigger amounts manually.
        INITIAL_TASKS = 5  # Change to 500 to run benchmarking, and uncomment raise at the end of the test.

        for x in range(1000, 1000 + INITIAL_TASKS):
            row = {'hash_col': f"key", 'range_col': x}
            self.dynamo_client.put(row, self.table_name)
            if INITIAL_TASKS > 10:
                time.sleep(
                    0.1
                )  # Sleep a little to fit the Write Capacity (10 WCU) of autotest table.

        st = time.perf_counter()
        result = self.dynamo_client.get_by_query({'hash_col': 'key'},
                                                 table_name=self.table_name,
                                                 max_items=3)
        bm = time.perf_counter() - st
        print(f"Benchmark: {bm}")

        self.assertEqual(len(result), 3)
        self.assertLess(bm, 0.1)

        # Check unspecified limit.
        result = self.dynamo_client.get_by_query({'hash_col': 'key'},
                                                 table_name=self.table_name)
        self.assertEqual(len(result), INITIAL_TASKS)

        # Benchmarking
        if INITIAL_TASKS >= 500:
            st = time.perf_counter()
            result = self.dynamo_client.get_by_query(
                {'hash_col': 'key'}, table_name=self.table_name, max_items=499)
            bm = time.perf_counter() - st
            print(f"Benchmark: {bm}")
            self.assertLess(bm, 0.1)

            self.assertEqual(len(result), 499)
            # Uncomment this see benchmark
            # self.assertEqual(1, 2)

    def test_get_by_query__return_count(self):
        rows = [{
            'hash_col': 'cat1',
            'range_col': 121,
            'some_col': 'test1'
        }, {
            'hash_col': 'cat1',
            'range_col': 122,
            'some_col': 'test2'
        }, {
            'hash_col': 'cat1',
            'range_col': 123,
            'some_col': 'test3'
        }]

        for x in rows:
            self.dynamo_client.put(x, table_name=self.table_name)

        result = self.dynamo_client.get_by_query({'hash_col': 'cat1'},
                                                 table_name=self.table_name,
                                                 return_count=True)

        self.assertEqual(result, 3)

    def test_get_by_query__reverse(self):
        rows = [{
            'hash_col': 'cat1',
            'range_col': 121,
            'some_col': 'test1'
        }, {
            'hash_col': 'cat1',
            'range_col': 122,
            'some_col': 'test2'
        }, {
            'hash_col': 'cat1',
            'range_col': 123,
            'some_col': 'test3'
        }]

        for x in rows:
            self.dynamo_client.put(x, table_name=self.table_name)

        result = self.dynamo_client.get_by_query({'hash_col': 'cat1'},
                                                 table_name=self.table_name,
                                                 desc=True)

        self.assertEqual(result[0], rows[-1])

    def test_get_by_scan__all(self):
        rows = [{
            'hash_col': 'cat1',
            'range_col': 121,
            'some_col': 'test1'
        }, {
            'hash_col': 'cat2',
            'range_col': 122,
            'some_col': 'test2'
        }, {
            'hash_col': 'cat3',
            'range_col': 123,
            'some_col': 'test3'
        }]
        for x in rows:
            self.dynamo_client.put(x, self.table_name)

        result = self.dynamo_client.get_by_scan()

        self.assertEqual(len(result), 3)

        for r in rows:
            assert r in result, f"row not in result from dynamo scan: {r}"

    def test_get_by_scan__with_filter(self):
        rows = [
            {
                'hash_col': 'cat1',
                'range_col': 121,
                'some_col': 'test1'
            },
            {
                'hash_col': 'cat1',
                'range_col': 122,
                'some_col': 'test2'
            },
            {
                'hash_col': 'cat2',
                'range_col': 122,
                'some_col': 'test2'
            },
        ]
        for x in rows:
            self.dynamo_client.put(x, self.table_name)

        filter = {'some_col': 'test2'}

        result = self.dynamo_client.get_by_scan(attrs=filter)

        self.assertEqual(len(result), 2)

        for r in rows[1:]:
            assert r in result, f"row not in result from dynamo scan: {r}"

    def test_batch_get_items(self):
        rows = [
            {
                'hash_col': 'cat1',
                'range_col': 121,
                'some_col': 'test1'
            },
            {
                'hash_col': 'cat1',
                'range_col': 122,
                'some_col': 'test2'
            },
            {
                'hash_col': 'cat2',
                'range_col': 122,
                'some_col': 'test2'
            },
        ]
        for x in rows:
            self.dynamo_client.put(x, self.table_name)

        keys_list_query = [
            {
                'hash_col': 'cat1',
                'range_col': 121
            },
            {
                'hash_col': 'doesnt_exist',
                'range_col': 40
            },
            {
                'hash_col': 'cat2',
                'range_col': 122
            },
        ]

        result = self.dynamo_client.batch_get_items_one_table(keys_list_query)

        self.assertEqual(len(result), 2)

        self.assertIn(rows[0], result)
        self.assertIn(rows[2], result)

    def test_delete(self):
        self.dynamo_client.put({'hash_col': 'cat1', 'range_col': 123})
        self.dynamo_client.put({'hash_col': 'cat2', 'range_col': 234})

        self.dynamo_client.delete(keys={
            'hash_col': 'cat1',
            'range_col': '123'
        })

        items = self.dynamo_client.get_by_scan()

        self.assertEqual(len(items), 1)
        self.assertEqual(items[0], {'hash_col': 'cat2', 'range_col': 234})
示例#8
0
class dynamodb_client_UnitTestCase(unittest.TestCase):
    TEST_CONFIG = {
        'row_mapper': {
            'lambda_name': 'S',
            'invocation_id': 'S',
            'en_time': 'N',
            'hash_col': 'S',
            'range_col': 'N',
            'other_col': 'S',
            'new_col': 'S',
            'some_col': 'S',
            'some_counter': 'N'
        },
        'required_fields': ['lambda_name'],
        'table_name': 'autotest_dynamo_db'
    }

    def setUp(self):
        self.HASH_KEY = ('hash_col', 'S')
        self.RANGE_KEY = ('range_col', 'N')
        self.KEYS = ('hash_col', 'range_col')
        self.table_name = 'autotest_dynamo_db'

        self.patcher = patch("boto3.client")
        self.dynamo_mock = MagicMock()
        self.paginator_mock = MagicMock()

        self.boto3_client_patch = self.patcher.start()
        self.boto3_client_patch.return_value = self.dynamo_mock
        self.dynamo_mock.get_paginator.return_value = self.paginator_mock

        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)

    def tearDown(self):
        self.patcher.stop()

    def test_dict_to_dynamo_strict(self):
        dict_row = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456
        }
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row)
        expected = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            }
        }
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dict_to_dynamo_not_strict(self):
        dict_row = {'name': 'cat', 'age': 3}
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row, strict=False)
        expected = {'name': {'S': 'cat'}, 'age': {'N': '3'}}
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dict_to_dynamo_prefix(self):
        dict_row = {'hash_col': 'cat', 'range_col': '123', 'some_col': 'no'}
        dynamo_row = self.dynamo_client.dict_to_dynamo(dict_row,
                                                       add_prefix="#")
        expected = {
            '#hash_col': {
                'S': 'cat'
            },
            '#range_col': {
                'N': '123'
            },
            '#some_col': {
                'S': 'no'
            }
        }
        for key in expected.keys():
            self.assertDictEqual(expected[key], dynamo_row[key])

    def test_dynamo_to_dict(self):
        dynamo_row = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            },
            'extra_key': {
                'N': '42'
            }
        }
        dict_row = self.dynamo_client.dynamo_to_dict(dynamo_row)
        expected = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456
        }
        self.assertDictEqual(dict_row, expected)

    def test_dynamo_to_dict_no_strict_row_mapper(self):
        dynamo_row = {
            'lambda_name': {
                'S': 'test_name'
            },
            'invocation_id': {
                'S': 'test_id'
            },
            'en_time': {
                'N': '123456'
            },
            'extra_key_n': {
                'N': '42'
            },
            'extra_key_s': {
                'S': 'wowie'
            }
        }
        dict_row = self.dynamo_client.dynamo_to_dict(dynamo_row, strict=False)
        expected = {
            'lambda_name': 'test_name',
            'invocation_id': 'test_id',
            'en_time': 123456,
            'extra_key_n': 42,
            'extra_key_s': 'wowie'
        }
        self.assertDictEqual(dict_row, expected)

    def test_dynamo_to_dict__dont_json_loads(self):
        config = self.TEST_CONFIG.copy()
        config['dont_json_loads_results'] = True

        self.dynamo_client = DynamoDbClient(config=config)

        dynamo_row = {
            'hash_col': {
                'S': 'aaa'
            },
            'range_col': {
                'N': '123'
            },
            'other_col': {
                'S': '{"how many": 300}'
            },
            'duck_quack': {
                'S': '{"quack": "duck"}'
            }
        }
        res = self.dynamo_client.dynamo_to_dict(dynamo_row, strict=False)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': '{"how many": 300}',
            'duck_quack': '{"quack": "duck"}'
        }
        self.assertDictEqual(res, expected)

        res = self.dynamo_client.dynamo_to_dict(dynamo_row, strict=True)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': '{"how many": 300}'
        }
        self.assertDictEqual(res, expected)

    def test_dynamo_to_dict__do_json_loads(self):
        config = self.TEST_CONFIG.copy()
        config['dont_json_loads_results'] = False

        self.dynamo_client = DynamoDbClient(config=config)

        dynamo_row = {
            'hash_col': {
                'S': 'aaa'
            },
            'range_col': {
                'N': '123'
            },
            'other_col': {
                'S': '{"how many": 300}'
            },
            'duck_quack': {
                'S': '{"quack": "duck"}'
            }
        }
        res = self.dynamo_client.dynamo_to_dict(dynamo_row, strict=False)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': {
                "how many": 300
            },
            'duck_quack': {
                "quack": "duck"
            }
        }
        self.assertDictEqual(res, expected)

        res = self.dynamo_client.dynamo_to_dict(dynamo_row, strict=True)
        expected = {
            'hash_col': 'aaa',
            'range_col': 123,
            'other_col': {
                "how many": 300
            }
        }
        self.assertDictEqual(res, expected)

    def test_get_by_query__validates_comparison(self):
        self.assertRaises(AssertionError,
                          self.dynamo_client.get_by_query,
                          keys={'k': '1'},
                          comparisons={'k': 'unsupported'})

    def test_get_by_query__between(self):
        keys = {
            'hash_col': 'cat',
            'st_between_range_col': '3',
            'en_between_range_col': '6'
        }

        # paginator = MagicMock()
        # self.dynamo_mock.get_paginator.return_value = p

        self.dynamo_client = DynamoDbClient(config=self.TEST_CONFIG)

        self.dynamo_client.get_by_query(keys=keys)
        # print(f"Call_args for paginate: {self.paginator_mock.paginate.call_args}")

        args, kwargs = self.paginator_mock.paginate.call_args
        # print(kwargs)

        self.assertEqual(len(kwargs['ExpressionAttributeValues']), 3)
        self.assertIn(
            'range_col between :st_between_range_col and :en_between_range_col',
            kwargs['KeyConditionExpression'])

    def test__parse_filter_expression(self):
        TESTS = {
            'key = 42': ("key = :filter_key", {
                ":filter_key": {
                    'N': '42'
                }
            }),
            '   key    = 42  ': ("key = :filter_key", {
                ":filter_key": {
                    'N': '42'
                }
            }),
            'cat = meaw': ("cat = :filter_cat", {
                ":filter_cat": {
                    'S': 'meaw'
                }
            }),
            'magic between 41 and 42':
            ("magic between :st_between_magic and :en_between_magic", {
                ":st_between_magic": {
                    'N': '41'
                },
                ":en_between_magic": {
                    'N': '42'
                }
            }),
            'attribute_not_exists boo': ("attribute_not_exists (boo)", {})
        }

        for data, expected in TESTS.items():
            self.assertEqual(self.dynamo_client._parse_filter_expression(data),
                             expected)

    def test__parse_filter_expression__raises(self):

        TESTS = [
            {
                'k': 1
            },
            [1, 2],
            None,  # Invalid input types
            'key == 42',
            'foo ~ 1',
            'foo3 <=> 0',
            'key between 42',  # Invalid operators
            'key between 23, 25',
            'key between [23, 25]',
            'key 23 between 21',  # Invalid between formats.
        ]

        for data in TESTS:
            self.assertRaises((AssertionError, ValueError),
                              self.dynamo_client._parse_filter_expression,
                              data)
示例#9
0
class WorkerAssistant_IntegrationTestCase(unittest.TestCase):
    TEST_CONFIG = TEST_TASK_CLIENT_CONFIG

    @classmethod
    def setUpClass(cls):
        """
        Clean the classic autotest table.
        """
        cls.TEST_CONFIG['init_clients'] = ['DynamoDb']
        global_vars.lambda_context = attrdict.AttrDict(
            {v: k
             for k, v in MetaHandler.CONTEXT_FIELDS_MAPPINGS.items()})

    def setUp(self):
        """
        We keep copies of main parameters here, because they may differ from test to test and cleanup needs them.
        This is responsibility of the test author to update these values if required from test.
        """
        self.config = self.TEST_CONFIG.copy()

        self.patcher = patch("sosw.app.get_config")
        self.get_config_patch = self.patcher.start()
        self.get_config_patch.return_value = {}

        self.table_name = self.config['dynamo_db_config']['table_name']
        self.HASH_KEY = ('task_id', 'S')

        self.clean_task_tables()

        self.dynamo_client = DynamoDbClient(
            config=self.config['dynamo_db_config'])

        self.assistant = WorkerAssistant(
            custom_config=TEST_WORKER_ASSISTANT_CONFIG)

    def tearDown(self):
        self.patcher.stop()
        self.clean_task_tables()

    def clean_task_tables(self):
        clean_dynamo_table(self.table_name, (self.HASH_KEY[0], ))

    def test_mark_task_as_completed(self):
        _ = self.assistant.get_db_field_name
        task_id = '123'

        initial_task = {
            _('task_id'): task_id,
            _('labourer_id'): 'lab',
            _('greenfield'): 8888,
            _('attempts'): 2
        }
        self.dynamo_client.put(initial_task)

        between_times = ((datetime.datetime.now() -
                          datetime.timedelta(minutes=1)).timestamp(),
                         (datetime.datetime.now() +
                          datetime.timedelta(minutes=1)).timestamp())

        self.assistant.mark_task_as_completed(task_id)

        changed_task = self.dynamo_client.get_by_query({_('task_id'):
                                                        task_id})[0]

        self.assertTrue(
            between_times[0] <= changed_task['completed_at'] <=
            between_times[1],
            msg=
            f"NOT {between_times[0]} <= {changed_task['completed_at']} <= {between_times[1]}"
        )

    def test_mark_task_as_failed(self):
        _ = self.assistant.get_db_field_name
        task_id = '123'

        initial_task = {
            _('task_id'): task_id,
            _('labourer_id'): 'lab',
            _('greenfield'): 8888,
            _('attempts'): 0
        }
        self.dynamo_client.put(initial_task)

        self.assistant.mark_task_as_failed(task_id, result={'some_info': 33})

        changed_task = self.dynamo_client.get_by_query(
            {_('task_id'): task_id}, fetch_all_fields=True)[0]

        self.assertEqual(1, changed_task['failed_attempts'])
        self.assertEqual(33, changed_task['result_some_info'])