Example #1
0
def run():
    now = datetime.utcnow()
    until = (now + timedelta(minutes=5)).isoformat()

    last_evaluated_key = None

    count = 0
    # emulated do while loop
    while True:

        response = load_data(until, last_evaluated_key)

        if response['Count'] == 0:
            break
        else:
            count += response['Count']

        ids = []
        for item in response['Items']:
            ids.append(item['id'])

        for chunk in make_chunks(ids, 200):
            invoke_lambda(os.environ.get('SCHEDULE_FUNCTION'),
                          json.dumps(chunk).encode('utf-8'))

        if 'LastEvaluatedKey' in response:
            print('Continuing at next page')
            last_evaluated_key = response['LastEvaluatedKey']
        else:
            print('Finished loading data')
            break

    print('Batched %d entries' % count)
Example #2
0
def run():
    current_segment = int(
        datetime.utcnow().replace(second=0, microsecond=0).timestamp() +
        10 * 60
    )  # scan the minute that is 10 minutes away, not the one that is already progressing

    count = 0

    for page in client.get_paginator('query').paginate(
            TableName=table_name,
            ProjectionExpression='pk,sk',
            KeyConditionExpression='pk = :s',
            ExpressionAttributeValues={':s': {
                'N': str(current_segment)
            }}):
        ids = []
        for item in page.get('Items', []):
            event = {k: deserializer.deserialize(v) for k, v in item.items()}
            ids.append({'pk': int(event['pk']), 'sk': event['sk']})

        for chunk in make_chunks(ids, 200):
            invoke_lambda(os.environ.get('SCHEDULE_FUNCTION'),
                          json.dumps(chunk).encode('utf-8'))

        count += page['Count']

    print('Batched %d entries' % count)
Example #3
0
def handle(events):
    received = datetime.utcnow()
    to_be_scheduled = []
    event_wrappers = []
    for event in events:
        if 'date' not in event:
            publish_to_failure_topic(event, 'date is required')
            print('error.date_required %s' % (json.dumps({'event': event})))
            continue
        if 'payload' not in event:
            publish_to_failure_topic(event, 'payload is required')
            print('error.payload_required %s' % (json.dumps({'event': event})))
            continue
        if 'target' not in event:
            publish_to_failure_topic(event, 'target is required')
            print('error.target_required %s' % (json.dumps({'event': event})))
            continue

        if not isinstance(event['payload'], str):
            publish_to_failure_topic(event, 'payload must be a string')
            print('error.payload_is_not_string %s' % (json.dumps({'event': event})))
            continue

        event_wrapper = {
            'id': str(uuid4()),
            'date': event['date'],
            'payload': event['payload'],
            'target': event['target'],
            'status': 'NEW'
        }

        if 'failure_topic' in event:
            event_wrapper['failure_topic'] = event['failure_topic']

        if 'user' not in event:
            if os.environ.get('ENFORCE_USER'):
                publish_to_failure_topic(event, 'user is required')
                print('error.event_has_no_user %s' % (json.dumps({'event': event})))
                continue
        else:
            event_wrapper['user'] = event['user']

        # if the event has less than 10 minutes until execution, then fast track it
        if has_less_then_ten_minutes(event_wrapper['date']):
            to_be_scheduled.append(event_wrapper['id'])

        print('event.consumed %s' % (json.dumps({'id': event_wrapper['id'], 'timestamp': str(received)})))
        event_wrappers.append(event_wrapper)

    # we must save before delegating, because the downstream function will access the DB entity
    save_with_retry(event_wrappers)

    print('Fast track scheduling for %d entries' % len(to_be_scheduled))
    for chunk in make_chunks(to_be_scheduled, 200):
        ids = json.dumps(chunk).encode('utf-8')
        invoke_lambda(os.environ.get('SCHEDULE_FUNCTION'), ids)

    print('Processed %d entries' % len(events))
Example #4
0
def handle(event, context):
    total_loading = 0
    total_batching = 0

    now = datetime.utcnow()
    until = (now + timedelta(minutes=5)).isoformat()

    last_evaluated_key = None

    count = 0
    # emulated do while loop
    while True:

        response, duration = load_data(until, last_evaluated_key)
        total_loading += duration

        if response['Count'] == 0:
            break
        else:
            count += response['Count']

        ids = []
        for item in response['Items']:
            ids.append(item['id'])

        start = datetime.now()
        for chunk in make_chunks(ids, 200):
            invoke_lambda(os.environ.get('SCHEDULE_FUNCTION'), json.dumps(chunk).encode('utf-8'))
        end = datetime.now()
        duration = int((end - start).total_seconds() * 1000)
        print(f"Batching {response['Count']} entries took { duration }ms.")
        total_batching += duration

        if 'LastEvaluatedKey' in response:
            print('Continuing at next page')
            last_evaluated_key = response['LastEvaluatedKey']
        else:
            print('Finished loading data')
            break

    print(f'Loading {count} entries took {total_loading}ms.')
    print(f'Batching {count} entries took {total_batching}ms.')
    print(f'Processing {count} entries took {int((datetime.now() - now).total_seconds() * 1000)}ms.')
Example #5
0
    def send(self, msgFormat, msgType, msg=None):
        '''
        Function to make message then packet and then send it
        '''
        global stop
        global acknowledged
        global nextSeqNumber

        sendMessage = util.make_message(msgFormat, msgType, msg)
        chunks = util.make_chunks(sendMessage)

        #Start packet
        sequenceNumber = random.randint(10, 100)
        sendPacket = util.make_packet(
            "start",
            sequenceNumber,
        )
        self.sock.sendto(sendPacket.encode("utf-8"),
                         (self.server_addr, self.server_port))

        #Data packets
        #For each data packet send it once acknowledge message is received
        for chunk in chunks:
            while not acknowledged:
                pass
            sendPacket = util.make_packet("data", nextSeqNumber, chunk)
            self.sock.sendto(sendPacket.encode("utf-8"),
                             (self.server_addr, self.server_port))
            acknowledged = False

        #End packet
        #Stop until acknowled message is not received
        while not acknowledged:
            pass
        sendPacket = util.make_packet(
            "end",
            nextSeqNumber,
        )
        self.sock.sendto(sendPacket.encode("utf-8"),
                         (self.server_addr, self.server_port))
        acknowledged = False
Example #6
0
base = {}
while len(base) < 16:
    o = oracle(bytes(16))
    base[o[16:32]] = o[32:]

start_len = len(oracle(bytes()))
dec_so_far = bytes()
ordered_base = []
for byte_num in range(start_len):
    found = 0
    while not found:
        for b in range(1, 256):
            if byte_num < blocksize:
                prepend = bytes(blocksize - byte_num)
            else:
                prepend = bytes()
            check = oracle(prepend + dec_so_far[-16:] + bytes([b]))
            if byte_num < blocksize:
                remainder = base.pop(check[16:32], None)
                if remainder:
                    found = b
                    ordered_base.append(util.make_chunks(remainder, 16))
                    break
            else:
                if check[16:32] == ordered_base[byte_num % 16][(byte_num // 16) - 1]:
                    found = b
                    break
    dec_so_far += bytes([found])
    print(bytes([found]).decode('utf-8'), end="", flush=True)

#print(dec_so_far.decode('utf-8'))
    def send(self, address, msgFormat, msgType, msg=None):
        '''
        Function to make message then packet and then send it and
        maintain the window
        '''
        global ACKSTORE
        global EXPECTEDACK
        windowStore = queue.Queue()
        retransmissions = 0
        packetIndex = 0

        sendMessage = util.make_message(msgFormat, msgType, msg)
        chunks = util.make_chunks(sendMessage)

        #Start packet functionality
        sequenceNumber = random.randint(10, 100)
        sendPacket = util.make_packet("start", sequenceNumber,)
        self.sock.sendto(sendPacket.encode("utf-8"), address)
        sequenceNumber += 1
        EXPECTEDACK[address] = sequenceNumber
        #Resend the start packet if no ack received
        while True:
            try:
                _ = ACKSTORE[address].get(timeout=util.TIME_OUT)
                break
            except:
                if retransmissions == util.NUM_OF_RETRANSMISSIONS:
                    break
                retransmissions += 1
                self.sock.sendto(sendPacket.encode("utf-8"), address)

        #Data packets functionality
        #Store the packets in the queue (number of packets stored <= window size)
        for i in range(self.window):
            sendPacket = util.make_packet("data", sequenceNumber, chunks[i])
            windowStore.put(sendPacket)
            sequenceNumber += 1
            #If packets required to send < window size
            if i == len(chunks) - 1:
                break
        packetIndex = self.window
        EXPECTEDACK[address] += 1
        self.send_window(address, windowStore)
        #Resend the data packet if no ack received and slide the window on receiving the ack
        while not windowStore.empty():
            try:
                _ = ACKSTORE[address].get(timeout=util.TIME_OUT)
                _ = windowStore.get()
                #If all the packets have been catered for we wont go in this block
                if packetIndex < len(chunks):
                    sendPacket = util.make_packet("data", sequenceNumber, chunks[packetIndex])
                    self.sock.sendto(sendPacket.encode("utf-8"), address)
                    windowStore.put(sendPacket)
                    packetIndex += 1
                    sequenceNumber += 1
                retransmissions = 0
                EXPECTEDACK[address] += 1
            except:
                if retransmissions == util.NUM_OF_RETRANSMISSIONS:
                    break
                retransmissions += 1
                self.send_window(address, windowStore)

        #End packet functionality
        sendPacket = util.make_packet("end", sequenceNumber,)
        self.sock.sendto(sendPacket.encode("utf-8"), address)
        #Resend the start packet if no ack received
        while True:
            try:
                _ = ACKSTORE[address].get(timeout=util.TIME_OUT)
                break
            except:
                if retransmissions == util.NUM_OF_RETRANSMISSIONS:
                    break
                retransmissions += 1
                self.sock.sendto(sendPacket.encode("utf-8"), address)
Example #8
0
import util
import binascii

f = open('8.txt')
texts = []
for l in f:
    texts.append(util.a2b(l.strip()))
f.close()

# break the ciphertext into 16 byte chunks, and pick the one with the most
# repeated.  We do this by putting the chunks into a set, and seeing how big
# the set is.
sizes = [(t_b, len(set(tuple(chunk) for chunk in util.make_chunks(t_b, 16))))
         for t_b in texts]

winner = min(sizes, key=lambda x: x[1])[0]

print(util.b2a(winner).decode('utf-8'))
Example #9
0
while len(base) < 16:
    o = oracle(bytes(16))
    base[o[16:32]] = o[32:]

start_len = len(oracle(bytes()))
dec_so_far = bytes()
ordered_base = []
for byte_num in range(start_len):
    found = 0
    while not found:
        for b in range(1, 256):
            if byte_num < blocksize:
                prepend = bytes(blocksize - byte_num)
            else:
                prepend = bytes()
            check = oracle(prepend + dec_so_far[-16:] + bytes([b]))
            if byte_num < blocksize:
                remainder = base.pop(check[16:32], None)
                if remainder:
                    found = b
                    ordered_base.append(util.make_chunks(remainder, 16))
                    break
            else:
                if check[16:32] == ordered_base[byte_num %
                                                16][(byte_num // 16) - 1]:
                    found = b
                    break
    dec_so_far += bytes([found])
    print(bytes([found]).decode('utf-8'), end="", flush=True)

#print(dec_so_far.decode('utf-8'))
Example #10
0
    return util.aes_ecb_enc(util.pkcs7pad(profile, 16), key_b)


def profile_decrypt(enc_b):
    return cookie_parse(util.aes_ecb_dec(enc_b, key_b))


# figure out the block size
start_len = len(profile_for(bytes()))
acc_bytes = b'a'
while True:
    new_len = len(profile_for(acc_bytes))
    if new_len != start_len:
        blocksize = new_len - start_len
        break
    else:
        acc_bytes += b'a'
print('Detected block size {}'.format(blocksize))

email_length = len(acc_bytes) + len(b'user')

our_email = b'a' * email_length
start_ciphertext = profile_for(our_email)[:-blocksize]

clear_len = len(acc_bytes) + 1
malicious_text = util.pkcs7pad(b'admin', blocksize)
out = profile_for((b'a' * clear_len) + malicious_text)
evil = util.make_chunks(out, blocksize)[1]

print(profile_decrypt(start_ciphertext + evil))
Example #11
0
    return iv, enc


def f2(iv, ct):
    msg = util.cbc_dec(ct, key_b, iv=iv)
    try:
        util.strip_pkcs7(msg)
        return True
    except util.PKCS7Error:
        return False


for a in range(len(options)):
    iv, ct = f1(a)

    blocks = [iv] + util.make_chunks(ct, 16)

    # figure out the padding amount
    padding_amount = 1
    for n in range(0, 15):
        evil = util.xor_bytestring(chain(repeat(0, n), [1], repeat(0, 15 - n)),
                                   blocks[-2])
        evil_blocks = blocks[:-2] + [evil] + blocks[-1:]
        if not f2(evil_blocks[0], bytes(chain.from_iterable(evil_blocks[1:]))):
            padding_amount = 16 - n
            break

    final = []

    plain = [padding_amount] * padding_amount
    for n in range(len(blocks) - 1, 0, -1):
Example #12
0
def main(config):

    # CONFIGURATION
    if config.mode not in ['finetuning', 'pretraining']:
        config.print_help()
        sys.exit(1)

    config_dict = vars(config)

    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.INFO)

    tokenizer = transformers.BertTokenizerFast.from_pretrained(
        config.bert_model)

    # changes of config_dict will change the config itself as well
    config_dict['logger'] = logger
    config_dict['date'] = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    config_dict['cls_token_id'] = tokenizer.encode(tokenizer.cls_token,
                                                   add_special_tokens=False)[0]
    config_dict['pad_token_id'] = tokenizer.encode(tokenizer.pad_token,
                                                   add_special_tokens=False)[0]

    logger.info("GIVEN ARGUMENTS")
    for k, v in config_dict.items():
        logger.info(f"{k}: {v}")

    model, optimizer, _, _ = instantiate_model(config, tokenizer)

    loss_fn = torch.nn.CrossEntropyLoss()
    metrics = eval.Metrics()
    logger.info(
        f'Loading and chunking articles from {config.articles_path} database...'
    )

    dev_chunks, dev_articles_ids = util.make_chunks(config.articles_path,
                                                    tokenizer,
                                                    config,
                                                    save_chunks=False)
    dev_chunks, dev_chunks_mask = util.process_chunks(dev_chunks, config)

    logger.info('Loading train/dev claims...')
    claims_dev, evidence_dev, labels_dev = util.load_claims('dev', config)

    logger.info('Tokenizing, padding and masking the claims...')
    claims_dev, claims_dev_mask = util.process_claims(claims_dev,
                                                      tokenizer,
                                                      config,
                                                      _pad_max=True)
    logger.info(f"{len(claims_dev)} dev claims prepared for finetuning.")

    ## Load embedded documents
    doc_emb_path = f"/home/ryparmar/trained_models/doc-emb-{config.continue_training.split('/')[-1]}.npy"
    if os.path.exists(doc_emb_path):
        model.eval()
        eval_claim_embeddings = eval.encode_chunks(
            claims_dev,
            claims_dev_mask,
            model,
            batch_size=config.test_batch_size)
        eval_document_embeddings = io_util.load_np_embeddings(doc_emb_path)
    else:
        eval_claim_embeddings, \
        eval_document_embeddings = eval.evaluation_preprocessing(claims_dev, claims_dev_mask, dev_chunks, dev_chunks_mask, model, config)
        ### CTK PRETRAIN -- Save embedded documents
        io_util.save_np_embeddings(eval_document_embeddings, doc_emb_path)
        logger.info(f"Embeddings saved")
        model.to('cpu')

    # Evaluation
    for k in [10, 20]:
        precision, recall, f1 = eval.retriever_score(eval_claim_embeddings,
                                                     eval_document_embeddings,
                                                     evidence_dev,
                                                     labels_dev,
                                                     dev_articles_ids,
                                                     config,
                                                     k=k)
        # config.logger.info
        print(f"F1: {f1}\tRecall@{k}: {recall}\tPrecision@10: {precision}")
Example #13
0
start_len = len(oracle(bytes()))
print(start_len)
acc_bytes = b'a'
while True:
    new_len = len(oracle(acc_bytes))
    if new_len != start_len:
        blocksize = new_len - start_len
        break
    else:
        acc_bytes += b'a'
print('Detected block size {}'.format(blocksize))

# step 2: detect that the function is using ECB
test_str = util.random_bytes(blocksize)
test_enc = oracle(test_str + test_str)
p1, p2 = util.make_chunks(test_enc, blocksize)[:2]
if p1 == p2:
    print('Oracle is using ECB')

enc_block_at = {}
dec_so_far = bytes()
for byte_num in range(start_len):
    if byte_num < blocksize:
        prepend = b'a' * (blocksize - byte_num - 1)
    else:
        prepend = dec_so_far[-blocksize + 1:]
    enc_chunks = util.make_chunks(oracle(prepend), blocksize)
    if byte_num < blocksize:
        for n, chunk in enumerate(enc_chunks[1:]):
            enc_block_at[byte_num + (n * blocksize)] = chunk
        want = enc_chunks[0]
Example #14
0
candidates = []

for keysize in range(2, 40):
    a = enc_b[:keysize]
    b = enc_b[keysize:2*keysize]
    c = enc_b[keysize*3:keysize*4]
    d = enc_b[keysize*4:keysize*5]
    score = sum(util.hamming_dist(x, y) for x, y in combinations([a, b, c, d], 2))
    candidates.append((keysize, score / keysize))

possible_sizes = [x[0] for x in sorted(candidates, key=lambda x: x[1])[:3]]

possible_keys = []
for keysize in possible_sizes:
    chunks = util.make_chunks(enc_b, keysize)

    key_chars = []
    key_score = 0

    # zip_longest transposes the list of lists
    for block in zip_longest(*chunks):
        candidates = []
        for key_i in range(256):
            try:
                candidates.append((key_i, util.b2s(util.single_byte_xor(util.safe_bytes(block), key_i))))
            except UnicodeDecodeError:
                pass
        key, pt = max(candidates, key=lambda x: util.score_plaintext(x[1]))
        key_chars.append(key)
        key_score += util.score_plaintext(pt)
Example #15
0
import util
import binascii

f = open('8.txt')
texts = []
for l in f:
         texts.append(util.a2b(l.strip()))
f.close()

# break the ciphertext into 16 byte chunks, and pick the one with the most
# repeated.  We do this by putting the chunks into a set, and seeing how big
# the set is.
sizes = [(t_b, len(set(tuple(chunk) for chunk in util.make_chunks(t_b, 16))))
         for t_b in texts]

winner = min(sizes, key=lambda x: x[1])[0]

print(util.b2a(winner).decode('utf-8'))
Example #16
0
def main(config):
    config_dict = vars(config)

    wandb.login()
    wandb.init(project=f'{config.mode}_{config.task}')
    wb_config = wandb.config
    wb_config.learning_rate = config.learning_rate

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s %(levelname)s %(module)s - %(funcName)s: %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
    )

    tokenizer = transformers.BertTokenizerFast.from_pretrained(config.bert_model)

    # changes of config_dict will change the config itself as well
    config_dict['logger'] = logger
    config_dict['date'] = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    config_dict['cls_token_id'] = tokenizer.encode(tokenizer.cls_token, add_special_tokens=False)[0]
    config_dict['pad_token_id'] = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0]
    if not config.task:
        config_dict['task'] = 'ICT+BFS'

    logger.info("GIVEN ARGUMENTS")
    for k, v in config_dict.items():
        logger.info(f"{k}: {v}")

    model, optimizer, metrics = instantiate_model(config, tokenizer)  #prepare_model(config)
    loss_fn = torch.nn.CrossEntropyLoss()

    if metrics:
        logging.info(f"Metrics initialized with pretrained model metrics:\n{metrics}")
    metrics = eval.Metrics(metrics)

    logger.info(f'Loading and chunking articles from {config.articles_path} database...')
    if config.mode == 'finetuning':  ### FINETUNING
        doc_chunks = util.make_chunks(config.articles_path, tokenizer, config, save_chunks=True)
        articles_ids = util.get_par_ids(doc_chunks) if 'CTK' in config.articles_path else list(doc_chunks.keys())

        doc_chunks, chunks_mask = util.process_chunks(doc_chunks, config)
        dev_chunks, dev_chunks_mask, dev_articles_ids = doc_chunks, chunks_mask, articles_ids

        logger.info('Loading train/dev claims...')
        claims_dev, evidence_dev, labels_dev = util.load_claims('dev', config)
        claims_train, evidence_train, labels_train = util.load_claims('train', config)
        logger.info('Removing unverifiable claims...')
        claims_train, evidence_train, labels_train = util.remove_unverifiable_claims(claims_train,
                                                                                    evidence_train,
                                                                                    labels_train, config)
        claims_dev, evidence_dev, labels_dev = util.remove_unverifiable_claims(claims_dev,
                                                                        evidence_dev,
                                                                        labels_dev, config)
        
        logger.info('Removing claims for which the evidence containing document could not be found...')        
        claims_train, \
        evidence_train, \
        labels_train = util.remove_invalid_claims(claims_train, evidence_train, labels_train, articles_ids, config)

        claims_dev, \
        evidence_dev, \
        labels_dev = util.remove_invalid_claims(claims_dev, evidence_dev, labels_dev, articles_ids, config)

        logger.info('Tokenizing, padding and masking the claims...')
        claims_train, claims_train_mask = util.process_claims(claims_train, tokenizer, config, _pad_max=True)
        claims_dev, claims_dev_mask = util.process_claims(claims_dev, tokenizer, config, _pad_max=True)
        logger.info(f"{len(claims_train)} training claims and {len(claims_dev)} dev claims prepared for finetuning.")
        # if 'CTK' in config.articles_path:
    else:  ### PRETRAINING
        # Pro pretraining neni potreba nahravat claimy - claimy se extrahuji z chunku.
        doc_chunks = util.make_chunks(config.articles_path, tokenizer, config, save_chunks=True)
        # if config.task.upper() == 'ICT':
        #     doc_chunks = [chunk for doc, chunks in doc_chunks.items() for chunk in chunks] 

        dev_chunks = util.make_chunks("/mnt/data/factcheck/fever/data-cs/fever/fever.db", 
                                        tokenizer, config, as_eval=True, save_chunks=True)
        dev_articles_ids = list(dev_chunks.keys())
        dev_chunks, dev_chunks_mask = util.process_chunks(dev_chunks, config)

        logger.info('Loading dev claims...')
        claims_dev, evidence_dev, labels_dev = util.load_claims('dev', config,
                                                     path='/mnt/data/factcheck/fever/data-cs/fever-data/dev.jsonl')
        logger.info('Tokenizing, padding and masking the claims...')
        
        claims_dev, evidence_dev, labels_dev = util.remove_unverifiable_claims(claims_dev,
                                                                        evidence_dev,
                                                                        labels_dev, config)
        claims_dev, \
        evidence_dev, \
        labels_dev = util.remove_invalid_claims(claims_dev, evidence_dev, labels_dev, dev_articles_ids, config)
        claims_dev, claims_dev_mask = util.process_claims(claims_dev, tokenizer, config, _pad_max=True)

    id2doc = {i: doc_id for i, (doc_id, _) in enumerate(doc_chunks.items())} if isinstance(doc_chunks, dict) else []
    
    loader = (get_loader(torch.tensor([i for i in range(len(claims_train))]), config.bs) if config.mode == 'finetuning'
              else get_loader(torch.tensor([i for i in range(len(doc_chunks))]), config.bs))


    # Evaluation
    logger.info("Initial evaluation check...")
    def get_sample_keys(keys: list, sample=0.3):
        keys = list(dev_chunks.keys())
        return random.sample(keys, round(len(keys)*sample))

    def get_subset(d: dict, keys: list):
        return {k: d[k] for k in keys}
    sample_keys = get_sample_keys(list(dev_chunks.keys()), 0.3)
    
    # # eval_claim_embed, eval_doc_embed = eval.evaluation_preprocessing(claims_dev, claims_dev_mask, 
    # #                                                                 dev_chunks, dev_chunks_mask, model, config)
    # eval_claim_embed, eval_doc_embed = eval.evaluation_preprocessing(claims_dev, claims_dev_mask, 
    #                                                                 get_subset(dev_chunks, sample_keys), 
    #                                                                 get_subset(dev_chunks_mask, sample_keys), 
    #                                                                 model, config)
    # logger.info("FAISS retrieval")
    # precision, recall, f1, mrr = eval.retriever_score(eval_doc_embed, dev_articles_ids, eval_claim_embed, 
    #                                                 evidence_dev, labels_dev, config, k=20)
    # logger.info(f"F1: {f1}\tPrecision@{20}: {precision}\tRecall@{20}: {recall}\tMRR@{20}: {mrr}")


    logger.info("Training...")
    epoch_num = -1
    wandb.watch(model)
    for epoch_num in range(config.epoch):
        model.train()
        batch_num = len(loader)
        num_training_examples, running_loss = 0, 0.0
        for batch in tqdm(loader, total=batch_num):
            optimizer.zero_grad()
            batch = batch[0]
            num_training_examples += batch.size(0)
            if config.mode == 'finetuning':
                query, query_mask, \
                context, context_mask = util.get_finetuning_batch(batch, claims_train, claims_train_mask, evidence_train,
                                                                doc_chunks, chunks_mask, articles_ids, config)
            else:
                query, query_mask, \
                context, context_mask = util.get_pretraining_batch(ids2docs(batch, id2doc), doc_chunks, 
                                                                            tokenizer, config)

            query_cls_out = model(x=query, x_mask=query_mask)
            context_cls_out = model(x=context, x_mask=context_mask)
            logit = torch.matmul(query_cls_out, context_cls_out.transpose(-2, -1))
            correct_class = torch.tensor([i for i in range(len(query))]).long().to(config.device)
            loss = loss_fn(logit, correct_class)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * batch.size(0)
            epoch_avg_loss = running_loss / num_training_examples

            wandb.log({"epoch_avg_loss": epoch_avg_loss,
                       "running_loss": running_loss,
                       })

        logger.info(f"{epoch_num} epoch, train loss : {round(epoch_avg_loss, 3)}")

        # Backup save of the model
        logger.info(f"Saving backup of the model after {epoch_num+1} epochs")
        metrics.update_loss(epoch_avg_loss, epoch_num)
        if (epoch_num + 1) % 5 == 0:
            io_util.save_model(model, optimizer, metrics, f'{config.model_weight}_{epoch_num+1}')

        # Evaluation
        # if (epoch_num + 1) % 5 == 0:            
        #     logger.info("Evaluation...")
        #     # eval_claim_embed, eval_doc_embed = eval.evaluation_preprocessing(claims_dev, claims_dev_mask, 
        #     #                                                                 dev_chunks, dev_chunks_mask, model, config)

        #     eval_claim_embed, eval_doc_embed = eval.evaluation_preprocessing(claims_dev, claims_dev_mask, 
        #                                                                 get_subset(dev_chunks, sample_keys), 
        #                                                                 get_subset(dev_chunks_mask, sample_keys), 
        #                                                                 model, config)

        #     precision, recall, f1, mrr = eval.retriever_score(eval_doc_embed, dev_articles_ids, eval_claim_embed, 
        #                                                     evidence_dev, labels_dev, config, k=20)

        #     logger.info(f"F1: {f1}\tPrecision@{20}: {precision}\tRecall@{20}: {recall}\tMRR@{20}: {mrr}")
            # metrics.update_metrics(f1, precision, recall, mrr)
            # metrics.update_loss(epoch_avg_loss, epoch_num)
        #     if recall > metrics.max_rec:
        #         metrics.update_best(f1, precision, recall, mrr)
        #         io_util.save_model(model, optimizer, metrics, config.model_weight + '_best')
        #         logger.info(f"Model saved. Metrics:\n{metrics}")

    metrics.save(f'metrics/{eval.get_model_name_for_metrics(config)}')
    logger.info("Done training")
Example #17
0
    profile = b'email=' + email_b + b'&uid=10&role=user'

    return util.aes_ecb_enc(util.pkcs7pad(profile, 16), key_b)

def profile_decrypt(enc_b):
    return cookie_parse(util.aes_ecb_dec(enc_b, key_b))

# figure out the block size
start_len = len(profile_for(bytes()))
acc_bytes = b'a'
while True:
    new_len = len(profile_for(acc_bytes))
    if new_len != start_len:
        blocksize = new_len - start_len
        break
    else:
        acc_bytes += b'a'
print('Detected block size {}'.format(blocksize))

email_length = len(acc_bytes) + len(b'user')

our_email = b'a' * email_length
start_ciphertext = profile_for(our_email)[:-blocksize]

clear_len = len(acc_bytes) + 1
malicious_text = util.pkcs7pad(b'admin', blocksize)
out = profile_for((b'a' * clear_len) + malicious_text)
evil = util.make_chunks(out, blocksize)[1]

print(profile_decrypt(start_ciphertext + evil))