Пример #1
0
def eval_worker_loop(temp_dir, sample):
    o_filename = sample['label'] + '-' + path.basename(
        sample['base_dir']) + '.gz'
    o_filepath = path.join(temp_dir, o_filename)
    logger.log_debug(module_name, 'Writing to ' + o_filepath)
    with gzip.open(o_filepath, 'wt') as ofile:
        if options.preprocess:
            gen_func = reader.read_preprocessed
        else:
            gen_func = reader.disasm_pt_file

        iqueue, oqueue = generator.start_generator(1, gen_func,
                                                   options.queue_size,
                                                   options.seq_len, redis_info)

        if options.preprocess:
            iqueue.put((None, sample['parsed_filepath']))
        else:
            sample_memory = reader.read_memory_file(sample['mapping_filepath'])
            if sample_memory is None:
                logger.log_warning(module_name,
                                   'Failed to parse memory file, skipping')
                generator.stop_generator(10)
                return
            iqueue.put(
                (None, sample['trace_filepath'], bin_dirpath, sample_memory))

        while True:
            try:
                res = oqueue.get(True, 5)
            except:
                in_service = generator.get_in_service()
                if in_service == 0:
                    break
                else:
                    logger.log_debug(
                        module_name,
                        str(in_service) + ' workers still working on jobs')
                    continue

            xs = res[1][1:]
            ys = res[1][0] % options.max_classes

            predict, conf = predict_prob(xs, ys)
            corr = int(predict == ys)
            ofile.write(
                str(corr) + ',' + str(predict) + ',' + str(conf) + ',' +
                str(ys) + "\n")

        generator.stop_generator(10)
Пример #2
0
def test_generator():
    from sys import argv, exit
    import reader
    import tempfile

    if len(argv) < 5:
        print(argv[0], '<input_file>', '<bin_dir>', '<memory_file>',
              '<seq_len>')
        exit(0)

    logger.log_start(logging.DEBUG)

    try:
        ofile = tempfile.mkstemp(text=True)
        ofilefd = os.fdopen(ofile[0], 'w')

        filters.set_filters(['ret'])
        memory = reader.read_memory_file(argv[3])

        input, output = start_generator(2,
                                        reader.disasm_pt_file,
                                        seq_len=int(argv[4], 10))
        input.put((None, argv[1], argv[2], memory))
        while True:
            try:
                res = output.get(True, 5)
            except queue.Empty:
                count = get_in_service()
                if get_in_service() == 0:
                    break
                else:
                    logger.log_debug(
                        module_name,
                        str(count) + ' workers still working on jobs')
                    continue
            ofilefd.write(str(res[0]) + ": " + str(res[1]) + "\n")

        stop_generator(10)
        ofilefd.close()
    except:
        traceback.print_exc()
        ofilefd.close()
        os.remove(ofile[1])
        logger.log_stop()
        exit(1)

    logger.log_info(module_name, 'Wrote generated tuples to ' + str(ofile[1]))
    logger.log_stop()
Пример #3
0
def map_to_model(samples, f):
    """ A helper function because train_on_batch() and test_on_batch() are so similar."""
    random.shuffle(samples)
    # There's no point spinning up more worker threads than there are samples
    threads = min(options.threads, len(samples))

    if options.preprocess:
        gen_func = reader.read_preprocessed
    else:
        gen_func = reader.disasm_pt_file

    # When you gonna fire it up? When you gonna fire it up?
    iqueue, oqueue = generator.start_generator(threads, gen_func, options.queue_size, options.seq_len,
                                               options.embedding_in_dim, options.max_classes, options.batch_size)

    for sample in samples:
        if options.preprocess:
            iqueue.put((None, sample['parsed_filepath']))
        else:
            sample_memory = reader.read_memory_file(sample['mapping_filepath'])
            if sample_memory is None:
                logger.log_warning(module_name, 'Failed to parse memory file, skipping')
                continue
            iqueue.put((None, sample['trace_filepath'], options.bin_dir, sample_memory, options.timeout))

    # Get parsed sequences and feed them to the LSTM model
    batch_cnt = 0
    while True:
        try:
            res = oqueue.get(True, 5)
        except queue.Empty:
            in_service = generator.get_in_service()
            if in_service == 0:
                break
            else:
                logger.log_debug(module_name, str(in_service) + ' workers still working on jobs')
                continue

        yield f(res[1], res[2])
        batch_cnt += 1

    logger.log_info(module_name, "Processed " + str(batch_cnt) + " batches, " + str(batch_cnt * options.batch_size) + " samples")

    generator.stop_generator(10)
    # End of generator
    while True:
        yield None
Пример #4
0
def map_to_model(samples, f):
    """ A helper function because train_on_batch() and test_on_batch() are so similar."""
    global redis_info
    global oqueue

    random.shuffle(samples)
    # There's no point spinning up more worker threads than there are samples
    threads = min(options.threads, len(samples))

    if options.preprocess:
        gen_func = reader.read_preprocessed
    else:
        gen_func = reader.disasm_pt_file

    # When you gonna fire it up? When you gonna fire it up?
    iqueue, oqueue = generator.start_generator(threads, gen_func,
                                               options.queue_size,
                                               options.seq_len, redis_info)

    for sample in samples:
        if options.preprocess:
            iqueue.put((None, sample['parsed_filepath']))
        else:
            sample_memory = reader.read_memory_file(sample['mapping_filepath'])
            if sample_memory is None:
                logger.log_warning(module_name,
                                   'Failed to parse memory file, skipping')
                continue
            iqueue.put(
                (None, sample['trace_filepath'], bin_dirpath, sample_memory))

    ncpu = cpu_count()
    workers = Pool(ncpu)
    res = workers.map(worker_loop, [f] * ncpu)

    generator.stop_generator(10)

    return sum(res) / len(res)
Пример #5
0
def eval_model(eval_set):
    """ Evaluate the LSTM model."""
    random.shuffle(eval_set)
    # There's no point spinning up more worker threads than there are samples
    threads = min(options.threads, len(eval_set))

    if options.eval_dir is None:
        eval_dir = tempfile.mkdtemp(suffix='-lstm-pt')
    else:
        if not path.exists(options.eval_dir):
            mkdir(options.eval_dir)
        eval_dir = options.eval_dir
    logger.log_info(module_name, 'Evaluation results will be written to ' + eval_dir)

    if options.preprocess:
        gen_func = reader.read_preprocessed
    else:
        gen_func = reader.disasm_pt_file

    iqueue, oqueue = generator.start_generator(threads, gen_func, options.queue_size, options.seq_len,
                                               options.embedding_in_dim, options.max_classes, options.batch_size)

    for sample in eval_set:
        o_filename = sample['label'] + '-' + path.basename(sample['base_dir']) + '.gz'
        o_filepath = path.join(eval_dir, o_filename)
        EVAL_WRITE_LOCKS[o_filepath] = threading.Lock()
        if options.preprocess:
            iqueue.put((o_filepath, sample['parsed_filepath']))
        else:
            sample_memory = reader.read_memory_file(sample['mapping_filepath'])
            if sample_memory is None:
                logger.log_warning(module_name, 'Failed to parse memory file, skipping')
                continue
            iqueue.put((o_filepath, sample['trace_filepath'], options.bin_dir, sample_memory, options.timeout))

    # Use threads instead of processes to handle and write the prediction
    # results because I/O and numpy crunching do not require the GIL.
    EVAL_PRED_DONE.clear()
    wqueue = queue.Queue(options.queue_size)
    workers = list()
    for id in range(threads):
        worker = threading.Thread(target=eval_worker, args=(wqueue,))
        worker.daemon = True
        worker.start()
        workers.append(worker)

    while True:
        try:
            res = oqueue.get(True, 5)
        except queue.Empty:
            in_service = generator.get_in_service()
            if in_service == 0:
                break
            else:
                logger.log_debug(module_name, str(in_service) + ' workers still working on jobs')
                continue

        wqueue.put([res, model.predict_on_batch(res[1])])

    EVAL_PRED_DONE.set()
    logger.log_debug(module_name, "Waiting for eval workers to terminate")
    for worker in workers:
        worker.join()
    logger.log_debug(module_name, "All eval workers are done")
    generator.stop_generator(10)
Пример #6
0
def main():
    # Parse input arguments
    parser = OptionParser(
        usage='Usage: %prog [options] trace_directory bin_directory')
    parser.add_option(
        '-f',
        '--force',
        action='store_true',
        help='If a complete or partial output already exists, overwrite it.')
    parser.add_option(
        '-t',
        '--timeout',
        action='store',
        type='int',
        default=None,
        help='Max seconds to run before quitting (default: infinite).')
    parser.add_option(
        '-p',
        '--no-partial',
        action='store_true',
        help='If timeout is reached, do not save the partially parsed trace.')
    options, args = parser.parse_args()

    if len(args) < 2:
        parser.print_help()
        sys.exit(0)

    data_dir = args[0]
    bin_dir = args[1]

    logger.log_start(logging.INFO)

    # Input validation
    if not os.path.isdir(data_dir):
        logger.log_error(module_name, data_dir + ' is not a directory')
        logger.log_stop()
        sys.exit(1)

    if not os.path.isdir(bin_dir):
        logger.log_error(module_name, bin_dir + ' is not a directory')
        logger.log_stop()
        sys.exit(1)

    if options.timeout is None and options.no_partial:
        logger.log_warning(
            module_name, "Setting --no-partial without --timeout does nothing")

    # Make sure all the expected files are there
    mem_file = None
    trace_file = None

    files = os.listdir(data_dir)
    for file in files:
        if file == 'mapping.txt' or file == 'mapping.txt.gz':
            mem_file = os.path.join(data_dir, file)
        elif file == 'trace_0' or file == 'trace_0.gz':
            trace_file = os.path.join(data_dir, file)

    if mem_file is None:
        logger.log_error(
            module_name,
            'Could not find mapping.txt or mapping.txt.gz in ' + data_dir)
        logger.log_stop()
        sys.exit(1)

    if trace_file is None:
        logger.log_error(module_name,
                         'Could not find trace_0 or trace_0.gz in ' + data_dir)
        logger.log_stop()
        sys.exit(1)

    # Parse the memory file
    mem_map = reader.read_memory_file(mem_file)
    if mem_map is None:
        logger.log_error(module_name, 'Failed to parse memory mapping file')
        logger.log_stop()
        sys.exit(1)

    # We're ready to parse the trace
    o_filepath = os.path.join(data_dir, 'trace_parsed.gz')

    if os.path.isfile(o_filepath) and not options.force:
        logger.log_error(module_name, 'Preprocess file already exists')
        logger.log_stop()
        sys.exit(1)

    if os.path.isfile(o_filepath + '.part') and not options.force:
        logger.log_error(module_name, 'Partial preprocess file already exists')
        logger.log_stop()
        sys.exit(1)

    entries = 0
    with gzip.open(o_filepath + '.part', 'wb') as ofile:
        for instr in reader.disasm_pt_file(trace_file, bin_dir, mem_map,
                                           options.timeout):
            if instr is None:
                break
            ofile.write(pack_instr(instr))
            entries += 1

    if reader.DISASM_TIMEOUT.is_set() and options.no_partial:
        logger.log_info(module_name, "Deleting partial trace")
        os.remove(o_filepath + '.part')
    elif entries > 0:
        os.rename(o_filepath + '.part', o_filepath)
    else:
        logger.log_error(module_name, 'No output produced, empty file')
        os.remove(o_filepath + '.part')

    logger.log_stop()