Exemple #1
0
    def openvpn_start(self):
        response = self.collection.update({
            '_id': self.linked_server.id,
            'links.server_id': self.server.id,
        }, {'$set': {
            'links.$.user_id': self.user.id,
        }})

        if not response['updatedExisting']:
            raise ServerLinkError('Failed to update server links')

        ovpn_conf_path = self.generate_client_conf()

        try:
            self.process = subprocess.Popen(['openvpn', ovpn_conf_path],
                stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        except OSError:
            self.server.output_link.push_output(
                traceback.format_exc(),
                label=self.output_label,
                link_server_id=self.linked_server.id,
            )
            logger.exception('Failed to start link ovpn process. %r' % {
                'server_id': self.server.id,
            })
            raise
Exemple #2
0
def get_mem_usage():
    try:
        free = subprocess.check_output(['free']).split()
        return float(free[15]) / float(free[7])
    except:
        logger.exception('Failed to get memory usage')
    return 0
Exemple #3
0
    def task(self):
        try:
            timestamp_spec = utils.now() - datetime.timedelta(
                seconds=settings.vpn.route_ping_ttl)

            docs = self.routes_collection.find({
                'timestamp': {'$lt': timestamp_spec},
            })

            yield

            for doc in docs:
                server_id = doc['server_id']
                vpc_region = doc['vpc_region']
                vpc_id = doc['vpc_id']
                network = doc['network']

                messenger.publish('instance', ['route_advertisement',
                    server_id, vpc_region, vpc_id, network])
        except GeneratorExit:
            raise
        except:
            logger.exception('Error checking route states', 'tasks')

        yield interrupter_sleep(settings.vpn.server_ping)
Exemple #4
0
    def task(self):
        acme_domain = settings.app.acme_domain

        if not acme_domain:
            return

        if not settings.app.acme_timestamp:
            logger.exception(
                'Failed to update acme certificate. Timestamp not set',
                'tasks',
                acme_domain=acme_domain,
            )
            return

        if not settings.app.acme_key:
            logger.exception(
                'Failed to update acme certificate. Account key not set',
                'tasks',
                acme_domain=acme_domain,
            )
            return

        if utils.time_now() - settings.app.acme_timestamp < \
                settings.app.acme_renew:
            return

        logger.info(
            'Updating acme certificate', 'tasks',
            acme_domain=acme_domain,
        )

        acme.update_acme_cert()
        app.update_server()
Exemple #5
0
    def _keep_alive_thread(self):
        while not self.interrupt:
            try:
                doc = self.collection.find_and_modify({
                    '_id': self.server.id,
                    'instances.instance_id': self.id,
                }, {'$set': {
                    'instances.$.ping_timestamp': utils.now(),
                }}, fields={
                    '_id': False,
                    'instances': True,
                }, new=True)

                yield

                if not doc:
                    if self.stop_process():
                        break
                    else:
                        time.sleep(0.1)
                        continue

            except:
                logger.exception('Failed to update server ping', 'server',
                    server_id=self.server.id,
                )
            yield interrupter_sleep(settings.vpn.server_ping)
Exemple #6
0
    def add_route(self, virt_address, virt_address6,
            host_address, host_address6):
        virt_address = virt_address.split('/')[0]

        try:
            if virt_address in self.client_routes:
                try:
                    self.client_routes.remove(virt_address)
                except KeyError:
                    pass
                utils.del_route(virt_address)

            if not host_address or host_address == \
                    settings.local.host.local_addr:
                return

            self.client_routes.add(virt_address)
            utils.add_route(virt_address, host_address)
        except:
            logger.exception('Failed to add route', 'clients',
                virt_address=virt_address,
                virt_address6=virt_address6,
                host_address=host_address,
                host_address6=host_address6,
            )
Exemple #7
0
def setup_all():
    from pritunl import logger

    setup_local()
    setup_logger()

    try:
        setup_temp_path()
        setup_app()
        setup_signal_handler()
        setup_server()
        setup_mongo()
        setup_cache()
        setup_public_ip()
        setup_host()
        setup_server_listeners()
        setup_dns()
        setup_monitoring()
        setup_poolers()
        setup_host_fix()
        setup_subscription()
        setup_runners()
        setup_handlers()
        setup_check()

        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
        if soft < 25000 or hard < 25000:
            logger.warning(
                'Open file ulimit is lower then recommended',
                'setup',
            )
    except:
        logger.exception('Pritunl setup failed', 'setup')
        raise
Exemple #8
0
def setup_all():
    setup_local()
    setup_logger()

    try:
        setup_app()
        setup_signal_handler()
        setup_server()
        setup_mongo()
        setup_temp_path()

        if settings.conf.ssl:
            setup_server_cert()

        setup_public_ip()
        setup_host()
        setup_poolers()
        setup_host_fix()
        setup_runners()
        setup_handlers()
        setup_check()
    except:
        from pritunl import logger
        logger.exception('Pritunl setup failed', 'setup')
        raise
Exemple #9
0
def send_email(to_addr, subject, text_body, html_body):
    from pritunl import logger

    email_server = settings.app.email_server
    email_from = settings.app.email_from
    email_username = settings.app.email_username
    email_password = settings.app.email_password

    if not email_server or not email_from or not \
            email_username or not email_password:
        raise EmailNotConfiguredError('Email not configured')

    msg = email.mime.multipart.MIMEMultipart('alternative')
    msg['Subject'] = subject
    msg['From'] = email_from
    msg['To'] = to_addr

    msg.attach(email.mime.text.MIMEText(text_body, 'plain'))
    msg.attach(email.mime.text.MIMEText(html_body, 'html'))

    try:
        smtp_conn = smtplib.SMTP_SSL(email_server)
        smtp_conn.login(email_username, email_password)
        smtp_conn.sendmail(email_from, to_addr, msg.as_string())
        smtp_conn.quit()
    except smtplib.SMTPAuthenticationError:
        raise EmailAuthInvalid('Email auth is invalid')
    except smtplib.SMTPSenderRefused:
        raise EmailAuthInvalid('Email from address refused')
    except:
        logger.exception('Unknown smtp error', 'utils',
            from_addr=email_from,
            to_addr=to_addr,
        )
        raise
Exemple #10
0
    def reserve_route_advertisement(self, vpc_region, vpc_id, network):
        ra_id = '%s_%s_%s' % (self.server.id, vpc_id, network)
        timestamp_spec = utils.now() - datetime.timedelta(
            seconds=settings.vpn.route_ping_ttl)

        try:
            self.routes_collection.update_one({
                '_id': ra_id,
                'timestamp': {'$lt': timestamp_spec},
            }, {'$set': {
                'instance_id': self.id,
                'server_id': self.server.id,
                'vpc_region': vpc_region,
                'vpc_id': vpc_id,
                'network': network,
                'timestamp': utils.now(),
            }}, upsert=True)

            utils.add_vpc_route(vpc_region, vpc_id, network,
                settings.local.host.aws_id)

            self.route_advertisements.add(ra_id)
        except pymongo.errors.DuplicateKeyError:
            return
        except:
            logger.exception('Failed to add vpc route', 'server',
                server_id=self.server.id,
                instance_id=self.id,
                vpc_region=vpc_region,
                vpc_id=vpc_id,
                network=network,
            )
Exemple #11
0
    def _sub_thread(self, cursor_id):
        try:
            for msg in self.subscribe(cursor_id=cursor_id):
                yield

                if self.interrupt:
                    return
                message = msg['message']

                try:
                    if message == 'stop':
                        if self.stop_process():
                            self.clean_exit = True
                    elif message == 'force_stop':
                        self.clean_exit = True
                        for _ in xrange(10):
                            self.process.send_signal(signal.SIGKILL)
                            time.sleep(0.01)
                except OSError:
                    pass
        except:
            logger.exception('Exception in messaging thread', 'server',
                server_id=self.server.id,
            )
            self.stop_process()
Exemple #12
0
    def _tail_auth_log(self):
        try:
            self.auth_log_process = subprocess.Popen(
                ['tail', '-f', self.auth_log_path],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )
        except OSError:
            self.server.output.push_output(traceback.format_exc())
            logger.exception('Failed to start tail auth log process. %r' % {
                'server_id': self.server.id,
            })

        while True:
            line = self.auth_log_process.stdout.readline()
            if not line:
                if self.auth_log_process.poll() is not None:
                    break
                else:
                    time.sleep(0.05)
                    continue

            yield

            try:
                self.server.output.push_output(line)
            except:
                logger.exception('Failed to push auth log output. %r', {
                    'server_id': self.server.id,
                })

            yield
Exemple #13
0
    def _connected(self, client_id):
        client = self.clients.find_id(client_id)
        if not client:
            self.instance_com.push_output(
                'ERROR Unknown client connected client_id=%s' % client_id)
            self.instance_com.client_kill(client_id)
            return

        self.set_iptables_rules(
            client['iptables_rules'],
            client['ip6tables_rules'],
        )

        timestamp = utils.now()
        doc = {
            'user_id': client['user_id'],
            'server_id': self.server.id,
            'host_id': settings.local.host_id,
            'timestamp': timestamp,
            'platform': client['platform'],
            'type': client['user_type'],
            'device_name': client['device_name'],
            'mac_addr': client['mac_addr'],
            'network': self.server.network,
            'real_address': client['real_address'],
            'virt_address': client['virt_address'],
            'virt_address6': client['virt_address6'],
            'dns_servers': client['dns_servers'],
            'dns_suffix': client['dns_suffix'],
            'connected_since': int(timestamp.strftime('%s')),
        }

        if settings.local.sub_active and \
                settings.local.sub_plan == 'enterprise':
            domain_hash = hashlib.md5()
            domain_hash.update((client['user_name'].split('@')[0] +
                                '.' + client['org_name']).lower())
            domain_hash = bson.binary.Binary(domain_hash.digest(),
                subtype=bson.binary.MD5_SUBTYPE)
            doc['domain'] = domain_hash

        try:
            doc_id = self.collection.insert(doc)
        except:
            logger.exception('Error adding client', 'server',
                server_id=self.server.id,
            )
            self.instance_com.client_kill(client_id)
            return

        self.clients.update_id(client_id, {
            'doc_id': doc_id,
            'timestamp': time.time(),
        })

        self.clients_queue.append(client_id)

        self.instance_com.push_output(
            'User connected user_id=%s' % client['user_id'])
        self.send_event()
Exemple #14
0
def run_thread():
    last_run = None

    try:
        for task_cls in task.tasks_on_start:
            run_task(task_cls())
    except:
        logger.exception('Error running on start tasks', 'runners')

    while True:
        try:
            cur_time = utils.now()

            if int(time.mktime(cur_time.timetuple())) != last_run:
                last_run = int(time.mktime(cur_time.timetuple()))

                for hour in ('all', cur_time.hour):
                    for minute in ('all', cur_time.minute):
                        for second in ('all', cur_time.second):
                            for task_cls in task.tasks[hour][minute][second]:
                                run_task(task_cls())
        except:
            logger.exception('Error in tasks run thread', 'runners')

        time.sleep(0.5)
        yield
Exemple #15
0
    def _stress_thread(self):
        try:
            i = 0

            for org in self.server.iter_orgs():
                for user in org.iter_users():
                    if user.type != CERT_CLIENT:
                        continue

                    i += 1

                    client = {
                        'client_id': i,
                        'key_id': 1,
                        'org_id': org.id,
                        'user_id': user.id,
                        'mac_addr': utils.rand_str(16),
                        'remote_ip': str(
                            ipaddress.IPAddress(100000000 + random.randint(
                                0, 1000000000))),
                        'platform': 'linux',
                        'device_id': str(bson.ObjectId()),
                        'device_name': utils.random_name(),
                    }

                    self.clients.connect(client)
        except:
            logger.exception('Error in stress thread', 'server',
                server_id=self.server.id,
                instance_id=self.instance.id,
                socket_path=self.socket_path,
            )
Exemple #16
0
def auth_onelogin(username):
    try:
        response = requests.get(
            ONELOGIN_URL + '/api/v3/users/username/%s' % (
                urllib.quote(username)),
            auth=(settings.app.sso_onelogin_key, 'x'),
        )
    except httplib.HTTPException:
        logger.exception('OneLogin api error', 'sso',
            username=username,
        )
        return False

    if response.status_code == 200:
        return True
    elif response.status_code == 404:
        logger.error('OneLogin user not found', 'sso',
            username=username,
        )
    elif response.status_code == 406:
        logger.error('OneLogin user disabled', 'sso',
            username=username,
        )
    else:
        logger.error('OneLogin api error', 'sso',
            username=username,
            status_code=response.status_code,
            response=response.content,
        )
    return False
Exemple #17
0
    def _watch_thread(self):
        try:
            while True:
                self.cur_timestamp = utils.now()
                timestamp_ttl = self.cur_timestamp - datetime.timedelta(
                    seconds=180)

                for client_id, (timestamp, _, _) in self.client_bytes.items():
                    if timestamp < timestamp_ttl:
                        self.client_bytes.pop(client_id, None)

                self.bytes_lock.acquire()
                bytes_recv = self.bytes_recv
                bytes_sent = self.bytes_sent
                self.bytes_recv = 0
                self.bytes_sent = 0
                self.bytes_lock.release()

                if bytes_recv != 0 or bytes_sent != 0:
                    self.server.bandwidth.add_data(
                        utils.now(), bytes_recv, bytes_sent)

                yield interrupter_sleep(self.bandwidth_rate)
                if self.instance.sock_interrupt:
                    return
        except GeneratorExit:
            raise
        except:
            self.push_output('ERROR Management thread error')
            logger.exception('Error in management watch thread', 'server',
                server_id=self.server.id,
                instance_id=self.instance.id,
            )
            self.instance.stop_process()
Exemple #18
0
    def disconnected(self, client_id):
        client = self.clients.find_id(client_id)
        if not client:
            return
        self.clients.remove_id(client_id)

        virt_address = client['virt_address']
        if client['address_dynamic']:
            updated = self.clients.update({
                'id': client_id,
                'virt_address': virt_address,
            }, {
                'virt_address': None,
            })
            if updated:
                self.ip_pool.append(virt_address.split('/')[0])

        doc_id = client.get('doc_id')
        if doc_id:
            try:
                self.collection.remove({
                    '_id': doc_id,
                })
            except:
                logger.exception('Error removing client', 'server',
                    server_id=self.server.id,
                )

        self.instance_com.push_output(
            'User disconnected user_id=%s' % client['user_id'])
        self.send_event()
Exemple #19
0
    def rollback_actions(self):
        logger.warning('Transaction failed rolling back...', 'transaction',
            actions=self.action_sets,
        )

        response = self.transaction_collection.update({
            '_id': self.id,
            'state': ROLLBACK,
        }, {
            '$set': {
                'ttl_timestamp': utils.now() + \
                    datetime.timedelta(seconds=self.ttl),
            },
        })

        if not response['updatedExisting']:
            return

        try:
            self._rollback_actions()
        except:
            logger.exception('Error occurred rolling back ' +
                'transaction actions', 'transaction',
                transaction_id=self.id,
            )
            raise

        self.transaction_collection.remove(self.id)
Exemple #20
0
def init():
    settings.local.host = Host()

    try:
        settings.local.host.load()
    except NotFound:
        pass

    settings.local.host.status = ONLINE
    settings.local.host.users_online = 0
    settings.local.host.start_timestamp = utils.now()
    settings.local.host.ping_timestamp = utils.now()
    if settings.local.public_ip:
        settings.local.host.auto_public_address = settings.local.public_ip
    if settings.local.public_ip6:
        settings.local.host.auto_public_address6 = settings.local.public_ip6

    try:
        settings.local.host.hostname = socket.gethostname()
    except:
        logger.exception('Failed to get hostname', 'host')
        settings.local.host.hostname = None

    if settings.conf.local_address_interface == 'auto':
        try:
            settings.local.host.auto_local_address = utils.get_local_address()
        except:
            logger.exception('Failed to get auto_local_address', 'host')
            settings.local.host.local_address = None

        try:
            settings.local.host.auto_local_address6 = \
                utils.get_local_address6()
        except:
            logger.exception('Failed to get auto_local_address6', 'host')
            settings.local.host.local_address6 = None
    else:
        try:
            settings.local.host.auto_local_address = \
                utils.get_interface_address(
                    str(settings.conf.local_address_interface))
        except:
            logger.exception('Failed to get auto_local_address', 'host',
                interface=settings.conf.local_address_interface)
            settings.local.host.auto_local_address = None

        try:
            settings.local.host.auto_local_address6 = \
                utils.get_interface_address6(
                    str(settings.conf.local_address_interface))
        except:
            logger.exception('Failed to get auto_local_address6', 'host',
                interface=settings.conf.local_address_interface)
            settings.local.host.auto_local_address6 = None

    settings.local.host.auto_instance_id = utils.get_instance_id()
    settings.local.host.local_networks = utils.get_local_networks()

    settings.local.host.commit()
    event.Event(type=HOSTS_UPDATED)
Exemple #21
0
    def _keep_alive_thread(self, semaphore, process):
        semaphore.release()
        exit_attempts = 0
        while not self._interrupt:
            self.load()
            if self._instance_id != self.instance_id:
                logger.info('Server instance removed, stopping server. %r' % {
                    'server_id': self.id,
                    'instance_id': self._instance_id,
                })
                if exit_attempts > 2:
                    process.send_signal(signal.SIGKILL)
                else:
                    process.send_signal(signal.SIGINT)
                exit_attempts += 1
                time.sleep(0.5)
                continue

            self.ping_timestamp = datetime.datetime.utcnow()
            try:
                self.commit('ping_timestamp')
            except:
                logger.exception('Failed to update server ping. %r' % {
                    'server_id': self.id,
                })
            time.sleep(settings.vpn.server_ping)
Exemple #22
0
def check_thread():
    while True:
        try:
            cur_timestamp = utils.now()
            spec = {
                'ttl_timestamp': {'$lt': cur_timestamp},
                'state': {'$ne': COMPLETE},
            }

            for task_item in task.iter_tasks(spec):
                random_sleep()

                response = task.Task.collection.update({
                    '_id': task_item.id,
                    'state': {'$ne': COMPLETE},
                    'ttl_timestamp': {'$lt': cur_timestamp},
                }, {'$unset': {
                    'runner_id': '',
                }})
                if response['updatedExisting']:
                    run_task(task_item)
        except:
            logger.exception('Error in task check thread', 'runners')

        yield interrupter_sleep(settings.mongo.task_ttl)
Exemple #23
0
def sync_time():
    nounce = None
    doc = {}

    try:
        collection = mongo.get_collection('time_sync')

        nounce = ObjectId()
        collection.insert({
            'nounce': nounce,
        }, manipulate=False)

        mongo_time_start = datetime.datetime.utcnow()

        doc = collection.find_one({
            'nounce': nounce,
        })
        mongo_time = doc['_id'].generation_time.replace(tzinfo=None)

        settings.local.mongo_time = (mongo_time_start, mongo_time)

        collection.remove(doc['_id'])
    except:
        from pritunl import logger

        logger.exception('Failed to sync time',
            nounce=nounce,
            doc_id=doc.get('id'),
        )
        raise
Exemple #24
0
def _on_msg(msg):
    if msg['message'] != 'start':
        return

    try:
        svr = server.get_server(msg['server_id'])
        if settings.local.host_id not in svr.hosts:
            return

        prefered_host = msg.get('prefered_host')

        # When server start msg is received from check_thread it is
        # possible for multiple servers to send the start message.
        # Attempt to choose a random host based on the current time in
        # seconds so that all servers will choose the same random host
        # if selected in the same one second window
        if not prefered_host:
            rand_hash = hashlib.sha256(str(int(time.time()))).digest()
            rand_gen = random.Random(rand_hash)
            prefered_host = svr.hosts[rand_gen.randint(0, len(svr.hosts) - 1)]

        if settings.local.host_id != prefered_host:
            time.sleep(0.1)

        svr.run(send_events=msg.get('send_events'))
    except:
        logger.exception('Failed to run server.')
Exemple #25
0
def _check_thread():
    collection = mongo.get_collection('transaction')

    while True:
        try:
            spec = {
                'ttl_timestamp': {'$lt': utils.now()},
            }

            for doc in collection.find(spec).sort('priority'):
                logger.info('Transaction timeout retrying...', 'runners',
                    doc=doc,
                )

                try:
                    tran = transaction.Transaction(doc=doc)
                    tran.run()
                except:
                    logger.exception('Failed to run transaction', 'runners',
                        transaction_id=doc['_id'],
                    )

            yield interrupter_sleep(settings.mongo.tran_ttl)
        except GeneratorExit:
            raise
        except:
            logger.exception('Error in transaction runner thread', 'runners')
            time.sleep(0.5)
Exemple #26
0
def get_user_id(username):
    try:
        response = requests.get(
            _getokta_url() + "/api/v1/users/%s" % urllib.quote(username),
            headers={"Accept": "application/json", "Authorization": "SSWS %s" % settings.app.sso_okta_token},
        )
    except httplib.HTTPException:
        logger.exception("Okta api error", "sso", username=username)
        return None

    if response.status_code != 200:
        logger.error(
            "Okta api error", "sso", username=username, status_code=response.status_code, response=response.content
        )
        return None

    data = response.json()
    if "id" in data:
        if data["status"].lower() != "active":
            logger.error("Okta user is not active", "sso", username=username)
            return None
        return data["id"]

    logger.error(
        "Okta username not found", "sso", username=username, status_code=response.status_code, response=response.content
    )

    return None
Exemple #27
0
def get_mem_usage():
    try:
        free = utils.check_output_logged(['free']).split()
        return float(free[15]) / float(free[7])
    except:
        logger.exception('Failed to get memory usage', 'host')
    return 0
Exemple #28
0
    def _stress_thread(self):
        try:
            i = 0

            for org in self.server.iter_orgs():
                for user in org.iter_users():
                    if user.type != CERT_CLIENT:
                        continue

                    i += 1

                    client = {
                        "client_id": i,
                        "key_id": 1,
                        "org_id": org.id,
                        "user_id": user.id,
                        "mac_addr": utils.rand_str(16),
                        "remote_ip": str(ipaddress.IPAddress(100000000 + random.randint(0, 1000000000))),
                        "platform": "linux",
                        "device_id": str(bson.ObjectId()),
                        "device_name": utils.random_name(),
                    }

                    self.clients.connect(client)
        except:
            logger.exception(
                "Error in stress thread",
                "server",
                server_id=self.server.id,
                instance_id=self.instance.id,
                socket_path=self.socket_path,
            )
Exemple #29
0
def _host_check_thread():
    collection = mongo.get_collection('hosts')

    while True:
        try:
            ttl_timestamp = {'$lt': utils.now() -
                datetime.timedelta(seconds=settings.app.host_ttl)}

            cursor = collection.find({
                'ping_timestamp': ttl_timestamp,
            }, {
                '_id': True,
            })

            for doc in cursor:
                response = collection.update({
                    '_id': doc['_id'],
                    'ping_timestamp': ttl_timestamp,
                }, {'$set': {
                    'status': OFFLINE,
                    'ping_timestamp': None,
                }})

                if response['updatedExisting']:
                    event.Event(type=HOSTS_UPDATED)
        except GeneratorExit:
            raise
        except:
            logger.exception('Error checking host status', 'runners')

        yield interrupter_sleep(settings.app.host_ttl)
Exemple #30
0
    def _sub_thread(self, cursor_id):
        try:
            for msg in self.subscribe(cursor_id=cursor_id):
                yield

                if self.interrupt:
                    return
                message = msg['message']

                try:
                    if message == 'stop':
                        if self.stop_process():
                            self.clean_exit = True
                    elif message == 'rebalance':
                        if settings.local.host.availability_group != \
                                msg['availability_group']:
                            if self.stop_process():
                                self.clean_exit = True
                    elif message == 'force_stop':
                        for instance_link in self.server_links:
                            instance_link.stop()

                        self.clean_exit = True
                        for _ in xrange(10):
                            self.process.send_signal(signal.SIGKILL)
                            time.sleep(0.01)
                except OSError:
                    pass
        except:
            logger.exception('Exception in messaging thread', 'server',
                server_id=self.server.id,
            )
            self.stop_process()
Exemple #31
0
def _run_server(restart):
    global app_server

    logger.info('Starting server', 'app')

    app_server = CherryPyWSGIServerLogged(
        ('localhost', settings.app.server_internal_port),
        app,
        request_queue_size=settings.app.request_queue_size,
        numthreads=settings.app.request_thread_count,
        shutdown_timeout=3,
    )
    app_server.server_name = ''

    server_cert_path = None
    server_key_path = None
    redirect_server = 'true' if settings.app.redirect_server else 'false'
    internal_addr = 'localhost:%s' % settings.app.server_internal_port

    if settings.app.server_ssl:
        app.config.update(SESSION_COOKIE_SECURE=True, )

        setup_server_cert()

        server_cert_path, server_key_path = utils.write_server_cert(
            settings.app.server_cert,
            settings.app.server_key,
            settings.app.acme_domain,
        )

    if not restart:
        settings.local.server_ready.set()
        settings.local.server_start.wait()

    process_state = True
    process = subprocess.Popen(
        ['pritunl-web'],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        env=dict(os.environ, **{
            'REVERSE_PROXY_HEADER': settings.app.reverse_proxy_header if \
                settings.app.reverse_proxy else '',
            'REDIRECT_SERVER': redirect_server,
            'BIND_HOST': settings.conf.bind_addr,
            'BIND_PORT': str(settings.app.server_port),
            'INTERNAL_ADDRESS': internal_addr,
            'CERT_PATH': server_cert_path or '',
            'KEY_PATH': server_key_path or '',
        }),
    )

    def poll_thread():
        time.sleep(0.5)
        if process.wait() and process_state:
            time.sleep(0.25)
            if not check_global_interrupt():
                stdout, stderr = process._communicate(None)
                logger.error(
                    'Web server process exited unexpectedly',
                    'app',
                    stdout=stdout,
                    stderr=stderr,
                )
                time.sleep(1)
                restart_server(1)

    thread = threading.Thread(target=poll_thread)
    thread.daemon = True
    thread.start()

    _watch_event.set()

    try:
        app_server.start()
    except (KeyboardInterrupt, SystemExit):
        return
    except ServerRestart:
        raise
    except:
        logger.exception('Server error occurred', 'app')
        raise
    finally:
        process_state = False
        try:
            process.kill()
        except:
            pass
Exemple #32
0
def settings_put():
    if settings.app.demo_mode:
        return utils.demo_blocked()

    org_event = False
    admin = flask.g.administrator
    changes = set()

    settings_commit = False
    update_server = False
    update_acme = False
    update_cert = False

    if 'username' in flask.request.json and flask.request.json['username']:
        username = utils.filter_str(
            flask.request.json['username']).lower()
        if username != admin.username:
            changes.add('username')
        admin.username = username

    if 'password' in flask.request.json and flask.request.json['password']:
        password = flask.request.json['password']
        changes.add('password')
        admin.password = password

    if 'token' in flask.request.json and flask.request.json['token']:
        admin.generate_token()
        changes.add('token')

    if 'secret' in flask.request.json and flask.request.json['secret']:
        admin.generate_secret()
        changes.add('token')

    if 'server_cert' in flask.request.json:
        settings_commit = True
        server_cert = flask.request.json['server_cert']
        if server_cert:
            server_cert = server_cert.strip()
        else:
            server_cert = None

        if server_cert != settings.app.server_cert:
            update_server = True

        settings.app.server_cert = server_cert

    if 'server_key' in flask.request.json:
        settings_commit = True
        server_key = flask.request.json['server_key']
        if server_key:
            server_key = server_key.strip()
        else:
            server_key = None

        if server_key != settings.app.server_key:
            update_server = True

        settings.app.server_key = server_key

    if 'server_port' in flask.request.json:
        settings_commit = True

        server_port = flask.request.json['server_port']
        if not server_port:
            server_port = 443

        try:
            server_port = int(server_port)
            if server_port < 1 or server_port > 65535:
                raise ValueError('Port invalid')
        except ValueError:
            return utils.jsonify({
                'error': PORT_INVALID,
                'error_msg': PORT_INVALID_MSG,
            }, 400)

        if settings.app.redirect_server and server_port == 80:
            return utils.jsonify({
                'error': PORT_RESERVED,
                'error_msg': PORT_RESERVED_MSG,
            }, 400)

        if server_port != settings.app.server_port:
            update_server = True

        settings.app.server_port = server_port

    if 'acme_domain' in flask.request.json:
        settings_commit = True

        acme_domain = utils.filter_str(
            flask.request.json['acme_domain'] or None)
        if acme_domain:
            acme_domain = acme_domain.replace('https://', '')
            acme_domain = acme_domain.replace('http://', '')
            acme_domain = acme_domain.replace('/', '')

        if acme_domain != settings.app.acme_domain:
            if not acme_domain:
                settings.app.acme_key = None
                settings.app.acme_timestamp = None
                settings.app.server_key = None
                settings.app.server_cert = None
                update_server = True
                update_cert = True
            else:
                update_acme = True
        settings.app.acme_domain = acme_domain

    if 'auditing' in flask.request.json:
        settings_commit = True
        auditing = flask.request.json['auditing'] or None

        if settings.app.auditing != auditing:
            if not flask.g.administrator.super_user:
                return utils.jsonify({
                    'error': REQUIRES_SUPER_USER,
                    'error_msg': REQUIRES_SUPER_USER_MSG,
                }, 400)
            org_event = True

        settings.app.auditing = auditing

    if 'monitoring' in flask.request.json:
        settings_commit = True
        monitoring = flask.request.json['monitoring'] or None
        settings.app.monitoring = monitoring

    if 'influxdb_uri' in flask.request.json:
        settings_commit = True
        influxdb_uri = flask.request.json['influxdb_uri'] or None
        settings.app.influxdb_uri = influxdb_uri

    if 'email_from' in flask.request.json:
        settings_commit = True
        email_from = flask.request.json['email_from'] or None
        if email_from != settings.app.email_from:
            changes.add('smtp')
        settings.app.email_from = email_from

    if 'email_server' in flask.request.json:
        settings_commit = True
        email_server = flask.request.json['email_server'] or None
        if email_server != settings.app.email_server:
            changes.add('smtp')
        settings.app.email_server = email_server

    if 'email_username' in flask.request.json:
        settings_commit = True
        email_username = flask.request.json['email_username'] or None
        if email_username != settings.app.email_username:
            changes.add('smtp')
        settings.app.email_username = email_username

    if 'email_password' in flask.request.json:
        settings_commit = True
        email_password = flask.request.json['email_password'] or None
        if email_password != settings.app.email_password:
            changes.add('smtp')
        settings.app.email_password = email_password

    if 'pin_mode' in flask.request.json:
        settings_commit = True
        pin_mode = flask.request.json['pin_mode'] or None
        if pin_mode != settings.user.pin_mode:
            changes.add('pin_mode')
        settings.user.pin_mode = pin_mode

    if 'sso' in flask.request.json:
        org_event = True
        settings_commit = True
        sso = flask.request.json['sso'] or None
        if sso != settings.app.sso:
            changes.add('sso')
        settings.app.sso = sso

    if 'sso_match' in flask.request.json:
        settings_commit = True
        sso_match = flask.request.json['sso_match'] or None

        if sso_match != settings.app.sso_match:
            changes.add('sso')

        if isinstance(sso_match, list):
            settings.app.sso_match = sso_match
        else:
            settings.app.sso_match = None

    if 'sso_token' in flask.request.json:
        settings_commit = True
        sso_token = flask.request.json['sso_token'] or None
        if sso_token != settings.app.sso_token:
            changes.add('sso')
        settings.app.sso_token = sso_token

    if 'sso_secret' in flask.request.json:
        settings_commit = True
        sso_secret = flask.request.json['sso_secret'] or None
        if sso_secret != settings.app.sso_secret:
            changes.add('sso')
        settings.app.sso_secret = sso_secret

    if 'sso_host' in flask.request.json:
        settings_commit = True
        sso_host = flask.request.json['sso_host'] or None
        if sso_host != settings.app.sso_host:
            changes.add('sso')
        settings.app.sso_host = sso_host

    if 'sso_org' in flask.request.json:
        settings_commit = True
        sso_org = flask.request.json['sso_org']

        if sso_org:
            sso_org = utils.ObjectId(sso_org)
        else:
            sso_org = None

        if sso_org != settings.app.sso_org:
            changes.add('sso')

        settings.app.sso_org = sso_org

    if 'sso_saml_url' in flask.request.json:
        settings_commit = True
        sso_saml_url = flask.request.json['sso_saml_url'] or None
        if sso_saml_url != settings.app.sso_saml_url:
            changes.add('sso')
        settings.app.sso_saml_url = sso_saml_url

    if 'sso_saml_issuer_url' in flask.request.json:
        settings_commit = True
        sso_saml_issuer_url = flask.request.json['sso_saml_issuer_url'] or None
        if sso_saml_issuer_url != settings.app.sso_saml_issuer_url:
            changes.add('sso')
        settings.app.sso_saml_issuer_url = sso_saml_issuer_url

    if 'sso_saml_cert' in flask.request.json:
        settings_commit = True
        sso_saml_cert = flask.request.json['sso_saml_cert'] or None
        if sso_saml_cert != settings.app.sso_saml_cert:
            changes.add('sso')
        settings.app.sso_saml_cert = sso_saml_cert

    if 'sso_okta_token' in flask.request.json:
        settings_commit = True
        sso_okta_token = flask.request.json['sso_okta_token'] or None
        if sso_okta_token != settings.app.sso_okta_token:
            changes.add('sso')
        settings.app.sso_okta_token = sso_okta_token

    if 'sso_onelogin_key' in flask.request.json:
        settings_commit = True
        sso_onelogin_key = flask.request.json['sso_onelogin_key'] or None
        if sso_onelogin_key != settings.app.sso_onelogin_key:
            changes.add('sso')
        settings.app.sso_onelogin_key = sso_onelogin_key

    if 'theme' in flask.request.json:
        settings_commit = True
        theme = 'dark' if flask.request.json['theme'] == 'dark' else 'light'

        if theme != settings.app.theme:
            if theme == 'dark':
                event.Event(type=THEME_DARK)
            else:
                event.Event(type=THEME_LIGHT)

        settings.app.theme = theme

    if 'public_address' in flask.request.json:
        public_address = flask.request.json['public_address']
        settings.local.host.public_address = public_address
        settings.local.host.commit('public_address')

    if 'public_address6' in flask.request.json:
        public_address6 = flask.request.json['public_address6']
        settings.local.host.public_address6 = public_address6
        settings.local.host.commit('public_address6')

    if 'routed_subnet6' in flask.request.json:
        routed_subnet6 = flask.request.json['routed_subnet6']
        if routed_subnet6:
            try:
                routed_subnet6 = ipaddress.IPv6Network(
                    flask.request.json['routed_subnet6'])
            except (ipaddress.AddressValueError, ValueError):
                return utils.jsonify({
                    'error': IPV6_SUBNET_INVALID,
                    'error_msg': IPV6_SUBNET_INVALID_MSG,
                }, 400)

            if routed_subnet6.prefixlen > 64:
                return utils.jsonify({
                    'error': IPV6_SUBNET_SIZE_INVALID,
                    'error_msg': IPV6_SUBNET_SIZE_INVALID_MSG,
                }, 400)

            routed_subnet6 = str(routed_subnet6)
        else:
            routed_subnet6 = None

        if settings.local.host.routed_subnet6 != routed_subnet6:
            if server.get_online_ipv6_count():
                return utils.jsonify({
                    'error': IPV6_SUBNET_ONLINE,
                    'error_msg': IPV6_SUBNET_ONLINE_MSG,
                }, 400)
            settings.local.host.routed_subnet6 = routed_subnet6
            settings.local.host.commit('routed_subnet6')

    if 'reverse_proxy' in flask.request.json:
        settings_commit = True
        reverse_proxy = flask.request.json['reverse_proxy']
        settings.app.reverse_proxy = True if reverse_proxy else False

    if 'cloud_provider' in flask.request.json:
        settings_commit = True
        cloud_provider = flask.request.json['cloud_provider']
        if cloud_provider:
            settings.app.cloud_provider = cloud_provider
        else:
            settings.app.cloud_provider = None

    for aws_key in (
                'us_east_1_access_key',
                'us_east_1_secret_key',
                'us_west_1_access_key',
                'us_west_1_secret_key',
                'us_west_2_access_key',
                'us_west_2_secret_key',
                'eu_west_1_access_key',
                'eu_west_1_secret_key',
                'eu_central_1_access_key',
                'eu_central_1_secret_key',
                'ap_northeast_1_access_key',
                'ap_northeast_1_secret_key',
                'ap_northeast_2_access_key',
                'ap_northeast_2_secret_key',
                'ap_southeast_1_access_key',
                'ap_southeast_1_secret_key',
                'ap_southeast_2_access_key',
                'ap_southeast_2_secret_key',
                'sa_east_1_access_key',
                'sa_east_1_secret_key',
            ):
        if aws_key in flask.request.json:
            settings_commit = True
            aws_value = flask.request.json[aws_key]

            if aws_value:
                setattr(settings.app, aws_key, utils.filter_str(aws_value))
            else:
                setattr(settings.app, aws_key, None)

    if not settings.app.sso:
        settings.app.sso_match = None
        settings.app.sso_token = None
        settings.app.sso_secret = None
        settings.app.sso_host = None
        settings.app.sso_org = None
        settings.app.sso_saml_url = None
        settings.app.sso_saml_issuer_url = None
        settings.app.sso_saml_cert = None
        settings.app.sso_okta_token = None
        settings.app.sso_onelogin_key = None

    for change in changes:
        flask.g.administrator.audit_event(
            'admin_settings',
            _changes_audit_text[change],
            remote_addr=utils.get_remote_addr(),
        )

    if settings_commit:
        settings.commit()

    admin.commit(admin.changed)

    if org_event:
        for org in organization.iter_orgs(fields=('_id')):
            event.Event(type=USERS_UPDATED, resource_id=org.id)

    event.Event(type=SETTINGS_UPDATED)

    if update_acme:
        try:
            acme.update_acme_cert()
            app.update_server(0.5)
        except:
            logger.exception('Failed to get LetsEncrypt cert', 'handler',
                acme_domain=settings.app.acme_domain,
            )
            settings.app.acme_domain = None
            settings.app.acme_key = None
            settings.commit()
            return utils.jsonify({
                'error': ACME_ERROR,
                'error_msg': ACME_ERROR_MSG,
            }, 400)
    elif update_cert:
        utils.create_server_cert()
        app.update_server(0.5)
    elif update_server:
        app.update_server(0.5)

    response = flask.g.administrator.dict()
    response.update(_dict())
    return utils.jsonify(response)
Exemple #33
0
def main(default_conf=None):
    if len(sys.argv) > 1:
        cmd = sys.argv[1]
    else:
        cmd = 'start'

    parser = optparse.OptionParser(usage=USAGE)

    if cmd == 'start':
        parser.add_option('-d', '--daemon', action='store_true',
            help='Daemonize process')
        parser.add_option('-p', '--pidfile', type='string',
            help='Path to create pid file')
        parser.add_option('-c', '--conf', type='string',
            help='Path to configuration file')
        parser.add_option('-q', '--quiet', action='store_true',
            help='Suppress logging output')
    elif cmd == 'logs':
        parser.add_option('--archive', action='store_true',
            help='Archive log file')
        parser.add_option('--tail', action='store_true',
            help='Tail log file')
        parser.add_option('--limit', type='int',
            help='Limit log lines')
        parser.add_option('--natural', action='store_true',
            help='Natural log sort')
    elif cmd == 'set':
        parser.disable_interspersed_args()

    (options, args) = parser.parse_args()

    if hasattr(options, 'conf') and options.conf:
        conf_path = options.conf
    else:
        conf_path = default_conf
    pritunl.set_conf_path(conf_path)

    if cmd == 'version':
        print('%s v%s' % (pritunl.__title__, pritunl.__version__))
        sys.exit(0)
    elif cmd == 'setup-key':
        from pritunl import setup
        from pritunl import settings

        setup.setup_loc()
        print(settings.local.setup_key)

        sys.exit(0)
    elif cmd == 'reset-version':
        from pritunl.constants import MIN_DATABASE_VER
        from pritunl import setup
        from pritunl import utils

        setup.setup_db()
        utils.set_db_ver(pritunl.__version__, MIN_DATABASE_VER)

        time.sleep(.2)
        print('Database version reset to %s' % pritunl.__version__)

        sys.exit(0)
    elif cmd == 'reset-password':
        from pritunl import setup
        from pritunl import auth

        setup.setup_db()
        username, password = auth.reset_password()

        print('Administrator password successfully reset:\n' + \
            '  username: "******"\n  password: "******"' % (username, password))

        sys.exit(0)
    elif cmd == 'default-password':
        from pritunl import setup
        from pritunl import auth

        setup.setup_db()
        username, password = auth.get_default_password()

        if not password:
            print('No default password available, use reset-password')
        else:
            print('Administrator default password:\n' + \
                '  username: "******"\n  password: "******"' % (username, password))

        sys.exit(0)
    elif cmd == 'reconfigure':
        from pritunl import setup
        from pritunl import settings
        setup.setup_loc()

        settings.conf.mongodb_uri = None
        settings.conf.commit()

        time.sleep(.2)
        print('Database configuration successfully reset')

        sys.exit(0)
    elif cmd == 'get':
        from pritunl import setup
        from pritunl import settings
        setup.setup_db_host()

        if len(args) != 2:
            raise ValueError('Invalid arguments')

        split = args[1].split('.')
        key_str = None
        group_str = split[0]
        if len(split) > 1:
            key_str = split[1]

        if group_str == 'host':
            group = settings.local.host
        else:
            group = getattr(settings, group_str)

        if key_str:
            val = getattr(group, key_str)
            print('%s.%s = %s' % (group_str, key_str,
                json.dumps(val, default=lambda x: str(x))))

        else:
            for field in group.fields:
                val = getattr(group, field)
                print('%s.%s = %s' % (group_str, field,
                    json.dumps(val, default=lambda x: str(x))))

        sys.exit(0)
    elif cmd == 'set':
        from pritunl.constants import HOSTS_UPDATED
        from pritunl import setup
        from pritunl import settings
        from pritunl import event
        from pritunl import messenger
        setup.setup_db_host()

        if len(args) != 3:
            raise ValueError('Invalid arguments')

        group_str, key_str = args[1].split('.')

        if group_str == 'host':
            group = settings.local.host
        else:
            group = getattr(settings, group_str)

        val_str = args[2]
        try:
            val = json.loads(val_str)
        except ValueError:
            val = json.loads(json.JSONEncoder().encode(val_str))

        setattr(group, key_str, val)

        if group_str == 'host':
            settings.local.host.commit()

            event.Event(type=HOSTS_UPDATED)
            messenger.publish('hosts', 'updated')
        else:
            settings.commit()

        time.sleep(.2)

        print('%s.%s = %s' % (group_str, key_str,
            json.dumps(getattr(group, key_str), default=lambda x: str(x))))
        print('Successfully updated configuration. This change is ' \
            'stored in the database and has been applied to all hosts ' \
            'in the cluster.')

        sys.exit(0)
    elif cmd == 'unset':
        from pritunl import setup
        from pritunl import settings
        setup.setup_db()

        if len(args) != 2:
            raise ValueError('Invalid arguments')

        group_str, key_str = args[1].split('.')

        group = getattr(settings, group_str)

        group.unset(key_str)

        settings.commit()

        time.sleep(.2)

        print('%s.%s = %s' % (group_str, key_str,
            json.dumps(getattr(group, key_str), default=lambda x: str(x))))
        print('Successfully updated configuration. This change is ' \
            'stored in the database and has been applied to all hosts ' \
            'in the cluster.')

        sys.exit(0)
    elif cmd == 'get-mongodb':
        from pritunl import setup
        from pritunl import settings
        setup.setup_loc()

        print(settings.conf.mongodb_uri)

        sys.exit(0)
    elif cmd == 'set-mongodb':
        from pritunl import setup
        from pritunl import settings
        setup.setup_loc()

        if len(args) > 1:
            mongodb_uri = args[1]
        else:
            mongodb_uri = None

        settings.conf.mongodb_uri = mongodb_uri
        settings.conf.commit()

        time.sleep(.2)
        print('Database configuration successfully set')

        sys.exit(0)
    elif cmd == 'get-host-id':
        from pritunl import setup
        from pritunl import settings
        setup.setup_loc()

        print(settings.local.host_id)

        sys.exit(0)
    elif cmd == 'set-host-id':
        from pritunl import setup
        from pritunl import settings
        setup.setup_loc()

        if len(args) > 1:
            host_id = args[1]
        else:
            host_id = None

        with open(settings.conf.uuid_path, 'w') as uuid_file:
            uuid_file.write(host_id.strip())

        time.sleep(.2)
        print('Host ID successfully set')

        sys.exit(0)
    elif cmd == 'reset-ssl-cert':
        from pritunl import setup
        from pritunl import settings
        setup.setup_db()

        settings.app.server_cert = None
        settings.app.server_key = None
        settings.app.acme_timestamp = None
        settings.app.acme_key = None
        settings.app.acme_domain = None
        settings.commit()

        time.sleep(.2)
        print('Server ssl certificate successfully reset')

        sys.exit(0)
    elif cmd == 'destroy-secondary':
        from pritunl import setup
        from pritunl import logger
        from pritunl import mongo

        setup.setup_db()

        print('Destroying secondary database...')

        mongo.get_collection('clients').drop()
        mongo.get_collection('clients_pool').drop()
        mongo.get_collection('transaction').drop()
        mongo.get_collection('queue').drop()
        mongo.get_collection('tasks').drop()

        mongo.get_collection('messages').drop()
        mongo.get_collection('users_key_link').drop()
        mongo.get_collection('auth_sessions').drop()
        mongo.get_collection('auth_csrf_tokens').drop()
        mongo.get_collection('auth_limiter').drop()
        mongo.get_collection('otp').drop()
        mongo.get_collection('otp_cache').drop()
        mongo.get_collection('sso_tokens').drop()
        mongo.get_collection('sso_push_cache').drop()
        mongo.get_collection('sso_client_cache').drop()
        mongo.get_collection('sso_passcode_cache').drop()

        setup.upsert_indexes()

        server_coll = mongo.get_collection('servers')
        server_coll.update_many({}, {
            '$set': {
                'status': 'offline',
                'instances': [],
                'instances_count': 0,
            },
            '$unset': {
                'network_lock': '',
                'network_lock_ttl': '',
            },
        })

        print('Secondary database destroyed')

        sys.exit(0)
    elif cmd == 'repair-database':
        from pritunl import setup
        from pritunl import logger
        from pritunl import mongo

        setup.setup_db()

        print('Repairing database...')

        mongo.get_collection('clients').drop()
        mongo.get_collection('clients_pool').drop()
        mongo.get_collection('transaction').drop()
        mongo.get_collection('queue').drop()
        mongo.get_collection('tasks').drop()

        mongo.get_collection('messages').drop()
        mongo.get_collection('users_key_link').drop()
        mongo.get_collection('auth_sessions').drop()
        mongo.get_collection('auth_csrf_tokens').drop()
        mongo.get_collection('auth_limiter').drop()
        mongo.get_collection('otp').drop()
        mongo.get_collection('otp_cache').drop()
        mongo.get_collection('sso_tokens').drop()
        mongo.get_collection('sso_push_cache').drop()
        mongo.get_collection('sso_client_cache').drop()
        mongo.get_collection('sso_passcode_cache').drop()
        mongo.get_collection('acme_challenges').drop()

        mongo.get_collection('logs').drop()
        mongo.get_collection('log_entries').drop()
        mongo.get_collection('servers_ip_pool').drop()

        setup.upsert_indexes()

        server_coll = mongo.get_collection('servers')
        server_coll.update_many({}, {
            '$set': {
                'status': 'offline',
                'instances': [],
                'instances_count': 0,
            },
            '$unset': {
                'network_lock': '',
                'network_lock_ttl': '',
            },
        })

        from pritunl import server

        for svr in server.iter_servers():
            try:
                svr.ip_pool.sync_ip_pool()
            except:
                logger.exception('Failed to sync server IP pool', 'tasks',
                    server_id=svr.id,
                )

        server_coll.update_many({}, {
            '$set': {
                'status': 'offline',
                'instances': [],
                'instances_count': 0,
            },
            '$unset': {
                'network_lock': '',
                'network_lock_ttl': '',
            },
        })

        print('Database repair complete')

        sys.exit(0)
    elif cmd == 'logs':
        from pritunl import setup
        from pritunl import logger
        setup.setup_db()

        log_view = logger.LogView()

        if options.archive:
            if len(args) > 1:
                archive_path = args[1]
            else:
                archive_path = './'
            print('Log archived to: ' + log_view.archive_log(archive_path,
                options.natural, options.limit))
        elif options.tail:
            for msg in log_view.tail_log_lines():
                print(msg)
        else:
            print(log_view.get_log_lines(
                natural=options.natural,
                limit=options.limit,
            ))

        sys.exit(0)
    elif cmd == 'clear-auth-limit':
        from pritunl import setup
        from pritunl import logger
        from pritunl import mongo
        from pritunl import settings

        setup.setup_db()

        mongo.get_collection('auth_limiter').delete_many({})

        print('Auth limiter cleared')

        sys.exit(0)
    elif cmd == 'clear-logs':
        from pritunl import setup
        from pritunl import logger
        from pritunl import mongo
        from pritunl import settings

        setup.setup_db()

        mongo.get_collection('logs').drop()
        mongo.get_collection('log_entries').drop()

        prefix = settings.conf.mongodb_collection_prefix or ''

        log_limit = settings.app.log_limit
        mongo.database.create_collection(prefix + 'logs', capped=True,
            size=log_limit * 1024, max=log_limit)

        log_entry_limit = settings.app.log_entry_limit
        mongo.database.create_collection(prefix + 'log_entries', capped=True,
            size=log_entry_limit * 512, max=log_entry_limit)

        print('Log entries cleared')

        sys.exit(0)
    elif cmd != 'start':
        raise ValueError('Invalid command')

    from pritunl import settings

    if options.quiet:
        settings.local.quiet = True

    if options.daemon:
        pid = os.fork()
        if pid > 0:
            if options.pidfile:
                with open(options.pidfile, 'w') as pid_file:
                    pid_file.write('%s' % pid)
            sys.exit(0)
    elif not options.quiet:
        print('##############################################################')
        print('#                                                            #')
        print('#                      /$$   /$$                         /$$ #')
        print('#                     |__/  | $$                        | $$ #')
        print('#   /$$$$$$   /$$$$$$  /$$ /$$$$$$   /$$   /$$ /$$$$$$$ | $$ #')
        print('#  /$$__  $$ /$$__  $$| $$|_  $$_/  | $$  | $$| $$__  $$| $$ #')
        print('# | $$  \ $$| $$  \__/| $$  | $$    | $$  | $$| $$  \ $$| $$ #')
        print('# | $$  | $$| $$      | $$  | $$ /$$| $$  | $$| $$  | $$| $$ #')
        print('# | $$$$$$$/| $$      | $$  |  $$$$/|  $$$$$$/| $$  | $$| $$ #')
        print('# | $$____/ |__/      |__/   \____/  \______/ |__/  |__/|__/ #')
        print('# | $$                                                       #')
        print('# | $$                                                       #')
        print('# |__/                                                       #')
        print('#                                                            #')
        print('##############################################################')

    pritunl.init_server()
Exemple #34
0
def upsert_indexes():
    prefix = settings.conf.mongodb_collection_prefix or ''
    mongo.prefix = prefix

    cur_collections = mongo.database.collection_names()
    if prefix + 'logs' not in cur_collections:
        log_limit = settings.app.log_limit
        mongo.database.create_collection(prefix + 'logs',
                                         capped=True,
                                         size=log_limit * 1024,
                                         max=log_limit)

    if prefix + 'log_entries' not in cur_collections:
        log_entry_limit = settings.app.log_entry_limit
        mongo.database.create_collection(prefix + 'log_entries',
                                         capped=True,
                                         size=log_entry_limit * 512,
                                         max=log_entry_limit)

    cur_collections = mongo.secondary_database.collection_names()
    if prefix + 'messages' not in cur_collections:
        mongo.secondary_database.create_collection(prefix + 'messages',
                                                   capped=True,
                                                   size=5000192,
                                                   max=1000)

    upsert_index('logs', 'timestamp', background=True)
    upsert_index('transaction', 'lock_id', background=True, unique=True)
    upsert_index('transaction', [
        ('ttl_timestamp', pymongo.ASCENDING),
        ('state', pymongo.ASCENDING),
        ('priority', pymongo.DESCENDING),
    ],
                 background=True)
    upsert_index('queue', 'runner_id', background=True)
    upsert_index('queue', 'ttl_timestamp', background=True)
    upsert_index('queue', [
        ('priority', pymongo.ASCENDING),
        ('ttl_timestamp', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('tasks', [
        ('ttl_timestamp', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('tasks', [
        ('ttl_timestamp', pymongo.ASCENDING),
        ('state', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('log_entries', [
        ('timestamp', pymongo.DESCENDING),
    ],
                 background=True)
    upsert_index('messages', 'channel', background=True)
    upsert_index('administrators', 'username', background=True, unique=True)
    upsert_index('users', 'resource_id', background=True)
    upsert_index('users', [
        ('type', pymongo.ASCENDING),
        ('org_id', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('users', [
        ('org_id', pymongo.ASCENDING),
        ('name', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('users', [
        ('name', pymongo.ASCENDING),
        ('auth_type', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('users_audit', [
        ('org_id', pymongo.ASCENDING),
        ('user_id', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('users_audit', [
        ('timestamp', pymongo.DESCENDING),
    ],
                 background=True)
    upsert_index('users_key_link', [
        ('org_id', pymongo.ASCENDING),
        ('user_id', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('users_key_link', 'key_id', background=True)
    upsert_index('users_key_link', 'short_id', background=True, unique=True)
    upsert_index('users_net_link', 'user_id', background=True)
    upsert_index('users_net_link', 'org_id', background=True)
    upsert_index('users_net_link', 'network', background=True)
    upsert_index('clients', 'user_id', background=True)
    upsert_index('clients', 'domain', background=True)
    upsert_index('clients', 'virt_address_num', background=True)
    upsert_index('clients', [
        ('server_id', pymongo.ASCENDING),
        ('type', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('clients', [
        ('host_id', pymongo.ASCENDING),
        ('type', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('clients_pool', 'client_id', background=True)
    upsert_index('clients_pool', 'timestamp', background=True)
    upsert_index('clients_pool', [
        ('server_id', pymongo.ASCENDING),
        ('user_id', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('organizations', 'type', background=True)
    upsert_index('organizations', 'auth_token', background=True)
    upsert_index('hosts', 'name', background=True)
    upsert_index('hosts_usage', [
        ('host_id', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('servers', 'name', background=True)
    upsert_index('servers', 'ping_timestamp', background=True)
    upsert_index('servers_output', [
        ('server_id', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('servers_output_link', [
        ('server_id', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('servers_bandwidth', [
        ('server_id', pymongo.ASCENDING),
        ('period', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('servers_ip_pool', [
        ('server_id', pymongo.ASCENDING),
        ('user_id', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('servers_ip_pool', [
        ('server_id', pymongo.ASCENDING),
        ('_id', pymongo.DESCENDING),
    ],
                 background=True)
    upsert_index('servers_ip_pool', 'user_id', background=True)
    upsert_index('links_hosts', 'link_id', background=True)
    upsert_index('links_hosts', [
        ('location_id', pymongo.ASCENDING),
        ('status', pymongo.ASCENDING),
        ('active', pymongo.ASCENDING),
        ('priority', pymongo.DESCENDING),
    ],
                 background=True)
    upsert_index('links_hosts', [
        ('location_id', pymongo.ASCENDING),
        ('static', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('links_hosts', [
        ('location_id', pymongo.ASCENDING),
        ('name', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('links_hosts', 'ping_timestamp_ttl', background=True)
    upsert_index('links_locations', 'link_id', background=True)
    upsert_index('routes_reserve', 'timestamp', background=True)
    upsert_index('dh_params', 'dh_param_bits', background=True)
    upsert_index('auth_nonces', [
        ('token', pymongo.ASCENDING),
        ('nonce', pymongo.ASCENDING),
    ],
                 background=True,
                 unique=True)
    upsert_index('otp_cache', [
        ('user_id', pymongo.ASCENDING),
        ('server_id', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('sso_push_cache', [
        ('user_id', pymongo.ASCENDING),
        ('server_id', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('sso_client_cache', [
        ('user_id', pymongo.ASCENDING),
        ('server_id', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('sso_passcode_cache', [
        ('user_id', pymongo.ASCENDING),
        ('server_id', pymongo.ASCENDING),
    ],
                 background=True)
    upsert_index('vxlans', 'server_id', background=True, unique=True)

    upsert_index('tasks', 'timestamp', background=True, expireAfterSeconds=300)
    if settings.app.demo_mode:
        drop_index(mongo.get_collection('clients'),
                   'timestamp',
                   background=True)
    else:
        upsert_index('clients',
                     'timestamp',
                     background=True,
                     expireAfterSeconds=settings.vpn.client_ttl)
        upsert_index('clients_pool',
                     'timestamp',
                     background=True,
                     expireAfterSeconds=settings.vpn.client_ttl)
    upsert_index('users_key_link',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=settings.app.key_link_timeout)
    upsert_index('acme_challenges',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=180)
    upsert_index('auth_sessions',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=settings.app.session_timeout)
    upsert_index('auth_nonces',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=max(
                     settings.app.auth_time_window * 2,
                     settings.app.auth_expire_window,
                 ))
    upsert_index('auth_csrf_tokens',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=604800)
    upsert_index('auth_limiter',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=settings.app.auth_limiter_ttl)
    upsert_index('wg_keys',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=settings.app.wg_public_key_ttl)
    upsert_index('otp', 'timestamp', background=True, expireAfterSeconds=120)
    upsert_index('otp_cache',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=settings.app.sso_cache_timeout)
    upsert_index('yubikey',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=86400)
    upsert_index('sso_tokens',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=600)
    upsert_index('sso_push_cache',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=settings.app.sso_cache_timeout)
    upsert_index('sso_client_cache',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=settings.app.sso_client_cache_timeout +
                 settings.app.sso_client_cache_window)
    upsert_index('sso_passcode_cache',
                 'timestamp',
                 background=True,
                 expireAfterSeconds=settings.app.sso_cache_timeout)

    try:
        clean_indexes()
    except:
        logger.exception('Failed to clean indexes', 'setup')
Exemple #35
0
def setup_mongo():
    prefix = settings.conf.mongodb_collection_prefix or ''
    read_pref = _get_read_pref(settings.conf.mongodb_read_preference)
    max_pool = settings.conf.mongodb_max_pool_size or None
    last_error = time.time() - 24

    while True:
        try:

            if read_pref:
                client = pymongo.MongoClient(
                    settings.conf.mongodb_uri,
                    connectTimeoutMS=MONGO_CONNECT_TIMEOUT,
                    socketTimeoutMS=MONGO_SOCKET_TIMEOUT,
                    serverSelectionTimeoutMS=MONGO_SOCKET_TIMEOUT,
                    maxPoolSize=max_pool,
                    read_preference=read_pref,
                )
            else:
                client = pymongo.MongoClient(
                    settings.conf.mongodb_uri,
                    connectTimeoutMS=MONGO_CONNECT_TIMEOUT,
                    socketTimeoutMS=MONGO_SOCKET_TIMEOUT,
                    serverSelectionTimeoutMS=MONGO_SOCKET_TIMEOUT,
                    maxPoolSize=max_pool,
                )

            break
        except pymongo.errors.ConnectionFailure:
            time.sleep(0.5)
            if time.time() - last_error > 30:
                last_error = time.time()
                logger.exception('Error connecting to mongodb server')

    database = client.get_default_database()

    settings_col = getattr(database, prefix + 'settings')
    app_settings = settings_col.find_one({'_id': 'app'})
    if app_settings:
        secondary_mongodb_uri = app_settings.get('secondary_mongodb_uri')
    else:
        secondary_mongodb_uri = None

    if secondary_mongodb_uri:
        while True:
            try:
                read_pref = _get_read_pref(
                    settings.conf.mongodb_read_preference)

                if read_pref:
                    secondary_client = pymongo.MongoClient(
                        secondary_mongodb_uri,
                        connectTimeoutMS=MONGO_CONNECT_TIMEOUT,
                        socketTimeoutMS=MONGO_SOCKET_TIMEOUT,
                        serverSelectionTimeoutMS=MONGO_SOCKET_TIMEOUT,
                        maxPoolSize=max_pool,
                        read_preference=read_pref,
                    )
                else:
                    secondary_client = pymongo.MongoClient(
                        secondary_mongodb_uri,
                        connectTimeoutMS=MONGO_CONNECT_TIMEOUT,
                        socketTimeoutMS=MONGO_SOCKET_TIMEOUT,
                        serverSelectionTimeoutMS=MONGO_SOCKET_TIMEOUT,
                        maxPoolSize=max_pool,
                    )

                break
            except pymongo.errors.ConnectionFailure:
                time.sleep(0.5)
                if time.time() - last_error > 30:
                    last_error = time.time()
                    logger.exception(
                        'Error connecting to secondary mongodb server')

        secondary_database = secondary_client.get_default_database()
    else:
        secondary_database = database

    mongo.database = database
    mongo.secondary_database = secondary_database

    cur_collections = database.collection_names()
    cur_sec_collections = secondary_database.collection_names()
    if 'authorities' in cur_collections or \
            'authorities' in cur_sec_collections:
        raise TypeError('Cannot connect to a Pritunl Zero database')

    mongo.collection_types = {
        'transaction': 1,
        'queue': 1,
        'tasks': 1,
        'settings': 1,
        'messages': 2,
        'administrators': 1,
        'users': 1,
        'users_audit': 1,
        'users_key_link': 2,
        'users_net_link': 1,
        'clients': 1,
        'clients_pool': 1,
        'organizations': 1,
        'hosts': 1,
        'hosts_usage': 1,
        'servers': 1,
        'servers_output': 1,
        'servers_output_link': 1,
        'servers_bandwidth': 1,
        'servers_ip_pool': 1,
        'links': 1,
        'links_locations': 1,
        'links_hosts': 1,
        'routes_reserve': 1,
        'dh_params': 1,
        'acme_challenges': 1,
        'auth_sessions': 2,
        'auth_csrf_tokens': 2,
        'auth_nonces': 2,
        'auth_limiter': 2,
        'wg_keys': 2,
        'otp': 2,
        'otp_cache': 2,
        'yubikey': 2,
        'sso_tokens': 2,
        'sso_push_cache': 2,
        'sso_client_cache': 2,
        'sso_passcode_cache': 2,
        'vxlans': 1,
        'logs': 1,
        'log_entries': 1,
    }

    cur_collections = mongo.secondary_database.collection_names()
    if prefix + 'messages' not in cur_collections:
        mongo.secondary_database.create_collection(prefix + 'messages',
                                                   capped=True,
                                                   size=5000192,
                                                   max=1000)
    elif not mongo.get_collection('messages').options().get('capped'):
        mongo.get_collection('messages').drop()
        mongo.secondary_database.create_collection(prefix + 'messages',
                                                   capped=True,
                                                   size=5000192,
                                                   max=1000)

    settings.local.mongo_time = None

    while True:
        try:
            utils.sync_time()
            break
        except:
            logger.exception('Failed to sync time', 'setup')
            time.sleep(30)

    settings.init()

    upsert_indexes()

    if not auth.Administrator.collection.find_one():
        default_admin = auth.Administrator(username=DEFAULT_USERNAME, )
        default_admin.generate_default_password()
        default_admin.commit()

    secret_key = settings.app.cookie_secret
    secret_key2 = settings.app.cookie_secret2
    settings_commit = False

    if not secret_key:
        settings_commit = True
        secret_key = utils.rand_str(64)
        settings.app.cookie_secret = secret_key

    if not secret_key2:
        settings_commit = True
        settings.app.cookie_secret2 = utils.rand_str(64)

    if settings_commit:
        settings.commit()

    app.app.secret_key = secret_key.encode()
Exemple #36
0
    def _run_thread(self, send_events):
        from pritunl.server.utils import get_by_id

        logger.debug(
            'Starting ovpn process',
            'server',
            server_id=self.server.id,
        )

        self.resources_acquire()
        try:
            cursor_id = self.get_cursor_id()

            os.makedirs(self._temp_path)

            self.enable_ip_forwarding()
            self.bridge_start()
            self.generate_ovpn_conf()

            self.generate_iptables_rules()
            self.iptables.upsert_rules()

            self.init_route_advertisements()

            self.process = self.openvpn_start()
            if not self.process:
                return

            self.start_threads(cursor_id)

            self.instance_com = ServerInstanceCom(self.server, self)
            self.instance_com.start()

            self.publish('started')

            if send_events:
                event.Event(type=SERVERS_UPDATED)
                event.Event(type=SERVER_HOSTS_UPDATED,
                            resource_id=self.server.id)
                for org_id in self.server.organizations:
                    event.Event(type=USERS_UPDATED, resource_id=org_id)

            for link_doc in self.server.links:
                if self.server.id > link_doc['server_id']:
                    instance_link = ServerInstanceLink(
                        server=self.server,
                        linked_server=get_by_id(link_doc['server_id']),
                    )
                    self.server_links.append(instance_link)
                    instance_link.start()

            self.openvpn_watch()

            self.interrupt = True
            self.bridge_stop()
            self.iptables.clear_rules()
            self.resources_release()

            if not self.clean_exit:
                event.Event(type=SERVERS_UPDATED)
                self.server.send_link_events()
                logger.LogEntry(message='Server stopped unexpectedly "%s".' %
                                (self.server.name))
        except:
            self.interrupt = True
            self.stop_process()
            if self.resource_lock:
                self.iptables.clear_rules()
                self.bridge_stop()
            self.resources_release()

            logger.exception(
                'Server error occurred while running',
                'server',
                server_id=self.server.id,
            )
        finally:
            self.stop_threads()
            self.collection.update(
                {
                    '_id': self.server.id,
                    'instances.instance_id': self.id,
                }, {
                    '$pull': {
                        'instances': {
                            'instance_id': self.id,
                        },
                    },
                    '$inc': {
                        'instances_count': -1,
                    },
                })
            utils.rmtree(self._temp_path)
Exemple #37
0
    def generate_iptables_rules(self):
        server_addr = utils.get_network_gateway(self.server.network)
        server_addr6 = utils.get_network_gateway(self.server.network6)
        ipv6_firewall = self.server.ipv6_firewall and \
            settings.local.host.routed_subnet6

        self.iptables.id = self.server.id
        self.iptables.server_addr = server_addr
        self.iptables.server_addr6 = server_addr6
        self.iptables.virt_interface = self.interface
        self.iptables.virt_network = self.server.network
        self.iptables.virt_network6 = self.server.network6
        self.iptables.ipv6_firewall = ipv6_firewall
        self.iptables.inter_client = self.server.inter_client

        try:
            routes_output = utils.check_output_logged(['route', '-n'])
        except subprocess.CalledProcessError:
            logger.exception(
                'Failed to get IP routes',
                'server',
                server_id=self.server.id,
            )
            raise

        routes = []
        default_interface = None
        for line in routes_output.splitlines():
            line_split = line.split()
            if len(line_split) < 8 or not re.match(IP_REGEX, line_split[0]):
                continue
            if line_split[0] not in routes:
                if line_split[0] == '0.0.0.0':
                    default_interface = line_split[7]

                routes.append((ipaddress.IPNetwork(
                    '%s/%s' %
                    (line_split[0], utils.subnet_to_cidr(line_split[2]))),
                               line_split[7]))
        routes.reverse()

        if not default_interface:
            raise IptablesError('Failed to find default network interface')

        routes6 = []
        default_interface6 = None
        if self.server.ipv6:
            try:
                routes_output = utils.check_output_logged(
                    ['route', '-n', '-A', 'inet6'])
            except subprocess.CalledProcessError:
                logger.exception(
                    'Failed to get IPv6 routes',
                    'server',
                    server_id=self.server.id,
                )
                raise

            for line in routes_output.splitlines():
                line_split = line.split()

                if len(line_split) < 7:
                    continue

                try:
                    route_network = ipaddress.IPv6Network(line_split[0])
                except (ipaddress.AddressValueError, ValueError):
                    continue

                if not default_interface6 and line_split[0] == '::/0':
                    default_interface6 = line_split[6]

                routes6.append((route_network, line_split[6]))

            if not default_interface6:
                raise IptablesError(
                    'Failed to find default IPv6 network interface')

            if default_interface6 == 'lo':
                logger.error(
                    'Failed to find default IPv6 interface',
                    'server',
                    server_id=self.server.id,
                )
        routes6.reverse()

        interfaces = set()
        interfaces6 = set()

        for route in self.server.get_routes(
                include_hidden=True,
                include_server_links=True,
                include_default=True,
        ):
            if route['virtual_network'] or route['link_virtual_network']:
                self.iptables.add_nat_network(route['network'])

            if route['virtual_network']:
                continue

            network = route['network']
            is6 = ':' in network
            network_obj = ipaddress.IPNetwork(network)

            interface = None
            if is6:
                for route_net, route_intf in routes6:
                    if network_obj in route_net:
                        interface = route_intf
                        break

                if not interface:
                    logger.info(
                        'Failed to find interface for local ' + \
                            'IPv6 network route, using default route',
                            'server',
                        server_id=self.server.id,
                        network=network,
                    )
                    interface = default_interface6
                interfaces6.add(interface)
            else:
                for route_net, route_intf in routes:
                    if network_obj in route_net:
                        interface = route_intf
                        break

                if not interface:
                    logger.info(
                        'Failed to find interface for local ' + \
                            'network route, using default route',
                            'server',
                        server_id=self.server.id,
                        network=network,
                    )
                    interface = default_interface
                interfaces.add(interface)

            self.iptables.add_route(
                network,
                nat=route['nat'],
                nat_interface=interface,
            )

        self.iptables.generate()
Exemple #38
0
    def task(self):
        if settings.app.demo_mode:
            return

        try:
            timestamp = utils.now()
            timestamp_spec = timestamp - datetime.timedelta(
                seconds=settings.vpn.server_ping_ttl)

            docs = self.server_collection.find({
                'instances.ping_timestamp': {'$lt': timestamp_spec},
            }, {
                '_id': True,
                'instances': True,
            })

            yield

            for doc in docs:
                for instance in doc['instances']:
                    if instance['ping_timestamp'] < timestamp_spec:
                        logger.warning('Removing instance doc', 'server',
                            server_id=doc['_id'],
                            instance_id=instance['instance_id'],
                            cur_timestamp=timestamp,
                            ttl_timestamp=timestamp_spec,
                            ping_timestamp=instance['ping_timestamp'],
                        )

                        self.server_collection.update({
                            '_id': doc['_id'],
                            'instances.instance_id': instance['instance_id'],
                        }, {
                            '$pull': {
                                'instances': {
                                    'instance_id': instance['instance_id'],
                                },
                            },
                            '$inc': {
                                'instances_count': -1,
                            },
                        })

            yield

            docs = self.host_collection.find({
                'status': ONLINE,
            }, {
                '_id': True,
                'availability_group': True,
            })

            yield

            hosts_group = {}
            for doc in docs:
                hosts_group[doc['_id']] = doc.get(
                    'availability_group', DEFAULT)

            yield

            response = self.server_collection.aggregate([
                {'$match': {
                    'status': ONLINE,
                    'start_timestamp': {'$lt': timestamp_spec},
                }},
                {'$project': {
                    '_id': True,
                    'hosts': True,
                    'instances': True,
                    'replica_count': True,
                    'availability_group': True,
                    'offline_instances_count': {
                        '$subtract': [
                            '$replica_count',
                            '$instances_count',
                        ],
                    }
                }},
                {'$match': {
                    'offline_instances_count': {'$gt': 0},
                }},
            ])

            yield

            recover_count = 0

            for doc in response:
                cur_avail_group = doc.get('availability_group', DEFAULT)

                hosts_set = set(doc['hosts'])
                group_best = None
                group_len_max = 0
                server_groups = collections.defaultdict(set)

                for hst in hosts_set:
                    avail_zone = hosts_group.get(hst)
                    if not avail_zone:
                        continue

                    server_groups[avail_zone].add(hst)
                    group_len = len(server_groups[avail_zone])

                    if group_len > group_len_max:
                        group_len_max = group_len
                        group_best = avail_zone
                    elif group_len == group_len_max and \
                            avail_zone == cur_avail_group:
                        group_best = avail_zone

                if group_best and cur_avail_group != group_best:
                    logger.info(
                        'Rebalancing server availability group',
                        'server',
                        server_id=doc['_id'],
                        current_availability_group=cur_avail_group,
                        new_availability_group=group_best,
                    )

                    self.server_collection.update({
                        '_id': doc['_id'],
                        'status': ONLINE,
                    }, {'$set': {
                        'instances': [],
                        'instances_count': 0,
                        'availability_group': group_best,
                    }})

                    messenger.publish('servers', 'rebalance', extra={
                        'server_id': doc['_id'],
                        'availability_group': group_best,
                    })

                    prefered_hosts = server_groups[group_best]
                else:
                    prefered_hosts = server_groups[cur_avail_group]

                active_hosts = set(
                    [x['host_id'] for x in doc['instances']])
                prefered_hosts = list(prefered_hosts - active_hosts)
                if not prefered_hosts:
                    continue

                if recover_count >= 3:
                    continue
                recover_count += 1

                logger.info('Recovering server state', 'server',
                    server_id=doc['_id'],
                    prefered_hosts=prefered_hosts,
                )

                messenger.publish('servers', 'start', extra={
                    'server_id': doc['_id'],
                    'send_events': True,
                    'prefered_hosts': host.get_prefered_hosts(
                        prefered_hosts, doc['replica_count'])
                })
        except GeneratorExit:
            raise
        except:
            logger.exception('Error checking server states', 'tasks')
Exemple #39
0
    def _keep_alive_thread(self):
        try:
            error_count = 0

            while not self.interrupt:
                try:
                    doc = self.collection.find_and_modify({
                        '_id': self.server.id,
                        'availability_group': \
                            settings.local.host.availability_group,
                        'instances.instance_id': self.id,
                    }, {'$set': {
                        'instances.$.ping_timestamp': utils.now(),
                    }}, fields={
                        '_id': False,
                        'instances': True,
                    }, new=True)

                    yield

                    if not doc:
                        doc = self.collection.find_one({
                            '_id': self.server.id,
                        })

                        doc_hosts = ((doc or {}).get('hosts') or [])
                        if settings.local.host_id in doc_hosts:
                            logger.error(
                                'Instance doc lost, stopping server',
                                'server',
                                server_id=self.server.id,
                                instance_id=self.id,
                                cur_timestamp=utils.now(),
                            )

                        if self.stop_process():
                            break
                        else:
                            time.sleep(0.1)
                            continue
                    else:
                        error_count = 0

                    yield
                except GeneratorExit:
                    self.stop_process()
                except:
                    error_count += 1
                    if error_count >= 2 and self.stop_process():
                        logger.exception(
                            'Failed to update server ping, stopping server',
                            'server',
                            server_id=self.server.id,
                        )
                        break

                    logger.exception(
                        'Failed to update server ping',
                        'server',
                        server_id=self.server.id,
                    )
                    time.sleep(2)

                yield interrupter_sleep(settings.vpn.server_ping)
        except GeneratorExit:
            self.stop_process()
Exemple #40
0
        if self.state == state:
            if not password and self.user.has_pin():
                self.state = None
            else:
                return password

    def _check_call(self, func):
        try:
            func()
        except AuthError, err:
            self._callback(False, str(err))
            raise
        except AuthForked:
            raise
        except:
            logger.exception('Exception in user authorize', 'authorize')
            self._callback(False, 'Unknown error occurred')
            raise

    def _callback(self, allow, reason=None):
        if allow:
            try:
                self._check_call(self._update_token)
            except:
                return

        self.callback(allow, reason)

    def _check_token(self):
        if settings.app.sso_client_cache and self.auth_token:
            doc = self.sso_client_cache_collection.find_one({
Exemple #41
0
def setup_clean():
    try:
        try:
            utils.check_call_silent([
                'killall',
                'openvpn',
            ])
        except subprocess.CalledProcessError:
            pass

        try:
            utils.check_call_silent([
                'killall',
                'openssl',
            ])
        except subprocess.CalledProcessError:
            pass

        try:
            utils.check_call_silent([
                'killall',
                'pritunl-dns',
            ])
        except subprocess.CalledProcessError:
            pass

        try:
            utils.check_call_silent([
                'killall',
                'pritunl-web',
            ])
        except subprocess.CalledProcessError:
            pass

        time.sleep(2)

        try:
            utils.check_call_silent([
                'killall',
                '-s9',
                'openvpn',
            ])
        except subprocess.CalledProcessError:
            pass

        try:
            utils.check_call_silent([
                'killall',
                '-s9',
                'openssl',
            ])
        except subprocess.CalledProcessError:
            pass

        try:
            utils.check_call_silent([
                'killall',
                '-s9',
                'pritunl-dns',
            ])
        except subprocess.CalledProcessError:
            pass

        try:
            utils.check_call_silent([
                'killall',
                '-s9',
                'pritunl-web',
            ])
        except subprocess.CalledProcessError:
            pass

        output = subprocess.check_output([
            'ip',
            '-o',
            'link',
            'show',
        ])

        for line in output.splitlines():
            iface_name = line.split(':')
            if len(iface_name) < 2:
                continue
            iface_name = iface_name[1].strip()

            if not iface_name.startswith('pxlan'):
                continue

            try:
                utils.check_call_silent([
                    'ip',
                    'link',
                    'set',
                    'down',
                    iface_name,
                ])
            except subprocess.CalledProcessError:
                pass

            try:
                utils.check_call_silent([
                    'ip',
                    'link',
                    'del',
                    iface_name,
                ])
            except subprocess.CalledProcessError:
                pass

        output = subprocess.check_output([
            'iptables-save',
        ])

        table = None
        for line in output.splitlines():
            line = line.strip()

            if line in ('*nat', '*filter'):
                table = line[1:]
                continue

            if '--comment pritunl_' not in line:
                continue

            try:
                utils.check_call_silent([
                    'iptables -t %s -D %s' % (table, line[3:]),
                ],
                                        shell=True)
            except subprocess.CalledProcessError:
                pass
    except:
        logger.exception('Server clean failed', 'setup')
Exemple #42
0
    def _run_thread(self, send_events):
        from pritunl.server.utils import get_by_id

        logger.info(
            'Starting vpn server',
            'server',
            server_id=self.server.id,
            instance_id=self.id,
            network=self.server.network,
            network6=self.server.network6,
            host_address=settings.local.host.local_addr,
            host_address6=settings.local.host.local_addr6,
            host_networks=settings.local.host.local_networks,
            cur_timestamp=utils.now(),
        )

        def timeout():
            logger.error(
                'Server startup timed out, stopping...',
                'server',
                server_id=self.server.id,
                instance_id=self.id,
                state=self.state,
            )
            self.stop_process()

        self.state = 'init'
        timer = threading.Timer(settings.vpn.op_timeout + 3, timeout)
        timer.start()

        try:
            self.resources_acquire()

            cursor_id = self.get_cursor_id()

            if self.is_interrupted():
                return

            self.state = 'temp_path'
            os.makedirs(self._temp_path)

            if self.is_interrupted():
                return

            self.state = 'ip_forwarding'
            self.enable_ip_forwarding()

            if self.is_interrupted():
                return

            self.state = 'bridge_start'
            self.bridge_start()

            if self.is_interrupted():
                return

            if self.server.replicating and self.server.vxlan:
                try:
                    self.state = 'get_vxlan'
                    self.vxlan = vxlan.get_vxlan(self.server.id,
                                                 self.server.ipv6)

                    if self.is_interrupted():
                        return

                    self.state = 'start_vxlan'
                    self.vxlan.start()

                    if self.is_interrupted():
                        return
                except:
                    logger.exception(
                        'Failed to setup server vxlan',
                        'vxlan',
                        server_id=self.server.id,
                        instance_id=self.id,
                    )

            self.state = 'generate_ovpn_conf'
            self.generate_ovpn_conf()

            if self.is_interrupted():
                return

            self.state = 'generate_iptables_rules'
            self.generate_iptables_rules()

            if self.is_interrupted():
                return

            self.state = 'upsert_iptables_rules'
            self.iptables.upsert_rules()

            if self.is_interrupted():
                return

            self.state = 'init_route_advertisements'
            self.init_route_advertisements()

            if self.is_interrupted():
                return

            self.state = 'openvpn_start'
            self.process = self.openvpn_start()
            self.start_threads(cursor_id)

            if self.is_interrupted():
                return

            self.state = 'instance_com_start'
            self.instance_com = ServerInstanceCom(self.server, self)
            self.instance_com.start()

            if self.is_interrupted():
                return

            self.state = 'starting'
            self.publish('started')

            if self.is_interrupted():
                return

            if send_events:
                self.state = 'events'
                event.Event(type=SERVERS_UPDATED)
                event.Event(type=SERVER_HOSTS_UPDATED,
                            resource_id=self.server.id)
                for org_id in self.server.organizations:
                    event.Event(type=USERS_UPDATED, resource_id=org_id)

                    if self.is_interrupted():
                        return

            for link_doc in self.server.links:
                if self.server.id > link_doc['server_id']:
                    self.state = 'instance_link'
                    instance_link = ServerInstanceLink(
                        server=self.server,
                        linked_server=get_by_id(link_doc['server_id']),
                    )
                    self.server_links.append(instance_link)
                    instance_link.start()

                    if self.is_interrupted():
                        return

            self.state = 'running'
            self.openvpn_output()

            if self.is_interrupted():
                return

            timer.cancel()

            plugins.caller(
                'server_start',
                host_id=settings.local.host_id,
                host_name=settings.local.host.name,
                server_id=self.server.id,
                server_name=self.server.name,
                port=self.server.port,
                protocol=self.server.protocol,
                ipv6=self.server.ipv6,
                ipv6_firewall=self.server.ipv6_firewall,
                network=self.server.network,
                network6=self.server.network6,
                network_mode=self.server.network_mode,
                network_start=self.server.network_start,
                network_stop=self.server.network_end,
                restrict_routes=self.server.restrict_routes,
                bind_address=self.server.bind_address,
                onc_hostname=self.server.onc_hostname,
                dh_param_bits=self.server.dh_param_bits,
                multi_device=self.server.multi_device,
                dns_servers=self.server.dns_servers,
                search_domain=self.server.search_domain,
                otp_auth=self.server.otp_auth,
                cipher=self.server.cipher,
                hash=self.server.hash,
                inter_client=self.server.inter_client,
                ping_interval=self.server.ping_interval,
                ping_timeout=self.server.ping_timeout,
                link_ping_interval=self.server.link_ping_interval,
                link_ping_timeout=self.server.link_ping_timeout,
                allowed_devices=self.server.allowed_devices,
                max_clients=self.server.max_clients,
                replica_count=self.server.replica_count,
                dns_mapping=self.server.dns_mapping,
                debug=self.server.debug,
                interface=self.interface,
                bridge_interface=self.bridge_interface,
                vxlan=self.vxlan,
            )
            try:
                while True:
                    if self.process.poll() is not None:
                        break
                    if self.is_interrupted():
                        self.stop_process()
                    time.sleep(0.05)
            finally:
                plugins.caller(
                    'server_stop',
                    host_id=settings.local.host_id,
                    host_name=settings.local.host.name,
                    server_id=self.server.id,
                    server_name=self.server.name,
                    port=self.server.port,
                    protocol=self.server.protocol,
                    ipv6=self.server.ipv6,
                    ipv6_firewall=self.server.ipv6_firewall,
                    network=self.server.network,
                    network6=self.server.network6,
                    network_mode=self.server.network_mode,
                    network_start=self.server.network_start,
                    network_stop=self.server.network_end,
                    restrict_routes=self.server.restrict_routes,
                    bind_address=self.server.bind_address,
                    onc_hostname=self.server.onc_hostname,
                    dh_param_bits=self.server.dh_param_bits,
                    multi_device=self.server.multi_device,
                    dns_servers=self.server.dns_servers,
                    search_domain=self.server.search_domain,
                    otp_auth=self.server.otp_auth,
                    cipher=self.server.cipher,
                    hash=self.server.hash,
                    inter_client=self.server.inter_client,
                    ping_interval=self.server.ping_interval,
                    ping_timeout=self.server.ping_timeout,
                    link_ping_interval=self.server.link_ping_interval,
                    link_ping_timeout=self.server.link_ping_timeout,
                    allowed_devices=self.server.allowed_devices,
                    max_clients=self.server.max_clients,
                    replica_count=self.server.replica_count,
                    dns_mapping=self.server.dns_mapping,
                    debug=self.server.debug,
                    interface=self.interface,
                    bridge_interface=self.bridge_interface,
                    vxlan=self.vxlan,
                )

            if not self.clean_exit:
                event.Event(type=SERVERS_UPDATED)
                self.server.send_link_events()
                logger.LogEntry(message='Server stopped unexpectedly "%s".' %
                                (self.server.name))
        except:
            try:
                self.stop_process()
            except:
                logger.exception(
                    'Server stop error',
                    'server',
                    server_id=self.server.id,
                    instance_id=self.id,
                )

            logger.exception(
                'Server error occurred while running',
                'server',
                server_id=self.server.id,
                instance_id=self.id,
            )
        finally:
            timer.cancel()

            self.interrupt = True
            self.sock_interrupt = True

            try:
                self.bridge_stop()
            except:
                logger.exception(
                    'Failed to remove server bridge',
                    'server',
                    server_id=self.server.id,
                    instance_id=self.id,
                )

            try:
                self.iptables.clear_rules()
            except:
                logger.exception(
                    'Server iptables clean up error',
                    'server',
                    server_id=self.server.id,
                    instance_id=self.id,
                )

            if self.vxlan:
                try:
                    self.vxlan.stop()
                except:
                    logger.exception(
                        'Failed to stop server vxlan',
                        'server',
                        server_id=self.server.id,
                        instance_id=self.id,
                    )

            try:
                self.collection.update(
                    {
                        '_id': self.server.id,
                        'instances.instance_id': self.id,
                    }, {
                        '$pull': {
                            'instances': {
                                'instance_id': self.id,
                            },
                        },
                        '$inc': {
                            'instances_count': -1,
                        },
                    })
                utils.rmtree(self._temp_path)
            except:
                logger.exception(
                    'Server clean up error',
                    'server',
                    server_id=self.server.id,
                    instance_id=self.id,
                )

            try:
                self.resources_release()
            except:
                logger.exception(
                    'Failed to release resources',
                    'server',
                    server_id=self.server.id,
                    instance_id=self.id,
                )
Exemple #43
0
    def add_route(self, virt_address, virt_address6, host_address,
                  host_address6):
        virt_address = virt_address.split('/')[0]

        _route_lock.acquire()
        try:
            if virt_address in self.client_routes:
                try:
                    self.client_routes.remove(virt_address)
                    try:
                        utils.check_call_silent([
                            'ip',
                            'route',
                            'del',
                            virt_address,
                        ])
                    except subprocess.CalledProcessError:
                        pass
                except KeyError:
                    pass

            if not host_address or host_address == \
                    settings.local.host.local_addr:
                return

            for i in xrange(3):
                try:
                    utils.check_output_logged([
                        'ip',
                        'route',
                        'add',
                        virt_address,
                        'via',
                        host_address,
                    ])
                    break
                except subprocess.CalledProcessError:
                    if i == 0:
                        try:
                            utils.check_call_silent([
                                'ip',
                                'route',
                                'del',
                                virt_address,
                            ])
                        except subprocess.CalledProcessError:
                            pass
                    elif i == 2:
                        raise
                    time.sleep(0.2)
        except:
            logger.exception(
                'Failed to add route',
                'clients',
                virt_address=virt_address,
                virt_address6=virt_address6,
                host_address=host_address,
                host_address6=host_address6,
            )
        finally:
            _route_lock.release()
Exemple #44
0
def setup_mongo():
    prefix = settings.conf.mongodb_collection_prefix or ''
    last_error = time.time() - 24

    while True:
        try:
            read_pref = _get_read_pref(settings.conf.mongodb_read_preference)

            if read_pref:
                client = pymongo.MongoClient(
                    settings.conf.mongodb_uri,
                    connectTimeoutMS=MONGO_CONNECT_TIMEOUT,
                    socketTimeoutMS=MONGO_SOCKET_TIMEOUT,
                    serverSelectionTimeoutMS=MONGO_SOCKET_TIMEOUT,
                    read_preference=read_pref,
                )
            else:
                client = pymongo.MongoClient(
                    settings.conf.mongodb_uri,
                    connectTimeoutMS=MONGO_CONNECT_TIMEOUT,
                    socketTimeoutMS=MONGO_SOCKET_TIMEOUT,
                    serverSelectionTimeoutMS=MONGO_SOCKET_TIMEOUT,
                )

            break
        except pymongo.errors.ConnectionFailure:
            time.sleep(0.5)
            if time.time() - last_error > 30:
                last_error = time.time()
                logger.exception('Error connecting to mongodb server')

    database = client.get_default_database()

    settings_col = getattr(database, prefix + 'settings')
    app_settings = settings_col.find_one({'_id': 'app'})
    if app_settings:
        secondary_mongodb_uri = app_settings.get('secondary_mongodb_uri')
    else:
        secondary_mongodb_uri = None

    if secondary_mongodb_uri:
        while True:
            try:
                read_pref = _get_read_pref(
                    settings.conf.mongodb_read_preference)

                if read_pref:
                    secondary_client = pymongo.MongoClient(
                        secondary_mongodb_uri,
                        connectTimeoutMS=MONGO_CONNECT_TIMEOUT,
                        socketTimeoutMS=MONGO_SOCKET_TIMEOUT,
                        serverSelectionTimeoutMS=MONGO_SOCKET_TIMEOUT,
                        read_preference=read_pref,
                    )
                else:
                    secondary_client = pymongo.MongoClient(
                        secondary_mongodb_uri,
                        connectTimeoutMS=MONGO_CONNECT_TIMEOUT,
                        socketTimeoutMS=MONGO_SOCKET_TIMEOUT,
                        serverSelectionTimeoutMS=MONGO_SOCKET_TIMEOUT,
                    )

                break
            except pymongo.errors.ConnectionFailure:
                time.sleep(0.5)
                if time.time() - last_error > 30:
                    last_error = time.time()
                    logger.exception(
                        'Error connecting to secondary mongodb server')

        secondary_database = secondary_client.get_default_database()
    else:
        secondary_database = database

    mongo.database = database
    mongo.secondary_database = secondary_database

    db_collections = database.collection_names()
    cur_collections = secondary_database.collection_names()

    if 'authorities' in db_collections or 'authorities' in cur_collections:
        raise TypeError('Cannot connect to a Pritunl Zero database')

    if prefix + 'messages' not in cur_collections:
        secondary_database.create_collection(prefix + 'messages', capped=True,
            size=5000192, max=1000)

    mongo.collections.update({
        'transaction': getattr(database, prefix + 'transaction'),
        'queue': getattr(database, prefix + 'queue'),
        'tasks': getattr(database, prefix + 'tasks'),
        'settings': getattr(database, prefix + 'settings'),
        'messages': getattr(secondary_database, prefix + 'messages'),
        'administrators': getattr(database, prefix + 'administrators'),
        'users': getattr(database, prefix + 'users'),
        'users_audit': getattr(database, prefix + 'users_audit'),
        'users_key_link': getattr(secondary_database,
            prefix + 'users_key_link'),
        'users_net_link': getattr(database, prefix + 'users_net_link'),
        'clients': getattr(database, prefix + 'clients'),
        'clients_pool': getattr(database, prefix + 'clients_pool'),
        'organizations': getattr(database, prefix + 'organizations'),
        'hosts': getattr(database, prefix + 'hosts'),
        'hosts_usage': getattr(database, prefix + 'hosts_usage'),
        'servers': getattr(database, prefix + 'servers'),
        'servers_output': getattr(database, prefix + 'servers_output'),
        'servers_output_link': getattr(database,
            prefix + 'servers_output_link'),
        'servers_bandwidth': getattr(database, prefix + 'servers_bandwidth'),
        'servers_ip_pool': getattr(database, prefix + 'servers_ip_pool'),
        'links': getattr(database, prefix + 'links'),
        'links_locations': getattr(database, prefix + 'links_locations'),
        'links_hosts': getattr(database, prefix + 'links_hosts'),
        'routes_reserve': getattr(database, prefix + 'routes_reserve'),
        'dh_params': getattr(database, prefix + 'dh_params'),
        'acme_challenges': getattr(database, prefix + 'acme_challenges'),
        'auth_sessions': getattr(secondary_database,
            prefix + 'auth_sessions'),
        'auth_csrf_tokens': getattr(secondary_database,
            prefix + 'auth_csrf_tokens'),
        'auth_nonces': getattr(secondary_database, prefix + 'auth_nonces'),
        'auth_limiter': getattr(secondary_database, prefix + 'auth_limiter'),
        'otp': getattr(secondary_database, prefix + 'otp'),
        'otp_cache': getattr(secondary_database, prefix + 'otp_cache'),
        'yubikey': getattr(secondary_database, prefix + 'yubikey'),
        'sso_tokens': getattr(secondary_database, prefix + 'sso_tokens'),
        'sso_push_cache': getattr(secondary_database,
            prefix + 'sso_push_cache'),
        'sso_client_cache': getattr(secondary_database,
            prefix + 'sso_client_cache'),
        'sso_passcode_cache': getattr(secondary_database,
            prefix + 'sso_passcode_cache'),
        'vxlans': getattr(database, prefix + 'vxlans'),
    })

    for collection_name, collection in mongo.collections.items():
        collection.name_str = collection_name

    settings.local.mongo_time = None

    while True:
        try:
            utils.sync_time()
            break
        except:
            logger.exception('Failed to sync time', 'setup')
            time.sleep(30)

    settings.init()

    cur_collections = database.collection_names()
    if prefix + 'logs' not in cur_collections:
        log_limit = settings.app.log_limit
        database.create_collection(prefix + 'logs', capped=True,
            size=log_limit * 1024, max=log_limit)

    if prefix + 'log_entries' not in cur_collections:
        log_entry_limit = settings.app.log_entry_limit
        database.create_collection(prefix + 'log_entries', capped=True,
            size=log_entry_limit * 512, max=log_entry_limit)

    mongo.collections.update({
        'logs': getattr(database, prefix + 'logs'),
        'log_entries': getattr(database, prefix + 'log_entries'),
    })
    mongo.collections['logs'].name_str = 'logs'
    mongo.collections['log_entries'].name_str = 'log_entries'

    upsert_index('logs', 'timestamp', background=True)
    upsert_index('transaction', 'lock_id',
        background=True, unique=True)
    upsert_index('transaction', [
        ('ttl_timestamp', pymongo.ASCENDING),
        ('state', pymongo.ASCENDING),
        ('priority', pymongo.DESCENDING),
    ], background=True)
    upsert_index('queue', 'runner_id', background=True)
    upsert_index('queue', 'ttl_timestamp', background=True)
    upsert_index('queue', [
        ('priority', pymongo.ASCENDING),
        ('ttl_timestamp', pymongo.ASCENDING),
    ], background=True)
    upsert_index('tasks', [
        ('ttl_timestamp', pymongo.ASCENDING),
    ], background=True)
    upsert_index('tasks', [
        ('ttl_timestamp', pymongo.ASCENDING),
        ('state', pymongo.ASCENDING),
    ], background=True)
    upsert_index('log_entries', [
        ('timestamp', pymongo.DESCENDING),
    ], background=True)
    upsert_index('messages', 'channel', background=True)
    upsert_index('administrators', 'username',
        background=True, unique=True)
    upsert_index('users', 'resource_id', background=True)
    upsert_index('users', [
        ('type', pymongo.ASCENDING),
        ('org_id', pymongo.ASCENDING),
    ], background=True)
    upsert_index('users', [
        ('org_id', pymongo.ASCENDING),
        ('name', pymongo.ASCENDING),
    ], background=True)
    upsert_index('users', [
        ('name', pymongo.ASCENDING),
        ('auth_type', pymongo.ASCENDING),
    ], background=True)
    upsert_index('users_audit', [
        ('org_id', pymongo.ASCENDING),
        ('user_id', pymongo.ASCENDING),
    ], background=True)
    upsert_index('users_audit', [
        ('timestamp', pymongo.DESCENDING),
    ], background=True)
    upsert_index('users_key_link', 'key_id',
        background=True)
    upsert_index('users_key_link', 'short_id',
        background=True, unique=True)
    upsert_index('users_net_link', 'user_id',
        background=True)
    upsert_index('users_net_link', 'org_id',
        background=True)
    upsert_index('users_net_link', 'network',
        background=True)
    upsert_index('clients', 'user_id', background=True)
    upsert_index('clients', 'domain', background=True)
    upsert_index('clients', 'virt_address_num',
        background=True)
    upsert_index('clients', [
        ('server_id', pymongo.ASCENDING),
        ('type', pymongo.ASCENDING),
    ], background=True)
    upsert_index('clients', [
        ('host_id', pymongo.ASCENDING),
        ('type', pymongo.ASCENDING),
    ], background=True)
    upsert_index('clients_pool',
        'client_id', background=True)
    upsert_index('clients_pool',
        'timestamp', background=True)
    upsert_index('clients_pool', [
        ('server_id', pymongo.ASCENDING),
        ('user_id', pymongo.ASCENDING),
    ], background=True)
    upsert_index('organizations', 'type', background=True)
    upsert_index('organizations',
        'auth_token', background=True)
    upsert_index('hosts', 'name', background=True)
    upsert_index('hosts_usage', [
        ('host_id', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ], background=True)
    upsert_index('servers', 'name', background=True)
    upsert_index('servers', 'ping_timestamp',
        background=True)
    upsert_index('servers_output', [
        ('server_id', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ], background=True)
    upsert_index('servers_output_link', [
        ('server_id', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ], background=True)
    upsert_index('servers_bandwidth', [
        ('server_id', pymongo.ASCENDING),
        ('period', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ], background=True)
    upsert_index('servers_ip_pool', [
        ('server_id', pymongo.ASCENDING),
        ('user_id', pymongo.ASCENDING),
    ], background=True)
    upsert_index('servers_ip_pool', [
        ('server_id', pymongo.ASCENDING),
        ('_id', pymongo.DESCENDING),
    ], background=True)
    upsert_index('servers_ip_pool', 'user_id',
        background=True)
    upsert_index('links_hosts', 'link_id',
        background=True)
    upsert_index('links_hosts', [
        ('location_id', pymongo.ASCENDING),
        ('status', pymongo.ASCENDING),
        ('active', pymongo.ASCENDING),
        ('priority', pymongo.DESCENDING),
    ], background=True)
    upsert_index('links_hosts', [
        ('location_id', pymongo.ASCENDING),
        ('static', pymongo.ASCENDING),
    ], background=True)
    upsert_index('links_hosts', [
        ('location_id', pymongo.ASCENDING),
        ('name', pymongo.ASCENDING),
    ], background=True)
    upsert_index('links_hosts', 'ping_timestamp_ttl',
        background=True)
    upsert_index('links_locations', 'link_id',
        background=True)
    upsert_index('routes_reserve', 'timestamp',
        background=True)
    upsert_index('dh_params', 'dh_param_bits',
        background=True)
    upsert_index('auth_nonces', [
        ('token', pymongo.ASCENDING),
        ('nonce', pymongo.ASCENDING),
    ], background=True, unique=True)
    upsert_index('otp_cache', [
        ('user_id', pymongo.ASCENDING),
        ('server_id', pymongo.ASCENDING),
    ], background=True)
    upsert_index('sso_push_cache', [
        ('user_id', pymongo.ASCENDING),
        ('server_id', pymongo.ASCENDING),
    ], background=True)
    upsert_index('sso_client_cache', [
        ('user_id', pymongo.ASCENDING),
        ('server_id', pymongo.ASCENDING),
    ], background=True)
    upsert_index('sso_passcode_cache', [
        ('user_id', pymongo.ASCENDING),
        ('server_id', pymongo.ASCENDING),
    ], background=True)
    upsert_index('vxlans', 'server_id',
        background=True, unique=True)

    upsert_index('tasks', 'timestamp',
        background=True, expireAfterSeconds=300)
    if settings.app.demo_mode:
        drop_index(mongo.collections['clients'], 'timestamp', background=True)
    else:
        upsert_index('clients', 'timestamp',
            background=True, expireAfterSeconds=settings.vpn.client_ttl)
        upsert_index('clients_pool', 'timestamp',
            background=True, expireAfterSeconds=settings.vpn.client_ttl)
    upsert_index('users_key_link', 'timestamp',
        background=True, expireAfterSeconds=settings.app.key_link_timeout)
    upsert_index('acme_challenges', 'timestamp',
        background=True, expireAfterSeconds=180)
    upsert_index('auth_sessions', 'timestamp',
        background=True, expireAfterSeconds=settings.app.session_timeout)
    upsert_index('auth_nonces', 'timestamp',
        background=True,
        expireAfterSeconds=settings.app.auth_time_window * 2.1)
    upsert_index('auth_csrf_tokens', 'timestamp',
        background=True, expireAfterSeconds=604800)
    upsert_index('auth_limiter', 'timestamp',
        background=True, expireAfterSeconds=settings.app.auth_limiter_ttl)
    upsert_index('otp', 'timestamp', background=True,
        expireAfterSeconds=120)
    upsert_index('otp_cache', 'timestamp',
        background=True, expireAfterSeconds=settings.vpn.otp_cache_timeout)
    upsert_index('yubikey', 'timestamp',
        background=True, expireAfterSeconds=86400)
    upsert_index('sso_tokens', 'timestamp',
        background=True, expireAfterSeconds=600)
    upsert_index('sso_push_cache', 'timestamp',
        background=True, expireAfterSeconds=settings.app.sso_cache_timeout)
    upsert_index('sso_client_cache', 'timestamp',
        background=True,
        expireAfterSeconds=settings.app.sso_client_cache_timeout +
            settings.app.sso_client_cache_window)
    upsert_index('sso_passcode_cache', 'timestamp',
        background=True, expireAfterSeconds=settings.app.sso_cache_timeout)

    try:
        clean_indexes()
    except:
        logger.exception('Failed to clean indexes', 'setup')

    if not auth.Administrator.collection.find_one():
        auth.Administrator(
            username=DEFAULT_USERNAME,
            password=DEFAULT_PASSWORD,
            default=True,
        ).commit()

    secret_key = settings.app.cookie_secret
    secret_key2 = settings.app.cookie_secret2
    settings_commit = False

    if not secret_key:
        settings_commit = True
        secret_key = utils.rand_str(64)
        settings.app.cookie_secret = secret_key

    if not secret_key2:
        settings_commit = True
        settings.app.cookie_secret2 = utils.rand_str(64)

    if settings_commit:
        settings.commit()

    app.app.secret_key = secret_key.encode()
Exemple #45
0
    def _connected(self, client_id):
        client = self.clients.find_id(client_id)
        if not client:
            self.instance_com.push_output(
                'ERROR Unknown client connected client_id=%s' % client_id)
            self.instance_com.client_kill(client_id)
            return

        self.set_iptables_rules(
            client['iptables_rules'],
            client['ip6tables_rules'],
        )

        timestamp = utils.now()
        doc = {
            'user_id': client['user_id'],
            'server_id': self.server.id,
            'host_id': settings.local.host_id,
            'timestamp': timestamp,
            'platform': client['platform'],
            'type': client['user_type'],
            'device_name': client['device_name'],
            'mac_addr': client['mac_addr'],
            'network': self.server.network,
            'real_address': client['real_address'],
            'virt_address': client['virt_address'],
            'virt_address6': client['virt_address6'],
            'host_address': settings.local.host.local_addr,
            'host_address6': settings.local.host.local_addr6,
            'dns_servers': client['dns_servers'],
            'dns_suffix': client['dns_suffix'],
            'connected_since': int(timestamp.strftime('%s')),
        }

        if settings.local.sub_active and \
                settings.local.sub_plan == 'enterprise':
            domain_hash = hashlib.md5()
            domain_hash.update((client['user_name'].split('@')[0] + '.' +
                                client['org_name']).lower())
            domain_hash = bson.binary.Binary(domain_hash.digest(),
                                             subtype=bson.binary.MD5_SUBTYPE)
            doc['domain'] = domain_hash

        try:
            doc_id = self.collection.insert(doc)
            if self.route_clients:
                messenger.publish(
                    'client', {
                        'state': True,
                        'virt_address': client['virt_address'],
                        'virt_address6': client['virt_address6'],
                        'host_address': settings.local.host.local_addr,
                        'host_address6': settings.local.host.local_addr6,
                    })
        except:
            logger.exception(
                'Error adding client',
                'server',
                server_id=self.server.id,
            )
            self.instance_com.client_kill(client_id)
            return

        self.clients.update_id(client_id, {
            'doc_id': doc_id,
            'timestamp': time.time(),
        })

        self.clients_queue.append(client_id)

        self.instance_com.push_output('User connected user_id=%s' %
                                      client['user_id'])
        self.send_event()
Exemple #46
0
    def client_connect(self, client):
        from pritunl.server.utils import get_by_id

        try:
            client_id = client['client_id']
            org_id = bson.ObjectId(client['org_id'])
            user_id = bson.ObjectId(client['user_id'])
            device_id = client.get('device_id')
            device_name = client.get('device_name')
            platform = client.get('platform')
            mac_addr = client.get('mac_addr')
            otp_code = client.get('otp_code')
            remote_ip = client.get('remote_ip')
            devices = self.client_devices[user_id]

            if not _limiter.validate(remote_ip):
                self.send_client_deny(client, 'Too many connect requests')
                return

            org = self.server.get_org(org_id, fields=['_id'])
            if not org:
                self.send_client_deny(client, 'Organization is not valid')
                return

            user = org.get_user(user_id, fields=[
                '_id', 'name', 'type', 'disabled', 'otp_secret'])
            if not user:
                self.send_client_deny(client, 'User is not valid')
                return

            if self.server.otp_auth and  user.type == CERT_CLIENT and \
                    not user.verify_otp_code(otp_code, remote_ip):
                logger.LogEntry(message='User failed two-step ' +
                    'authentication "%s".' % user.name)
                self.send_client_deny(client, 'Invalid OTP code')
                return

            client_conf = ''

            link_svr_id = None
            for link_doc in self.server.links:
                if link_doc['user_id'] == user.id:
                    link_svr_id = link_doc['server_id']
                    break

            if link_svr_id:
                link_svr = get_by_id(link_svr_id,
                    fields=['_id', 'network', 'local_networks'])
                client_conf += 'iroute %s %s\n' % utils.parse_network(
                    link_svr.network)
                for local_network in link_svr.local_networks:
                    client_conf += 'iroute %s %s\n' % utils.parse_network(
                        local_network)

            if not self.server.multi_device:
                virt_address = self.server.get_ip_addr(org.id, user_id)

                if virt_address and virt_address in self.client_ips:
                    for i, device in enumerate(devices):
                        if device['virt_address'] == virt_address:
                            self.client_kill(device)
                            if virt_address in self.client_ips:
                                self.client_ips.remove(virt_address)

                            del devices[i]
            else:
                virt_address = None
                if devices and device_id:
                    for i, device in enumerate(devices):
                        if device['device_id'] == device_id:
                            virt_address = device['virt_address']

                            self.client_kill(device)
                            if virt_address in self.client_ips:
                                self.client_ips.remove(virt_address)

                            del devices[i]

                if not virt_address:
                    virt_address = self.server.get_ip_addr(org.id, user_id)

                if virt_address and virt_address in self.client_ips:
                    virt_address = None

            if not virt_address:
                while True:
                    try:
                        ip_addr = self.ip_pool.pop()
                    except IndexError:
                        break

                    ip_addr = '%s/%s' % (
                        ip_addr, self.ip_network.prefixlen)

                    if ip_addr not in self.client_ips:
                        virt_address = ip_addr
                        self.client_dyn_ips.add(virt_address)
                        break

            if virt_address:
                self.client_ips.add(virt_address)
                devices.append({
                    'user_id': user_id,
                    'org_id': org_id,
                    'client_id': client_id,
                    'device_id': device_id,
                    'device_name': device_name,
                    'type': user.type,
                    'platform': platform,
                    'mac_addr': mac_addr,
                    'otp_code': otp_code,
                    'virt_address': virt_address,
                    'real_address': remote_ip,
                })
                client_conf += 'ifconfig-push %s %s\n' % utils.parse_network(
                    virt_address)
                self.send_client_auth(client, client_conf)
            else:
                self.send_client_deny(client, 'Unable to assign ip address')
        except:
            logger.exception('Error parsing client connect', 'server',
                server_id=self.server.id,
                instance_id=self.instance.id,
            )
            self.send_client_deny(client, 'Error parsing client connect')
Exemple #47
0
    def sso_auth_check(self, password, remote_ip):
        sso_mode = settings.app.sso or ''
        auth_server = AUTH_SERVER
        if settings.app.dedicated:
            auth_server = settings.app.dedicated

        if GOOGLE_AUTH in self.auth_type and GOOGLE_AUTH in sso_mode:
            if settings.user.skip_remote_sso_check:
                return True

            try:
                resp = requests.get(auth_server +
                                    '/update/google?user=%s&license=%s' % (
                                        urllib.parse.quote(self.email),
                                        settings.app.license,
                                    ))

                if resp.status_code != 200:
                    logger.error(
                        'Google auth check request error',
                        'user',
                        user_id=self.id,
                        user_name=self.name,
                        status_code=resp.status_code,
                        content=resp.content,
                    )
                    return False

                valid, google_groups = sso.verify_google(self.email)
                if not valid:
                    logger.error(
                        'Google auth check failed',
                        'user',
                        user_id=self.id,
                        user_name=self.name,
                    )
                    return False

                if settings.app.sso_google_mode == 'groups':
                    cur_groups = set(self.groups)
                    new_groups = set(google_groups)

                    if cur_groups != new_groups:
                        self.groups = list(new_groups)
                        self.commit('groups')

                return True
            except:
                logger.exception(
                    'Google auth check error',
                    'user',
                    user_id=self.id,
                    user_name=self.name,
                )
            return False
        elif AZURE_AUTH in self.auth_type and AZURE_AUTH in sso_mode:
            if settings.user.skip_remote_sso_check:
                return True

            try:
                resp = requests.get(
                    auth_server + ('/update/azure?user=%s&license=%s&' +
                                   'directory_id=%s&app_id=%s&app_secret=%s') %
                    (
                        urllib.parse.quote(self.name),
                        settings.app.license,
                        urllib.parse.quote(
                            settings.app.sso_azure_directory_id),
                        urllib.parse.quote(settings.app.sso_azure_app_id),
                        urllib.parse.quote(settings.app.sso_azure_app_secret),
                    ))

                if resp.status_code != 200:
                    logger.error(
                        'Azure auth check request error',
                        'user',
                        user_id=self.id,
                        user_name=self.name,
                        status_code=resp.status_code,
                        content=resp.content,
                    )
                    return False

                valid, azure_groups = sso.verify_azure(self.name)
                if not valid:
                    logger.error(
                        'Azure auth check failed',
                        'user',
                        user_id=self.id,
                        user_name=self.name,
                    )
                    return False

                if settings.app.sso_azure_mode == 'groups':
                    cur_groups = set(self.groups)
                    new_groups = set(azure_groups)

                    if cur_groups != new_groups:
                        self.groups = list(new_groups)
                        self.commit('groups')

                return True
            except:
                logger.exception(
                    'Azure auth check error',
                    'user',
                    user_id=self.id,
                    user_name=self.name,
                )
            return False
        elif AUTHZERO_AUTH in self.auth_type and AUTHZERO_AUTH in sso_mode:
            if settings.user.skip_remote_sso_check:
                return True

            try:
                resp = requests.get(
                    auth_server + ('/update/authzero?user=%s&license=%s&' +
                                   'app_domain=%s&app_id=%s&app_secret=%s') %
                    (
                        urllib.parse.quote(self.name),
                        settings.app.license,
                        urllib.parse.quote(settings.app.sso_authzero_domain),
                        urllib.parse.quote(settings.app.sso_authzero_app_id),
                        urllib.parse.quote(
                            settings.app.sso_authzero_app_secret),
                    ))

                if resp.status_code != 200:
                    logger.error(
                        'Auth0 auth check request error',
                        'user',
                        user_id=self.id,
                        user_name=self.name,
                        status_code=resp.status_code,
                        content=resp.content,
                    )
                    return False

                valid, authzero_groups = sso.verify_authzero(self.name)
                if not valid:
                    logger.error(
                        'Auth0 auth check failed',
                        'user',
                        user_id=self.id,
                        user_name=self.name,
                    )
                    return False

                if settings.app.sso_authzero_mode == 'groups':
                    cur_groups = set(self.groups)
                    new_groups = set(authzero_groups)

                    if cur_groups != new_groups:
                        self.groups = list(new_groups)
                        self.commit('groups')

                return True
            except:
                logger.exception(
                    'Auth0 auth check error',
                    'user',
                    user_id=self.id,
                    user_name=self.name,
                )
            return False
        elif SLACK_AUTH in self.auth_type and SLACK_AUTH in sso_mode:
            if settings.user.skip_remote_sso_check:
                return True

            if not isinstance(settings.app.sso_match, list):
                raise TypeError('Invalid sso match')

            try:
                resp = requests.get(
                    auth_server + '/update/slack?user=%s&team=%s&license=%s' %
                    (
                        urllib.parse.quote(self.name),
                        urllib.parse.quote(settings.app.sso_match[0]),
                        settings.app.license,
                    ))

                if resp.status_code != 200:
                    logger.error(
                        'Slack auth check request error',
                        'user',
                        user_id=self.id,
                        user_name=self.name,
                        status_code=resp.status_code,
                        content=resp.content,
                    )
                    return False

                return True
            except:
                logger.exception(
                    'Slack auth check error',
                    'user',
                    user_id=self.id,
                    user_name=self.name,
                )
            return False
        elif SAML_ONELOGIN_AUTH in self.auth_type and \
                SAML_ONELOGIN_AUTH in sso_mode:
            if settings.user.skip_remote_sso_check:
                return True

            try:
                return sso.auth_onelogin(self.name)
            except:
                logger.exception(
                    'OneLogin auth check error',
                    'user',
                    user_id=self.id,
                    user_name=self.name,
                )
            return False
        elif SAML_OKTA_AUTH in self.auth_type and \
                SAML_OKTA_AUTH in sso_mode:
            if settings.user.skip_remote_sso_check:
                return True

            try:
                return sso.auth_okta(self.name)
            except:
                logger.exception(
                    'Okta auth check error',
                    'user',
                    user_id=self.id,
                    user_name=self.name,
                )
            return False
        elif RADIUS_AUTH in self.auth_type and RADIUS_AUTH in sso_mode:
            try:
                return sso.verify_radius(self.name, password)[0]
            except:
                logger.exception(
                    'Radius auth check error',
                    'user',
                    user_id=self.id,
                    user_name=self.name,
                )
            return False
        elif PLUGIN_AUTH in self.auth_type:
            try:
                return sso.plugin_login_authenticate(
                    user_name=self.name,
                    password=password,
                    remote_ip=remote_ip,
                )[1]
            except:
                logger.exception(
                    'Plugin auth check error',
                    'user',
                    user_id=self.id,
                    user_name=self.name,
                )
            return False

        return True
Exemple #48
0
    def ping_thread(self):
        try:
            while True:
                try:
                    try:
                        client_id = self.clients_queue.popleft()
                    except IndexError:
                        if self.interrupter_sleep(10):
                            return
                        continue

                    client = self.clients.find_id(client_id)
                    if not client:
                        continue

                    diff = settings.vpn.client_ttl - 150 - \
                           (time.time() - client['timestamp'])

                    if diff > settings.vpn.client_ttl:
                        logger.error(
                            'Client ping time diff out of range',
                            'server',
                            time_diff=diff,
                            server_id=self.server.id,
                            instance_id=self.instance.id,
                        )
                        if self.interrupter_sleep(10):
                            return
                    elif diff > 1:
                        if self.interrupter_sleep(diff):
                            return

                    if self.instance.sock_interrupt:
                        return

                    try:
                        updated = self.clients.update_id(
                            client_id, {
                                'timestamp': time.time(),
                            })
                        if not updated:
                            continue

                        response = self.collection.update(
                            {
                                '_id': client['doc_id'],
                            }, {
                                '$set': {
                                    'timestamp': utils.now(),
                                },
                            })
                        if not response['updatedExisting']:
                            logger.error(
                                'Client lost unexpectedly',
                                'server',
                                server_id=self.server.id,
                                instance_id=self.instance.id,
                            )
                            self.instance_com.client_kill(client_id)
                            continue
                    except:
                        self.clients_queue.append(client_id)
                        logger.exception(
                            'Failed to update client',
                            'server',
                            server_id=self.server.id,
                            instance_id=self.instance.id,
                        )
                        yield interrupter_sleep(1)
                        continue

                    self.clients_queue.append(client_id)

                    yield
                    if self.instance.sock_interrupt:
                        return
                except GeneratorExit:
                    raise
                except:
                    logger.exception(
                        'Error in client thread',
                        'server',
                        server_id=self.server.id,
                        instance_id=self.instance.id,
                    )
                    yield interrupter_sleep(3)
                    if self.instance.sock_interrupt:
                        return
        finally:
            doc_ids = []
            for client in self.clients.find_all():
                doc_id = client.get('doc_id')
                if doc_id:
                    doc_ids.append(doc_id)

            try:
                self.collection.remove({
                    '_id': {
                        '$in': doc_ids
                    },
                })
            except:
                logger.exception(
                    'Error removing client',
                    'server',
                    server_id=self.server.id,
                )
Exemple #49
0
    def set_iptables_rules(self):
        logger.debug(
            'Setting iptables rules',
            'server',
            server_id=self.server.id,
        )

        processes = {}
        poller = select.epoll()
        self.iptables_rules, self.ip6tables_rules = \
            self.generate_iptables_rules()

        for rule in self.iptables_rules:
            cmd, process = self.exists_iptables_rules(rule)
            fileno = process.stdout.fileno()

            processes[fileno] = (cmd, process, ['iptables', '-I'] + rule)
            poller.register(fileno, select.EPOLLHUP)

        for rule in self.ip6tables_rules:
            cmd, process = self.exists_ip6tables_rules(rule)
            fileno = process.stdout.fileno()

            processes[fileno] = (cmd, process, ['ip6tables', '-I'] + rule)
            poller.register(fileno, select.EPOLLHUP)

        try:
            while True:
                for fd, event in poller.poll(timeout=8):
                    cmd, process, next_cmd = processes.pop(fd)
                    poller.unregister(fd)

                    if next_cmd:
                        if process.poll():
                            process = subprocess.Popen(
                                next_cmd,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                            )
                            fileno = process.stdout.fileno()

                            processes[fileno] = (next_cmd, process, None)
                            poller.register(fileno, select.EPOLLHUP)
                    else:
                        retcode = process.poll()
                        if retcode:
                            std_out, err_out = process.communicate()
                            raise subprocess.CalledProcessError(retcode,
                                                                cmd,
                                                                output=err_out)

                    if not processes:
                        return

        except subprocess.CalledProcessError as error:
            logger.exception('Failed to apply iptables ' + \
                'routing rule', 'server',
                server_id=self.server.id,
                output=error.output,
            )
            raise
Exemple #50
0
    def _connect(self, client_data, reauth):
        client_id = client_data['client_id']
        key_id = client_data['key_id']
        org_id = client_data['org_id']
        user_id = client_data['user_id']
        remote_ip = client_data.get('remote_ip')
        platform = client_data.get('platform')
        device_name = client_data.get('device_name')
        password = client_data.get('password')

        try:
            if not settings.vpn.stress_test and \
                    not _limiter.validate(remote_ip):
                self.instance_com.send_client_deny(
                    client_id, key_id, 'Too many connect requests')
                return

            org = self.get_org(org_id)
            if not org:
                self.instance_com.send_client_deny(
                    client_id, key_id, 'Organization is not valid')
                return

            user = org.get_user(user_id,
                                fields=('_id', 'name', 'email', 'pin', 'type',
                                        'auth_type', 'disabled', 'otp_secret',
                                        'link_server_id', 'bypass_secondary',
                                        'dns_servers', 'dns_suffix',
                                        'port_forwarding'))
            if not user:
                self.instance_com.send_client_deny(client_id, key_id,
                                                   'User is not valid')
                return

            def callback(allow, reason=None):
                try:
                    if allow:
                        self.allow_client(client_data, org, user, reauth)
                        if settings.vpn.stress_test:
                            self._connected(client_id)
                    else:
                        self.instance_com.send_client_deny(
                            client_id, key_id, reason)
                except:
                    logger.exception(
                        'Error in authorizer callback',
                        'server',
                        server_id=self.server.id,
                        instance_id=self.instance.id,
                    )

            authorizer.Authorizer(
                self.server,
                user,
                remote_ip,
                platform,
                device_name,
                password,
                reauth,
                callback,
            ).authenticate()
        except:
            logger.exception(
                'Error parsing client connect',
                'server',
                server_id=self.server.id,
            )
            self.instance_com.send_client_deny(client_id, key_id,
                                               'Error parsing client connect')
Exemple #51
0
    def sso_auth_check(self, password, remote_ip):
        if GOOGLE_AUTH in self.auth_type and GOOGLE_AUTH in settings.app.sso:
            if settings.user.skip_remote_sso_check:
                return True

            try:
                resp = requests.get(AUTH_SERVER +
                                    '/update/google?user=%s&license=%s' % (
                                        urllib.quote(self.email),
                                        settings.app.license,
                                    ))

                if resp.status_code == 200:
                    return True
            except:
                logger.exception(
                    'Google auth check error',
                    'user',
                    user_id=self.id,
                )
            return False
        elif SLACK_AUTH in self.auth_type and SLACK_AUTH in settings.app.sso:
            if settings.user.skip_remote_sso_check:
                return True

            try:
                resp = requests.get(
                    AUTH_SERVER + '/update/slack?user=%s&team=%s&license=%s' %
                    (
                        urllib.quote(self.name),
                        urllib.quote(settings.app.sso_match[0]),
                        settings.app.license,
                    ))

                if resp.status_code == 200:
                    return True
            except:
                logger.exception(
                    'Slack auth check error',
                    'user',
                    user_id=self.id,
                )
            return False
        elif SAML_ONELOGIN_AUTH in self.auth_type and \
                SAML_ONELOGIN_AUTH in settings.app.sso:
            if settings.user.skip_remote_sso_check:
                return True

            try:
                return sso.auth_onelogin(self.name)
            except:
                logger.exception(
                    'OneLogin auth check error',
                    'user',
                    user_id=self.id,
                )
            return False
        elif SAML_OKTA_AUTH in self.auth_type and \
                SAML_OKTA_AUTH in settings.app.sso:
            if settings.user.skip_remote_sso_check:
                return True

            try:
                return sso.auth_okta(self.name)
            except:
                logger.exception(
                    'Okta auth check error',
                    'user',
                    user_id=self.id,
                )
            return False
        elif RADIUS_AUTH in self.auth_type and RADIUS_AUTH in settings.app.sso:
            try:
                return sso.verify_radius(self.name, password)[0]
            except:
                logger.exception(
                    'Radius auth check error',
                    'user',
                    user_id=self.id,
                )
            return False
        elif PLUGIN_AUTH in self.auth_type:
            try:
                return sso.plugin_login_authenticate(
                    user_name=self.name,
                    password=password,
                    remote_ip=remote_ip,
                )[0]
            except:
                logger.exception(
                    'Plugin auth check error',
                    'user',
                    user_id=self.id,
                )
            return False

        return True
Exemple #52
0
    def generate_iptables_rules(self):
        rules = []

        try:
            routes_output = utils.check_output_logged(['route', '-n'])
        except subprocess.CalledProcessError:
            logger.exception(
                'Failed to get IP routes',
                'server',
                server_id=self.server.id,
            )
            raise

        routes = {}
        for line in routes_output.splitlines():
            line_split = line.split()
            if len(line_split) < 8 or not re.match(IP_REGEX, line_split[0]):
                continue
            routes[line_split[0]] = line_split[7]

        if '0.0.0.0' not in routes:
            raise IptablesError('Failed to find default network interface', {
                'server_id': self.server.id,
            })
        default_interface = routes['0.0.0.0']

        rules.append(['INPUT', '-i', self.interface, '-j', 'ACCEPT'])
        rules.append(['FORWARD', '-i', self.interface, '-j', 'ACCEPT'])

        interfaces = set()
        other_networks = []
        if self.server.mode == ALL_TRAFFIC and \
                self.server.network_mode != BRIDGE:
            other_networks = ['0.0.0.0/0']

        link_svr_networks = []
        for link_svr in self.server.iter_links(fields=('_id', 'network')):
            link_svr_networks.append(link_svr.network)

        for network_address in self.server.local_networks or other_networks:
            args_base = ['POSTROUTING', '-t', 'nat']
            network = utils.parse_network(network_address)[0]

            if network not in routes:
                logger.info('Failed to find interface for local ' + \
                    'network route, using default route', 'server',
                    server_id=self.server.id,
                )
                interface = default_interface
            else:
                interface = routes[network]
            interfaces.add(interface)

            if network != '0.0.0.0':
                args_base += ['-d', network_address]

            args_base += [
                '-o',
                interface,
                '-j',
                'MASQUERADE',
            ]

            rules.append(args_base + ['-s', self.server.network])

            for link_svr_net in link_svr_networks:
                rules.append(args_base + ['-s', link_svr_net])

        for interface in interfaces:
            rules.append([
                'FORWARD',
                '-i',
                interface,
                '-o',
                self.interface,
                '-m',
                'state',
                '--state',
                'ESTABLISHED,RELATED',
                '-j',
                'ACCEPT',
            ])
            rules.append([
                'FORWARD',
                '-i',
                self.interface,
                '-o',
                interface,
                '-m',
                'state',
                '--state',
                'ESTABLISHED,RELATED',
                '-j',
                'ACCEPT',
            ])

        extra_args = [
            '-m',
            'comment',
            '--comment',
            'pritunl_%s' % self.server.id,
        ]

        if settings.local.iptables_wait:
            extra_args.append('--wait')

        rules = [x + extra_args for x in rules]

        return rules
Exemple #53
0
def setup_mongo():
    prefix = settings.conf.mongodb_collection_prefix or ''
    last_error = time.time() - 24
    while True:
        try:
            client = pymongo.MongoClient(
                settings.conf.mongodb_uri,
                connectTimeoutMS=MONGO_CONNECT_TIMEOUT)
            break
        except pymongo.errors.ConnectionFailure:
            time.sleep(0.5)
            if time.time() - last_error > 30:
                last_error = time.time()
                logger.exception('Error connecting to mongodb server')

    database = client.get_default_database()
    cur_collections = database.collection_names()

    if prefix + 'messages' not in cur_collections:
        database.create_collection(prefix + 'messages',
                                   capped=True,
                                   size=100000,
                                   max=1024)

    mongo.collections.update({
        'time_sync':
        getattr(database, prefix + 'time_sync'),
        'transaction':
        getattr(database, prefix + 'transaction'),
        'queue':
        getattr(database, prefix + 'queue'),
        'task':
        getattr(database, prefix + 'task'),
        'settings':
        getattr(database, prefix + 'settings'),
        'messages':
        getattr(database, prefix + 'messages'),
        'administrators':
        getattr(database, prefix + 'administrators'),
        'users':
        getattr(database, prefix + 'users'),
        'users_key_link':
        getattr(database, prefix + 'users_key_link'),
        'organizations':
        getattr(database, prefix + 'organizations'),
        'hosts':
        getattr(database, prefix + 'hosts'),
        'hosts_usage':
        getattr(database, prefix + 'hosts_usage'),
        'servers':
        getattr(database, prefix + 'servers'),
        'servers_output':
        getattr(database, prefix + 'servers_output'),
        'servers_output_link':
        getattr(database, prefix + 'servers_output_link'),
        'servers_bandwidth':
        getattr(database, prefix + 'servers_bandwidth'),
        'servers_ip_pool':
        getattr(database, prefix + 'servers_ip_pool'),
        'dh_params':
        getattr(database, prefix + 'dh_params'),
        'auth_nonces':
        getattr(database, prefix + 'auth_nonces'),
        'auth_limiter':
        getattr(database, prefix + 'auth_limiter'),
        'otp':
        getattr(database, prefix + 'otp'),
        'otp_cache':
        getattr(database, prefix + 'otp_cache'),
    })

    for collection_name, collection in mongo.collections.items():
        collection.name_str = collection_name

    utils.sync_time()

    settings.init()

    if prefix + 'logs' not in cur_collections:
        log_limit = settings.app.log_limit
        database.create_collection(prefix + 'logs',
                                   capped=True,
                                   size=log_limit * 1024,
                                   max=log_limit)

    if prefix + 'log_entries' not in cur_collections:
        log_entry_limit = settings.app.log_entry_limit
        database.create_collection(prefix + 'log_entries',
                                   capped=True,
                                   size=log_entry_limit * 512,
                                   max=log_entry_limit)

    mongo.collections.update({
        'logs':
        getattr(database, prefix + 'logs'),
        'log_entries':
        getattr(database, prefix + 'log_entries'),
    })
    mongo.collections['logs'].name_str = 'logs'
    mongo.collections['log_entries'].name_str = 'log_entries'

    mongo.collections['logs'].ensure_index('timestamp')
    mongo.collections['transaction'].ensure_index('lock_id', unique=True)
    mongo.collections['transaction'].ensure_index([
        ('ttl_timestamp', pymongo.ASCENDING),
        ('state', pymongo.ASCENDING),
        ('priority', pymongo.DESCENDING),
    ])
    mongo.collections['queue'].ensure_index('runner_id')
    mongo.collections['queue'].ensure_index('ttl_timestamp')
    mongo.collections['task'].ensure_index('type', unique=True)
    mongo.collections['task'].ensure_index('ttl_timestamp')
    mongo.collections['log_entries'].ensure_index([
        ('timestamp', pymongo.DESCENDING),
    ])
    mongo.collections['messages'].ensure_index('channel')
    mongo.collections['administrators'].ensure_index('username', unique=True)
    mongo.collections['users'].ensure_index('resource_id')
    mongo.collections['users'].ensure_index([
        ('type', pymongo.ASCENDING),
        ('org_id', pymongo.ASCENDING),
    ])
    mongo.collections['users'].ensure_index([
        ('org_id', pymongo.ASCENDING),
        ('name', pymongo.ASCENDING),
    ])
    mongo.collections['users_key_link'].ensure_index('key_id')
    mongo.collections['users_key_link'].ensure_index('short_id', unique=True)
    mongo.collections['organizations'].ensure_index('type')
    mongo.collections['hosts'].ensure_index('name')
    mongo.collections['hosts_usage'].ensure_index([
        ('host_id', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ])
    mongo.collections['servers'].ensure_index('name')
    mongo.collections['servers'].ensure_index('ping_timestamp')
    mongo.collections['servers_output'].ensure_index([
        ('server_id', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ])
    mongo.collections['servers_output_link'].ensure_index([
        ('server_id', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ])
    mongo.collections['servers_bandwidth'].ensure_index([
        ('server_id', pymongo.ASCENDING),
        ('period', pymongo.ASCENDING),
        ('timestamp', pymongo.ASCENDING),
    ])
    mongo.collections['servers_ip_pool'].ensure_index([
        ('server_id', pymongo.ASCENDING),
        ('user_id', pymongo.ASCENDING),
    ])
    mongo.collections['servers_ip_pool'].ensure_index('user_id')
    mongo.collections['dh_params'].ensure_index('dh_param_bits')
    mongo.collections['auth_nonces'].ensure_index([
        ('token', pymongo.ASCENDING),
        ('nonce', pymongo.ASCENDING),
    ],
                                                  unique=True)

    # TODO check and remove current index when changed
    mongo.collections['users_key_link'].ensure_index(
        'timestamp', expireAfterSeconds=settings.app.key_link_timeout)
    mongo.collections['auth_nonces'].ensure_index(
        'timestamp', expireAfterSeconds=settings.app.auth_time_window * 2.1)
    mongo.collections['auth_limiter'].ensure_index(
        'timestamp', expireAfterSeconds=settings.app.auth_limiter_ttl)
    mongo.collections['otp'].ensure_index('timestamp', expireAfterSeconds=120)
    mongo.collections['otp_cache'].ensure_index(
        'timestamp', expireAfterSeconds=settings.user.otp_cache_ttl)

    if not auth.Administrator.collection.find_one():
        auth.Administrator(
            username=DEFAULT_USERNAME,
            password=DEFAULT_PASSWORD,
            default=True,
        ).commit()

    secret_key = settings.app.cookie_secret
    if not secret_key:
        secret_key = re.sub(r'[\W_]+', '',
                            base64.b64encode(os.urandom(128)))[:64]
        settings.app.cookie_secret = secret_key
        settings.commit()
    app.app.secret_key = secret_key.encode()

    server_api_key = settings.app.server_api_key
    if not server_api_key:
        server_api_key = re.sub(r'[\W_]+', '',
                                base64.b64encode(os.urandom(128)))[:64]
        settings.app.server_api_key = server_api_key
        settings.commit()
Exemple #54
0
    def generate_iptables_rules(self):
        rules = []
        rules6 = []

        try:
            routes_output = utils.check_output_logged(['route', '-n'])
        except subprocess.CalledProcessError:
            logger.exception(
                'Failed to get IP routes',
                'server',
                server_id=self.server.id,
            )
            raise

        routes = {}
        for line in routes_output.splitlines():
            line_split = line.split()
            if len(line_split) < 8 or not re.match(IP_REGEX, line_split[0]):
                continue
            if line_split[0] not in routes:
                routes[line_split[0]] = line_split[7]

        if '0.0.0.0' not in routes:
            raise IptablesError('Failed to find default network interface')
        default_interface = routes['0.0.0.0']

        routes6 = {}
        default_interface6 = None
        if self.server.ipv6:
            try:
                routes_output = utils.check_output_logged(
                    ['route', '-n', '-A', 'inet6'])
            except subprocess.CalledProcessError:
                logger.exception(
                    'Failed to get IPv6 routes',
                    'server',
                    server_id=self.server.id,
                )
                raise

            for line in routes_output.splitlines():
                line_split = line.split()

                if len(line_split) < 7:
                    continue

                try:
                    ipaddress.IPv6Network(line_split[0])
                except ipaddress.AddressValueError:
                    continue

                if line_split[0] not in routes6:
                    routes6[line_split[0]] = line_split[6]

            if '::/0' not in routes6:
                raise IptablesError(
                    'Failed to find default IPv6 network interface')
            default_interface6 = routes6['::/0']

            if default_interface6 == 'lo':
                logger.error(
                    'Failed to find default IPv6 interface',
                    'server',
                    server_id=self.server.id,
                )

        rules.append(['INPUT', '-i', self.interface, '-j', 'ACCEPT'])
        rules.append(['FORWARD', '-i', self.interface, '-j', 'ACCEPT'])
        if self.server.ipv6:
            if self.server.ipv6_firewall and \
                    settings.local.host.routed_subnet6:
                rules6.append(
                    ['INPUT', '-d', self.server.network6, '-j', 'DROP'])
                rules6.append([
                    'INPUT', '-d', self.server.network6, '-m', 'conntrack',
                    '--ctstate', 'RELATED,ESTABLISHED', '-j', 'ACCEPT'
                ])
                rules6.append([
                    'INPUT', '-d', self.server.network6, '-p', 'icmpv6',
                    '--icmpv6-type', '128', '-m', 'conntrack', '--ctstate',
                    'NEW', '-j', 'ACCEPT'
                ])
                rules6.append(
                    ['FORWARD', '-d', self.server.network6, '-j', 'DROP'])
                rules6.append([
                    'FORWARD', '-d', self.server.network6, '-m', 'conntrack',
                    '--ctstate', 'RELATED,ESTABLISHED', '-j', 'ACCEPT'
                ])
                rules6.append([
                    'FORWARD', '-d', self.server.network6, '-p', 'icmpv6',
                    '--icmpv6-type', '128', '-m', 'conntrack', '--ctstate',
                    'NEW', '-j', 'ACCEPT'
                ])
            else:
                rules6.append(
                    ['INPUT', '-d', self.server.network6, '-j', 'ACCEPT'])
                rules6.append(
                    ['FORWARD', '-d', self.server.network6, '-j', 'ACCEPT'])

        interfaces = set()
        interfaces6 = set()
        other_networks = []
        if self.server.mode == ALL_TRAFFIC and \
                self.server.network_mode != BRIDGE:
            other_networks = ['0.0.0.0/0']
            if self.server.ipv6 and not settings.local.host.routed_subnet6:
                other_networks.append('::/0')

        link_svr_networks = []
        for link_svr in self.server.iter_links(fields=('_id', 'network',
                                                       'network_start',
                                                       'network_end')):
            link_svr_networks.append(link_svr.network)

        for network_address in self.server.local_networks or other_networks:
            args_base = ['POSTROUTING', '-t', 'nat']
            is6 = ':' in network_address
            if is6:
                network = network_address
            else:
                network = utils.parse_network(network_address)[0]

            if is6:
                if network not in routes6:
                    logger.info('Failed to find interface for local ' + \
                        'IPv6 network route, using default route', 'server',
                        server_id=self.server.id,
                        network=network,
                    )
                    interface = default_interface6
                else:
                    interface = routes6[network]
                interfaces6.add(interface)
            else:
                if network not in routes:
                    logger.info('Failed to find interface for local ' + \
                        'network route, using default route', 'server',
                        server_id=self.server.id,
                        network=network,
                    )
                    interface = default_interface
                else:
                    interface = routes[network]
                interfaces.add(interface)

            if network != '0.0.0.0' and network != '::/0':
                args_base += ['-d', network_address]

            args_base += [
                '-o',
                interface,
                '-j',
                'MASQUERADE',
            ]

            if is6:
                rules6.append(args_base + ['-s', self.server.network6])
            else:
                rules.append(args_base + ['-s', self.server.network])

            for link_svr_net in link_svr_networks:
                rules.append(args_base + ['-s', link_svr_net])

        for interface in interfaces:
            rules.append([
                'FORWARD',
                '-i',
                interface,
                '-o',
                self.interface,
                '-m',
                'state',
                '--state',
                'ESTABLISHED,RELATED',
                '-j',
                'ACCEPT',
            ])
            rules.append([
                'FORWARD',
                '-i',
                self.interface,
                '-o',
                interface,
                '-m',
                'state',
                '--state',
                'ESTABLISHED,RELATED',
                '-j',
                'ACCEPT',
            ])

        for interface in interfaces6:
            if self.server.ipv6 and self.server.ipv6_firewall and \
                    settings.local.host.routed_subnet6 and \
                    interface == default_interface6:
                continue

            rules6.append([
                'FORWARD',
                '-i',
                interface,
                '-o',
                self.interface,
                '-m',
                'state',
                '--state',
                'ESTABLISHED,RELATED',
                '-j',
                'ACCEPT',
            ])
            rules6.append([
                'FORWARD',
                '-i',
                self.interface,
                '-o',
                interface,
                '-m',
                'state',
                '--state',
                'ESTABLISHED,RELATED',
                '-j',
                'ACCEPT',
            ])

        extra_args = [
            '-m',
            'comment',
            '--comment',
            'pritunl_%s' % self.server.id,
        ]

        if settings.local.iptables_wait:
            extra_args.append('--wait')

        rules = [x + extra_args for x in rules]
        rules6 = [x + extra_args for x in rules6]

        return rules, rules6
Exemple #55
0
def _dns_thread():
    from pritunl import host

    while True:
        process = None

        try:
            if not host.dns_mapping_servers:
                yield interrupter_sleep(3)
                continue

            yield

            process = subprocess.Popen(
                ['pritunl-dns'],
                env=dict(
                    os.environ, **{
                        'DB': settings.conf.mongodb_uri,
                        'DB_PREFIX': settings.conf.mongodb_collection_prefix
                        or '',
                    }),
            )

            while True:
                if not host.dns_mapping_servers:
                    process.terminate()
                    yield interrupter_sleep(3)
                    process.kill()
                    process = None
                    break
                elif process.poll() is not None:
                    output = None
                    try:
                        output = process.stdout.readall()
                        output += process.stderr.readall()
                    except:
                        pass

                    if check_global_interrupt():
                        return

                    logger.error(
                        'DNS mapping service stopped unexpectedly',
                        'setup',
                        output=output,
                    )
                    process = None

                    yield interrupter_sleep(1)

                    break

                time.sleep(0.5)
                yield
        except GeneratorExit:
            if process:
                process.terminate()
                time.sleep(1)
                process.kill()
            return
        except:
            logger.exception('Error in dns service', 'setup')

        yield interrupter_sleep(1)
Exemple #56
0
def settings_put():
    if settings.app.demo_mode:
        return utils.demo_blocked()

    org_event = False
    admin_event = False
    admin = flask.g.administrator
    changes = set()

    settings_commit = False
    update_server = False
    update_acme = False
    update_cert = False

    if 'username' in flask.request.json and flask.request.json['username']:
        username = utils.filter_str(flask.request.json['username']).lower()
        if username != admin.username:
            changes.add('username')
        admin.username = username

    if 'password' in flask.request.json and flask.request.json['password']:
        password = flask.request.json['password']
        changes.add('password')
        admin.password = password

    if 'server_cert' in flask.request.json:
        settings_commit = True
        server_cert = flask.request.json['server_cert']
        if server_cert:
            server_cert = server_cert.strip()
        else:
            server_cert = None

        if server_cert != settings.app.server_cert:
            update_server = True

        settings.app.server_cert = server_cert

    if 'server_key' in flask.request.json:
        settings_commit = True
        server_key = flask.request.json['server_key']
        if server_key:
            server_key = server_key.strip()
        else:
            server_key = None

        if server_key != settings.app.server_key:
            update_server = True

        settings.app.server_key = server_key

    if 'server_port' in flask.request.json:
        settings_commit = True

        server_port = flask.request.json['server_port']
        if not server_port:
            server_port = 443

        try:
            server_port = int(server_port)
            if server_port < 1 or server_port > 65535:
                raise ValueError('Port invalid')
        except ValueError:
            return utils.jsonify(
                {
                    'error': PORT_INVALID,
                    'error_msg': PORT_INVALID_MSG,
                }, 400)

        if server_port != settings.app.server_port:
            update_server = True

        settings.app.server_port = server_port

    if 'acme_domain' in flask.request.json:
        settings_commit = True

        acme_domain = utils.filter_str(flask.request.json['acme_domain']
                                       or None)
        if acme_domain:
            acme_domain = acme_domain.replace('https://', '')
            acme_domain = acme_domain.replace('http://', '')
            acme_domain = acme_domain.replace('/', '')

        if acme_domain != settings.app.acme_domain:
            if not acme_domain:
                settings.app.acme_key = None
                settings.app.acme_timestamp = None
                settings.app.server_key = None
                settings.app.server_cert = None
                update_server = True
                update_cert = True
            else:
                update_acme = True
        settings.app.acme_domain = acme_domain

    if 'auditing' in flask.request.json:
        settings_commit = True
        auditing = flask.request.json['auditing'] or None

        if settings.app.auditing == ALL and auditing != ALL:
            return utils.jsonify(
                {
                    'error': CANNOT_DISABLE_AUTIDING,
                    'error_msg': CANNOT_DISABLE_AUTIDING_MSG,
                }, 400)

        if settings.app.auditing != auditing:
            if not flask.g.administrator.super_user:
                return utils.jsonify(
                    {
                        'error': REQUIRES_SUPER_USER,
                        'error_msg': REQUIRES_SUPER_USER_MSG,
                    }, 400)
            admin_event = True
            org_event = True

        settings.app.auditing = auditing

    if 'monitoring' in flask.request.json:
        settings_commit = True
        monitoring = flask.request.json['monitoring'] or None
        settings.app.monitoring = monitoring

    if 'influxdb_uri' in flask.request.json:
        settings_commit = True
        influxdb_uri = flask.request.json['influxdb_uri'] or None
        settings.app.influxdb_uri = influxdb_uri

    if 'email_from' in flask.request.json:
        settings_commit = True
        email_from = flask.request.json['email_from'] or None
        if email_from != settings.app.email_from:
            changes.add('smtp')
        settings.app.email_from = email_from

    if 'email_server' in flask.request.json:
        settings_commit = True
        email_server = flask.request.json['email_server'] or None
        if email_server != settings.app.email_server:
            changes.add('smtp')
        settings.app.email_server = email_server

    if 'email_username' in flask.request.json:
        settings_commit = True
        email_username = flask.request.json['email_username'] or None
        if email_username != settings.app.email_username:
            changes.add('smtp')
        settings.app.email_username = email_username

    if 'email_password' in flask.request.json:
        settings_commit = True
        email_password = flask.request.json['email_password'] or None
        if email_password != settings.app.email_password:
            changes.add('smtp')
        settings.app.email_password = email_password

    if 'pin_mode' in flask.request.json:
        settings_commit = True
        pin_mode = flask.request.json['pin_mode'] or None
        if pin_mode != settings.user.pin_mode:
            changes.add('pin_mode')
        settings.user.pin_mode = pin_mode

    if 'sso' in flask.request.json:
        org_event = True
        settings_commit = True
        sso = flask.request.json['sso'] or None
        if sso != settings.app.sso:
            changes.add('sso')
        settings.app.sso = sso

    if 'sso_match' in flask.request.json:
        settings_commit = True
        sso_match = flask.request.json['sso_match'] or None

        if sso_match != settings.app.sso_match:
            changes.add('sso')

        if isinstance(sso_match, list):
            settings.app.sso_match = sso_match
        else:
            settings.app.sso_match = None

    if 'sso_azure_directory_id' in flask.request.json:
        settings_commit = True
        sso_azure_directory_id = flask.request.json[
            'sso_azure_directory_id'] or None
        if sso_azure_directory_id != settings.app.sso_azure_directory_id:
            changes.add('sso')
        settings.app.sso_azure_directory_id = sso_azure_directory_id

    if 'sso_azure_app_id' in flask.request.json:
        settings_commit = True
        sso_azure_app_id = flask.request.json['sso_azure_app_id'] or None
        if sso_azure_app_id != settings.app.sso_azure_app_id:
            changes.add('sso')
        settings.app.sso_azure_app_id = sso_azure_app_id

    if 'sso_azure_app_secret' in flask.request.json:
        settings_commit = True
        sso_azure_app_secret = flask.request.json[
            'sso_azure_app_secret'] or None
        if sso_azure_app_secret != settings.app.sso_azure_app_secret:
            changes.add('sso')
        settings.app.sso_azure_app_secret = sso_azure_app_secret

    if 'sso_authzero_domain' in flask.request.json:
        settings_commit = True
        sso_authzero_domain = flask.request.json['sso_authzero_domain'] or None
        if sso_authzero_domain != settings.app.sso_authzero_domain:
            changes.add('sso')
        settings.app.sso_authzero_domain = sso_authzero_domain

    if 'sso_authzero_app_id' in flask.request.json:
        settings_commit = True
        sso_authzero_app_id = flask.request.json['sso_authzero_app_id'] or None
        if sso_authzero_app_id != settings.app.sso_authzero_app_id:
            changes.add('sso')
        settings.app.sso_authzero_app_id = sso_authzero_app_id

    if 'sso_authzero_app_secret' in flask.request.json:
        settings_commit = True
        sso_authzero_app_secret = flask.request.json[
            'sso_authzero_app_secret'] or None
        if sso_authzero_app_secret != settings.app.sso_authzero_app_secret:
            changes.add('sso')
        settings.app.sso_authzero_app_secret = sso_authzero_app_secret

    if 'sso_google_key' in flask.request.json:
        settings_commit = True
        sso_google_key = flask.request.json['sso_google_key'] or None
        if sso_google_key != settings.app.sso_google_key:
            changes.add('sso')
        settings.app.sso_google_key = sso_google_key

    if 'sso_google_email' in flask.request.json:
        settings_commit = True
        sso_google_email = flask.request.json['sso_google_email'] or None
        if sso_google_email != settings.app.sso_google_email:
            changes.add('sso')
        settings.app.sso_google_email = sso_google_email

    if 'sso_duo_token' in flask.request.json:
        settings_commit = True
        sso_duo_token = flask.request.json['sso_duo_token'] or None
        if sso_duo_token != settings.app.sso_duo_token:
            changes.add('sso')
        settings.app.sso_duo_token = sso_duo_token

    if 'sso_duo_secret' in flask.request.json:
        settings_commit = True
        sso_duo_secret = flask.request.json['sso_duo_secret'] or None
        if sso_duo_secret != settings.app.sso_duo_secret:
            changes.add('sso')
        settings.app.sso_duo_secret = sso_duo_secret

    if 'sso_duo_host' in flask.request.json:
        settings_commit = True
        sso_duo_host = flask.request.json['sso_duo_host'] or None
        if sso_duo_host != settings.app.sso_duo_host:
            changes.add('sso')
        settings.app.sso_duo_host = sso_duo_host

    if 'sso_duo_mode' in flask.request.json:
        settings_commit = True
        sso_duo_mode = flask.request.json['sso_duo_mode'] or None
        if sso_duo_mode != settings.app.sso_duo_mode:
            changes.add('sso')
        settings.app.sso_duo_mode = sso_duo_mode

    if 'sso_radius_secret' in flask.request.json:
        settings_commit = True
        sso_radius_secret = flask.request.json['sso_radius_secret'] or None
        if sso_radius_secret != settings.app.sso_radius_secret:
            changes.add('sso')
        settings.app.sso_radius_secret = sso_radius_secret

    if 'sso_radius_host' in flask.request.json:
        settings_commit = True
        sso_radius_host = flask.request.json['sso_radius_host'] or None
        if sso_radius_host != settings.app.sso_radius_host:
            changes.add('sso')
        settings.app.sso_radius_host = sso_radius_host

    if 'sso_org' in flask.request.json:
        settings_commit = True
        sso_org = flask.request.json['sso_org'] or None

        if sso_org:
            sso_org = utils.ObjectId(sso_org)
        else:
            sso_org = None

        if sso_org != settings.app.sso_org:
            changes.add('sso')

        if settings.app.sso and not sso_org:
            return utils.jsonify(
                {
                    'error': SSO_ORG_NULL,
                    'error_msg': SSO_ORG_NULL_MSG,
                }, 400)

        settings.app.sso_org = sso_org

    if 'sso_saml_url' in flask.request.json:
        settings_commit = True
        sso_saml_url = flask.request.json['sso_saml_url'] or None
        if sso_saml_url != settings.app.sso_saml_url:
            changes.add('sso')
        settings.app.sso_saml_url = sso_saml_url

    if 'sso_saml_issuer_url' in flask.request.json:
        settings_commit = True
        sso_saml_issuer_url = flask.request.json['sso_saml_issuer_url'] or \
            None
        if sso_saml_issuer_url != settings.app.sso_saml_issuer_url:
            changes.add('sso')
        settings.app.sso_saml_issuer_url = sso_saml_issuer_url

    if 'sso_saml_cert' in flask.request.json:
        settings_commit = True
        sso_saml_cert = flask.request.json['sso_saml_cert'] or None
        if sso_saml_cert != settings.app.sso_saml_cert:
            changes.add('sso')
        settings.app.sso_saml_cert = sso_saml_cert

    if 'sso_okta_app_id' in flask.request.json:
        settings_commit = True
        sso_okta_app_id = flask.request.json['sso_okta_app_id'] or None
        if sso_okta_app_id != settings.app.sso_okta_app_id:
            changes.add('sso')
        settings.app.sso_okta_app_id = sso_okta_app_id

    if 'sso_okta_token' in flask.request.json:
        settings_commit = True
        sso_okta_token = flask.request.json['sso_okta_token'] or None
        if sso_okta_token != settings.app.sso_okta_token:
            changes.add('sso')
        settings.app.sso_okta_token = sso_okta_token

    if 'sso_okta_mode' in flask.request.json:
        sso_mode = settings.app.sso
        if sso_mode and sso_mode == SAML_OKTA_AUTH:
            settings_commit = True
            sso_okta_mode = flask.request.json['sso_okta_mode']
            settings.app.sso_okta_mode = sso_okta_mode

    if 'sso_onelogin_app_id' in flask.request.json:
        settings_commit = True
        sso_onelogin_app_id = flask.request.json['sso_onelogin_app_id'] or \
            None
        if sso_onelogin_app_id != settings.app.sso_onelogin_app_id:
            changes.add('sso')
        settings.app.sso_onelogin_app_id = sso_onelogin_app_id

    if 'sso_onelogin_id' in flask.request.json:
        settings_commit = True
        sso_onelogin_id = flask.request.json['sso_onelogin_id'] or None
        if sso_onelogin_id != settings.app.sso_onelogin_id:
            changes.add('sso')
        settings.app.sso_onelogin_id = sso_onelogin_id

    if 'sso_onelogin_secret' in flask.request.json:
        settings_commit = True
        sso_onelogin_secret = \
            flask.request.json['sso_onelogin_secret'] or None
        if sso_onelogin_secret != settings.app.sso_onelogin_secret:
            changes.add('sso')
        settings.app.sso_onelogin_secret = sso_onelogin_secret

    if 'sso_onelogin_mode' in flask.request.json:
        sso_mode = settings.app.sso
        if sso_mode and sso_mode == SAML_ONELOGIN_AUTH:
            settings_commit = True
            sso_onelogin_mode = flask.request.json['sso_onelogin_mode']
            settings.app.sso_onelogin_mode = sso_onelogin_mode

    if 'sso_cache' in flask.request.json:
        settings_commit = True
        sso_cache = True if \
            flask.request.json['sso_cache'] else False
        if sso_cache != settings.app.sso_cache:
            changes.add('sso')
        settings.app.sso_cache = sso_cache

    if 'sso_client_cache' in flask.request.json:
        settings_commit = True
        sso_client_cache = True if \
            flask.request.json['sso_client_cache'] else False
        if sso_client_cache != settings.app.sso_client_cache:
            changes.add('sso')
        settings.app.sso_client_cache = sso_client_cache

    if 'restrict_import' in flask.request.json:
        settings_commit = True
        restrict_import = True if \
            flask.request.json['restrict_import'] else False
        if restrict_import != settings.user.restrict_import:
            changes.add('restrict_import')
        settings.user.restrict_import = restrict_import

    if 'client_reconnect' in flask.request.json:
        settings_commit = True
        client_reconnect = True if \
            flask.request.json['client_reconnect'] else False
        settings.user.reconnect = client_reconnect

    if 'sso_yubico_client' in flask.request.json:
        settings_commit = True
        sso_yubico_client = \
            flask.request.json['sso_yubico_client'] or None
        if sso_yubico_client != settings.app.sso_yubico_client:
            changes.add('sso')
        settings.app.sso_yubico_client = sso_yubico_client

    if 'sso_yubico_secret' in flask.request.json:
        settings_commit = True
        sso_yubico_secret = \
            flask.request.json['sso_yubico_secret'] or None
        if sso_yubico_secret != settings.app.sso_yubico_secret:
            changes.add('sso')
        settings.app.sso_yubico_secret = sso_yubico_secret

    if flask.request.json.get('theme'):
        settings_commit = True
        theme = 'light' if flask.request.json['theme'] == 'light' else 'dark'

        if theme != settings.app.theme:
            if theme == 'dark':
                event.Event(type=THEME_DARK)
            else:
                event.Event(type=THEME_LIGHT)

        settings.app.theme = theme

    if 'public_address' in flask.request.json:
        public_address = flask.request.json['public_address'] or None

        if public_address != settings.local.host.public_addr:
            settings.local.host.public_address = public_address
            settings.local.host.commit('public_address')

    if 'public_address6' in flask.request.json:
        public_address6 = flask.request.json['public_address6'] or None

        if public_address6 != settings.local.host.public_addr6:
            settings.local.host.public_address6 = public_address6
            settings.local.host.commit('public_address6')

    if 'routed_subnet6' in flask.request.json:
        routed_subnet6 = flask.request.json['routed_subnet6']
        if routed_subnet6:
            try:
                routed_subnet6 = ipaddress.IPv6Network(
                    flask.request.json['routed_subnet6'])
            except (ipaddress.AddressValueError, ValueError):
                return utils.jsonify(
                    {
                        'error': IPV6_SUBNET_INVALID,
                        'error_msg': IPV6_SUBNET_INVALID_MSG,
                    }, 400)

            if routed_subnet6.prefixlen > 64:
                return utils.jsonify(
                    {
                        'error': IPV6_SUBNET_SIZE_INVALID,
                        'error_msg': IPV6_SUBNET_SIZE_INVALID_MSG,
                    }, 400)

            routed_subnet6 = str(routed_subnet6)
        else:
            routed_subnet6 = None

        if settings.local.host.routed_subnet6 != routed_subnet6:
            if server.get_online_ipv6_count():
                return utils.jsonify(
                    {
                        'error': IPV6_SUBNET_ONLINE,
                        'error_msg': IPV6_SUBNET_ONLINE_MSG,
                    }, 400)
            settings.local.host.routed_subnet6 = routed_subnet6
            settings.local.host.commit('routed_subnet6')

    if 'routed_subnet6_wg' in flask.request.json:
        routed_subnet6_wg = flask.request.json['routed_subnet6_wg']
        if routed_subnet6_wg:
            try:
                routed_subnet6_wg = ipaddress.IPv6Network(
                    flask.request.json['routed_subnet6_wg'])
            except (ipaddress.AddressValueError, ValueError):
                return utils.jsonify(
                    {
                        'error': IPV6_SUBNET_WG_INVALID,
                        'error_msg': IPV6_SUBNET_WG_INVALID_MSG,
                    }, 400)

            if routed_subnet6_wg.prefixlen > 64:
                return utils.jsonify(
                    {
                        'error': IPV6_SUBNET_WG_SIZE_INVALID,
                        'error_msg': IPV6_SUBNET_WG_SIZE_INVALID_MSG,
                    }, 400)

            routed_subnet6_wg = str(routed_subnet6_wg)
        else:
            routed_subnet6_wg = None

        if settings.local.host.routed_subnet6_wg != routed_subnet6_wg:
            if server.get_online_ipv6_count():
                return utils.jsonify(
                    {
                        'error': IPV6_SUBNET_WG_ONLINE,
                        'error_msg': IPV6_SUBNET_WG_ONLINE_MSG,
                    }, 400)
            settings.local.host.routed_subnet6_wg = routed_subnet6_wg
            settings.local.host.commit('routed_subnet6_wg')

    if 'reverse_proxy' in flask.request.json:
        settings_commit = True
        reverse_proxy = flask.request.json['reverse_proxy']
        settings.app.reverse_proxy = True if reverse_proxy else False

    if 'cloud_provider' in flask.request.json:
        settings_commit = True
        cloud_provider = flask.request.json['cloud_provider'] or None
        settings.app.cloud_provider = cloud_provider

    if 'route53_region' in flask.request.json:
        settings_commit = True
        settings.app.route53_region = utils.filter_str(
            flask.request.json['route53_region']) or None

    if 'route53_zone' in flask.request.json:
        settings_commit = True
        settings.app.route53_zone = utils.filter_str(
            flask.request.json['route53_zone']) or None

    if settings.app.cloud_provider == 'oracle':
        if 'oracle_user_ocid' in flask.request.json:
            settings_commit = True
            settings.app.oracle_user_ocid = utils.filter_str(
                flask.request.json['oracle_user_ocid']) or None
    elif settings.app.oracle_user_ocid:
        settings_commit = True
        settings.app.oracle_user_ocid = None

    if 'oracle_public_key' in flask.request.json:
        if flask.request.json['oracle_public_key'] == 'reset':
            settings_commit = True
            private_key, public_key = utils.generate_rsa_key()
            settings.app.oracle_private_key = private_key
            settings.app.oracle_public_key = public_key

    for aws_key in (
            'us_east_1_access_key',
            'us_east_1_secret_key',
            'us_east_2_access_key',
            'us_east_2_secret_key',
            'us_west_1_access_key',
            'us_west_1_secret_key',
            'us_west_2_access_key',
            'us_west_2_secret_key',
            'us_gov_east_1_access_key',
            'us_gov_east_1_secret_key',
            'us_gov_west_1_access_key',
            'us_gov_west_1_secret_key',
            'eu_north_1_access_key',
            'eu_north_1_secret_key',
            'eu_west_1_access_key',
            'eu_west_1_secret_key',
            'eu_west_2_access_key',
            'eu_west_2_secret_key',
            'eu_west_3_access_key',
            'eu_west_3_secret_key',
            'eu_central_1_access_key',
            'eu_central_1_secret_key',
            'ca_central_1_access_key',
            'ca_central_1_secret_key',
            'cn_north_1_access_key',
            'cn_north_1_secret_key',
            'cn_northwest_1_access_key',
            'cn_northwest_1_secret_key',
            'ap_northeast_1_access_key',
            'ap_northeast_1_secret_key',
            'ap_northeast_2_access_key',
            'ap_northeast_2_secret_key',
            'ap_southeast_1_access_key',
            'ap_southeast_1_secret_key',
            'ap_southeast_2_access_key',
            'ap_southeast_2_secret_key',
            'ap_east_1_access_key',
            'ap_east_1_secret_key',
            'ap_south_1_access_key',
            'ap_south_1_secret_key',
            'sa_east_1_access_key',
            'sa_east_1_secret_key',
    ):
        if settings.app.cloud_provider != 'aws':
            settings_commit = True
            setattr(settings.app, aws_key, None)
        elif aws_key in flask.request.json:
            settings_commit = True
            aws_value = flask.request.json[aws_key]

            if aws_value:
                setattr(settings.app, aws_key, utils.filter_str(aws_value))
            else:
                setattr(settings.app, aws_key, None)

    if not settings.app.sso:
        settings.app.sso_match = None
        settings.app.sso_azure_directory_id = None
        settings.app.sso_azure_app_id = None
        settings.app.sso_azure_app_secret = None
        settings.app.sso_authzero_directory_id = None
        settings.app.sso_authzero_app_id = None
        settings.app.sso_authzero_app_secret = None
        settings.app.sso_google_key = None
        settings.app.sso_google_email = None
        settings.app.sso_duo_token = None
        settings.app.sso_duo_secret = None
        settings.app.sso_duo_host = None
        settings.app.sso_org = None
        settings.app.sso_saml_url = None
        settings.app.sso_saml_issuer_url = None
        settings.app.sso_saml_cert = None
        settings.app.sso_okta_app_id = None
        settings.app.sso_okta_token = None
        settings.app.sso_onelogin_key = None
        settings.app.sso_onelogin_app_id = None
        settings.app.sso_onelogin_id = None
        settings.app.sso_onelogin_secret = None
        settings.app.sso_radius_secret = None
        settings.app.sso_radius_host = None
    else:
        if RADIUS_AUTH in settings.app.sso and \
                settings.app.sso_duo_mode == 'passcode':
            return utils.jsonify(
                {
                    'error': RADIUS_DUO_PASSCODE,
                    'error_msg': RADIUS_DUO_PASSCODE_MSG,
                }, 400)

        if settings.app.sso == DUO_AUTH and \
                settings.app.sso_duo_mode == 'passcode':
            return utils.jsonify(
                {
                    'error': DUO_PASSCODE,
                    'error_msg': DUO_PASSCODE_MSG,
                }, 400)

    for change in changes:
        remote_addr = utils.get_remote_addr()
        flask.g.administrator.audit_event(
            'admin_settings',
            _changes_audit_text[change],
            remote_addr=remote_addr,
        )
        journal.entry(
            journal.SETTINGS_UPDATE,
            remote_address=remote_addr,
            event_long='Settings updated',
            changed=_changes_audit_text[change],
        )

    if settings_commit:
        settings.commit()

    admin.commit(admin.changed)

    if admin_event:
        event.Event(type=ADMINS_UPDATED)

    if org_event:
        for org in organization.iter_orgs(fields=('_id')):
            event.Event(type=USERS_UPDATED, resource_id=org.id)

    event.Event(type=SETTINGS_UPDATED)

    if update_acme:
        try:
            acme.update_acme_cert()
            app.update_server(0.5)
        except:
            logger.exception(
                'Failed to get LetsEncrypt cert',
                'handler',
                acme_domain=settings.app.acme_domain,
            )
            settings.app.acme_domain = None
            settings.app.acme_key = None
            settings.app.acme_timestamp = None
            settings.commit()
            return utils.jsonify(
                {
                    'error': ACME_ERROR,
                    'error_msg': ACME_ERROR_MSG,
                }, 400)
    elif update_cert:
        logger.info('Regenerating server certificate...', 'handler')
        utils.create_server_cert()
        app.update_server(0.5)
    elif update_server:
        app.update_server(0.5)

    response = flask.g.administrator.dict()
    response.update(_dict())
    return utils.jsonify(response)
Exemple #57
0
def update():
    license = settings.app.license
    collection = mongo.get_collection('settings')

    if not settings.app.id:
        settings.app.id = utils.random_name()
        settings.commit()

    if not license:
        settings.local.sub_active = False
        settings.local.sub_status = None
        settings.local.sub_plan = None
        settings.local.sub_quantity = None
        settings.local.sub_amount = None
        settings.local.sub_period_end = None
        settings.local.sub_trial_end = None
        settings.local.sub_cancel_at_period_end = None
        settings.local.sub_balance = None
        settings.local.sub_url_key = None
    else:
        for i in xrange(2):
            try:
                url = 'https://app.pritunl.com/subscription'
                if settings.app.dedicated:
                    url = settings.app.dedicated + '/subscription'

                response = requests.get(
                    url,
                    json={
                        'id': settings.app.id,
                        'license': license,
                        'version': settings.local.version_int,
                    },
                    timeout=max(settings.app.http_request_timeout, 10),
                )

                # License key invalid
                if response.status_code == 470:
                    raise ValueError('License key is invalid')

                if response.status_code == 473:
                    raise ValueError(('Version %r not recognized by ' +
                        'subscription server') % settings.local.version_int)

                data = response.json()

                settings.local.sub_active = data['active']
                settings.local.sub_status = data['status']
                settings.local.sub_plan = data['plan']
                settings.local.sub_quantity = data['quantity']
                settings.local.sub_amount = data['amount']
                settings.local.sub_period_end = data['period_end']
                settings.local.sub_trial_end = data['trial_end']
                settings.local.sub_cancel_at_period_end = data[
                    'cancel_at_period_end']
                settings.local.sub_balance = data.get('balance')
                settings.local.sub_url_key = data.get('url_key')
                settings.local.sub_styles[data['plan']] = data['styles']
            except:
                if i < 1:
                    logger.exception('Failed to check subscription status',
                        'subscription, retrying...')
                    time.sleep(1)
                    continue
                logger.exception('Failed to check subscription status',
                    'subscription')
                settings.local.sub_active = False
                settings.local.sub_status = None
                settings.local.sub_plan = None
                settings.local.sub_quantity = None
                settings.local.sub_amount = None
                settings.local.sub_period_end = None
                settings.local.sub_trial_end = None
                settings.local.sub_cancel_at_period_end = None
                settings.local.sub_balance = None
                settings.local.sub_url_key = None
            break

    if settings.app.license_plan != settings.local.sub_plan and \
            settings.local.sub_plan:
        settings.app.license_plan = settings.local.sub_plan
        settings.commit()

    response = collection.update({
        '_id': 'subscription',
        '$or': [
            {'active': {'$ne': settings.local.sub_active}},
            {'plan': {'$ne': settings.local.sub_plan}},
        ],
    }, {'$set': {
        'active': settings.local.sub_active,
        'plan': settings.local.sub_plan,
    }})
    if response['updatedExisting']:
        if settings.local.sub_active:
            if settings.local.sub_plan == 'premium':
                event.Event(type=SUBSCRIPTION_PREMIUM_ACTIVE)
            elif settings.local.sub_plan == 'enterprise':
                event.Event(type=SUBSCRIPTION_ENTERPRISE_ACTIVE)
            elif settings.local.sub_plan == 'enterprise_plus':
                event.Event(type=SUBSCRIPTION_ENTERPRISE_PLUS_ACTIVE)
            else:
                event.Event(type=SUBSCRIPTION_NONE_INACTIVE)
        else:
            if settings.local.sub_plan == 'premium':
                event.Event(type=SUBSCRIPTION_PREMIUM_INACTIVE)
            elif settings.local.sub_plan == 'enterprise':
                event.Event(type=SUBSCRIPTION_ENTERPRISE_INACTIVE)
            elif settings.local.sub_plan == 'enterprise_plus':
                event.Event(type=SUBSCRIPTION_ENTERPRISE_PLUS_INACTIVE)
            else:
                event.Event(type=SUBSCRIPTION_NONE_INACTIVE)

    return True
Exemple #58
0
def auth_onelogin(username):
    if not settings.app.sso_onelogin_id or \
            not settings.app.sso_onelogin_secret:
        try:
            response = requests.get(
                ONELOGIN_URL + '/api/v3/users/username/%s' %
                (urllib.parse.quote(username)),
                auth=(settings.app.sso_onelogin_key, 'x'),
            )
        except http.client.HTTPException:
            logger.exception(
                'OneLogin api error',
                'sso',
                username=username,
            )
            return False

        if response.status_code == 200:
            data = xml.etree.ElementTree.fromstring(response.content)
            if data.find('status').text == '1':
                return True

            logger.warning(
                'OneLogin user disabled',
                'sso',
                username=username,
            )
        elif response.status_code == 404:
            logger.error(
                'OneLogin user not found',
                'sso',
                username=username,
            )
        elif response.status_code == 406:
            logger.warning(
                'OneLogin user disabled',
                'sso',
                username=username,
            )
        else:
            logger.error(
                'OneLogin api error',
                'sso',
                username=username,
                status_code=response.status_code,
                response=response.content,
            )
        return False

    access_token = _get_access_token()
    if not access_token:
        return False

    response = requests.get(
        _get_base_url() + '/api/1/users',
        headers={
            'Authorization': 'bearer:%s' % access_token,
            'Content-Type': 'application/json',
        },
        params={
            'username': username,
        },
    )

    if response.status_code != 200:
        logger.error(
            'OneLogin api error',
            'sso',
            username=username,
            status_code=response.status_code,
            response=response.content,
        )
        return False

    users = response.json()['data']
    if not users:
        logger.error(
            'OneLogin user not found',
            'sso',
            username=username,
        )
        return False

    user = users[0]
    if user['status'] != 1:
        logger.warning(
            'OneLogin user disabled',
            'sso',
            username=username,
        )
        return False

    onelogin_app_id = settings.app.sso_onelogin_app_id
    if not onelogin_app_id:
        return True

    try:
        onelogin_app_id = int(onelogin_app_id)
    except ValueError:
        pass

    user_id = user['id']

    response = requests.get(
        _get_base_url() + '/api/1/users/%d/apps' % user_id,
        headers={
            'Authorization': 'bearer:%s' % access_token,
        },
    )

    if response.status_code != 200:
        logger.error(
            'OneLogin api error',
            'sso',
            username=username,
            status_code=response.status_code,
            response=response.content,
        )
        return False

    applications = response.json()['data']
    if not applications:
        logger.error(
            'OneLogin user apps not found',
            'sso',
            username=username,
        )
        return False

    for application in applications:
        if application['id'] == onelogin_app_id:
            return True

    logger.warning(
        'OneLogin user is not assigned to application',
        'sso',
        username=username,
        onelogin_app_id=onelogin_app_id,
    )

    return False
Exemple #59
0
    def initialize(self):
        temp_path = utils.get_temp_path()
        index_path = os.path.join(temp_path, INDEX_NAME)
        index_attr_path = os.path.join(temp_path, INDEX_ATTR_NAME)
        serial_path = os.path.join(temp_path, SERIAL_NAME)
        ssl_conf_path = os.path.join(temp_path, OPENSSL_NAME)
        reqs_path = os.path.join(temp_path, '%s.csr' % self.id)
        key_path = os.path.join(temp_path, '%s.key' % self.id)
        cert_path = os.path.join(temp_path, '%s.crt' % self.id)
        ca_name = self.id if self.type == CERT_CA else 'ca'
        ca_cert_path = os.path.join(temp_path, '%s.crt' % ca_name)
        ca_key_path = os.path.join(temp_path, '%s.key' % ca_name)

        self.org.queue_com.wait_status()

        try:
            os.makedirs(temp_path)

            with open(index_path, 'a'):
                os.utime(index_path, None)

            with open(index_attr_path, 'a'):
                os.utime(index_attr_path, None)

            with open(serial_path, 'w') as serial_file:
                serial_file.write('01\n')

            with open(ssl_conf_path, 'w') as conf_file:
                conf_file.write(CERT_CONF % (
                    settings.user.cert_key_bits,
                    settings.user.cert_message_digest,
                    self.org.id,
                    self.id,
                    index_path,
                    serial_path,
                    temp_path,
                    ca_cert_path,
                    ca_key_path,
                    settings.user.cert_message_digest,
                ))

            self.org.queue_com.wait_status()

            if self.type != CERT_CA:
                self.org.write_file('ca_certificate', ca_cert_path, chmod=0600)
                self.org.write_file('ca_private_key', ca_key_path, chmod=0600)
                self.generate_otp_secret()

            try:
                args = [
                    'openssl',
                    'req',
                    '-new',
                    '-batch',
                    '-config',
                    ssl_conf_path,
                    '-out',
                    reqs_path,
                    '-keyout',
                    key_path,
                    '-reqexts',
                    '%s_req_ext' % self.type.replace('_pool', ''),
                ]
                self.org.queue_com.popen(args)
            except (OSError, ValueError):
                logger.exception(
                    'Failed to create user cert requests',
                    'user',
                    org_id=self.org.id,
                    user_id=self.id,
                )
                raise
            self.read_file('private_key', key_path)

            try:
                args = ['openssl', 'ca', '-batch']

                if self.type == CERT_CA:
                    args += ['-selfsign']

                args += [
                    '-config',
                    ssl_conf_path,
                    '-in',
                    reqs_path,
                    '-out',
                    cert_path,
                    '-extensions',
                    '%s_ext' % self.type.replace('_pool', ''),
                ]

                self.org.queue_com.popen(args)
            except (OSError, ValueError):
                logger.exception(
                    'Failed to create user cert',
                    'user',
                    org_id=self.org.id,
                    user_id=self.id,
                )
                raise
            self.read_file('certificate', cert_path)
        finally:
            try:
                utils.rmtree(temp_path)
            except subprocess.CalledProcessError:
                pass

        self.org.queue_com.wait_status()

        # If assign ip addr fails it will be corrected in ip sync task
        try:
            self.assign_ip_addr()
        except:
            logger.exception(
                'Failed to assign users ip address',
                'user',
                org_id=self.org.id,
                user_id=self.id,
            )
Exemple #60
0
def _check():
    try:
        settings.load_mongo()
    except:
        logger.exception('Auto settings check failed')
    _start_check_timer()