class TestFlotillaSchedulerDynamo(unittest.TestCase): def setUp(self): self.assignments = MagicMock(spec=Table) self.assignments._dynamizer = MagicMock() self.regions = MagicMock(spec=Table) self.services = MagicMock(spec=Table) self.stacks = MagicMock(spec=Table) self.status = MagicMock(spec=Table) self.service_item = MagicMock(spec=Item) self.service_data = {'service_name': SERVICE, REVISION: 1} self.service_item.__getitem__.side_effect = \ self.service_data.__getitem__ self.service_item.__setitem__.side_effect = \ self.service_data.__setitem__ self.service_item.__contains__.side_effect = \ self.service_data.__contains__ self.service_item.items.side_effect = \ self.service_data.items self.service_item.__delitem__.side_effect = \ self.service_data.__delitem__ self.db = FlotillaSchedulerDynamo(self.assignments, self.regions, self.services, self.stacks, self.status) def test_get_all_revision_weights_empty(self): weights = self.db.get_all_revision_weights() self.assertEqual(0, len(weights)) def test_get_all_revision_weights(self): rev2 = REVISION.replace('a', 'b') rev3 = REVISION.replace('a', 'c') rev4 = REVISION.replace('a', 'd') self.services.scan.return_value = [{ 'service_name': SERVICE, 'regions': ['us-east-1'], REVISION: 1, rev2: 2 }, { 'service_name': 'test2', rev3: 1, rev4: 2 }] weights = self.db.get_all_revision_weights() self.assertEqual(2, len(weights)) service_test = weights[SERVICE] self.assertEqual(1, service_test[REVISION]) self.assertEqual(2, service_test[rev2]) service_test = weights['test2'] self.assertEqual(1, service_test[rev3]) self.assertEqual(2, service_test[rev4]) def test_get_all_revision_weights_ignore_negative(self): rev2 = REVISION.replace('a', 'b') self.services.scan.return_value = [{ 'service_name': SERVICE, 'regions': ['us-east-1'], REVISION: 1, rev2: -2 }] weights = self.db.get_all_revision_weights() self.assertEqual(1, len(weights[SERVICE])) def test_get_service_revisions(self): self.services.get_item.return_value = { 'service_name': SERVICE, 'regions': ['us-east-1'], REVISION: 1 } weights = self.db.get_revision_weights(SERVICE) self.assertEqual(1, len(weights)) self.assertEqual(1, weights[REVISION]) def test_get_service_revisions_not_found(self): self.services.get_item.side_effect = ItemNotFound() weights = self.db.get_revision_weights(SERVICE) self.assertEqual(0, len(weights)) def test_set_assignment(self): self.db.set_assignment(SERVICE, INSTANCE_ID, REVISION) self.assignments.put_item.assert_called_with(data=ANY, overwrite=True) def test_set_assignments(self): mock_batch = MagicMock(spec=BatchTable) self.assignments.batch_write.return_value = mock_batch self.db.set_assignments([{ 'instance_id': INSTANCE_ID }, { 'instance_id': INSTANCE_ID }]) mock_batch.put_item.call_count = 2 def test_get_instance_assignments_empty(self): assignments = self.db.get_instance_assignments(SERVICE) self.assertEqual(0, len(assignments)) def test_get_instance_assignments_assigned(self): self.status.query_2.return_value = [{ 'instance_id': INSTANCE_ID, 'status_time': time.time() }] self.assignments.batch_get.return_value = [{ 'instance_id': INSTANCE_ID, 'assignment': REVISION }] assignments = self.db.get_instance_assignments(SERVICE) self.assertEqual(1, len(assignments[REVISION])) def test_get_instance_assignments_unassigned(self): self.status.query_2.return_value = [{ 'instance_id': INSTANCE_ID, 'status_time': time.time() }] assignments = self.db.get_instance_assignments(SERVICE) self.assertEqual(1, len(assignments[None])) def test_get_instance_assignments_garbage_collection(self): self.status.query_2.return_value = [{ 'instance_id': INSTANCE_ID, 'status_time': time.time() - (INSTANCE_EXPIRY + 1) }] assignments = self.db.get_instance_assignments(SERVICE) self.assertEqual(0, len(assignments)) self.status.batch_write.assert_called_with() self.assignments.batch_write.assert_called_with() def test_get_stacks_empty(self): stacks = self.db.get_stacks() self.assertEqual(0, len(stacks)) def test_get_stacks(self): self.stacks.scan.return_value = [{'service_name': 'fred'}] stacks = self.db.get_stacks() self.assertEqual(1, len(stacks)) self.assertEquals('fred', stacks[0]['service_name']) def test_set_stacks(self): self.db.set_stacks([{'stack_arn': 'foo'}]) self.stacks.batch_write.assert_called_with() def test_set_stacks_empty(self): self.db.set_stacks([]) self.stacks.batch_write.assert_not_called() def test_set_services(self): self.db.set_services([{'service_name': 'foo'}]) self.services.batch_write.assert_called_with() def test_set_services_empty(self): self.db.set_services([]) self.services.batch_write.assert_not_called() def test_get_region_params(self): self.regions.get_item.return_value = { 'region_name': 'us-east-1', 'az1': 'us-east-1e' } region_params = self.db.get_region_params('us-east-1') self.assertEqual(region_params['az1'], 'us-east-1e') def test_get_service_status(self): self.status.query_2.return_value = [{ 'instance_id': 'i-goodinstance', 'status_time': time.time(), 'test-%s.service' % REVISION: '{}' }, { 'instance_id': 'i-anothergood', 'status_time': time.time(), 'test-%s.service' % REVISION: '{}' }, { 'instance_id': INSTANCE_ID, 'status_time': time.time(), 'test-%s.service' % REVISION: '{}' }, { 'instance_id': 'i-expired', 'status_time': time.time() - (INSTANCE_EXPIRY + 1), 'test-%s.service' % REVISION: '{}' }] status = { k: v for k, v in self.db.get_service_status(SERVICE, REVISION, INSTANCE_ID) } self.assertEqual(len(status), 2) def test_make_only_revision_not_found(self): self.db.make_only_revision(SERVICE, REVISION) def test_make_only_revision_no_changes(self): self.services.get_item.return_value = self.service_item self.db.make_only_revision(SERVICE, REVISION) self.service_item.save.assert_not_called() def test_make_only_revision(self): new_hash = REVISION.replace('1', '4') self.service_data[new_hash] = 1 self.services.get_item.return_value = self.service_item self.db.make_only_revision(SERVICE, new_hash) self.service_item.save.assert_called_with() self.assertNotIn(REVISION, self.service_item)