Example #1
0
    def deploy_ssh_pubkey(self, username, pubkey):
        """
        Deploy authorized_key
        """
        path, thumbprint, value = pubkey
        if path is None:
            raise OSUtilError("Public key path is None")

        crytputil = CryptUtil(conf.get_openssl_cmd())

        path = self._norm_path(path)
        dir_path = os.path.dirname(path)
        fileutil.mkdir(dir_path, mode=0o700, owner=username)
        if value is not None:
            if not value.startswith("ssh-"):
                raise OSUtilError("Bad public key: {0}".format(value))
            fileutil.write_file(path, value)
        elif thumbprint is not None:
            lib_dir = conf.get_lib_dir()
            crt_path = os.path.join(lib_dir, thumbprint + '.crt')
            if not os.path.isfile(crt_path):
                raise OSUtilError("Can't find {0}.crt".format(thumbprint))
            pub_path = os.path.join(lib_dir, thumbprint + '.pub')
            pub = crytputil.get_pubkey_from_crt(crt_path)
            fileutil.write_file(pub_path, pub)
            self.set_selinux_context(pub_path,
                                     'unconfined_u:object_r:ssh_home_t:s0')
            self.openssl_to_openssh(pub_path, path)
            fileutil.chmod(pub_path, 0o600)
        else:
            raise OSUtilError("SSH public key Fingerprint and Value are None")

        self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0')
        fileutil.chowner(path, username)
        fileutil.chmod(path, 0o644)
Example #2
0
    def deploy_ssh_pubkey(self, username, pubkey):
        """
        Deploy authorized_key
        """
        path, thumbprint, value = pubkey
        if path is None:
            raise OSUtilError("Public key path is None")

        crytputil = CryptUtil(conf.get_openssl_cmd())

        path = self._norm_path(path)
        dir_path = os.path.dirname(path)
        fileutil.mkdir(dir_path, mode=0o700, owner=username)
        if value is not None:
            if not value.startswith("ssh-"):
                raise OSUtilError("Bad public key: {0}".format(value))
            fileutil.write_file(path, value)
        elif thumbprint is not None:
            lib_dir = conf.get_lib_dir()
            crt_path = os.path.join(lib_dir, thumbprint + '.crt')
            if not os.path.isfile(crt_path):
                raise OSUtilError("Can't find {0}.crt".format(thumbprint))
            pub_path = os.path.join(lib_dir, thumbprint + '.pub')
            pub = crytputil.get_pubkey_from_crt(crt_path)
            fileutil.write_file(pub_path, pub)
            self.set_selinux_context(pub_path,
                                     'unconfined_u:object_r:ssh_home_t:s0')
            self.openssl_to_openssh(pub_path, path)
            fileutil.chmod(pub_path, 0o600)
        else:
            raise OSUtilError("SSH public key Fingerprint and Value are None")

        self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0')
        fileutil.chowner(path, username)
        fileutil.chmod(path, 0o644)
Example #3
0
 def __init__(self, protocol):
     self.os_util = get_osutil()
     self.protocol = protocol
     self.cryptUtil = CryptUtil(conf.get_openssl_cmd())
     self.remote_access = None
     self.incarnation = 0
     self.error_message = ""
Example #4
0
 def openssl_to_openssh(self, input_file, output_file):
     pubkey = fileutil.read_file(input_file)
     try:
         cryptutil = CryptUtil(conf.get_openssl_cmd())
         ssh_rsa_pubkey = cryptutil.asn1_to_ssh(pubkey)
     except CryptError as e:
         raise OSUtilError(ustr(e))
     fileutil.write_file(output_file, ssh_rsa_pubkey)
Example #5
0
 def openssl_to_openssh(self, input_file, output_file):
     pubkey = fileutil.read_file(input_file)
     try:
         cryptutil = CryptUtil(conf.get_openssl_cmd())
         ssh_rsa_pubkey = cryptutil.asn1_to_ssh(pubkey)
     except CryptError as e:
         raise OSUtilError(ustr(e))
     fileutil.append_file(output_file, ssh_rsa_pubkey)
Example #6
0
 def test_decrypt_encrypted_text(self):
     encrypted_string = load_data("wire/encrypted.enc")
     prv_key = os.path.join(self.tmp_dir, "TransportPrivate.pem")
     with open(prv_key, 'w+') as c:
         c.write(load_data("wire/sample.pem"))
     secret = ']aPPEv}uNg1FPnl?'
     crypto = CryptUtil(conf.get_openssl_cmd())
     decrypted_string = crypto.decrypt_secret(encrypted_string, prv_key)
     self.assertEqual(secret, decrypted_string,
                      "decrypted string does not match expected")
Example #7
0
    def detect(self):
        self.client.check_wire_protocol_version()

        trans_prv_file = os.path.join(conf.get_lib_dir(),
                                      TRANSPORT_PRV_FILE_NAME)
        trans_cert_file = os.path.join(conf.get_lib_dir(),
                                       TRANSPORT_CERT_FILE_NAME)
        cryptutil = CryptUtil(conf.get_openssl_cmd())
        cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file)

        self.client.update_goal_state(forced=True)
Example #8
0
    def detect(self):
        self.client.check_wire_protocol_version()

        trans_prv_file = os.path.join(conf.get_lib_dir(),
                                      TRANSPORT_PRV_FILE_NAME)
        trans_cert_file = os.path.join(conf.get_lib_dir(),
                                       TRANSPORT_CERT_FILE_NAME)
        cryptutil = CryptUtil(conf.get_openssl_cmd())
        cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file)

        self.client.update_goal_state(forced=True)
Example #9
0
 def __init__(self):
     self.os_util = get_osutil()
     self.protocol_util = get_protocol_util()
     self.protocol = None
     self.cryptUtil = CryptUtil(conf.get_openssl_cmd())
     self.remote_access = None
     self.incarnation = 0
Example #10
0
 def test_decrypt_encrypted_text_text_not_encrypted(self):
     encrypted_string = "abc@123"
     prv_key = os.path.join(self.tmp_dir, "TransportPrivate.pem")
     with open(prv_key, 'w+') as c:
         c.write(load_data("wire/sample.pem"))
     crypto = CryptUtil(conf.get_openssl_cmd())
     self.assertRaises(CryptError, crypto.decrypt_secret, encrypted_string,
                       prv_key)
Example #11
0
    def detect(self):
        self.get_vminfo()
        trans_prv_file = os.path.join(conf.get_lib_dir(),
                                      TRANSPORT_PRV_FILE_NAME)
        trans_cert_file = os.path.join(conf.get_lib_dir(),
                                       TRANSPORT_CERT_FILE_NAME)
        cryptutil = CryptUtil(conf.get_openssl_cmd())
        cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file)

        #"Install" the cert and private key to /var/lib/waagent
        thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file)
        prv_file = os.path.join(conf.get_lib_dir(),
                                "{0}.prv".format(thumbprint))
        crt_file = os.path.join(conf.get_lib_dir(),
                                "{0}.crt".format(thumbprint))
        shutil.copyfile(trans_prv_file, prv_file)
        shutil.copyfile(trans_cert_file, crt_file)
Example #12
0
    def detect(self):
        self.get_vminfo()
        trans_prv_file = os.path.join(conf.get_lib_dir(), 
                                      TRANSPORT_PRV_FILE_NAME)
        trans_cert_file = os.path.join(conf.get_lib_dir(), 
                                       TRANSPORT_CERT_FILE_NAME)
        cryptutil = CryptUtil(conf.get_openssl_cmd())
        cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file)

        #"Install" the cert and private key to /var/lib/waagent
        thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file)
        prv_file = os.path.join(conf.get_lib_dir(), 
                                "{0}.prv".format(thumbprint))
        crt_file = os.path.join(conf.get_lib_dir(), 
                                "{0}.crt".format(thumbprint))
        shutil.copyfile(trans_prv_file, prv_file)
        shutil.copyfile(trans_cert_file, crt_file)
Example #13
0
 def test_decrypt_encrypted_text_wrong_private_key(self):
     encrypted_string = load_data("wire/encrypted.enc")
     prv_key = os.path.join(self.tmp_dir, "wrong.pem")
     with open(prv_key, 'w+') as c:
         c.write(load_data("wire/trans_prv"))
     crypto = CryptUtil(conf.get_openssl_cmd())
     self.assertRaises(CryptError, crypto.decrypt_secret, encrypted_string,
                       prv_key)
Example #14
0
    def openssl_to_openssh(self, input_file, output_file):
        cryptutil = CryptUtil(conf.get_openssl_cmd())
        ret, out = shellutil.run_get_output(
            conf.get_openssl_cmd() +
            " rsa -pubin -noout -text -in '" + input_file + "'")
        if ret != 0:
            raise OSUtilError('openssl failed with {0}'.format(ret))

        modulus = []
        exponent = []
        buf = None
        for line in out.split('\n'):
            if line.startswith('Modulus:'):
                buf = modulus
                buf.append(line)
                continue
            if line.startswith('Exponent:'):
                buf = exponent
                buf.append(line)
                continue
            if buf and line:
                buf.append(line.strip().replace(':', ''))

        def text_to_num(buf):
            if len(buf) == 1:
                return int(buf[0].split()[1])
            return long(''.join(buf[1:]), 16)

        n = text_to_num(modulus)
        e = text_to_num(exponent)

        keydata = bytearray()
        keydata.extend(struct.pack('>I', len('ssh-rsa')))
        keydata.extend(b'ssh-rsa')
        keydata.extend(struct.pack('>I', len(cryptutil.num_to_bytes(e))))
        keydata.extend(cryptutil.num_to_bytes(e))
        keydata.extend(struct.pack('>I', len(cryptutil.num_to_bytes(n)) + 1))
        keydata.extend(b'\0')
        keydata.extend(cryptutil.num_to_bytes(n))
        keydata_base64 = base64.b64encode(bytebuffer(keydata))
        fileutil.write_file(output_file,
                            ustr(b'ssh-rsa ' + keydata_base64 + b'\n',
                                 encoding='utf-8'))
Example #15
0
 def deploy_ssh_keypair(self, username, keypair):
     """
     Deploy id_rsa and id_rsa.pub
     """
     path, thumbprint = keypair
     path = self._norm_path(path)
     dir_path = os.path.dirname(path)
     fileutil.mkdir(dir_path, mode=0o700, owner=username)
     lib_dir = conf.get_lib_dir()
     prv_path = os.path.join(lib_dir, thumbprint + '.prv')
     if not os.path.isfile(prv_path):
         raise OSUtilError("Can't find {0}.prv".format(thumbprint))
     shutil.copyfile(prv_path, path)
     pub_path = path + '.pub'
     crytputil = CryptUtil(conf.get_openssl_cmd())
     pub = crytputil.get_pubkey_from_prv(prv_path)
     fileutil.write_file(pub_path, pub)
     self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0')
     self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0')
     os.chmod(path, 0o644)
     os.chmod(pub_path, 0o600)
Example #16
0
 def deploy_ssh_keypair(self, username, keypair):
     """
     Deploy id_rsa and id_rsa.pub
     """
     path, thumbprint = keypair
     path = self._norm_path(path)
     dir_path = os.path.dirname(path)
     fileutil.mkdir(dir_path, mode=0o700, owner=username)
     lib_dir = conf.get_lib_dir()
     prv_path = os.path.join(lib_dir, thumbprint + '.prv')
     if not os.path.isfile(prv_path):
         raise OSUtilError("Can't find {0}.prv".format(thumbprint))
     shutil.copyfile(prv_path, path)
     pub_path = path + '.pub'
     crytputil = CryptUtil(conf.get_openssl_cmd())
     pub = crytputil.get_pubkey_from_prv(prv_path)
     fileutil.write_file(pub_path, pub)
     self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0')
     self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0')
     os.chmod(path, 0o644)
     os.chmod(pub_path, 0o600)
    def openssl_to_openssh(self, input_file, output_file):
        cryptutil = CryptUtil(conf.get_openssl_cmd())
        ret, out = shellutil.run_get_output(conf.get_openssl_cmd() +
                                            " rsa -pubin -noout -text -in '" +
                                            input_file + "'")
        if ret != 0:
            raise OSUtilError('openssl failed with {0}'.format(ret))

        modulus = []
        exponent = []
        buf = None
        for line in out.split('\n'):
            if line.startswith('Modulus:'):
                buf = modulus
                buf.append(line)
                continue
            if line.startswith('Exponent:'):
                buf = exponent
                buf.append(line)
                continue
            if buf and line:
                buf.append(line.strip().replace(':', ''))

        def text_to_num(buf):
            if len(buf) == 1:
                return int(buf[0].split()[1])
            return long(''.join(buf[1:]), 16)

        n = text_to_num(modulus)
        e = text_to_num(exponent)

        keydata = bytearray()
        keydata.extend(struct.pack('>I', len('ssh-rsa')))
        keydata.extend(b'ssh-rsa')
        keydata.extend(struct.pack('>I', len(cryptutil.num_to_bytes(e))))
        keydata.extend(cryptutil.num_to_bytes(e))
        keydata.extend(struct.pack('>I', len(cryptutil.num_to_bytes(n)) + 1))
        keydata.extend(b'\0')
        keydata.extend(cryptutil.num_to_bytes(n))
        keydata_base64 = base64.b64encode(bytebuffer(keydata))
        fileutil.write_file(
            output_file,
            ustr(b'ssh-rsa ' + keydata_base64 + b'\n', encoding='utf-8'))
Example #18
0
 def openssl_to_openssh(self, input_file, output_file):
     cryptutil = CryptUtil(conf.get_openssl_cmd())
     cryptutil.crt_to_ssh(input_file, output_file)
Example #19
0
    def parse(self, xml_text):
        """
        Parse multiple certificates into seperate files.
        """
        xml_doc = parse_doc(xml_text)
        data = findtext(xml_doc, "Data")
        if data is None:
            return

        cryptutil = CryptUtil(conf.get_openssl_cmd())
        p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME)
        p7m = ("MIME-Version:1.0\n"
               "Content-Disposition: attachment; filename=\"{0}\"\n"
               "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n"
               "Content-Transfer-Encoding: base64\n"
               "\n"
               "{2}").format(p7m_file, p7m_file, data)

        self.client.save_cache(p7m_file, p7m)

        trans_prv_file = os.path.join(conf.get_lib_dir(),
                                      TRANSPORT_PRV_FILE_NAME)
        trans_cert_file = os.path.join(conf.get_lib_dir(),
                                       TRANSPORT_CERT_FILE_NAME)
        pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME)
        # decrypt certificates
        cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file,
                              pem_file)

        # The parsing process use public key to match prv and crt.
        buf = []
        begin_crt = False
        begin_prv = False
        prvs = {}
        thumbprints = {}
        index = 0
        v1_cert_list = []
        with open(pem_file) as pem:
            for line in pem.readlines():
                buf.append(line)
                if re.match(r'[-]+BEGIN.*KEY[-]+', line):
                    begin_prv = True
                elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line):
                    begin_crt = True
                elif re.match(r'[-]+END.*KEY[-]+', line):
                    tmp_file = self.write_to_tmp_file(index, 'prv', buf)
                    pub = cryptutil.get_pubkey_from_prv(tmp_file)
                    prvs[pub] = tmp_file
                    buf = []
                    index += 1
                    begin_prv = False
                elif re.match(r'[-]+END.*CERTIFICATE[-]+', line):
                    tmp_file = self.write_to_tmp_file(index, 'crt', buf)
                    pub = cryptutil.get_pubkey_from_crt(tmp_file)
                    thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file)
                    thumbprints[pub] = thumbprint
                    # Rename crt with thumbprint as the file name
                    crt = "{0}.crt".format(thumbprint)
                    v1_cert_list.append({
                        "name": None,
                        "thumbprint": thumbprint
                    })
                    os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt))
                    buf = []
                    index += 1
                    begin_crt = False

        # Rename prv key with thumbprint as the file name
        for pubkey in prvs:
            thumbprint = thumbprints[pubkey]
            if thumbprint:
                tmp_file = prvs[pubkey]
                prv = "{0}.prv".format(thumbprint)
                os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv))

        for v1_cert in v1_cert_list:
            cert = Cert()
            set_properties("certs", cert, v1_cert)
            self.cert_list.certificates.append(cert)
Example #20
0
class RemoteAccessHandler(object):
    def __init__(self, protocol):
        self._os_util = get_osutil()
        self._protocol = protocol
        self._cryptUtil = CryptUtil(conf.get_openssl_cmd())
        self._remote_access = None
        self._incarnation = 0
        self._check_existing_jit_users = True

    def run(self):
        try:
            if self._os_util.jit_enabled:
                current_incarnation = self._protocol.get_incarnation()
                if self._incarnation != current_incarnation:
                    # something changed. Handle remote access if any.
                    self._incarnation = current_incarnation
                    self._remote_access = self._protocol.client.get_remote_access(
                    )
                    self._handle_remote_access()
        except Exception as e:
            msg = u"Exception processing goal state for remote access users: {0} {1}".format(
                ustr(e), traceback.format_exc())
            add_event(AGENT_NAME,
                      version=CURRENT_VERSION,
                      op=WALAEventOperation.RemoteAccessHandling,
                      is_success=False,
                      message=msg)

    def _get_existing_jit_users(self):
        all_users = self._os_util.get_users()
        return set(u[0] for u in all_users if self._is_jit_user(u[4]))

    def _handle_remote_access(self):
        if self._remote_access is not None:
            logger.info("Processing remote access users in goal state.")

            self._check_existing_jit_users = True

            existing_jit_users = self._get_existing_jit_users()
            goal_state_users = set(
                u.name for u in self._remote_access.user_list.users)

            for acc in self._remote_access.user_list.users:
                try:
                    raw_expiration = acc.expiration
                    account_expiration = datetime.strptime(
                        raw_expiration, REMOTE_USR_EXPIRATION_FORMAT)
                    now = datetime.utcnow()
                    if acc.name not in existing_jit_users and now < account_expiration:
                        self._add_user(acc.name, acc.encrypted_password,
                                       account_expiration)
                    elif acc.name in existing_jit_users and now > account_expiration:
                        # user account expired, delete it.
                        logger.info("Remote access user '{0}' expired.",
                                    acc.name)
                        self._remove_user(acc.name)
                except Exception as e:
                    logger.error(
                        "Error processing remote access user '{0}' - {1}",
                        acc.name, ustr(e))

            for user in existing_jit_users:
                try:
                    if user not in goal_state_users:
                        # user explicitly removed
                        self._remove_user(user)
                except Exception as e:
                    logger.error(
                        "Error removing remote access user '{0}' - {1}", user,
                        ustr(e))
        else:
            # There are no JIT users in the goal state; that may mean that they were removed or that they
            # were never added. Enumerating the users on the current vm can be very slow and this path is hit
            # on each goal state; we use self._check_existing_jit_users to avoid enumerating the users
            # every single time.
            if self._check_existing_jit_users:
                logger.info("Looking for existing remote access users.")

                existing_jit_users = self._get_existing_jit_users()

                remove_user_errors = False

                for user in existing_jit_users:
                    try:
                        self._remove_user(user)
                    except Exception as e:
                        logger.error(
                            "Error removing remote access user '{0}' - {1}",
                            user, ustr(e))
                        remove_user_errors = True

                if not remove_user_errors:
                    self._check_existing_jit_users = False

    @staticmethod
    def _is_jit_user(comment):
        return comment == REMOTE_ACCESS_ACCOUNT_COMMENT

    def _add_user(self, username, encrypted_password, account_expiration):
        user_added = False

        try:
            expiration_date = (account_expiration +
                               timedelta(days=1)).strftime(DATE_FORMAT)
            logger.info(
                "Adding remote access user '{0}' with expiration date {1}",
                username, expiration_date)
            self._os_util.useradd(username, expiration_date,
                                  REMOTE_ACCESS_ACCOUNT_COMMENT)
            user_added = True

            logger.info("Adding remote access user '{0}' to sudoers", username)
            prv_key = os.path.join(conf.get_lib_dir(), TRANSPORT_PRIVATE_CERT)
            pwd = self._cryptUtil.decrypt_secret(encrypted_password, prv_key)
            self._os_util.chpasswd(username, pwd, conf.get_password_cryptid(),
                                   conf.get_password_crypt_salt_len())
            self._os_util.conf_sudoer(username)
        except Exception:
            if user_added:
                self._remove_user(username)
            raise

    def _remove_user(self, username):
        logger.info("Removing remote access user '{0}'", username)
        self._os_util.del_account(username)
Example #21
0
 def __init__(self, protocol):
     self._os_util = get_osutil()
     self._protocol = protocol
     self._cryptUtil = CryptUtil(conf.get_openssl_cmd())
     self._remote_access = None
     self._check_existing_jit_users = True
Example #22
0
class RemoteAccessHandler(object):
    def __init__(self):
        self.os_util = get_osutil()
        self.protocol_util = get_protocol_util()
        self.protocol = None
        self.cryptUtil = CryptUtil(conf.get_openssl_cmd())
        self.remote_access = None
        self.incarnation = 0
        self.error_message = ""

    def run(self):
        try:
            if self.os_util.jit_enabled:
                self.protocol = self.protocol_util.get_protocol()
                current_incarnation = self.protocol.get_incarnation()
                if self.incarnation != current_incarnation:
                    # something changed. Handle remote access if any.
                    self.incarnation = current_incarnation
                    self.remote_access = self.protocol.client.get_remote_access(
                    )
                    self.handle_remote_access()
        except Exception as e:
            msg = u"Exception processing remote access handler: {0} {1}".format(
                ustr(e), traceback.format_exc())
            logger.error(msg)
            add_event(AGENT_NAME,
                      version=CURRENT_VERSION,
                      op=WALAEventOperation.RemoteAccessHandling,
                      is_success=False,
                      message=msg)

    def handle_remote_access(self):
        # Get JIT user accounts.
        all_users = self.os_util.get_users()
        existing_jit_users = set(u[0] for u in all_users
                                 if self.validate_jit_user(u[4]))
        self.err_message = ""
        if self.remote_access is not None:
            goal_state_users = set(u.name
                                   for u in self.remote_access.user_list.users)
            for acc in self.remote_access.user_list.users:
                try:
                    raw_expiration = acc.expiration
                    account_expiration = datetime.strptime(
                        raw_expiration, REMOTE_USR_EXPIRATION_FORMAT)
                    now = datetime.utcnow()
                    if acc.name not in existing_jit_users and now < account_expiration:
                        self.add_user(acc.name, acc.encrypted_password,
                                      account_expiration)
                    elif acc.name in existing_jit_users and now > account_expiration:
                        # user account expired, delete it.
                        logger.info(
                            "user {0} expired from remote_access".format(
                                acc.name))
                        self.remove_user(acc.name)
                except RemoteAccessError as rae:
                    self.err_message = self.err_message + "Error processing user {0}. Exception: {1}"\
                        .format(acc.name, ustr(rae))
            for user in existing_jit_users:
                try:
                    if user not in goal_state_users:
                        # user explicitly removed
                        logger.info(
                            "User {0} removed from remote_access".format(user))
                        self.remove_user(user)
                except RemoteAccessError as rae:
                    self.err_message = self.err_message + "Error removing user {0}. Exception: {1}"\
                        .format(user, ustr(rae))
        else:
            # All users removed, remove any remaining JIT accounts.
            for user in existing_jit_users:
                try:
                    logger.info(
                        "User {0} removed from remote_access. remote_access empty"
                        .format(user))
                    self.remove_user(user)
                except RemoteAccessError as rae:
                    self.err_message = self.err_message + "Error removing user {0}. Exception: {1}"\
                        .format(user, ustr(rae))

    def validate_jit_user(self, comment):
        return comment == REMOTE_ACCESS_ACCOUNT_COMMENT

    def add_user(self, username, encrypted_password, account_expiration):
        try:
            expiration_date = (account_expiration +
                               timedelta(days=1)).strftime(DATE_FORMAT)
            logger.verbose("Adding user {0} with expiration date {1}".format(
                username, expiration_date))
            self.os_util.useradd(username, expiration_date,
                                 REMOTE_ACCESS_ACCOUNT_COMMENT)
        except Exception as e:
            raise RemoteAccessError("Error adding user {0}. {1}".format(
                username, ustr(e)))
        try:
            prv_key = os.path.join(conf.get_lib_dir(), TRANSPORT_PRIVATE_CERT)
            pwd = self.cryptUtil.decrypt_secret(encrypted_password, prv_key)
            self.os_util.chpasswd(username, pwd, conf.get_password_cryptid(),
                                  conf.get_password_crypt_salt_len())
            self.os_util.conf_sudoer(username)
            logger.info(
                "User '{0}' added successfully with expiration in {1}".format(
                    username, expiration_date))
        except Exception as e:
            error = "Error adding user {0}. {1} ".format(username, str(e))
            try:
                self.handle_failed_create(username)
                error += "cleanup successful"
            except RemoteAccessError as rae:
                error += "and error cleaning up {0}".format(str(rae))
            raise RemoteAccessError(
                "Error adding user {0} cleanup successful".format(username),
                ustr(e))

    def handle_failed_create(self, username):
        try:
            self.delete_user(username)
        except Exception as e:
            raise RemoteAccessError(
                "Failed to clean up after account creation for {0}.".format(
                    username), e)

    def remove_user(self, username):
        try:
            self.delete_user(username)
        except Exception as e:
            raise RemoteAccessError(
                "Failed to delete user {0}".format(username), e)

    def delete_user(self, username):
        self.os_util.del_account(username)
        logger.info("User deleted {0}".format(username))
Example #23
0
 def mock_crypt_util(self, *args, **kw):
     #Partially patch instance method of class CryptUtil
     cryptutil = CryptUtil(*args, **kw)
     cryptutil.gen_transport_cert = Mock(
         side_effect=self.mock_gen_trans_cert)
     return cryptutil
Example #24
0
class RemoteAccessHandler(object):
    def __init__(self):
        self.os_util = get_osutil()
        self.protocol_util = get_protocol_util()
        self.protocol = None
        self.cryptUtil = CryptUtil(conf.get_openssl_cmd())
        self.remote_access = None
        self.incarnation = 0

    def run(self):
        try:
            if self.os_util.jit_enabled:
                self.protocol = self.protocol_util.get_protocol()
                current_incarnation = self.protocol.get_incarnation()
                if self.incarnation != current_incarnation:
                    # something changed. Handle remote access if any.
                    self.incarnation = current_incarnation
                    self.remote_access = self.protocol.client.get_remote_access()
                    if self.remote_access is not None:
                        self.handle_remote_access()
        except Exception as e:
            msg = u"Exception processing remote access handler: {0} {1}".format(ustr(e), traceback.format_exc())
            logger.error(msg)
            add_event(AGENT_NAME,
                      version=CURRENT_VERSION,
                      op=WALAEventOperation.RemoteAccessHandling,
                      is_success=False,
                      message=msg)

    def handle_remote_access(self):
        if self.remote_access is not None:
            # Get JIT user accounts.
            all_users = self.os_util.get_users()
            jit_users = set()
            for usr in all_users:
                if self.validate_jit_user(usr[4]):
                    jit_users.add(usr[0])
            for acc in self.remote_access.user_list.users:
                raw_expiration = acc.expiration
                account_expiration = datetime.strptime(raw_expiration, REMOTE_USR_EXPIRATION_FORMAT)
                now = datetime.utcnow()
                if acc.name not in jit_users and now < account_expiration:
                    self.add_user(acc.name, acc.encrypted_password, account_expiration)

    def validate_jit_user(self, comment):
        return comment == REMOTE_ACCESS_ACCOUNT_COMMENT

    def add_user(self, username, encrypted_password, account_expiration):
        try:
            expiration_date = (account_expiration + timedelta(days=1)).strftime(DATE_FORMAT)
            logger.verbose("Adding user {0} with expiration date {1}"
                           .format(username, expiration_date))
            self.os_util.useradd(username, expiration_date, REMOTE_ACCESS_ACCOUNT_COMMENT)
        except OSError as oe:
            logger.error("Error adding user {0}. {1}"
                         .format(username, oe.strerror))
            return
        except Exception as e:
            logger.error("Error adding user {0}. {1}".format(username, ustr(e)))
            return
        try:
            prv_key = os.path.join(conf.get_lib_dir(), TRANSPORT_PRIVATE_CERT)
            pwd = self.cryptUtil.decrypt_secret(encrypted_password, prv_key)
            self.os_util.chpasswd(username, pwd, conf.get_password_cryptid(), conf.get_password_crypt_salt_len())
            self.os_util.conf_sudoer(username)
            logger.info("User '{0}' added successfully with expiration in {1}"
                        .format(username, expiration_date))
            return
        except OSError as oe:
            self.handle_failed_create(username, oe.strerror)
        except Exception as e:
            self.handle_failed_create(username, ustr(e))

    def handle_failed_create(self, username, error_message):
        logger.error("Error creating user {0}. {1}"
                     .format(username, error_message))
        try:
            self.delete_user(username)
        except OSError as oe:
            logger.error("Failed to clean up after account creation for {0}. {1}"
                         .format(username, oe.strerror()))
        except Exception as e:
            logger.error("Failed to clean up after account creation for {0}. {1}"
                         .format(username, str(e)))

    def delete_user(self, username):
        self.os_util.del_account(username)
        logger.info("User deleted {0}".format(username))
Example #25
0
    def parse(self, xml_text):
        """
        Parse multiple certificates into seperate files.
        """
        xml_doc = parse_doc(xml_text)
        data = findtext(xml_doc, "Data")
        if data is None:
            return

        cryptutil = CryptUtil(conf.get_openssl_cmd())
        p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME)
        p7m = ("MIME-Version:1.0\n"
               "Content-Disposition: attachment; filename=\"{0}\"\n"
               "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n"
               "Content-Transfer-Encoding: base64\n"
               "\n"
               "{2}").format(p7m_file, p7m_file, data)

        self.client.save_cache(p7m_file, p7m)

        trans_prv_file = os.path.join(conf.get_lib_dir(),
                                      TRANSPORT_PRV_FILE_NAME)
        trans_cert_file = os.path.join(conf.get_lib_dir(),
                                       TRANSPORT_CERT_FILE_NAME)
        pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME)
        # decrypt certificates
        cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file,
                              pem_file)

        # The parsing process use public key to match prv and crt.
        buf = []
        begin_crt = False
        begin_prv = False
        prvs = {}
        thumbprints = {}
        index = 0
        v1_cert_list = []
        with open(pem_file) as pem:
            for line in pem.readlines():
                buf.append(line)
                if re.match(r'[-]+BEGIN.*KEY[-]+', line):
                    begin_prv = True
                elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line):
                    begin_crt = True
                elif re.match(r'[-]+END.*KEY[-]+', line):
                    tmp_file = self.write_to_tmp_file(index, 'prv', buf)
                    pub = cryptutil.get_pubkey_from_prv(tmp_file)
                    prvs[pub] = tmp_file
                    buf = []
                    index += 1
                    begin_prv = False
                elif re.match(r'[-]+END.*CERTIFICATE[-]+', line):
                    tmp_file = self.write_to_tmp_file(index, 'crt', buf)
                    pub = cryptutil.get_pubkey_from_crt(tmp_file)
                    thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file)
                    thumbprints[pub] = thumbprint
                    # Rename crt with thumbprint as the file name
                    crt = "{0}.crt".format(thumbprint)
                    v1_cert_list.append({
                        "name": None,
                        "thumbprint": thumbprint
                    })
                    os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt))
                    buf = []
                    index += 1
                    begin_crt = False

        # Rename prv key with thumbprint as the file name
        for pubkey in prvs:
            thumbprint = thumbprints[pubkey]
            if thumbprint:
                tmp_file = prvs[pubkey]
                prv = "{0}.prv".format(thumbprint)
                os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv))

        for v1_cert in v1_cert_list:
            cert = Cert()
            set_properties("certs", cert, v1_cert)
            self.cert_list.certificates.append(cert)
Example #26
0
 def mock_crypt_util(self, *args, **kw):
     #Partially patch instance method of class CryptUtil
     cryptutil = CryptUtil(*args, **kw)
     cryptutil.gen_transport_cert = Mock(side_effect=self.mock_gen_trans_cert)
     return cryptutil
Example #27
0
    def test_get_pubkey_from_crt_invalid_file(self):
        crypto = CryptUtil(conf.get_openssl_cmd())
        prv_key = os.path.join(data_dir, "wire", "trans_prv_does_not_exist")

        self.assertRaises(IOError, crypto.get_pubkey_from_prv, prv_key)
Example #28
0
    def parse(self, json_text):
        """
        Parse multiple certificates into seperate files.
        """

        data = json_text["certificateData"]
        if data is None:
            logger.verbose("No data in json_text received!")
            return

        cryptutil = CryptUtil(conf.get_openssl_cmd())
        p7b_file = os.path.join(conf.get_lib_dir(), P7B_FILE_NAME)

        # Wrapping the certificate lines.
        # decode and save the result into p7b_file
        fileutil.write_file(p7b_file, base64.b64decode(data), asbin=True)

        ssl_cmd = "openssl pkcs7 -text -in {0} -inform der | grep -v '^-----' "
        ret, data = shellutil.run_get_output(ssl_cmd.format(p7b_file))

        p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME)
        p7m = ("MIME-Version:1.0\n"
               "Content-Disposition: attachment; filename=\"{0}\"\n"
               "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n"
               "Content-Transfer-Encoding: base64\n"
               "\n"
               "{2}").format(p7m_file, p7m_file, data)

        self.save_cache(p7m_file, p7m)

        trans_prv_file = os.path.join(conf.get_lib_dir(),
                                      TRANSPORT_PRV_FILE_NAME)
        trans_cert_file = os.path.join(conf.get_lib_dir(),
                                       TRANSPORT_CERT_FILE_NAME)
        pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME)
        # decrypt certificates
        cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file,
                              pem_file)

        # The parsing process use public key to match prv and crt.
        buf = []
        begin_crt = False
        begin_prv = False
        prvs = {}
        thumbprints = {}
        index = 0
        v1_cert_list = []
        with open(pem_file) as pem:
            for line in pem.readlines():
                buf.append(line)
                if re.match(r'[-]+BEGIN.*KEY[-]+', line):
                    begin_prv = True
                elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line):
                    begin_crt = True
                elif re.match(r'[-]+END.*KEY[-]+', line):
                    tmp_file = self.write_to_tmp_file(index, 'prv', buf)
                    pub = cryptutil.get_pubkey_from_prv(tmp_file)
                    prvs[pub] = tmp_file
                    buf = []
                    index += 1
                    begin_prv = False
                elif re.match(r'[-]+END.*CERTIFICATE[-]+', line):
                    tmp_file = self.write_to_tmp_file(index, 'crt', buf)
                    pub = cryptutil.get_pubkey_from_crt(tmp_file)
                    thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file)
                    thumbprints[pub] = thumbprint
                    # Rename crt with thumbprint as the file name
                    crt = "{0}.crt".format(thumbprint)
                    v1_cert_list.append({
                        "name": None,
                        "thumbprint": thumbprint
                    })
                    os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt))
                    buf = []
                    index += 1
                    begin_crt = False

        # Rename prv key with thumbprint as the file name
        for pubkey in prvs:
            thumbprint = thumbprints[pubkey]
            if thumbprint:
                tmp_file = prvs[pubkey]
                prv = "{0}.prv".format(thumbprint)
                os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv))

        for v1_cert in v1_cert_list:
            cert = Cert()
            set_properties("certs", cert, v1_cert)
            self.cert_list.certificates.append(cert)
Example #29
0
    def parse(self, json_text):
        """
        Parse multiple certificates into seperate files.
        """

        data = json_text["certificateData"]
        if data is None:
            logger.verbose("No data in json_text received!")
            return

        cryptutil = CryptUtil(conf.get_openssl_cmd())
        p7b_file = os.path.join(conf.get_lib_dir(), P7B_FILE_NAME)

        # Wrapping the certificate lines.
        b64_cmd = "echo {0} | base64 -d > {1}"
        shellutil.run(b64_cmd.format(data, p7b_file))
        ssl_cmd = "openssl pkcs7 -text -in {0} -inform der | grep -v '^-----' "
        ret, data = shellutil.run_get_output(ssl_cmd.format(p7b_file))

        p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME)
        p7m = ("MIME-Version:1.0\n"
               "Content-Disposition: attachment; filename=\"{0}\"\n"
               "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n"
               "Content-Transfer-Encoding: base64\n"
               "\n"
               "{2}").format(p7m_file, p7m_file, data)

        self.save_cache(p7m_file, p7m)

        trans_prv_file = os.path.join(conf.get_lib_dir(),
                                      TRANSPORT_PRV_FILE_NAME)
        trans_cert_file = os.path.join(conf.get_lib_dir(),
                                       TRANSPORT_CERT_FILE_NAME)
        pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME)
        # decrypt certificates
        cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file,
                              pem_file)

        # The parsing process use public key to match prv and crt.
        buf = []
        begin_crt = False
        begin_prv = False
        prvs = {}
        thumbprints = {}
        index = 0
        v1_cert_list = []
        with open(pem_file) as pem:
            for line in pem.readlines():
                buf.append(line)
                if re.match(r'[-]+BEGIN.*KEY[-]+', line):
                    begin_prv = True
                elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line):
                    begin_crt = True
                elif re.match(r'[-]+END.*KEY[-]+', line):
                    tmp_file = self.write_to_tmp_file(index, 'prv', buf)
                    pub = cryptutil.get_pubkey_from_prv(tmp_file)
                    prvs[pub] = tmp_file
                    buf = []
                    index += 1
                    begin_prv = False
                elif re.match(r'[-]+END.*CERTIFICATE[-]+', line):
                    tmp_file = self.write_to_tmp_file(index, 'crt', buf)
                    pub = cryptutil.get_pubkey_from_crt(tmp_file)
                    thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file)
                    thumbprints[pub] = thumbprint
                    # Rename crt with thumbprint as the file name
                    crt = "{0}.crt".format(thumbprint)
                    v1_cert_list.append({
                        "name": None,
                        "thumbprint": thumbprint
                    })
                    os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt))
                    buf = []
                    index += 1
                    begin_crt = False

        # Rename prv key with thumbprint as the file name
        for pubkey in prvs:
            thumbprint = thumbprints[pubkey]
            if thumbprint:
                tmp_file = prvs[pubkey]
                prv = "{0}.prv".format(thumbprint)
                os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv))

        for v1_cert in v1_cert_list:
            cert = Cert()
            set_properties("certs", cert, v1_cert)
            self.cert_list.certificates.append(cert)
Example #30
0
class RemoteAccessHandler(object):
    def __init__(self):
        self.os_util = get_osutil()
        self.protocol_util = get_protocol_util()
        self.protocol = None
        self.cryptUtil = CryptUtil(conf.get_openssl_cmd())
        self.remote_access = None
        self.incarnation = 0

    def run(self):
        try:
            if self.os_util.jit_enabled:
                self.protocol = self.protocol_util.get_protocol()
                current_incarnation = self.protocol.get_incarnation()
                if self.incarnation != current_incarnation:
                    # something changed. Handle remote access if any.
                    self.incarnation = current_incarnation
                    self.remote_access = self.protocol.client.get_remote_access(
                    )
                    if self.remote_access is not None:
                        self.handle_remote_access()
        except Exception as e:
            msg = u"Exception processing remote access handler: {0} {1}".format(
                ustr(e), traceback.format_exc())
            logger.error(msg)
            add_event(AGENT_NAME,
                      version=CURRENT_VERSION,
                      op=WALAEventOperation.RemoteAccessHandling,
                      is_success=False,
                      message=msg)

    def handle_remote_access(self):
        if self.remote_access is not None:
            # Get JIT user accounts.
            all_users = self.os_util.get_users()
            jit_users = set()
            for usr in all_users:
                if self.validate_jit_user(usr[4]):
                    jit_users.add(usr[0])
            for acc in self.remote_access.user_list.users:
                raw_expiration = acc.expiration
                account_expiration = datetime.strptime(
                    raw_expiration, REMOTE_USR_EXPIRATION_FORMAT)
                now = datetime.utcnow()
                if acc.name not in jit_users and now < account_expiration:
                    self.add_user(acc.name, acc.encrypted_password,
                                  account_expiration)

    def validate_jit_user(self, comment):
        return comment == REMOTE_ACCESS_ACCOUNT_COMMENT

    def add_user(self, username, encrypted_password, account_expiration):
        try:
            expiration_date = (account_expiration +
                               timedelta(days=1)).strftime(DATE_FORMAT)
            logger.verbose("Adding user {0} with expiration date {1}".format(
                username, expiration_date))
            self.os_util.useradd(username, expiration_date,
                                 REMOTE_ACCESS_ACCOUNT_COMMENT)
        except OSError as oe:
            logger.error("Error adding user {0}. {1}".format(
                username, oe.strerror))
            return
        except Exception as e:
            logger.error("Error adding user {0}. {1}".format(
                username, ustr(e)))
            return
        try:
            prv_key = os.path.join(conf.get_lib_dir(), TRANSPORT_PRIVATE_CERT)
            pwd = self.cryptUtil.decrypt_secret(encrypted_password, prv_key)
            self.os_util.chpasswd(username, pwd, conf.get_password_cryptid(),
                                  conf.get_password_crypt_salt_len())
            self.os_util.conf_sudoer(username)
            logger.info(
                "User '{0}' added successfully with expiration in {1}".format(
                    username, expiration_date))
            return
        except OSError as oe:
            self.handle_failed_create(username, oe.strerror)
        except Exception as e:
            self.handle_failed_create(username, ustr(e))

    def handle_failed_create(self, username, error_message):
        logger.error("Error creating user {0}. {1}".format(
            username, error_message))
        try:
            self.delete_user(username)
        except OSError as oe:
            logger.error(
                "Failed to clean up after account creation for {0}. {1}".
                format(username, oe.strerror()))
        except Exception as e:
            logger.error(
                "Failed to clean up after account creation for {0}. {1}".
                format(username, str(e)))

    def delete_user(self, username):
        self.os_util.del_account(username)
        logger.info("User deleted {0}".format(username))
Example #31
0
 def openssl_to_openssh(self, input_file, output_file):
     cryptutil = CryptUtil(conf.get_openssl_cmd())
     cryptutil.crt_to_ssh(input_file, output_file)
Example #32
0
 def test_decrypt_encrypted_text_missing_private_key(self):
     encrypted_string = load_data("wire/encrypted.enc")
     prv_key = os.path.join(self.tmp_dir, "TransportPrivate.pem")
     crypto = CryptUtil(conf.get_openssl_cmd())
     self.assertRaises(CryptError, crypto.decrypt_secret, encrypted_string,
                       "abc" + prv_key)
Example #33
0
    def __init__(self, xml_text):
        self.cert_list = CertList()

        # Save the certificates
        local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
        fileutil.write_file(local_file, xml_text)

        # Separate the certificates into individual files.
        xml_doc = parse_doc(xml_text)
        data = findtext(xml_doc, "Data")
        if data is None:
            return

        # if the certificates format is not Pkcs7BlobWithPfxContents do not parse it
        certificateFormat = findtext(xml_doc, "Format")
        if certificateFormat and certificateFormat != "Pkcs7BlobWithPfxContents":
            logger.warn("The Format is not Pkcs7BlobWithPfxContents. Format is " + certificateFormat)
            return

        cryptutil = CryptUtil(conf.get_openssl_cmd())
        p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME)
        p7m = ("MIME-Version:1.0\n"  # pylint: disable=W1308
               "Content-Disposition: attachment; filename=\"{0}\"\n"
               "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n"
               "Content-Transfer-Encoding: base64\n"
               "\n"
               "{2}").format(p7m_file, p7m_file, data)

        fileutil.write_file(p7m_file, p7m)

        trans_prv_file = os.path.join(conf.get_lib_dir(), TRANSPORT_PRV_FILE_NAME)
        trans_cert_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME)
        pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME)
        # decrypt certificates
        cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, pem_file)

        # The parsing process use public key to match prv and crt.
        buf = []
        begin_crt = False  # pylint: disable=W0612
        begin_prv = False  # pylint: disable=W0612
        prvs = {}
        thumbprints = {}
        index = 0
        v1_cert_list = []
        with open(pem_file) as pem:
            for line in pem.readlines():
                buf.append(line)
                if re.match(r'[-]+BEGIN.*KEY[-]+', line):
                    begin_prv = True
                elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line):
                    begin_crt = True
                elif re.match(r'[-]+END.*KEY[-]+', line):
                    tmp_file = Certificates._write_to_tmp_file(index, 'prv', buf)
                    pub = cryptutil.get_pubkey_from_prv(tmp_file)
                    prvs[pub] = tmp_file
                    buf = []
                    index += 1
                    begin_prv = False
                elif re.match(r'[-]+END.*CERTIFICATE[-]+', line):
                    tmp_file = Certificates._write_to_tmp_file(index, 'crt', buf)
                    pub = cryptutil.get_pubkey_from_crt(tmp_file)
                    thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file)
                    thumbprints[pub] = thumbprint
                    # Rename crt with thumbprint as the file name
                    crt = "{0}.crt".format(thumbprint)
                    v1_cert_list.append({
                        "name": None,
                        "thumbprint": thumbprint
                    })
                    os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt))
                    buf = []
                    index += 1
                    begin_crt = False

        # Rename prv key with thumbprint as the file name
        for pubkey in prvs:
            thumbprint = thumbprints[pubkey]
            if thumbprint:
                tmp_file = prvs[pubkey]
                prv = "{0}.prv".format(thumbprint)
                os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv))
                logger.info("Found private key matching thumbprint {0}".format(thumbprint))
            else:
                # Since private key has *no* matching certificate,
                # it will not be named correctly
                logger.warn("Found NO matching cert/thumbprint for private key!")

        # Log if any certificates were found without matching private keys
        # This can happen (rarely), and is useful to know for debugging
        for pubkey in thumbprints:
            if not pubkey in prvs:
                msg = "Certificate with thumbprint {0} has no matching private key."
                logger.info(msg.format(thumbprints[pubkey]))

        for v1_cert in v1_cert_list:
            cert = Cert()
            set_properties("certs", cert, v1_cert)
            self.cert_list.certificates.append(cert)
Example #34
0
class RemoteAccessHandler(object):
    def __init__(self):
        self.os_util = get_osutil()
        self.protocol_util = get_protocol_util()
        self.protocol = None
        self.cryptUtil = CryptUtil(conf.get_openssl_cmd())
        self.remote_access = None
        self.incarnation = 0
        self.error_message = ""

    def run(self):
        try:
            if self.os_util.jit_enabled:
                self.protocol = self.protocol_util.get_protocol()
                current_incarnation = self.protocol.get_incarnation()
                if self.incarnation != current_incarnation:
                    # something changed. Handle remote access if any.
                    self.incarnation = current_incarnation
                    self.remote_access = self.protocol.client.get_remote_access()
                    self.handle_remote_access()
        except Exception as e:
            msg = u"Exception processing remote access handler: {0} {1}".format(ustr(e), traceback.format_exc())
            logger.error(msg)
            add_event(AGENT_NAME,
                      version=CURRENT_VERSION,
                      op=WALAEventOperation.RemoteAccessHandling,
                      is_success=False,
                      message=msg)

    def handle_remote_access(self):
        # Get JIT user accounts.
        all_users = self.os_util.get_users()
        existing_jit_users = set(u[0] for u in all_users if self.validate_jit_user(u[4]))
        self.err_message = ""
        if self.remote_access is not None:
            goal_state_users = set(u.name for u in self.remote_access.user_list.users)
            for acc in self.remote_access.user_list.users:
                try:
                    raw_expiration = acc.expiration
                    account_expiration = datetime.strptime(raw_expiration, REMOTE_USR_EXPIRATION_FORMAT)
                    now = datetime.utcnow()
                    if acc.name not in existing_jit_users and now < account_expiration:
                        self.add_user(acc.name, acc.encrypted_password, account_expiration)
                    elif acc.name in existing_jit_users and now > account_expiration:
                        # user account expired, delete it.
                        logger.info("user {0} expired from remote_access".format(acc.name))
                        self.remove_user(acc.name)
                except RemoteAccessError as rae:
                    self.err_message = self.err_message + "Error processing user {0}. Exception: {1}"\
                        .format(acc.name, ustr(rae))
            for user in existing_jit_users:
                try:
                    if user not in goal_state_users:
                        # user explicitly removed
                        logger.info("User {0} removed from remote_access".format(user))
                        self.remove_user(user)
                except RemoteAccessError as rae:
                    self.err_message = self.err_message + "Error removing user {0}. Exception: {1}"\
                        .format(user, ustr(rae))
        else:
            # All users removed, remove any remaining JIT accounts.
            for user in existing_jit_users:
                try:
                    logger.info("User {0} removed from remote_access. remote_access empty".format(user))
                    self.remove_user(user)
                except RemoteAccessError as rae:
                    self.err_message = self.err_message + "Error removing user {0}. Exception: {1}"\
                        .format(user, ustr(rae))

    def validate_jit_user(self, comment):
        return comment == REMOTE_ACCESS_ACCOUNT_COMMENT

    def add_user(self, username, encrypted_password, account_expiration):
        try:
            expiration_date = (account_expiration + timedelta(days=1)).strftime(DATE_FORMAT)
            logger.verbose("Adding user {0} with expiration date {1}".format(username, expiration_date))
            self.os_util.useradd(username, expiration_date, REMOTE_ACCESS_ACCOUNT_COMMENT)
        except Exception as e:
            raise RemoteAccessError("Error adding user {0}. {1}".format(username, ustr(e)))
        try:
            prv_key = os.path.join(conf.get_lib_dir(), TRANSPORT_PRIVATE_CERT)
            pwd = self.cryptUtil.decrypt_secret(encrypted_password, prv_key)
            self.os_util.chpasswd(username, pwd, conf.get_password_cryptid(), conf.get_password_crypt_salt_len())
            self.os_util.conf_sudoer(username)
            logger.info("User '{0}' added successfully with expiration in {1}".format(username, expiration_date))
        except Exception as e:
            error = "Error adding user {0}. {1} ".format(username, str(e))
            try:
                self.handle_failed_create(username)
                error += "cleanup successful"
            except RemoteAccessError as rae:
                error += "and error cleaning up {0}".format(str(rae))
            raise RemoteAccessError("Error adding user {0} cleanup successful".format(username), ustr(e))

    def handle_failed_create(self, username):
        try:
            self.delete_user(username)
        except Exception as e:
            raise RemoteAccessError("Failed to clean up after account creation for {0}.".format(username), e)

    def remove_user(self, username):
        try:
            self.delete_user(username)
        except Exception as e:
            raise RemoteAccessError("Failed to delete user {0}".format(username), e)

    def delete_user(self, username):
        self.os_util.del_account(username)
        logger.info("User deleted {0}".format(username))