Example #1
0
    def test_default_log_group(self):
        ctx = Bag(session_factory=None,
                  options=Bag(account_id='001100', region='us-east-1'),
                  policy=Bag(name='test', resource_type='ec2'))

        log_output = output.log_outputs.select('custodian/xyz', ctx)
        self.assertEqual(log_output.log_group, 'custodian/xyz')
        self.assertEqual(log_output.construct_stream_name(), 'test')

        log_output = output.log_outputs.select('/custodian/xyz/', ctx)
        self.assertEqual(log_output.log_group, 'custodian/xyz')

        log_output = output.log_outputs.select('aws://somewhere/out/there',
                                               ctx)
        self.assertEqual(log_output.log_group, 'somewhere/out/there')

        log_output = output.log_outputs.select('aws:///somewhere/out', ctx)
        self.assertEqual(log_output.log_group, 'somewhere/out')

        log_output = output.log_outputs.select('aws://somewhere', ctx)
        self.assertEqual(log_output.log_group, 'somewhere')

        log_output = output.log_outputs.select(
            "aws:///somewhere/out?stream={region}/{policy}", ctx)
        self.assertEqual(log_output.log_group, 'somewhere/out')
        self.assertEqual(log_output.construct_stream_name(), 'us-east-1/test')
Example #2
0
def eperm(provider, el, r=None):
    if el.permissions:
        return el.permissions
    element_type = get_element_type(el)
    if r is None or r.type is None:
        # dummy resource type for policy
        if provider == 'aws':
            r = Bag({'type': 'kinesis'})
        elif provider == 'gcp':
            r = Bag({'type': 'instance'})
        elif provider == 'azure':
            r = Bag({'type': 'vm'})

    # print(f'policy construction lookup {r.type}.{element_type}.{el.type}')

    loader = PolicyLoader(Config.empty())
    pdata = {
        'name': f'permissions-{r.type}',
        'resource': f'{provider}.{r.type}'
    }
    pdata[element_type] = get_element_data(element_type, el)

    try:
        pset = loader.load_data({'policies': [pdata]}, ':mem:', validate=False)
    except Exception as e:
        print(f'error loading {el} as {element_type}:{el.type} error: {e} \n {pdata}')
        return []
    el = get_policy_element(el, list(pset)[0])
    return el.get_permissions()
Example #3
0
    def test_metrics_output(self):
        project_id = 'cloud-custodian'
        factory = self.replay_flight_data('output-metrics',
                                          project_id=project_id)
        ctx = Bag(session_factory=factory,
                  policy=Bag(name='custodian-works',
                             resource_type='gcp.function'))
        conf = Bag()
        metrics = StackDriverMetrics(ctx, conf)
        metrics.put_metric('ResourceCount', 43, 'Count', Scope='Policy')
        metrics.flush()

        if self.recording:
            time.sleep(42)

        session = factory()
        client = session.client('monitoring', 'v3', 'projects.timeSeries')
        results = client.execute_command(
            'list', {
                'name':
                'projects/{}'.format(project_id),
                'filter':
                'metric.type="custom.googleapis.com/custodian/policy/resourcecount"',
                'pageSize':
                3,
                'interval_startTime':
                (datetime.datetime.utcnow() -
                 datetime.timedelta(minutes=5)).isoformat('T') + 'Z',
                'interval_endTime':
                datetime.datetime.utcnow().isoformat('T') + 'Z'
            })
        self.assertEqual(results['timeSeries'], [{
            u'metric': {
                u'labels': {
                    u'policy': u'custodian-works',
                    u'project_id': u'cloud-custodian'
                },
                u'type':
                u'custom.googleapis.com/custodian/policy/resourcecount'
            },
            u'metricKind':
            u'GAUGE',
            u'points': [{
                u'interval': {
                    u'endTime': u'2018-08-12T22:30:53.524505Z',
                    u'startTime': u'2018-08-12T22:30:53.524505Z'
                },
                u'value': {
                    u'int64Value': u'43'
                }
            }],
            u'resource': {
                u'labels': {
                    u'project_id': u'cloud-custodian'
                },
                u'type': u'global'
            },
            u'valueType':
            u'INT64'
        }])
Example #4
0
    def test_metrics_destination_dims(self):
        tmetrics = []

        class Metrics(aws.MetricsOutput):

            def _put_metrics(self, ns, metrics):
                tmetrics.extend(metrics)

        conf = Bag({'region': 'us-east-2', 'scheme': 'aws', 'netloc': 'master'})
        ctx = Bag(session_factory=None,
                  options=Bag(account_id='001100', region='us-east-1'),
                  policy=Bag(name='test', resource_type='ec2'))
        moutput = Metrics(ctx, conf)

        moutput.put_metric('Calories', 400, 'Count', Scope='Policy', Food='Pizza')
        moutput.flush()

        tmetrics[0].pop('Timestamp')
        self.assertEqual(tmetrics, [{
            'Dimensions': [{'Name': 'Policy', 'Value': 'test'},
                           {'Name': 'ResType', 'Value': 'ec2'},
                           {'Name': 'Food', 'Value': 'Pizza'},
                           {'Name': 'Region', 'Value': 'us-east-1'},
                           {'Name': 'Account', 'Value': '001100'}],
            'MetricName': 'Calories',
            'Unit': 'Count',
            'Value': 400}])
Example #5
0
 def test_metrics(self):
     session_factory = self.replay_flight_data('output-aws-metrics')
     policy = Bag(name='test', resource_type='ec2')
     ctx = Bag(session_factory=session_factory, policy=policy)
     sink = output.metrics_outputs.select('aws', ctx)
     self.assertTrue(isinstance(sink, aws.MetricsOutput))
     sink.put_metric('ResourceCount', 101, 'Count')
     sink.flush()
Example #6
0
 def test_stream_override(self):
     session_factory = self.replay_flight_data('test_log_stream_override')
     ctx = Bag(session_factory=session_factory,
               options=Bag(account_id='001100', region='us-east-1'),
               policy=Bag(name='test', resource_type='ec2'))
     log_output = output.log_outputs.select(
         'aws://master/custodian?region=us-east-2&stream=testing', ctx)
     stream = log_output.get_handler()
     self.assertTrue(stream.log_stream == 'testing')
Example #7
0
 def test_app_insights_logs(self):
     policy = Bag(name='test',
                  resource_type='azure.vm',
                  session_factory=Session)
     ctx = Bag(policy=policy,
               execution_id='00000000-0000-0000-0000-000000000000')
     with log_outputs.select('azure://00000000-0000-0000-0000-000000000000',
                             ctx) as log:
         self.assertTrue(isinstance(log, AppInsightsLogOutput))
         logging.getLogger('custodian.test').warning('test message')
Example #8
0
 def test_app_insights_metrics(self):
     policy = Bag(name='test',
                  resource_type='azure.vm',
                  session_factory=Session)
     ctx = Bag(policy=policy,
               execution_id='00000000-0000-0000-0000-000000000000')
     sink = metrics_outputs.select(
         'azure://00000000-0000-0000-0000-000000000000', ctx)
     self.assertTrue(isinstance(sink, MetricsOutput))
     sink.put_metric('ResourceCount', 101, 'Count')
     sink.flush()
Example #9
0
 def test_metrics_output_set_write_project_id(self):
     project_id = 'cloud-custodian-sub'
     write_project_id = 'cloud-custodian'
     factory = self.replay_flight_data('output-metrics',
                                       project_id=project_id)
     ctx = Bag(session_factory=factory,
               policy=Bag(name='custodian-works',
                          resource_type='gcp.function'))
     conf = Bag(project_id=write_project_id)
     metrics = StackDriverMetrics(ctx, conf)
     metrics.put_metric('ResourceCount', 43, 'Count', Scope='Policy')
     metrics.flush()
Example #10
0
def validate(options):
    load_resources()
    if len(options.configs) < 1:
        log.error('no config files specified')
        sys.exit(1)

    used_policy_names = set()
    schm = schema.generate()
    errors = []

    for config_file in options.configs:
        config_file = os.path.expanduser(config_file)
        if not os.path.exists(config_file):
            raise ValueError("Invalid path for config %r" % config_file)

        options.dryrun = True
        fmt = config_file.rsplit('.', 1)[-1]
        with open(config_file) as fh:
            if fmt in ('yml', 'yaml'):
                data = yaml.safe_load(fh.read())
            elif fmt in ('json', ):
                data = json.load(fh)
            else:
                log.error("The config file must end in .json, .yml or .yaml.")
                raise ValueError(
                    "The config file must end in .json, .yml or .yaml.")

        errors += schema.validate(data, schm)
        conf_policy_names = {
            p.get('name', 'unknown')
            for p in data.get('policies', ())
        }
        dupes = conf_policy_names.intersection(used_policy_names)
        if len(dupes) >= 1:
            errors.append(
                ValueError(
                    "Only one policy with a given name allowed, duplicates: %s"
                    % (", ".join(dupes))))
        used_policy_names = used_policy_names.union(conf_policy_names)
        if not errors:
            null_config = Config.empty(dryrun=True,
                                       account_id='na',
                                       region='na')
            for p in data.get('policies', ()):
                try:
                    policy = Policy(p, null_config, Bag())
                    policy.validate()
                except Exception as e:
                    msg = "Policy: %s is invalid: %s" % (p.get(
                        'name', 'unknown'), e)
                    errors.append(msg)
        if not errors:
            log.info("Configuration valid: {}".format(config_file))
            continue

        log.error("Configuration invalid: {}".format(config_file))
        for e in errors:
            log.error("%s" % e)
    if errors:
        sys.exit(1)
Example #11
0
 def test_azure_function_resolved_uai_identity(self):
     session = mock.MagicMock()
     p = self.load_policy({
         'name': 'sm',
         'resource': 'azure.vm',
         'mode': {
             'type': FUNCTION_EVENT_TRIGGER_MODE,
             'events': ['VmWrite'],
             'provision-options': {
                 'identity': {
                     'type': 'UserAssigned',
                     'id': 'mike'
                 }
             }
         }
     })
     exec_mode = p.get_execution_mode()
     uai = dict(name='mike',
                id='/subscriptions/xyz/userAssignedIdentities/foo',
                client_id='bob')
     session.client(
         'azure.mgmt.msi.ManagedServiceIdentityClient'
     ).user_assigned_identities.list_by_subscription.return_value = [
         Bag(uai)
     ]
     identity = exec_mode._get_identity(session)
     self.assertEqual(
         identity, {
             'type': 'UserAssigned',
             'client_id': 'bob',
             'id': '/subscriptions/xyz/userAssignedIdentities/foo'
         })
Example #12
0
 def get_context(self, config=None, policy=None):
     if config is None:
         self.context_output_dir = self.get_temp_dir()
         config = Config.empty(output_dir=self.context_output_dir)
     ctx = ExecutionContext(Session, policy or Bag({'name': 'test-policy'}),
                            config)
     return ctx
Example #13
0
 def __init__(self, *args, **kw):
     super(XrayContext, self).__init__(*args, **kw)
     self.sampler = LocalSampler()
     # We want process global semantics as policy execution
     # can span threads.
     self._local = Bag()
     self._current_subsegment = None
Example #14
0
 def get_context(self, config=None, session_factory=None, policy=None):
     if config is None:
         self.context_output_dir = self.get_temp_dir()
         config = Config.empty(output_dir=self.context_output_dir)
     ctx = ExecutionContext(session_factory, policy
                            or Bag({"name": "test-policy"}), config)
     return ctx
    def test_verify_parent_filter(self):
        manager = KeyVaultKeys(
            ExecutionContext(None, Bag(name="xyz", provider_name='azure'),
                             Config.empty()), {
                                 'name':
                                 'test-policy',
                                 'resource':
                                 'azure.keyvault-keys',
                                 'filters': [{
                                     'type': 'parent',
                                     'filter': {
                                         'type': 'value',
                                         'key': 'name',
                                         'op': 'glob',
                                         'value': 'cctestkv*'
                                     }
                                 }]
                             })

        self.assertEqual(len(manager.filters), 1)

        filter = manager.filters[0]
        self.assertTrue(isinstance(filter, ParentFilter))
        self.assertTrue(isinstance(filter.parent_manager, KeyVault))
        self.assertTrue(isinstance(filter.parent_filter, ValueFilter))
Example #16
0
 def test_stream_override(self):
     session_factory = self.replay_flight_data(
         'test_log_stream_override')
     conf = Bag({
         'region': 'us-east-2',
         'scheme': 'aws',
         'netloc': 'master',
         'path': 'custodian',
         'stream': "testing"
     })
     ctx = Bag(session_factory=session_factory,
         options=Bag(account_id='001100', region='us-east-1'),
         policy=Bag(name='test', resource_type='ec2'))
     output = aws.CloudWatchLogOutput(ctx, conf)
     stream = output.get_handler()
     self.assertTrue(stream.log_stream == 'testing')
Example #17
0
def run(organization, hook_context, github_url, github_token,
        verbose, metrics=False, since=None, assume=None, region=None):
    """scan org repo status hooks"""
    logging.basicConfig(level=logging.DEBUG)

    since = dateparser.parse(
        since, settings={
            'RETURN_AS_TIMEZONE_AWARE': True, 'TO_TIMEZONE': 'UTC'})

    headers = {"Authorization": "token {}".format(github_token)}

    response = requests.post(
        github_url, headers=headers,
        json={'query': query, 'variables': {'organization': organization}})

    result = response.json()

    if response.status_code != 200 or 'errors' in result:
        raise Exception(
            "Query failed to run by returning code of {}. {}".format(
                response.status_code, response.content))

    now = datetime.now(tzutc())
    stats = Counter()
    repo_metrics = RepoMetrics(
        Bag(session_factory=SessionFactory(region, assume_role=assume)),
        {'namespace': DEFAULT_NAMESPACE}
    )

    for r in result['data']['organization']['repositories']['nodes']:
        commits = jmespath.search(
            'pullRequests.edges[].node[].commits[].nodes[].commit[]', r)
        if not commits:
            continue
        log.debug("processing repo: %s prs: %d", r['name'], len(commits))
        repo_metrics.dims = {
            'Hook': hook_context,
            'Repo': '{}/{}'.format(organization, r['name'])}

        # Each commit represents a separate pr
        for c in commits:
            process_commit(c, r, repo_metrics, stats, since, now)

    repo_metrics.dims = None

    if stats['missing']:
        repo_metrics.put_metric(
            'RepoHookPending', stats['missing'], 'Count',
            Hook=hook_context)
        repo_metrics.put_metric(
            'RepoHookLatency', stats['missing_time'], 'Seconds',
            Hook=hook_context)

    if not metrics:
        print(dumps(repo_metrics.buf, indent=2))
        return
    else:
        repo_metrics.BUF_SIZE = 20
        repo_metrics.flush()
Example #18
0
    def test_tracer(self):
        session_factory = self.replay_flight_data('output-xray-trace')
        policy = Bag(name='test', resource_type='ec2')
        ctx = Bag(
            policy=policy,
            session_factory=session_factory,
            options=Bag(account_id='644160558196'))
        ctx.get_metadata = lambda *args: {}
        config = Bag()
        tracer = aws.XrayTracer(ctx, config)

        with tracer:
            try:
                with tracer.subsegment('testing') as w:
                    raise ValueError()
            except ValueError:
                pass
            self.assertNotEqual(w.cause, {})
Example #19
0
def get_session(role, session_name, profile):
    region = os.environ.get('AWS_DEFAULT_REGION', 'eu-west-1')
    stats = ApiStats(Bag(), Config.empty())
    if role:
        s = assumed_session(role, session_name, region=region)
    else:
        s = SessionFactory(region, profile)()
    stats(s)
    return stats, s
Example #20
0
def get_log_output(request, output_url):
    log = StackDriverLogging(
        ExecutionContext(lambda assume=False: mock.MagicMock(),
                         Bag(name="xyz",
                             provider_name="gcp",
                             resource_type='gcp.function'),
                         Config.empty(account_id='custodian-test')),
        parse_url_config(output_url))
    request.addfinalizer(reset_session_cache)
    return log
Example #21
0
    def test_iam_permissions_validity(self):
        cfg = Config.empty()
        missing = set()
        all_invalid = []

        perms = load_data('iam-actions.json')

        for k, v in manager.resources.items():
            p = Bag({'name': 'permcheck', 'resource': k, 'provider_name': 'aws'})
            ctx = self.get_context(config=cfg, policy=p)
            mgr = v(ctx, p)
            invalid = []
            # if getattr(mgr, 'permissions', None):
            #    print(mgr)

            found = False
            for s in (mgr.resource_type.service,
                      getattr(mgr.resource_type, 'permission_prefix', None)):
                if s in perms:
                    found = True
            if not found:
                missing.add("%s->%s" % (k, mgr.resource_type.service))
                continue
            invalid.extend(self.check_permissions(perms, mgr.get_permissions(), k))

            for n, a in v.action_registry.items():
                p['actions'] = [n]
                invalid.extend(
                    self.check_permissions(
                        perms, a({}, mgr).get_permissions(),
                        "{k}.actions.{n}".format(k=k, n=n)))

            for n, f in v.filter_registry.items():
                if n in ('or', 'and', 'not', 'missing'):
                    continue
                p['filters'] = [n]
                invalid.extend(
                    self.check_permissions(
                        perms, f({}, mgr).get_permissions(),
                        "{k}.filters.{n}".format(k=k, n=n)))

            if invalid:
                for k, perm_set in invalid:
                    perm_set = [i for i in perm_set
                                if not i.startswith('elasticloadbalancing')]
                    if perm_set:
                        all_invalid.append((k, perm_set))

        if missing:
            raise ValueError(
                "resources missing service %s" % ('\n'.join(sorted(missing))))

        if all_invalid:
            raise ValueError(
                "invalid permissions \n %s" % ('\n'.join(sorted(map(str, all_invalid)))))
Example #22
0
    def get_azure_output(self):
        output = AzureStorageOutput(
            ExecutionContext(
                None,
                Bag(name="xyz"),
                Config.empty(
                    output_dir="azure://mystorage.blob.core.windows.net/logs"),
            ))
        self.addCleanup(shutil.rmtree, output.root_dir)

        return output
Example #23
0
    def test_check_permissions(self):
        load_resources(('gcp.*', ))
        missing = []
        invalid = []
        iam_path = os.path.join(os.path.dirname(__file__), 'data',
                                'iam-permissions.json')
        with open(iam_path) as fh:
            valid_perms = set(json.load(fh).get('permissions'))
        cfg = Config.empty()

        for k, v in resources.items():
            policy = Bag({
                'name': 'permcheck',
                'resource': 'gcp.%s' % k,
                'provider_name': 'gcp'
            })
            ctx = self.get_context(config=cfg, policy=policy)
            mgr = v(ctx, policy)
            perms = mgr.get_permissions()
            if not perms:
                missing.append(k)
            for p in perms:
                if p not in valid_perms:
                    invalid.append((k, p))

            for n, a in list(v.action_registry.items()):
                if n in ALLOWED_NOPERM:
                    continue
                policy['actions'] = [n]
                perms = a({}, mgr).get_permissions()
                if not perms:
                    missing.append('%s.actions.%s' % (k, n))
                for p in perms:
                    if p not in valid_perms:
                        invalid.append(('%s.actions.%s' % (k, n), p))

            for n, f in list(v.filter_registry.items()):
                if n in ALLOWED_NOPERM:
                    continue
                policy['filters'] = [n]
                perms = f({}, mgr).get_permissions()
                if not perms:
                    missing.append('%s.filters.%s' % (k, n))
                for p in perms:
                    if p not in valid_perms:
                        invalid.append(('%s.filters.%s' % (k, n), p))

        if missing:
            self.fail('missing permissions %d on \n\t%s' %
                      (len(missing), '\n\t'.join(sorted(missing))))

        if invalid:
            self.fail('invalid permissions %d on \n\t%s' %
                      (len(invalid), '\n\t'.join(map(str, sorted(invalid)))))
Example #24
0
    def get_azure_output(self, custom_pyformat=None):
        output_dir = "azure://mystorage.blob.core.windows.net/logs"
        if custom_pyformat:
            output_dir = AzureStorageOutput.join(output_dir, custom_pyformat)

        output = AzureStorageOutput(
            ExecutionContext(None, Bag(name="xyz", provider_name='azure'),
                             Config.empty(output_dir=output_dir)),
            {'url': output_dir},
        )
        self.addCleanup(shutil.rmtree, output.root_dir)

        return output
Example #25
0
def watch(limit):
    """watch scan rates across the cluster"""
    period = 5.0
    prev = db.db()
    prev_totals = None

    while True:
        click.clear()
        time.sleep(period)
        cur = db.db()
        cur.data['gkrate'] = {}
        progress = []
        prev_buckets = {b.bucket_id: b for b in prev.buckets()}

        totals = {'scanned': 0, 'krate': 0, 'lrate': 0, 'bucket_id': 'totals'}

        for b in cur.buckets():
            if not b.scanned:
                continue

            totals['scanned'] += b.scanned
            totals['krate'] += b.krate
            totals['lrate'] += b.lrate

            if b.bucket_id not in prev_buckets:
                b.data['gkrate'][b.bucket_id] = b.scanned / period
            elif b.scanned == prev_buckets[b.bucket_id].scanned:
                continue
            else:
                b.data['gkrate'][b.bucket_id] = (
                    b.scanned - prev_buckets[b.bucket_id].scanned) / period
            progress.append(b)

        if prev_totals is None:
            totals['gkrate'] = '...'
        else:
            totals['gkrate'] = (totals['scanned'] -
                                prev_totals['scanned']) / period
        prev = cur
        prev_totals = totals

        progress = sorted(progress, key=lambda x: x.gkrate, reverse=True)

        if limit:
            progress = progress[:limit]

        progress.insert(0, Bag(totals))
        format_plain(progress,
                     None,
                     explicit_only=True,
                     keys=['bucket_id', 'scanned', 'gkrate', 'lrate', 'krate'])
Example #26
0
def get_blob_output(request, output_url=None, cleanup=True):
    if output_url is None:
        output_url = "gs://cloud-custodian/policies"
    output = GCPStorageOutput(
        ExecutionContext(lambda assume=False: mock.MagicMock(),
                         Bag(name="xyz", provider_name="gcp"),
                         Config.empty(output_dir=output_url,
                                      account_id='custodian-test')),
        parse_url_config(output_url))

    if cleanup:
        request.addfinalizer(lambda: shutil.rmtree(output.root_dir))  # noqa
    request.addfinalizer(reset_session_cache)
    return output
Example #27
0
 def test_app_insights_metrics(self, put_mock):
     policy = self.load_policy({
         'name': 'test-rg',
         'resource': 'azure.resourcegroup'
     })
     ctx = Bag(policy=policy, execution_id='00000000-0000-0000-0000-000000000000')
     sink = metrics_outputs.select('azure://00000000-0000-0000-0000-000000000000', ctx)
     self.assertTrue(isinstance(sink, MetricsOutput))
     sink.put_metric('ResourceCount', 101, 'Count')
     sink.flush()
     put_mock.assert_called_once_with(
         'test-rg',
         [{
             'Name': 'ResourceCount',
             'Value': 101,
             'Dimensions':
                 {'Policy': 'test-rg',
                  'ResType': 'azure.resourcegroup',
                  'SubscriptionId': local_session(Session).get_subscription_id(),
                  'ExecutionId': '00000000-0000-0000-0000-000000000000',
                  'ExecutionMode': 'pull',
                  'Unit': 'Count'}}])
Example #28
0
    def test_tracer(self):
        session_factory = self.replay_flight_data('output-xray-trace')
        policy = Bag(name='test', resource_type='ec2')
        ctx = Bag(policy=policy,
                  session_factory=session_factory,
                  options=Bag(account_id='644160558196'))
        ctx.get_metadata = lambda *args: {}
        config = Bag()
        tracer = aws.XrayTracer(ctx, config)

        with tracer:
            try:
                with tracer.subsegment('testing') as w:
                    raise ValueError()
            except ValueError:
                pass
            self.assertNotEqual(w.cause, {})
Example #29
0
class BaseTest(TestUtils, AzureVCRBaseTest):

    test_context = ExecutionContext(
        Session,
        Bag(name="xyz", provider_name='azure'),
        Config.empty()
    )

    """ Azure base testing class.
    """
    def __init__(self, *args, **kwargs):
        super(BaseTest, self).__init__(*args, **kwargs)
        self._requires_polling = False

    @classmethod
    def setUpClass(cls, *args, **kwargs):
        super(BaseTest, cls).setUpClass(*args, **kwargs)
        if os.environ.get(constants.ENV_ACCESS_TOKEN) == "fake_token":
            cls._token_patch = patch(
                'c7n_azure.session.jwt.decode',
                return_value={'tid': DEFAULT_TENANT_ID})
            cls._token_patch.start()

    @classmethod
    def tearDownClass(cls, *args, **kwargs):
        super(BaseTest, cls).tearDownClass(*args, **kwargs)
        if os.environ.get(constants.ENV_ACCESS_TOKEN) == "fake_token":
            cls._token_patch.stop()

    def setUp(self):
        super(BaseTest, self).setUp()
        ThreadHelper.disable_multi_threading = True

        # We always patch the date for recordings so URLs that involve dates match up
        if self.vcr_enabled:
            self._utc_patch = patch.object(utils, 'utcnow', self._get_test_date)
            self._utc_patch.start()
            self.addCleanup(self._utc_patch.stop)

            self._now_patch = patch.object(utils, 'now', self._get_test_date)
            self._now_patch.start()
            self.addCleanup(self._now_patch.stop)

        if not self._requires_polling:
            # Patch Poller with constructor that always disables polling
            # This breaks blocking on long running operations (resource creation).
            self._lro_patch = patch.object(msrest.polling.LROPoller, '__init__', BaseTest.lro_init)
            self._lro_patch.start()
            self.addCleanup(self._lro_patch.stop)

        if self.is_playback():
            if self._requires_polling:
                # If using polling we need to monkey patch the timeout during playback
                # or we'll have long sleeps introduced into our test runs
                Session._old_client = Session.client
                Session.client = BaseTest.session_client_wrapper
                self.addCleanup(BaseTest.session_client_cleanup)

            if constants.ENV_ACCESS_TOKEN in os.environ:
                self._tenant_patch = patch('c7n_azure.session.Session.get_tenant_id',
                                           return_value=DEFAULT_TENANT_ID)
                self._tenant_patch.start()
                self.addCleanup(self._tenant_patch.stop)

            self._subscription_patch = patch('c7n_azure.session.Session.get_subscription_id',
                                             return_value=DEFAULT_SUBSCRIPTION_ID)
            self._subscription_patch.start()
            self.addCleanup(self._subscription_patch.stop)

        self.session = local_session(Session)

    def _get_test_date(self, tz=None):
        header_date = self.cassette.responses[0]['headers'].get('date') \
            if self.cassette.responses else None

        if header_date:
            test_date = datetime.datetime(*eut.parsedate(header_date[0])[:6])
        else:
            return datetime.datetime.now(tz=tz)
        return test_date.replace(hour=23, minute=59, second=59, microsecond=0)

    def sleep_in_live_mode(self, interval=60):
        if not self.is_playback():
            sleep(interval)

    @staticmethod
    def setup_account():
        # Find actual name of storage account provisioned in our test environment
        s = Session()
        client = s.client('azure.mgmt.storage.StorageManagementClient')
        accounts = list(client.storage_accounts.list())
        matching_account = [a for a in accounts if a.name.startswith("cctstorage")]
        return matching_account[0]

    @staticmethod
    def sign_out_patch():
        return patch.dict(os.environ,
                          {
                              constants.ENV_TENANT_ID: '',
                              constants.ENV_SUB_ID: '',
                              constants.ENV_CLIENT_ID: '',
                              constants.ENV_CLIENT_SECRET: ''
                          }, clear=True)

    @staticmethod
    def lro_init(self, client, initial_response, deserialization_callback, polling_method):
        self._client = client if isinstance(client, ServiceClient) else client._client
        self._response = initial_response.response if \
            isinstance(initial_response, ClientRawResponse) else \
            initial_response
        self._callbacks = []  # type List[Callable]
        self._polling_method = msrest.polling.NoPolling()

        if isinstance(deserialization_callback, type) and \
                issubclass(deserialization_callback, Model):
            deserialization_callback = deserialization_callback.deserialize

        # Might raise a CloudError
        self._polling_method.initialize(self._client, self._response, deserialization_callback)

        self._thread = None
        self._done = None
        self._exception = None

    @staticmethod
    def session_client_cleanup():
        Session.client = Session._old_client

    @staticmethod
    def session_client_wrapper(self, client):
        client = Session._old_client(self, client)
        client.config.long_running_operation_timeout = 0
        return client
Example #30
0
class UtilsTest(BaseTest):
    def setUp(self):
        super(UtilsTest, self).setUp()

    def test_get_subscription_id(self):
        self.assertEqual(ResourceIdParser.get_subscription_id(RESOURCE_ID), DEFAULT_SUBSCRIPTION_ID)

    def test_get_namespace(self):
        self.assertEqual(ResourceIdParser.get_namespace(RESOURCE_ID), "Microsoft.Compute")
        self.assertEqual(ResourceIdParser.get_namespace(RESOURCE_ID_CHILD), "Microsoft.Sql")

    def test_get_resource_group(self):
        self.assertEqual(ResourceIdParser.get_resource_group(RESOURCE_ID), "rgtest")

    def test_get_resource_type(self):

        self.assertEqual(ResourceIdParser.get_resource_type(RESOURCE_ID), "virtualMachines")
        self.assertEqual(ResourceIdParser.get_resource_type(RESOURCE_ID_CHILD), "servers/databases")

    def test_get_full_type(self):
        self.assertEqual(ResourceIdParser.get_full_type(RESOURCE_ID),
                         "Microsoft.Compute/virtualMachines")

    def test_resource_name(self):
        self.assertEqual(ResourceIdParser.get_resource_name(RESOURCE_ID), "nametest")

    def test_math_mean(self):
        self.assertEqual(Math.mean([4, 5, None, 3]), 4)
        self.assertEqual(Math.mean([None]), 0)
        self.assertEqual(Math.mean([3, 4]), 3.5)

    def test_math_sum(self):
        self.assertEqual(Math.sum([4, 5, None, 3]), 12)
        self.assertEqual(Math.sum([None]), 0)
        self.assertEqual(Math.sum([3.5, 4]), 7.5)

    def test_string_utils_equal(self):
        # Case insensitive matches
        self.assertTrue(StringUtils.equal("FOO", "foo"))
        self.assertTrue(StringUtils.equal("fOo", "FoO"))
        self.assertTrue(StringUtils.equal("ABCDEFGH", "abcdefgh"))
        self.assertFalse(StringUtils.equal("Foo", "Bar"))

        # Case sensitive matches
        self.assertFalse(StringUtils.equal("Foo", "foo", False))
        self.assertTrue(StringUtils.equal("foo", "foo", False))
        self.assertTrue(StringUtils.equal("fOo", "fOo", False))
        self.assertFalse(StringUtils.equal("Foo", "Bar"))

        # Strip whitespace matches
        self.assertTrue(StringUtils.equal(" Foo ", "foo"))
        self.assertTrue(StringUtils.equal("Foo", " foo "))
        self.assertTrue(StringUtils.equal(" Foo ", "Foo", False))
        self.assertTrue(StringUtils.equal("Foo", " Foo ", False))

        # Returns false for non string types
        self.assertFalse(StringUtils.equal(1, "foo"))
        self.assertFalse(StringUtils.equal("foo", 1))
        self.assertFalse(StringUtils.equal(True, False))

    def test_get_tag_value(self):
        resource = {'tags': {'tag1': 'value1', 'tAg2': 'VaLuE2', 'TAG3': 'VALUE3'}}

        self.assertEqual(TagHelper.get_tag_value(resource, 'tag1', True), 'value1')
        self.assertEqual(TagHelper.get_tag_value(resource, 'tag2', True), 'VaLuE2')
        self.assertEqual(TagHelper.get_tag_value(resource, 'tag3', True), 'VALUE3')

    def test_get_ports(self):
        self.assertEqual(PortsRangeHelper.get_ports_set_from_string("5, 4-5, 9"), set([4, 5, 9]))
        rule = {'properties': {'destinationPortRange': '10-12'}}
        self.assertEqual(PortsRangeHelper.get_ports_set_from_rule(rule), set([10, 11, 12]))
        rule = {'properties': {'destinationPortRanges': ['80', '10-12']}}
        self.assertEqual(PortsRangeHelper.get_ports_set_from_rule(rule), set([10, 11, 12, 80]))

    def test_validate_ports_string(self):
        self.assertEqual(PortsRangeHelper.validate_ports_string('80'), True)
        self.assertEqual(PortsRangeHelper.validate_ports_string('22-26'), True)
        self.assertEqual(PortsRangeHelper.validate_ports_string('80,22'), True)
        self.assertEqual(PortsRangeHelper.validate_ports_string('80,22-26'), True)
        self.assertEqual(PortsRangeHelper.validate_ports_string('80,22-26,30-34'), True)
        self.assertEqual(PortsRangeHelper.validate_ports_string('65537'), False)
        self.assertEqual(PortsRangeHelper.validate_ports_string('-1'), False)
        self.assertEqual(PortsRangeHelper.validate_ports_string('10-8'), False)
        self.assertEqual(PortsRangeHelper.validate_ports_string('80,30,25-65538'), False)
        self.assertEqual(PortsRangeHelper.validate_ports_string('65536-65537'), False)

    def test_get_ports_strings_from_list(self):
        self.assertEqual(PortsRangeHelper.get_ports_strings_from_list([]),
                         [])
        self.assertEqual(PortsRangeHelper.get_ports_strings_from_list([10, 11]),
                         ['10-11'])
        self.assertEqual(PortsRangeHelper.get_ports_strings_from_list([10, 12, 13, 14]),
                         ['10', '12-14'])
        self.assertEqual(PortsRangeHelper.get_ports_strings_from_list([10, 12, 13, 14, 20, 21, 22]),
                         ['10', '12-14', '20-22'])

    def test_build_ports_dict(self):
        securityRules = [
            {'properties': {'destinationPortRange': '80-84',
                            'priority': 100,
                            'direction': 'Outbound',
                            'access': 'Allow',
                            'protocol': 'TCP'}},
            {'properties': {'destinationPortRange': '85-89',
                            'priority': 110,
                            'direction': 'Outbound',
                            'access': 'Allow',
                            'protocol': 'UDP'}},
            {'properties': {'destinationPortRange': '80-84',
                            'priority': 120,
                            'direction': 'Inbound',
                            'access': 'Deny',
                            'protocol': 'TCP'}},
            {'properties': {'destinationPortRange': '85-89',
                            'priority': 130,
                            'direction': 'Inbound',
                            'access': 'Deny',
                            'protocol': 'UDP'}},
            {'properties': {'destinationPortRange': '80-89',
                            'priority': 140,
                            'direction': 'Inbound',
                            'access': 'Allow',
                            'protocol': '*'}}]
        nsg = {'properties': {'securityRules': securityRules}}

        self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Inbound', 'TCP'),
                         {k: k > 84 for k in range(80, 90)})
        self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Inbound', 'UDP'),
                         {k: k < 85 for k in range(80, 90)})
        self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Inbound', '*'),
                         {k: False for k in range(80, 90)})
        self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Outbound', 'TCP'),
                         {k: True for k in range(80, 85)})
        self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Outbound', 'UDP'),
                         {k: True for k in range(85, 90)})
        self.assertEqual(PortsRangeHelper.build_ports_dict(nsg, 'Outbound', '*'),
                         {k: True for k in range(80, 90)})

    def test_snake_to_camel(self):
        self.assertEqual(StringUtils.snake_to_camel(""), "")
        self.assertEqual(StringUtils.snake_to_camel("test"), "test")
        self.assertEqual(StringUtils.snake_to_camel("test_abc"), "testAbc")
        self.assertEqual(StringUtils.snake_to_camel("test_abc_def"), "testAbcDef")

    def test_naming_hash(self):
        source = 'Lorem ipsum dolor sit amet'
        source2 = 'amet sit dolor ipsum Lorem'
        self.assertEqual(StringUtils.naming_hash(source), '16aba539')
        self.assertEqual(StringUtils.naming_hash(source, 10), '16aba5393a')
        self.assertNotEqual(StringUtils.naming_hash(source), StringUtils.naming_hash(source2))

    @patch('azure.mgmt.applicationinsights.operations.ComponentsOperations.get',
           return_value=type(str('result_data'), (), {'instrumentation_key': GUID}))
    def test_app_insights_get_instrumentation_key(self, mock_handler_run):
        self.assertEqual(AppInsightsHelper.get_instrumentation_key('azure://' + GUID), GUID)
        self.assertEqual(AppInsightsHelper.get_instrumentation_key('azure://resourceGroup/name'),
                         GUID)
        mock_handler_run.assert_called_once_with('resourceGroup', 'name')

    @patch('c7n_azure.utils.send_logger.debug')
    def test_custodian_azure_send_override_200(self, logger):
        mock = Mock()
        mock.send = types.MethodType(custodian_azure_send_override, mock)

        response_dict = {
            'headers': {'x-ms-ratelimit-remaining-subscription-reads': '12000'},
            'status_code': 200
        }
        mock.orig_send.return_value = type(str('response'), (), response_dict)
        mock.send('')

        self.assertEqual(mock.orig_send.call_count, 1)
        self.assertEqual(logger.call_count, 2)

    @patch('c7n_azure.utils.send_logger.debug')
    @patch('c7n_azure.utils.send_logger.warning')
    def test_custodian_azure_send_override_429(self, logger_debug, logger_warning):
        mock = Mock()
        mock.send = types.MethodType(custodian_azure_send_override, mock)

        response_dict = {
            'headers': {'Retry-After': 0},
            'status_code': 429
        }
        mock.orig_send.return_value = type(str('response'), (), response_dict)
        mock.send('')

        self.assertEqual(mock.orig_send.call_count, 3)
        self.assertEqual(logger_debug.call_count, 3)
        self.assertEqual(logger_warning.call_count, 3)

    @patch('c7n_azure.utils.send_logger.error')
    def test_custodian_azure_send_override_429_long_retry(self, logger):
        mock = Mock()
        mock.send = types.MethodType(custodian_azure_send_override, mock)

        response_dict = {
            'headers': {'Retry-After': 60},
            'status_code': 429
        }
        mock.orig_send.return_value = type(str('response'), (), response_dict)
        mock.send('')

        self.assertEqual(mock.orig_send.call_count, 1)
        self.assertEqual(logger.call_count, 1)

    managed_group_return_value = ([
        Bag({'name': '/providers/Microsoft.Management/managementGroups/cc-test-1',
             'type': '/providers/Microsoft.Management/managementGroups'}),
        Bag({'name': '/providers/Microsoft.Management/managementGroups/cc-test-2',
             'type': '/providers/Microsoft.Management/managementGroups'}),
        Bag({'name': DEFAULT_SUBSCRIPTION_ID,
             'type': '/subscriptions'}),
        Bag({'name': GUID,
             'type': '/subscriptions'}),
    ])

    @patch('azure.mgmt.managementgroups.operations.EntitiesOperations.list',
           return_value=managed_group_return_value)
    def test_managed_group_helper(self, _1):
        sub_ids = ManagedGroupHelper.get_subscriptions_list('test-group', "")
        self.assertEqual(sub_ids, [DEFAULT_SUBSCRIPTION_ID, GUID])

    @patch('msrestazure.azure_active_directory.MSIAuthentication')
    def test_get_keyvault_secret(self, _1):
        mock = Mock()
        mock.value = '{"client_id": "client", "client_secret": "secret"}'
        with patch('azure.common.credentials.ServicePrincipalCredentials.__init__',
                   return_value=None), \
                patch('azure.keyvault.v7_0.KeyVaultClient.get_secret', return_value=mock):

            reload(sys.modules['c7n_azure.utils'])

            result = get_keyvault_secret(None, 'https://testkv.vault.net/secrets/testsecret/123412')
            self.assertEqual(result, mock.value)

    # Test relies on substitute data in Azure Common, not designed for live data
    @pytest.mark.skiplive
    def test_get_service_tag_ip_space(self):
        # Get with region
        result = get_service_tag_ip_space('ApiManagement', 'WestUS')
        self.assertEqual(3, len(result))
        self.assertEqual({"13.64.39.16/32",
                          "40.112.242.148/31",
                          "40.112.243.240/28"}, set(result))

        # Get without region
        result = get_service_tag_ip_space('ApiManagement')
        self.assertEqual(5, len(result))
        self.assertEqual({"13.69.64.76/31",
                          "13.69.66.144/28",
                          "23.101.67.140/32",
                          "51.145.179.78/32",
                          "137.117.160.56/32"}, set(result))

        # Invalid tag
        result = get_service_tag_ip_space('foo')
        self.assertEqual(0, len(result))

    def test_is_resource_group_id(self):
        self.assertTrue(is_resource_group_id('/subscriptions/GUID/resourceGroups/rg'))
        self.assertTrue(is_resource_group_id('/subscriptions/GUID/resourceGroups/rg/'))
        self.assertTrue(is_resource_group_id('/Subscriptions/GUID/resourcegroups/rg'))

        self.assertFalse(is_resource_group_id('/subscriptions/GUID/rg/'))
        self.assertFalse(is_resource_group_id('subscriptions/GUID/rg/'))
        self.assertFalse(is_resource_group_id('/GUID/rg/'))
        self.assertFalse(is_resource_group_id('/subscriptions/GUID/rg/providers/vm/vm'))
        self.assertFalse(is_resource_group_id('/subscriptions/GUID/rg/providers'))
        self.assertFalse(is_resource_group_id('/subscriptions/GUID/rg/p'))

    def test_is_resource_group(self):
        self.assertTrue(is_resource_group({'type': 'resourceGroups'}))
        self.assertFalse(is_resource_group({'type': 'virtualMachines'}))
Example #31
0
 def test_default_account_id_assume(self):
     config = Bag(assume_role='arn:aws:iam::644160558196:role/custodian-mu', account_id=None)
     aws._default_account_id(config)
     self.assertEqual(config.account_id, '644160558196')