示例#1
0
    def __init__(self,
                 session,
                 role=None,
                 account=None,
                 debug=False,
                 headless=True,
                 saml_request=None):
        self._session = session
        self._role = role
        self._account = account
        self._debug = debug
        self._headless = headless
        self._config = self._session.get_scoped_config()
        self._config_writer = ConfigFileWriter()
        self._azure_tenant_id = self._config.get('azure_tenant_id')
        self._azure_app_id_uri = self._config.get('azure_app_id_uri')
        self._azure_mfa = self._config.get('azure_mfa')
        self._azure_kmsi = self._config.get('azure_kmsi', False)
        self._azure_username = self._config.get('azure_username')
        self._azure_password = None
        self._session_duration = int(self._config.get('session_duration',
                                                      3600))
        self._use_keyring = self._config.get('use_keyring')
        self.saml_response = None

        if saml_request:
            self._SAML_REQUEST = saml_request
示例#2
0
    def __init__(self,
                 session,
                 role=None,
                 account=None,
                 debug=False,
                 headless=True,
                 saml_request=None):
        self._session = session
        self._role = role
        self._account = account
        self._debug = debug
        self._headless = headless
        self._config = self._session.get_scoped_config()
        self._config_writer = ConfigFileWriter()
        self._azure_tenant_id = os.environ["azure_tenant_id"]
        self._azure_app_id_uri = os.environ["azure_app_id_uri"]
        self._azure_mfa = False
        self._azure_kmsi = False
        self._azure_username = os.environ["azure_default_username"]
        self._azure_password = os.environ["azure_default_password"]
        self._session_duration = 3600
        self._use_keyring = False
        self.saml_response = None

        if saml_request:
            self._SAML_REQUEST = saml_request
class ConfigWriter(object):

    def __init__(self, session):
        self.session = session
        self.section = _get_profile_str(session, ' ')
        self.config_file_writer = ConfigFileWriter()

    def update_config(self, key, value):
        config_filename = \
            os.path.expanduser(self.session.get_config_variable('config_file'))
        updated_config = {'__section__': self.section,
                          'emr': {key: value}}
        self.config_file_writer.update_config(updated_config, config_filename)
示例#4
0
    def save(self):
        session = Session(profile=self.profile)
        writer = ConfigFileWriter()
        values = {
            "awsad-azure_tenant_id": self.azure_tenant_id,
            "awsad-azure_app_id": self.azure_app_id,
            "awsad-azure_app_title": self.azure_app_title,
            "awsad-aws_default_role_arn": self.aws_default_role_arn,
            "awsad-aws_session_duration": self.aws_session_duration
        }
        if session.profile is not None:
            values["__section__"] = f"profile {session.profile}"
        writer.update_config(values,
                             self.config_file_path(session, expand_user=True))

        writer = ConfigFileWriter()
        values = {
            "aws_access_key_id":
            self.aws_access_key_id,
            "aws_secret_access_key":
            self.aws_secret_access_key,
            "aws_session_token":
            self.aws_session_token,
            "awsad-aws_expiration_time":
            self.aws_expiration_time.isoformat()
            if self.aws_expiration_time else None
        }
        if session.profile is not None:
            values["__section__"] = session.profile
        writer.update_config(
            values, self.credentials_file_path(session, expand_user=True))
示例#5
0
def create_default_wizard_runner(session):
    api_invoker = core.APIInvoker(session=session)
    shared_config = core.SharedConfigAPI(session=session,
                                         config_writer=ConfigFileWriter())
    planner = core.Planner(
        step_handlers={
            core.StaticStep.NAME:
            core.StaticStep(),
            core.PromptStep.NAME:
            core.PromptStep(ui.UIPrompter()),
            core.YesNoPrompt.NAME:
            core.YesNoPrompt(ui.UIPrompter()),
            core.FilePromptStep.NAME:
            core.FilePromptStep(ui.UIFilePrompter(ui.FileCompleter())),
            core.TemplateStep.NAME:
            core.TemplateStep(),
            core.APICallStep.NAME:
            core.APICallStep(api_invoker=api_invoker),
            core.SharedConfigStep.NAME:
            core.SharedConfigStep(config_api=shared_config),
        })
    executor = core.Executor(
        step_handlers={
            core.APICallExecutorStep.NAME:
            core.APICallExecutorStep(api_invoker),
            core.SharedConfigExecutorStep.NAME:
            core.SharedConfigExecutorStep(shared_config),
            core.DefineVariableStep.NAME:
            core.DefineVariableStep(),
            core.MergeDictStep.NAME:
            core.MergeDictStep(),
        })
    runner = core.Runner(planner, executor)
    return runner
示例#6
0
 def __init__(self, session, prompter=None, config_writer=None):
     super(ConfigureCommand, self).__init__(session)
     if prompter is None:
         prompter = InteractivePrompter()
     self._prompter = prompter
     if config_writer is None:
         config_writer = ConfigFileWriter()
     self._config_writer = config_writer
示例#7
0
    def update(self) -> None:
        """ Interactively update the profile. """
        new_values = {}
        writer = ConfigFileWriter()

        for attr, string in self._config_options.items():
            value = getattr(self, attr, self._optional.get(attr))

            prompt = "%s [%s]: " % (string, value)
            value = input(prompt)

            if value:
                new_values[attr] = value

        if new_values:
            if self.name != 'default':
                new_values['__section__'] = self.name

            writer.update_config(new_values, self.config_file)
示例#8
0
    def _run_main(self, parsed_args, parsed_globals):
        current_key = None
        current_secret = None
        masked_current_secret = None
        if self._session._credentials:
            current_key = self._session._credentials.access_key
            current_secret = self._session._credentials.secret_key
            if current_secret is not None:
                masked_current_secret = "*" * (len(current_secret) -
                                               4) + current_secret[-4:]

        key = parsed_args.key
        if key is None:
            import getpass
            key = getpass.getpass("AWS Access Key ID [%s]: " % current_key)
            if key is None or key == "":
                key = current_key

        secret = parsed_args.secret
        if secret is None:
            import getpass
            secret = getpass.getpass("AWS Secret Access Key [%s]: " %
                                     masked_current_secret)
            if secret is None or secret == "":
                secret = current_secret

        profile = self._session.profile
        if profile is None:
            profile = "default"
            config_section = "default"
        else:
            config_section = "profile {0}".format(profile)

        persistence.set_credentials(profile, key, secret)

        config_update = {"__section__": config_section, "keyring": "true"}
        config_filename = os.path.expanduser(
            self._session.get_config_variable("config_file"))

        config_writer = ConfigFileWriter()
        config_writer.update_config(config_update, config_filename)

        return 0
示例#9
0
def add_tmp_profile(profile,
                    credentials,
                    *,
                    config_path=CONFIG,
                    credential_path=CREDENTIALS):

    config_val = {
        "__section__": 'profile ' + profile,
        "region": DEFAULT_REGION
    }

    credential_val = {
        "__section__": profile,
        "aws_access_key_id": credentials["AccessKeyId"],
        "aws_secret_access_key": credentials["SecretAccessKey"],
        "aws_session_token": credentials["SessionToken"],
    }

    writer = Cw()
    writer.update_config(config_val, os.path.expanduser(config_path))
    writer.update_config(credential_val, os.path.expanduser(credential_path))
示例#10
0
    def _run_main(self, parsed_args, parsed_globals):
        current_key = None
        current_secret = None
        masked_current_secret = None
        if self._session._credentials:
            current_key = self._session._credentials.access_key
            current_secret = self._session._credentials.secret_key
            if current_secret is not None:
                masked_current_secret = "*" * (len(current_secret) - 4) + current_secret[-4:]

        key = parsed_args.key
        if key is None:
            import getpass
            key = getpass.getpass("AWS Access Key ID [%s]: " % current_key)
            if key is None or key == "":
                key = current_key

        secret = parsed_args.secret
        if secret is None:
            import getpass
            secret = getpass.getpass("AWS Secret Access Key [%s]: " % masked_current_secret)
            if secret is None or secret == "":
                secret = current_secret

        profile = self._session.profile
        if profile is None:
            profile = "default"
            config_section = "default"
        else:
            config_section = "profile {0}".format(profile)

        persistence.set_credentials(profile, key, secret)

        config_update = {"__section__": config_section, "keyring": "true"}
        config_filename = os.path.expanduser(self._session.get_config_variable("config_file"))

        config_writer = ConfigFileWriter()
        config_writer.update_config(config_update, config_filename)

        return 0
示例#11
0
    def __init__(self,
                 session,
                 prompter=None,
                 selector=None,
                 config_writer=None,
                 sso_token_cache=None):
        super(ConfigureSSOCommand, self).__init__(session)
        if prompter is None:
            prompter = PTKPrompt()
        self._prompter = prompter
        if selector is None:
            selector = select_menu
        self._selector = selector
        if config_writer is None:
            config_writer = ConfigFileWriter()
        self._config_writer = config_writer
        self._sso_token_cache = sso_token_cache

        self._new_values = {}
        self._original_profile_name = self._session.profile
        try:
            self._config = self._session.get_scoped_config()
        except ProfileNotFound:
            self._config = {}
示例#12
0
def setup_cli(name=None, key_id=None, secret=None, region=None):
    if name is None:
        name = eval(input("Profile name: "))
    home_dir = os.path.expanduser("~")
    config_file = os.path.join(home_dir, ".aws", "config")
    credentials_file = os.path.join(home_dir, ".aws", "credentials")
    if has_entry("profile ", name, config_file) or \
       has_entry("", name, credentials_file):
        print("Profile " + name + " already exists. Not overwriting.")
        return
    if key_id is None:
        key_id = eval(input("Key ID: "))
    if secret is None:
        secret = eval(input("Key secret: "))
    if region is None:
        region = eval(input("Default region: "))
    writer = ConfigFileWriter()
    config_values = {
        "__section__": "profile " + name,
        "output": "json",
        "region": region
    }
    credentials_values = {
        "__section__": name,
        "aws_access_key_id": key_id,
        "aws_secret_access_key": secret
    }
    writer.update_config(config_values, config_file)
    writer.update_config(credentials_values, credentials_file)
    home_bin = credentials_file = os.path.join(home_dir, "bin")
    if not os.path.isdir(home_bin):
        os.makedirs(home_bin)
    source_file = os.path.join(home_bin, name)
    with open(source_file, "w") as source_script:
        source_script.write('#!/bin/bash\n\n')
        source_script.write('export AWS_DEFAULT_REGION=')
        source_script.write(region + ' AWS_PROFILE=' + name)
        source_script.write(' AWS_DEFAULT_PROFILE=' + name + "\n")
    os.chmod(source_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
    os.environ['AWS_PROFILE'] = name
    os.environ['AWS_DEFAULT_PROFILE'] = name
    os.environ['AWS_DEFAULT_REGION'] = region
示例#13
0
 def setUp(self):
     self.dirname = tempfile.mkdtemp()
     self.config_filename = os.path.join(self.dirname, 'config')
     self.writer = ConfigFileWriter()
示例#14
0
class TestConfigFileWriter(unittest.TestCase):

    def setUp(self):
        self.dirname = tempfile.mkdtemp()
        self.config_filename = os.path.join(self.dirname, 'config')
        self.writer = ConfigFileWriter()

    def tearDown(self):
        shutil.rmtree(self.dirname)

    def assert_update_config(self, original_config_contents, updated_data,
                             updated_config_contents):
        # Given the original_config, when it's updated with update_data,
        # it should produce updated_config_contents.
        with open(self.config_filename, 'w') as f:
            f.write(original_config_contents)
        self.writer.update_config(updated_data, self.config_filename)
        with open(self.config_filename, 'r') as f:
            new_contents = f.read()
        if new_contents != updated_config_contents:
            self.fail("Config file contents do not match.\n"
                      "Expected contents:\n"
                      "%s\n\n"
                      "Actual Contents:\n"
                      "%s\n" % (updated_config_contents, new_contents))

    def test_update_single_existing_value(self):
        original = '[default]\nfoo = 1\nbar = 1'
        updated = '[default]\nfoo = newvalue\nbar = 1'
        self.assert_update_config(
            original, {'foo': 'newvalue'}, updated)

    def test_update_single_existing_value_no_spaces(self):
        original = '[default]\nfoo=1\nbar=1'
        updated = '[default]\nfoo = newvalue\nbar=1'
        self.assert_update_config(
            original, {'foo': 'newvalue'}, updated)

    def test_update_single_new_values(self):
        expected = '[default]\nfoo = 1\nbar = 2\nbaz = newvalue\n'
        self.assert_update_config(
            '[default]\nfoo = 1\nbar = 2',
            {'baz': 'newvalue'}, expected)

    def test_handles_no_spaces(self):
        expected = '[default]\nfoo=1\nbar=2\nbaz = newvalue\n'
        self.assert_update_config(
            '[default]\nfoo=1\nbar=2',
            {'baz': 'newvalue'}, expected)

    def test_insert_values_in_middle_section(self):
        original_contents = (
            '[a]\n'
            'foo = bar\n'
            'baz = bar\n'
            '\n'
            '[b]\n'
            '\n'
            'foo = bar\n'
            '[c]\n'
            'foo = bar\n'
            'baz = bar\n'
        )
        expected_contents = (
            '[a]\n'
            'foo = bar\n'
            'baz = bar\n'
            '\n'
            '[b]\n'
            '\n'
            'foo = newvalue\n'
            '[c]\n'
            'foo = bar\n'
            'baz = bar\n'
        )
        self.assert_update_config(
            original_contents,
            {'foo': 'newvalue', '__section__': 'b'},
            expected_contents)

    def test_insert_new_value_in_middle_section(self):
        original_contents = (
            '[a]\n'
            'foo = bar\n'
            '\n'
            '[b]\n'
            '\n'
            'foo = bar\n'
            '\n'
            '[c]\n'
            'foo = bar\n'
        )
        expected_contents = (
            '[a]\n'
            'foo = bar\n'
            '\n'
            '[b]\n'
            '\n'
            'foo = bar\n'
            'newvalue = newvalue\n'
            '\n'
            '[c]\n'
            'foo = bar\n'
        )
        self.assert_update_config(
            original_contents,
            {'newvalue': 'newvalue', '__section__': 'b'},
            expected_contents)

    def test_new_config_file(self):
        self.assert_update_config(
            '\n',
            {'foo': 'value'},
            '\n[default]\nfoo = value\n')

    def test_section_does_not_exist(self):
        original_contents = (
            '[notdefault]\n'
            'foo = bar\n'
            'baz = bar\n'
            '\n'
            '\n'
            '\n'
            '[other "section"]\n'
            '\n'
            'foo = bar\n'
        )
        appended_contents = (
            '[default]\n'
            'foo = value\n'
        )
        self.assert_update_config(
            original_contents,
            {'foo': 'value'},
            original_contents + appended_contents)

    def test_config_file_does_not_exist(self):
        self.writer.update_config({'foo': 'value'}, self.config_filename)
        with open(self.config_filename, 'r') as f:
            new_contents = f.read()
        self.assertEqual(new_contents, '[default]\nfoo = value\n')

    @skip_if_windows("Test not valid on windows.")
    def test_permissions_on_new_file(self):
        self.writer.update_config({'foo': 'value'}, self.config_filename)
        with open(self.config_filename, 'r') as f:
            f.read()
        self.assertEqual(os.stat(self.config_filename).st_mode & 0xFFF, 0o600)

    def test_update_config_with_comments(self):
        original = (
            '[default]\n'
            '#foo = 1\n'
            'bar = 1\n'
        )
        self.assert_update_config(
            original, {'foo': 'newvalue'},
            '[default]\n'
            '#foo = 1\n'
            'bar = 1\n'
            'foo = newvalue\n'
        )

    def test_update_config_with_commented_section(self):
        original = (
            '#[default]\n'
            '[default]\n'
            '#foo = 1\n'
            'bar = 1\n'
        )
        self.assert_update_config(
            original, {'foo': 'newvalue'},
            '#[default]\n'
            '[default]\n'
            '#foo = 1\n'
            'bar = 1\n'
            'foo = newvalue\n'
        )

    def test_spaces_around_key_names(self):
        original = (
            '[default]\n'
            'foo = 1\n'
            'bar = 1\n'
        )
        self.assert_update_config(
            original, {'foo': 'newvalue'},
            '[default]\n'
            'foo = newvalue\n'
            'bar = 1\n'
        )

    def test_unquoted_profile_name(self):
        original = (
            '[profile foobar]\n'
            'foo = 1\n'
            'bar = 1\n'
        )
        self.assert_update_config(
            original, {'foo': 'newvalue', '__section__': 'profile foobar'},
            '[profile foobar]\n'
            'foo = newvalue\n'
            'bar = 1\n'
        )

    def test_double_quoted_profile_name(self):
        original = (
            '[profile "foobar"]\n'
            'foo = 1\n'
            'bar = 1\n'
        )
        self.assert_update_config(
            original, {'foo': 'newvalue', '__section__': 'profile foobar'},
            '[profile "foobar"]\n'
            'foo = newvalue\n'
            'bar = 1\n'
        )

    def test_profile_with_multiple_spaces(self):
        original = (
            '[profile "two  spaces"]\n'
            'foo = 1\n'
            'bar = 1\n'
        )
        self.assert_update_config(
            original, {
                'foo': 'newvalue', '__section__': 'profile two  spaces'},
            '[profile "two  spaces"]\n'
            'foo = newvalue\n'
            'bar = 1\n'
        )

    def test_nested_attributes_new_file(self):
        original = ''
        self.assert_update_config(
            original, {'__section__': 'default',
                       's3': {'signature_version': 's3v4'}},
            '[default]\n'
            's3 =\n'
            '    signature_version = s3v4\n')

    def test_add_to_nested_with_nested_in_the_middle(self):
        original = (
            '[default]\n'
            's3 =\n'
            '    other = foo\n'
            'ec2 = bar\n'
        )
        self.assert_update_config(
            original, {'__section__': 'default',
                       's3': {'signature_version': 'newval'}},
            '[default]\n'
            's3 =\n'
            '    other = foo\n'
            '    signature_version = newval\n'
            'ec2 = bar\n')

    def test_add_to_nested_with_nested_in_the_end(self):
        original = (
            '[default]\n'
            's3 =\n'
            '    other = foo\n'
        )
        self.assert_update_config(
            original, {'__section__': 'default',
                       's3': {'signature_version': 'newval'}},
            '[default]\n'
            's3 =\n'
            '    other = foo\n'
            '    signature_version = newval\n')

    def test_update_nested_attribute(self):
        original = (
            '[default]\n'
            's3 =\n'
            '    signature_version = originalval\n'
        )
        self.assert_update_config(
            original, {'__section__': 'default',
                       's3': {'signature_version': 'newval'}},
            '[default]\n'
            's3 =\n'
            '    signature_version = newval\n')

    def test_updated_nested_attribute_new_section(self):
        original = (
            '[default]\n'
            's3 =\n'
            '    other = foo\n'
            '[profile foo]\n'
            'foo = bar\n'
        )
        self.assert_update_config(
            original, {'__section__': 'default',
                       's3': {'signature_version': 'newval'}},
            '[default]\n'
            's3 =\n'
            '    other = foo\n'
            '    signature_version = newval\n'
            '[profile foo]\n'
            'foo = bar\n')

    def test_update_nested_attr_no_prior_nesting(self):
        original = (
            '[default]\n'
            'foo = bar\n'
            '[profile foo]\n'
            'foo = bar\n'
        )
        self.assert_update_config(
            original, {'__section__': 'default',
                       's3': {'signature_version': 'newval'}},
            '[default]\n'
            'foo = bar\n'
            's3 =\n'
            '    signature_version = newval\n'
            '[profile foo]\n'
            'foo = bar\n')

    def test_can_handle_empty_section(self):
        original = (
            '[default]\n'
            '[preview]\n'
            'cloudfront = true\n'
        )
        self.assert_update_config(
            original, {'region': 'us-west-2', '__section__': 'default'},
            '[default]\n'
            'region = us-west-2\n'
            '[preview]\n'
            'cloudfront = true\n'
        )
示例#15
0
 def __init__(self, session=None):
     self._session = session
     config_writer = ConfigFileWriter()
     self._config_writer = config_writer
示例#16
0
 def setUp(self):
     self.dirname = tempfile.mkdtemp()
     self.config_filename = os.path.join(self.dirname, 'config')
     self.writer = ConfigFileWriter()
示例#17
0
class TestConfigFileWriter(unittest.TestCase):
    def setUp(self):
        self.dirname = tempfile.mkdtemp()
        self.config_filename = os.path.join(self.dirname, 'config')
        self.writer = ConfigFileWriter()

    def tearDown(self):
        shutil.rmtree(self.dirname)

    def assert_update_config(self, original_config_contents, updated_data,
                             updated_config_contents):
        # Given the original_config, when it's updated with update_data,
        # it should produce updated_config_contents.
        with open(self.config_filename, 'w') as f:
            f.write(original_config_contents)
        self.writer.update_config(updated_data, self.config_filename)
        with open(self.config_filename, 'r') as f:
            new_contents = f.read()
        if new_contents != updated_config_contents:
            self.fail("Config file contents do not match.\n"
                      "Expected contents:\n"
                      "%s\n\n"
                      "Actual Contents:\n"
                      "%s\n" % (updated_config_contents, new_contents))

    def test_update_single_existing_value(self):
        original = '[default]\nfoo = 1\nbar = 1'
        updated = '[default]\nfoo = newvalue\nbar = 1'
        self.assert_update_config(original, {'foo': 'newvalue'}, updated)

    def test_update_single_existing_value_no_spaces(self):
        original = '[default]\nfoo=1\nbar=1'
        updated = '[default]\nfoo = newvalue\nbar=1'
        self.assert_update_config(original, {'foo': 'newvalue'}, updated)

    def test_update_single_new_values(self):
        expected = '[default]\nfoo = 1\nbar = 2\nbaz = newvalue\n'
        self.assert_update_config('[default]\nfoo = 1\nbar = 2',
                                  {'baz': 'newvalue'}, expected)

    def test_handles_no_spaces(self):
        expected = '[default]\nfoo=1\nbar=2\nbaz = newvalue\n'
        self.assert_update_config('[default]\nfoo=1\nbar=2',
                                  {'baz': 'newvalue'}, expected)

    def test_insert_values_in_middle_section(self):
        original_contents = ('[a]\n'
                             'foo = bar\n'
                             'baz = bar\n'
                             '\n'
                             '[b]\n'
                             '\n'
                             'foo = bar\n'
                             '[c]\n'
                             'foo = bar\n'
                             'baz = bar\n')
        expected_contents = ('[a]\n'
                             'foo = bar\n'
                             'baz = bar\n'
                             '\n'
                             '[b]\n'
                             '\n'
                             'foo = newvalue\n'
                             '[c]\n'
                             'foo = bar\n'
                             'baz = bar\n')
        self.assert_update_config(original_contents, {
            'foo': 'newvalue',
            '__section__': 'b'
        }, expected_contents)

    def test_insert_new_value_in_middle_section(self):
        original_contents = ('[a]\n'
                             'foo = bar\n'
                             '\n'
                             '[b]\n'
                             '\n'
                             'foo = bar\n'
                             '\n'
                             '[c]\n'
                             'foo = bar\n')
        expected_contents = ('[a]\n'
                             'foo = bar\n'
                             '\n'
                             '[b]\n'
                             '\n'
                             'foo = bar\n'
                             'newvalue = newvalue\n'
                             '\n'
                             '[c]\n'
                             'foo = bar\n')
        self.assert_update_config(original_contents, {
            'newvalue': 'newvalue',
            '__section__': 'b'
        }, expected_contents)

    def test_new_config_file(self):
        self.assert_update_config('\n', {'foo': 'value'},
                                  '\n[default]\nfoo = value\n')

    def test_section_does_not_exist(self):
        original_contents = ('[notdefault]\n'
                             'foo = bar\n'
                             'baz = bar\n'
                             '\n'
                             '\n'
                             '\n'
                             '[other "section"]\n'
                             '\n'
                             'foo = bar\n')
        appended_contents = ('[default]\n' 'foo = value\n')
        self.assert_update_config(original_contents, {'foo': 'value'},
                                  original_contents + appended_contents)

    def test_config_file_does_not_exist(self):
        self.writer.update_config({'foo': 'value'}, self.config_filename)
        with open(self.config_filename, 'r') as f:
            new_contents = f.read()
        self.assertEqual(new_contents, '[default]\nfoo = value\n')

    @skip_if_windows("Test not valid on windows.")
    def test_permissions_on_new_file(self):
        self.writer.update_config({'foo': 'value'}, self.config_filename)
        with open(self.config_filename, 'r') as f:
            f.read()
        self.assertEqual(os.stat(self.config_filename).st_mode & 0xFFF, 0o600)

    def test_update_config_with_comments(self):
        original = ('[default]\n' '#foo = 1\n' 'bar = 1\n')
        self.assert_update_config(
            original, {'foo': 'newvalue'}, '[default]\n'
            '#foo = 1\n'
            'bar = 1\n'
            'foo = newvalue\n')

    def test_update_config_with_commented_section(self):
        original = ('#[default]\n' '[default]\n' '#foo = 1\n' 'bar = 1\n')
        self.assert_update_config(
            original, {'foo': 'newvalue'}, '#[default]\n'
            '[default]\n'
            '#foo = 1\n'
            'bar = 1\n'
            'foo = newvalue\n')

    def test_spaces_around_key_names(self):
        original = ('[default]\n' 'foo = 1\n' 'bar = 1\n')
        self.assert_update_config(original, {'foo': 'newvalue'}, '[default]\n'
                                  'foo = newvalue\n'
                                  'bar = 1\n')

    def test_unquoted_profile_name(self):
        original = ('[profile foobar]\n' 'foo = 1\n' 'bar = 1\n')
        self.assert_update_config(
            original, {
                'foo': 'newvalue',
                '__section__': 'profile foobar'
            }, '[profile foobar]\n'
            'foo = newvalue\n'
            'bar = 1\n')

    def test_double_quoted_profile_name(self):
        original = ('[profile "foobar"]\n' 'foo = 1\n' 'bar = 1\n')
        self.assert_update_config(
            original, {
                'foo': 'newvalue',
                '__section__': 'profile foobar'
            }, '[profile "foobar"]\n'
            'foo = newvalue\n'
            'bar = 1\n')

    def test_profile_with_multiple_spaces(self):
        original = ('[profile "two  spaces"]\n' 'foo = 1\n' 'bar = 1\n')
        self.assert_update_config(
            original, {
                'foo': 'newvalue',
                '__section__': 'profile two  spaces'
            }, '[profile "two  spaces"]\n'
            'foo = newvalue\n'
            'bar = 1\n')

    def test_nested_attributes_new_file(self):
        original = ''
        self.assert_update_config(
            original, {
                '__section__': 'default',
                's3': {
                    'signature_version': 's3v4'
                }
            }, '[default]\n'
            's3 =\n'
            '    signature_version = s3v4\n')

    def test_add_to_nested_with_nested_in_the_middle(self):
        original = ('[default]\n' 's3 =\n' '    other = foo\n' 'ec2 = bar\n')
        self.assert_update_config(
            original, {
                '__section__': 'default',
                's3': {
                    'signature_version': 'newval'
                }
            }, '[default]\n'
            's3 =\n'
            '    other = foo\n'
            '    signature_version = newval\n'
            'ec2 = bar\n')

    def test_add_to_nested_with_nested_in_the_end(self):
        original = ('[default]\n' 's3 =\n' '    other = foo\n')
        self.assert_update_config(
            original, {
                '__section__': 'default',
                's3': {
                    'signature_version': 'newval'
                }
            }, '[default]\n'
            's3 =\n'
            '    other = foo\n'
            '    signature_version = newval\n')

    def test_update_nested_attribute(self):
        original = ('[default]\n'
                    's3 =\n'
                    '    signature_version = originalval\n')
        self.assert_update_config(
            original, {
                '__section__': 'default',
                's3': {
                    'signature_version': 'newval'
                }
            }, '[default]\n'
            's3 =\n'
            '    signature_version = newval\n')

    def test_updated_nested_attribute_new_section(self):
        original = ('[default]\n'
                    's3 =\n'
                    '    other = foo\n'
                    '[profile foo]\n'
                    'foo = bar\n')
        self.assert_update_config(
            original, {
                '__section__': 'default',
                's3': {
                    'signature_version': 'newval'
                }
            }, '[default]\n'
            's3 =\n'
            '    other = foo\n'
            '    signature_version = newval\n'
            '[profile foo]\n'
            'foo = bar\n')

    def test_update_nested_attr_no_prior_nesting(self):
        original = ('[default]\n'
                    'foo = bar\n'
                    '[profile foo]\n'
                    'foo = bar\n')
        self.assert_update_config(
            original, {
                '__section__': 'default',
                's3': {
                    'signature_version': 'newval'
                }
            }, '[default]\n'
            'foo = bar\n'
            's3 =\n'
            '    signature_version = newval\n'
            '[profile foo]\n'
            'foo = bar\n')

    def test_can_handle_empty_section(self):
        original = ('[default]\n' '[preview]\n' 'cloudfront = true\n')
        self.assert_update_config(
            original, {
                'region': 'us-west-2',
                '__section__': 'default'
            }, '[default]\n'
            'region = us-west-2\n'
            '[preview]\n'
            'cloudfront = true\n')
 def __init__(self, session):
     self.session = session
     self.section = _get_profile_str(session, ' ')
     self.config_file_writer = ConfigFileWriter()
示例#19
0
class Login:
    _SAML_REQUEST = \
        '<samlp:AuthnRequest xmlns="urn:oasis:names:tc:SAML:2.0:metadata" xml' \
        'ns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" ID="id_{id}" Version' \
        '="2.0" IsPassive="false" IssueInstant="{date}" AssertionConsumerServ' \
        'iceURL="https://signin.aws.amazon.com/saml"><Issuer xmlns="urn:oasis' \
        ':names:tc:SAML:2.0:assertion">{app_id}</Issuer><samlp:NameIDPolicy F' \
        'ormat="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"/></sa' \
        'mlp:AuthnRequest>'

    _BEGIN_AUTH_URL = '{url}/common/SAS/BeginAuth'
    _END_AUTH_URL = '{url}/common/SAS/EndAuth'
    _PROCESS_AUTH_URL = '{url}/common/SAS/ProcessAuth'
    _SAML_URL = '{url}/{tenant_id}/saml2?SAMLRequest={saml_request}'
    _REFERER = '{url}/{tenant_id}/login'

    _CREDENTIALS = [
        'aws_access_key_id', 'aws_secret_access_key', 'aws_session_token'
    ]
    _MFA_DELAY = 3
    _MFA_TIMEOUT = 60  # timeout in seconds to process MFA
    _AWAIT_TIMEOUT = 30000
    _SLEEP_TIMEOUT = 500
    _EXEC_PATH = os.environ.get('CHROME_EXECUTABLE_PATH')
    _RETRIES = 5

    def __init__(self,
                 session,
                 role=None,
                 account=None,
                 debug=False,
                 headless=True,
                 saml_request=None):
        self._session = session
        self._role = role
        self._account = account
        self._debug = debug
        self._headless = headless
        self._config = self._session.get_scoped_config()
        self._config_writer = ConfigFileWriter()
        self._azure_tenant_id = self._config.get('azure_tenant_id')
        self._azure_app_id_uri = self._config.get('azure_app_id_uri')
        self._azure_mfa = self._config.get('azure_mfa')
        self._azure_kmsi = self._config.get('azure_kmsi', False)
        self._azure_username = self._config.get('azure_username')
        self._azure_password = None
        self._session_duration = int(self._config.get('session_duration',
                                                      3600))
        self._use_keyring = self._config.get('use_keyring')
        self.saml_response = None

        if saml_request:
            self._SAML_REQUEST = saml_request

    def __call__(self):
        return self._login()

    def _set_config_value(self, key, value):
        section = 'default'

        if self._session.profile is not None:
            section = 'profile {}'.format(self._session.profile)

        config_filename = os.path.expanduser(
            self._session.get_config_variable('config_file'))
        updated_config = {'__section__': section, key: value}

        if key in self._CREDENTIALS:
            config_filename = os.path.expanduser(
                self._session.get_config_variable('credentials_file'))
            section_name = updated_config['__section__']

            if section_name.startswith('profile '):
                updated_config['__section__'] = section_name[8:]
        self._config_writer.update_config(updated_config, config_filename)

    def _build_saml_login_url(self):
        saml_request = base64.b64encode(
            zlib.compress(self._SAML_REQUEST.strip().format(
                date=datetime.now().strftime("%Y-%m-%dT%H:%m:%SZ"),
                tenant_id=self._azure_tenant_id,
                id=uuid.uuid4(),
                app_id=self._azure_app_id_uri).encode('ascii'))
            [2:-4]).decode()
        return self._SAML_URL.format(url=LOGIN_URL,
                                     tenant_id=self._azure_tenant_id,
                                     saml_request=quote(saml_request))

    @classmethod
    async def _querySelector(cls, page, element, retries=0):
        if retries > cls._RETRIES:
            raise TimeoutError
        try:
            return await page.querySelector(element)
        except NetworkError:
            await page.waitFor(cls._SLEEP_TIMEOUT)
            return await cls._querySelector(page, element, retries + 1)

    async def _render_js_form(self, url, username, password, mfa=None):
        browser = await launch(executablePath=self._EXEC_PATH,
                               headless=self._headless)

        pages = await browser.pages()
        page = pages[0]

        async def _saml_response(req):
            if req.url == 'https://signin.aws.amazon.com/saml':
                self.saml_response = parse_qs(req.postData)['SAMLResponse'][0]
                await req.respond({
                    'status': 200,
                    'contentType': 'text/plain',
                    'body': ''
                })
            else:
                await req.continue_()

        await page.goto(url, waitUntil='domcontentloaded')
        await page.waitForSelector(
            'input[name="loginfmt"]:not(.moveOffScreen)', {"visible": True})
        await page.focus('input[name="loginfmt"]')
        await page.keyboard.type(username)
        await page.authenticate({
            'username': username,
            'password': password
        })
        await page.click('input[type=submit]')

        # Wait for the page to load and then grab the saml response
        await page.waitForNavigation({"waitUntil": "load"})
        try:
            await page.waitForSelector(
                'input[type="password"]:not(.moveOffScreen)',
                {"visible": True})
            await page.focus('input[type="password"]')
            await page.keyboard.type(password)
            await page.click('input[type="submit"]')
        except Exception as e:
            print(f'Could not input/submit password:\n\n Error: {e}')
            pass

        # Sign in another way
        try:
            await page.waitForSelector(
                'a[id="signInAnotherWay"]:not(.moveOffScreen)',
                {"visible": True})
            await page.click('a[id="signInAnotherWay"]')
            print('Clicked sign in another way...')
        except Exception as e:
            print(f'Could not click sign in another way:\n\n Error: {e}')
            pass

        # Phone app approval
        try:
            await page.waitForSelector(
                'div[data-value="PhoneAppNotification"]:not(.moveOffScreen)',
                {"visible": True})
            await page.click('div[data-value="PhoneAppNotification"]')
            print('Clicked PhoneAppNotification...')
        except Exception as e:
            print(f'Could not click PhoneAppNotification:\n\n Error: {e}')
            pass

        # Stay signed in
        try:
            await page.waitForSelector(
                'input[type="submit"]:not(.moveOffScreen)', {"visible": True})
            await page.click('input[type="submit"]')
            print('Clicked yes to "stay signed in"...')
        except Exception as e:
            print(f'Could not submit yes to stay signed in:\n\n Error: {e}')
            pass

        # Get SAML response
        try:
            page.on('request', _saml_response)
            await page.setRequestInterception(True)
        except Exception as e:
            print(f'Could not get SAML response:\n\n Error: {e}')

        try:
            if await self._querySelector(page, '.has-error'):
                raise FormError

            if mfa:
                if self._azure_mfa not in MFA_WAIT_METHODS:
                    await page.waitForSelector(
                        'input[name="otc"]:not(.moveOffScreen)',
                        {"visible": True})
                    await page.focus('input[name="otc"]')
                    mfa_token = input('Azure MFA Token: ')
                    for l in mfa_token:
                        await page.keyboard.sendCharacter(l)
                    await page.click('input[type=submit]')
                else:
                    print('Processing SAML response...')

            if self._azure_kmsi:
                await page.waitForSelector('form[action="/kmsi"]',
                                           timeout=self._AWAIT_TIMEOUT)
                await page.waitForSelector('#idBtn_Back')
                await page.click('#idBtn_Back')

            if not self.saml_response:
                page.on('request', _saml_response)
                await page.setRequestInterception(True)

            wait_time = time.time() + self._MFA_TIMEOUT
            while time.time() < wait_time and not self.saml_response:
                if await self._querySelector(page, '.has-error'):
                    raise FormError

            if not self.saml_response:
                raise TimeoutError

        except (TimeoutError, BrowserError, FormError) as e:
            print('An error occurred while authenticating, check credentials.')
            print(e)
            if self._debug:
                debugfile = 'aadaerror-{}.png'.format(
                    datetime.now().strftime("%Y-%m-%dT%H%m%SZ"))
                await page.screenshot({'path': debugfile})
                print('See screenshot {} for clues.'.format(debugfile))
            exit(1)

        finally:
            await browser.close()

    @staticmethod
    def _get_aws_roles(saml_response):
        aws_roles = []
        for attribute in ET.fromstring(base64.b64decode(saml_response)).iter(
                '{urn:oasis:names:tc:SAML:2.0:assertion}Attribute'):
            if (attribute.get('Name') ==
                    'https://aws.amazon.com/SAML/Attributes/Role'):
                for value in attribute.iter(
                        '{urn:oasis:names:tc:SAML:2.0:assertion}AttributeValue'
                ):
                    aws_roles.append(value.text)

        for role in aws_roles:
            chunks = role.split(',')
            if 'saml-provider' in chunks[0]:
                new_role = chunks[1] + ',' + chunks[0]
                index = aws_roles.index(role)
                aws_roles.insert(index, new_role)
                aws_roles.remove(role)
        return aws_roles

    def _assume_role(self, role_arn, principal_arn, saml_response):
        return boto3.client('sts').assume_role_with_saml(
            RoleArn=role_arn,
            PrincipalArn=principal_arn,
            SAMLAssertion=saml_response,
            DurationSeconds=self._session_duration)

    def _save_credentials(self, credentials, role_arn):
        self._set_config_value('aws_role_arn', role_arn)
        self._set_config_value('aws_access_key_id', credentials['AccessKeyId'])
        self._set_config_value('aws_secret_access_key',
                               credentials['SecretAccessKey'])
        self._set_config_value('aws_session_token',
                               credentials['SessionToken'])

    @staticmethod
    def _choose_role(self, aws_roles):
        count_roles = len(aws_roles)
        if count_roles > 1:
            if self._role:
                return self._role, 'arn:aws:iam::592380362770:saml-provider/WAAD'
            else:
                allowed_values = list(range(1, count_roles + 1))
                for i, role in enumerate(aws_roles, start=1):
                    print('[ {} ]: {}'.format(i, role.split(',')[0]))

                print('Choose the role you would like to assume:')
                selected_role = int(input('Selection: '))
                while selected_role not in allowed_values:
                    print('Invalid role index, please try again')
                    selected_role = int(input('Selection: '))
                return aws_roles[selected_role -
                                 1].split(',')[0], aws_roles[selected_role -
                                                             1].split(',')[1]
        return aws_roles[0].split(',')[0], aws_roles[0].split(',')[1]

    @staticmethod
    def _post(session, url, data, headers):
        return json.loads(session.post(url, data=data, headers=headers).text)

    def _login(self):
        """

        :param parsed_args:
        :return:
        """
        url = self._build_saml_login_url()
        username_input = self._azure_username
        profile = self._session.profile if self._session.profile else 'default'
        role_stored_in_config = self._role
        kr_pass = None

        # response from myrealpageportal.com as a reference to know if we are on the VPN or not.
        response = requests.get('http://myrealpageportal.com/')
        if response.headers['Server'] == 'BigIP':
            print(
                'You are not connected the RealPage network. You might need to connect to the BigIP VPN. Exiting...'
            )
            exit(1)
        else:
            print('RealPage network detected...')

        print(f'\n[{color.OKGREEN}Azure AD AWS CLI Authentication{color.END}]')
        print(
            f'{color.BOLD}Profile:{color.END} {color.AQUA}{profile}{color.END}'
        )
        print(
            f'{color.BOLD}Role:{color.END} {color.AQUA}{role_stored_in_config}{color.END}'
        )
        print(
            f'{color.BOLD}Username:{color.END} {color.AQUA}{self._azure_username}{color.END}'
        )

        if KEYRING and self._use_keyring:
            try:
                #print('Getting password from keyring')
                kr_pass = keyring.get_password('aada', self._azure_username)
            except Exception as e:
                print('Failed getting password from Keyring {}'.format(e))

        if kr_pass is not None:
            password_input = kr_pass
        else:
            password_input = getpass.getpass(
                f'{color.BOLD}Password:{color.END} ')

        print('-------------------------------------------------------------')
        print('Logging in...')

        asyncio.get_event_loop().run_until_complete(
            self._render_js_form(url, username_input, password_input,
                                 self._azure_mfa))

        if not self.saml_response:
            print('Something went wrong! No roles found!')
            exit(1)
        aws_roles = self._get_aws_roles(self.saml_response)
        role_arn, principal = self._choose_role(self, aws_roles)

        role_name = role_arn.split('/')[-1]

        print(f'{color.OKGREEN}Assuming role:{color.END} {role_name}')
        sts_token = self._assume_role(role_arn, principal, self.saml_response)
        credentials = sts_token['Credentials']
        self._save_credentials(credentials, role_arn)

        credential_experation_date = credentials['Expiration'].replace(
            tzinfo=tz.gettz('UTC')).astimezone(tz.tzlocal())
        experiation_delta = credentials['Expiration'] - datetime.utcnow(
        ).replace(tzinfo=pytz.UTC)
        time_till_experiation_seconds = experiation_delta.total_seconds()
        time_till_experiation_hours = int(
            time_till_experiation_seconds / 60 / 60) + 1
        print(
            f'{color.OKGREEN}Expiration:{color.END} {credential_experation_date:%Y-%m-%d %H:%M:%S} ( {time_till_experiation_hours} hours )'
        )
        print(
            '-------------------------------------------------------------\n')
        return 0
示例#20
0
class Login:
    _SAML_REQUEST = \
        '<samlp:AuthnRequest xmlns="urn:oasis:names:tc:SAML:2.0:metadata" xml' \
        'ns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" ID="id_{id}" Version' \
        '="2.0" IsPassive="false" IssueInstant="{date}" AssertionConsumerServ' \
        'iceURL="https://signin.aws.amazon.com/saml"><Issuer xmlns="urn:oasis' \
        ':names:tc:SAML:2.0:assertion">{app_id}</Issuer><samlp:NameIDPolicy F' \
        'ormat="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"/></sa' \
        'mlp:AuthnRequest>'

    _BEGIN_AUTH_URL = '{url}/common/SAS/BeginAuth'
    _END_AUTH_URL = '{url}/common/SAS/EndAuth'
    _PROCESS_AUTH_URL = '{url}/common/SAS/ProcessAuth'
    _SAML_URL = '{url}/{tenant_id}/saml2?SAMLRequest={saml_request}'
    _REFERER = '{url}/{tenant_id}/login'

    _CREDENTIALS = [
        'aws_access_key_id', 'aws_secret_access_key', 'aws_session_token'
    ]
    _MFA_DELAY = 3
    _MFA_TIMEOUT = 60  # timeout in seconds to process MFA
    _AWAIT_TIMEOUT = 30000
    _SLEEP_TIMEOUT = 500
    _EXEC_PATH = os.environ.get('CHROME_EXECUTABLE_PATH')
    _RETRIES = 5

    def __init__(self,
                 session,
                 role=None,
                 account=None,
                 debug=False,
                 headless=True,
                 saml_request=None):
        self._session = session
        self._role = role
        self._account = account
        self._debug = debug
        self._headless = headless
        self._config = self._session.get_scoped_config()
        self._config_writer = ConfigFileWriter()
        self._azure_tenant_id = self._config.get('azure_tenant_id')
        self._azure_app_id_uri = self._config.get('azure_app_id_uri')
        self._azure_mfa = self._config.get('azure_mfa')
        self._azure_kmsi = self._config.get('azure_kmsi', False)
        self._azure_username = self._config.get('azure_username')
        self._azure_password = None
        self._session_duration = int(self._config.get('session_duration',
                                                      3600))
        self._use_keyring = self._config.get('use_keyring')
        self.saml_response = None

        if saml_request:
            self._SAML_REQUEST = saml_request

    def __call__(self):
        return self._login()

    def _set_config_value(self, key, value):
        section = 'default'

        if self._session.profile is not None:
            section = 'profile {}'.format(self._session.profile)

        config_filename = os.path.expanduser(
            self._session.get_config_variable('config_file'))
        updated_config = {'__section__': section, key: value}

        if key in self._CREDENTIALS:
            config_filename = os.path.expanduser(
                self._session.get_config_variable('credentials_file'))
            section_name = updated_config['__section__']

            if section_name.startswith('profile '):
                updated_config['__section__'] = section_name[8:]
        self._config_writer.update_config(updated_config, config_filename)

    def _build_saml_login_url(self):
        saml_request = base64.b64encode(
            zlib.compress(self._SAML_REQUEST.strip().format(
                date=datetime.now().strftime("%Y-%m-%dT%H:%m:%SZ"),
                tenant_id=self._azure_tenant_id,
                id=uuid.uuid4(),
                app_id=self._azure_app_id_uri).encode('ascii'))
            [2:-4]).decode()
        return self._SAML_URL.format(url=LOGIN_URL,
                                     tenant_id=self._azure_tenant_id,
                                     saml_request=quote(saml_request))

    @classmethod
    async def _querySelector(cls, page, element, retries=0):
        if retries > cls._RETRIES:
            raise TimeoutError
        try:
            return await page.querySelector(element)
        except NetworkError:
            await page.waitFor(cls._SLEEP_TIMEOUT)
            return await cls._querySelector(page, element, retries + 1)

    async def _render_js_form(self, url, username, password, mfa=None):
        browser = await launch(executablePath=self._EXEC_PATH,
                               headless=self._headless)

        pages = await browser.pages()
        page = pages[0]

        async def _saml_response(req):
            if req.url == 'https://signin.aws.amazon.com/saml':
                self.saml_response = parse_qs(req.postData)['SAMLResponse'][0]
                await req.respond({
                    'status': 200,
                    'contentType': 'text/plain',
                    'body': ''
                })
            else:
                await req.continue_()

        await page.goto(url, waitUntil='domcontentloaded')
        await page.waitForSelector(
            'input[name="loginfmt"]:not(.moveOffScreen)', {"visible": True})
        await page.focus('input[name="loginfmt"]')
        await page.keyboard.type(username)
        await page.click('input[type=submit]')
        await page.waitForSelector('input[name="passwd"]:not(.moveOffScreen)',
                                   {"visible": True})
        await page.focus('input[name="passwd"]')
        await page.keyboard.type(password)
        await page.click('input[type=submit]')

        try:
            if await self._querySelector(page, '.has-error'):
                raise FormError

            if mfa:
                if self._azure_mfa not in MFA_WAIT_METHODS:
                    await page.waitForSelector(
                        'input[name="otc"]:not(.moveOffScreen)',
                        {"visible": True})
                    await page.focus('input[name="otc"]')
                    mfa_token = input('Azure MFA Token: ')
                    for l in mfa_token:
                        await page.keyboard.sendCharacter(l)
                    await page.click('input[type=submit]')
                else:
                    print('Processing MFA authentication...')

            if self._azure_kmsi:
                await page.waitForSelector('form[action="/kmsi"]',
                                           timeout=self._AWAIT_TIMEOUT)
                await page.waitForSelector('#idBtn_Back')
                await page.click('#idBtn_Back')

            page.on('request', _saml_response)
            await page.setRequestInterception(True)

            wait_time = time.time() + self._MFA_TIMEOUT
            while time.time() < wait_time and not self.saml_response:
                if await self._querySelector(page, '.has-error'):
                    raise FormError

            if not self.saml_response:
                raise TimeoutError

        except (TimeoutError, BrowserError, FormError) as e:
            print('An error occurred while authenticating, check credentials.')
            print(e)
            if self._debug:
                debugfile = 'aadaerror-{}.png'.format(
                    datetime.now().strftime("%Y-%m-%dT%H%m%SZ"))
                await page.screenshot({'path': debugfile})
                print('See screenshot {} for clues.'.format(debugfile))
            exit(1)

        finally:
            await browser.close()

    @staticmethod
    def _get_aws_roles(saml_response):
        aws_roles = []
        for attribute in ET.fromstring(base64.b64decode(saml_response)).iter(
                '{urn:oasis:names:tc:SAML:2.0:assertion}Attribute'):
            if (attribute.get('Name') ==
                    'https://aws.amazon.com/SAML/Attributes/Role'):
                for value in attribute.iter(
                        '{urn:oasis:names:tc:SAML:2.0:assertion}AttributeValue'
                ):
                    aws_roles.append(value.text)

        for role in aws_roles:
            chunks = role.split(',')
            if 'saml-provider' in chunks[0]:
                new_role = chunks[1] + ',' + chunks[0]
                index = aws_roles.index(role)
                aws_roles.insert(index, new_role)
                aws_roles.remove(role)
        return aws_roles

    def _assume_role(self, role_arn, principal_arn, saml_response):
        return boto3.client('sts').assume_role_with_saml(
            RoleArn=role_arn,
            PrincipalArn=principal_arn,
            SAMLAssertion=saml_response,
            DurationSeconds=self._session_duration)

    def _save_credentials(self, credentials, role_arn):
        self._set_config_value('aws_role_arn', role_arn)
        self._set_config_value('aws_access_key_id', credentials['AccessKeyId'])
        self._set_config_value('aws_secret_access_key',
                               credentials['SecretAccessKey'])
        self._set_config_value('aws_session_token',
                               credentials['SessionToken'])

    @staticmethod
    def _choose_role(self, aws_roles):
        count_roles = len(aws_roles)
        if count_roles > 1:
            if self._role:
                for i, role in enumerate(aws_roles, start=1):
                    row = role.split(',')[0]
                    role = row.split('/')[1]
                    account = row.split(':')[4]
                    if role == self._role and account == self._account:
                        return aws_roles[i - 1].split(',')[0], aws_roles[
                            i - 1].split(',')[1]
            else:
                allowed_values = list(range(1, count_roles + 1))
                for i, role in enumerate(aws_roles, start=1):
                    print('[ {} ]: {}'.format(i, role.split(',')[0]))

                print('Choose the role you would like to assume:')
                selected_role = int(input('Selection: '))
                while selected_role not in allowed_values:
                    print('Invalid role index, please try again')
                    selected_role = int(input('Selection: '))
                return aws_roles[selected_role -
                                 1].split(',')[0], aws_roles[selected_role -
                                                             1].split(',')[1]
        return aws_roles[0].split(',')[0], aws_roles[0].split(',')[1]

    @staticmethod
    def _post(session, url, data, headers):
        return json.loads(session.post(url, data=data, headers=headers).text)

    def _login(self):
        """

        :param parsed_args:
        :return:
        """
        url = self._build_saml_login_url()
        username_input = self._azure_username
        kr_pass = None
        print('Azure username: {}'.format(self._azure_username))

        if KEYRING and self._use_keyring:
            try:
                print('Getting password from keyring')
                kr_pass = keyring.get_password('aada', self._azure_username)
            except Exception as e:
                print('Failed getting password from Keyring {}'.format(e))

        if kr_pass is not None:
            password_input = kr_pass
        else:
            password_input = getpass.getpass('Azure password: '******'Something went wrong!')
            exit(1)
        aws_roles = self._get_aws_roles(self.saml_response)
        role_arn, principal = self._choose_role(self, aws_roles)

        print('Assuming AWS Role: {}'.format(role_arn))
        sts_token = self._assume_role(role_arn, principal, self.saml_response)
        credentials = sts_token['Credentials']
        self._save_credentials(credentials, role_arn)
        profile = self._session.profile if self._session.profile else 'default'

        print(
            '\n-------------------------------------------------------------')
        print('Your access key pair has been stored in the AWS configuration\n'
              'file under the {} profile.'.format(profile))
        print('Credentials expires at {:%Y-%m-%d %H:%M:%S}.'.format(
            credentials['Expiration']))
        print(
            '-------------------------------------------------------------\n')
        return 0
示例#21
0
 def __init__(self, session, config_writer=None):
     super(ConfigureSetCommand, self).__init__(session)
     if config_writer is None:
         config_writer = ConfigFileWriter()
     self._config_writer = config_writer