def deploy_ssh_pubkey(self, username, pubkey): """ Deploy authorized_key """ path, thumbprint, value = pubkey if path is None: raise OSUtilError("Public key path is None") crytputil = CryptUtil(conf.get_openssl_cmd()) path = self._norm_path(path) dir_path = os.path.dirname(path) fileutil.mkdir(dir_path, mode=0o700, owner=username) if value is not None: if not value.startswith("ssh-"): raise OSUtilError("Bad public key: {0}".format(value)) fileutil.write_file(path, value) elif thumbprint is not None: lib_dir = conf.get_lib_dir() crt_path = os.path.join(lib_dir, thumbprint + '.crt') if not os.path.isfile(crt_path): raise OSUtilError("Can't find {0}.crt".format(thumbprint)) pub_path = os.path.join(lib_dir, thumbprint + '.pub') pub = crytputil.get_pubkey_from_crt(crt_path) fileutil.write_file(pub_path, pub) self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0') self.openssl_to_openssh(pub_path, path) fileutil.chmod(pub_path, 0o600) else: raise OSUtilError("SSH public key Fingerprint and Value are None") self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0') fileutil.chowner(path, username) fileutil.chmod(path, 0o644)
def __init__(self, protocol): self.os_util = get_osutil() self.protocol = protocol self.cryptUtil = CryptUtil(conf.get_openssl_cmd()) self.remote_access = None self.incarnation = 0 self.error_message = ""
def openssl_to_openssh(self, input_file, output_file): pubkey = fileutil.read_file(input_file) try: cryptutil = CryptUtil(conf.get_openssl_cmd()) ssh_rsa_pubkey = cryptutil.asn1_to_ssh(pubkey) except CryptError as e: raise OSUtilError(ustr(e)) fileutil.write_file(output_file, ssh_rsa_pubkey)
def openssl_to_openssh(self, input_file, output_file): pubkey = fileutil.read_file(input_file) try: cryptutil = CryptUtil(conf.get_openssl_cmd()) ssh_rsa_pubkey = cryptutil.asn1_to_ssh(pubkey) except CryptError as e: raise OSUtilError(ustr(e)) fileutil.append_file(output_file, ssh_rsa_pubkey)
def test_decrypt_encrypted_text(self): encrypted_string = load_data("wire/encrypted.enc") prv_key = os.path.join(self.tmp_dir, "TransportPrivate.pem") with open(prv_key, 'w+') as c: c.write(load_data("wire/sample.pem")) secret = ']aPPEv}uNg1FPnl?' crypto = CryptUtil(conf.get_openssl_cmd()) decrypted_string = crypto.decrypt_secret(encrypted_string, prv_key) self.assertEqual(secret, decrypted_string, "decrypted string does not match expected")
def detect(self): self.client.check_wire_protocol_version() trans_prv_file = os.path.join(conf.get_lib_dir(), TRANSPORT_PRV_FILE_NAME) trans_cert_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME) cryptutil = CryptUtil(conf.get_openssl_cmd()) cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) self.client.update_goal_state(forced=True)
def __init__(self): self.os_util = get_osutil() self.protocol_util = get_protocol_util() self.protocol = None self.cryptUtil = CryptUtil(conf.get_openssl_cmd()) self.remote_access = None self.incarnation = 0
def test_decrypt_encrypted_text_text_not_encrypted(self): encrypted_string = "abc@123" prv_key = os.path.join(self.tmp_dir, "TransportPrivate.pem") with open(prv_key, 'w+') as c: c.write(load_data("wire/sample.pem")) crypto = CryptUtil(conf.get_openssl_cmd()) self.assertRaises(CryptError, crypto.decrypt_secret, encrypted_string, prv_key)
def detect(self): self.get_vminfo() trans_prv_file = os.path.join(conf.get_lib_dir(), TRANSPORT_PRV_FILE_NAME) trans_cert_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME) cryptutil = CryptUtil(conf.get_openssl_cmd()) cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) #"Install" the cert and private key to /var/lib/waagent thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file) prv_file = os.path.join(conf.get_lib_dir(), "{0}.prv".format(thumbprint)) crt_file = os.path.join(conf.get_lib_dir(), "{0}.crt".format(thumbprint)) shutil.copyfile(trans_prv_file, prv_file) shutil.copyfile(trans_cert_file, crt_file)
def test_decrypt_encrypted_text_wrong_private_key(self): encrypted_string = load_data("wire/encrypted.enc") prv_key = os.path.join(self.tmp_dir, "wrong.pem") with open(prv_key, 'w+') as c: c.write(load_data("wire/trans_prv")) crypto = CryptUtil(conf.get_openssl_cmd()) self.assertRaises(CryptError, crypto.decrypt_secret, encrypted_string, prv_key)
def openssl_to_openssh(self, input_file, output_file): cryptutil = CryptUtil(conf.get_openssl_cmd()) ret, out = shellutil.run_get_output( conf.get_openssl_cmd() + " rsa -pubin -noout -text -in '" + input_file + "'") if ret != 0: raise OSUtilError('openssl failed with {0}'.format(ret)) modulus = [] exponent = [] buf = None for line in out.split('\n'): if line.startswith('Modulus:'): buf = modulus buf.append(line) continue if line.startswith('Exponent:'): buf = exponent buf.append(line) continue if buf and line: buf.append(line.strip().replace(':', '')) def text_to_num(buf): if len(buf) == 1: return int(buf[0].split()[1]) return long(''.join(buf[1:]), 16) n = text_to_num(modulus) e = text_to_num(exponent) keydata = bytearray() keydata.extend(struct.pack('>I', len('ssh-rsa'))) keydata.extend(b'ssh-rsa') keydata.extend(struct.pack('>I', len(cryptutil.num_to_bytes(e)))) keydata.extend(cryptutil.num_to_bytes(e)) keydata.extend(struct.pack('>I', len(cryptutil.num_to_bytes(n)) + 1)) keydata.extend(b'\0') keydata.extend(cryptutil.num_to_bytes(n)) keydata_base64 = base64.b64encode(bytebuffer(keydata)) fileutil.write_file(output_file, ustr(b'ssh-rsa ' + keydata_base64 + b'\n', encoding='utf-8'))
def deploy_ssh_keypair(self, username, keypair): """ Deploy id_rsa and id_rsa.pub """ path, thumbprint = keypair path = self._norm_path(path) dir_path = os.path.dirname(path) fileutil.mkdir(dir_path, mode=0o700, owner=username) lib_dir = conf.get_lib_dir() prv_path = os.path.join(lib_dir, thumbprint + '.prv') if not os.path.isfile(prv_path): raise OSUtilError("Can't find {0}.prv".format(thumbprint)) shutil.copyfile(prv_path, path) pub_path = path + '.pub' crytputil = CryptUtil(conf.get_openssl_cmd()) pub = crytputil.get_pubkey_from_prv(prv_path) fileutil.write_file(pub_path, pub) self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0') self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0') os.chmod(path, 0o644) os.chmod(pub_path, 0o600)
def openssl_to_openssh(self, input_file, output_file): cryptutil = CryptUtil(conf.get_openssl_cmd()) ret, out = shellutil.run_get_output(conf.get_openssl_cmd() + " rsa -pubin -noout -text -in '" + input_file + "'") if ret != 0: raise OSUtilError('openssl failed with {0}'.format(ret)) modulus = [] exponent = [] buf = None for line in out.split('\n'): if line.startswith('Modulus:'): buf = modulus buf.append(line) continue if line.startswith('Exponent:'): buf = exponent buf.append(line) continue if buf and line: buf.append(line.strip().replace(':', '')) def text_to_num(buf): if len(buf) == 1: return int(buf[0].split()[1]) return long(''.join(buf[1:]), 16) n = text_to_num(modulus) e = text_to_num(exponent) keydata = bytearray() keydata.extend(struct.pack('>I', len('ssh-rsa'))) keydata.extend(b'ssh-rsa') keydata.extend(struct.pack('>I', len(cryptutil.num_to_bytes(e)))) keydata.extend(cryptutil.num_to_bytes(e)) keydata.extend(struct.pack('>I', len(cryptutil.num_to_bytes(n)) + 1)) keydata.extend(b'\0') keydata.extend(cryptutil.num_to_bytes(n)) keydata_base64 = base64.b64encode(bytebuffer(keydata)) fileutil.write_file( output_file, ustr(b'ssh-rsa ' + keydata_base64 + b'\n', encoding='utf-8'))
def openssl_to_openssh(self, input_file, output_file): cryptutil = CryptUtil(conf.get_openssl_cmd()) cryptutil.crt_to_ssh(input_file, output_file)
def parse(self, xml_text): """ Parse multiple certificates into seperate files. """ xml_doc = parse_doc(xml_text) data = findtext(xml_doc, "Data") if data is None: return cryptutil = CryptUtil(conf.get_openssl_cmd()) p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME) p7m = ("MIME-Version:1.0\n" "Content-Disposition: attachment; filename=\"{0}\"\n" "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n" "Content-Transfer-Encoding: base64\n" "\n" "{2}").format(p7m_file, p7m_file, data) self.client.save_cache(p7m_file, p7m) trans_prv_file = os.path.join(conf.get_lib_dir(), TRANSPORT_PRV_FILE_NAME) trans_cert_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME) pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME) # decrypt certificates cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, pem_file) # The parsing process use public key to match prv and crt. buf = [] begin_crt = False begin_prv = False prvs = {} thumbprints = {} index = 0 v1_cert_list = [] with open(pem_file) as pem: for line in pem.readlines(): buf.append(line) if re.match(r'[-]+BEGIN.*KEY[-]+', line): begin_prv = True elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line): begin_crt = True elif re.match(r'[-]+END.*KEY[-]+', line): tmp_file = self.write_to_tmp_file(index, 'prv', buf) pub = cryptutil.get_pubkey_from_prv(tmp_file) prvs[pub] = tmp_file buf = [] index += 1 begin_prv = False elif re.match(r'[-]+END.*CERTIFICATE[-]+', line): tmp_file = self.write_to_tmp_file(index, 'crt', buf) pub = cryptutil.get_pubkey_from_crt(tmp_file) thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file) thumbprints[pub] = thumbprint # Rename crt with thumbprint as the file name crt = "{0}.crt".format(thumbprint) v1_cert_list.append({ "name": None, "thumbprint": thumbprint }) os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt)) buf = [] index += 1 begin_crt = False # Rename prv key with thumbprint as the file name for pubkey in prvs: thumbprint = thumbprints[pubkey] if thumbprint: tmp_file = prvs[pubkey] prv = "{0}.prv".format(thumbprint) os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv)) for v1_cert in v1_cert_list: cert = Cert() set_properties("certs", cert, v1_cert) self.cert_list.certificates.append(cert)
class RemoteAccessHandler(object): def __init__(self, protocol): self._os_util = get_osutil() self._protocol = protocol self._cryptUtil = CryptUtil(conf.get_openssl_cmd()) self._remote_access = None self._incarnation = 0 self._check_existing_jit_users = True def run(self): try: if self._os_util.jit_enabled: current_incarnation = self._protocol.get_incarnation() if self._incarnation != current_incarnation: # something changed. Handle remote access if any. self._incarnation = current_incarnation self._remote_access = self._protocol.client.get_remote_access( ) self._handle_remote_access() except Exception as e: msg = u"Exception processing goal state for remote access users: {0} {1}".format( ustr(e), traceback.format_exc()) add_event(AGENT_NAME, version=CURRENT_VERSION, op=WALAEventOperation.RemoteAccessHandling, is_success=False, message=msg) def _get_existing_jit_users(self): all_users = self._os_util.get_users() return set(u[0] for u in all_users if self._is_jit_user(u[4])) def _handle_remote_access(self): if self._remote_access is not None: logger.info("Processing remote access users in goal state.") self._check_existing_jit_users = True existing_jit_users = self._get_existing_jit_users() goal_state_users = set( u.name for u in self._remote_access.user_list.users) for acc in self._remote_access.user_list.users: try: raw_expiration = acc.expiration account_expiration = datetime.strptime( raw_expiration, REMOTE_USR_EXPIRATION_FORMAT) now = datetime.utcnow() if acc.name not in existing_jit_users and now < account_expiration: self._add_user(acc.name, acc.encrypted_password, account_expiration) elif acc.name in existing_jit_users and now > account_expiration: # user account expired, delete it. logger.info("Remote access user '{0}' expired.", acc.name) self._remove_user(acc.name) except Exception as e: logger.error( "Error processing remote access user '{0}' - {1}", acc.name, ustr(e)) for user in existing_jit_users: try: if user not in goal_state_users: # user explicitly removed self._remove_user(user) except Exception as e: logger.error( "Error removing remote access user '{0}' - {1}", user, ustr(e)) else: # There are no JIT users in the goal state; that may mean that they were removed or that they # were never added. Enumerating the users on the current vm can be very slow and this path is hit # on each goal state; we use self._check_existing_jit_users to avoid enumerating the users # every single time. if self._check_existing_jit_users: logger.info("Looking for existing remote access users.") existing_jit_users = self._get_existing_jit_users() remove_user_errors = False for user in existing_jit_users: try: self._remove_user(user) except Exception as e: logger.error( "Error removing remote access user '{0}' - {1}", user, ustr(e)) remove_user_errors = True if not remove_user_errors: self._check_existing_jit_users = False @staticmethod def _is_jit_user(comment): return comment == REMOTE_ACCESS_ACCOUNT_COMMENT def _add_user(self, username, encrypted_password, account_expiration): user_added = False try: expiration_date = (account_expiration + timedelta(days=1)).strftime(DATE_FORMAT) logger.info( "Adding remote access user '{0}' with expiration date {1}", username, expiration_date) self._os_util.useradd(username, expiration_date, REMOTE_ACCESS_ACCOUNT_COMMENT) user_added = True logger.info("Adding remote access user '{0}' to sudoers", username) prv_key = os.path.join(conf.get_lib_dir(), TRANSPORT_PRIVATE_CERT) pwd = self._cryptUtil.decrypt_secret(encrypted_password, prv_key) self._os_util.chpasswd(username, pwd, conf.get_password_cryptid(), conf.get_password_crypt_salt_len()) self._os_util.conf_sudoer(username) except Exception: if user_added: self._remove_user(username) raise def _remove_user(self, username): logger.info("Removing remote access user '{0}'", username) self._os_util.del_account(username)
def __init__(self, protocol): self._os_util = get_osutil() self._protocol = protocol self._cryptUtil = CryptUtil(conf.get_openssl_cmd()) self._remote_access = None self._check_existing_jit_users = True
class RemoteAccessHandler(object): def __init__(self): self.os_util = get_osutil() self.protocol_util = get_protocol_util() self.protocol = None self.cryptUtil = CryptUtil(conf.get_openssl_cmd()) self.remote_access = None self.incarnation = 0 self.error_message = "" def run(self): try: if self.os_util.jit_enabled: self.protocol = self.protocol_util.get_protocol() current_incarnation = self.protocol.get_incarnation() if self.incarnation != current_incarnation: # something changed. Handle remote access if any. self.incarnation = current_incarnation self.remote_access = self.protocol.client.get_remote_access( ) self.handle_remote_access() except Exception as e: msg = u"Exception processing remote access handler: {0} {1}".format( ustr(e), traceback.format_exc()) logger.error(msg) add_event(AGENT_NAME, version=CURRENT_VERSION, op=WALAEventOperation.RemoteAccessHandling, is_success=False, message=msg) def handle_remote_access(self): # Get JIT user accounts. all_users = self.os_util.get_users() existing_jit_users = set(u[0] for u in all_users if self.validate_jit_user(u[4])) self.err_message = "" if self.remote_access is not None: goal_state_users = set(u.name for u in self.remote_access.user_list.users) for acc in self.remote_access.user_list.users: try: raw_expiration = acc.expiration account_expiration = datetime.strptime( raw_expiration, REMOTE_USR_EXPIRATION_FORMAT) now = datetime.utcnow() if acc.name not in existing_jit_users and now < account_expiration: self.add_user(acc.name, acc.encrypted_password, account_expiration) elif acc.name in existing_jit_users and now > account_expiration: # user account expired, delete it. logger.info( "user {0} expired from remote_access".format( acc.name)) self.remove_user(acc.name) except RemoteAccessError as rae: self.err_message = self.err_message + "Error processing user {0}. Exception: {1}"\ .format(acc.name, ustr(rae)) for user in existing_jit_users: try: if user not in goal_state_users: # user explicitly removed logger.info( "User {0} removed from remote_access".format(user)) self.remove_user(user) except RemoteAccessError as rae: self.err_message = self.err_message + "Error removing user {0}. Exception: {1}"\ .format(user, ustr(rae)) else: # All users removed, remove any remaining JIT accounts. for user in existing_jit_users: try: logger.info( "User {0} removed from remote_access. remote_access empty" .format(user)) self.remove_user(user) except RemoteAccessError as rae: self.err_message = self.err_message + "Error removing user {0}. Exception: {1}"\ .format(user, ustr(rae)) def validate_jit_user(self, comment): return comment == REMOTE_ACCESS_ACCOUNT_COMMENT def add_user(self, username, encrypted_password, account_expiration): try: expiration_date = (account_expiration + timedelta(days=1)).strftime(DATE_FORMAT) logger.verbose("Adding user {0} with expiration date {1}".format( username, expiration_date)) self.os_util.useradd(username, expiration_date, REMOTE_ACCESS_ACCOUNT_COMMENT) except Exception as e: raise RemoteAccessError("Error adding user {0}. {1}".format( username, ustr(e))) try: prv_key = os.path.join(conf.get_lib_dir(), TRANSPORT_PRIVATE_CERT) pwd = self.cryptUtil.decrypt_secret(encrypted_password, prv_key) self.os_util.chpasswd(username, pwd, conf.get_password_cryptid(), conf.get_password_crypt_salt_len()) self.os_util.conf_sudoer(username) logger.info( "User '{0}' added successfully with expiration in {1}".format( username, expiration_date)) except Exception as e: error = "Error adding user {0}. {1} ".format(username, str(e)) try: self.handle_failed_create(username) error += "cleanup successful" except RemoteAccessError as rae: error += "and error cleaning up {0}".format(str(rae)) raise RemoteAccessError( "Error adding user {0} cleanup successful".format(username), ustr(e)) def handle_failed_create(self, username): try: self.delete_user(username) except Exception as e: raise RemoteAccessError( "Failed to clean up after account creation for {0}.".format( username), e) def remove_user(self, username): try: self.delete_user(username) except Exception as e: raise RemoteAccessError( "Failed to delete user {0}".format(username), e) def delete_user(self, username): self.os_util.del_account(username) logger.info("User deleted {0}".format(username))
def mock_crypt_util(self, *args, **kw): #Partially patch instance method of class CryptUtil cryptutil = CryptUtil(*args, **kw) cryptutil.gen_transport_cert = Mock( side_effect=self.mock_gen_trans_cert) return cryptutil
class RemoteAccessHandler(object): def __init__(self): self.os_util = get_osutil() self.protocol_util = get_protocol_util() self.protocol = None self.cryptUtil = CryptUtil(conf.get_openssl_cmd()) self.remote_access = None self.incarnation = 0 def run(self): try: if self.os_util.jit_enabled: self.protocol = self.protocol_util.get_protocol() current_incarnation = self.protocol.get_incarnation() if self.incarnation != current_incarnation: # something changed. Handle remote access if any. self.incarnation = current_incarnation self.remote_access = self.protocol.client.get_remote_access() if self.remote_access is not None: self.handle_remote_access() except Exception as e: msg = u"Exception processing remote access handler: {0} {1}".format(ustr(e), traceback.format_exc()) logger.error(msg) add_event(AGENT_NAME, version=CURRENT_VERSION, op=WALAEventOperation.RemoteAccessHandling, is_success=False, message=msg) def handle_remote_access(self): if self.remote_access is not None: # Get JIT user accounts. all_users = self.os_util.get_users() jit_users = set() for usr in all_users: if self.validate_jit_user(usr[4]): jit_users.add(usr[0]) for acc in self.remote_access.user_list.users: raw_expiration = acc.expiration account_expiration = datetime.strptime(raw_expiration, REMOTE_USR_EXPIRATION_FORMAT) now = datetime.utcnow() if acc.name not in jit_users and now < account_expiration: self.add_user(acc.name, acc.encrypted_password, account_expiration) def validate_jit_user(self, comment): return comment == REMOTE_ACCESS_ACCOUNT_COMMENT def add_user(self, username, encrypted_password, account_expiration): try: expiration_date = (account_expiration + timedelta(days=1)).strftime(DATE_FORMAT) logger.verbose("Adding user {0} with expiration date {1}" .format(username, expiration_date)) self.os_util.useradd(username, expiration_date, REMOTE_ACCESS_ACCOUNT_COMMENT) except OSError as oe: logger.error("Error adding user {0}. {1}" .format(username, oe.strerror)) return except Exception as e: logger.error("Error adding user {0}. {1}".format(username, ustr(e))) return try: prv_key = os.path.join(conf.get_lib_dir(), TRANSPORT_PRIVATE_CERT) pwd = self.cryptUtil.decrypt_secret(encrypted_password, prv_key) self.os_util.chpasswd(username, pwd, conf.get_password_cryptid(), conf.get_password_crypt_salt_len()) self.os_util.conf_sudoer(username) logger.info("User '{0}' added successfully with expiration in {1}" .format(username, expiration_date)) return except OSError as oe: self.handle_failed_create(username, oe.strerror) except Exception as e: self.handle_failed_create(username, ustr(e)) def handle_failed_create(self, username, error_message): logger.error("Error creating user {0}. {1}" .format(username, error_message)) try: self.delete_user(username) except OSError as oe: logger.error("Failed to clean up after account creation for {0}. {1}" .format(username, oe.strerror())) except Exception as e: logger.error("Failed to clean up after account creation for {0}. {1}" .format(username, str(e))) def delete_user(self, username): self.os_util.del_account(username) logger.info("User deleted {0}".format(username))
def mock_crypt_util(self, *args, **kw): #Partially patch instance method of class CryptUtil cryptutil = CryptUtil(*args, **kw) cryptutil.gen_transport_cert = Mock(side_effect=self.mock_gen_trans_cert) return cryptutil
def test_get_pubkey_from_crt_invalid_file(self): crypto = CryptUtil(conf.get_openssl_cmd()) prv_key = os.path.join(data_dir, "wire", "trans_prv_does_not_exist") self.assertRaises(IOError, crypto.get_pubkey_from_prv, prv_key)
def parse(self, json_text): """ Parse multiple certificates into seperate files. """ data = json_text["certificateData"] if data is None: logger.verbose("No data in json_text received!") return cryptutil = CryptUtil(conf.get_openssl_cmd()) p7b_file = os.path.join(conf.get_lib_dir(), P7B_FILE_NAME) # Wrapping the certificate lines. # decode and save the result into p7b_file fileutil.write_file(p7b_file, base64.b64decode(data), asbin=True) ssl_cmd = "openssl pkcs7 -text -in {0} -inform der | grep -v '^-----' " ret, data = shellutil.run_get_output(ssl_cmd.format(p7b_file)) p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME) p7m = ("MIME-Version:1.0\n" "Content-Disposition: attachment; filename=\"{0}\"\n" "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n" "Content-Transfer-Encoding: base64\n" "\n" "{2}").format(p7m_file, p7m_file, data) self.save_cache(p7m_file, p7m) trans_prv_file = os.path.join(conf.get_lib_dir(), TRANSPORT_PRV_FILE_NAME) trans_cert_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME) pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME) # decrypt certificates cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, pem_file) # The parsing process use public key to match prv and crt. buf = [] begin_crt = False begin_prv = False prvs = {} thumbprints = {} index = 0 v1_cert_list = [] with open(pem_file) as pem: for line in pem.readlines(): buf.append(line) if re.match(r'[-]+BEGIN.*KEY[-]+', line): begin_prv = True elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line): begin_crt = True elif re.match(r'[-]+END.*KEY[-]+', line): tmp_file = self.write_to_tmp_file(index, 'prv', buf) pub = cryptutil.get_pubkey_from_prv(tmp_file) prvs[pub] = tmp_file buf = [] index += 1 begin_prv = False elif re.match(r'[-]+END.*CERTIFICATE[-]+', line): tmp_file = self.write_to_tmp_file(index, 'crt', buf) pub = cryptutil.get_pubkey_from_crt(tmp_file) thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file) thumbprints[pub] = thumbprint # Rename crt with thumbprint as the file name crt = "{0}.crt".format(thumbprint) v1_cert_list.append({ "name": None, "thumbprint": thumbprint }) os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt)) buf = [] index += 1 begin_crt = False # Rename prv key with thumbprint as the file name for pubkey in prvs: thumbprint = thumbprints[pubkey] if thumbprint: tmp_file = prvs[pubkey] prv = "{0}.prv".format(thumbprint) os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv)) for v1_cert in v1_cert_list: cert = Cert() set_properties("certs", cert, v1_cert) self.cert_list.certificates.append(cert)
def parse(self, json_text): """ Parse multiple certificates into seperate files. """ data = json_text["certificateData"] if data is None: logger.verbose("No data in json_text received!") return cryptutil = CryptUtil(conf.get_openssl_cmd()) p7b_file = os.path.join(conf.get_lib_dir(), P7B_FILE_NAME) # Wrapping the certificate lines. b64_cmd = "echo {0} | base64 -d > {1}" shellutil.run(b64_cmd.format(data, p7b_file)) ssl_cmd = "openssl pkcs7 -text -in {0} -inform der | grep -v '^-----' " ret, data = shellutil.run_get_output(ssl_cmd.format(p7b_file)) p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME) p7m = ("MIME-Version:1.0\n" "Content-Disposition: attachment; filename=\"{0}\"\n" "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n" "Content-Transfer-Encoding: base64\n" "\n" "{2}").format(p7m_file, p7m_file, data) self.save_cache(p7m_file, p7m) trans_prv_file = os.path.join(conf.get_lib_dir(), TRANSPORT_PRV_FILE_NAME) trans_cert_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME) pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME) # decrypt certificates cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, pem_file) # The parsing process use public key to match prv and crt. buf = [] begin_crt = False begin_prv = False prvs = {} thumbprints = {} index = 0 v1_cert_list = [] with open(pem_file) as pem: for line in pem.readlines(): buf.append(line) if re.match(r'[-]+BEGIN.*KEY[-]+', line): begin_prv = True elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line): begin_crt = True elif re.match(r'[-]+END.*KEY[-]+', line): tmp_file = self.write_to_tmp_file(index, 'prv', buf) pub = cryptutil.get_pubkey_from_prv(tmp_file) prvs[pub] = tmp_file buf = [] index += 1 begin_prv = False elif re.match(r'[-]+END.*CERTIFICATE[-]+', line): tmp_file = self.write_to_tmp_file(index, 'crt', buf) pub = cryptutil.get_pubkey_from_crt(tmp_file) thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file) thumbprints[pub] = thumbprint # Rename crt with thumbprint as the file name crt = "{0}.crt".format(thumbprint) v1_cert_list.append({ "name": None, "thumbprint": thumbprint }) os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt)) buf = [] index += 1 begin_crt = False # Rename prv key with thumbprint as the file name for pubkey in prvs: thumbprint = thumbprints[pubkey] if thumbprint: tmp_file = prvs[pubkey] prv = "{0}.prv".format(thumbprint) os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv)) for v1_cert in v1_cert_list: cert = Cert() set_properties("certs", cert, v1_cert) self.cert_list.certificates.append(cert)
class RemoteAccessHandler(object): def __init__(self): self.os_util = get_osutil() self.protocol_util = get_protocol_util() self.protocol = None self.cryptUtil = CryptUtil(conf.get_openssl_cmd()) self.remote_access = None self.incarnation = 0 def run(self): try: if self.os_util.jit_enabled: self.protocol = self.protocol_util.get_protocol() current_incarnation = self.protocol.get_incarnation() if self.incarnation != current_incarnation: # something changed. Handle remote access if any. self.incarnation = current_incarnation self.remote_access = self.protocol.client.get_remote_access( ) if self.remote_access is not None: self.handle_remote_access() except Exception as e: msg = u"Exception processing remote access handler: {0} {1}".format( ustr(e), traceback.format_exc()) logger.error(msg) add_event(AGENT_NAME, version=CURRENT_VERSION, op=WALAEventOperation.RemoteAccessHandling, is_success=False, message=msg) def handle_remote_access(self): if self.remote_access is not None: # Get JIT user accounts. all_users = self.os_util.get_users() jit_users = set() for usr in all_users: if self.validate_jit_user(usr[4]): jit_users.add(usr[0]) for acc in self.remote_access.user_list.users: raw_expiration = acc.expiration account_expiration = datetime.strptime( raw_expiration, REMOTE_USR_EXPIRATION_FORMAT) now = datetime.utcnow() if acc.name not in jit_users and now < account_expiration: self.add_user(acc.name, acc.encrypted_password, account_expiration) def validate_jit_user(self, comment): return comment == REMOTE_ACCESS_ACCOUNT_COMMENT def add_user(self, username, encrypted_password, account_expiration): try: expiration_date = (account_expiration + timedelta(days=1)).strftime(DATE_FORMAT) logger.verbose("Adding user {0} with expiration date {1}".format( username, expiration_date)) self.os_util.useradd(username, expiration_date, REMOTE_ACCESS_ACCOUNT_COMMENT) except OSError as oe: logger.error("Error adding user {0}. {1}".format( username, oe.strerror)) return except Exception as e: logger.error("Error adding user {0}. {1}".format( username, ustr(e))) return try: prv_key = os.path.join(conf.get_lib_dir(), TRANSPORT_PRIVATE_CERT) pwd = self.cryptUtil.decrypt_secret(encrypted_password, prv_key) self.os_util.chpasswd(username, pwd, conf.get_password_cryptid(), conf.get_password_crypt_salt_len()) self.os_util.conf_sudoer(username) logger.info( "User '{0}' added successfully with expiration in {1}".format( username, expiration_date)) return except OSError as oe: self.handle_failed_create(username, oe.strerror) except Exception as e: self.handle_failed_create(username, ustr(e)) def handle_failed_create(self, username, error_message): logger.error("Error creating user {0}. {1}".format( username, error_message)) try: self.delete_user(username) except OSError as oe: logger.error( "Failed to clean up after account creation for {0}. {1}". format(username, oe.strerror())) except Exception as e: logger.error( "Failed to clean up after account creation for {0}. {1}". format(username, str(e))) def delete_user(self, username): self.os_util.del_account(username) logger.info("User deleted {0}".format(username))
def test_decrypt_encrypted_text_missing_private_key(self): encrypted_string = load_data("wire/encrypted.enc") prv_key = os.path.join(self.tmp_dir, "TransportPrivate.pem") crypto = CryptUtil(conf.get_openssl_cmd()) self.assertRaises(CryptError, crypto.decrypt_secret, encrypted_string, "abc" + prv_key)
def __init__(self, xml_text): self.cert_list = CertList() # Save the certificates local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME) fileutil.write_file(local_file, xml_text) # Separate the certificates into individual files. xml_doc = parse_doc(xml_text) data = findtext(xml_doc, "Data") if data is None: return # if the certificates format is not Pkcs7BlobWithPfxContents do not parse it certificateFormat = findtext(xml_doc, "Format") if certificateFormat and certificateFormat != "Pkcs7BlobWithPfxContents": logger.warn("The Format is not Pkcs7BlobWithPfxContents. Format is " + certificateFormat) return cryptutil = CryptUtil(conf.get_openssl_cmd()) p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME) p7m = ("MIME-Version:1.0\n" # pylint: disable=W1308 "Content-Disposition: attachment; filename=\"{0}\"\n" "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n" "Content-Transfer-Encoding: base64\n" "\n" "{2}").format(p7m_file, p7m_file, data) fileutil.write_file(p7m_file, p7m) trans_prv_file = os.path.join(conf.get_lib_dir(), TRANSPORT_PRV_FILE_NAME) trans_cert_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME) pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME) # decrypt certificates cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, pem_file) # The parsing process use public key to match prv and crt. buf = [] begin_crt = False # pylint: disable=W0612 begin_prv = False # pylint: disable=W0612 prvs = {} thumbprints = {} index = 0 v1_cert_list = [] with open(pem_file) as pem: for line in pem.readlines(): buf.append(line) if re.match(r'[-]+BEGIN.*KEY[-]+', line): begin_prv = True elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line): begin_crt = True elif re.match(r'[-]+END.*KEY[-]+', line): tmp_file = Certificates._write_to_tmp_file(index, 'prv', buf) pub = cryptutil.get_pubkey_from_prv(tmp_file) prvs[pub] = tmp_file buf = [] index += 1 begin_prv = False elif re.match(r'[-]+END.*CERTIFICATE[-]+', line): tmp_file = Certificates._write_to_tmp_file(index, 'crt', buf) pub = cryptutil.get_pubkey_from_crt(tmp_file) thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file) thumbprints[pub] = thumbprint # Rename crt with thumbprint as the file name crt = "{0}.crt".format(thumbprint) v1_cert_list.append({ "name": None, "thumbprint": thumbprint }) os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt)) buf = [] index += 1 begin_crt = False # Rename prv key with thumbprint as the file name for pubkey in prvs: thumbprint = thumbprints[pubkey] if thumbprint: tmp_file = prvs[pubkey] prv = "{0}.prv".format(thumbprint) os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv)) logger.info("Found private key matching thumbprint {0}".format(thumbprint)) else: # Since private key has *no* matching certificate, # it will not be named correctly logger.warn("Found NO matching cert/thumbprint for private key!") # Log if any certificates were found without matching private keys # This can happen (rarely), and is useful to know for debugging for pubkey in thumbprints: if not pubkey in prvs: msg = "Certificate with thumbprint {0} has no matching private key." logger.info(msg.format(thumbprints[pubkey])) for v1_cert in v1_cert_list: cert = Cert() set_properties("certs", cert, v1_cert) self.cert_list.certificates.append(cert)
class RemoteAccessHandler(object): def __init__(self): self.os_util = get_osutil() self.protocol_util = get_protocol_util() self.protocol = None self.cryptUtil = CryptUtil(conf.get_openssl_cmd()) self.remote_access = None self.incarnation = 0 self.error_message = "" def run(self): try: if self.os_util.jit_enabled: self.protocol = self.protocol_util.get_protocol() current_incarnation = self.protocol.get_incarnation() if self.incarnation != current_incarnation: # something changed. Handle remote access if any. self.incarnation = current_incarnation self.remote_access = self.protocol.client.get_remote_access() self.handle_remote_access() except Exception as e: msg = u"Exception processing remote access handler: {0} {1}".format(ustr(e), traceback.format_exc()) logger.error(msg) add_event(AGENT_NAME, version=CURRENT_VERSION, op=WALAEventOperation.RemoteAccessHandling, is_success=False, message=msg) def handle_remote_access(self): # Get JIT user accounts. all_users = self.os_util.get_users() existing_jit_users = set(u[0] for u in all_users if self.validate_jit_user(u[4])) self.err_message = "" if self.remote_access is not None: goal_state_users = set(u.name for u in self.remote_access.user_list.users) for acc in self.remote_access.user_list.users: try: raw_expiration = acc.expiration account_expiration = datetime.strptime(raw_expiration, REMOTE_USR_EXPIRATION_FORMAT) now = datetime.utcnow() if acc.name not in existing_jit_users and now < account_expiration: self.add_user(acc.name, acc.encrypted_password, account_expiration) elif acc.name in existing_jit_users and now > account_expiration: # user account expired, delete it. logger.info("user {0} expired from remote_access".format(acc.name)) self.remove_user(acc.name) except RemoteAccessError as rae: self.err_message = self.err_message + "Error processing user {0}. Exception: {1}"\ .format(acc.name, ustr(rae)) for user in existing_jit_users: try: if user not in goal_state_users: # user explicitly removed logger.info("User {0} removed from remote_access".format(user)) self.remove_user(user) except RemoteAccessError as rae: self.err_message = self.err_message + "Error removing user {0}. Exception: {1}"\ .format(user, ustr(rae)) else: # All users removed, remove any remaining JIT accounts. for user in existing_jit_users: try: logger.info("User {0} removed from remote_access. remote_access empty".format(user)) self.remove_user(user) except RemoteAccessError as rae: self.err_message = self.err_message + "Error removing user {0}. Exception: {1}"\ .format(user, ustr(rae)) def validate_jit_user(self, comment): return comment == REMOTE_ACCESS_ACCOUNT_COMMENT def add_user(self, username, encrypted_password, account_expiration): try: expiration_date = (account_expiration + timedelta(days=1)).strftime(DATE_FORMAT) logger.verbose("Adding user {0} with expiration date {1}".format(username, expiration_date)) self.os_util.useradd(username, expiration_date, REMOTE_ACCESS_ACCOUNT_COMMENT) except Exception as e: raise RemoteAccessError("Error adding user {0}. {1}".format(username, ustr(e))) try: prv_key = os.path.join(conf.get_lib_dir(), TRANSPORT_PRIVATE_CERT) pwd = self.cryptUtil.decrypt_secret(encrypted_password, prv_key) self.os_util.chpasswd(username, pwd, conf.get_password_cryptid(), conf.get_password_crypt_salt_len()) self.os_util.conf_sudoer(username) logger.info("User '{0}' added successfully with expiration in {1}".format(username, expiration_date)) except Exception as e: error = "Error adding user {0}. {1} ".format(username, str(e)) try: self.handle_failed_create(username) error += "cleanup successful" except RemoteAccessError as rae: error += "and error cleaning up {0}".format(str(rae)) raise RemoteAccessError("Error adding user {0} cleanup successful".format(username), ustr(e)) def handle_failed_create(self, username): try: self.delete_user(username) except Exception as e: raise RemoteAccessError("Failed to clean up after account creation for {0}.".format(username), e) def remove_user(self, username): try: self.delete_user(username) except Exception as e: raise RemoteAccessError("Failed to delete user {0}".format(username), e) def delete_user(self, username): self.os_util.del_account(username) logger.info("User deleted {0}".format(username))