예제 #1
0
def clean_exit(error_code, message, kill_generator=False):
    """ Performs a clean exit, useful for when errors happen that can't be recovered from."""
    logger.log_critical(module_name, message)
    if kill_generator:
        generator.stop_generator(2)
    logger.log_stop()
    sys.exit(error_code)
예제 #2
0
def eval_worker_loop(temp_dir, sample):
    o_filename = sample['label'] + '-' + path.basename(
        sample['base_dir']) + '.gz'
    o_filepath = path.join(temp_dir, o_filename)
    logger.log_debug(module_name, 'Writing to ' + o_filepath)
    with gzip.open(o_filepath, 'wt') as ofile:
        if options.preprocess:
            gen_func = reader.read_preprocessed
        else:
            gen_func = reader.disasm_pt_file

        iqueue, oqueue = generator.start_generator(1, gen_func,
                                                   options.queue_size,
                                                   options.seq_len, redis_info)

        if options.preprocess:
            iqueue.put((None, sample['parsed_filepath']))
        else:
            sample_memory = reader.read_memory_file(sample['mapping_filepath'])
            if sample_memory is None:
                logger.log_warning(module_name,
                                   'Failed to parse memory file, skipping')
                generator.stop_generator(10)
                return
            iqueue.put(
                (None, sample['trace_filepath'], bin_dirpath, sample_memory))

        while True:
            try:
                res = oqueue.get(True, 5)
            except:
                in_service = generator.get_in_service()
                if in_service == 0:
                    break
                else:
                    logger.log_debug(
                        module_name,
                        str(in_service) + ' workers still working on jobs')
                    continue

            xs = res[1][1:]
            ys = res[1][0] % options.max_classes

            predict, conf = predict_prob(xs, ys)
            corr = int(predict == ys)
            ofile.write(
                str(corr) + ',' + str(predict) + ',' + str(conf) + ',' +
                str(ys) + "\n")

        generator.stop_generator(10)
예제 #3
0
def map_to_model(samples, f):
    """ A helper function because train_on_batch() and test_on_batch() are so similar."""
    random.shuffle(samples)
    # There's no point spinning up more worker threads than there are samples
    threads = min(options.threads, len(samples))

    if options.preprocess:
        gen_func = reader.read_preprocessed
    else:
        gen_func = reader.disasm_pt_file

    # When you gonna fire it up? When you gonna fire it up?
    iqueue, oqueue = generator.start_generator(threads, gen_func, options.queue_size, options.seq_len,
                                               options.embedding_in_dim, options.max_classes, options.batch_size)

    for sample in samples:
        if options.preprocess:
            iqueue.put((None, sample['parsed_filepath']))
        else:
            sample_memory = reader.read_memory_file(sample['mapping_filepath'])
            if sample_memory is None:
                logger.log_warning(module_name, 'Failed to parse memory file, skipping')
                continue
            iqueue.put((None, sample['trace_filepath'], options.bin_dir, sample_memory, options.timeout))

    # Get parsed sequences and feed them to the LSTM model
    batch_cnt = 0
    while True:
        try:
            res = oqueue.get(True, 5)
        except queue.Empty:
            in_service = generator.get_in_service()
            if in_service == 0:
                break
            else:
                logger.log_debug(module_name, str(in_service) + ' workers still working on jobs')
                continue

        yield f(res[1], res[2])
        batch_cnt += 1

    logger.log_info(module_name, "Processed " + str(batch_cnt) + " batches, " + str(batch_cnt * options.batch_size) + " samples")

    generator.stop_generator(10)
    # End of generator
    while True:
        yield None
예제 #4
0
def map_to_model(samples, f):
    """ A helper function because train_on_batch() and test_on_batch() are so similar."""
    global redis_info
    global oqueue

    random.shuffle(samples)
    # There's no point spinning up more worker threads than there are samples
    threads = min(options.threads, len(samples))

    if options.preprocess:
        gen_func = reader.read_preprocessed
    else:
        gen_func = reader.disasm_pt_file

    # When you gonna fire it up? When you gonna fire it up?
    iqueue, oqueue = generator.start_generator(threads, gen_func,
                                               options.queue_size,
                                               options.seq_len, redis_info)

    for sample in samples:
        if options.preprocess:
            iqueue.put((None, sample['parsed_filepath']))
        else:
            sample_memory = reader.read_memory_file(sample['mapping_filepath'])
            if sample_memory is None:
                logger.log_warning(module_name,
                                   'Failed to parse memory file, skipping')
                continue
            iqueue.put(
                (None, sample['trace_filepath'], bin_dirpath, sample_memory))

    ncpu = cpu_count()
    workers = Pool(ncpu)
    res = workers.map(worker_loop, [f] * ncpu)

    generator.stop_generator(10)

    return sum(res) / len(res)
예제 #5
0
def eval_model(eval_set):
    """ Evaluate the LSTM model."""
    random.shuffle(eval_set)
    # There's no point spinning up more worker threads than there are samples
    threads = min(options.threads, len(eval_set))

    if options.eval_dir is None:
        eval_dir = tempfile.mkdtemp(suffix='-lstm-pt')
    else:
        if not path.exists(options.eval_dir):
            mkdir(options.eval_dir)
        eval_dir = options.eval_dir
    logger.log_info(module_name, 'Evaluation results will be written to ' + eval_dir)

    if options.preprocess:
        gen_func = reader.read_preprocessed
    else:
        gen_func = reader.disasm_pt_file

    iqueue, oqueue = generator.start_generator(threads, gen_func, options.queue_size, options.seq_len,
                                               options.embedding_in_dim, options.max_classes, options.batch_size)

    for sample in eval_set:
        o_filename = sample['label'] + '-' + path.basename(sample['base_dir']) + '.gz'
        o_filepath = path.join(eval_dir, o_filename)
        EVAL_WRITE_LOCKS[o_filepath] = threading.Lock()
        if options.preprocess:
            iqueue.put((o_filepath, sample['parsed_filepath']))
        else:
            sample_memory = reader.read_memory_file(sample['mapping_filepath'])
            if sample_memory is None:
                logger.log_warning(module_name, 'Failed to parse memory file, skipping')
                continue
            iqueue.put((o_filepath, sample['trace_filepath'], options.bin_dir, sample_memory, options.timeout))

    # Use threads instead of processes to handle and write the prediction
    # results because I/O and numpy crunching do not require the GIL.
    EVAL_PRED_DONE.clear()
    wqueue = queue.Queue(options.queue_size)
    workers = list()
    for id in range(threads):
        worker = threading.Thread(target=eval_worker, args=(wqueue,))
        worker.daemon = True
        worker.start()
        workers.append(worker)

    while True:
        try:
            res = oqueue.get(True, 5)
        except queue.Empty:
            in_service = generator.get_in_service()
            if in_service == 0:
                break
            else:
                logger.log_debug(module_name, str(in_service) + ' workers still working on jobs')
                continue

        wqueue.put([res, model.predict_on_batch(res[1])])

    EVAL_PRED_DONE.set()
    logger.log_debug(module_name, "Waiting for eval workers to terminate")
    for worker in workers:
        worker.join()
    logger.log_debug(module_name, "All eval workers are done")
    generator.stop_generator(10)
예제 #6
0
def train_model(training_set):
    """ Trains the LSTM model."""
    start_time = datetime.now()
    # Checkpointing for saving model weights
    freq_c = options.checkpoint_interval * 60
    last_c = datetime.now()
    last_b = 10000
    # For reporting current metrics
    freq_s = options.status_interval * 60
    last_s = datetime.now()

    res = [0.0] * len(model.metrics_names)
    batches = 0
    num_samples = len(training_set)
    for status in map_to_model(training_set, model.train_on_batch):
        if status is None:
            break
        for stat in range(len(status)):
            res[stat] += status[stat]
        batches += 1
        # Print current metrics every minute
        if (datetime.now() - last_s).total_seconds() > freq_s:
            c_metrics = [status / batches for status in res]
            c_metrics_str = ', '.join([str(model.metrics_names[x]) + ' ' + '%.12f' % (c_metrics[x]) for x in range(len(c_metrics))])
            c_metrics_str += ', progress %.4f' % (float(generator.fin_tasks.value) / float(num_samples))
            logger.log_info(module_name, 'Status: ' + c_metrics_str)
            last_s = datetime.now()
        # Save current weights at user specified frequency
        if freq_c > 0 and (datetime.now() - last_c).total_seconds() > freq_c:
            logger.log_debug(module_name, 'Checkpointing weights')
            c_metrics = [status / batches for status in res]
            if not options.checkpoint_best or c_metrics[0] < last_b:
                try:
                    model.save_weights(options.save_weights)
                    if not options.multi_gpu is None:
                        template.save_weights(options.save_weights + '.single')
                except:
                    generator.stop_generator(10)
                    clean_exit(EXIT_RUNTIME_ERROR, "Failed to save LSTM weights:\n" + str(traceback.format_exc()))
            if options.checkpoint_es and c_metrics[0] > last_b:
                logger.log_info(module_name, 'Loss did not improve between checkpoints, early stopping and restoring last weights')
                generator.stop_generator(10)
                try:
                    model.load_weights(options.save_weights)
                except:
                    clean_exit(EXIT_RUNTIME_ERROR, "Failed to load LSTM weights:\n" + str(traceback.format_exc()))
                return
            last_b = c_metrics[0]
            last_c = datetime.now()

    if batches < 1:
        logger.log_warning(module_name, 'Testing set did not generate a full batch of data, cannot test')
        return

    for stat in range(len(res)):
        res[stat] /= batches

    logger.log_info(module_name, 'Results: ' + ', '.join([str(model.metrics_names[x]) + ' ' + str(res[x]) for x in range(len(res))]))
    logger.log_debug(module_name, 'Training finished in ' + str(datetime.now() - start_time))

    return res[0] # Average Loss