def test_default_log_group(self): ctx = Bag(session_factory=None, options=Bag(account_id='001100', region='us-east-1'), policy=Bag(name='test', resource_type='ec2')) log_output = output.log_outputs.select('custodian/xyz', ctx) self.assertEqual(log_output.log_group, 'custodian/xyz') self.assertEqual(log_output.construct_stream_name(), 'test') log_output = output.log_outputs.select('/custodian/xyz/', ctx) self.assertEqual(log_output.log_group, 'custodian/xyz') log_output = output.log_outputs.select('aws://somewhere/out/there', ctx) self.assertEqual(log_output.log_group, 'somewhere/out/there') log_output = output.log_outputs.select('aws:///somewhere/out', ctx) self.assertEqual(log_output.log_group, 'somewhere/out') log_output = output.log_outputs.select('aws://somewhere', ctx) self.assertEqual(log_output.log_group, 'somewhere') log_output = output.log_outputs.select( "aws:///somewhere/out?stream={region}/{policy}", ctx) self.assertEqual(log_output.log_group, 'somewhere/out') self.assertEqual(log_output.construct_stream_name(), 'us-east-1/test')
def eperm(provider, el, r=None): if el.permissions: return el.permissions element_type = get_element_type(el) if r is None or r.type is None: # dummy resource type for policy if provider == 'aws': r = Bag({'type': 'kinesis'}) elif provider == 'gcp': r = Bag({'type': 'instance'}) elif provider == 'azure': r = Bag({'type': 'vm'}) # print(f'policy construction lookup {r.type}.{element_type}.{el.type}') loader = PolicyLoader(Config.empty()) pdata = { 'name': f'permissions-{r.type}', 'resource': f'{provider}.{r.type}' } pdata[element_type] = get_element_data(element_type, el) try: pset = loader.load_data({'policies': [pdata]}, ':mem:', validate=False) except Exception as e: print(f'error loading {el} as {element_type}:{el.type} error: {e} \n {pdata}') return [] el = get_policy_element(el, list(pset)[0]) return el.get_permissions()
def test_metrics_output(self): project_id = 'cloud-custodian' factory = self.replay_flight_data('output-metrics', project_id=project_id) ctx = Bag(session_factory=factory, policy=Bag(name='custodian-works', resource_type='gcp.function')) conf = Bag() metrics = StackDriverMetrics(ctx, conf) metrics.put_metric('ResourceCount', 43, 'Count', Scope='Policy') metrics.flush() if self.recording: time.sleep(42) session = factory() client = session.client('monitoring', 'v3', 'projects.timeSeries') results = client.execute_command( 'list', { 'name': 'projects/{}'.format(project_id), 'filter': 'metric.type="custom.googleapis.com/custodian/policy/resourcecount"', 'pageSize': 3, 'interval_startTime': (datetime.datetime.utcnow() - datetime.timedelta(minutes=5)).isoformat('T') + 'Z', 'interval_endTime': datetime.datetime.utcnow().isoformat('T') + 'Z' }) self.assertEqual(results['timeSeries'], [{ u'metric': { u'labels': { u'policy': u'custodian-works', u'project_id': u'cloud-custodian' }, u'type': u'custom.googleapis.com/custodian/policy/resourcecount' }, u'metricKind': u'GAUGE', u'points': [{ u'interval': { u'endTime': u'2018-08-12T22:30:53.524505Z', u'startTime': u'2018-08-12T22:30:53.524505Z' }, u'value': { u'int64Value': u'43' } }], u'resource': { u'labels': { u'project_id': u'cloud-custodian' }, u'type': u'global' }, u'valueType': u'INT64' }])
def test_metrics_destination_dims(self): tmetrics = [] class Metrics(aws.MetricsOutput): def _put_metrics(self, ns, metrics): tmetrics.extend(metrics) conf = Bag({'region': 'us-east-2', 'scheme': 'aws', 'netloc': 'master'}) ctx = Bag(session_factory=None, options=Bag(account_id='001100', region='us-east-1'), policy=Bag(name='test', resource_type='ec2')) moutput = Metrics(ctx, conf) moutput.put_metric('Calories', 400, 'Count', Scope='Policy', Food='Pizza') moutput.flush() tmetrics[0].pop('Timestamp') self.assertEqual(tmetrics, [{ 'Dimensions': [{'Name': 'Policy', 'Value': 'test'}, {'Name': 'ResType', 'Value': 'ec2'}, {'Name': 'Food', 'Value': 'Pizza'}, {'Name': 'Region', 'Value': 'us-east-1'}, {'Name': 'Account', 'Value': '001100'}], 'MetricName': 'Calories', 'Unit': 'Count', 'Value': 400}])
def test_metrics(self): session_factory = self.replay_flight_data('output-aws-metrics') policy = Bag(name='test', resource_type='ec2') ctx = Bag(session_factory=session_factory, policy=policy) sink = output.metrics_outputs.select('aws', ctx) self.assertTrue(isinstance(sink, aws.MetricsOutput)) sink.put_metric('ResourceCount', 101, 'Count') sink.flush()
def test_stream_override(self): session_factory = self.replay_flight_data('test_log_stream_override') ctx = Bag(session_factory=session_factory, options=Bag(account_id='001100', region='us-east-1'), policy=Bag(name='test', resource_type='ec2')) log_output = output.log_outputs.select( 'aws://master/custodian?region=us-east-2&stream=testing', ctx) stream = log_output.get_handler() self.assertTrue(stream.log_stream == 'testing')
def test_app_insights_logs(self): policy = Bag(name='test', resource_type='azure.vm', session_factory=Session) ctx = Bag(policy=policy, execution_id='00000000-0000-0000-0000-000000000000') with log_outputs.select('azure://00000000-0000-0000-0000-000000000000', ctx) as log: self.assertTrue(isinstance(log, AppInsightsLogOutput)) logging.getLogger('custodian.test').warning('test message')
def test_app_insights_metrics(self): policy = Bag(name='test', resource_type='azure.vm', session_factory=Session) ctx = Bag(policy=policy, execution_id='00000000-0000-0000-0000-000000000000') sink = metrics_outputs.select( 'azure://00000000-0000-0000-0000-000000000000', ctx) self.assertTrue(isinstance(sink, MetricsOutput)) sink.put_metric('ResourceCount', 101, 'Count') sink.flush()
def test_metrics_output_set_write_project_id(self): project_id = 'cloud-custodian-sub' write_project_id = 'cloud-custodian' factory = self.replay_flight_data('output-metrics', project_id=project_id) ctx = Bag(session_factory=factory, policy=Bag(name='custodian-works', resource_type='gcp.function')) conf = Bag(project_id=write_project_id) metrics = StackDriverMetrics(ctx, conf) metrics.put_metric('ResourceCount', 43, 'Count', Scope='Policy') metrics.flush()
def validate(options): load_resources() if len(options.configs) < 1: log.error('no config files specified') sys.exit(1) used_policy_names = set() schm = schema.generate() errors = [] for config_file in options.configs: config_file = os.path.expanduser(config_file) if not os.path.exists(config_file): raise ValueError("Invalid path for config %r" % config_file) options.dryrun = True fmt = config_file.rsplit('.', 1)[-1] with open(config_file) as fh: if fmt in ('yml', 'yaml'): data = yaml.safe_load(fh.read()) elif fmt in ('json', ): data = json.load(fh) else: log.error("The config file must end in .json, .yml or .yaml.") raise ValueError( "The config file must end in .json, .yml or .yaml.") errors += schema.validate(data, schm) conf_policy_names = { p.get('name', 'unknown') for p in data.get('policies', ()) } dupes = conf_policy_names.intersection(used_policy_names) if len(dupes) >= 1: errors.append( ValueError( "Only one policy with a given name allowed, duplicates: %s" % (", ".join(dupes)))) used_policy_names = used_policy_names.union(conf_policy_names) if not errors: null_config = Config.empty(dryrun=True, account_id='na', region='na') for p in data.get('policies', ()): try: policy = Policy(p, null_config, Bag()) policy.validate() except Exception as e: msg = "Policy: %s is invalid: %s" % (p.get( 'name', 'unknown'), e) errors.append(msg) if not errors: log.info("Configuration valid: {}".format(config_file)) continue log.error("Configuration invalid: {}".format(config_file)) for e in errors: log.error("%s" % e) if errors: sys.exit(1)
def test_azure_function_resolved_uai_identity(self): session = mock.MagicMock() p = self.load_policy({ 'name': 'sm', 'resource': 'azure.vm', 'mode': { 'type': FUNCTION_EVENT_TRIGGER_MODE, 'events': ['VmWrite'], 'provision-options': { 'identity': { 'type': 'UserAssigned', 'id': 'mike' } } } }) exec_mode = p.get_execution_mode() uai = dict(name='mike', id='/subscriptions/xyz/userAssignedIdentities/foo', client_id='bob') session.client( 'azure.mgmt.msi.ManagedServiceIdentityClient' ).user_assigned_identities.list_by_subscription.return_value = [ Bag(uai) ] identity = exec_mode._get_identity(session) self.assertEqual( identity, { 'type': 'UserAssigned', 'client_id': 'bob', 'id': '/subscriptions/xyz/userAssignedIdentities/foo' })
def get_context(self, config=None, policy=None): if config is None: self.context_output_dir = self.get_temp_dir() config = Config.empty(output_dir=self.context_output_dir) ctx = ExecutionContext(Session, policy or Bag({'name': 'test-policy'}), config) return ctx
def __init__(self, *args, **kw): super(XrayContext, self).__init__(*args, **kw) self.sampler = LocalSampler() # We want process global semantics as policy execution # can span threads. self._local = Bag() self._current_subsegment = None
def get_context(self, config=None, session_factory=None, policy=None): if config is None: self.context_output_dir = self.get_temp_dir() config = Config.empty(output_dir=self.context_output_dir) ctx = ExecutionContext(session_factory, policy or Bag({"name": "test-policy"}), config) return ctx
def test_verify_parent_filter(self): manager = KeyVaultKeys( ExecutionContext(None, Bag(name="xyz", provider_name='azure'), Config.empty()), { 'name': 'test-policy', 'resource': 'azure.keyvault-keys', 'filters': [{ 'type': 'parent', 'filter': { 'type': 'value', 'key': 'name', 'op': 'glob', 'value': 'cctestkv*' } }] }) self.assertEqual(len(manager.filters), 1) filter = manager.filters[0] self.assertTrue(isinstance(filter, ParentFilter)) self.assertTrue(isinstance(filter.parent_manager, KeyVault)) self.assertTrue(isinstance(filter.parent_filter, ValueFilter))
def test_stream_override(self): session_factory = self.replay_flight_data( 'test_log_stream_override') conf = Bag({ 'region': 'us-east-2', 'scheme': 'aws', 'netloc': 'master', 'path': 'custodian', 'stream': "testing" }) ctx = Bag(session_factory=session_factory, options=Bag(account_id='001100', region='us-east-1'), policy=Bag(name='test', resource_type='ec2')) output = aws.CloudWatchLogOutput(ctx, conf) stream = output.get_handler() self.assertTrue(stream.log_stream == 'testing')
def run(organization, hook_context, github_url, github_token, verbose, metrics=False, since=None, assume=None, region=None): """scan org repo status hooks""" logging.basicConfig(level=logging.DEBUG) since = dateparser.parse( since, settings={ 'RETURN_AS_TIMEZONE_AWARE': True, 'TO_TIMEZONE': 'UTC'}) headers = {"Authorization": "token {}".format(github_token)} response = requests.post( github_url, headers=headers, json={'query': query, 'variables': {'organization': organization}}) result = response.json() if response.status_code != 200 or 'errors' in result: raise Exception( "Query failed to run by returning code of {}. {}".format( response.status_code, response.content)) now = datetime.now(tzutc()) stats = Counter() repo_metrics = RepoMetrics( Bag(session_factory=SessionFactory(region, assume_role=assume)), {'namespace': DEFAULT_NAMESPACE} ) for r in result['data']['organization']['repositories']['nodes']: commits = jmespath.search( 'pullRequests.edges[].node[].commits[].nodes[].commit[]', r) if not commits: continue log.debug("processing repo: %s prs: %d", r['name'], len(commits)) repo_metrics.dims = { 'Hook': hook_context, 'Repo': '{}/{}'.format(organization, r['name'])} # Each commit represents a separate pr for c in commits: process_commit(c, r, repo_metrics, stats, since, now) repo_metrics.dims = None if stats['missing']: repo_metrics.put_metric( 'RepoHookPending', stats['missing'], 'Count', Hook=hook_context) repo_metrics.put_metric( 'RepoHookLatency', stats['missing_time'], 'Seconds', Hook=hook_context) if not metrics: print(dumps(repo_metrics.buf, indent=2)) return else: repo_metrics.BUF_SIZE = 20 repo_metrics.flush()
def test_tracer(self): session_factory = self.replay_flight_data('output-xray-trace') policy = Bag(name='test', resource_type='ec2') ctx = Bag( policy=policy, session_factory=session_factory, options=Bag(account_id='644160558196')) ctx.get_metadata = lambda *args: {} config = Bag() tracer = aws.XrayTracer(ctx, config) with tracer: try: with tracer.subsegment('testing') as w: raise ValueError() except ValueError: pass self.assertNotEqual(w.cause, {})
def get_session(role, session_name, profile): region = os.environ.get('AWS_DEFAULT_REGION', 'eu-west-1') stats = ApiStats(Bag(), Config.empty()) if role: s = assumed_session(role, session_name, region=region) else: s = SessionFactory(region, profile)() stats(s) return stats, s
def get_log_output(request, output_url): log = StackDriverLogging( ExecutionContext(lambda assume=False: mock.MagicMock(), Bag(name="xyz", provider_name="gcp", resource_type='gcp.function'), Config.empty(account_id='custodian-test')), parse_url_config(output_url)) request.addfinalizer(reset_session_cache) return log
def test_iam_permissions_validity(self): cfg = Config.empty() missing = set() all_invalid = [] perms = load_data('iam-actions.json') for k, v in manager.resources.items(): p = Bag({'name': 'permcheck', 'resource': k, 'provider_name': 'aws'}) ctx = self.get_context(config=cfg, policy=p) mgr = v(ctx, p) invalid = [] # if getattr(mgr, 'permissions', None): # print(mgr) found = False for s in (mgr.resource_type.service, getattr(mgr.resource_type, 'permission_prefix', None)): if s in perms: found = True if not found: missing.add("%s->%s" % (k, mgr.resource_type.service)) continue invalid.extend(self.check_permissions(perms, mgr.get_permissions(), k)) for n, a in v.action_registry.items(): p['actions'] = [n] invalid.extend( self.check_permissions( perms, a({}, mgr).get_permissions(), "{k}.actions.{n}".format(k=k, n=n))) for n, f in v.filter_registry.items(): if n in ('or', 'and', 'not', 'missing'): continue p['filters'] = [n] invalid.extend( self.check_permissions( perms, f({}, mgr).get_permissions(), "{k}.filters.{n}".format(k=k, n=n))) if invalid: for k, perm_set in invalid: perm_set = [i for i in perm_set if not i.startswith('elasticloadbalancing')] if perm_set: all_invalid.append((k, perm_set)) if missing: raise ValueError( "resources missing service %s" % ('\n'.join(sorted(missing)))) if all_invalid: raise ValueError( "invalid permissions \n %s" % ('\n'.join(sorted(map(str, all_invalid)))))
def get_azure_output(self): output = AzureStorageOutput( ExecutionContext( None, Bag(name="xyz"), Config.empty( output_dir="azure://mystorage.blob.core.windows.net/logs"), )) self.addCleanup(shutil.rmtree, output.root_dir) return output
def test_check_permissions(self): load_resources(('gcp.*', )) missing = [] invalid = [] iam_path = os.path.join(os.path.dirname(__file__), 'data', 'iam-permissions.json') with open(iam_path) as fh: valid_perms = set(json.load(fh).get('permissions')) cfg = Config.empty() for k, v in resources.items(): policy = Bag({ 'name': 'permcheck', 'resource': 'gcp.%s' % k, 'provider_name': 'gcp' }) ctx = self.get_context(config=cfg, policy=policy) mgr = v(ctx, policy) perms = mgr.get_permissions() if not perms: missing.append(k) for p in perms: if p not in valid_perms: invalid.append((k, p)) for n, a in list(v.action_registry.items()): if n in ALLOWED_NOPERM: continue policy['actions'] = [n] perms = a({}, mgr).get_permissions() if not perms: missing.append('%s.actions.%s' % (k, n)) for p in perms: if p not in valid_perms: invalid.append(('%s.actions.%s' % (k, n), p)) for n, f in list(v.filter_registry.items()): if n in ALLOWED_NOPERM: continue policy['filters'] = [n] perms = f({}, mgr).get_permissions() if not perms: missing.append('%s.filters.%s' % (k, n)) for p in perms: if p not in valid_perms: invalid.append(('%s.filters.%s' % (k, n), p)) if missing: self.fail('missing permissions %d on \n\t%s' % (len(missing), '\n\t'.join(sorted(missing)))) if invalid: self.fail('invalid permissions %d on \n\t%s' % (len(invalid), '\n\t'.join(map(str, sorted(invalid)))))
def get_azure_output(self, custom_pyformat=None): output_dir = "azure://mystorage.blob.core.windows.net/logs" if custom_pyformat: output_dir = AzureStorageOutput.join(output_dir, custom_pyformat) output = AzureStorageOutput( ExecutionContext(None, Bag(name="xyz", provider_name='azure'), Config.empty(output_dir=output_dir)), {'url': output_dir}, ) self.addCleanup(shutil.rmtree, output.root_dir) return output
def watch(limit): """watch scan rates across the cluster""" period = 5.0 prev = db.db() prev_totals = None while True: click.clear() time.sleep(period) cur = db.db() cur.data['gkrate'] = {} progress = [] prev_buckets = {b.bucket_id: b for b in prev.buckets()} totals = {'scanned': 0, 'krate': 0, 'lrate': 0, 'bucket_id': 'totals'} for b in cur.buckets(): if not b.scanned: continue totals['scanned'] += b.scanned totals['krate'] += b.krate totals['lrate'] += b.lrate if b.bucket_id not in prev_buckets: b.data['gkrate'][b.bucket_id] = b.scanned / period elif b.scanned == prev_buckets[b.bucket_id].scanned: continue else: b.data['gkrate'][b.bucket_id] = ( b.scanned - prev_buckets[b.bucket_id].scanned) / period progress.append(b) if prev_totals is None: totals['gkrate'] = '...' else: totals['gkrate'] = (totals['scanned'] - prev_totals['scanned']) / period prev = cur prev_totals = totals progress = sorted(progress, key=lambda x: x.gkrate, reverse=True) if limit: progress = progress[:limit] progress.insert(0, Bag(totals)) format_plain(progress, None, explicit_only=True, keys=['bucket_id', 'scanned', 'gkrate', 'lrate', 'krate'])
def get_blob_output(request, output_url=None, cleanup=True): if output_url is None: output_url = "gs://cloud-custodian/policies" output = GCPStorageOutput( ExecutionContext(lambda assume=False: mock.MagicMock(), Bag(name="xyz", provider_name="gcp"), Config.empty(output_dir=output_url, account_id='custodian-test')), parse_url_config(output_url)) if cleanup: request.addfinalizer(lambda: shutil.rmtree(output.root_dir)) # noqa request.addfinalizer(reset_session_cache) return output
def test_app_insights_metrics(self, put_mock): policy = self.load_policy({ 'name': 'test-rg', 'resource': 'azure.resourcegroup' }) ctx = Bag(policy=policy, execution_id='00000000-0000-0000-0000-000000000000') sink = metrics_outputs.select('azure://00000000-0000-0000-0000-000000000000', ctx) self.assertTrue(isinstance(sink, MetricsOutput)) sink.put_metric('ResourceCount', 101, 'Count') sink.flush() put_mock.assert_called_once_with( 'test-rg', [{ 'Name': 'ResourceCount', 'Value': 101, 'Dimensions': {'Policy': 'test-rg', 'ResType': 'azure.resourcegroup', 'SubscriptionId': local_session(Session).get_subscription_id(), 'ExecutionId': '00000000-0000-0000-0000-000000000000', 'ExecutionMode': 'pull', 'Unit': 'Count'}}])
def test_tracer(self): session_factory = self.replay_flight_data('output-xray-trace') policy = Bag(name='test', resource_type='ec2') ctx = Bag(policy=policy, session_factory=session_factory, options=Bag(account_id='644160558196')) ctx.get_metadata = lambda *args: {} config = Bag() tracer = aws.XrayTracer(ctx, config) with tracer: try: with tracer.subsegment('testing') as w: raise ValueError() except ValueError: pass self.assertNotEqual(w.cause, {})
class BaseTest(TestUtils, AzureVCRBaseTest): test_context = ExecutionContext( Session, Bag(name="xyz", provider_name='azure'), Config.empty() ) """ Azure base testing class. """ def __init__(self, *args, **kwargs): super(BaseTest, self).__init__(*args, **kwargs) self._requires_polling = False @classmethod def setUpClass(cls, *args, **kwargs): super(BaseTest, cls).setUpClass(*args, **kwargs) if os.environ.get(constants.ENV_ACCESS_TOKEN) == "fake_token": cls._token_patch = patch( 'c7n_azure.session.jwt.decode', return_value={'tid': DEFAULT_TENANT_ID}) cls._token_patch.start() @classmethod def tearDownClass(cls, *args, **kwargs): super(BaseTest, cls).tearDownClass(*args, **kwargs) if os.environ.get(constants.ENV_ACCESS_TOKEN) == "fake_token": cls._token_patch.stop() def setUp(self): super(BaseTest, self).setUp() ThreadHelper.disable_multi_threading = True # We always patch the date for recordings so URLs that involve dates match up if self.vcr_enabled: self._utc_patch = patch.object(utils, 'utcnow', self._get_test_date) self._utc_patch.start() self.addCleanup(self._utc_patch.stop) self._now_patch = patch.object(utils, 'now', self._get_test_date) self._now_patch.start() self.addCleanup(self._now_patch.stop) if not self._requires_polling: # Patch Poller with constructor that always disables polling # This breaks blocking on long running operations (resource creation). self._lro_patch = patch.object(msrest.polling.LROPoller, '__init__', BaseTest.lro_init) self._lro_patch.start() self.addCleanup(self._lro_patch.stop) if self.is_playback(): if self._requires_polling: # If using polling we need to monkey patch the timeout during playback # or we'll have long sleeps introduced into our test runs Session._old_client = Session.client Session.client = BaseTest.session_client_wrapper self.addCleanup(BaseTest.session_client_cleanup) if constants.ENV_ACCESS_TOKEN in os.environ: self._tenant_patch = patch('c7n_azure.session.Session.get_tenant_id', return_value=DEFAULT_TENANT_ID) self._tenant_patch.start() self.addCleanup(self._tenant_patch.stop) self._subscription_patch = patch('c7n_azure.session.Session.get_subscription_id', return_value=DEFAULT_SUBSCRIPTION_ID) self._subscription_patch.start() self.addCleanup(self._subscription_patch.stop) self.session = local_session(Session) def _get_test_date(self, tz=None): header_date = self.cassette.responses[0]['headers'].get('date') \ if self.cassette.responses else None if header_date: test_date = datetime.datetime(*eut.parsedate(header_date[0])[:6]) else: return datetime.datetime.now(tz=tz) return test_date.replace(hour=23, minute=59, second=59, microsecond=0) def sleep_in_live_mode(self, interval=60): if not self.is_playback(): sleep(interval) @staticmethod def setup_account(): # Find actual name of storage account provisioned in our test environment s = Session() client = s.client('azure.mgmt.storage.StorageManagementClient') accounts = list(client.storage_accounts.list()) matching_account = [a for a in accounts if a.name.startswith("cctstorage")] return matching_account[0] @staticmethod def sign_out_patch(): return patch.dict(os.environ, { constants.ENV_TENANT_ID: '', constants.ENV_SUB_ID: '', constants.ENV_CLIENT_ID: '', constants.ENV_CLIENT_SECRET: '' }, clear=True) @staticmethod def lro_init(self, client, initial_response, deserialization_callback, polling_method): self._client = client if isinstance(client, ServiceClient) else client._client self._response = initial_response.response if \ isinstance(initial_response, ClientRawResponse) else \ initial_response self._callbacks = [] # type List[Callable] self._polling_method = msrest.polling.NoPolling() if isinstance(deserialization_callback, type) and \ issubclass(deserialization_callback, Model): deserialization_callback = deserialization_callback.deserialize # Might raise a CloudError self._polling_method.initialize(self._client, self._response, deserialization_callback) self._thread = None self._done = None self._exception = None @staticmethod def session_client_cleanup(): Session.client = Session._old_client @staticmethod def session_client_wrapper(self, client): client = Session._old_client(self, client) client.config.long_running_operation_timeout = 0 return client
class UtilsTest(BaseTest): def setUp(self): super(UtilsTest, self).setUp() def test_get_subscription_id(self): self.assertEqual(ResourceIdParser.get_subscription_id(RESOURCE_ID), DEFAULT_SUBSCRIPTION_ID) def test_get_namespace(self): self.assertEqual(ResourceIdParser.get_namespace(RESOURCE_ID), "Microsoft.Compute") self.assertEqual(ResourceIdParser.get_namespace(RESOURCE_ID_CHILD), "Microsoft.Sql") def test_get_resource_group(self): self.assertEqual(ResourceIdParser.get_resource_group(RESOURCE_ID), "rgtest") def test_get_resource_type(self): self.assertEqual(ResourceIdParser.get_resource_type(RESOURCE_ID), "virtualMachines") self.assertEqual(ResourceIdParser.get_resource_type(RESOURCE_ID_CHILD), "servers/databases") def test_get_full_type(self): self.assertEqual(ResourceIdParser.get_full_type(RESOURCE_ID), "Microsoft.Compute/virtualMachines") def test_resource_name(self): self.assertEqual(ResourceIdParser.get_resource_name(RESOURCE_ID), "nametest") def test_math_mean(self): self.assertEqual(Math.mean([4, 5, None, 3]), 4) self.assertEqual(Math.mean([None]), 0) self.assertEqual(Math.mean([3, 4]), 3.5) def test_math_sum(self): self.assertEqual(Math.sum([4, 5, None, 3]), 12) self.assertEqual(Math.sum([None]), 0) self.assertEqual(Math.sum([3.5, 4]), 7.5) def test_string_utils_equal(self): # Case insensitive matches self.assertTrue(StringUtils.equal("FOO", "foo")) self.assertTrue(StringUtils.equal("fOo", "FoO")) self.assertTrue(StringUtils.equal("ABCDEFGH", "abcdefgh")) self.assertFalse(StringUtils.equal("Foo", "Bar")) # Case sensitive matches self.assertFalse(StringUtils.equal("Foo", "foo", False)) self.assertTrue(StringUtils.equal("foo", "foo", False)) self.assertTrue(StringUtils.equal("fOo", "fOo", False)) self.assertFalse(StringUtils.equal("Foo", "Bar")) # Strip whitespace matches self.assertTrue(StringUtils.equal(" Foo ", "foo")) self.assertTrue(StringUtils.equal("Foo", " foo ")) self.assertTrue(StringUtils.equal(" Foo ", "Foo", False)) self.assertTrue(StringUtils.equal("Foo", " Foo ", False)) # Returns false for non string types self.assertFalse(StringUtils.equal(1, "foo")) self.assertFalse(StringUtils.equal("foo", 1)) self.assertFalse(StringUtils.equal(True, False)) def test_get_tag_value(self): resource = {'tags': {'tag1': 'value1', 'tAg2': 'VaLuE2', 'TAG3': 'VALUE3'}} self.assertEqual(TagHelper.get_tag_value(resource, 'tag1', True), 'value1') self.assertEqual(TagHelper.get_tag_value(resource, 'tag2', True), 'VaLuE2') self.assertEqual(TagHelper.get_tag_value(resource, 'tag3', True), 'VALUE3') def test_get_ports(self): self.assertEqual(PortsRangeHelper.get_ports_set_from_string("5, 4-5, 9"), set([4, 5, 9])) rule = {'properties': {'destinationPortRange': '10-12'}} self.assertEqual(PortsRangeHelper.get_ports_set_from_rule(rule), set([10, 11, 12])) rule = {'properties': {'destinationPortRanges': ['80', '10-12']}} self.assertEqual(PortsRangeHelper.get_ports_set_from_rule(rule), set([10, 11, 12, 80])) def test_validate_ports_string(self): self.assertEqual(PortsRangeHelper.validate_ports_string('80'), True) self.assertEqual(PortsRangeHelper.validate_ports_string('22-26'), True) self.assertEqual(PortsRangeHelper.validate_ports_string('80,22'), True) self.assertEqual(PortsRangeHelper.validate_ports_string('80,22-26'), True) self.assertEqual(PortsRangeHelper.validate_ports_string('80,22-26,30-34'), True) self.assertEqual(PortsRangeHelper.validate_ports_string('65537'), False) self.assertEqual(PortsRangeHelper.validate_ports_string('-1'), False) self.assertEqual(PortsRangeHelper.validate_ports_string('10-8'), False) self.assertEqual(PortsRangeHelper.validate_ports_string('80,30,25-65538'), False) self.assertEqual(PortsRangeHelper.validate_ports_string('65536-65537'), False) def test_get_ports_strings_from_list(self): self.assertEqual(PortsRangeHelper.get_ports_strings_from_list([]), []) self.assertEqual(PortsRangeHelper.get_ports_strings_from_list([10, 11]), ['10-11']) self.assertEqual(PortsRangeHelper.get_ports_strings_from_list([10, 12, 13, 14]), ['10', '12-14']) self.assertEqual(PortsRangeHelper.get_ports_strings_from_list([10, 12, 13, 14, 20, 21, 22]), ['10', '12-14', '20-22']) def test_build_ports_dict(self): securityRules = [ {'properties': {'destinationPortRange': '80-84', 'priority': 100, 'direction': 'Outbound', 'access': 'Allow', 'protocol': 'TCP'}}, {'properties': {'destinationPortRange': '85-89', 'priority': 110, 'direction': 'Outbound', 'access': 'Allow', 'protocol': 'UDP'}}, {'properties': {'destinationPortRange': '80-84', 'priority': 120, 'direction': 'Inbound', 'access': 'Deny', 'protocol': 'TCP'}}, {'properties': {'destinationPortRange': '85-89', 'priority': 130, 'direction': 'Inbound', 'access': 'Deny', 'protocol': 'UDP'}}, {'properties': {'destinationPortRange': '80-89', 'priority': 140, 'direction': 'Inbound', 'access': 'Allow', 'protocol': '*'}}] nsg = {'properties': {'securityRules': securityRules}} self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Inbound', 'TCP'), {k: k > 84 for k in range(80, 90)}) self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Inbound', 'UDP'), {k: k < 85 for k in range(80, 90)}) self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Inbound', '*'), {k: False for k in range(80, 90)}) self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Outbound', 'TCP'), {k: True for k in range(80, 85)}) self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Outbound', 'UDP'), {k: True for k in range(85, 90)}) self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Outbound', '*'), {k: True for k in range(80, 90)}) def test_snake_to_camel(self): self.assertEqual(StringUtils.snake_to_camel(""), "") self.assertEqual(StringUtils.snake_to_camel("test"), "test") self.assertEqual(StringUtils.snake_to_camel("test_abc"), "testAbc") self.assertEqual(StringUtils.snake_to_camel("test_abc_def"), "testAbcDef") def test_naming_hash(self): source = 'Lorem ipsum dolor sit amet' source2 = 'amet sit dolor ipsum Lorem' self.assertEqual(StringUtils.naming_hash(source), '16aba539') self.assertEqual(StringUtils.naming_hash(source, 10), '16aba5393a') self.assertNotEqual(StringUtils.naming_hash(source), StringUtils.naming_hash(source2)) @patch('azure.mgmt.applicationinsights.operations.ComponentsOperations.get', return_value=type(str('result_data'), (), {'instrumentation_key': GUID})) def test_app_insights_get_instrumentation_key(self, mock_handler_run): self.assertEqual(AppInsightsHelper.get_instrumentation_key('azure://' + GUID), GUID) self.assertEqual(AppInsightsHelper.get_instrumentation_key('azure://resourceGroup/name'), GUID) mock_handler_run.assert_called_once_with('resourceGroup', 'name') @patch('c7n_azure.utils.send_logger.debug') def test_custodian_azure_send_override_200(self, logger): mock = Mock() mock.send = types.MethodType(custodian_azure_send_override, mock) response_dict = { 'headers': {'x-ms-ratelimit-remaining-subscription-reads': '12000'}, 'status_code': 200 } mock.orig_send.return_value = type(str('response'), (), response_dict) mock.send('') self.assertEqual(mock.orig_send.call_count, 1) self.assertEqual(logger.call_count, 2) @patch('c7n_azure.utils.send_logger.debug') @patch('c7n_azure.utils.send_logger.warning') def test_custodian_azure_send_override_429(self, logger_debug, logger_warning): mock = Mock() mock.send = types.MethodType(custodian_azure_send_override, mock) response_dict = { 'headers': {'Retry-After': 0}, 'status_code': 429 } mock.orig_send.return_value = type(str('response'), (), response_dict) mock.send('') self.assertEqual(mock.orig_send.call_count, 3) self.assertEqual(logger_debug.call_count, 3) self.assertEqual(logger_warning.call_count, 3) @patch('c7n_azure.utils.send_logger.error') def test_custodian_azure_send_override_429_long_retry(self, logger): mock = Mock() mock.send = types.MethodType(custodian_azure_send_override, mock) response_dict = { 'headers': {'Retry-After': 60}, 'status_code': 429 } mock.orig_send.return_value = type(str('response'), (), response_dict) mock.send('') self.assertEqual(mock.orig_send.call_count, 1) self.assertEqual(logger.call_count, 1) managed_group_return_value = ([ Bag({'name': '/providers/Microsoft.Management/managementGroups/cc-test-1', 'type': '/providers/Microsoft.Management/managementGroups'}), Bag({'name': '/providers/Microsoft.Management/managementGroups/cc-test-2', 'type': '/providers/Microsoft.Management/managementGroups'}), Bag({'name': DEFAULT_SUBSCRIPTION_ID, 'type': '/subscriptions'}), Bag({'name': GUID, 'type': '/subscriptions'}), ]) @patch('azure.mgmt.managementgroups.operations.EntitiesOperations.list', return_value=managed_group_return_value) def test_managed_group_helper(self, _1): sub_ids = ManagedGroupHelper.get_subscriptions_list('test-group', "") self.assertEqual(sub_ids, [DEFAULT_SUBSCRIPTION_ID, GUID]) @patch('msrestazure.azure_active_directory.MSIAuthentication') def test_get_keyvault_secret(self, _1): mock = Mock() mock.value = '{"client_id": "client", "client_secret": "secret"}' with patch('azure.common.credentials.ServicePrincipalCredentials.__init__', return_value=None), \ patch('azure.keyvault.v7_0.KeyVaultClient.get_secret', return_value=mock): reload(sys.modules['c7n_azure.utils']) result = get_keyvault_secret(None, 'https://testkv.vault.net/secrets/testsecret/123412') self.assertEqual(result, mock.value) # Test relies on substitute data in Azure Common, not designed for live data @pytest.mark.skiplive def test_get_service_tag_ip_space(self): # Get with region result = get_service_tag_ip_space('ApiManagement', 'WestUS') self.assertEqual(3, len(result)) self.assertEqual({"13.64.39.16/32", "40.112.242.148/31", "40.112.243.240/28"}, set(result)) # Get without region result = get_service_tag_ip_space('ApiManagement') self.assertEqual(5, len(result)) self.assertEqual({"13.69.64.76/31", "13.69.66.144/28", "23.101.67.140/32", "51.145.179.78/32", "137.117.160.56/32"}, set(result)) # Invalid tag result = get_service_tag_ip_space('foo') self.assertEqual(0, len(result)) def test_is_resource_group_id(self): self.assertTrue(is_resource_group_id('/subscriptions/GUID/resourceGroups/rg')) self.assertTrue(is_resource_group_id('/subscriptions/GUID/resourceGroups/rg/')) self.assertTrue(is_resource_group_id('/Subscriptions/GUID/resourcegroups/rg')) self.assertFalse(is_resource_group_id('/subscriptions/GUID/rg/')) self.assertFalse(is_resource_group_id('subscriptions/GUID/rg/')) self.assertFalse(is_resource_group_id('/GUID/rg/')) self.assertFalse(is_resource_group_id('/subscriptions/GUID/rg/providers/vm/vm')) self.assertFalse(is_resource_group_id('/subscriptions/GUID/rg/providers')) self.assertFalse(is_resource_group_id('/subscriptions/GUID/rg/p')) def test_is_resource_group(self): self.assertTrue(is_resource_group({'type': 'resourceGroups'})) self.assertFalse(is_resource_group({'type': 'virtualMachines'}))
def test_default_account_id_assume(self): config = Bag(assume_role='arn:aws:iam::644160558196:role/custodian-mu', account_id=None) aws._default_account_id(config) self.assertEqual(config.account_id, '644160558196')