Exemplo n.º 1
0
    def __init__(self, freq, activation, input, target_idx, task_loss, surrogate_loss,
                 hyperparameter, learning_rate, batch_generator, n_batches,
                 factor=1.5, n_updates=10):
        Extension.__init__(self, 'adapt_zloss', freq)

        self.batch_generator = batch_generator
        self.n_batches = n_batches
        self.learning_rate = learning_rate
        self.hyperparameter = hyperparameter
        self.factor = factor
        self.n_updates = n_updates

        # grad = theano.grad(surrogate_loss, activation)
        # new_activation = activation - learning_rate * grad
        self.fun_activation = theano.function([input], activation)

        activation_bis = tensor.matrix()
        surr_loss_bis = theano.clone(surrogate_loss,
                                     replace={activation: activation_bis})
        grad = theano.grad(surr_loss_bis, activation_bis)
        new_activation = activation_bis - 100*learning_rate * grad

        task_loss_bis = theano.clone(task_loss,
                                     replace={activation: new_activation})

        self.fun_update_task_loss = theano.function(
                [activation_bis, target_idx], [task_loss_bis, new_activation])
Exemplo n.º 2
0
 def __init__(self, freq, words, embedding_matrix, knn, vocab, inv_vocab):
     Extension.__init__(self, 'Nearest neighbours of words', freq)
     self.words = words
     self.embedding_matrix = embedding_matrix
     self.knn = knn
     self.vocab = vocab
     self.inv_vocab = inv_vocab
     self.word_ids = []
     for word in words:
         self.word_ids.append(self.vocab[word])
Exemplo n.º 3
0
 def __init__(self, freq, words, embedding_matrix, knn, vocab, inv_vocab):
     Extension.__init__(self, 'Nearest neighbours of words', freq)
     self.words = words
     self.embedding_matrix = embedding_matrix
     self.knn = knn
     self.vocab = vocab
     self.inv_vocab = inv_vocab
     self.word_ids = []
     for word in words:
         self.word_ids.append(self.vocab[word])
Exemplo n.º 4
0
    def __init__(self,
                 freq,
                 activation,
                 input,
                 target_idx,
                 task_loss,
                 surrogate_loss,
                 hyperparameter,
                 learning_rate,
                 batch_generator,
                 n_batches,
                 factor=1.5,
                 n_updates=10):
        Extension.__init__(self, 'adapt_zloss', freq)

        self.batch_generator = batch_generator
        self.n_batches = n_batches
        self.learning_rate = learning_rate
        self.hyperparameter = hyperparameter
        self.factor = factor
        self.n_updates = n_updates

        # grad = theano.grad(surrogate_loss, activation)
        # new_activation = activation - learning_rate * grad
        self.fun_activation = theano.function([input], activation)

        activation_bis = tensor.matrix()
        surr_loss_bis = theano.clone(surrogate_loss,
                                     replace={activation: activation_bis})
        grad = theano.grad(surr_loss_bis, activation_bis)
        new_activation = activation_bis - 100 * learning_rate * grad

        task_loss_bis = theano.clone(task_loss,
                                     replace={activation: new_activation})

        self.fun_update_task_loss = theano.function(
            [activation_bis, target_idx], [task_loss_bis, new_activation])
Exemplo n.º 5
0
    def server_hello(self):
        record_bytes, hello_bytes = self.read(return_record=True)
        assert record_bytes[:1] == constants.CONTENT_TYPE_HANDSHAKE, 'Server return {}'.format(
            print_hex(record_bytes[:1]))
        self.messages.append(hello_bytes)
        assert len(hello_bytes) > 0, 'No response from server'
        assert hello_bytes[:1] == b'\x02', 'Not server hello'
        tls_version = hello_bytes[4:6]
        assert tls_version == self.tls_version, 'Not a desired tls version'

        # Parse hello bytes
        self.server_random, hello_bytes = hello_bytes[6:6 + 32], hello_bytes[6 + 32:]
        session_id_length = int.from_bytes(hello_bytes[:1], 'big')
        session_id, hello_bytes = hello_bytes[:session_id_length + 1], hello_bytes[session_id_length + 1:]
        # This session_id can be reused for the session_id in the client Hello for the next request
        # Reusing a session_id results in no certificate sent after Server Hello
        server_cipher_suite, hello_bytes = hello_bytes[:2], hello_bytes[2:]
        compression_method, hello_bytes = hello_bytes[:1], hello_bytes[1:]
        extensions_length, hello_bytes = int.from_bytes(hello_bytes[:2], 'big'), hello_bytes[2:]
        extensions = hello_bytes[:extensions_length]
        self.extensions = Extension.parse_extensions(extensions)
        alpn = list(filter(lambda x: isinstance(x, ALPN), self.extensions))
        self.http_version = alpn[0].protocols[0] if len(alpn) > 0 else constants.EXTENSION_ALPN_HTTP_1_1
        assert self.http_version == constants.EXTENSION_ALPN_HTTP_1_1, 'Not support http2 yet'

        certificate_bytes = self.read()
        cached_cert_path = r'./debug/{}.crt'.format(self.host)
        if certificate_bytes[0] == 0x0B:
            self.messages.append(certificate_bytes)
            certificate_bytes = certificate_bytes[7:]
            os.path.exists(r'./debug') or os.makedirs(r'./debug')
            self.server_certificate = get_certificate(certificate_bytes, open(cached_cert_path, 'wb+'),
                                                      match_hostname=self.match_hostname, host=self.host)
            next_bytes = self.read()
        elif os.path.exists(cached_cert_path):
            self.server_certificate = load(open(cached_cert_path, 'rb'))
            next_bytes = certificate_bytes
        else:
            raise ValueError('No certificate was received.')

        self.cipher_suite = CipherSuite.get_from_id(self.tls_version, self.client_random, self.server_random,
                                                    self.server_certificate, server_cipher_suite)

        self.messages.append(next_bytes)
        self.is_server_key_exchange = next_bytes[:1] == b'\x0c'
        if self.is_server_key_exchange:  # Server key exchange
            self.cipher_suite.parse_key_exchange_params(next_bytes)

            hello_done_bytes = self.read()
            self.messages.append(hello_done_bytes)
        elif self.session_id:  # @todo handle sessions
            raise ValueError('No server key exchange has received. # @todo')
        self.debug_print('Cipher suite negotiated', ' {}({})'.format(self.cipher_suite, print_hex(server_cipher_suite)))
        self.debug_print('TLS version', self.tls_version)
        self.debug_print('Server random', print_hex(self.server_random))
        self.debug_print('Key exchange', self.cipher_suite.key_exchange.__class__.__name__)
        self.debug_print('Server cert not before (UTC)', self.server_certificate.not_valid_before)
        self.debug_print('Server cert not after (UTC)', self.server_certificate.not_valid_after)
        self.debug_print('Server cert fingerprint (sha256)', print_hex(self.server_certificate.fingerprint(SHA256())))
        if self.is_server_key_exchange:
            public_key = self.cipher_suite.key_exchange.public_key
            self.debug_print('Key Exchange Server Public Key ({!s} bytes)'.format(len(public_key)),
                             print_hex(public_key))
Exemplo n.º 6
0
import sys
import os
from extensions import Extension
import re
import importlib

EXTENSION = Extension()


def get_the_getters(given_list, getter='GET'):
    getters = [i for i, x in enumerate(given_list) if x == getter]
    for g in getters:
        print('getters: ', given_list[g], g, end=', ')
    if getters != []:
        for getter in getters:
            getter_var = getter + 1
            getter_val = given_list[getter_var]
            value = getattr(EXTENSION,
                            given_list[getter])(given_list[getter_var])
            print('value: ', value)
            given_list[getter_var] = value
            del given_list[getter]
            #            given_list = [value if given == getter_val else given for given in given_list]
            #            print('pt', given_list)
            return given_list
    else:
        return given_list


def remove_space_or_not(part):
    while part.startswith(' '):