def setUp(self):
        self.pb = ProviderBootstrapper()

        # At certain points we are going to be replacing these methods
        # directly in ProviderConfig to be able to catch calls from
        # new ProviderConfig objects inside the methods tested. We
        # need to save the old implementation and restore it in
        # tearDown so we are sure everything is as expected for each
        # test. If we do it inside each specific test, a failure in
        # the test will leave the implementation with the mock.
        self.old_gpp = ProviderConfig.get_path_prefix
        self.old_load = ProviderConfig.load
        self.old_save = ProviderConfig.save
        self.old_api_version = ProviderConfig.get_api_version
 def setUp(self):
     self.pb = ProviderBootstrapper()
class ProviderBootstrapperTest(BaseLeapTest):
    def setUp(self):
        self.pb = ProviderBootstrapper()

    def tearDown(self):
        pass

    def test_name_resolution_check(self):
        # Something highly likely to success
        self.pb._domain = "google.com"
        self.pb._check_name_resolution()
        # Something highly likely to fail
        self.pb._domain = "uquhqweuihowquie.abc.def"

        # In python 2.7.4 raises socket.error
        # In python 2.7.5 raises socket.gaierror
        with self.assertRaises((socket.gaierror, socket.error)):
            self.pb._check_name_resolution()

    @deferred()
    def test_run_provider_select_checks(self):
        self.pb._check_name_resolution = mock.MagicMock()
        self.pb._check_https = mock.MagicMock()
        self.pb._download_provider_info = mock.MagicMock()

        d = self.pb.run_provider_select_checks("somedomain")

        def check(*args):
            self.pb._check_name_resolution.assert_called_once_with()
            self.pb._check_https.assert_called_once_with(None)
            self.pb._download_provider_info.assert_called_once_with(None)
        d.addCallback(check)
        return d

    @deferred()
    def test_run_provider_setup_checks(self):
        self.pb._download_ca_cert = mock.MagicMock()
        self.pb._check_ca_fingerprint = mock.MagicMock()
        self.pb._check_api_certificate = mock.MagicMock()

        d = self.pb.run_provider_setup_checks(ProviderConfig())

        def check(*args):
            self.pb._download_ca_cert.assert_called_once_with()
            self.pb._check_ca_fingerprint.assert_called_once_with(None)
            self.pb._check_api_certificate.assert_called_once_with(None)
        d.addCallback(check)
        return d

    def test_should_proceed_cert(self):
        self.pb._provider_config = mock.Mock()
        self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
            return_value=where("cacert.pem"))

        self.pb._download_if_needed = False
        self.assertTrue(self.pb._should_proceed_cert())

        self.pb._download_if_needed = True
        self.assertFalse(self.pb._should_proceed_cert())

        self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
            return_value=where("somefilethatdoesntexist.pem"))
        self.assertTrue(self.pb._should_proceed_cert())

    def _check_download_ca_cert(self, should_proceed):
        """
        Helper to check different paths easily for the download ca
        cert check

        :param should_proceed: sets the _should_proceed_cert in the
                               provider bootstrapper being tested
        :type should_proceed: bool

        :returns: The contents of the certificate, the expected
                  content depending on should_proceed, and the mode of
                  the file to be checked by the caller
        :rtype: tuple of str, str, int
        """
        old_content = "NOT THE NEW CERT"
        new_content = "NEW CERT"
        new_cert_path = os.path.join(tempfile.mkdtemp(),
                                     "mynewcert.pem")

        with open(new_cert_path, "w") as c:
            c.write(old_content)

        self.pb._provider_config = mock.Mock()
        self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
            return_value=new_cert_path)
        self.pb._domain = "somedomain"

        self.pb._should_proceed_cert = mock.MagicMock(
            return_value=should_proceed)

        read = None
        content_to_check = None
        mode = None

        with mock.patch('requests.models.Response.content',
                        new_callable=mock.PropertyMock) as \
                content:
            content.return_value = new_content
            response_obj = Response()
            response_obj.raise_for_status = mock.MagicMock()

            self.pb._session.get = mock.MagicMock(return_value=response_obj)
            self.pb._download_ca_cert()
            with open(new_cert_path, "r") as nc:
                read = nc.read()
                if should_proceed:
                    content_to_check = new_content
                else:
                    content_to_check = old_content
            mode = stat.S_IMODE(os.stat(new_cert_path).st_mode)

        os.unlink(new_cert_path)
        return read, content_to_check, mode

    def test_download_ca_cert_no_saving(self):
        read, expected_read, mode = self._check_download_ca_cert(False)
        self.assertEqual(read, expected_read)
        self.assertEqual(mode, int("600", 8))

    def test_download_ca_cert_saving(self):
        read, expected_read, mode = self._check_download_ca_cert(True)
        self.assertEqual(read, expected_read)
        self.assertEqual(mode, int("600", 8))

    def test_check_ca_fingerprint_skips(self):
        self.pb._provider_config = mock.Mock()
        self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock(
            return_value="")
        self.pb._domain = "somedomain"

        self.pb._should_proceed_cert = mock.MagicMock(return_value=False)

        self.pb._check_ca_fingerprint()
        self.assertFalse(self.pb._provider_config.
                         get_ca_cert_fingerprint.called)

    def test_check_ca_cert_fingerprint_raises_bad_format(self):
        self.pb._provider_config = mock.Mock()
        self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock(
            return_value="wrongfprformat!!")
        self.pb._domain = "somedomain"

        self.pb._should_proceed_cert = mock.MagicMock(return_value=True)

        with self.assertRaises(WrongFingerprint):
            self.pb._check_ca_fingerprint()

    # This two hashes different in the last byte, but that's good enough
    # for the tests
    KNOWN_BAD_HASH = "SHA256: 0f17c033115f6b76ff67871872303ff65034efe" \
                     "7dd1b910062ca323eb4da5c7f"
    KNOWN_GOOD_HASH = "SHA256: 0f17c033115f6b76ff67871872303ff65034ef" \
                      "e7dd1b910062ca323eb4da5c7e"
    KNOWN_GOOD_CERT = """
-----BEGIN CERTIFICATE-----
MIIFbzCCA1egAwIBAgIBATANBgkqhkiG9w0BAQ0FADBKMRgwFgYDVQQDDA9CaXRt
YXNrIFJvb3QgQ0ExEDAOBgNVBAoMB0JpdG1hc2sxHDAaBgNVBAsME2h0dHBzOi8v
Yml0bWFzay5uZXQwHhcNMTIxMTA2MDAwMDAwWhcNMjIxMTA2MDAwMDAwWjBKMRgw
FgYDVQQDDA9CaXRtYXNrIFJvb3QgQ0ExEDAOBgNVBAoMB0JpdG1hc2sxHDAaBgNV
BAsME2h0dHBzOi8vYml0bWFzay5uZXQwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAw
ggIKAoICAQC1eV4YvayaU+maJbWrD4OHo3d7S1BtDlcvkIRS1Fw3iYDjsyDkZxai
dHp4EUasfNQ+EVtXUvtk6170EmLco6Elg8SJBQ27trE6nielPRPCfX3fQzETRfvB
7tNvGw4Jn2YKiYoMD79kkjgyZjkJ2r/bEHUSevmR09BRp86syHZerdNGpXYhcQ84
CA1+V+603GFIHnrP+uQDdssW93rgDNYu+exT+Wj6STfnUkugyjmPRPjL7wh0tzy+
znCeLl4xiV3g9sjPnc7r2EQKd5uaTe3j71sDPF92KRk0SSUndREz+B1+Dbe/RGk4
MEqGFuOzrtsgEhPIX0hplhb0Tgz/rtug+yTT7oJjBa3u20AAOQ38/M99EfdeJvc4
lPFF1XBBLh6X9UKF72an2NuANiX6XPySnJgZ7nZ09RiYZqVwu/qt3DfvLfhboq+0
bQvLUPXrVDr70onv5UDjpmEA/cLmaIqqrduuTkFZOym65/PfAPvpGnt7crQj/Ibl
DEDYZQmP7AS+6zBjoOzNjUGE5r40zWAR1RSi7zliXTu+yfsjXUIhUAWmYR6J3KxB
lfsiHBQ+8dn9kC3YrUexWoOqBiqJOAJzZh5Y1tqgzfh+2nmHSB2dsQRs7rDRRlyy
YMbkpzL9ZsOUO2eTP1mmar6YjCN+rggYjRrX71K2SpBG6b1zZxOG+wIDAQABo2Aw
XjAdBgNVHQ4EFgQUuYGDLL2sswnYpHHvProt1JU+D48wDgYDVR0PAQH/BAQDAgIE
MAwGA1UdEwQFMAMBAf8wHwYDVR0jBBgwFoAUuYGDLL2sswnYpHHvProt1JU+D48w
DQYJKoZIhvcNAQENBQADggIBADeG67vaFcbITGpi51264kHPYPEWaXUa5XYbtmBl
cXYyB6hY5hv/YNuVGJ1gWsDmdeXEyj0j2icGQjYdHRfwhrbEri+h1EZOm1cSBDuY
k/P5+ctHyOXx8IE79DBsZ6IL61UKIaKhqZBfLGYcWu17DVV6+LT+AKtHhOrv3TSj
RnAcKnCbKqXLhUPXpK0eTjPYS2zQGQGIhIy9sQXVXJJJsGrPgMxna1Xw2JikBOCG
htD/JKwt6xBmNwktH0GI/LVtVgSp82Clbn9C4eZN9E5YbVYjLkIEDhpByeC71QhX
EIQ0ZR56bFuJA/CwValBqV/G9gscTPQqd+iETp8yrFpAVHOW+YzSFbxjTEkBte1J
aF0vmbqdMAWLk+LEFPQRptZh0B88igtx6tV5oVd+p5IVRM49poLhuPNJGPvMj99l
mlZ4+AeRUnbOOeAEuvpLJbel4rhwFzmUiGoeTVoPZyMevWcVFq6BMkS+jRR2w0jK
G6b0v5XDHlcFYPOgUrtsOBFJVwbutLvxdk6q37kIFnWCd8L3kmES5q4wjyFK47Co
Ja8zlx64jmMZPg/t3wWqkZgXZ14qnbyG5/lGsj5CwVtfDljrhN0oCWK1FZaUmW3d
69db12/g4f6phldhxiWuGC/W6fCW5kre7nmhshcltqAJJuU47iX+DarBFiIj816e
yV8e
-----END CERTIFICATE-----
"""

    def _prepare_provider_config_with(self, cert_path, cert_hash):
        """
        Mocks the provider config to give the cert_path and cert_hash
        specified

        :param cert_path: path for the certificate
        :type cert_path: str
        :param cert_hash: hash for the certificate as it would appear
                          in the provider config json
        :type cert_hash: str
        """
        self.pb._provider_config = mock.Mock()
        self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock(
            return_value=cert_hash)
        self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
            return_value=cert_path)
        self.pb._domain = "somedomain"

    def test_check_ca_fingerprint_checksout(self):
        cert_path = os.path.join(tempfile.mkdtemp(),
                                 "mynewcert.pem")

        with open(cert_path, "w") as c:
            c.write(self.KNOWN_GOOD_CERT)

        self._prepare_provider_config_with(cert_path, self.KNOWN_GOOD_HASH)

        self.pb._should_proceed_cert = mock.MagicMock(return_value=True)

        self.pb._check_ca_fingerprint()

        os.unlink(cert_path)

    def test_check_ca_fingerprint_fails(self):
        cert_path = os.path.join(tempfile.mkdtemp(),
                                 "mynewcert.pem")

        with open(cert_path, "w") as c:
            c.write(self.KNOWN_GOOD_CERT)

        self._prepare_provider_config_with(cert_path, self.KNOWN_BAD_HASH)

        self.pb._should_proceed_cert = mock.MagicMock(return_value=True)

        with self.assertRaises(WrongFingerprint):
            self.pb._check_ca_fingerprint()

        os.unlink(cert_path)
class ProviderBootstrapperActiveTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        factory = fake_provider.get_provider_factory()
        http = reactor.listenTCP(8002, factory)
        https = reactor.listenSSL(
            0, factory,
            fake_provider.OpenSSLServerContextFactory())
        get_port = lambda p: p.getHost().port
        cls.http_port = get_port(http)
        cls.https_port = get_port(https)

    def setUp(self):
        self.pb = ProviderBootstrapper()

        # At certain points we are going to be replacing these methods
        # directly in ProviderConfig to be able to catch calls from
        # new ProviderConfig objects inside the methods tested. We
        # need to save the old implementation and restore it in
        # tearDown so we are sure everything is as expected for each
        # test. If we do it inside each specific test, a failure in
        # the test will leave the implementation with the mock.
        self.old_gpp = ProviderConfig.get_path_prefix
        self.old_load = ProviderConfig.load
        self.old_save = ProviderConfig.save
        self.old_api_version = ProviderConfig.get_api_version

    def tearDown(self):
        ProviderConfig.get_path_prefix = self.old_gpp
        ProviderConfig.load = self.old_load
        ProviderConfig.save = self.old_save
        ProviderConfig.get_api_version = self.old_api_version

    def test_check_https_succeeds(self):
        # XXX: Need a proper CA signed cert to test this
        pass

    @deferred()
    def test_check_https_fails(self):
        self.pb._domain = "localhost:%s" % (self.https_port,)

        def check(*args):
            with self.assertRaises(requests.exceptions.SSLError):
                self.pb._check_https()
        return threads.deferToThread(check)

    @deferred()
    def test_second_check_https_fails(self):
        self.pb._domain = "localhost:1234"

        def check(*args):
            with self.assertRaises(Exception):
                self.pb._check_https()
        return threads.deferToThread(check)

    @deferred()
    def test_check_https_succeeds_if_danger(self):
        self.pb._domain = "localhost:%s" % (self.https_port,)
        self.pb._bypass_checks = True

        def check(*args):
            self.pb._check_https()

        return threads.deferToThread(check)

    def _setup_provider_config_with(self, api, path_prefix):
        """
        Sets up the ProviderConfig with mocks for the path prefix, the
        api returned and load/save methods.
        It modifies ProviderConfig directly instead of an object
        because the object used is created in the method itself and we
        cannot control that.

        :param api: API to return
        :type api: str
        :param path_prefix: path prefix to be used when calculating
                            paths
        :type path_prefix: str
        """
        ProviderConfig.get_path_prefix = mock.MagicMock(
            return_value=path_prefix)
        ProviderConfig.get_api_version = mock.MagicMock(
            return_value=api)
        ProviderConfig.load = mock.MagicMock()
        ProviderConfig.save = mock.MagicMock()

    def _setup_providerbootstrapper(self, ifneeded):
        """
        Sets the provider bootstrapper's domain to
        localhost:https_port, sets it to bypass https checks and sets
        the download if needed based on the ifneeded value.

        :param ifneeded: Value for _download_if_needed
        :type ifneeded: bool
        """
        self.pb._domain = "localhost:%s" % (self.https_port,)
        self.pb._bypass_checks = True
        self.pb._download_if_needed = ifneeded

    def _produce_dummy_provider_json(self):
        """
        Creates a dummy provider json on disk in order to test
        behaviour around it (download if newer online, etc)

        :returns: the provider.json path used
        :rtype: str
        """
        provider_dir = os.path.join(ProviderConfig()
                                    .get_path_prefix(),
                                    "leap",
                                    "providers",
                                    self.pb._domain)
        mkdir_p(provider_dir)
        provider_path = os.path.join(provider_dir,
                                     "provider.json")

        with open(provider_path, "w") as p:
            p.write("A")
        return provider_path

    def test_download_provider_info_new_provider(self):
        self._setup_provider_config_with("1", tempfile.mkdtemp())
        self._setup_providerbootstrapper(True)

        self.pb._download_provider_info()
        self.assertTrue(ProviderConfig.save.called)

    @mock.patch(
        'leap.bitmask.config.providerconfig.ProviderConfig.get_ca_cert_path',
        lambda x: where('cacert.pem'))
    def test_download_provider_info_not_modified(self):
        self._setup_provider_config_with("1", tempfile.mkdtemp())
        self._setup_providerbootstrapper(True)
        provider_path = self._produce_dummy_provider_json()

        # set mtime to something really new
        os.utime(provider_path, (-1, time.time()))

        with mock.patch.object(
                ProviderConfig, 'get_api_uri',
                return_value="https://localhost:%s" % (self.https_port,)):
            self.pb._download_provider_info()
        # we check that it doesn't save the provider
        # config, because it's new enough
        self.assertFalse(ProviderConfig.save.called)

    @mock.patch(
        'leap.bitmask.config.providerconfig.ProviderConfig.get_domain',
        lambda x: where('testdomain.com'))
    def test_download_provider_info_not_modified_and_no_cacert(self):
        self._setup_provider_config_with("1", tempfile.mkdtemp())
        self._setup_providerbootstrapper(True)
        provider_path = self._produce_dummy_provider_json()

        # set mtime to something really new
        os.utime(provider_path, (-1, time.time()))

        with mock.patch.object(
                ProviderConfig, 'get_api_uri',
                return_value="https://localhost:%s" % (self.https_port,)):
            self.pb._download_provider_info()
        # we check that it doesn't save the provider
        # config, because it's new enough
        self.assertFalse(ProviderConfig.save.called)

    @mock.patch(
        'leap.bitmask.config.providerconfig.ProviderConfig.get_ca_cert_path',
        lambda x: where('cacert.pem'))
    def test_download_provider_info_modified(self):
        self._setup_provider_config_with("1", tempfile.mkdtemp())
        self._setup_providerbootstrapper(True)
        provider_path = self._produce_dummy_provider_json()

        # set mtime to something really old
        os.utime(provider_path, (-1, 100))

        with mock.patch.object(
                ProviderConfig, 'get_api_uri',
                return_value="https://localhost:%s" % (self.https_port,)):
            self.pb._download_provider_info()
        self.assertTrue(ProviderConfig.load.called)
        self.assertTrue(ProviderConfig.save.called)

    @mock.patch(
        'leap.bitmask.config.providerconfig.ProviderConfig.get_ca_cert_path',
        lambda x: where('cacert.pem'))
    def test_download_provider_info_unsupported_api_raises(self):
        self._setup_provider_config_with("9999999", tempfile.mkdtemp())
        self._setup_providerbootstrapper(False)
        self._produce_dummy_provider_json()

        with mock.patch.object(
                ProviderConfig, 'get_api_uri',
                return_value="https://localhost:%s" % (self.https_port,)):
            with self.assertRaises(UnsupportedProviderAPI):
                self.pb._download_provider_info()

    @mock.patch(
        'leap.bitmask.config.providerconfig.ProviderConfig.get_ca_cert_path',
        lambda x: where('cacert.pem'))
    def test_download_provider_info_unsupported_api(self):
        self._setup_provider_config_with(SupportedAPIs.SUPPORTED_APIS[0],
                                         tempfile.mkdtemp())
        self._setup_providerbootstrapper(False)
        self._produce_dummy_provider_json()

        with mock.patch.object(
                ProviderConfig, 'get_api_uri',
                return_value="https://localhost:%s" % (self.https_port,)):
            self.pb._download_provider_info()

    @mock.patch(
        'leap.bitmask.config.providerconfig.ProviderConfig.get_api_uri',
        lambda x: 'api.uri')
    @mock.patch(
        'leap.bitmask.config.providerconfig.ProviderConfig.get_ca_cert_path',
        lambda x: '/cert/path')
    def test_check_api_certificate_skips(self):
        self.pb._provider_config = ProviderConfig()
        self.pb._session.get = mock.MagicMock(return_value=Response())

        self.pb._should_proceed_cert = mock.MagicMock(return_value=False)
        self.pb._check_api_certificate()
        self.assertFalse(self.pb._session.get.called)

    @deferred()
    def test_check_api_certificate_fails(self):
        self.pb._provider_config = ProviderConfig()
        self.pb._provider_config.get_api_uri = mock.MagicMock(
            return_value="https://localhost:%s" % (self.https_port,))
        self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
            return_value=os.path.join(
                os.path.split(__file__)[0],
                "wrongcert.pem"))
        self.pb._provider_config.get_api_version = mock.MagicMock(
            return_value="1")

        self.pb._should_proceed_cert = mock.MagicMock(return_value=True)

        def check(*args):
            with self.assertRaises(requests.exceptions.SSLError):
                self.pb._check_api_certificate()
        d = threads.deferToThread(check)
        return d

    @deferred()
    def test_check_api_certificate_succeeds(self):
        self.pb._provider_config = ProviderConfig()
        self.pb._provider_config.get_api_uri = mock.MagicMock(
            return_value="https://localhost:%s" % (self.https_port,))
        self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
            return_value=where('cacert.pem'))
        self.pb._provider_config.get_api_version = mock.MagicMock(
            return_value="1")

        self.pb._should_proceed_cert = mock.MagicMock(return_value=True)

        def check(*args):
            self.pb._check_api_certificate()
        d = threads.deferToThread(check)
        return d