Exemple #1
0
    def dc_watcher(self):

        (r1, w1) = os.pipe()
        pid = os.fork()
        if pid != 0:
            # Parent process return the result socket to the caller.
            return r1

        # Load the lp context for the Domain Controller, rather than the
        # member server.
        config_file = os.environ["DC_SERVERCONFFILE"]
        lp_ctx = LoadParm()
        lp_ctx.load(config_file)

        #
        # Is the message a SamLogon authentication?
        def is_sam_logon(m):
            if m is None:
                return False
            msg = json.loads(m)
            return (msg["type"] == "Authentication" and
                    msg["Authentication"]["serviceDescription"] == "SamLogon")

        #
        # Handler function for received authentication messages.
        def message_handler(context, msgType, src, message):
            # Print the message to help debugging the tests.
            # as it's a JSON message it does not look like a sub-unit message.
            print(message)
            self.dc_msgs.append(message)

        # Set up a messaging context to listen for authentication events on
        # the domain controller.
        msg_ctx = Messaging((1, ), lp_ctx=lp_ctx)
        msg_ctx.irpc_add_name(AUTH_EVENT_NAME)
        msg_handler_and_context = (message_handler, None)
        msg_ctx.register(msg_handler_and_context, msg_type=MSG_AUTH_LOG)

        # Wait for the SamLogon message.
        # As there could be other SamLogon's in progress we need to collect
        # all the SamLogons and let the caller match them to the session.
        self.dc_msgs = []
        start_time = time.time()
        while (time.time() - start_time < 1):
            msg_ctx.loop_once(0.1)

        # Only interested in SamLogon messages, filter out the rest
        msgs = list(filter(is_sam_logon, self.dc_msgs))
        if msgs:
            for m in msgs:
                m += "\n"
                os.write(w1, get_bytes(m))
        else:
            os.write(w1, get_bytes("None\n"))
        os.close(w1)

        msg_ctx.deregister(msg_handler_and_context, msg_type=MSG_AUTH_LOG)
        msg_ctx.irpc_remove_name(AUTH_EVENT_NAME)

        os._exit(0)
Exemple #2
0
    def get_secdesc(self, name):
        """Obtain the security descriptor on a particular share.

        :param name: Name of the share
        """
        secdesc = self.db.get(get_bytes("SECDESC/%s" % name))
        # FIXME: Run ndr_pull_security_descriptor
        return secdesc
Exemple #3
0
    def run(self, groupname, credopts=None, sambaopts=None, versionopts=None,
            H=None, editor=None):
        from . import common

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)

        filter = ("(&(sAMAccountName=%s)(objectClass=group))" % groupname)

        domaindn = samdb.domain_dn()

        try:
            res = samdb.search(base=domaindn,
                               expression=filter,
                               scope=ldb.SCOPE_SUBTREE)
            group_dn = res[0].dn
        except IndexError:
            raise CommandError('Unable to find group "%s"' % (groupname))

        if len(res) != 1:
            raise CommandError('Invalid number of results: for "%s": %d' %
                               ((groupname), len(res)))

        msg = res[0]
        result_ldif = common.get_ldif_for_editor(samdb, msg)

        if editor is None:
            editor = os.environ.get('EDITOR')
            if editor is None:
                editor = 'vi'

        with tempfile.NamedTemporaryFile(suffix=".tmp") as t_file:
            t_file.write(get_bytes(result_ldif))
            t_file.flush()
            try:
                check_call([editor, t_file.name])
            except CalledProcessError as e:
                raise CalledProcessError("ERROR: ", e)
            with open(t_file.name) as edited_file:
                edited_message = edited_file.read()

        msgs_edited = samdb.parse_ldif(edited_message)
        msg_edited = next(msgs_edited)[1]

        res_msg_diff = samdb.msg_diff(msg, msg_edited)
        if len(res_msg_diff) == 0:
            self.outf.write("Nothing to do\n")
            return

        try:
            samdb.modify(res_msg_diff)
        except Exception as e:
            raise CommandError("Failed to modify group '%s': " % groupname, e)

        self.outf.write("Modified group '%s' successfully\n" % groupname)
Exemple #4
0
    def get_sid(self, xid, id_type):
        """Retrive SID associated with a particular id and type.

        :param xid: UID or GID to retrieve SID for.
        :param id_type: Type of id specified - 'UID' or 'GID'
        """
        data = self.db.get(get_bytes("%s %s\0" % (id_type, str(xid))))
        if data is None:
            return data
        return data.rstrip("\0")
Exemple #5
0
 def string_to_key(cls, string, salt, params):
     (iterations, ) = unpack('>L', params or b'\x00\x00\x10\x00')
     pwbytes = get_bytes(string)
     kdf = PBKDF2HMAC(algorithm=hashes.SHA1(),
                      length=cls.seedsize,
                      salt=salt,
                      iterations=iterations,
                      backend=default_backend())
     seed = kdf.derive(pwbytes)
     tkey = cls.random_to_key(seed)
     return cls.derive(tkey, b'kerberos')
Exemple #6
0
    def new_xml_entity(self, name, ent_type):
        identifier = md5(get_bytes(name)).hexdigest()

        type_str = entity_type_to_string(ent_type)

        if type_str is None:
            raise GPGeneralizeException("No such entity type")

        # For formattting reasons, align the length of the entities
        longest = entity_type_to_string(ENTITY_NETWORK_PATH)
        type_str = type_str.center(len(longest), '_')

        return "&SAMBA__{}__{}__;".format(type_str, identifier)
Exemple #7
0
def stage_file(path, data):
    dirname = os.path.dirname(path)
    if not os.path.exists(dirname):
        try:
            os.makedirs(dirname)
        except OSError as e:
            if not (e.errno == errno.EEXIST and os.path.isdir(dirname)):
                return False
    if os.path.exists(path):
        os.rename(path, '%s.bak' % path)
    with NamedTemporaryFile(delete=False, dir=os.path.dirname(path)) as f:
        f.write(get_bytes(data))
        os.rename(f.name, path)
        os.chmod(path, 0o644)
    return True
Exemple #8
0
    def run_gte_tests_with_expressions(self, expressions):
        # Here we don't test every before/after combination.
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        for expression in expressions:
            for attr in attrs:
                gte_order, expected_order, gte_map = \
                    self.get_gte_tests_and_order(attr, expression)
                # In case there is some order dependency, disorder tests
                gte_tests = gte_order[:]
                random.seed(2)
                random.shuffle(gte_tests)
                res = None
                sort_control = "server_sort:1:0:%s" % attr

                expected_order = self.get_expected_order(attr, expression)
                sort_control = "server_sort:1:0:%s" % attr
                res = None
                for before in range(0, 11):
                    after = before
                    for gte in gte_tests:
                        if res is not None:
                            cookie = get_cookie(res.controls)
                        else:
                            cookie = None
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=get_bytes(gte),
                                                        cookie=cookie)

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            expression=expression,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        results = [x[attr][0] for x in res]
                        offset = gte_map.get(gte, len(expected_order))

                        # here offset is 0-based
                        start = max(offset - before, 0)
                        end = offset + 1 + after

                        expected_results = expected_order[start:end]

                        self.assertEquals(expected_results, results)
Exemple #9
0
    def test_server_vlv_gte_no_cookie(self):
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        iteration = 0
        for attr in attrs:
            gte_order, expected_order, gte_map = \
                                        self.get_gte_tests_and_order(attr)
            # In case there is some order dependency, disorder tests
            gte_tests = gte_order[:]
            random.seed(1)
            random.shuffle(gte_tests)

            sort_control = "server_sort:1:0:%s" % attr
            for before in [0, 1, 3]:
                for after in [0, 4]:
                    for gte in gte_tests:
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=get_bytes(gte))

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])
                        results = [x[attr][0] for x in res]

                        # here offset is 0-based
                        offset = gte_map.get(gte, len(expected_order))
                        start = max(offset - before, 0)
                        end = offset + after + 1
                        expected_results = expected_order[start:end]
                        iteration += 1
                        if expected_results != results:
                            middle = expected_order[len(expected_order) // 2]
                            print(expected_results, results)
                            print(middle)
                            print(expected_order)
                            print()
                            print("\nattr %s offset %d before %d "
                                  "after %d gte %s" %
                                  (attr, offset, before, after, gte))
                        self.assertEquals(expected_results, results)
Exemple #10
0
    def test_server_vlv_gte_with_cookie(self):
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        for attr in attrs:
            gte_order, expected_order, gte_map = \
                                        self.get_gte_tests_and_order(attr)
            # In case there is some order dependency, disorder tests
            gte_tests = gte_order[:]
            random.seed(1)
            random.shuffle(gte_tests)
            res = None
            sort_control = "server_sort:1:0:%s" % attr
            for before in [0, 1, 2, 4]:
                for after in [0, 1, 3, 6]:
                    for gte in gte_tests:
                        if res is not None:
                            cookie = get_cookie(res.controls, len(self.users))
                        else:
                            cookie = None
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=get_bytes(gte),
                                                        cookie=cookie)

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        results = [x[attr][0] for x in res]
                        offset = gte_map.get(gte, len(expected_order))

                        # here offset is 0-based
                        start = max(offset - before, 0)
                        end = offset + 1 + after

                        expected_results = expected_order[start:end]

                        self.assertEquals(expected_results, results)
Exemple #11
0
 def store(self, key, val):
     self.log.store(get_bytes(key), get_bytes(val))
Exemple #12
0
 def get_gplog(self, user):
     return gp_log(user, self, self.log.get(get_bytes(user)))
Exemple #13
0
 def get(self, key):
     return self.log.get(get_bytes(key))
Exemple #14
0
 def get_int(self, key):
     try:
         return int(self.log.get(get_bytes(key)))
     except TypeError:
         return None
Exemple #15
0
    def run(self,
            contactname,
            sambaopts=None,
            credopts=None,
            versionopts=None,
            H=None,
            editor=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)
        base_dn = samdb.domain_dn()
        scope = ldb.SCOPE_SUBTREE

        filter = ("(&(objectClass=contact)(name=%s))" %
                   ldb.binary_encode(contactname))

        if contactname.upper().startswith("CN="):
            # contact is specified by DN
            filter = "(objectClass=contact)"
            scope = ldb.SCOPE_BASE
            try:
                base_dn = samdb.normalize_dn_in_domain(contactname)
            except Exception as e:
                raise CommandError('Invalid dn "%s": %s' %
                                   (contactname, e))

        try:
            res = samdb.search(base=base_dn,
                               scope=scope,
                               expression=filter)
            contact_dn = res[0].dn
        except IndexError:
            raise CommandError('Unable to find contact "%s"' % (contactname))

        if len(res) > 1:
            for msg in sorted(res, key=attrgetter('dn')):
                self.outf.write("found: %s\n" % msg.dn)
            raise CommandError("Multiple results for contact '%s'\n"
                               "Please specify the contact's full DN" %
                               contactname)

        for msg in res:
            result_ldif = common.get_ldif_for_editor(samdb, msg)

            if editor is None:
                editor = os.environ.get('EDITOR')
                if editor is None:
                    editor = 'vi'

            with tempfile.NamedTemporaryFile(suffix=".tmp") as t_file:
                t_file.write(get_bytes(result_ldif))
                t_file.flush()
                try:
                    check_call([editor, t_file.name])
                except CalledProcessError as e:
                    raise CalledProcessError("ERROR: ", e)
                with open(t_file.name) as edited_file:
                    edited_message = edited_file.read()


        msgs_edited = samdb.parse_ldif(edited_message)
        msg_edited = next(msgs_edited)[1]

        res_msg_diff = samdb.msg_diff(msg, msg_edited)
        if len(res_msg_diff) == 0:
            self.outf.write("Nothing to do\n")
            return

        try:
            samdb.modify(res_msg_diff)
        except Exception as e:
            raise CommandError("Failed to modify contact '%s': " % contactname,
                               e)

        self.outf.write("Modified contact '%s' successfully\n" % contactname)
Exemple #16
0
 def get_machine_acc(self, host):
     return self.db.get(get_bytes("SECRETS/$MACHINE.ACC/%s" % host))
Exemple #17
0
 def get_sid(self, host):
     return self.db.get(get_bytes("SECRETS/SID/%s" % host.upper()))
Exemple #18
0
 def get_ldap_bind_pw(self, host):
     return self.db.get(get_bytes("SECRETS/LDAP_BIND_PW/%s" % host))
Exemple #19
0
    def test_server_vlv_with_cookie_while_adding_and_deleting(self):
        """What happens if we add or remove items in the middle of the VLV?

        Nothing. The search and the sort is not repeated, and we only
        deal with the objects originally found.
        """
        attrs = ['cn'] + [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        user_number = 0
        iteration = 0
        for attr in attrs:
            full_results, controls, sort_control = \
                            self.get_full_list(attr, True)
            original_n = len(self.users)

            expected_order = full_results
            random.seed(1)

            for before in list(range(0, 3)) + [6, 11, 19]:
                for after in list(range(0, 3)) + [6, 11, 19]:
                    start = max(before - 1, 1)
                    end = max(start + 4, original_n - after + 2)
                    for offset in range(start, end):
                        # if iteration > 2076:
                        #    return
                        cookie = get_cookie(controls, original_n)
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        offset=offset,
                                                        n=original_n,
                                                        cookie=cookie)

                        iteration += 1
                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        controls = res.controls
                        results = [x[attr][0] for x in res]
                        real_offset = max(1, min(offset, len(expected_order)))

                        expected_results = []
                        skipped = 0
                        begin_offset = max(real_offset - before - 1, 0)
                        real_before = min(before, real_offset - 1)
                        real_after = min(after,
                                         len(expected_order) - real_offset)

                        for x in expected_order[begin_offset:]:
                            if x is not None:
                                expected_results.append(get_bytes(x[0]))
                                if (len(expected_results) == real_before +
                                        real_after + 1):
                                    break
                            else:
                                skipped += 1

                        if expected_results != results:
                            print("attr %s before %d after %d offset %d" %
                                  (attr, before, after, offset))
                        self.assertEquals(expected_results, results)

                        n = len(self.users)
                        if random.random() < 0.1 + (n < 5) * 0.05:
                            if n == 0:
                                i = 0
                            else:
                                i = random.randrange(n)
                            user = self.create_user(i,
                                                    n,
                                                    suffix='-%s' % user_number)
                            user_number += 1
                        if random.random() < 0.1 + (n > 50) * 0.02 and n:
                            index = random.randrange(n)
                            user = self.users.pop(index)

                            self.ldb.delete(user['dn'])

                            replaced = (user[attr], user['cn'])
                            if replaced in expected_order:
                                i = expected_order.index(replaced)
                                expected_order[i] = None
Exemple #20
0
 def get_machine_sec_channel_type(self, host):
     return fetch_uint32(
         self.db, get_bytes("SECRETS/MACHINE_SEC_CHANNEL_TYPE/%s" % host))
Exemple #21
0
 def delete(self, key):
     self.log.delete(get_bytes(key))
Exemple #22
0
    def test_setpassword(self):
        for user in self.users:
            newpasswd = self.random_password(16)
            (result, out, err) = self.runsubcmd("user", "setpassword",
                                                user["name"],
                                                "--newpassword=%s" % newpasswd,
                                                "-H", "ldap://%s" % os.environ["DC_SERVER"],
                                                "-U%s%%%s" % (os.environ["DC_USERNAME"], os.environ["DC_PASSWORD"]))
            self.assertCmdSuccess(result, out, err, "Ensure setpassword runs")
            self.assertEquals(err, "", "setpassword with url")
            self.assertMatch(out, "Changed password OK", "setpassword with url")

        attributes = "sAMAccountName,unicodePwd,supplementalCredentials,virtualClearTextUTF8,virtualClearTextUTF16,virtualSSHA,virtualSambaGPG"
        (result, out, err) = self.runsubcmd("user", "syncpasswords",
                                            "--cache-ldb-initialize",
                                            "--attributes=%s" % attributes,
                                            "--decrypt-samba-gpg")
        self.assertCmdSuccess(result, out, err, "Ensure syncpasswords --cache-ldb-initialize runs")
        self.assertEqual(err, "", "getpassword without url")
        cache_attrs = {
            "objectClass": {"value": "userSyncPasswords"},
            "samdbUrl": {},
            "dirsyncFilter": {},
            "dirsyncAttribute": {},
            "dirsyncControl": {"value": "dirsync:1:0:0"},
            "passwordAttribute": {},
            "decryptSambaGPG": {},
            "currentTime": {},
        }
        for a in cache_attrs.keys():
            v = cache_attrs[a].get("value", "")
            self.assertMatch(out, "%s: %s" % (a, v),
                             "syncpasswords --cache-ldb-initialize: %s: %s out[%s]" % (a, v, out))

        (result, out, err) = self.runsubcmd("user", "syncpasswords", "--no-wait")
        self.assertCmdSuccess(result, out, err, "Ensure syncpasswords --no-wait runs")
        self.assertEqual(err, "", "syncpasswords --no-wait")
        self.assertMatch(out, "dirsync_loop(): results 0",
                         "syncpasswords --no-wait: 'dirsync_loop(): results 0': out[%s]" % (out))
        for user in self.users:
            self.assertMatch(out, "sAMAccountName: %s" % (user["name"]),
                             "syncpasswords --no-wait: 'sAMAccountName': %s out[%s]" % (user["name"], out))

        for user in self.users:
            newpasswd = self.random_password(16)
            creds = credentials.Credentials()
            creds.set_anonymous()
            creds.set_password(newpasswd)
            nthash = creds.get_nt_hash()
            unicodePwd = base64.b64encode(creds.get_nt_hash()).decode('utf8')
            virtualClearTextUTF8 = base64.b64encode(get_bytes(newpasswd)).decode('utf8')
            virtualClearTextUTF16 = base64.b64encode(get_string(newpasswd).encode('utf-16-le')).decode('utf8')

            (result, out, err) = self.runsubcmd("user", "setpassword",
                                                user["name"],
                                                "--newpassword=%s" % newpasswd)
            self.assertCmdSuccess(result, out, err, "Ensure setpassword runs")
            self.assertEquals(err, "", "setpassword without url")
            self.assertMatch(out, "Changed password OK", "setpassword without url")

            (result, out, err) = self.runsubcmd("user", "syncpasswords", "--no-wait")
            self.assertCmdSuccess(result, out, err, "Ensure syncpasswords --no-wait runs")
            self.assertEqual(err, "", "syncpasswords --no-wait")
            self.assertMatch(out, "dirsync_loop(): results 0",
                             "syncpasswords --no-wait: 'dirsync_loop(): results 0': out[%s]" % (out))
            self.assertMatch(out, "sAMAccountName: %s" % (user["name"]),
                             "syncpasswords --no-wait: 'sAMAccountName': %s out[%s]" % (user["name"], out))
            self.assertMatch(out, "# unicodePwd::: REDACTED SECRET ATTRIBUTE",
                             "getpassword '# unicodePwd::: REDACTED SECRET ATTRIBUTE': out[%s]" % out)
            self.assertMatch(out, "unicodePwd:: %s" % unicodePwd,
                             "getpassword unicodePwd: out[%s]" % out)
            self.assertMatch(out, "# supplementalCredentials::: REDACTED SECRET ATTRIBUTE",
                             "getpassword '# supplementalCredentials::: REDACTED SECRET ATTRIBUTE': out[%s]" % out)
            self.assertMatch(out, "supplementalCredentials:: ",
                             "getpassword supplementalCredentials: out[%s]" % out)
            if "virtualSambaGPG:: " in out:
                self.assertMatch(out, "virtualClearTextUTF8:: %s" % virtualClearTextUTF8,
                                 "getpassword virtualClearTextUTF8: out[%s]" % out)
                self.assertMatch(out, "virtualClearTextUTF16:: %s" % virtualClearTextUTF16,
                                 "getpassword virtualClearTextUTF16: out[%s]" % out)
                self.assertMatch(out, "virtualSSHA: ",
                                 "getpassword virtualSSHA: out[%s]" % out)

            (result, out, err) = self.runsubcmd("user", "getpassword",
                                                user["name"],
                                                "--attributes=%s" % attributes,
                                                "--decrypt-samba-gpg")
            self.assertCmdSuccess(result, out, err, "Ensure getpassword runs")
            self.assertEqual(err, "", "getpassword without url")
            self.assertMatch(out, "Got password OK", "getpassword without url")
            self.assertMatch(out, "sAMAccountName: %s" % (user["name"]),
                             "getpassword: '******': %s out[%s]" % (user["name"], out))
            self.assertMatch(out, "unicodePwd:: %s" % unicodePwd,
                             "getpassword unicodePwd: out[%s]" % out)
            self.assertMatch(out, "supplementalCredentials:: ",
                             "getpassword supplementalCredentials: out[%s]" % out)
            self._verify_supplementalCredentials(out.replace("\nGot password OK\n", ""))
            if "virtualSambaGPG:: " in out:
                self.assertMatch(out, "virtualClearTextUTF8:: %s" % virtualClearTextUTF8,
                                 "getpassword virtualClearTextUTF8: out[%s]" % out)
                self.assertMatch(out, "virtualClearTextUTF16:: %s" % virtualClearTextUTF16,
                                 "getpassword virtualClearTextUTF16: out[%s]" % out)
                self.assertMatch(out, "virtualSSHA: ",
                                 "getpassword virtualSSHA: out[%s]" % out)

        for user in self.users:
            newpasswd = self.random_password(16)
            (result, out, err) = self.runsubcmd("user", "setpassword",
                                                user["name"],
                                                "--newpassword=%s" % newpasswd,
                                                "--must-change-at-next-login",
                                                "-H", "ldap://%s" % os.environ["DC_SERVER"],
                                                "-U%s%%%s" % (os.environ["DC_USERNAME"], os.environ["DC_PASSWORD"]))
            self.assertCmdSuccess(result, out, err, "Ensure setpassword runs")
            self.assertEquals(err, "", "setpassword with forced change")
            self.assertMatch(out, "Changed password OK", "setpassword with forced change")
Exemple #23
0
 def get_afs_keyfile(self, host):
     return self.db.get(get_bytes("SECRETS/AFS_KEYFILE/%s" % host))
Exemple #24
0
 def get_domtrust_acc(self, host):
     return self.db.get(get_bytes("SECRETS/$DOMTRUST.ACC/%s" % host))
Exemple #25
0
    def add_remove_group_members(self, groupname, members,
                                 add_members_operation=True,
                                 member_types=[ 'user', 'group', 'computer' ],
                                 member_base_dn=None):
        """Adds or removes group members

        :param groupname: Name of the target group
        :param members: list of group members
        :param add_members_operation: Defines if its an add or remove
            operation
        """

        groupfilter = "(&(sAMAccountName=%s)(objectCategory=%s,%s))" % (
            ldb.binary_encode(groupname), "CN=Group,CN=Schema,CN=Configuration", self.domain_dn())

        self.transaction_start()
        try:
            targetgroup = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
                                      expression=groupfilter, attrs=['member'])
            if len(targetgroup) == 0:
                raise Exception('Unable to find group "%s"' % groupname)
            assert(len(targetgroup) == 1)

            modified = False

            addtargettogroup = """
dn: %s
changetype: modify
""" % (str(targetgroup[0].dn))

            for member in members:
                targetmember_dn = None
                if member_base_dn is None:
                    member_base_dn = self.domain_dn()

                try:
                    membersid = security.dom_sid(member)
                    targetmember_dn = "<SID=%s>" % str(membersid)
                except TypeError as e:
                    pass

                if targetmember_dn is None:
                    try:
                        member_dn = ldb.Dn(self, member)
                        if member_dn.get_linearized() == member_dn.extended_str(1):
                            full_member_dn = self.normalize_dn_in_domain(member_dn)
                        else:
                            full_member_dn = member_dn
                        targetmember_dn = full_member_dn.extended_str(1)
                    except ValueError as e:
                        pass

                if targetmember_dn is None:
                    filter = self.group_member_filter(member, member_types)
                    targetmember = self.search(base=member_base_dn,
                                               scope=ldb.SCOPE_SUBTREE,
                                               expression=filter,
                                               attrs=[])

                    if len(targetmember) > 1:
                        targetmemberlist_str = ""
                        for msg in targetmember:
                            targetmemberlist_str += "%s\n" % msg.get("dn")
                        raise Exception('Found multiple results for "%s":\n%s' %
                                        (member, targetmemberlist_str))
                    if len(targetmember) != 1:
                        raise Exception('Unable to find "%s". Operation cancelled.' % member)
                    targetmember_dn = targetmember[0].dn.extended_str(1)

                if add_members_operation is True and (targetgroup[0].get('member') is None or get_bytes(targetmember_dn) not in [str(x) for x in targetgroup[0]['member']]):
                    modified = True
                    addtargettogroup += """add: member
member: %s
""" % (str(targetmember_dn))

                elif add_members_operation is False and (targetgroup[0].get('member') is not None and get_bytes(targetmember_dn) in targetgroup[0]['member']):
                    modified = True
                    addtargettogroup += """delete: member
member: %s
""" % (str(targetmember_dn))

            if modified is True:
                self.modify_ldif(addtargettogroup)

        except:
            self.transaction_cancel()
            raise
        else:
            self.transaction_commit()
Exemple #26
0
    def generalize_xml(self, root, out_file, global_entities):
        entities = []

        # Locate all user_id and all ACLs
        user_ids = root.findall('.//*[@user_id="TRUE"]')
        user_ids.sort(key=lambda x: x.tag)

        for elem in user_ids:
            old_text = elem.text
            if old_text is None or old_text == '':
                continue

            if old_text in global_entities:
                elem.text = global_entities[old_text]
                entities.append((elem.text, old_text))
            else:
                elem.text = self.new_xml_entity(old_text, ENTITY_USER_ID)

                entities.append((elem.text, old_text))
                global_entities.update([(old_text, elem.text)])

        acls = root.findall('.//*[@acl="TRUE"]')
        acls.sort(key=lambda x: x.tag)

        for elem in acls:
            old_text = elem.text

            if old_text is None or old_text == '':
                continue

            if old_text in global_entities:
                elem.text = global_entities[old_text]
                entities.append((elem.text, old_text))
            else:
                elem.text = self.new_xml_entity(old_text, ENTITY_SDDL_ACL)

                entities.append((elem.text, old_text))
                global_entities.update([(old_text, elem.text)])

        share_paths = root.findall('.//*[@network_path="TRUE"]')
        share_paths.sort(key=lambda x: x.tag)

        for elem in share_paths:
            old_text = elem.text

            if old_text is None or old_text == '':
                continue

            stripped = old_text.lstrip('\\')
            file_server = stripped.split('\\')[0]

            server_index = old_text.find(file_server)

            remaining = old_text[server_index + len(file_server):]
            old_text = old_text[:server_index] + file_server

            if old_text in global_entities:
                elem.text = global_entities[old_text] + remaining
                to_put = global_entities[old_text]
                entities.append((to_put, old_text))
            else:
                to_put = self.new_xml_entity(old_text, ENTITY_NETWORK_PATH)
                elem.text = to_put + remaining

                entities.append((to_put, old_text))
                global_entities.update([(old_text, to_put)])

        # Call any file specific customization of entities
        # (which appear in any subclasses).
        entities.extend(self.custom_entities(root, global_entities))

        output_xml = tostring(root)

        for ent in entities:
            entb = get_bytes(ent[0])
            output_xml = output_xml.replace(entb.replace(b'&', b'&amp;'), entb)

        with open(out_file, 'wb') as f:
            f.write(output_xml)

        return entities
Exemple #27
0
    def add_remove_group_members(self,
                                 groupname,
                                 members,
                                 add_members_operation=True):
        """Adds or removes group members

        :param groupname: Name of the target group
        :param members: list of group members
        :param add_members_operation: Defines if its an add or remove
            operation
        """

        groupfilter = "(&(sAMAccountName=%s)(objectCategory=%s,%s))" % (
            ldb.binary_encode(groupname),
            "CN=Group,CN=Schema,CN=Configuration", self.domain_dn())

        self.transaction_start()
        try:
            targetgroup = self.search(base=self.domain_dn(),
                                      scope=ldb.SCOPE_SUBTREE,
                                      expression=groupfilter,
                                      attrs=['member'])
            if len(targetgroup) == 0:
                raise Exception('Unable to find group "%s"' % groupname)
            assert (len(targetgroup) == 1)

            modified = False

            addtargettogroup = """
dn: %s
changetype: modify
""" % (str(targetgroup[0].dn))

            for member in members:
                filter = ('(&(sAMAccountName=%s)(|(objectclass=user)'
                          '(objectclass=group)))' % ldb.binary_encode(member))
                foreign_msg = None
                try:
                    membersid = security.dom_sid(member)
                except TypeError as e:
                    membersid = None

                if membersid is not None:
                    filter = '(objectSid=%s)' % str(membersid)
                    dn_str = "<SID=%s>" % str(membersid)
                    foreign_msg = ldb.Message()
                    foreign_msg.dn = ldb.Dn(self, dn_str)

                targetmember = self.search(base=self.domain_dn(),
                                           scope=ldb.SCOPE_SUBTREE,
                                           expression="%s" % filter,
                                           attrs=[])

                if len(targetmember) == 0 and foreign_msg is not None:
                    targetmember = [foreign_msg]
                if len(targetmember) != 1:
                    raise Exception(
                        'Unable to find "%s". Operation cancelled.' % member)
                targetmember_dn = targetmember[0].dn.extended_str(1)
                if add_members_operation is True and (
                        targetgroup[0].get('member') is None
                        or get_bytes(targetmember_dn)
                        not in [str(x) for x in targetgroup[0]['member']]):
                    modified = True
                    addtargettogroup += """add: member
member: %s
""" % (str(targetmember_dn))

                elif add_members_operation is False and (
                        targetgroup[0].get('member') is not None
                        and get_bytes(targetmember_dn)
                        in targetgroup[0]['member']):
                    modified = True
                    addtargettogroup += """delete: member
member: %s
""" % (str(targetmember_dn))

            if modified is True:
                self.modify_ldif(addtargettogroup)

        except:
            self.transaction_cancel()
            raise
        else:
            self.transaction_commit()
Exemple #28
0
 def get_machine_password(self, host):
     return self.db.get(get_bytes("SECRETS/MACHINE_PASSWORD/%s" % host))