예제 #1
0
class Register(object):
    """
    Interfaces with setup and bootstrapping operations for a provider
    """

    zope.interface.implements(ILEAPComponent)

    def __init__(self, signaler=None):
        """
        Constructor for the Register component

        :param signaler: Object in charge of handling communication
                         back to the frontend
        :type signaler: Signaler
        """
        object.__init__(self)
        self.key = "register"
        self._signaler = signaler
        self._provider_config = ProviderConfig()

    def register_user(self, domain, username, password):
        """
        Register a user using the domain and password given as parameters.

        :param domain: the domain we need to register the user.
        :type domain: unicode
        :param username: the user name
        :type username: unicode
        :param password: the password for the username
        :type password: unicode

        :returns: the defer for the operation running in a thread.
        :rtype: twisted.internet.defer.Deferred
        """
        # If there's no loaded provider or
        # we want to connect to other provider...
        if (not self._provider_config.loaded() or
                self._provider_config.get_domain() != domain):
            self._provider_config.load(get_provider_path(domain))

        if self._provider_config.loaded():
            srpregister = SRPRegister(signaler=self._signaler,
                                      provider_config=self._provider_config)
            return threads.deferToThread(
                partial(srpregister.register_user, username, password))
        else:
            if self._signaler is not None:
                self._signaler.signal(self._signaler.srp_registration_failed)
            logger.error("Could not load provider configuration.")
예제 #2
0
    def __init__(self, userid, passwd, mdir=None):
        """
        Initialize the plumber with all that's needed to authenticate
        against the provider.

        :param userid: user identifier, foo@bar
        :type userid: basestring
        :param passwd: the soledad passphrase
        :type passwd: basestring
        :param mdir: a path to a maildir to import
        :type mdir: str or None
        """
        self.userid = userid
        self.passwd = passwd
        user, provider = userid.split('@')
        self.user = user
        self.mdir = mdir
        self.sol = None
        self._settings = Settings()

        provider_config_path = os.path.join(get_path_prefix(),
                                            get_provider_path(provider))
        provider_config = ProviderConfig()
        loaded = provider_config.load(provider_config_path)
        if not loaded:
            print "could not load provider config!"
            return self.exit()
예제 #3
0
    def __init__(self, userid, passwd, mdir=None):
        """
        Initialize the plumber with all that's needed to authenticate
        against the provider.

        :param userid: user identifier, foo@bar
        :type userid: basestring
        :param passwd: the soledad passphrase
        :type passwd: basestring
        :param mdir: a path to a maildir to import
        :type mdir: str or None
        """
        self.userid = userid
        self.passwd = passwd
        user, provider = userid.split('@')
        self.user = user
        self.mdir = mdir
        self.sol = None
        self._settings = Settings()

        provider_config_path = os.path.join(get_path_prefix(),
                                            get_provider_path(provider))
        provider_config = ProviderConfig()
        loaded = provider_config.load(provider_config_path)
        if not loaded:
            print "could not load provider config!"
            return self.exit()
예제 #4
0
    def test_correct_http_uri(self):
        """
        Checks that registration autocorrect http uris to https ones.
        """
        HTTP_URI = "http://localhost:%s" % (self.https_port, )
        HTTPS_URI = "https://localhost:%s/1/users" % (self.https_port, )
        provider = ProviderConfig()
        provider.get_ca_cert_path = MagicMock()
        provider.get_ca_cert_path.return_value = _get_capath()
        provider.get_api_uri = MagicMock()

        # we introduce a http uri in the config file...
        provider.get_api_uri.return_value = HTTP_URI
        loaded = provider.load(path=os.path.join(
            _here, "test_provider.json"))
        if not loaded:
            raise ImproperlyConfiguredError(
                "Could not load test provider config")

        register = srpregister.SRPRegister(provider_config=provider)

        # ... and we check that we're correctly taking the HTTPS protocol
        # instead
        reg_uri = register._get_registration_uri()
        self.assertEquals(reg_uri, HTTPS_URI)
        register._get_registration_uri = MagicMock(return_value=HTTPS_URI)
        d = threads.deferToThread(register.register_user, "test_failhttp",
                                  "barpass")
        d.addCallback(self.assertTrue)

        return d
예제 #5
0
    def setUpClass(cls):
        """
        Sets up this TestCase with a simple and faked provider instance:

        * runs a threaded reactor
        * loads a mocked ProviderConfig that points to the certs in the
          leap.common.testing module.
        """
        factory = fake_provider.get_provider_factory()
        http = reactor.listenTCP(8001, factory)
        https = reactor.listenSSL(
            0, factory,
            fake_provider.OpenSSLServerContextFactory())
        get_port = lambda p: p.getHost().port
        cls.http_port = get_port(http)
        cls.https_port = get_port(https)

        provider = ProviderConfig()
        provider.get_ca_cert_path = MagicMock()
        provider.get_ca_cert_path.return_value = _get_capath()

        provider.get_api_uri = MagicMock()
        provider.get_api_uri.return_value = cls._get_https_uri()

        loaded = provider.load(path=os.path.join(
            _here, "test_provider.json"))
        if not loaded:
            raise ImproperlyConfiguredError(
                "Could not load test provider config")
        cls.register = srpregister.SRPRegister(provider_config=provider)

        cls.auth = srpauth.SRPAuth(provider)
    def test_correct_http_uri(self):
        """
        Checks that registration autocorrect http uris to https ones.
        """
        HTTP_URI = "http://localhost:%s" % (self.https_port, )
        HTTPS_URI = "https://localhost:%s/1/users" % (self.https_port, )
        provider = ProviderConfig()
        provider.get_ca_cert_path = MagicMock()
        provider.get_ca_cert_path.return_value = _get_capath()
        provider.get_api_uri = MagicMock()

        # we introduce a http uri in the config file...
        provider.get_api_uri.return_value = HTTP_URI
        loaded = provider.load(path=os.path.join(_here, "test_provider.json"))
        if not loaded:
            raise ImproperlyConfiguredError(
                "Could not load test provider config")

        register = srpregister.SRPRegister(provider_config=provider)

        # ... and we check that we're correctly taking the HTTPS protocol
        # instead
        reg_uri = register._get_registration_uri()
        self.assertEquals(reg_uri, HTTPS_URI)
        register._get_registration_uri = MagicMock(return_value=HTTPS_URI)
        d = threads.deferToThread(register.register_user, "test_failhttp",
                                  "barpass")
        d.addCallback(self.assertTrue)

        return d
    def setUpClass(cls):
        """
        Sets up this TestCase with a simple and faked provider instance:

        * runs a threaded reactor
        * loads a mocked ProviderConfig that points to the certs in the
          leap.common.testing module.
        """
        factory = fake_provider.get_provider_factory()
        http = reactor.listenTCP(8001, factory)
        https = reactor.listenSSL(0, factory,
                                  fake_provider.OpenSSLServerContextFactory())
        get_port = lambda p: p.getHost().port
        cls.http_port = get_port(http)
        cls.https_port = get_port(https)

        provider = ProviderConfig()
        provider.get_ca_cert_path = MagicMock()
        provider.get_ca_cert_path.return_value = _get_capath()

        provider.get_api_uri = MagicMock()
        provider.get_api_uri.return_value = cls._get_https_uri()

        loaded = provider.load(path=os.path.join(_here, "test_provider.json"))
        if not loaded:
            raise ImproperlyConfiguredError(
                "Could not load test provider config")
        cls.register = srpregister.SRPRegister(provider_config=provider)

        cls.auth = srpauth.SRPAuth(provider)
예제 #8
0
    def setUp(self):
        """
        Sets up this TestCase with a simple and faked provider instance:

        * runs a threaded reactor
        * loads a mocked ProviderConfig that points to the certs in the
          leap.common.testing module.
        """
        factory = fake_provider.get_provider_factory()
        http = reactor.listenTCP(0, factory)
        https = reactor.listenSSL(0, factory,
                                  fake_provider.OpenSSLServerContextFactory())
        get_port = lambda p: p.getHost().port
        self.http_port = get_port(http)
        self.https_port = get_port(https)

        provider = ProviderConfig()
        provider.get_ca_cert_path = mock.create_autospec(
            provider.get_ca_cert_path)
        provider.get_ca_cert_path.return_value = _get_capath()

        provider.get_api_uri = mock.create_autospec(provider.get_api_uri)
        provider.get_api_uri.return_value = self._get_https_uri()

        loaded = provider.load(path=os.path.join(_here, "test_provider.json"))
        if not loaded:
            raise ImproperlyConfiguredError(
                "Could not load test provider config")
        self.register = srpregister.SRPRegister(provider_config=provider)
        self.provider = provider
        self.TEST_USER = "******"
        self.TEST_PASS = "******"

        # Reset the singleton
        srpauth.SRPAuth._SRPAuth__instance = None
        self.auth = srpauth.SRPAuth(self.provider)
        self.auth_backend = self.auth._SRPAuth__instance

        self.old_post = self.auth_backend._session.post
        self.old_put = self.auth_backend._session.put
        self.old_delete = self.auth_backend._session.delete

        self.old_start_auth = self.auth_backend._start_authentication
        self.old_proc_challenge = self.auth_backend._process_challenge
        self.old_extract_data = self.auth_backend._extract_data
        self.old_verify_session = self.auth_backend._verify_session
        self.old_auth_preproc = self.auth_backend._authentication_preprocessing
        self.old_get_sid = self.auth_backend.get_session_id
        self.old_cookie_get = self.auth_backend._session.cookies.get
        self.old_auth = self.auth_backend.authenticate

        # HACK: this is needed since it seems that the backend settings path is
        # not using the right path
        mkdir_p('config/leap')
    def test_none_port(self):
        provider = ProviderConfig()
        provider.get_api_uri = MagicMock()
        provider.get_api_uri.return_value = "http://localhost/"
        loaded = provider.load(path=os.path.join(_here, "test_provider.json"))
        if not loaded:
            raise ImproperlyConfiguredError(
                "Could not load test provider config")

        register = srpregister.SRPRegister(provider_config=provider)
        self.assertEquals(register._port, "443")
예제 #10
0
    def test_none_port(self):
        provider = ProviderConfig()
        provider.get_api_uri = MagicMock()
        provider.get_api_uri.return_value = "http://localhost/"
        loaded = provider.load(path=os.path.join(
            _here, "test_provider.json"))
        if not loaded:
            raise ImproperlyConfiguredError(
                "Could not load test provider config")

        register = srpregister.SRPRegister(provider_config=provider)
        self.assertEquals(register._port, "443")
예제 #11
0
    def setUp(self):
        """
        Sets up this TestCase with a simple and faked provider instance:

        * runs a threaded reactor
        * loads a mocked ProviderConfig that points to the certs in the
          leap.common.testing module.
        """
        factory = fake_provider.get_provider_factory()
        http = reactor.listenTCP(0, factory)
        https = reactor.listenSSL(
            0, factory,
            fake_provider.OpenSSLServerContextFactory())
        get_port = lambda p: p.getHost().port
        self.http_port = get_port(http)
        self.https_port = get_port(https)

        provider = ProviderConfig()
        provider.get_ca_cert_path = mock.create_autospec(
            provider.get_ca_cert_path)
        provider.get_ca_cert_path.return_value = _get_capath()

        provider.get_api_uri = mock.create_autospec(
            provider.get_api_uri)
        provider.get_api_uri.return_value = self._get_https_uri()

        loaded = provider.load(path=os.path.join(
            _here, "test_provider.json"))
        if not loaded:
            raise ImproperlyConfiguredError(
                "Could not load test provider config")
        self.register = srpregister.SRPRegister(provider_config=provider)
        self.provider = provider
        self.TEST_USER = "******"
        self.TEST_PASS = "******"

        # Reset the singleton
        srpauth.SRPAuth._SRPAuth__instance = None
        self.auth = srpauth.SRPAuth(self.provider)
        self.auth_backend = self.auth._SRPAuth__instance

        self.old_post = self.auth_backend._session.post
        self.old_put = self.auth_backend._session.put
        self.old_delete = self.auth_backend._session.delete

        self.old_start_auth = self.auth_backend._start_authentication
        self.old_proc_challenge = self.auth_backend._process_challenge
        self.old_extract_data = self.auth_backend._extract_data
        self.old_verify_session = self.auth_backend._verify_session
        self.old_auth_preproc = self.auth_backend._authentication_preprocessing
        self.old_get_sid = self.auth_backend.get_session_id
        self.old_cookie_get = self.auth_backend._session.cookies.get
        self.old_auth = self.auth_backend.authenticate
예제 #12
0
    def test_wrong_cert(self):
        provider = ProviderConfig()
        loaded = provider.load(path=os.path.join(_here, "test_provider.json"))
        provider.get_ca_cert_path = MagicMock()
        provider.get_ca_cert_path.return_value = os.path.join(
            _here, "wrongcert.pem")
        provider.get_api_uri = MagicMock()
        provider.get_api_uri.return_value = self._get_https_uri()
        if not loaded:
            raise ImproperlyConfiguredError(
                "Could not load test provider config")

        register = srpregister.SRPRegister(provider_config=provider,
                                           register_path="users")
        ok = register.register_user("foouser_firsttime", "barpass")
        self.assertFalse(ok)
예제 #13
0
    def _get_provider_config(self, domain):
        """
        Helper to return a valid Provider Config from the domain name.

        :param domain: the domain name of the provider.
        :type domain: str

        :rtype: ProviderConfig or None if there is a problem loading the config
        """
        provider_config = ProviderConfig()
        provider_config_path = os.path.join(
            "leap", "providers", domain, "provider.json")

        if not provider_config.load(provider_config_path):
            provider_config = None

        return provider_config
    def test_wrong_cert(self):
        provider = ProviderConfig()
        loaded = provider.load(path=os.path.join(
            _here, "test_provider.json"))
        provider.get_ca_cert_path = MagicMock()
        provider.get_ca_cert_path.return_value = os.path.join(
            _here,
            "wrongcert.pem")
        provider.get_api_uri = MagicMock()
        provider.get_api_uri.return_value = self._get_https_uri()
        if not loaded:
            raise ImproperlyConfiguredError(
                "Could not load test provider config")

        register = srpregister.SRPRegister(provider_config=provider,
                                           register_path="users")
        ok = register.register_user("foouser_firsttime", "barpass")
        self.assertFalse(ok)
예제 #15
0
class Provider(object):
    """
    Interfaces with setup and bootstrapping operations for a provider
    """

    zope.interface.implements(ILEAPComponent)

    PROBLEM_SIGNAL = "prov_problem_with_provider"

    def __init__(self, signaler=None, bypass_checks=False):
        """
        Constructor for the Provider component

        :param signaler: Object in charge of handling communication
                         back to the frontend
        :type signaler: Signaler
        :param bypass_checks: Set to true if the app should bypass
                              first round of checks for CA
                              certificates at bootstrap
        :type bypass_checks: bool
        """
        object.__init__(self)
        self.key = "provider"
        self._provider_bootstrapper = ProviderBootstrapper(signaler,
                                                           bypass_checks)
        self._download_provider_defer = None
        self._provider_config = ProviderConfig()

    def setup_provider(self, provider):
        """
        Initiates the setup for a provider

        :param provider: URL for the provider
        :type provider: unicode

        :returns: the defer for the operation running in a thread.
        :rtype: twisted.internet.defer.Deferred
        """
        log.msg("Setting up provider %s..." % (provider.encode("idna"),))
        pb = self._provider_bootstrapper
        d = pb.run_provider_select_checks(provider, download_if_needed=True)
        self._download_provider_defer = d
        return d

    def cancel_setup_provider(self):
        """
        Cancel the ongoing setup provider defer (if any).
        """
        d = self._download_provider_defer
        if d is not None:
            d.cancel()

    def bootstrap(self, provider):
        """
        Second stage of bootstrapping for a provider.

        :param provider: URL for the provider
        :type provider: unicode

        :returns: the defer for the operation running in a thread.
        :rtype: twisted.internet.defer.Deferred
        """
        d = None

        # If there's no loaded provider or
        # we want to connect to other provider...
        if (not self._provider_config.loaded() or
                self._provider_config.get_domain() != provider):
            self._provider_config.load(get_provider_path(provider))

        if self._provider_config.loaded():
            d = self._provider_bootstrapper.run_provider_setup_checks(
                self._provider_config,
                download_if_needed=True)
        else:
            if self._signaler is not None:
                self._signaler.signal(self.PROBLEM_SIGNAL)
            logger.error("Could not load provider configuration.")
            self._login_widget.set_enabled(True)

        if d is None:
            d = defer.Deferred()
        return d
예제 #16
0
# EDIT THIS --------------------------------------------
user = u"USERNAME"
uuid = u"USERUUID"
_pass = u"USERPASS"
server_url = "https://soledad.server.example.org:2323"
# EDIT THIS --------------------------------------------

secrets_path = "/tmp/%s.secrets" % uuid
local_db_path = "/tmp/%s.soledad" % uuid
cert_file = "/tmp/cacert.pem"
provider_config = '/tmp/cdev.json'


provider = ProviderConfig()
provider.load(provider_config)

soledad = None


def printStuff(r):
    print r


def printErr(err):
    logging.exception(err.value)


def init_soledad(_):
    token = srpauth.get_token()
    print "token", token
class ProviderConfigTest(BaseLeapTest):
    """Tests for ProviderConfig"""

    def setUp(self):
        self._provider_config = ProviderConfig()
        json_string = json.dumps(sample_config)
        self._provider_config.load(data=json_string)

        # At certain points we are going to be replacing these method
        # to avoid creating a file.
        # We need to save the old implementation and restore it in
        # tearDown so we are sure everything is as expected for each
        # test. If we do it inside each specific test, a failure in
        # the test will leave the implementation with the mock.
        self._old_ospath_exists = os.path.exists

    def tearDown(self):
        os.path.exists = self._old_ospath_exists

    def test_configs_ok(self):
        """
        Test if the configs loads ok
        """
        # TODO: this test should go to the BaseConfig tests
        pc = self._provider_config
        self.assertEqual(pc.get_api_uri(), sample_config['api_uri'])
        self.assertEqual(pc.get_api_version(), sample_config['api_version'])
        self.assertEqual(pc.get_ca_cert_fingerprint(),
                         sample_config['ca_cert_fingerprint'])
        self.assertEqual(pc.get_ca_cert_uri(), sample_config['ca_cert_uri'])
        self.assertEqual(pc.get_default_language(),
                         sample_config['default_language'])

        self.assertEqual(pc.get_domain(), sample_config['domain'])
        self.assertEqual(pc.get_enrollment_policy(),
                         sample_config['enrollment_policy'])
        self.assertEqual(pc.get_languages(), sample_config['languages'])

    def test_localizations(self):
        pc = self._provider_config

        self.assertEqual(pc.get_description(lang='en'),
                         sample_config['description']['en'])
        self.assertEqual(pc.get_description(lang='es'),
                         sample_config['description']['es'])

        self.assertEqual(pc.get_name(lang='en'), sample_config['name']['en'])
        self.assertEqual(pc.get_name(lang='es'), sample_config['name']['es'])

    def _localize(self, lang):
        """
        Helper to change default language of the provider config.
        """
        pc = self._provider_config
        config = copy.deepcopy(sample_config)
        config['default_language'] = lang
        json_string = json.dumps(config)
        pc.load(data=json_string)

        return config

    def test_default_localization1(self):
        pc = self._provider_config
        config = self._localize(sample_config['languages'][0])

        default_language = config['default_language']
        default_description = config['description'][default_language]
        default_name = config['name'][default_language]

        self.assertEqual(pc.get_description(lang='xx'), default_description)
        self.assertEqual(pc.get_description(), default_description)

        self.assertEqual(pc.get_name(lang='xx'), default_name)
        self.assertEqual(pc.get_name(), default_name)

    def test_default_localization2(self):
        pc = self._provider_config
        config = self._localize(sample_config['languages'][1])

        default_language = config['default_language']
        default_description = config['description'][default_language]
        default_name = config['name'][default_language]

        self.assertEqual(pc.get_description(lang='xx'), default_description)
        self.assertEqual(pc.get_description(), default_description)

        self.assertEqual(pc.get_name(lang='xx'), default_name)
        self.assertEqual(pc.get_name(), default_name)

    def test_get_ca_cert_path_as_expected(self):
        pc = self._provider_config
        pc.get_path_prefix = Mock(return_value='test')

        provider_domain = sample_config['domain']
        expected_path = os.path.join('test', 'leap', 'providers',
                                     provider_domain, 'keys', 'ca',
                                     'cacert.pem')

        # mock 'os.path.exists' so we don't get an error for unexisting file
        os.path.exists = Mock(return_value=True)
        cert_path = pc.get_ca_cert_path()

        self.assertEqual(cert_path, expected_path)

    def test_get_ca_cert_path_about_to_download(self):
        pc = self._provider_config
        pc.get_path_prefix = Mock(return_value='test')

        provider_domain = sample_config['domain']
        expected_path = os.path.join('test', 'leap', 'providers',
                                     provider_domain, 'keys', 'ca',
                                     'cacert.pem')

        cert_path = pc.get_ca_cert_path(about_to_download=True)

        self.assertEqual(cert_path, expected_path)

    def test_get_ca_cert_path_fails(self):
        pc = self._provider_config
        pc.get_path_prefix = Mock(return_value='test')

        # mock 'get_domain' so we don't need to load a config
        provider_domain = 'test.provider.com'
        pc.get_domain = Mock(return_value=provider_domain)

        with self.assertRaises(MissingCACert):
            pc.get_ca_cert_path()

    def test_provides_eip(self):
        pc = self._provider_config
        config = copy.deepcopy(sample_config)

        # It provides
        config['services'] = ['openvpn', 'test_service']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertTrue(pc.provides_eip())

        # It does not provides
        config['services'] = ['test_service', 'other_service']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertFalse(pc.provides_eip())

    def test_provides_mx(self):
        pc = self._provider_config
        config = copy.deepcopy(sample_config)

        # It provides
        config['services'] = ['mx', 'other_service']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertTrue(pc.provides_mx())

        # It does not provides
        config['services'] = ['test_service', 'other_service']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertFalse(pc.provides_mx())

    def test_supports_unknown_service(self):
        pc = self._provider_config
        config = copy.deepcopy(sample_config)

        config['services'] = ['unknown']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertFalse('unknown' in get_supported(pc.get_services()))

    def test_provides_unknown_service(self):
        pc = self._provider_config
        config = copy.deepcopy(sample_config)

        config['services'] = ['unknown']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertTrue('unknown' in pc.get_services())

    def test_get_services_string(self):
        pc = self._provider_config
        config = copy.deepcopy(sample_config)
        config['services'] = [
            'openvpn', 'asdf', 'openvpn', 'not_supported_service']
        json_string = json.dumps(config)
        pc.load(data=json_string)

        self.assertEqual(pc.get_services_string(),
                         "Encrypted Internet, asdf, Encrypted Internet,"
                         " not_supported_service")
예제 #18
0
    def _download_provider_info(self, *args):
        """
        Downloads the provider.json defition
        """
        leap_assert(self._domain,
                    "Cannot download provider info without a domain")
        logger.debug("Downloading provider info for %r" % (self._domain))

        # --------------------------------------------------------------
        # TODO factor out with the download routines in services.
        # Watch out! We're handling the verify paramenter differently here.

        headers = {}
        domain = self._domain.encode(sys.getfilesystemencoding())
        provider_json = os.path.join(util.get_path_prefix(),
                                     get_provider_path(domain))

        if domain in PinnedProviders.domains() and \
           not os.path.exists(provider_json):
            mkdir_p(os.path.join(os.path.dirname(provider_json), "keys", "ca"))
            cacert = os.path.join(os.path.dirname(provider_json), "keys", "ca",
                                  "cacert.pem")
            PinnedProviders.save_hardcoded(domain, provider_json, cacert)

        mtime = get_mtime(provider_json)

        if self._download_if_needed and mtime:
            headers['if-modified-since'] = mtime

        uri = "https://%s/%s" % (self._domain, "provider.json")
        verify = self.verify

        if mtime:  # the provider.json exists
            # So, we're getting it from the api.* and checking against
            # the provider ca.
            try:
                provider_config = ProviderConfig()
                provider_config.load(provider_json)
                uri = provider_config.get_api_uri() + '/provider.json'
                verify = provider_config.get_ca_cert_path()
            except MissingCACert:
                # no ca? then download from main domain again.
                pass

        if verify:
            verify = verify.encode(sys.getfilesystemencoding())
        logger.debug("Requesting for provider.json... "
                     "uri: {0}, verify: {1}, headers: {2}".format(
                         uri, verify, headers))
        res = self._session.get(uri.encode('idna'),
                                verify=verify,
                                headers=headers,
                                timeout=REQUEST_TIMEOUT)
        res.raise_for_status()
        logger.debug("Request status code: {0}".format(res.status_code))

        min_client_version = res.headers.get(self.MIN_CLIENT_VERSION, '0')

        # Not modified
        if res.status_code == 304:
            logger.debug("Provider definition has not been modified")
        # --------------------------------------------------------------
        # end refactor, more or less...
        # XXX Watch out, have to check the supported api yet.
        else:
            if flags.APP_VERSION_CHECK:
                # TODO split
                if not provider.supports_client(min_client_version):
                    if self._signaler is not None:
                        self._signaler.signal(
                            self._signaler.prov_unsupported_client)
                    raise UnsupportedClientVersionError()

            provider_definition, mtime = get_content(res)

            provider_config = ProviderConfig()
            provider_config.load(data=provider_definition, mtime=mtime)
            provider_config.save(
                ["leap", "providers", domain, "provider.json"])

            if flags.API_VERSION_CHECK:
                # TODO split
                api_version = provider_config.get_api_version()
                if provider.supports_api(api_version):
                    logger.debug("Provider definition has been modified")
                else:
                    api_supported = ', '.join(provider.SUPPORTED_APIS)
                    error = ('Unsupported provider API version. '
                             'Supported versions are: {0}. '
                             'Found: {1}.').format(api_supported, api_version)

                    logger.error(error)
                    if self._signaler is not None:
                        self._signaler.signal(
                            self._signaler.prov_unsupported_api)
                    raise UnsupportedProviderAPI(error)
    def _download_provider_info(self, *args):
        """
        Downloads the provider.json defition
        """
        leap_assert(self._domain,
                    "Cannot download provider info without a domain")
        logger.debug("Downloading provider info for %r" % (self._domain))

        # --------------------------------------------------------------
        # TODO factor out with the download routines in services.
        # Watch out! We're handling the verify paramenter differently here.

        headers = {}
        domain = self._domain.encode(sys.getfilesystemencoding())
        provider_json = os.path.join(util.get_path_prefix(),
                                     get_provider_path(domain))

        if domain in PinnedProviders.domains() and \
           not os.path.exists(provider_json):
            mkdir_p(os.path.join(os.path.dirname(provider_json),
                                 "keys", "ca"))
            cacert = os.path.join(os.path.dirname(provider_json),
                                  "keys", "ca", "cacert.pem")
            PinnedProviders.save_hardcoded(domain, provider_json, cacert)

        mtime = get_mtime(provider_json)

        if self._download_if_needed and mtime:
            headers['if-modified-since'] = mtime

        uri = "https://%s/%s" % (self._domain, "provider.json")
        verify = self.verify

        if mtime:  # the provider.json exists
            # So, we're getting it from the api.* and checking against
            # the provider ca.
            try:
                provider_config = ProviderConfig()
                provider_config.load(provider_json)
                uri = provider_config.get_api_uri() + '/provider.json'
                verify = provider_config.get_ca_cert_path()
            except MissingCACert:
                # no ca? then download from main domain again.
                pass

        if verify:
            verify = verify.encode(sys.getfilesystemencoding())
        logger.debug("Requesting for provider.json... "
                     "uri: {0}, verify: {1}, headers: {2}".format(
                         uri, verify, headers))
        res = self._session.get(uri.encode('idna'), verify=verify,
                                headers=headers, timeout=REQUEST_TIMEOUT)
        res.raise_for_status()
        logger.debug("Request status code: {0}".format(res.status_code))

        min_client_version = res.headers.get(self.MIN_CLIENT_VERSION, '0')

        # Not modified
        if res.status_code == 304:
            logger.debug("Provider definition has not been modified")
        # --------------------------------------------------------------
        # end refactor, more or less...
        # XXX Watch out, have to check the supported api yet.
        else:
            if flags.APP_VERSION_CHECK:
                # TODO split
                if not provider.supports_client(min_client_version):
                    self._signaler.signal(
                        self._signaler.prov_unsupported_client)
                    raise UnsupportedClientVersionError()

            provider_definition, mtime = get_content(res)

            provider_config = ProviderConfig()
            provider_config.load(data=provider_definition, mtime=mtime)
            provider_config.save(["leap", "providers",
                                  domain, "provider.json"])

            if flags.API_VERSION_CHECK:
                # TODO split
                api_version = provider_config.get_api_version()
                if provider.supports_api(api_version):
                    logger.debug("Provider definition has been modified")
                else:
                    api_supported = ', '.join(provider.SUPPORTED_APIS)
                    error = ('Unsupported provider API version. '
                             'Supported versions are: {0}. '
                             'Found: {1}.').format(api_supported, api_version)

                    logger.error(error)
                    self._signaler.signal(self._signaler.prov_unsupported_api)
                    raise UnsupportedProviderAPI(error)
예제 #20
0
class Wizard(QtGui.QWizard):
    """
    First run wizard to register a user and setup a provider
    """

    INTRO_PAGE = 0
    SELECT_PROVIDER_PAGE = 1
    PRESENT_PROVIDER_PAGE = 2
    SETUP_PROVIDER_PAGE = 3
    REGISTER_USER_PAGE = 4
    SERVICES_PAGE = 5

    WEAK_PASSWORDS = ("123456", "qweasd", "qwerty",
                      "password")

    BARE_USERNAME_REGEX = r"^[A-Za-z\d_]+$"

    def __init__(self, standalone=False, bypass_checks=False):
        """
        Constructor for the main Wizard.

        :param standalone: If True, the application is running as standalone
            and the wizard should display some messages according to this.
        :type standalone: bool
        :param bypass_checks: Set to true if the app should bypass
        first round of checks for CA certificates at bootstrap
        :type bypass_checks: bool
        """
        QtGui.QWizard.__init__(self)

        self.standalone = standalone

        self.ui = Ui_Wizard()
        self.ui.setupUi(self)

        self.setPixmap(QtGui.QWizard.LogoPixmap,
                       QtGui.QPixmap(":/images/mask-icon.png"))

        self.QUESTION_ICON = QtGui.QPixmap(":/images/Emblem-question.png")
        self.ERROR_ICON = QtGui.QPixmap(":/images/Dialog-error.png")
        self.OK_ICON = QtGui.QPixmap(":/images/Dialog-accept.png")

        self._selected_services = set()
        self._shown_services = set()

        self._show_register = False

        self.ui.grpCheckProvider.setVisible(False)
        self.ui.btnCheck.clicked.connect(self._check_provider)
        self.ui.lnProvider.returnPressed.connect(self._check_provider)

        self._provider_bootstrapper = ProviderBootstrapper(bypass_checks)
        self._provider_bootstrapper.name_resolution.connect(
            self._name_resolution)
        self._provider_bootstrapper.https_connection.connect(
            self._https_connection)
        self._provider_bootstrapper.download_provider_info.connect(
            self._download_provider_info)

        self._provider_bootstrapper.download_ca_cert.connect(
            self._download_ca_cert)
        self._provider_bootstrapper.check_ca_fingerprint.connect(
            self._check_ca_fingerprint)
        self._provider_bootstrapper.check_api_certificate.connect(
            self._check_api_certificate)

        self._domain = None
        self._provider_config = ProviderConfig()

        # We will store a reference to the defers for eventual use
        # (eg, to cancel them) but not doing anything with them right now.
        self._provider_select_defer = None
        self._provider_setup_defer = None

        self.currentIdChanged.connect(self._current_id_changed)

        self.ui.lblPassword.setEchoMode(QtGui.QLineEdit.Password)
        self.ui.lblPassword2.setEchoMode(QtGui.QLineEdit.Password)

        self.ui.lnProvider.textChanged.connect(
            self._enable_check)

        self.ui.lblUser.returnPressed.connect(
            self._focus_password)
        self.ui.lblPassword.returnPressed.connect(
            self._focus_second_password)
        self.ui.lblPassword2.returnPressed.connect(
            self._register)
        self.ui.btnRegister.clicked.connect(
            self._register)

        usernameRe = QtCore.QRegExp(self.BARE_USERNAME_REGEX)
        self.ui.lblUser.setValidator(
            QtGui.QRegExpValidator(usernameRe, self))

        self.page(self.REGISTER_USER_PAGE).setCommitPage(True)

        self._username = None
        self._password = None

        self.page(self.REGISTER_USER_PAGE).setButtonText(
            QtGui.QWizard.CommitButton, self.tr("&Next >"))
        self.page(self.SERVICES_PAGE).setButtonText(
            QtGui.QWizard.FinishButton, self.tr("Connect"))

        # XXX: Temporary removal for enrollment policy
        # https://leap.se/code/issues/2922
        self.ui.label_12.setVisible(False)
        self.ui.lblProviderPolicy.setVisible(False)

    def get_domain(self):
        return self._domain

    def get_username(self):
        return self._username

    def get_password(self):
        return self._password

    def get_remember(self):
        return has_keyring() and self.ui.chkRemember.isChecked()

    def get_services(self):
        return self._selected_services

    def _enable_check(self, text):
        self.ui.btnCheck.setEnabled(len(self.ui.lnProvider.text()) != 0)
        self._reset_provider_check()

    def _focus_password(self):
        """
        Focuses at the password lineedit for the registration page
        """
        self.ui.lblPassword.setFocus()

    def _focus_second_password(self):
        """
        Focuses at the second password lineedit for the registration page
        """
        self.ui.lblPassword2.setFocus()

    def _register(self):
        """
        Performs the registration based on the values provided in the form
        """
        self.ui.btnRegister.setEnabled(False)

        username = self.ui.lblUser.text()
        password = self.ui.lblPassword.text()
        password2 = self.ui.lblPassword2.text()

        ok, msg = basic_password_checks(username, password, password2)
        if ok:
            register = SRPRegister(provider_config=self._provider_config)
            register.registration_finished.connect(
                self._registration_finished)

            threads.deferToThread(
                partial(register.register_user,
                        username.encode("utf8"),
                        password.encode("utf8")))

            self._username = username
            self._password = password
            self._set_register_status(self.tr("Starting registration..."))
        else:
            self._set_register_status(msg, error=True)
            self._focus_password()
            self.ui.btnRegister.setEnabled(True)

    def _set_registration_fields_visibility(self, visible):
        """
        This method hides the username and password labels and inputboxes.

        :param visible: sets the visibility of the widgets
            True: widgets are visible or False: are not
        :type visible: bool
        """
        # username and password inputs
        self.ui.lblUser.setVisible(visible)
        self.ui.lblPassword.setVisible(visible)
        self.ui.lblPassword2.setVisible(visible)

        # username and password labels
        self.ui.label_15.setVisible(visible)
        self.ui.label_16.setVisible(visible)
        self.ui.label_17.setVisible(visible)

        # register button
        self.ui.btnRegister.setVisible(visible)

    def _registration_finished(self, ok, req):
        if ok:
            user_domain = self._username + "@" + self._domain
            message = "<font color='green'><h3>"
            message += self.tr("User %s successfully registered.") % (
                user_domain, )
            message += "</h3></font>"
            self._set_register_status(message)

            self.ui.lblPassword2.clearFocus()
            self._set_registration_fields_visibility(False)

            # Allow the user to remember his password
            if has_keyring():
                self.ui.chkRemember.setVisible(True)
                self.ui.chkRemember.setEnabled(True)

            self.page(self.REGISTER_USER_PAGE).set_completed()
            self.button(QtGui.QWizard.BackButton).setEnabled(False)
        else:
            old_username = self._username
            self._username = None
            self._password = None
            error_msg = self.tr("Unknown error")
            try:
                content, _ = get_content(req)
                json_content = json.loads(content)
                error_msg = json_content.get("errors").get("login")[0]
                if not error_msg.istitle():
                    error_msg = "%s %s" % (old_username, error_msg)
            except Exception as e:
                logger.error("Unknown error: %r" % (e,))

            self._set_register_status(error_msg, error=True)
            self.ui.btnRegister.setEnabled(True)

    def _set_register_status(self, status, error=False):
        """
        Sets the status label in the registration page to status

        :param status: status message to display, can be HTML
        :type status: str
        """
        if error:
            status = "<font color='red'><b>%s</b></font>" % (status,)
        self.ui.lblRegisterStatus.setText(status)

    def _reset_provider_check(self):
        """
        Resets the UI for checking a provider. Also resets the domain
        in this object.
        """
        self.ui.lblNameResolution.setPixmap(None)
        self.ui.lblHTTPS.setPixmap(None)
        self.ui.lblProviderInfo.setPixmap(None)
        self.ui.lblProviderSelectStatus.setText("")
        self._domain = None
        self.button(QtGui.QWizard.NextButton).setEnabled(False)
        self.page(self.SELECT_PROVIDER_PAGE).set_completed(False)

    def _reset_provider_setup(self):
        """
        Resets the UI for setting up a provider.
        """
        self.ui.lblDownloadCaCert.setPixmap(None)
        self.ui.lblCheckCaFpr.setPixmap(None)
        self.ui.lblCheckApiCert.setPixmap(None)

    def _check_provider(self):
        """
        SLOT
        TRIGGERS:
          self.ui.btnCheck.clicked
          self.ui.lnProvider.returnPressed

        Starts the checks for a given provider
        """
        if len(self.ui.lnProvider.text()) == 0:
            return

        self.ui.grpCheckProvider.setVisible(True)
        self.ui.btnCheck.setEnabled(False)
        self.ui.lnProvider.setEnabled(False)
        self.button(QtGui.QWizard.BackButton).clearFocus()
        self._domain = self.ui.lnProvider.text()

        self.ui.lblNameResolution.setPixmap(self.QUESTION_ICON)
        self._provider_select_defer = self._provider_bootstrapper.\
            run_provider_select_checks(self._domain)

    def _complete_task(self, data, label, complete=False, complete_page=-1):
        """
        Checks a task and completes a page if specified

        :param data: data as it comes from the bootstrapper thread for
        a specific check
        :type data: dict
        :param label: label that displays the status icon for a
        specific check that corresponds to the data
        :type label: QtGui.QLabel
        :param complete: if True, it completes the page specified,
        which must be of type WizardPage
        :type complete: bool
        :param complete_page: page id to complete
        :type complete_page: int
        """
        passed = data[self._provider_bootstrapper.PASSED_KEY]
        error = data[self._provider_bootstrapper.ERROR_KEY]
        if passed:
            label.setPixmap(self.OK_ICON)
            if complete:
                self.page(complete_page).set_completed()
                self.button(QtGui.QWizard.NextButton).setFocus()
        else:
            label.setPixmap(self.ERROR_ICON)
            logger.error(error)

    def _name_resolution(self, data):
        """
        SLOT
        TRIGGER: self._provider_bootstrapper.name_resolution

        Sets the status for the name resolution check
        """
        self._complete_task(data, self.ui.lblNameResolution)
        status = ""
        passed = data[self._provider_bootstrapper.PASSED_KEY]
        if not passed:
            status = self.tr("<font color='red'><b>Non-existent "
                             "provider</b></font>")
        else:
            self.ui.lblHTTPS.setPixmap(self.QUESTION_ICON)
        self.ui.lblProviderSelectStatus.setText(status)
        self.ui.btnCheck.setEnabled(not passed)
        self.ui.lnProvider.setEnabled(not passed)

    def _https_connection(self, data):
        """
        SLOT
        TRIGGER: self._provider_bootstrapper.https_connection

        Sets the status for the https connection check
        """
        self._complete_task(data, self.ui.lblHTTPS)
        status = ""
        passed = data[self._provider_bootstrapper.PASSED_KEY]
        if not passed:
            status = self.tr("<font color='red'><b>%s</b></font>") \
                % (data[self._provider_bootstrapper.ERROR_KEY])
            self.ui.lblProviderSelectStatus.setText(status)
        else:
            self.ui.lblProviderInfo.setPixmap(self.QUESTION_ICON)
        self.ui.btnCheck.setEnabled(not passed)
        self.ui.lnProvider.setEnabled(not passed)

    def _download_provider_info(self, data):
        """
        SLOT
        TRIGGER: self._provider_bootstrapper.download_provider_info

        Sets the status for the provider information download
        check. Since this check is the last of this set, it also
        completes the page if passed
        """
        if self._provider_config.load(os.path.join("leap",
                                                   "providers",
                                                   self._domain,
                                                   "provider.json")):
            self._complete_task(data, self.ui.lblProviderInfo,
                                True, self.SELECT_PROVIDER_PAGE)
        else:
            new_data = {
                self._provider_bootstrapper.PASSED_KEY: False,
                self._provider_bootstrapper.ERROR_KEY:
                self.tr("Unable to load provider configuration")
            }
            self._complete_task(new_data, self.ui.lblProviderInfo)

        status = ""
        if not data[self._provider_bootstrapper.PASSED_KEY]:
            status = self.tr("<font color='red'><b>Not a valid provider"
                             "</b></font>")
            self.ui.lblProviderSelectStatus.setText(status)
        self.ui.btnCheck.setEnabled(True)
        self.ui.lnProvider.setEnabled(True)

    def _download_ca_cert(self, data):
        """
        SLOT
        TRIGGER: self._provider_bootstrapper.download_ca_cert

        Sets the status for the download of the CA certificate check
        """
        self._complete_task(data, self.ui.lblDownloadCaCert)
        passed = data[self._provider_bootstrapper.PASSED_KEY]
        if passed:
            self.ui.lblCheckCaFpr.setPixmap(self.QUESTION_ICON)

    def _check_ca_fingerprint(self, data):
        """
        SLOT
        TRIGGER: self._provider_bootstrapper.check_ca_fingerprint

        Sets the status for the CA fingerprint check
        """
        self._complete_task(data, self.ui.lblCheckCaFpr)
        passed = data[self._provider_bootstrapper.PASSED_KEY]
        if passed:
            self.ui.lblCheckApiCert.setPixmap(self.QUESTION_ICON)

    def _check_api_certificate(self, data):
        """
        SLOT
        TRIGGER: self._provider_bootstrapper.check_api_certificate

        Sets the status for the API certificate check. Also finishes
        the provider bootstrapper thread since it's not needed anymore
        from this point on, unless the whole check chain is restarted
        """
        self._complete_task(data, self.ui.lblCheckApiCert,
                            True, self.SETUP_PROVIDER_PAGE)

    def _service_selection_changed(self, service, state):
        """
        SLOT
        TRIGGER: service_checkbox.stateChanged
        Adds the service to the state if the state is checked, removes
        it otherwise

        :param service: service to handle
        :type service: str
        :param state: state of the checkbox
        :type state: int
        """
        if state == QtCore.Qt.Checked:
            self._selected_services = \
                self._selected_services.union(set([service]))
        else:
            self._selected_services = \
                self._selected_services.difference(set([service]))

    def _populate_services(self):
        """
        Loads the services that the provider provides into the UI for
        the user to enable or disable.
        """
        self.ui.grpServices.setTitle(
            self.tr("Services by %s") %
            (self._provider_config.get_name(),))

        services = get_supported(
            self._provider_config.get_services())

        for service in services:
            try:
                if service not in self._shown_services:
                    checkbox = QtGui.QCheckBox(self)
                    service_label = get_service_display_name(
                        service, self.standalone)
                    checkbox.setText(service_label)

                    self.ui.serviceListLayout.addWidget(checkbox)
                    checkbox.stateChanged.connect(
                        partial(self._service_selection_changed, service))
                    checkbox.setChecked(True)
                    self._shown_services.add(service)
            except ValueError:
                logger.error(
                    self.tr("Something went wrong while trying to "
                            "load service %s" % (service,)))

    def _current_id_changed(self, pageId):
        """
        SLOT
        TRIGGER: self.currentIdChanged

        Prepares the pages when they appear
        """
        if pageId == self.SELECT_PROVIDER_PAGE:
            self._reset_provider_check()
            self._enable_check("")

        if pageId == self.SETUP_PROVIDER_PAGE:
            self._reset_provider_setup()
            self.page(pageId).setSubTitle(self.tr("Gathering configuration "
                                                  "options for %s") %
                                          (self._provider_config
                                           .get_name(),))
            self.ui.lblDownloadCaCert.setPixmap(self.QUESTION_ICON)
            self._provider_setup_defer = self._provider_bootstrapper.\
                run_provider_setup_checks(self._provider_config)

        if pageId == self.PRESENT_PROVIDER_PAGE:
            self.page(pageId).setSubTitle(self.tr("Description of services "
                                                  "offered by %s") %
                                          (self._provider_config
                                           .get_name(),))

            lang = QtCore.QLocale.system().name()
            self.ui.lblProviderName.setText(
                "<b>%s</b>" %
                (self._provider_config.get_name(lang=lang),))
            self.ui.lblProviderURL.setText(
                "https://%s" % (self._provider_config.get_domain(),))
            self.ui.lblProviderDesc.setText(
                "<i>%s</i>" %
                (self._provider_config.get_description(lang=lang),))

            self.ui.lblServicesOffered.setText(self._provider_config
                                               .get_services_string())
            self.ui.lblProviderPolicy.setText(self._provider_config
                                              .get_enrollment_policy())

        if pageId == self.REGISTER_USER_PAGE:
            self.page(pageId).setSubTitle(self.tr("Register a new user with "
                                                  "%s") %
                                          (self._provider_config
                                           .get_name(),))
            self.ui.chkRemember.setVisible(False)

        if pageId == self.SERVICES_PAGE:
            self._populate_services()

    def _is_need_eip_password_warning(self):
        """
        Returns True if we need to add a warning about eip needing
        administrative permissions to start. That can be either
        because we are running in standalone mode, or because we could
        not find the needed privilege escalation mechanisms being operative.
        """
        return self.standalone or is_missing_policy_permissions()

    def nextId(self):
        """
        Sets the next page id for the wizard based on wether the user
        wants to register a new identity or uses an existing one
        """
        if self.currentPage() == self.page(self.INTRO_PAGE):
            self._show_register = self.ui.rdoRegister.isChecked()

        if self.currentPage() == self.page(self.SETUP_PROVIDER_PAGE):
            if self._show_register:
                return self.REGISTER_USER_PAGE
            else:
                return self.SERVICES_PAGE

        return QtGui.QWizard.nextId(self)
예제 #21
0
            return

        if status_code in self.STATUS_OK:
            self._signaler.signal(self._signaler.srp_registration_finished)
        elif status_code == self.STATUS_TAKEN:
            self._signaler.signal(self._signaler.srp_registration_taken)
        elif status_code == self.STATUS_FORBIDDEN:
            self._signaler.signal(self._signaler.srp_registration_disabled)
        else:
            self._signaler.signal(self._signaler.srp_registration_failed)


if __name__ == "__main__":
    logger = logging.getLogger(name='leap')
    logger.setLevel(logging.DEBUG)
    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        '%(asctime)s '
        '- %(name)s - %(levelname)s - %(message)s')
    console.setFormatter(formatter)
    logger.addHandler(console)

    provider = ProviderConfig()

    if provider.load("leap/providers/bitmask.net/provider.json"):
        register = SRPRegister(provider_config=provider)
        print "Registering user..."
        print register.register_user("test1", "sarasaaaa")
        print register.register_user("test2", "sarasaaaa")
예제 #22
0
class ProviderConfigTest(BaseLeapTest):
    """Tests for ProviderConfig"""
    def setUp(self):
        self._provider_config = ProviderConfig()
        json_string = json.dumps(sample_config)
        self._provider_config.load(data=json_string)

        # At certain points we are going to be replacing these method
        # to avoid creating a file.
        # We need to save the old implementation and restore it in
        # tearDown so we are sure everything is as expected for each
        # test. If we do it inside each specific test, a failure in
        # the test will leave the implementation with the mock.
        self._old_ospath_exists = os.path.exists

    def tearDown(self):
        os.path.exists = self._old_ospath_exists

    def test_configs_ok(self):
        """
        Test if the configs loads ok
        """
        # TODO: this test should go to the BaseConfig tests
        pc = self._provider_config
        self.assertEqual(pc.get_api_uri(), sample_config['api_uri'])
        self.assertEqual(pc.get_api_version(), sample_config['api_version'])
        self.assertEqual(pc.get_ca_cert_fingerprint(),
                         sample_config['ca_cert_fingerprint'])
        self.assertEqual(pc.get_ca_cert_uri(), sample_config['ca_cert_uri'])
        self.assertEqual(pc.get_default_language(),
                         sample_config['default_language'])

        self.assertEqual(pc.get_domain(), sample_config['domain'])
        self.assertEqual(pc.get_enrollment_policy(),
                         sample_config['enrollment_policy'])
        self.assertEqual(pc.get_languages(), sample_config['languages'])

    def test_localizations(self):
        pc = self._provider_config

        self.assertEqual(pc.get_description(lang='en'),
                         sample_config['description']['en'])
        self.assertEqual(pc.get_description(lang='es'),
                         sample_config['description']['es'])

        self.assertEqual(pc.get_name(lang='en'), sample_config['name']['en'])
        self.assertEqual(pc.get_name(lang='es'), sample_config['name']['es'])

    def _localize(self, lang):
        """
        Helper to change default language of the provider config.
        """
        pc = self._provider_config
        config = copy.deepcopy(sample_config)
        config['default_language'] = lang
        json_string = json.dumps(config)
        pc.load(data=json_string)

        return config

    def test_default_localization1(self):
        pc = self._provider_config
        config = self._localize(sample_config['languages'][0])

        default_language = config['default_language']
        default_description = config['description'][default_language]
        default_name = config['name'][default_language]

        self.assertEqual(pc.get_description(lang='xx'), default_description)
        self.assertEqual(pc.get_description(), default_description)

        self.assertEqual(pc.get_name(lang='xx'), default_name)
        self.assertEqual(pc.get_name(), default_name)

    def test_default_localization2(self):
        pc = self._provider_config
        config = self._localize(sample_config['languages'][1])

        default_language = config['default_language']
        default_description = config['description'][default_language]
        default_name = config['name'][default_language]

        self.assertEqual(pc.get_description(lang='xx'), default_description)
        self.assertEqual(pc.get_description(), default_description)

        self.assertEqual(pc.get_name(lang='xx'), default_name)
        self.assertEqual(pc.get_name(), default_name)

    def test_get_ca_cert_path_as_expected(self):
        pc = self._provider_config

        provider_domain = sample_config['domain']
        expected_path = os.path.join('leap', 'providers', provider_domain,
                                     'keys', 'ca', 'cacert.pem')

        # mock 'os.path.exists' so we don't get an error for unexisting file
        os.path.exists = Mock(return_value=True)
        cert_path = pc.get_ca_cert_path()

        self.assertTrue(cert_path.endswith(expected_path))

    def test_get_ca_cert_path_about_to_download(self):
        pc = self._provider_config

        provider_domain = sample_config['domain']
        expected_path = os.path.join('leap', 'providers', provider_domain,
                                     'keys', 'ca', 'cacert.pem')

        cert_path = pc.get_ca_cert_path(about_to_download=True)
        self.assertTrue(cert_path.endswith(expected_path))

    def test_get_ca_cert_path_fails(self):
        pc = self._provider_config

        # mock 'get_domain' so we don't need to load a config
        provider_domain = 'test.provider.com'
        pc.get_domain = Mock(return_value=provider_domain)

        with self.assertRaises(MissingCACert):
            pc.get_ca_cert_path()

    def test_provides_eip(self):
        pc = self._provider_config
        config = copy.deepcopy(sample_config)

        # It provides
        config['services'] = ['openvpn', 'test_service']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertTrue(pc.provides_eip())

        # It does not provides
        config['services'] = ['test_service', 'other_service']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertFalse(pc.provides_eip())

    def test_provides_mx(self):
        pc = self._provider_config
        config = copy.deepcopy(sample_config)

        # It provides
        config['services'] = ['mx', 'other_service']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertTrue(pc.provides_mx())

        # It does not provides
        config['services'] = ['test_service', 'other_service']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertFalse(pc.provides_mx())

    def test_supports_unknown_service(self):
        pc = self._provider_config
        config = copy.deepcopy(sample_config)

        config['services'] = ['unknown']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertFalse('unknown' in get_supported(pc.get_services()))

    def test_provides_unknown_service(self):
        pc = self._provider_config
        config = copy.deepcopy(sample_config)

        config['services'] = ['unknown']
        json_string = json.dumps(config)
        pc.load(data=json_string)
        self.assertTrue('unknown' in pc.get_services())
예제 #23
0
# EDIT THIS --------------------------------------------
user = u"USERNAME"
uuid = u"USERUUID"
_pass = u"USERPASS"
server_url = "https://soledad.server.example.org:2323"
# EDIT THIS --------------------------------------------

secrets_path = "/tmp/%s.secrets" % uuid
local_db_path = "/tmp/%s.soledad" % uuid
cert_file = "/tmp/cacert.pem"
provider_config = '/tmp/cdev.json'


provider = ProviderConfig()
provider.load(provider_config)

soledad = None


def printStuff(r):
    print r


def printErr(err):
    logging.exception(err.value)


def init_soledad(_):
    token = srpauth.get_token()
    print "token", token
    def _download_provider_info(self, *args):
        """
        Downloads the provider.json defition
        """
        leap_assert(self._domain,
                    "Cannot download provider info without a domain")

        logger.debug("Downloading provider info for %s" % (self._domain))

        headers = {}

        provider_json = os.path.join(
            ProviderConfig().get_path_prefix(), "leap", "providers",
            self._domain, "provider.json")
        mtime = get_mtime(provider_json)

        if self._download_if_needed and mtime:
            headers['if-modified-since'] = mtime

        uri = "https://%s/%s" % (self._domain, "provider.json")
        verify = not self._bypass_checks

        if mtime:  # the provider.json exists
            provider_config = ProviderConfig()
            provider_config.load(provider_json)
            try:
                verify = provider_config.get_ca_cert_path()
                uri = provider_config.get_api_uri() + '/provider.json'
            except MissingCACert:
                # get_ca_cert_path fails if the certificate does not exists.
                pass

        logger.debug("Requesting for provider.json... "
                     "uri: {0}, verify: {1}, headers: {2}".format(
                         uri, verify, headers))
        res = self._session.get(uri, verify=verify,
                                headers=headers, timeout=REQUEST_TIMEOUT)
        res.raise_for_status()
        logger.debug("Request status code: {0}".format(res.status_code))

        # Not modified
        if res.status_code == 304:
            logger.debug("Provider definition has not been modified")
        else:
            provider_definition, mtime = get_content(res)

            provider_config = ProviderConfig()
            provider_config.load(data=provider_definition, mtime=mtime)
            provider_config.save(["leap",
                                  "providers",
                                  self._domain,
                                  "provider.json"])

            api_version = provider_config.get_api_version()
            if SupportedAPIs.supports(api_version):
                logger.debug("Provider definition has been modified")
            else:
                api_supported = ', '.join(SupportedAPIs.SUPPORTED_APIS)
                error = ('Unsupported provider API version. '
                         'Supported versions are: {}. '
                         'Found: {}.').format(api_supported, api_version)

                logger.error(error)
                raise UnsupportedProviderAPI(error)