Example #1
0
    def test_unauthenticated(self):
        args = ["aws"]
        self.mocked_provider.provider_code = "aws"
        self.mocked_provider.authenticate = MagicMock(return_value=False)

        args = ScoutSuiteArgumentParser().parse_args(args)
        code = main(args)

        unauthenticated_code = 42
        assert (code == unauthenticated_code)
Example #2
0
    def test_empty(self):
        args = None
        code = None

        with patch("sys.stderr", return_value=MagicMock()):
            with self.assertRaises(SystemExit):
                args = ScoutSuiteArgumentParser().parse_args(args)
                code = main(args)

        assert (code is None)
Example #3
0
    def test_azure_provider(self):
        args = ["azure", "--cli"]
        self.mocked_provider.provider_code = "azure"

        args = ScoutSuiteArgumentParser().parse_args(args)
        code = main(args)

        success_code = 0
        assert (code == success_code)

        report_init_args = self.constructor[self.mocked_report].call_args_list[0][0]
        assert (report_init_args[0] == "azure")  # provider
        assert (report_init_args[1] == "azure")  # report_file_name
        assert (report_init_args[2] == "scoutsuite-report")  # report_dir
Example #4
0
    def test_gcp_provider(self):
        args = ["gcp", "--service-account", "fakecredentials"]
        self.mocked_provider.provider_code = "gcp"

        args = ScoutSuiteArgumentParser().parse_args(args)
        code = main(args)

        success_code = 0
        assert (code == success_code)

        report_init_args = self.constructor[self.mocked_report].call_args_list[0][0]
        assert (report_init_args[0] == "gcp")  # provider
        assert (report_init_args[1] == "gcp")  # report_file_name
        assert (report_init_args[2] == "scoutsuite-report")  # report_dir
Example #5
0
    def test_keyboardinterrupted(self):
        args = ["aws"]
        self.mocked_provider.provider_code = "aws"

        def _raise(e):
            raise e

        self.mocked_provider.fetch = MagicMock(side_effect=_raise(KeyboardInterrupt))

        args = ScoutSuiteArgumentParser().parse_args(args)
        code = main(args)

        keyboardinterrupted_code = 130
        assert (code == keyboardinterrupted_code)
def main(passed_args=None):
    """
    Main method that runs a scan

    :return:
    """

    # FIXME check that all requirements are installed
    # # Check version of opinel
    # requirements_file_path = '%s/requirements.txt' % os.path.dirname(sys.modules['__main__'].__file__)
    # if not check_requirements(requirements_file_path):
    #     return 42

    # Parse arguments
    parser = ScoutSuiteArgumentParser()

    if passed_args:
        args = parser.parse_args(passed_args)
    else:
        args = parser.parse_args()

    # Configure the debug level
    configPrintException(args.debug)

    # Create a cloud provider object
    cloud_provider = get_provider(provider=args.provider,
                                  profile=args.profile[0],
                                  project_id=args.project_id,
                                  folder_id=args.folder_id,
                                  organization_id=args.organization_id,
                                  report_dir=args.report_dir,
                                  timestamp=args.timestamp,
                                  services=args.services,
                                  skipped_services=args.skipped_services,
                                  thread_config=args.thread_config)

    if cloud_provider.provider_code == 'aws':
        if args.profile:
            report_file_name = 'aws-%s' % args.profile[0]
        else:
            report_file_name = 'aws'
    if cloud_provider.provider_code == 'gcp':
        if args.project_id:
            report_file_name = 'gcp-%s' % args.project_id
        elif args.organization_id:
            report_file_name = 'gcp-%s' % args.organization_id
        elif args.folder_id:
            report_file_name = 'gcp-%s' % args.folder_id
        else:
            report_file_name = 'gcp'
    if cloud_provider.provider_code == 'azure':
        report_file_name = 'azure'

    # Create a new report
    report = Scout2Report(args.provider, report_file_name, args.report_dir, args.timestamp)

    # Complete run, including pulling data from provider
    if not args.fetch_local:
        # Authenticate to the cloud provider
        authenticated = cloud_provider.authenticate(profile=args.profile[0],
                                                    csv_credentials=args.csv_credentials,
                                                    mfa_serial=args.mfa_serial,
                                                    mfa_code=args.mfa_code,
                                                    key_file=args.key_file,
                                                    user_account=args.user_account,
                                                    service_account=args.service_account,
                                                    azure_cli=args.azure_cli,
                                                    azure_msi=args.azure_msi,
                                                    azure_service_principal=args.azure_service_principal,
                                                    azure_file_auth=args.azure_file_auth,
                                                    azure_user_credentials=args.azure_user_credentials)

        if not authenticated:
            return 42

        # Fetch data from provider APIs
        try:
            cloud_provider.fetch(regions=args.regions)
        except KeyboardInterrupt:
            printInfo('\nCancelled by user')
            return 130

        # Update means we reload the whole config and overwrite part of it
        if args.update:
            current_run_services = copy.deepcopy(cloud_provider.services)
            last_run_dict = report.jsrw.load_from_file(AWSCONFIG)
            cloud_provider.services = last_run_dict['services']
            for service in cloud_provider.service_list:
                cloud_provider.services[service] = current_run_services[service]

    # Partial run, using pre-pulled data
    else:
        # Reload to flatten everything into a python dictionary
        last_run_dict = report.jsrw.load_from_file(AWSCONFIG)
        for key in last_run_dict:
            setattr(cloud_provider, key, last_run_dict[key])

    # Pre processing
    cloud_provider.preprocessing(args.ip_ranges, args.ip_ranges_name_key)

    # Analyze config
    finding_rules = Ruleset(environment_name=args.profile[0],
                            cloud_provider=args.provider,
                            filename=args.ruleset,
                            ip_ranges=args.ip_ranges,
                            aws_account_id=cloud_provider.aws_account_id)
    processing_engine = ProcessingEngine(finding_rules)
    processing_engine.run(cloud_provider)

    # Create display filters
    filter_rules = Ruleset(cloud_provider=args.provider,
                           filename='filters.json',
                           rule_type='filters',
                           aws_account_id=cloud_provider.aws_account_id)
    processing_engine = ProcessingEngine(filter_rules)
    processing_engine.run(cloud_provider)

    # Handle exceptions
    try:
        exceptions = RuleExceptions(args.profile[0], args.exceptions[0])
        exceptions.process(cloud_provider)
        exceptions = exceptions.exceptions
    except Exception as e:
        printDebug('Warning, failed to load exceptions. The file may not exist or may have an invalid format.')
        exceptions = {}

    # Finalize
    cloud_provider.postprocessing(report.current_time, finding_rules)

    # TODO this is AWS-specific - move to postprocessing?
    # Get organization data if it exists
    try:
        profile = AWSProfiles.get(args.profile[0])[0]
        if 'source_profile' in profile.attributes:
            organization_info_file = os.path.join(os.path.expanduser('~/.aws/recipes/%s/organization.json' %
                                                                     profile.attributes['source_profile']))
            if os.path.isfile(organization_info_file):
                with open(organization_info_file, 'rt') as f:
                    org = {}
                    accounts = json.load(f)
                    for account in accounts:
                        account_id = account.pop('Id')
                        org[account_id] = account
                    setattr(cloud_provider, 'organization', org)
    except Exception as e:
        pass

    # Save config and create HTML report
    html_report_path = report.save(cloud_provider, exceptions, args.force_write, args.debug)

    # Open the report by default
    if not args.no_browser:
        printInfo('Opening the HTML report...')
        url = 'file://%s' % os.path.abspath(html_report_path)
        webbrowser.open(url, new=2)

    return 0
Example #7
0
def main(args=None):
    """
    Main method that runs a scan

    :return:
    """
    if not args:
        parser = ScoutSuiteArgumentParser()
        args = parser.parse_args()

    # Get the dictionnary to get None instead of a crash
    args = args.__dict__

    # Configure the debug level
    configPrintException(args.get('debug'))

    # Create a cloud provider object
    cloud_provider = get_provider(
        provider=args.get('provider'),
        profile=args.get('profile'),
        project_id=args.get('project_id'),
        folder_id=args.get('folder_id'),
        organization_id=args.get('organization_id'),
        all_projects=args.get('all_projects'),
        report_dir=args.get('report_dir'),
        timestamp=args.get('timestamp'),
        services=args.get('services'),
        skipped_services=args.get('skipped_services'),
        thread_config=args.get('thread_config'),
    )

    report_file_name = generate_report_name(cloud_provider.provider_code, args)

    # TODO move this to after authentication, so that the report can be more specific to what's being scanned.
    # For example if scanning with a GCP service account, the SA email can only be known after authenticating...
    # Create a new report
    report = Scout2Report(args.get('provider'), report_file_name,
                          args.get('report_dir'), args.get('timestamp'))

    # Complete run, including pulling data from provider
    if not args.get('fetch_local'):
        # Authenticate to the cloud provider
        authenticated = cloud_provider.authenticate(
            profile=args.get('profile'),
            csv_credentials=args.get('csv_credentials'),
            mfa_serial=args.get('mfa_serial'),
            mfa_code=args.get('mfa_code'),
            user_account=args.get('user_account'),
            service_account=args.get('service_account'),
            cli=args.get('cli'),
            msi=args.get('msi'),
            service_principal=args.get('service_principal'),
            file_auth=args.get('file_auth'),
            tenant_id=args.get('tenant_id'),
            subscription_id=args.get('subscription_id'),
            client_id=args.get('client_id'),
            client_secret=args.get('client_secret'),
            username=args.get('username'),
            password=args.get('password'))

        if not authenticated:
            return 42

        # Fetch data from provider APIs
        try:
            cloud_provider.fetch(regions=args.get('regions'))
        except KeyboardInterrupt:
            printInfo('\nCancelled by user')
            return 130

        # Update means we reload the whole config and overwrite part of it
        if args.get('update'):
            current_run_services = copy.deepcopy(cloud_provider.services)
            last_run_dict = report.jsrw.load_from_file(AWSCONFIG)
            cloud_provider.services = last_run_dict['services']
            for service in cloud_provider.service_list:
                cloud_provider.services[service] = current_run_services[
                    service]

    # Partial run, using pre-pulled data
    else:
        # Reload to flatten everything into a python dictionary
        last_run_dict = report.jsrw.load_from_file(AWSCONFIG)
        for key in last_run_dict:
            setattr(cloud_provider, key, last_run_dict[key])

    # Pre processing
    cloud_provider.preprocessing(args.get('ip_ranges'),
                                 args.get('ip_ranges_name_key'))

    # Analyze config
    finding_rules = Ruleset(environment_name=args.get('profile'),
                            cloud_provider=args.get('provider'),
                            filename=args.get('ruleset'),
                            ip_ranges=args.get('ip_ranges'),
                            aws_account_id=cloud_provider.aws_account_id)
    processing_engine = ProcessingEngine(finding_rules)
    processing_engine.run(cloud_provider)

    # Create display filters
    filter_rules = Ruleset(cloud_provider=args.get('provider'),
                           filename='filters.json',
                           rule_type='filters',
                           aws_account_id=cloud_provider.aws_account_id)
    processing_engine = ProcessingEngine(filter_rules)
    processing_engine.run(cloud_provider)

    # Handle exceptions
    try:
        exceptions = RuleExceptions(args.get('profile'),
                                    args.get('exceptions')[0])
        exceptions.process(cloud_provider)
        exceptions = exceptions.exceptions
    except Exception as e:
        printDebug(
            'Warning, failed to load exceptions. The file may not exist or may have an invalid format.'
        )
        exceptions = {}

    # Finalize
    cloud_provider.postprocessing(report.current_time, finding_rules)

    # TODO this is AWS-specific - move to postprocessing?
    # Get organization data if it exists
    try:
        profile = AWSProfiles.get(args.get('profile'))[0]
        if 'source_profile' in profile.attributes:
            organization_info_file = os.path.join(
                os.path.expanduser('~/.aws/recipes/%s/organization.json' %
                                   profile.attributes['source_profile']))
            if os.path.isfile(organization_info_file):
                with open(organization_info_file, 'rt') as f:
                    org = {}
                    accounts = json.load(f)
                    for account in accounts:
                        account_id = account.pop('Id')
                        org[account_id] = account
                    setattr(cloud_provider, 'organization', org)
    except Exception as e:
        pass

    # Save config and create HTML report
    html_report_path = report.save(cloud_provider, exceptions,
                                   args.get('force_write'), args.get('debug'))

    # Open the report by default
    if not args.get('no_browser'):
        printInfo('Opening the HTML report...')
        url = 'file://%s' % os.path.abspath(html_report_path)
        webbrowser.open(url, new=2)

    return 0