def __init__(self, args, loop, tasks: List[List[Union[AnyStr, Dict]]]=None): self.config = args self.loop = loop self.connection = None self.pushEndpoint = None self.channelID = None self.notifications = [] self.uaid = None self.recv = [] self.tasks = tasks or [] self.output = None self.vapid_cache = {} if args.vapid_key: self.vapid = Vapid().from_file(args.vapid_key) else: self.vapid = Vapid() self.vapid.generate_keys() self.tls_conn = None if args.partner_endpoint_cert: if os.path.isfile(args.partner_endpoint_cert): context = ssl.create_default_context( cafile=args.partner_endpoint_cert) else: context = ssl.create_default_context( cadata=args.partner_endpoint_cert) context.verify_mode = ssl.CERT_REQUIRED context.check_hostname = False self.tls_conn = aiohttp.TCPConnector(ssl_context=context)
def __init__(self, db_name="subscriptors.db", verbose=False): """ Class constructor. :param db_name: The [optional] name ("subscriptors.db" by default) of the file in which subscriptions will be stored in. This is only required if methods like ``newSubscription`` will be used. :type db_name: str :param verbose: An optional value, to enabled or disabled the "verbose mode" (False by default) :type verbose: bool """ self.__verbose__ = verbose self.__db_name__ = db_name if not os.path.exists('private_key.pem'): self.__print__("No private_key.pem file found") Vapid().save_key('private_key.pem') self.__print__("private_key.pem file created") self.__vapid__ = Vapid('private_key.pem') if not os.path.exists('public_key.pem'): self.__print__("No public_key.pem file found") self.__vapid__.save_public_key('public_key.pem') self.__print__("public_key.pem file created") if verbose: self.__print__("PublicKey: %s" % self.getB64PublicKey())
def __init__(self, load_runner, websocket_url, statsd_client, scenario, endpoint=None, endpoint_ssl_cert=None, endpoint_ssl_key=None, *scenario_args, **scenario_kw): logging.debug("Connecting to {}".format(websocket_url)) self._factory = WebSocketClientFactory( websocket_url, headers={"Origin": "http://localhost:9000"}) self._factory.protocol = WSClientProtocol self._factory.harness = self if websocket_url.startswith("wss"): self._factory_context = ssl.ClientContextFactory() else: self._factory_context = None # somewhat bogus encryption headers self._crypto_key = "keyid=p256dh;dh=c2VuZGVy" self._encryption = "keyid=p256dh;salt=XZwpw6o37R-6qoZjw6KwAw" # Processor and Websocket client vars self._scenario = scenario self._scenario_args = scenario_args self._scenario_kw = scenario_kw self._processors = 0 self._ws_clients = {} self._connect_waiters = deque() self._load_runner = load_runner self._stat_client = statsd_client self._vapid = Vapid() if "vapid_private_key" in self._scenario_kw: self._vapid = Vapid( private_key=self._scenario_kw.get("vapid_private_key")) else: self._vapid.generate_keys() self._claims = () if "vapid_claims" in self._scenario_kw: self._claims = self._scenario_kw.get("vapid_claims") self._endpoint = urlparse.urlparse(endpoint) if endpoint else None self._agent = None if endpoint_ssl_cert: self._agent = Agent(reactor, contextFactory=UnverifiedHTTPS( endpoint_ssl_cert, endpoint_ssl_key)) if hasattr(endpoint_ssl_cert, 'seek'): endpoint_ssl_cert.seek(0) if endpoint_ssl_key and hasattr(endpoint_ssl_key, 'seek'): endpoint_ssl_key.seek(0)
def gen_application_server_keys(): """ Generate Vapid key pair """ vapid = Vapid() vapid.generate_keys() vapid.save_key(settings.VAPID_PRIVATE_KEY) vapid.save_public_key(settings.VAPID_PUBLIC_KEY)
def __init__(self, name, sygnal, config): super(WebpushPushkin, self).__init__(name, sygnal, config) nonunderstood = self.cfg.keys() - self.UNDERSTOOD_CONFIG_FIELDS if nonunderstood: logger.warning( "The following configuration fields are not understood: %s", nonunderstood, ) self.http_pool = HTTPConnectionPool(reactor=sygnal.reactor) self.max_connections = self.get_config("max_connections", DEFAULT_MAX_CONNECTIONS) self.connection_semaphore = DeferredSemaphore(self.max_connections) self.http_pool.maxPersistentPerHost = self.max_connections tls_client_options_factory = ClientTLSOptionsFactory() # use the Sygnal global proxy configuration proxy_url = sygnal.config.get("proxy") self.http_agent = ProxyAgent( reactor=sygnal.reactor, pool=self.http_pool, contextFactory=tls_client_options_factory, proxy_url_str=proxy_url, ) self.http_agent_wrapper = HttpAgentWrapper(self.http_agent) self.allowed_endpoints = None # type: Optional[List[Pattern]] allowed_endpoints = self.get_config("allowed_endpoints") if allowed_endpoints: if not isinstance(allowed_endpoints, list): raise PushkinSetupException( "'allowed_endpoints' should be a list or not set") self.allowed_endpoints = list(map(glob_to_regex, allowed_endpoints)) privkey_filename = self.get_config("vapid_private_key") if not privkey_filename: raise PushkinSetupException( "'vapid_private_key' not set in config") if not os.path.exists(privkey_filename): raise PushkinSetupException( "path in 'vapid_private_key' does not exist") try: self.vapid_private_key = Vapid.from_file( private_key_file=privkey_filename) except VapidException as e: raise PushkinSetupException( "invalid 'vapid_private_key' file") from e self.vapid_contact_email = self.get_config("vapid_contact_email") if not self.vapid_contact_email: raise PushkinSetupException( "'vapid_contact_email' not set in config") self.ttl = self.get_config("ttl", DEFAULT_TTL) if not isinstance(self.ttl, int): raise PushkinSetupException("'ttl' must be an int if set")
def create_vapid_headers(vapid_email, subscription_info, vapid_private_key): """Create encrypted headers to send to WebPusher.""" if vapid_email and vapid_private_key and ATTR_ENDPOINT in subscription_info: url = urlparse(subscription_info.get(ATTR_ENDPOINT)) vapid_claims = { "sub": f"mailto:{vapid_email}", "aud": f"{url.scheme}://{url.netloc}", } vapid = Vapid.from_string(private_key=vapid_private_key) return vapid.sign(vapid_claims) return None
def create_vapid_headers(vapid_email, subscription_info, vapid_private_key): """Create encrypted headers to send to WebPusher.""" from py_vapid import Vapid if vapid_email and vapid_private_key and ATTR_ENDPOINT in subscription_info: url = urlparse(subscription_info.get(ATTR_ENDPOINT)) vapid_claims = { "sub": "mailto:{}".format(vapid_email), "aud": "{}://{}".format(url.scheme, url.netloc), } vapid = Vapid.from_string(private_key=vapid_private_key) return vapid.sign(vapid_claims) return None
def test_init(self): args = TrialSettings() args.vapid_key = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex) vapid = Vapid() vapid.generate_keys() vapid.save_key(args.vapid_key) client = PushClient(loop=self.loop, args=args) assert (client.vapid.public_key.public_numbers().encode_point == vapid.public_key.public_numbers().encode_point) os.unlink(args.vapid_key)
def create_vapid_headers(vapid_email, subscription_info, vapid_private_key): """Create encrypted headers to send to WebPusher.""" from py_vapid import Vapid try: from urllib.parse import urlparse except ImportError: # pragma: no cover from urlparse import urlparse if (vapid_email and vapid_private_key and ATTR_ENDPOINT in subscription_info): url = urlparse(subscription_info.get(ATTR_ENDPOINT)) vapid_claims = { 'sub': 'mailto:{}'.format(vapid_email), 'aud': "{}://{}".format(url.scheme, url.netloc) } vapid = Vapid.from_string(private_key=vapid_private_key) return vapid.sign(vapid_claims) return None
def generate_vapid_headers(private_key_data, endpoint): """ Generate vapid headers for web push call :param private_key_data: private key string data :param endpoint: endpoint URL from subscription info :return: vapid Authorization header """ url = urlparse(endpoint) aud = "{}://{}".format(url.scheme, url.netloc) vapid_claims = { "aud": aud, "exp": int(time.time()) + 86400, "sub": "mailto:[email protected]" } vapid_key = Vapid.from_pem(private_key=private_key_data.encode()) headers = vapid_key.sign(vapid_claims) return headers
class RunnerHarness(object): """Runs multiple instances of a single scenario Running an instance of the scenario is triggered with :meth:`run`. It will run to completion or possibly forever. """ def __init__(self, load_runner, websocket_url, statsd_client, scenario, endpoint=None, endpoint_ssl_cert=None, endpoint_ssl_key=None, *scenario_args, **scenario_kw): logging.debug("Connecting to {}".format(websocket_url)) self._factory = WebSocketClientFactory( websocket_url, headers={"Origin": "http://localhost:9000"}) self._factory.protocol = WSClientProtocol self._factory.harness = self if websocket_url.startswith("wss"): self._factory_context = ssl.ClientContextFactory() else: self._factory_context = None # somewhat bogus encryption headers self._crypto_key = "keyid=p256dh;dh=c2VuZGVy" self._encryption = "keyid=p256dh;salt=XZwpw6o37R-6qoZjw6KwAw" # Processor and Websocket client vars self._scenario = scenario self._scenario_args = scenario_args self._scenario_kw = scenario_kw self._processors = 0 self._ws_clients = {} self._connect_waiters = deque() self._load_runner = load_runner self._stat_client = statsd_client self._vapid = Vapid() if "vapid_private_key" in self._scenario_kw: self._vapid = Vapid( private_key=self._scenario_kw.get("vapid_private_key")) else: self._vapid.generate_keys() self._claims = () if "vapid_claims" in self._scenario_kw: self._claims = self._scenario_kw.get("vapid_claims") self._endpoint = urlparse.urlparse(endpoint) if endpoint else None self._agent = None if endpoint_ssl_cert: self._agent = Agent(reactor, contextFactory=UnverifiedHTTPS( endpoint_ssl_cert, endpoint_ssl_key)) if hasattr(endpoint_ssl_cert, 'seek'): endpoint_ssl_cert.seek(0) if endpoint_ssl_key and hasattr(endpoint_ssl_key, 'seek'): endpoint_ssl_key.seek(0) def run(self): """Start registered scenario""" # Create the processor and start it processor = CommandProcessor(self._scenario, self._scenario_args, self._scenario_kw, self) processor.run() self._processors += 1 def spawn(self, test_plan): """Spawn a new test plan""" self._load_runner.spawn(test_plan) def connect(self, processor): """Start a connection for a processor and queue it for when the connection is available""" self._connect_waiters.append(processor) connectWS(self._factory, contextFactory=self._factory_context) def send_notification(self, processor, url, data, headers=None, claims=None): """Send out a notification to a url for a processor This uses the older `aesgcm` format. """ if not headers: headers = {} url = url.encode("utf-8") if "TTL" not in headers: headers["TTL"] = "0" crypto_key = self._crypto_key if claims is None: claims = () claims = claims or self._claims if self._vapid and claims: if isinstance(claims, str): claims = json.loads(claims) if "aud" not in claims: # Construct a valid `aud` from the known endpoint parsed = urlparse.urlparse(url) claims["aud"] = "{scheme}://{netloc}".format( scheme=parsed.scheme, netloc=parsed.netloc) log.msg("Setting VAPID 'aud' to {}".format(claims["aud"])) headers.update(self._vapid.sign(claims, self._crypto_key)) if data: headers.update({ "Content-Type": "application/octet-stream", "Content-Encoding": "aesgcm", "Crypto-key": crypto_key, "Encryption": self._encryption, }) d = treq.post(url, data=data, headers=headers, allow_redirects=False, agent=self._agent) d.addCallback(self._sent_notification, processor) d.addErrback(self._error_notif, processor) def _sent_notification(self, result, processor): d = result.content() d.addCallback(self._finished_notification, result, processor) d.addErrback(self._error_notif, result, processor) def _finished_notification(self, result, response, processor): # Give the fully read content and response to the processor processor._send_command_result((response, result)) def _error_notif(self, failure, processor): # Send the failure back processor._send_command_result((None, failure)) def add_client(self, ws_client): """Register a new websocket connection and return a waiting processor""" try: processor = self._connect_waiters.popleft() except IndexError: log.msg("No waiting processors for new client connection.") ws_client.sendClose() else: self._ws_clients[ws_client] = processor return processor def remove_client(self, ws_client): """Remove a websocket connection from the client registry""" processor = self._ws_clients.pop(ws_client, None) if not processor: # Possible failed connection, if we have waiting processors still # then try a new connection if len(self._connect_waiters): connectWS(self._factory, contextFactory=self._factory_context) return def remove_processor(self): """Remove a completed processor""" self._processors -= 1 def timer(self, name, duration): """Record a metric timer if we have a statsd client""" self._stat_client.timing(name, duration) def counter(self, name, count=1): """Record a counter if we have a statsd client""" self._stat_client.increment(name, count)
def webpush(subscription_info, data=None, vapid_private_key=None, vapid_claims=None, content_encoding="aesgcm", curl=False): """ One call solution to endcode and send `data` to the endpoint contained in `subscription_info` using optional VAPID auth headers. in example: .. code-block:: python from pywebpush import python webpush( subscription_info={ "endpoint": "https://push.example.com/v1/abcd", "keys": {"p256dh": "0123abcd...", "auth": "001122..."} }, data="Mary had a little lamb, with a nice mint jelly", vapid_private_key="path/to/key.pem", vapid_claims={"sub": "*****@*****.**"} ) No additional method call is required. Any non-success will throw a `WebPushException`. :param subscription_info: Provided by the client call :type subscription_info: dict :param data: Serialized data to send :type data: str :param vapid_private_key: Path to vapid private key PEM or encoded str :type vapid_private_key: str :param vapid_claims: Dictionary of claims ('sub' required) :type vapid_claims: dict :param content_encoding: Optional content type string :type content_encoding: str :param curl: Return as "curl" string instead of sending :type curl: bool :return requests.Response or string """ vapid_headers = None if vapid_claims: if not vapid_claims.get('aud'): url = urlparse(subscription_info.get('endpoint')) aud = "{}://{}".format(url.scheme, url.netloc) vapid_claims['aud'] = aud if not vapid_private_key: raise WebPushException("VAPID dict missing 'private_key'") if os.path.isfile(vapid_private_key): # Presume that key from file is handled correctly by # py_vapid. vv = Vapid.from_file( private_key_file=vapid_private_key) # pragma no cover else: vv = Vapid.from_raw(private_raw=vapid_private_key.encode()) vapid_headers = vv.sign(vapid_claims) result = WebPusher(subscription_info).send( data, vapid_headers, content_encoding=content_encoding, curl=curl, ) if not curl and result.status_code > 202: raise WebPushException("Push failed: {}: {}".format( result, result.text)) return result
def main(): parser = argparse.ArgumentParser(description="VAPID tool") parser.add_argument('--sign', '-s', help='claims file to sign') parser.add_argument('--validate', '-v', help='dashboard token to validate') args = parser.parse_args() if not os.path.exists('private_key.pem'): print "No private_key.pem file found." answer = None while answer not in ['y', 'n']: answer = raw_input("Do you want me to create one for you? (Y/n)") if not answer: answer = 'y' answer = answer.lower()[0] if answer == 'n': print "Sorry, can't do much for you then." exit if answer == 'y': break Vapid().save_key('private_key.pem') vapid = Vapid('private_key.pem') if not os.path.exists('public_key.pem'): print "No public_key.pem file found. You'll need this to access " print "the developer dashboard." answer = None while answer not in ['y', 'n']: answer = raw_input("Do you want me to create one for you? (Y/n)") if not answer: answer = 'y' answer = answer.lower()[0] if answer == 'y': vapid.save_public_key('public_key.pem') claim_file = args.sign if claim_file: if not os.path.exists(claim_file): print "No %s file found." % claim_file print """ The claims file should be a JSON formatted file that holds the information that describes you. There are three elements in the claims file you'll need: "sub" This is your site's admin email address (e.g. "mailto:[email protected]") "exp" This is the expiration time for the claim in seconds. If you don't have one, I'll add one that expires in 24 hours. You're also welcome to add additional fields to the claims which could be helpful for the Push Service operations team to pass along to your operations team (e.g. "ami-id": "e-123456", "cust-id": "a3sfa10987"). Remember to keep these values short to prevent some servers from rejecting the transaction due to overly large headers. See https://jwt.io/introduction/ for details. For example, a claims.json file could contain: {"sub": "mailto:[email protected]"} """ exit try: claims = json.loads(open(claim_file).read()) result = vapid.sign(claims) except Exception, exc: print "Crap, something went wrong: %s", repr(exc) raise exc print "Include the following headers in your request:\n" for key, value in result.items(): print "%s: %s" % (key, value) print "\n"
conn = sqlite3.Connection('/data/subs.db') try: c = conn.cursor() for i, row in enumerate(c.execute('SELECT * FROM subs')): try: # Manually recreate the push facade from the pywebpush API to be able to specify both TTL and urgency subscription_info = json.loads(row[0]) pusher = WebPusher(subscription_info) url = urlparse(subscription_info['endpoint']) aud = "{}://{}".format(url.scheme, url.netloc) vapid_claims = { 'sub': f'mailto:{config["sub_email"]}', 'aud': aud, 'exp': int(time.time()) + 12 * 60 * 60 } vv = Vapid.from_string(config['vapid_key']) headers = vv.sign(vapid_claims) # Define the urgency to be "normal", corresponding to messages being delivered # while the device is "On neither power nor wifi". # https://tools.ietf.org/html/draft-ietf-webpush-protocol-12#section-5.3 headers['Urgency'] = 'normal' resp = pusher.send(message, headers, ttl=12 * 60 * 60) # TODO: Handle cases where response status code is not 201. logging.debug( f'{i} ({resp.status_code}: {resp.text}): {subscription_info}') except Exception as e: logging.warning(f'{i} (failed): {e}') finally: if conn: conn.close()
"p256dh": "BK7h-R0UgDeT89jhWi76-FlTtlEr3DbVBnrr34qmK91Husli_Fazu7vo7kW1mg9F_qhNzrs2glbrc6wfqGFsXks=", "auth": "CyOHiGNXPcT5Slo9UMx2uA==" } } data = 'aaaaaa' vapid_private_key = PRIVATE_KEY = ''' MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg2xyYpqQhOaIdSbqH UwM3+ySvF47MoJyAFUNaHM7g/zOhRANCAAT554ztzCpjiIFOxNfEIicSzNPOZTIB Y1+CGl+LDfM5RlUNERFdfZYRqMmwvX7ydq7UiASkspWqdVVKZnLCzPD3 '''.strip() vapid_claims = {"sub": "mailto:[email protected]"} vapid_headers = None if vapid_claims: if not vapid_claims.get('aud'): url = urlparse(subscription_info.get('endpoint')) aud = "{}://{}".format(url.scheme, url.netloc) vapid_claims['aud'] = aud if os.path.isfile(vapid_private_key): vv = Vapid.from_file( private_key_file=vapid_private_key) # pragma no cover else: vv = Vapid.from_string(private_key=vapid_private_key) vapid_headers = vv.sign(vapid_claims) result = WebPusher(subscription_info, requests_session).send(data, vapid_headers) print result.text
def webpush(subscription_info, data=None, vapid_private_key=None, vapid_claims=None, content_encoding="aesgcm", curl=False, timeout=None): """ One call solution to endcode and send `data` to the endpoint contained in `subscription_info` using optional VAPID auth headers. in example: .. code-block:: python from pywebpush import python webpush( subscription_info={ "endpoint": "https://push.example.com/v1/abcd", "keys": {"p256dh": "0123abcd...", "auth": "001122..."} }, data="Mary had a little lamb, with a nice mint jelly", vapid_private_key="path/to/key.pem", vapid_claims={"sub": "*****@*****.**"} ) No additional method call is required. Any non-success will throw a `WebPushException`. :param subscription_info: Provided by the client call :type subscription_info: dict :param data: Serialized data to send :type data: str :param vapid_private_key: Path to vapid private key PEM or encoded str :type vapid_private_key: str :param vapid_claims: Dictionary of claims ('sub' required) :type vapid_claims: dict :param content_encoding: Optional content type string :type content_encoding: str :param curl: Return as "curl" string instead of sending :type curl: bool :param timeout: POST requests timeout :type timeout: float or tuple :return requests.Response or string """ vapid_headers = None if vapid_claims: if not vapid_claims.get('aud'): url = urlparse(subscription_info.get('endpoint')) aud = "{}://{}".format(url.scheme, url.netloc) vapid_claims['aud'] = aud if not vapid_private_key: raise WebPushException("VAPID dict missing 'private_key'") if os.path.isfile(vapid_private_key): # Presume that key from file is handled correctly by # py_vapid. vv = Vapid.from_file( private_key_file=vapid_private_key) # pragma no cover else: vv = Vapid.from_raw(private_raw=vapid_private_key.encode()) vapid_headers = vv.sign(vapid_claims) result = WebPusher(subscription_info).send( data, vapid_headers, content_encoding=content_encoding, curl=curl, timeout=timeout, ) if not curl and result.status_code > 202: raise WebPushException("Push failed: {}: {}".format( result, result.text)) return result
class Pusher: """ Pusher objects allows you to integrate Web Push Notifications into your project. Instantiate this class to integrate Web Push Notifications into your server. Objects of this class will create your public and private key, track your subscriptions, notify your clients, and do all the required work for you. e.g. >>> from solidwebpush import Pusher >>> >>> pusher = Pusher() >>> >>> #what's my base64-encoded public key? >>> print pusher.getB64PublicKey() >>> >>> subscription = "{Alice's serviceWorker subscription object}" >>> >>> #notify Alice >>> pusher.sendNotification(subscription, "Hello World!") >>> >>> #or >>> #permanently subscribe Alice >>> pusher.newSubscription(alice_session_id, subscription) >>> >>> #so that, from now on we can notify her by >>> pusher.notify(alice_session_id, "Hello World") >>> >>> #or notify all the permanently subscribed clients >>> pusher.notifyAll("Hello World") (for more "toy" examples visit https://github.com/sergioburdisso/solidwebpush/tree/master/examples) """ __vapid__ = None __verbose__ = False __db_name__ = None __db_conn__ = None __db__ = None __pool__ = None __RE_URL__ = r"(https?://(?:[\w-]+\.)*[\w-]+(?::\d+)?)(?:/.*)?" def __init__(self, db_name="subscriptors.db", verbose=False): """ Class constructor. :param db_name: The [optional] name ("subscriptors.db" by default) of the file in which subscriptions will be stored in. This is only required if methods like ``newSubscription`` will be used. :type db_name: str :param verbose: An optional value, to enabled or disabled the "verbose mode" (False by default) :type verbose: bool """ self.__verbose__ = verbose self.__db_name__ = db_name if not os.path.exists('private_key.pem'): self.__print__("No private_key.pem file found") Vapid().save_key('private_key.pem') self.__print__("private_key.pem file created") self.__vapid__ = Vapid('private_key.pem') if not os.path.exists('public_key.pem'): self.__print__("No public_key.pem file found") self.__vapid__.save_public_key('public_key.pem') self.__print__("public_key.pem file created") if verbose: self.__print__("PublicKey: %s" % self.getB64PublicKey()) def __getstate__(self): """Class state getter.""" self_dict = self.__dict__.copy() try: del self_dict['__pool__'] del self_dict['__db_conn__'] del self_dict['__db__'] except KeyError: pass return self_dict def __call__(self, subscription, data): """Class instances callable.""" self.__send__(subscription, data) def __print__(self, msg): """Verbose print wrapper.""" print("[ SolidWebPush ] %s" % msg) def __b64rpad__(self, b64str): """Base64 right (=)padding.""" return b64str + b"===="[:len(b64str) % 4] def __encrypt__(self, user_publickey, user_auth, payload): """Encrypt the given payload.""" user_publickey = user_publickey.encode("utf8") raw_user_publickey = base64.urlsafe_b64decode( self.__b64rpad__(user_publickey) ) user_auth = user_auth.encode("utf8") raw_user_auth = base64.urlsafe_b64decode(self.__b64rpad__(user_auth)) salt = os.urandom(16) curve = pyelliptic.ECC(curve="prime256v1") curve_id = base64.urlsafe_b64encode(curve.get_pubkey()[1:]) http_ece.keys[curve_id] = curve http_ece.labels[curve_id] = "P-256" encrypted = http_ece.encrypt( payload.encode('utf8'), keyid=curve_id, dh=raw_user_publickey, salt=salt, authSecret=raw_user_auth, version="aesgcm" ) return { 'dh': base64.urlsafe_b64encode( curve.get_pubkey() ).strip(b'=').decode("utf-8"), 'salt': base64.urlsafe_b64encode( salt ).strip(b'=').decode("utf-8"), 'body': encrypted } def __send__(self, subscription, data): """Encrypt and send the data to the Message Server.""" if __is_valid_json__(subscription): subscription = json.loads(subscription) else: raise SubscriptionError() if type(data) == dict: data = json.dumps(data) base_url = re.search( self.__RE_URL__, subscription["endpoint"] ).group(1) encrypted = self.__encrypt__( subscription["keys"]["p256dh"], subscription["keys"]["auth"], data ) jwt_payload = { "aud": base_url, "exp": str(int(time.time()) + 60 * 60 * 12), "sub": "mailto:[email protected]" } headers = self.__vapid__.sign(jwt_payload) headers["TTL"] = str(43200) headers["Content-Type"] = "application/octet-stream" headers['Content-Encoding'] = "aesgcm" headers['Encryption'] = "salt=%s" % encrypted["salt"] headers["Crypto-Key"] = "dh=%s;p256ecdsa=%s" % ( encrypted["dh"], self.getUrlB64PublicKey() ) r = requests.post( subscription["endpoint"], data=encrypted["body"], headers=headers ) if self.__verbose__: self.__print__( "Message Server response was: \nStatus: %d\nBody: %s" % (r.status_code, r.text) ) def setVerbose(self, value): """ Verbose mode. Enable and disable the verbose mode (disabled by default). When verbose mode is active, some internal messages are going to be displayed, as well as the responses from the Message Server. :param value: True to enable or False to disable :type value: bool """ self.__verbose__ = value def getPublicKey(self): """ Raw public key getter. :returns: the raw public key :rtype: str """ return b"\x04" + self.__vapid__.public_key.to_string() def getPrivateKey(self): """ Raw private key getter. (probably you won't care about private key at all) :returns: the raw private key :rtype: str """ return self.__vapid__.private_key.to_string() def getB64PublicKey(self): """ Base64 public key getter. Returns the string you're going to use when subscribing your serviceWorker. (as long as you're planning to decode it using JavaScript's ``atob`` function) :returns: Base64-encoded version of the public key :rtype: str """ return base64.b64encode(self.getPublicKey()).decode("utf-8") def getB64PrivateKey(self): """ Base64 private key getter. (probably you won't care about private key at all) :returns: Base64-encoded version of the private key :rtype: str """ return base64.b64encode(self.getPrivateKey()).decode("utf-8") def getUrlB64PublicKey(self): """ Url-Safe Base64 public key getter. This is the string you're going to use when subscribing your serviceWorker. (so long as you're planning to decode it using a function like ``urlB64ToUint8Array`` from https://developers.google.com/web/fundamentals/getting-started/codelabs/push-notifications/) :returns: URLSafe-Base64-encoded version of the public key :rtype: str """ return base64.urlsafe_b64encode( self.getPublicKey() ).strip(b"=").decode("utf-8") def getUrlB64PrivateKey(self): """ Url-Safe Base64 private key getter. (probably you won't care about private key at all) :returns: URLSafe-Base64-encoded version of the private key :rtype: str """ return base64.urlsafe_b64encode( self.getPrivateKey() ).strip(b"=").decode("utf-8") def sendNotification(self, subscription, data, nonblocking=False): """ Send the data to the Message Server. Pushes a notification carrying ``data`` to the client associated with the ``subscription`` object. If ``nonblocking`` is True, the program won't block waiting for the message to be completely sent. The ``wait()`` method should be used instead. (see ``wait()`` for more details) :param subscription: the client's subscription JSON object :type subscription: str :param data: A string or a dict object to be sent. The dict will be automatically converted into a JSON string before being sent. An example of a dict object would be: ``{"title": "hey Bob!", "body": "you rock"}`` :type data: str or dict :param nonblocking: Whether to block the caller until this method finishes running or not. :type nonblocking: bool """ self.sendNotificationToAll( [subscription], data, nonblocking=nonblocking, processes=1 ) def sendNotificationToAll( self, subscriptions, data, nonblocking=False, processes=None): """ Send the data to the Message Server. Pushes a notification carrying ``data`` to each of the clients associated with the list of ``subscriptions``. If ``nonblocking`` is True, the program won't block waiting for all the messages to be completely sent. The ``wait()`` method should be used instead. (see ``wait()`` for more details) :param subscriptions: The list of client's subscription JSON object :type subscriptions: list :param data: A string or a dict object to be sent. The dict will be automatically converted into a JSON string before being sent. An example of a dict object would be: ``{"title": "hey Bob!", "body": "you rock"}`` :type data: str or dict :param processes: The [optional] number of worker processes to use. If processes is not given then the number returned by os.cpu_count() is used. :type processes: int :param nonblocking: Whether to block the caller until this method finishes running or not. :type nonblocking: bool """ if not self.__pool__: self.__pool__ = Pool(processes) if nonblocking: pool_apply = self.__pool__.apply_async else: pool_apply = self.__pool__.apply for subscription in subscriptions: pool_apply(self, args=(subscription, data)) if not nonblocking: self.__pool__.close() self.__pool__.join() self.__pool__ = None def wait(self): """ Wait for all the messages to be completely sent. Block the program and wait for all the notifications to be sent, before continuing. This only works if there exist a previous call to a method with the ``nonblocking`` parameter set to ``True``, as shown in the following example: >>> pusher.sendNotificationToAll( listOfsubscriptions, "Hello World", nonblocking=True ) >>> # Maybe some other useful computation here >>> pusher.wait() """ self.__pool__.close() self.__pool__.join() self.__pool__ = None @__database__ def newSubscription(self, session_id, subscription, group_id=0): """ newSubscription(session_id, subscription, group_id=0) Permanently subscribe a client. Subscribes the client by permanently storing its ``subscription`` and group id (``group_id``). This will allow you to push notifications using the client id (``session_id``) instead of its ``subscription`` object. Groups help you organize subscribers. For instance, suppose you want to notify Bob by sending a notification to all of his devices. If you previously subscribed each one of his devices to the same group let's say 13, then calling notifyAll with 13 will push notifications to all of them: >>> BobsGroup = 13 >>> ... >>> pusher.newSubscription( BobsTabletSessionId, subscription0, BobsGroup ) >>> ... >>> pusher.newSubscription( BobsLaptopSessionId, subscription1, BobsGroup ) >>> ... >>> pusher.newSubscription( BobsMobileSessionId, subscription2, BobsGroup ) >>> ... >>> pusher.notifyAll(BobsGroup) :param session_id: The client's identification (e.g. a cookie or other session token) :type session_id: str :param subscription: The client's subscription JSON object :type subscription: str :param group_id: an optional Group ID value (0 by default) :type group_id: int """ if not __is_valid_json__(subscription): raise SubscriptionError() if not session_id and session_id != 0: raise SesionIDError("session_id cannot be empty") if not self.getSubscription(session_id): old_session_id = self.getIdSession(subscription) if old_session_id: self.removeSubscription(old_session_id) self.__db__.execute( "INSERT INTO subscriptors (session_id, subscription, group_id)" " VALUES (?,?,?)", (session_id, subscription, group_id) ) else: self.__db__.execute( "UPDATE subscriptors SET subscription=?, group_id=? WHERE" " session_id=?", (subscription, group_id, session_id,) ) self.__db_conn__.commit() @__database__ def removeSubscription(self, session_id): """ removeSubscription(session_id) Permanently unsubscribes a client. Unsubscribes the client by permanently removing its ``subscription`` and group id. :param session_id: The client's identification (e.g. a cookie or other session token) :type session_id: str """ self.__db__.execute( "DELETE FROM subscriptors WHERE session_id = ?", (session_id,) ) self.__db_conn__.commit() @__database__ def notify(self, session_id, data, nonblocking=False): """ notify(session_id, data, nonblocking=False) Notify a given client. Pushes a notification carrying ``data`` to the client associated with the ``session_id``. ``session_id`` is the value passed to the ``newSubscription`` method when storing the client's subscription object. :param session_id: The client's identification (e.g. a cookie or other session token) :type session_id: str :param data: A string or a dict object to be sent. The dict will be automatically converted into a JSON string before being sent. An example of a dict object would be: ``{"title": "hey Bob!", "body": "you rock"}`` :type data: str or dict :param nonblocking: Whether to block the caller until this method finishes running or not. :type nonblocking: bool """ if self.getSubscription(session_id): self.sendNotification( self.getSubscription(session_id), data, nonblocking=nonblocking ) else: raise SesionIDError( "the given session_id '%s' does not exist " "(it has not been subscribed yet)." % session_id ) @__database__ def notifyAll(self, data, group_id=None, exceptions=[], nonblocking=False): """ notifyAll(data, group_id=None, exceptions=[], nonblocking=False) Notify a group of clients. When no ``group_id`` is given, notify all subscribers (except for those in ``exceptions``). Otherwise, it only notifies all members of the ``group_id`` group (except for those in ``exceptions``). :param data: A string or a dict object to be sent. The dict will be automatically converted into a JSON string before being sent. An example of a dict object would be: ``{"title": "hey Bob!", "body": "you rock"}`` :type data: str or dict :param group_id: an optional Group ID value (0 by default) :type group_id: int :param exceptions: The list of sessions ids to be excluded. :type exceptions: list :param nonblocking: Whether to block the caller until this method finishes running or not. :type nonblocking: bool """ if group_id is not None: condition = " WHERE group_id=" + group_id else: condition = "" self.sendNotificationToAll( [ row["subscription"] for row in self.__db__.execute( "SELECT * FROM subscriptors" + condition ).fetchall() if row["session_id"] not in exceptions ], data, nonblocking=nonblocking ) @__database__ def getIdSession(self, subscription): """ getIdSession(subscription) Given a subscription object returns the session id associated with it. :param subscription: The client's subscription JSON object :type subscription: str :returns: the session id associated with subscription :rtype: str """ res = self.__db__.execute( "SELECT session_id FROM subscriptors WHERE subscription=?", (subscription,) ).fetchone() return list(res.values())[0] if res else None @__database__ def getSubscription(self, session_id): """ getSubscription(session_id) Given a session id returns the subscription object associated with it. :param session_id: A session id :type session_id: str :returns: The client's subscription JSON object associated with the session id. :rtype: str """ res = self.__db__.execute( "SELECT subscription FROM subscriptors WHERE session_id=?", (session_id,) ).fetchone() return list(res.values())[0] if res else None @__database__ def getGroupId(self, session_id): """ getGroupId(session_id) Given a session id returns the group id it belongs to. :param session_id: A session id :type session_id: str :returns: a group id value :rtype: int """ res = self.__db__.execute( "SELECT group_id FROM subscriptors WHERE session_id=?", (session_id,) ).fetchone() return list(res.values())[0] if res else None
def webpush(subscription_info, data=None, vapid_private_key=None, vapid_claims=None, content_encoding="aes128gcm", curl=False, timeout=None, ttl=0, verbose=False, headers=None): """ One call solution to endcode and send `data` to the endpoint contained in `subscription_info` using optional VAPID auth headers. in example: .. code-block:: python from pywebpush import python webpush( subscription_info={ "endpoint": "https://push.example.com/v1/abcd", "keys": {"p256dh": "0123abcd...", "auth": "001122..."} }, data="Mary had a little lamb, with a nice mint jelly", vapid_private_key="path/to/key.pem", vapid_claims={"sub": "*****@*****.**"} ) No additional method call is required. Any non-success will throw a `WebPushException`. :param subscription_info: Provided by the client call :type subscription_info: dict :param data: Serialized data to send :type data: str :param vapid_private_key: Vapid instance or path to vapid private key PEM \ or encoded str :type vapid_private_key: Union[Vapid, str] :param vapid_claims: Dictionary of claims ('sub' required) :type vapid_claims: dict :param content_encoding: Optional content type string :type content_encoding: str :param curl: Return as "curl" string instead of sending :type curl: bool :param timeout: POST requests timeout :type timeout: float or tuple :param ttl: Time To Live :type ttl: int :param verbose: Provide verbose feedback :type verbose: bool :return requests.Response or string :param headers: Dictionary of extra HTTP headers to include :type headers: dict """ if headers is None: headers = dict() else: # Ensure we don't leak VAPID headers by mutating the passed in dict. headers = headers.copy() vapid_headers = None if vapid_claims: if verbose: print("Generating VAPID headers...") if not vapid_claims.get('aud'): url = urlparse(subscription_info.get('endpoint')) aud = "{}://{}".format(url.scheme, url.netloc) vapid_claims['aud'] = aud # Remember, passed structures are mutable in python. # It's possible that a previously set `exp` field is no longer valid. if (not vapid_claims.get('exp') or vapid_claims.get('exp') < int(time.time())): # encryption lives for 12 hours vapid_claims['exp'] = int(time.time()) + (12 * 60 * 60) if verbose: print("Setting VAPID expry to {}...".format( vapid_claims['exp'])) if not vapid_private_key: raise WebPushException("VAPID dict missing 'private_key'") if isinstance(vapid_private_key, Vapid): vv = vapid_private_key elif os.path.isfile(vapid_private_key): # Presume that key from file is handled correctly by # py_vapid. vv = Vapid.from_file( private_key_file=vapid_private_key) # pragma no cover else: vv = Vapid.from_string(private_key=vapid_private_key) if verbose: print("\t claims: {}".format(vapid_claims)) vapid_headers = vv.sign(vapid_claims) if verbose: print("\t headers: {}".format(vapid_headers)) headers.update(vapid_headers) response = WebPusher(subscription_info, verbose=verbose).send( data, headers, ttl=ttl, content_encoding=content_encoding, curl=curl, timeout=timeout, ) if not curl and response.status_code > 202: raise WebPushException("Push failed: {} {}\nResponse body:{}".format( response.status_code, response.reason, response.text), response=response) return response
class PushClient(object): """Smoke Test the Autopush push server""" def __init__(self, args, loop, tasks: List[List[Union[AnyStr, Dict]]]=None): self.config = args self.loop = loop self.connection = None self.pushEndpoint = None self.channelID = None self.notifications = [] self.uaid = None self.recv = [] self.tasks = tasks or [] self.output = None self.vapid_cache = {} if args.vapid_key: self.vapid = Vapid().from_file(args.vapid_key) else: self.vapid = Vapid() self.vapid.generate_keys() self.tls_conn = None if args.partner_endpoint_cert: if os.path.isfile(args.partner_endpoint_cert): context = ssl.create_default_context( cafile=args.partner_endpoint_cert) else: context = ssl.create_default_context( cadata=args.partner_endpoint_cert) context.verify_mode = ssl.CERT_REQUIRED context.check_hostname = False self.tls_conn = aiohttp.TCPConnector(ssl_context=context) def _fix_endpoint(self, endpoint: str) -> str: """Adjust the endpoint if needed""" if self.config.partner_endpoint: orig_path = urlparse(endpoint).path partner = urlparse(self.config.partner_endpoint) return "{scheme}://{host}{path}".format( scheme=partner.scheme, host=partner.netloc, path=orig_path) return endpoint def _cache_sign(self, claims: Dict[str, str]) -> Dict[str, str]: """Pull a VAPID header from the cache or sign the new header :param claims: list of VAPID claims. :returns: dictionary of VAPID headers. """ vhash = hashlib.sha1() vhash.update(json.dumps(claims).encode()) key = vhash.digest() if key not in self.vapid_cache: self.vapid_cache[key] = self.vapid.sign(claims) return self.vapid_cache[key] async def _next_task(self): """Tasks are shared between active "cmd_*" commands and async "recv_*" events. Since both are reading off the same stack, we centralize that here. """ try: task = self.tasks.pop(0) logging.debug(">>> cmd_{}".format(task[0])) await getattr(self, "cmd_" + task[0])(**(task[1])) return True except IndexError: await self.cmd_done() return False except PushException: raise except AttributeError: raise PushException("Invalid command: {}".format(task[0])) except Exception: # pragma nocover traceback.print_exc() raise async def run(self, server: str="wss://push.services.mozilla.com", tasks: List[List[Union[AnyStr, Dict]]]=None): """Connect to a remote server and execute the tasks :param server: URL to the Push Server :param tasks: List of tasks and arguments to run """ if tasks: self.tasks = tasks if not self.connection: await self.cmd_connect(server) while await self._next_task(): pass async def process(self, message: Dict[str, Any]): """Process an incoming websocket message :param message: JSON message content """ mtype = "recv_" + message.get('messageType').lower() try: await getattr(self, mtype)(**message) except AttributeError as ex: raise PushException( "Unknown messageType: {}".format(mtype)) from ex async def receiver(self): """Receiver handler for websocket messages """ try: while self.connection: message = await self.connection.recv() print("<<< {}".format(message)) await self.process(json.loads(message)) except websockets.exceptions.ConnectionClosed: output_msg(out=self.output, status="Websocket Connection closed") # Commands::: async def _send(self, no_recv: bool=False, **msg): """Send a message out the websocket connection :param no_recv: Flag to indicate if response is expected :param msg: message content :return: """ output_msg(out=self.output, flow="output", msg=msg) try: await self.connection.send(json.dumps(msg)) if no_recv: return message = await self.connection.recv() await self.process(json.loads(message)) except websockets.exceptions.ConnectionClosed: pass async def cmd_connect(self, server: str=None, **kwargs): """Connect to a remote websocket server :param server: Websocket url :param kwargs: ignored """ srv = self.config.server or server output_msg(out=self.output, status="Connecting to {}".format(srv)) self.connection = await websockets.connect(srv) self.recv.append(asyncio.ensure_future(self.receiver())) async def cmd_close(self, **kwargs): """Close the websocket connection (if needed) :param kwargs: ignored """ output_msg(out=self.output, status="Closing socket connection") if self.connection and self.connection.state == 1: try: for recv in self.recv: recv.cancel() self.recv = [] await self.connection.close() except (websockets.exceptions.ConnectionClosed, futures.CancelledError): pass async def cmd_sleep(self, period: int=5, **kwargs): output_msg(out=self.output, status="Sleeping...") await asyncio.sleep(period) async def cmd_hello(self, uaid: str=None, **kwargs): """Send a websocket "hello" message :param uaid: User Agent ID (if reconnecting) """ if not self.connection or self.connection.state != 1: await self.cmd_connect() output_msg(out=self.output, status="Sending Hello") args = dict(messageType="hello", use_webpush=1, **kwargs) if uaid: args['uaid'] = uaid elif self.uaid: args['uaid'] = self.uaid await self._send(**args) async def cmd_ack(self, channelID: str=None, version: str=None, timeout: int=60, **kwargs): """Acknowledge a previous mesage :param channelID: Channel to acknowledge :param version: Version string for message to acknowledge :param kwargs: Additional optional arguments :param timeout: Time to wait for notifications (used by testing) """ timeout = timeout * 2 while not self.notifications: output_msg( out=self.output, status="No notifications recv'd, Sleeping...") await asyncio.sleep(0.5) timeout -= 1 if timeout < 1: raise PushException("Timeout waiting for messages") self.notifications.reverse() for notif in self.notifications: output_msg( out=self.output, status="Sending ACK", channelID=channelID or notif['channelID'], version=version or notif['version']) await self._send(messageType="ack", channelID=channelID or notif['channelID'], version=version or notif['version'], no_recv=True) self.notifications = [] async def cmd_register(self, channelID: str=None, key: str=None, **kwargs): """Register a new ChannelID :param channelID: UUID for the channel to register :param key: applicationServerKey for a restricted access channel :param kwargs: additional optional arguments :return: """ output_msg( out=self.output, status="Sending new channel registration") channelID = channelID or self.channelID or str(uuid.uuid4()) args = dict(messageType='register', channelID=channelID) if key: args[key] = key args.update(kwargs) await self._send(**args) async def cmd_done(self, **kwargs): """Close all connections and mark as done :param kwargs: ignored :return: """ output_msg( out=self.output, status="done") await self.cmd_close() await self.connection.close_connection() """ recv_* commands handle incoming responses.Since they are asynchronous and need to trigger follow-up tasks, they each will need to pull and process the next task. """ async def recv_hello(self, **msg: Dict[str, Any]): """Process a received "hello" :param msg: body of response :return: """ assert(msg['status'] == 200) try: self.uaid = msg['uaid'] await self._next_task() except KeyError as ex: raise PushException from ex async def recv_register(self, **msg): """Process a received registration message :param msg: body of response :return: """ assert(msg['status'] == 200) self.pushEndpoint = self._fix_endpoint(msg['pushEndpoint']) self.channelID = msg['channelID'] output_msg( out=self.output, flow="input", msg=dict( message="register", channelID=self.channelID, pushEndpoint=self.pushEndpoint)) await self._next_task() async def recv_notification(self, **msg): """Process a received notification message. This event does NOT trigger the next command in the stack. :param msg: body of response """ def repad(string): return string + '===='[len(msg['data']) % 4:] msg['_decoded_data'] = base64.urlsafe_b64decode( repad(msg['data'])).decode() output_msg( out=self.output, flow="input", message="notification", msg=msg) self.notifications.append(msg) await self.cmd_ack() async def _post(self, session, url: str, data: bytes): """Post a message to the endpoint :param session: async session object :param url: pushEndpoint :param data: data to send :return: """ # print ("Fetching {}".format(url)) with aiohttp.Timeout(10, loop=session.loop): return await session.post(url=url, data=data) async def _post_session(self, url: str, headers: Dict[str, str], data: bytes): """create a session to send the post message to the endpoint :param url: pushEndpoint :param headers: dictionary of headers :param data: body of the content to send """ async with aiohttp.ClientSession( loop=self.loop, headers=headers, read_timeout=30, connector=self.tls_conn, ) as session: reply = await self._post(session, url, data) return reply async def cmd_push(self, data: bytes=None, headers: Dict[str, str]=None, claims: Dict[str, str]=None): """Push data to the pushEndpoint :param data: message content :param headers: dictionary of headers :param claims: VAPID claims :return: """ if not self.pushEndpoint: raise PushException("No Endpoint, no registration?") if not headers: headers = {} if claims: headers.update(self._cache_sign(claims)) output_msg( out=self.output, status="Pushing message", msg=repr(data)) if data and 'content-encoding' not in headers: headers.update({ "content-encoding": "aesgcm128", "encryption": "salt=test", "encryption-key": "dh=test", }) result = await self._post_session(self.pushEndpoint, headers, data) body = await result.text() output_msg( out=self.output, flow="http-out", pushEndpoint=self.pushEndpoint, headers=headers, data=repr(data), result="{}: {}".format(result.status, body))