Пример #1
0
def main():
    args = parse_test_args()
    train_args = pickle.load(
        open(os.path.join(args.output_dir, 'train_args'), 'rb'))
    assert train_args.output_dir == args.output_dir
    args.__dict__.update(train_args.__dict__)
    init_logging(os.path.join(args.output_dir, 'log_test.txt'))
    logger.info("args: " + str(args))

    config_class, model_class, args.tokenizer_class = model_classes[
        args.model_type]
    model_config = config_class.from_pretrained(args.model_name,
                                                num_labels=args.n_labels,
                                                hidden_dropout_prob=0,
                                                attention_probs_dropout_prob=0)
    save_model_path = os.path.join(args.output_dir,
                                   'checkpoint-{}'.format(len(args.tasks) - 1))
    model = model_class.from_pretrained(save_model_path,
                                        config=model_config).cuda()

    avg_acc = 0
    for task_id, task in enumerate(args.tasks):
        logger.info("Start testing {}...".format(task))
        test_dataset = pickle.load(
            open(
                os.path.join(args.output_dir,
                             'test_dataset-{}'.format(task_id)), 'rb'))
        task_acc = test_task(task_id, args, model, test_dataset)
        avg_acc += task_acc / len(args.tasks)
    logger.info("Average acc: {:.3f}".format(avg_acc))
Пример #2
0
    def __init__(self,
                 interface,
                 ports=None,
                 embedded_bro=True,
                 bro_home=None,
                 idx=1,
                 start_port=None,
                 bpf_filter=""):
        Driver.__init__(self)

        if ports is None:
            ports = [80, 81, 1080, 3128, 8000, 8080, 8888, 9001]
        self.embedded_bro = embedded_bro
        self.bro_home = get_bro_home(bro_home)
        self.interface = interface
        self.bpf_filter = bpf_filter
        self.logger = settings.init_logging('bro.{}'.format(idx))
        self.ports = configcontainer.get_config("sniffer").get_string(
            "filter.traffic.server_ports", "") or ports
        self.ports = expand_ports(self.ports)
        self.idx = idx
        self.bro_port = start_port + idx
        self.last_netstat_ts = millis_now()
        self.sub_task = None
        self.client_task = None
        self.last_update = 0
        self.filtered_clients = []
        self.encrypt_keys = []
        self.encrypt_salt = ""
        self.ep = None
        self.sub = None
        self.ss = None
        self.data_mr = None
        self.error_mr = None
        self.running = False
Пример #3
0
def main():
    args = parse_train_args()
    pickle.dump(args, open(os.path.join(args.output_dir, 'train_args'), 'wb'))
    init_logging(os.path.join(args.output_dir, 'log_train.txt'))
    logger.info("args: " + str(args))

    logger.info("Initializing main {} model".format(args.model_name))
    config_class, model_class, args.tokenizer_class = model_classes[args.model_type]
    tokenizer = args.tokenizer_class.from_pretrained(args.model_name)

    model_config = config_class.from_pretrained(args.model_name, num_labels=args.n_labels)
    config_save_path = os.path.join(args.output_dir, 'config')
    model_config.to_json_file(config_save_path)
    model = model_class.from_pretrained(args.model_name, config=model_config).cuda()
    memory = Memory(args)

    for task_id, task in enumerate(args.tasks):
        logger.info("Start parsing {} train data...".format(task))
        train_dataset = TextClassificationDataset(task, "train", args, tokenizer)

        if args.valid_ratio > 0:
            logger.info("Start parsing {} valid data...".format(task))
            valid_dataset = TextClassificationDataset(task, "valid", args, tokenizer)
        else:
            valid_dataset = None

        logger.info("Start training {}...".format(task))
        train_task(args, model, memory, train_dataset, valid_dataset)
        model_save_path = os.path.join(args.output_dir, 'checkpoint-{}'.format(task_id))
        torch.save(model.state_dict(), model_save_path)
        pickle.dump(memory, open(os.path.join(args.output_dir, 'memory-{}'.format(task_id)), 'wb'))


    del model
    memory.build_tree()

    for task_id, task in enumerate(args.tasks):
        logger.info("Start parsing {} test data...".format(task))
        test_dataset = TextClassificationDataset(task, "test", args, tokenizer)
        pickle.dump(test_dataset, open(os.path.join(args.output_dir, 'test_dataset-{}'.format(task_id)), 'wb'))
        logger.info("Start querying {}...".format(task))
        query_neighbors(task_id, args, memory, test_dataset)
Пример #4
0
    def __init__(self, name, maxsize=10000):
        # internal msg queue
        self.queue = gevent.queue.Queue(maxsize=maxsize)

        # # of dropped msgs by driver
        self.dropped_msgs = 0

        # # of the recent consecutive errors due to queue full
        # TODO thread safe
        self.full_error_count = 0

        self.logger = settings.init_logging("sniffer.driver.{}".format(name))
Пример #5
0
def main():
    args = parse_test_args()
    train_args = pickle.load(
        open(os.path.join(args.output_dir, 'train_args'), 'rb'))
    # bp()
    train_args.model_type = "bert-class"
    train_args.__dict__.update(args.__dict__)
    args = train_args
    init_logging(
        os.path.join(
            args.output_dir,
            args.test_log_filename.split(".")[0] + args.model_type + ".txt"))
    logger.info(f"args: {args}")

    config_class, model_class, args.tokenizer_class = MODEL_CLASSES[
        args.model_type]
    model_config = config_class.from_pretrained(args.model_name,
                                                num_labels=args.n_labels,
                                                hidden_dropout_prob=0,
                                                attention_probs_dropout_prob=0)
    args.dataset_order = [
        "dbpedia_csv", "yahoo_answers_csv", "ag_news_csv",
        "amazon_review_full_csv", "yelp_review_full_csv"
    ]
    save_model_path = os.path.join(args.output_dir,
                                   f'checkpoint-{len(args.dataset_order)-1}')
    model = model_class.from_pretrained(save_model_path,
                                        config=model_config).to(args.device)

    avg_acc = 0
    for dataset_id, dataset_name in enumerate(args.dataset_order):
        logger.info(f"Start testing {dataset_name}...")
        test_dataset = pickle.load(
            open(os.path.join(args.output_dir, f'test_dataset-{dataset_id}'),
                 'rb'))
        dataset_acc = test(dataset_id, args, model, test_dataset)
        avg_acc += dataset_acc / len(args.dataset_order)
    logger.info(f"Average acc: {avg_acc:.3f}")
Пример #6
0
    def __init__(self, id, parser, driver, cpu=None, is_process=True):
        self.parser = parser
        self.driver = driver
        self.id = id

        self._running = False
        self._rpc_task = None
        self._events_task = None
        self._health_task = None

        self.queue = gevent.queue.Queue(maxsize=10000)
        self.cpu = cpu
        self.is_process = is_process

        self.logger = settings.init_logging("main.{}".format(self.id))

        self.error_mr = MetricsRecorder("sniffer.main.error")
        self.msg_mr = MetricsRecorder("sniffer.main.msg")
        self.event_mr = MetricsRecorder("sniffer.main.event")
        self.rpc_mr = MetricsRecorder("sniffer.main.rpc")
        self.main_mr = MetricsRecorder("sniffer.main.loop")

        self.urltree = URLTree()
Пример #7
0
    w.addstr(y, 0, zheader)
    w.addnstr(y, zwidth + x, header.rjust(width), width)
    for i in range(0, len(counts)):
        w.addstr(y + 1 + i, 0, '%*d' % (zwidth, i))
        w.addstr(y + 1 + i, zwidth + x, '%*d' % (width, counts[i] if i < len(counts) else 0))

    w.refresh()



def fatal(msg):
    sys.stderr.write(msg + '\n')
    sys.exit()

if __name__ == "__main__":
    settings.init_logging()

    u.setup()

    try:
        specfile = open(sys.argv[1])
    except IndexError:
        specfile = sys.stdin

    try:
        args = parse_yaml_args(specfile)
        db_validate(args)
    except RuntimeError, e:
        fatal(str(e))

    download(args['region'], args['layers'])
Пример #8
0
import os
import time
import traceback

import gevent
import json
from json import loads
from json import dumps
import requests
from requests import post
from requests import get
from requests import put
from requests import delete
from settings import init_logging

logger = init_logging('nebula.produce')


class RequestsData(object):
    def __init__(self, url, data, cookies, method='get'):
        self.data = data
        self.url = url
        self.cookies = cookies
        self.method = method

    def request(self):
        m = dict(
            get=get,
            put=put,
            post=post,
            delete=delete,
Пример #9
0
import os
import time
import traceback

import gevent
import json
from json import loads
from json import dumps
import requests
from requests import post
from requests import get
from requests import put
from requests import delete
import settings

logger = settings.init_logging('nebula.produce')


class RequestsData(object):
    def __init__(self, url, data, cookies, method='get'):
        self.data = data
        self.url = url
        self.cookies = cookies
        self.method = method

    def request(self):
        m = dict(
            get=get,
            put=put,
            post=post,
            delete=delete,
Пример #10
0
# -*- coding: utf-8 -*-
import re
import os
import logging
import subprocess
import json
from threathunter_common.util import millis_now
from threathunter_common.event import Event

from ..parser import Parser, extract_common_properties, extract_http_log_event
from ..parserutil import extract_value_from_body, get_md5, get_json_obj
from ..msg import HttpMsg
import time
import importlib
import settings
logger = settings.init_logging("sniffer.parser.{}".format("defaultparser"))
"""
#demo

#Login event extractor

l_passwd_pattern = re.compile("(&|^)password=(.*?)($|&)")
l_name_pattern = re.compile("(&|^)account=(.*?)($|&)")


def extract_login_log_event(httpmsg):
    if not isinstance(httpmsg, HttpMsg):
        return
    if httpmsg.method != "POST":
        return
    if "users/login" not in httpmsg.uri:
Пример #11
0
import json
import hashlib
import urlparse
import logging
import Cookie
from collections import Mapping
from IPy import IP

from complexconfig.configcontainer import configcontainer

from .bson.objectid import ObjectId
from .befilteredexception import BeFilteredException
from .path_normalizer import normalize_path
from settings import init_logging

logger = init_logging("sniffer.httpmsg")

sniffer_config = configcontainer.get_config("sniffer")

suffix_config = sniffer_config.item(
    key="filter.static.suffixes",
    caching=60,
    default={
        "gif", "png", "ico", "css", "js", "csv", "txt", "jpeg", "jpg", "woff",
        "ttf"
    },
    cb_load=lambda raw: set(raw.lower().split(",")) if raw else set())

filtered_hosts_config = sniffer_config.item(
    key="filter.traffic.domains",
    caching=60,
Пример #12
0
import json
import hashlib
import urlparse
import logging
import Cookie
from collections import Mapping
from IPy import IP

from complexconfig.configcontainer import configcontainer

from .bson.objectid import ObjectId
from .befilteredexception import BeFilteredException
from .path_normalizer import normalize_path
import settings

logger = settings.init_logging("sniffer.httpmsg")

sniffer_config = configcontainer.get_config("sniffer")

suffix_config = sniffer_config.item(
    key="filter.static.suffixes",
    caching=60,
    default={
        "gif", "png", "ico", "css", "js", "csv", "txt", "jpeg", "jpg", "woff",
        "ttf"
    },
    cb_load=lambda raw: set(raw.lower().split(",")) if raw else set())

filtered_hosts_config = sniffer_config.item(
    key="filter.traffic.domains",
    caching=60,
Пример #13
0
        logger.info("score: {}".format(score_dict))
        score_dicts.append(score_dict)

    with open(os.path.join(model_dir, "metrics.json"),"w") as f:
        json.dump(score_dicts, f)


if __name__ == '__main__':
    if args.n_gpus > 1:
        raise NotImplementedError("test can be run with only one gpu currently!")
    
    if args.model_name == "gpt2":
        args.fp32 = False  # always use fp16 in testing

    if not args.debug:
        logging.getLogger("pytorch_transformers").setLevel(logging.WARNING)
        logging.getLogger("pytorch_transformers.tokenization_utils").setLevel(logging.CRITICAL)
    init_logging(os.path.join(args.model_dir_root, 'log_test.txt'))
    logger.info('args = {}'.format(args))

    if args.seq_train_type == "multitask":
        test_one_to_many('_'.join(args.tasks))
    else:
        if args.unbound:
            TASK_DICT = lll_unbound_setting(split_size=args.unbound, data_type="test",test_target="origin")
            for task_load in args.splitted_tasks:
                test_one_to_many(task_load)
        else:
            for task_load in args.tasks:
                test_one_to_many(task_load)
Пример #14
0
def run_preprocessing():
    init_logging(get_path_to_output(), "preprocessing.log")
    log.info("log something")
    return 3
Пример #15
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import atexit
import time
import logging
import gevent
import threading
from threathunter_common.util import run_in_thread, run_in_subprocess
from produce import Produce
import settings

logger = settings.init_logging('nebula.sniffer')


def get_parser(parser_name, parser_module):
    from nebula_sniffer import parser
    __import__("{}.{}".format("nebula_sniffer.customparsers", parser_module),
               globals=globals())
    return parser.Parser.get_parser(parser_name)


def get_driver(config, interface, parser, idx):
    """ global c """

    from complexconfig.configcontainer import configcontainer
    name = config['driver']
    if name == "bro":
        from nebula_sniffer.drivers.brohttpdriver import BroHttpDriver
        embedded = config.get("embedded", True)
        ports = config['ports']
Пример #16
0
# -*- coding: utf-8 -*-

import atexit
import time
import logging
import gevent
import threading
from threathunter_common.util import run_in_thread, run_in_subprocess
from produce import Produce
from settings import init_logging
from settings import Global_Conf_FN
from settings import Sniffer_Conf_FN
from settings import Logging_Datefmt
from settings import DEBUG

logger = init_logging('nebula.sniffer')


def print_debug_level():
    if DEBUG:
        print "logging debug level is 'DEBUG'"
        logger.info("logging debug level is 'DEBUG'")
    else:
        print "logging debug level is 'INFO'"
        logger.info("logging debug level is 'INFO'")


def get_parser(parser_name, parser_module):
    from nebula_sniffer import parser
    __import__("{}.{}".format("nebula_sniffer.customparsers", parser_module), globals=globals())
    return parser.Parser.get_parser(parser_name)
Пример #17
0
    HEADER_WIDTH = 20
    STATUS_WIDTH = 45

    if status == None:
        status = '-- no status --'

    return header.ljust(HEADER_WIDTH) + ' [' + status.rjust(STATUS_WIDTH) + ']'


def println(w, str, y, x=0):
    w.addstr(y, x, str)
    w.clrtoeol()
    w.refresh()


if __name__ == "__main__":
    settings.init_logging()
    init_signal_handlers()

    parser = OptionParser()
    parser.add_option('-x',
                      '--no-tracklog',
                      dest='tracklog',
                      action='store_false',
                      default=True,
                      help='don\'t store tracklog')

    (options, args) = parser.parse_args()

    loader(options)