def __init__(self, net, netfile_name, cfgfile=None):
        self.net = net
        self.netfile_name = netfile_name
        print(cfgfile)
        if cfgfile != None:
            self.cfginit(cfgfile)
            print(1)
        print(SAVE_DIR)
        utils.makedir(SAVE_DIR)
        parser = argparse.ArgumentParser(description="base class for network training")
        self.args = self.argparser(parser)

        net_savefile = "{0}.{1}".format(self.netfile_name, NETFILE_EXTENTION)
        self.save_dir = os.path.join(SAVE_DIR, "nets")
        utils.makedir(self.save_dir)
        self.save_path = os.path.join(self.save_dir, net_savefile)
        self.savepath_epoch = os.path.join(SAVEDIR_EPOCH, net_savefile)

        if os.path.exists(self.save_path) and CONTINUETRAIN:

            try:
                self.net.load_state_dict(torch.load(self.save_path))
                print("net param load successful")

            except:
                self.net = torch.load(self.save_path)
                print("net load successful")


        else:
            self.net.paraminit()
            print("param initial complete")

        if ISCUDA:
            self.net = self.net.to(DEVICE)

        if NEEDTEST:
            self.detecter = Detector()

        self.logdir = os.path.join(SAVE_DIR, "log")

        utils.makedir(self.logdir)
        self.logfile = os.path.join(self.logdir, "{0}.txt".format(self.netfile_name))
        if not os.path.exists(self.logfile):
            with open(self.logfile, 'w') as f:
                print("%.2f %d    " % (0.00, 0), end='\r', file=f)
                print("logfile created")

        self.optimizer = optim.Adam(self.net.parameters())

        # 损失函数定义
        self.conf_loss_fn = nn.BCEWithLogitsLoss()  # 定义置信度损失函数
        self.center_loss_fn = nn.BCEWithLogitsLoss()  # 定义中心点损失函数
        self.wh_loss_fn = nn.MSELoss()  # 宽高损失
        # self.cls_loss_fn = torch.nn.CrossEntropyLoss()  # 定义交叉熵损失
        self.cls_loss_fn = nn.CrossEntropyLoss()

        self.detecter = Detector()

        print("initial complete")
 def test_has_uppercase_elements(self):
     target = Target({})
     actual = target.preparation("Red HEEL")
     self.assertDictEqual(
         actual,
         {
             'used_query': 'Red HEEL',
             'original_query': 'Red HEEL',
             'tokens': [
                 {
                     'start': 0,
                     'end': 3,
                     'pos': 'NNP',
                     'stem': u'red',
                     'stop_word': False,
                     'skip_word': False,
                     'use': True,
                     'value': 'red'
                 },
                 {
                     'start': 4,
                     'end': 8,
                     'pos': 'NNP',
                     'stem': u'heel',
                     'stop_word': False,
                     'skip_word': False,
                     'use': True,
                     'value': 'heel'
                 }
             ]
         }
     )
Beispiel #3
0
def run_camera(args, ctx):
    assert args.batch_size == 1, "only batch size of 1 is supported"
    logging.info("Detection threshold is {}".format(args.thresh))
    iter = CameraIterator(frame_resize=parse_frame_resize(args.frame_resize))
    class_names = parse_class_names(args.class_names)
    mean_pixels = (args.mean_r, args.mean_g, args.mean_b)
    data_shape = int(args.data_shape)
    batch_size = int(args.batch_size)
    detector = Detector(
        get_symbol(args.network, data_shape, num_classes=len(class_names)),
        network_path(args.prefix, args.network, data_shape), args.epoch,
        data_shape, mean_pixels, batch_size, ctx)
    for frame in iter:
        logging.info("Frame info: shape %s type %s", frame.shape, frame.dtype)
        logging.info("Generating batch")
        data_batch = detector.create_batch(frame)
        logging.info("Detecting objects")
        detections_batch = detector.detect_batch(data_batch)
        #detections = [mx.nd.array((1,1,0.2,0.2,0.4,0.4))]
        detections = detections_batch[0]
        logging.info("%d detections", len(detections))
        for det in detections:
            obj = det.asnumpy()
            (klass, score, x0, y0, x1, y1) = obj
            if score > args.thresh:
                draw_detection(frame, obj, class_names)
        cv2.imshow('frame', frame)
Beispiel #4
0
    def test_mix_spelling_alias(self):
        target = Target({})
        actual = target.unique_matches([{
            'found_item': [{
                'display_name': 'citrus',
                'key': 'citrus',
                'match_type': 'alias',
                'source': 'content',
                'type': 'color'
            }],
            'start':
            0,
            'end':
            6,
            'term':
            'citrus',
            'tokens': ['citrus']
        }, {
            'found_item': [{
                'display_name': 'citrus',
                'key': 'citrus',
                'match_type': 'spelling',
                'source': 'content',
                'type': 'color'
            }],
            'start':
            10,
            'end':
            16,
            'term':
            'citrus',
            'tokens': ['citrus']
        }, {
            'found_item': [{
                'display_name': 'high heels',
                'key': 'high heels',
                'match_type': 'alias',
                'source': 'content',
                'type': 'style'
            }],
            'start':
            10,
            'end':
            16,
            'term':
            'high heels',
            'tokens': ['high heels']
        }])

        self.assertEqual(len(actual), 2)
        self.assertTrue({
            'source': 'content',
            'type': 'color',
            'key': 'citrus'
        } in actual)
        self.assertTrue({
            'source': 'content',
            'type': 'style',
            'key': 'high heels'
        } in actual)
Beispiel #5
0
    def test_duplicates(self):
        target = Target({})
        actual = target.unique_non_detections([{
            'value': 'citrus',
            'start': 0,
            'skip_word': False,
            'stop_word': False,
            'stem': 'citrus',
            'end': 5,
            'use': True,
            'pos': 'NN'
        }, {
            'value': 'heel',
            'start': 5,
            'skip_word': False,
            'stop_word': False,
            'stem': 'heel',
            'end': 10,
            'use': True,
            'pos': 'NN'
        }, {
            'value': 'citrus',
            'start': 15,
            'skip_word': False,
            'stop_word': False,
            'stem': 'citrus',
            'end': 20,
            'use': True,
            'pos': 'NN'
        }])

        self.assertSetEqual(set(actual), set(['heel', 'citrus']))
 def test_Single(self):
     target = Target({})
     actual = target.unique_matches(
         [
             {
                 'found_item': [
                     {
                         'display_name': 'high heels',
                         'key': 'high heels',
                         'match_type': 'alias',
                         'source': 'content',
                         'type': 'style'
                     }
                 ],
                 'start': 10,
                 'end': 16,
                 'term': 'high heels',
                 'tokens': ['high heels']
             }
         ]
     )
     self.assertListEqual(
         actual,
         [
             {
                 'source': 'content',
                 'type': 'style',
                 'key': 'high heels'
             }
         ]
     )
    def test_non_duplicates(self):
        target = Target({})
        actual = target.unique_non_detections(
            [
                {
                    'value': 'citrus',
                    'start': 0,
                    'skip_word': False,
                    'stop_word': False,
                    'stem': 'citrus',
                    'end': 5,
                    'use': True,
                    'pos': 'NN'
                },
                {
                    'value': 'heel',
                    'start': 5,
                    'skip_word': False,
                    'stop_word': False,
                    'stem': 'heel',
                    'end': 10,
                    'use': True,
                    'pos': 'NN'
                }
            ]
        )

        self.assertSetEqual(
            set(actual),
            set(['heel', 'citrus'])
        )
Beispiel #8
0
 def test_has_uppercase_elements(self):
     target = Target({})
     actual = target.preparation("Red HEEL")
     self.assertDictEqual(
         actual, {
             'used_query':
             'Red HEEL',
             'original_query':
             'Red HEEL',
             'tokens': [{
                 'start': 0,
                 'end': 3,
                 'pos': 'NNP',
                 'stem': u'red',
                 'stop_word': False,
                 'skip_word': False,
                 'use': True,
                 'value': 'red'
             }, {
                 'start': 4,
                 'end': 8,
                 'pos': 'NNP',
                 'stem': u'heel',
                 'stop_word': False,
                 'skip_word': False,
                 'use': True,
                 'value': 'heel'
             }]
         })
    def test_single_mistake(self):
        target = Target({})
        actual = target.autocorrect_query(
            # 0123456789
            "citru",
            [
                {
                    'found_item': [
                        {
                            'display_name': 'citrus',
                            'key': 'citrus',
                            'match_type': 'spelling',
                            'source': 'content',
                            'type': 'color'
                        }
                    ],
                    'start': 0,
                    'end': 5,
                    'term': 'citru',
                    'tokens': ['citru']
                }
            ]
        )

        self.assertEqual(
            actual,
            "citrus"
        )
Beispiel #10
0
    def test_remove_if_found_in_larger_string(self):
        all_found = [
            {'term': 'blue', 'start': 0, 'tokens': ['blue'], 'position': '0_4', 'end': 4, 'found_item': [
                {'type': 'color', 'match_type': 'alias', 'key': 'blue', 'display_name': 'blue', 'source': 'content'}]},
            {'term': 'heels', 'start': 10, 'tokens': ['heels'], 'position': '10_15', 'end': 15, 'found_item': [
                {'type': 'style', 'match_type': 'alias', 'key': 'heels', 'display_name': 'display_name',
                 'source': 'content'}]},
            {'term': 'high heels', 'start': 5, 'tokens': ['high', 'heels'], 'position': '5_15', 'end': 15,
             'found_item': [
                 {'type': 'style', 'match_type': 'alias', 'key': 'high heels', 'display_name': 'display_name',
                  'source': 'content'}]}
        ]
        target = Target({})
        actual = target.format_found_entities(all_found)

        self.assertListEqual(
            [
                {
                    'term': 'blue', 'start': 0, 'tokens': ['blue'], 'position': '0_4', 'end': 4, 'found_item': [
                    {'type': 'color', 'match_type': 'alias', 'key': 'blue', 'display_name': 'blue',
                     'source': 'content'}
                ]
                },
                {
                    'term': 'high heels', 'start': 5, 'tokens': ['high', 'heels'], 'position': '5_15', 'end': 15,
                    'found_item': [
                        {'type': 'style', 'match_type': 'alias', 'key': 'high heels', 'display_name': 'display_name',
                         'source': 'content'}
                    ]
                }
            ],
            actual
        )
Beispiel #11
0
def run_camera(args,ctx):
    assert args.batch_size == 1, "only batch size of 1 is supported"
    logging.info("Detection threshold is {}".format(args.thresh))
    iter = CameraIterator()
    class_names = parse_class_names(args.class_names)
    mean_pixels = (args.mean_r, args.mean_g, args.mean_b)
    data_shape = int(args.data_shape)
    batch_size = int(args.batch_size)
    detector = Detector(
        get_symbol(args.network, data_shape, num_classes=len(class_names)),
        network_path(args.prefix, args.network, data_shape),
        args.epoch,
        data_shape,
        mean_pixels,
        batch_size,
        ctx
    )
    for frame in iter:
        logging.info("Frame info: shape %s type %s", frame.shape, frame.dtype)
        logging.info("Generating batch")
        data_batch = detector.create_batch(frame)
        logging.info("Detecting objects")
        detections_batch = detector.detect_batch(data_batch)
        #detections = [mx.nd.array((1,1,0.2,0.2,0.4,0.4))]
        detections = detections_batch[0]
        logging.info("%d detections", len(detections))
        for det in detections:
            obj = det.asnumpy()
            (klass, score, x0, y0, x1, y1) = obj
            if score > args.thresh:
                draw_detection(frame, obj, class_names)
        cv2.imshow('frame', frame)
Beispiel #12
0
    def test_multiple_term(self):
        target = Target({})
        actual = target.find_matches(
            3, [{
                'value': 'red',
                'pos': 'VBD',
                'stem': 'red',
                'stop_word': False,
                'use': True,
                'start': 0,
                'skip_word': False,
                'end': 3
            }, {
                'value': 'valentino',
                'pos': 'VBN',
                'stem': 'valentino',
                'stop_word': False,
                'use': True,
                'start': 4,
                'skip_word': False,
                'end': 13
            }], {
                "en": {
                    "red": [{
                        "key": "red",
                        "type": "color",
                        "source": "content",
                        'match_type': 'alias'
                    }],
                    "red valentino": [{
                        "key": "red valentino",
                        "type": "brand",
                        "source": "content",
                        'match_type': 'alias'
                    }]
                }
            })

        self.assertDictEqual(
            actual, {
                "can_not_match": [],
                "found": [{
                    'found_item': [{
                        'key': 'red valentino',
                        'match_type': 'alias',
                        'source': 'content',
                        'type': 'brand'
                    }],
                    'term':
                    'red valentino',
                    'end':
                    13,
                    'start':
                    0,
                    'position':
                    '0_13',
                    'tokens': ['red', 'valentino']
                }]
            })
Beispiel #13
0
    def test_empty_can_not_match(self):
        target = Target({})
        actual = target.unique_non_detections([])

        self.assertListEqual(
            actual,
            []
        )
Beispiel #14
0
 def test_empty_string(self):
     target = Target({})
     actual = target.preparation("")
     self.assertDictEqual(actual, {
         'used_query': '',
         'tokens': [],
         'original_query': ''
     })
Beispiel #15
0
 def test_has_stopwords(self):
     target = Target({})
     actual = target.preparation("Shoes with red and white")
     self.assertDictEqual(
         actual,
         {
             'original_query': 'Shoes with red and white',
             'used_query': 'Shoes with red and white',
             'tokens': [
                 {
                     'start': 0,
                     'end': 5,
                     'pos': 'VBZ',
                     'stem': u'shoe',
                     'stop_word': False,
                     'skip_word': False,
                     'use': True,
                     'value': 'shoes'},
                 {
                     'start': 6,
                     'end': 10,
                     'pos': 'IN',
                     'stem': u'with',
                     'stop_word': True,
                     'skip_word': False,
                     'use': False,
                     'value': 'with'},
                 {
                     'start': 11,
                     'end': 14,
                     'pos': 'JJ',
                     'stem': u'red',
                     'stop_word': False,
                     'skip_word': False,
                     'use': True,
                     'value': 'red'},
                 {
                     'start': 15,
                     'end': 18,
                     'pos': 'CC',
                     'stem': u'and',
                     'stop_word': True,
                     'skip_word': False,
                     'use': False,
                     'value': 'and'},
                 {
                     'start': 19,
                     'end': 24,
                     'pos': 'JJ',
                     'stem': u'white',
                     'stop_word': False,
                     'skip_word': False,
                     'use': True,
                     'value': 'white'
                 }
             ]
         }
     )
Beispiel #16
0
    def test_multiple_mistakes_with_same_entity_no_mistake(self):
        target = Target({})
        actual = target.autocorrect_query(
            # 012345678901234567890123456789
            "I want citru high hells citrus thanks",
            [
                {
                    'found_item': [
                        {
                            'display_name': 'citrus',
                            'key': 'citrus',
                            'match_type': 'spelling',
                            'source': 'content',
                            'type': 'color'
                        }
                    ],
                    'start': 7,
                    'end': 12,
                    'term': 'citru',
                    'tokens': ['citru']
                },
                {
                    'found_item': [
                        {
                            'display_name': 'high heels',
                            'key': 'high heels',
                            'match_type': 'spelling',
                            'source': 'content',
                            'type': 'color'
                        }
                    ],
                    'start': 13,
                    'end': 23,
                    'term': 'high hells',
                    'tokens': ['high', 'hells']
                },
                {
                    'found_item': [
                        {
                            'display_name': 'citrus',
                            'key': 'citrus',
                            'match_type': 'alias',
                            'source': 'content',
                            'type': 'color'
                        }
                    ],
                    'start': 24,
                    'end': 30,
                    'term': 'citrus',
                    'tokens': ['citrus']
                }
            ]
        )

        self.assertEqual(
            actual,
            "I want citrus high heels citrus thanks"
        )
Beispiel #17
0
 def initialize(self, alias_data):
     from detect.data.response import Response
     self.data_response = Response()
     self.data_response.open_connection()
     self.alias_data = alias_data
     self.param_extractor = ParamExtractor(self)
     self.path_extractor = PathExtractor(self)
     self.entity_factory = EntityFactory(self.alias_data)
     self.brute_detector = Detector(self.alias_data)
Beispiel #18
0
 def test_has_stopwords(self):
     target = Target({})
     actual = target.preparation("Shoes with red and white")
     self.assertDictEqual(
         actual, {
             'original_query':
             'Shoes with red and white',
             'used_query':
             'Shoes with red and white',
             'tokens': [{
                 'start': 0,
                 'end': 5,
                 'pos': 'VBZ',
                 'stem': u'shoe',
                 'stop_word': False,
                 'skip_word': False,
                 'use': True,
                 'value': 'shoes'
             }, {
                 'start': 6,
                 'end': 10,
                 'pos': 'IN',
                 'stem': u'with',
                 'stop_word': True,
                 'skip_word': False,
                 'use': False,
                 'value': 'with'
             }, {
                 'start': 11,
                 'end': 14,
                 'pos': 'JJ',
                 'stem': u'red',
                 'stop_word': False,
                 'skip_word': False,
                 'use': True,
                 'value': 'red'
             }, {
                 'start': 15,
                 'end': 18,
                 'pos': 'CC',
                 'stem': u'and',
                 'stop_word': True,
                 'skip_word': False,
                 'use': False,
                 'value': 'and'
             }, {
                 'start': 19,
                 'end': 24,
                 'pos': 'JJ',
                 'stem': u'white',
                 'stop_word': False,
                 'skip_word': False,
                 'use': True,
                 'value': 'white'
             }]
         })
Beispiel #19
0
def peopleDetect():

    CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
               'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
               'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
               'tvmonitor')
    cap = cv2.VideoCapture(1)
    net = None
    prefix = os.path.join(os.getcwd(), 'model', 'yolo2_darknet19_416')
    epoch = 0

    mean_pixels = (123, 117, 104)
    ctx = mx.gpu(0)
    global numPeople
    global isNotQuit
    count = 0

    ret1, frame1 = cap.read()
    detector = Detector(net,
                        prefix,
                        epoch,
                        data_shape,
                        mean_pixels,
                        ctx=ctx,
                        batch_size=batch)
    while isNotQuit:
        count += 1
        ret, frame = cap.read()
        ims = [
            cv2.resize(frame, (data_shape, data_shape)) for i in range(batch)
        ]

        data = None
        data = get_batch(ims)

        start = timer()

        det_batch = mx.io.DataBatch(data, [])
        detector.mod.forward(det_batch, is_train=False)
        detections = detector.mod.get_outputs()[0].asnumpy()
        result = []

        for i in range(detections.shape[0]):
            det = detections[i, :, :]
            res = det[np.where(det[:, 0] >= 0)[0]]
            result.append(res)
        time_elapsed = timer() - start
        # print("Detection time for {} images: {:.4f} sec , fps : {:.4f}".format(batch*1, time_elapsed , (batch*1/time_elapsed)))
        numPeople, numChair = detector.show_result(frame, det, CLASSES, 0.5,
                                                   batch * 1 / time_elapsed)
    # if count>40:
    #	isNotQuit = False
    #break
    cap.release()
    cv2.destroyAllWindows()
Beispiel #20
0
    def test_no_duplicates(self):
        target = Target({})
        actual = target.unique_matches(
            [
                {
                    'found_item': [
                        {
                            'display_name': 'citrus',
                            'key': 'citrus',
                            'match_type': 'alias',
                            'source': 'content',
                            'type': 'color'
                        }
                    ],
                    'start': 0,
                    'end': 6,
                    'term': 'citrus',
                    'tokens': ['citrus']
                },
                {
                    'found_item': [
                        {
                            'display_name': 'high heels',
                            'key': 'high heels',
                            'match_type': 'alias',
                            'source': 'content',
                            'type': 'style'
                        }
                    ],
                    'start': 10,
                    'end': 16,
                    'term': 'high heels',
                    'tokens': ['high heels']
                }
            ]
        )

        self.assertEqual(
            len(actual),
            2
        )
        self.assertTrue(
            {
                'source': 'content',
                'type': 'color',
                'key': 'citrus'
            } in actual
        )
        self.assertTrue(
            {
                'source': 'content',
                'type': 'style',
                'key': 'high heels'
            } in actual
        )
Beispiel #21
0
def evaluate_net(net, dataset, devkit_path, mean_pixels, data_shape,
                 model_prefix, epoch, ctx, year=None, sets='test',
                 batch_size=1, nms_thresh=0.5, force_nms=False):
    """
    Evaluate entire dataset, basically simple wrapper for detections

    Parameters:
    ---------
    dataset : str
        name of dataset to evaluate
    devkit_path : str
        root directory of dataset
    mean_pixels : tuple of float
        (R, G, B) mean pixel values
    data_shape : int
        resize input data shape
    model_prefix : str
        load model prefix
    epoch : int
        load model epoch
    ctx : mx.ctx
        running context, mx.cpu() or mx.gpu(0)...
    year : str or None
        evaluate on which year's data
    sets : str
        evaluation set
    batch_size : int
        using batch_size for evaluation
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : bool
        force suppress different categories
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    if dataset == "pascal":
        if not year:
            year = '2007'
        imdb = PascalVoc(sets, year, devkit_path, shuffle=False, is_train=False)
        data_iter = DetIter(imdb, batch_size, data_shape, mean_pixels,
            rand_samplers=[], rand_mirror=False, is_train=False, shuffle=False)
        sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol'))
        net = importlib.import_module("symbol_" + net) \
            .get_symbol(imdb.num_classes, nms_thresh, force_nms)
        model_prefix += "_" + str(data_shape)
        detector = Detector(net, model_prefix, epoch, data_shape, mean_pixels, batch_size, ctx)
        logger.info("Start evaluation with {} images, be patient...".format(imdb.num_images))
        detections = detector.detect(data_iter)
        imdb.evaluate_detections(detections)
    else:
        raise NotImplementedError("No support for dataset: " + dataset)
Beispiel #22
0
 def test_empty_string(self):
     target = Target({})
     actual = target.preparation("")
     self.assertDictEqual(
         actual,
         {
             'used_query': '',
             'tokens': [],
             'original_query': ''
         }
     )
Beispiel #23
0
def evaluate_net(net, dataset, devkit_path, mean_pixels, data_shape,
                 model_prefix, epoch, ctx, year=None, sets='test',
                 batch_size=1, nms_thresh=0.5, force_nms=False):
    """
    Evaluate entire dataset, basically simple wrapper for detections

    Parameters:
    ---------
    dataset : str
        name of dataset to evaluate
    devkit_path : str
        root directory of dataset
    mean_pixels : tuple of float
        (R, G, B) mean pixel values
    data_shape : int
        resize input data shape
    model_prefix : str
        load model prefix
    epoch : int
        load model epoch
    ctx : mx.ctx
        running context, mx.cpu() or mx.gpu(0)...
    year : str or None
        evaluate on which year's data
    sets : str
        evaluation set
    batch_size : int
        using batch_size for evaluation
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : bool
        force suppress different categories
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    if dataset == "pascal":
        if not year:
            year = '2007'
        imdb = PascalVoc(sets, year, devkit_path, shuffle=False, is_train=False)
        data_iter = DetIter(imdb, batch_size, data_shape, mean_pixels,
            rand_samplers=[], rand_mirror=False, is_train=False, shuffle=False)
        sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol'))
        net = importlib.import_module("symbol_" + net) \
            .get_symbol(imdb.num_classes, nms_thresh, force_nms)
        model_prefix += "_" + str(data_shape)
        detector = Detector(net, model_prefix, epoch, data_shape, mean_pixels, batch_size, ctx)
        logger.info("Start evaluation with {} images, be patient...".format(imdb.num_images))
        detections = detector.detect(data_iter)
        imdb.evaluate_detections(detections)
    else:
        raise NotImplementedError, "No support for dataset: " + dataset
Beispiel #24
0
    def test_miss_spelling(self):
        target = Target({})
        actual = target.find_matches(
            3,
            [
                {
                    'value': 'citru',
                    'start': 0,
                    'skip_word': False,
                    'stop_word': False,
                    'stem': 'citru',
                    'end': 5,
                    'use': True,
                    'pos': 'NN'
                }
            ],
            {
                'en': {
                    'citru': [{'type': 'color', 'key': 'citrus', 'source': 'content', 'display_name': 'citrus',
                               'match_type': 'spelling'}],
                    'citrus': [{'type': 'color', 'key': 'citrus', 'source': 'content', 'display_name': 'citrus',
                                'match_type': 'alias'}],
                    'red': [{'type': 'color', 'key': 'red', 'source': 'content', 'display_name': 'red',
                             'match_type': 'alias'}],
                    'heel': [{'type': 'style', 'key': 'heel', 'source': 'content', 'display_name': 'heel',
                              'match_type': 'alias'}]
                }
            }
        )

        self.assertDictEqual(
            actual,
            {
                'can_not_match': [],
                "found": [
                    {
                        'found_item': [
                            {
                                'display_name': 'citrus',
                                'key': 'citrus',
                                'match_type': 'spelling',
                                'source': 'content',
                                'type': 'color'
                            }
                        ],
                        'position': '0_5',
                        'start': 0,
                        'end': 5,
                        'term': 'citru',
                        'tokens': ['citru']
                    }
                ]
            }
        )
Beispiel #25
0
    def test_multiple_mistakes_with_same_entity_no_mistake(self):
        target = Target({})
        actual = target.autocorrect_query(
            # 012345678901234567890123456789
            "I want citru high hells citrus thanks",
            [{
                'found_item': [{
                    'display_name': 'citrus',
                    'key': 'citrus',
                    'match_type': 'spelling',
                    'source': 'content',
                    'type': 'color'
                }],
                'start':
                7,
                'end':
                12,
                'term':
                'citru',
                'tokens': ['citru']
            }, {
                'found_item': [{
                    'display_name': 'high heels',
                    'key': 'high heels',
                    'match_type': 'spelling',
                    'source': 'content',
                    'type': 'color'
                }],
                'start':
                13,
                'end':
                23,
                'term':
                'high hells',
                'tokens': ['high', 'hells']
            }, {
                'found_item': [{
                    'display_name': 'citrus',
                    'key': 'citrus',
                    'match_type': 'alias',
                    'source': 'content',
                    'type': 'color'
                }],
                'start':
                24,
                'end':
                30,
                'term':
                'citrus',
                'tokens': ['citrus']
            }])

        self.assertEqual(actual, "I want citrus high heels citrus thanks")
Beispiel #26
0
    def test_regular(self):
        target = Target({})
        actual = target.create_found_doc("terms_value", "tokens_value",
                                         "found_item_value", "start_value",
                                         "end_value")

        self.assertDictEqual(
            actual, {
                "term": "terms_value",
                "tokens": "tokens_value",
                "found_item": "found_item_value",
                "start": "start_value",
                'position': 'start_value_end_value',
                "end": "end_value"
            })
Beispiel #27
0
def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx, num_class,
                 nms_thresh=0.5, force_nms=True, nms_topk=400):
    """
    wrapper for initialize a detector

    Parameters:
    ----------
    net : str
        test network name
    prefix : str
        load model prefix
    epoch : int
        load model epoch
    data_shape : int
        resize image shape
    mean_pixels : tuple (float, float, float)
        mean pixel values (R, G, B)
    ctx : mx.ctx
        running context, mx.cpu() or mx.gpu(?)
    num_class : int
        number of classes
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : bool
        force suppress different categories
    """
    if net is not None:
        net = get_symbol(net, data_shape, num_classes=num_class, nms_thresh=nms_thresh,
            force_nms=force_nms, nms_topk=nms_topk)
    detector = Detector(net, prefix, epoch, data_shape, mean_pixels, ctx=ctx)
    return detector
Beispiel #28
0
def get_detector(net,
                 prefix,
                 epoch,
                 data_shape,
                 mean_pixels,
                 ctx,
                 nms_thresh=0.5,
                 force_nms=True):
    """
    wrapper for initialize a detector

    Parameters:
    ----------
    net : str
        test network name
    prefix : str
        load model prefix
    epoch : int
        load model epoch
    data_shape : int
        resize image shape
    mean_pixels : tuple (float, float, float)
        mean pixel values (R, G, B)
    ctx : mx.ctx
        running context, mx.cpu() or mx.gpu(?)
    force_nms : bool
        force suppress different categories
    """
    sys.path.append(os.path.join(os.getcwd(), 'symbol'))
    if net is not None:
        net = importlib.import_module("symbol_" + net) \
            .get_symbol(len(CLASSES), nms_thresh, force_nms)
    detector = Detector(net, prefix + "_" + str(data_shape), epoch, \
        data_shape, mean_pixels, ctx=ctx)
    return detector
Beispiel #29
0
    def initialize(self, alias_data):
        from detect.data.response import Response

        self.data_response = Response()
        self.data_response.open_connection()
        self.alias_data = alias_data
        self.param_extractor = ParamExtractor(self)
        self.path_extractor = PathExtractor(self)
        self.entity_factory = EntityFactory(self.alias_data)
        self.brute_detector = Detector(self.alias_data)
Beispiel #30
0
    def test_regular(self):
        target = Target({})
        actual = target.create_found_doc(
            "terms_value",
            "tokens_value",
            "found_item_value",
            "start_value",
            "end_value"
        )

        self.assertDictEqual(
            actual,
            {
                "term": "terms_value",
                "tokens": "tokens_value",
                "found_item": "found_item_value",
                "start": "start_value",
                'position': 'start_value_end_value',
                "end": "end_value"
            }
        )
Beispiel #31
0
 def test_has_skipwords(self):
     target = Target({})
     actual = target.preparation("Show me anything")
     self.assertDictEqual(
         actual,
         {
             'original_query': 'Show me anything',
             'used_query': 'Show me anything',
             'tokens': [
                 {
                     'start': 0,
                     'end': 4,
                     'pos': 'NNP',
                     'stem': u'show',
                     'stop_word': True,
                     'skip_word': True,
                     'use': False,
                     'value': 'show'},
                 {
                     'start': 5,
                     'end': 7,
                     'pos': 'PRP',
                     'stem': u'me',
                     'stop_word': True,
                     'skip_word': False,
                     'use': False,
                     'value': 'me'},
                 {
                     'start': 8,
                     'end': 16,
                     'pos': 'NN',
                     'stem': u'anyth',
                     'stop_word': False,
                     'skip_word': True,
                     'use': False,
                     'value': 'anything'
                 }
             ]
         }
     )
Beispiel #32
0
    def test_no_mistakes(self):
        target = Target({})
        actual = target.autocorrect_query(
            # 012345678901234567890123456789
            "I want citrus high heels",
            [
                {
                    'found_item': [
                        {
                            'display_name': 'citrus',
                            'key': 'citrus',
                            'match_type': 'alias',
                            'source': 'content',
                            'type': 'color'
                        }
                    ],
                    'start': 7,
                    'end': 13,
                    'term': 'citrus',
                    'tokens': ['citrus']
                },
                {
                    'found_item': [
                        {
                            'display_name': 'high heels',
                            'key': 'high heels',
                            'match_type': 'alias',
                            'source': 'content',
                            'type': 'color'
                        }
                    ],
                    'start': 14,
                    'end': 24,
                    'term': 'high heels',
                    'tokens': ['high', 'heels']
                }
            ]
        )

        self.assertIsNone(actual)
Beispiel #33
0
 def test_has_skipwords(self):
     target = Target({})
     actual = target.preparation("Show me anything")
     self.assertDictEqual(
         actual, {
             'original_query':
             'Show me anything',
             'used_query':
             'Show me anything',
             'tokens': [{
                 'start': 0,
                 'end': 4,
                 'pos': 'NNP',
                 'stem': u'show',
                 'stop_word': True,
                 'skip_word': True,
                 'use': False,
                 'value': 'show'
             }, {
                 'start': 5,
                 'end': 7,
                 'pos': 'PRP',
                 'stem': u'me',
                 'stop_word': True,
                 'skip_word': False,
                 'use': False,
                 'value': 'me'
             }, {
                 'start': 8,
                 'end': 16,
                 'pos': 'NN',
                 'stem': u'anyth',
                 'stop_word': False,
                 'skip_word': True,
                 'use': False,
                 'value': 'anything'
             }]
         })
Beispiel #34
0
    def test_no_mistakes(self):
        target = Target({})
        actual = target.autocorrect_query(
            # 012345678901234567890123456789
            "I want citrus high heels",
            [{
                'found_item': [{
                    'display_name': 'citrus',
                    'key': 'citrus',
                    'match_type': 'alias',
                    'source': 'content',
                    'type': 'color'
                }],
                'start':
                7,
                'end':
                13,
                'term':
                'citrus',
                'tokens': ['citrus']
            }, {
                'found_item': [{
                    'display_name': 'high heels',
                    'key': 'high heels',
                    'match_type': 'alias',
                    'source': 'content',
                    'type': 'color'
                }],
                'start':
                14,
                'end':
                24,
                'term':
                'high heels',
                'tokens': ['high', 'heels']
            }])

        self.assertIsNone(actual)
Beispiel #35
0
 def test_Single(self):
     target = Target({})
     actual = target.unique_matches([{
         'found_item': [{
             'display_name': 'high heels',
             'key': 'high heels',
             'match_type': 'alias',
             'source': 'content',
             'type': 'style'
         }],
         'start':
         10,
         'end':
         16,
         'term':
         'high heels',
         'tokens': ['high heels']
     }])
     self.assertListEqual(actual, [{
         'source': 'content',
         'type': 'style',
         'key': 'high heels'
     }])
Beispiel #36
0
    def test_single_mistake(self):
        target = Target({})
        actual = target.autocorrect_query(
            # 0123456789
            "citru",
            [{
                'found_item': [{
                    'display_name': 'citrus',
                    'key': 'citrus',
                    'match_type': 'spelling',
                    'source': 'content',
                    'type': 'color'
                }],
                'start':
                0,
                'end':
                5,
                'term':
                'citru',
                'tokens': ['citru']
            }])

        self.assertEqual(actual, "citrus")
Beispiel #37
0
def get_detector(net,
                 prefix,
                 epoch,
                 data_shape,
                 mean_pixels,
                 ctx,
                 nms_thresh=0.5,
                 force_nms=True):
    sys.path.append(os.path.join(os.getcwd(), 'symbol'))
    net = importlib.import_module("symbol_"+net)\
            .get_symbol(len(CLASSES), nms_thresh, force_nms)
    detector = Detector(net, prefix + "_"+ str(data_shape), epoch, \
                       data_shape, mean_pixels, ctx = ctx)
    return detector
Beispiel #38
0
def get_mxnet_detector(net,
                       prefix,
                       epoch,
                       data_shape,
                       mean_pixels,
                       ctx,
                       batch_size=1):
    detector = Detector(net,
                        prefix,
                        epoch,
                        data_shape,
                        mean_pixels,
                        ctx=ctx,
                        batch_size=1)
    return detector
Beispiel #39
0
    def test_no_autocorrection(self):
        target = Target({})
        target.find_matches = Mock()
        target.find_matches.return_value = {
            "found": "find_matches_found",
            "can_not_match": "find_matches_can_not_match"
        }
        target.autocorrect_query = Mock()
        target.autocorrect_query.return_value = None
        target.key_matches = Mock(return_value="key_matches:return_value")
        target.unique_non_detections = Mock()
        target.unique_non_detections.return_value = "unique_non_detections:return_value"
        target.format_found_entities = Mock()
        target.format_found_entities.return_value = "formated_found_entities"

        actual = target.detect_entities(
            "vocab",
            {
                "tokens": "preperation_result:tokens",
                "used_query": "preperation_result:used_query"
            }
        )
        self.assertEqual(1, target.find_matches.call_count)
        self.assertEqual(3, target.find_matches.call_args_list[0][0][0])
        self.assertEqual("preperation_result:tokens", target.find_matches.call_args_list[0][0][1])
        self.assertEqual(
            target.find_matches.call_args_list[0][0][2],
            'vocab'
        )

        self.assertEqual(1, target.autocorrect_query.call_count)
        self.assertEqual('preperation_result:used_query', target.autocorrect_query.call_args_list[0][0][0])
        self.assertEqual('formated_found_entities', target.autocorrect_query.call_args_list[0][0][1])

        self.assertEqual(1, target.key_matches.call_count)

        self.assertEqual('formated_found_entities', target.key_matches.call_args_list[0][0][0])

        self.assertEqual(1, target.unique_non_detections.call_count)
        self.assertEqual('find_matches_can_not_match', target.unique_non_detections.call_args_list[0][0][0])

        self.assertDictEqual(
            {
                'detections': 'key_matches:return_value',
                'non_detections': 'unique_non_detections:return_value'
            },
            actual
        )
Beispiel #40
0
    def test_autocorrection(self):
        target = Target({})
        target.find_matches = Mock()
        target.find_matches.return_value = {
            "found": "find_matches_found",
            "can_not_match": "find_matches_can_not_match"
        }
        target.autocorrect_query = Mock()
        target.autocorrect_query.return_value = "autocorrected_query_new_value"
        target.key_matches = Mock(return_value="key_matches:return_value")
        target.unique_non_detections = Mock()
        target.unique_non_detections.return_value = "unique_non_detections:return_value"
        target.format_found_entities = Mock()
        target.format_found_entities.return_value = "formated_found_entities"

        actual = target.detect_entities(
            "vocab", {
                "tokens": "preperation_result:tokens",
                "used_query": "preperation_result:used_query"
            })
        self.assertEqual(1, target.find_matches.call_count)
        self.assertEqual(3, target.find_matches.call_args_list[0][0][0])
        self.assertEqual("preperation_result:tokens",
                         target.find_matches.call_args_list[0][0][1])
        self.assertEqual('vocab', target.find_matches.call_args_list[0][0][2])

        self.assertEqual(target.autocorrect_query.call_count, 1)
        self.assertEqual(target.autocorrect_query.call_args_list[0][0][0],
                         'preperation_result:used_query')
        self.assertEqual(target.autocorrect_query.call_args_list[0][0][1],
                         'formated_found_entities')

        self.assertEqual(target.key_matches.call_count, 1)

        self.assertEqual(target.key_matches.call_args_list[0][0][0],
                         'formated_found_entities')

        self.assertEqual(target.unique_non_detections.call_count, 1)
        self.assertEqual(target.unique_non_detections.call_args_list[0][0][0],
                         'find_matches_can_not_match')

        self.assertDictEqual(
            actual, {
                'autocorrected_query': 'autocorrected_query_new_value',
                'detections': 'key_matches:return_value',
                'non_detections': 'unique_non_detections:return_value'
            })
min_face_size = 24
stride = 2
slide_window = False
shuffle = False
detectors = []
prefix = [
    'detect/MTCNN_model/PNet_landmark/PNet',
    'detect/MTCNN_model/RNet_landmark/RNet',
    'detect/MTCNN_model/ONet_landmark/ONet'
]
epoch = [18, 14, 16]
batch_size = [2048, 256, 16]
model_path = ['%s-%s' % (x, y) for x, y in zip(prefix, epoch)]

detectors.append(FcnDetector(P_Net, model_path[0]))
detectors.append(Detector(R_Net, 24, batch_size[1], model_path[1]))
detectors.append(Detector(O_Net, 48, batch_size[2], model_path[2]))

mtcnn_detector = MtcnnDetector(detectors=detectors,
                               min_face_size=min_face_size,
                               stride=stride,
                               threshold=thresh,
                               slide_window=slide_window)

# Init another version of MtcnnDetector
print('Creating networks and loading parameters')
minsize = 20  # minimum size of face
threshold = [0.6, 0.7, 0.7]  # three steps's threshold
factor = 0.709  # scale factor
margin = 44
with tf.Graph().as_default():
Beispiel #42
0
class Detect(RequestHandler):
    brute_detector = None
    alias_data = None
    data_response = None
    param_extractor = None
    path_extractor = None
    entity_factory = None

    def data_received(self, chunk):
        pass

    def initialize(self, alias_data):
        from detect.data.response import Response
        self.data_response = Response()
        self.data_response.open_connection()
        self.alias_data = alias_data
        self.param_extractor = ParamExtractor(self)
        self.path_extractor = PathExtractor(self)
        self.entity_factory = EntityFactory(self.alias_data)
        self.brute_detector = Detector(self.alias_data)

    def on_finish(self):
        pass

    @asynchronous
    def post(self, *args, **kwargs):
        self.set_header('Content-Type', 'application/json')

        detection_id = ObjectId()

        app_log.info(
            "app=detection,function=detect,detection_id=%s,application_id=%s,session_id=%s,q=%s",
            detection_id, self.param_extractor.application_id(),
            self.param_extractor.session_id(), self.param_extractor.query())

        if False:
            url = "%smessage?v=%s&q=%s&msg_id=%s" % (
                WIT_URL, WIT_URL_VERSION,
                url_escape(self.param_extractor.query()), str(detection_id))
            r = HTTPRequest(url,
                            headers={"Authorization": "Bearer %s" % WIT_TOKEN})
            client = AsyncHTTPClient()
            client.fetch(r, callback=self.wit_call_back)
        else:
            date = datetime.now()
            outcomes = self.brute_detector.detect(self.param_extractor.query())
            self.data_response.insert(self.param_extractor.user_id(),
                                      self.param_extractor.application_id(),
                                      self.param_extractor.session_id(),
                                      detection_id,
                                      "brute",
                                      date,
                                      self.param_extractor.query(),
                                      outcomes=outcomes)

            self.set_status(202)
            self.set_header("Location", "/%s" % str(detection_id))
            self.set_header("_id", str(detection_id))
            self.finish()

            Worker(self.param_extractor.user_id(),
                   self.param_extractor.application_id(),
                   self.param_extractor.session_id(),
                   detection_id,
                   date,
                   self.param_extractor.query(),
                   self.param_extractor.skip_slack_log(),
                   detection_type="wit",
                   outcomes=outcomes).start()

    @asynchronous
    def get(self, detection_id, *args, **kwargs):
        data = self.data_response.get(
            self.path_extractor.detection_id(detection_id))
        if data is not None:
            self.set_header('Content-Type', 'application/json')
            self.set_status(200)
            self.finish(
                dumps({
                    "type": data["type"],
                    "q": data["q"],
                    "outcomes": data["outcomes"],
                    "_id": data["_id"],
                    "version": data["version"],
                    "timestamp": data["timestamp"]
                }))
        else:
            self.set_status(404)
            self.finish()

    def wit_call_back(self, response):
        data = json_decode(response.body)
        outcomes = []
        date = datetime.now()
        for outcome in data["outcomes"]:
            entities = []
            for _type in outcome["entities"].keys():
                if _type not in ["polite"]:
                    for value in outcome["entities"][_type]:
                        suggested = value[
                            "suggested"] if "suggested" in value else False
                        key = value["value"]["value"] if type(
                            value["value"]) is dict else value["value"]
                        entity = self.entity_factory.create(
                            _type, key, suggested)

                        # TODO this needs to be moved somewhere else preferably a seperate service call
                        entities.append(entity)

            outcomes.append({
                "confidence": outcome["confidence"] * 100,
                "intent": outcome["intent"],
                "entities": entities
            })

        self.data_response.insert(self.param_extractor.user_id(),
                                  self.param_extractor.application_id(),
                                  self.param_extractor.session_id(),
                                  ObjectId(data["msg_id"]),
                                  "wit",
                                  date,
                                  self.param_extractor.query(),
                                  outcomes=outcomes)

        self.set_status(202)
        self.set_header("Location", "/%s" % data["msg_id"])
        self.set_header("_id", data["msg_id"])
        self.finish()

        Worker(self.param_extractor.user_id(),
               self.param_extractor.application_id(),
               self.param_extractor.session_id(),
               ObjectId(data["msg_id"]),
               date,
               self.param_extractor.query(),
               self.param_extractor.skip_slack_log(),
               detection_type="wit",
               outcomes=outcomes).start()
Beispiel #43
0
    def test_miss_spelling_ngram(self):
        target = Target({})
        actual = target.find_matches(
            3,
            [
                {'skip_word': False, 'stem': 'white', 'start': 0, 'pos': 'JJ', 'value': 'white', 'stop_word': False,
                 'use': True, 'end': 5},
                {'skip_word': False, 'stem': 'and', 'start': 7, 'pos': 'CC', 'value': 'and', 'stop_word': True,
                 'use': False, 'end': 10},
                {'skip_word': False, 'stem': 'blue', 'start': 11, 'pos': 'JJ', 'value': 'blue', 'stop_word': False,
                 'use': True, 'end': 15},
                {'skip_word': False, 'stem': 'high', 'start': 16, 'pos': 'NN', 'value': 'high', 'stop_word': False,
                 'use': True, 'end': 20},
                {'skip_word': False, 'stem': 'heal', 'start': 21, 'pos': 'NNS', 'value': 'heals', 'stop_word': False,
                 'use': True, 'end': 26}
            ],
            {
                'en': {
                    'high heals': [
                        {'type': 'color', 'key': 'high heels', 'source': 'content', 'display_name': 'high heels',
                         'match_type': 'spelling'}],
                    'white': [{'type': 'color', 'key': 'white', 'source': 'content', 'display_name': 'white',
                               'match_type': 'alias'}],
                    'blue': [{'type': 'color', 'key': 'blue', 'source': 'content', 'display_name': 'blue',
                              'match_type': 'alias'}]
                }
            }
        )

        self.assertDictEqual(
            {
                'found': [
                    {
                        'position': '0_5',
                        'start': 0,
                        'end': 5,
                        'found_item': [
                            {'source': 'content', 'type': 'color', 'match_type': 'alias', 'key': 'white',
                             'display_name': 'white'}
                        ],
                        'term': 'white',
                        'tokens': ['white']
                    },
                    {
                        'position': '11_15',
                        'start': 11, 'end': 15,
                        'found_item': [
                            {'source': 'content', 'type': 'color', 'match_type': 'alias', 'key': 'blue',
                             'display_name': 'blue'}
                        ],
                        'term': 'blue', 'tokens': ['blue']
                    },
                    {
                        'position': '16_26',
                        'start': 16, 'end': 26,
                        'found_item': [
                            {'source': 'content', 'type': 'color', 'match_type': 'spelling', 'key': 'high heels',
                             'display_name': 'high heels'}
                        ],
                        'term': 'high heals', 'tokens': ['high', 'heals']
                    }
                ],
                'can_not_match': []
            },
            actual

        )
Beispiel #44
0
    def test_single_term(self):
        target = Target({})

        actual = target.find_matches(
            3, [{
                'stem': 'red',
                'value': 'red',
                'start': 0,
                'pos': 'NNP',
                'end': 3,
                'skip_word': False,
                'stop_word': False,
                'use': True
            }, {
                'stem': 'heel',
                'value': 'heel',
                'start': 4,
                'pos': 'NN',
                'end': 8,
                'skip_word': False,
                'stop_word': False,
                'use': True
            }], {
                'en': {
                    'citru': [{
                        'type': 'color',
                        'key': 'citrus',
                        'source': 'content',
                        'display_name': 'citrus',
                        'match_type': 'spelling'
                    }],
                    'citrus': [{
                        'type': 'color',
                        'key': 'citrus',
                        'source': 'content',
                        'display_name': 'citrus',
                        'match_type': 'alias'
                    }],
                    'red': [{
                        'type': 'color',
                        'key': 'red',
                        'source': 'content',
                        'display_name': 'red',
                        'match_type': 'alias'
                    }],
                    'heel': [{
                        'type': 'style',
                        'key': 'heel',
                        'source': 'content',
                        'display_name': 'heel',
                        'match_type': 'alias'
                    }]
                }
            })

        self.assertDictEqual(
            actual, {
                "can_not_match": [],
                "found": [{
                    'found_item': [{
                        'display_name': 'red',
                        'key': 'red',
                        'match_type': 'alias',
                        'source': 'content',
                        'type': 'color'
                    }],
                    'position':
                    '0_3',
                    'start':
                    0,
                    'end':
                    3,
                    'term':
                    'red',
                    'tokens': ['red']
                }, {
                    'found_item': [{
                        'display_name': 'heel',
                        'key': 'heel',
                        'match_type': 'alias',
                        'source': 'content',
                        'type': 'style'
                    }],
                    'position':
                    '4_8',
                    'start':
                    4,
                    'end':
                    8,
                    'term':
                    'heel',
                    'tokens': ['heel']
                }]
            })
Beispiel #45
0
img = './data/demo/dog.jpg'
net = 'darknet19_yolo'
sys.path.append(os.path.join(os.getcwd(), 'symbol'))
net = importlib.import_module("symbol_" + net) \
            .get_symbol(len(CLASSES), nms_thresh = 0.5, force_nms = True)
prefix = os.path.join(os.getcwd(), 'model', 'yolo2_darknet19_416')
epoch = 0
data_shape = 608
mean_pixels = (123, 117, 104)
ctx = mx.gpu(0)
batch = 3

detector = Detector(net,
                    prefix,
                    epoch,
                    data_shape,
                    mean_pixels,
                    ctx=ctx,
                    batch_size=batch)

ims = [
    cv2.resize(cv2.imread(img), (data_shape, data_shape)) for i in range(batch)
]


def get_batch(imgs):
    img_len = len(imgs)
    l = []
    for i in range(batch):
        if i < img_len:
            img = np.swapaxes(imgs[i], 0, 2)
Beispiel #46
0
from multiprocessing import Process
from urllib.parse import urlparse
from capture.har import Har
from capture.chrome import Chrome
from database.observer import Observer
from detect.detector import Detector

REDIS_SERVER = os.environ['REDIS_SERVER']
REDIS_PASSWORD = os.environ['REDIS_PASSWORD']
REDIS_TOPIC_OBSERVER_URLS = os.environ['REDIS_TOPIC_OBSERVER_URLS']
redis = redis.StrictRedis(host=REDIS_SERVER, password=REDIS_PASSWORD)

har = Har()
observer = Observer()
chrome = Chrome()
detector = Detector(observer)

def get_origin(url):
  if url.startswith('http'):
    origin = get_fld(url)
  else:
    origin = get_fld('http://' + url)
  return origin

# def do_observe(id, observer_url, language):
#
#   origin = get_origin(observer_url)
#
#   data = har.capture(observer_url)
#   # print(data)
#
Beispiel #47
0
    def test_has_non_matches(self):
        target = Target({})
        actual = target.find_matches(
            3,
            [
                {
                    'value': 'red',
                    'pos': 'VBD',
                    'stem': 'red',
                    'stop_word': False,
                    'use': True,
                    'start': 0,
                    'skip_word': False,
                    'end': 3
                },
                {
                    'value': 'valentino',
                    'pos': 'VBN',
                    'stem': 'valentino',
                    'stop_word': False,
                    'use': True,
                    'start': 4,
                    'skip_word': False,
                    'end': 13
                }
            ],
            {
                "en": {
                    "red": [
                        {"key": "red", "type": "color", "source": "content", 'match_type': 'alias'}
                    ]
                }
            }
        )

        self.assertDictEqual(
            {
                "found": [
                    {
                        'end': 3,
                        'position': '0_3',
                        'tokens': ['red'],
                        'found_item': [
                            {
                                'key': 'red',
                                'match_type': 'alias',
                                'source': 'content',
                                'type': 'color'
                            }
                        ],
                        'term': 'red',
                        'start': 0
                    }
                ],
                'can_not_match': [
                    {
                        'value': 'valentino',
                        'pos': 'VBN',
                        'stem': 'valentino',
                        'stop_word': False,
                        'use': True,
                        'start': 4,
                        'skip_word': False,
                        'end': 13
                    }
                ]
            },
            actual
        )
Beispiel #48
0
import torch
from PIL import Image
from matplotlib import pyplot
from detect.detector import Detector
if __name__ == '__main__':

    image_path = r'E:\PyCharmProject\mtcnn\src\images\2.jpg'

    p_net_param = r'E:\PyCharmProject\mtcnn\config\p.pt'
    r_net_param = r'E:\PyCharmProject\mtcnn\config\r.pt'
    o_net_param = r'E:\PyCharmProject\mtcnn\config\o.pt'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    detector = Detector(p_net_param, r_net_param, o_net_param, device)

    with Image.open(image_path) as img:
        print(img.size)
        boxes = detector.detect(img)
        print(boxes)
        for box in boxes:
            x1 = int(box[0])
            y1 = int(box[1])
            x2 = int(box[2])
            y2 = int(box[3])

            pyplot.gca().add_patch(
                pyplot.Rectangle((x1, y1),
                                 width=x2 - x1,
                                 height=y2 - y1,
Beispiel #49
0
    def test_multiple_term(self):
        target = Target({})
        actual = target.find_matches(
            3,
            [
                {
                    'value': 'red',
                    'pos': 'VBD',
                    'stem': 'red',
                    'stop_word': False,
                    'use': True,
                    'start': 0,
                    'skip_word': False,
                    'end': 3
                },
                {
                    'value': 'valentino',
                    'pos': 'VBN',
                    'stem': 'valentino',
                    'stop_word': False,
                    'use': True,
                    'start': 4,
                    'skip_word': False,
                    'end': 13
                }
            ],
            {
                "en": {
                    "red": [
                        {"key": "red", "type": "color", "source": "content", 'match_type': 'alias'}
                    ],
                    "red valentino": [
                        {"key": "red valentino", "type": "brand", "source": "content", 'match_type': 'alias'}
                    ]
                }
            }
        )

        self.assertDictEqual(
            actual,
            {
                "can_not_match": [],
                "found": [
                    {
                        'found_item': [
                            {
                                'key': 'red valentino',
                                'match_type': 'alias',
                                'source': 'content',
                                'type': 'brand'
                            }
                        ],
                        'term': 'red valentino',
                        'end': 13,
                        'start': 0,
                        'position': '0_13',
                        'tokens': ['red', 'valentino']
                    }
                ]
            }

        )
class Yolov3Trainer:
    def __init__(self, net, netfile_name, cfgfile=None):
        self.net = net
        self.netfile_name = netfile_name
        print(cfgfile)
        if cfgfile != None:
            self.cfginit(cfgfile)
            print(1)
        print(SAVE_DIR)
        utils.makedir(SAVE_DIR)
        parser = argparse.ArgumentParser(description="base class for network training")
        self.args = self.argparser(parser)

        net_savefile = "{0}.{1}".format(self.netfile_name, NETFILE_EXTENTION)
        self.save_dir = os.path.join(SAVE_DIR, "nets")
        utils.makedir(self.save_dir)
        self.save_path = os.path.join(self.save_dir, net_savefile)
        self.savepath_epoch = os.path.join(SAVEDIR_EPOCH, net_savefile)

        if os.path.exists(self.save_path) and CONTINUETRAIN:

            try:
                self.net.load_state_dict(torch.load(self.save_path))
                print("net param load successful")

            except:
                self.net = torch.load(self.save_path)
                print("net load successful")


        else:
            self.net.paraminit()
            print("param initial complete")

        if ISCUDA:
            self.net = self.net.to(DEVICE)

        if NEEDTEST:
            self.detecter = Detector()

        self.logdir = os.path.join(SAVE_DIR, "log")

        utils.makedir(self.logdir)
        self.logfile = os.path.join(self.logdir, "{0}.txt".format(self.netfile_name))
        if not os.path.exists(self.logfile):
            with open(self.logfile, 'w') as f:
                print("%.2f %d    " % (0.00, 0), end='\r', file=f)
                print("logfile created")

        self.optimizer = optim.Adam(self.net.parameters())

        # 损失函数定义
        self.conf_loss_fn = nn.BCEWithLogitsLoss()  # 定义置信度损失函数
        self.center_loss_fn = nn.BCEWithLogitsLoss()  # 定义中心点损失函数
        self.wh_loss_fn = nn.MSELoss()  # 宽高损失
        # self.cls_loss_fn = torch.nn.CrossEntropyLoss()  # 定义交叉熵损失
        self.cls_loss_fn = nn.CrossEntropyLoss()

        self.detecter = Detector()

        print("initial complete")

    def cfginit(self, cfgfile):
        config = configparser.ConfigParser()
        config.read(cfgfile)
        items_ = config.items(self.netfile_name)

        for key, value in items_:
            if key.upper() in globals().keys():
                try:
                    globals()[key.upper()] = config.getint(self.netfile_name, key.upper())
                except:
                    try:
                        globals()[key.upper()] = config.getfloat(self.netfile_name, key.upper())
                    except:
                        try:
                            globals()[key.upper()] = config.getboolean(self.netfile_name, key.upper())
                        except:
                            globals()[key.upper()] = config.get(self.netfile_name, key.upper())


    def argparser(self, parser):
        """default argparse, please customize it by yourself. """

        parser.add_argument("-e", "--epoch", type=int, default=EPOCH, help="number of epochs")
        parser.add_argument("-b", "--batch_size", type=int, default=BATCHSIZE, help="mini-batch size")
        parser.add_argument("-n", "--num_workers", type=int, default=NUMWORKERS,
                            help="number of threads used during batch generation")
        parser.add_argument("-l", "--lr", type=float, default=LR, help="learning rate for gradient descent")
        parser.add_argument("-r", "--record_point", type=int, default=RECORDPOINT, help="print frequency")
        parser.add_argument("-t", "--test_point", type=int, default=TESTPOINT,
                            help="interval between evaluations on validation set")
        parser.add_argument("-a", "--alpha", type=float, default=ALPHA, help="ratio of conf and offset loss")
        parser.add_argument("-d", "--threshold", type=float, default=THREHOLD, help="threhold")

        return parser.parse_args()

    def _loss_fn(self, output, target, alpha):

        output = output.permute(0, 2, 3, 1)
        output = output.reshape(output.size(0), output.size(1), output.size(2), 3, -1)

        mask_obj = target[..., 0] > 0
        mask_noobj = target[..., 0] == 0

        output_obj, target_obj = output[mask_obj], target[mask_obj]

        loss_obj_conf = self.conf_loss_fn(output_obj[:, 0], target_obj[:, 0])
        loss_obj_center = self.center_loss_fn(output_obj[:, 1:3], target_obj[:, 1:3])
        loss_obj_wh = self.wh_loss_fn(output_obj[:, 3:5], target_obj[:, 3:5])
        loss_obj_cls = self.cls_loss_fn(output_obj[:, 5:], target_obj[:, 5].long())
        loss_obj = loss_obj_conf + loss_obj_center + loss_obj_wh + loss_obj_cls

        output_noobj, target_noobj = output[mask_noobj], target[mask_noobj]
        loss_noobj = self.conf_loss_fn(output_noobj[:, 0], target_noobj[:, 0])

        loss = alpha * loss_obj + (1 - alpha) * loss_noobj
        return loss

    def logging(self, result, dataloader_len, RECORDPOINT):

        with open(self.logfile, "r+") as f:

            if f.readline() == "":
                batchcount = 0
                f.seek(0, 0)
                print("%.2f %d        " % (0.00, 0), end='\r', file=f)

            else:
                f.seek(0, 0)
                batchcount = int(f.readline().split()[-1].strip()) + RECORDPOINT


            f.seek(0, 0)
            print("%.2f %d " % (batchcount / dataloader_len, batchcount), end='', file=f)

            f.seek(0, 2)
            print(result, file=f)

    def getstatistics(self):
        datalist = []
        with open(self.logfile) as f:
            for line in f.readlines():
                if not line[0].isdigit():
                    datalist.append(eval(line))
        return datalist

    def scalarplotting(self, datalist, key):
        save_dir = os.path.join(SAVE_DIR, key)
        utils.makedir(save_dir)
        save_name = "{0}.jpg".format(key)

        save_file = os.path.join(save_dir, save_name)
        values = []
        for data_dict in datalist:
            if data_dict:
                values.append(data_dict[key])
        if len(values) != 0:
            plt.plot(values)
            plt.savefig(save_file)
            plt.show()

    def FDplotting(self, net):
        save_dir = os.path.join(SAVE_DIR, "params")
        utils.makedir(save_dir)
        save_name = "{0}_param.jpg".format(self.netfile_name)
        save_file = os.path.join(SAVE_DIR, save_name)
        params = []
        for param in net.parameters():
            params.extend(param.view(-1).cpu().detach().numpy())
        params = np.array(params)
        histo = np.histogram(params, 10, range=(np.min(params), np.max(params)))
        plt.plot(histo[1][1:], histo[0])
        plt.savefig(save_file)
        plt.show()

    def train(self):
        dataset = YoloDataset(LABEL_PATH, PIC_DIR)
        train_loader = data.DataLoader(dataset, batch_size=self.args.batch_size, shuffle=True,
                                       num_workers=self.args.num_workers,
                                       drop_last=True)
        dataloader_len = len(train_loader)

        start_time = time.time()

        if os.path.exists(self.logfile):
            with open(self.logfile) as f:
                if f.readline() != "":
                    f.seek(0, 0)
                    batch_count = int(float(f.readline().split()[1]))

        for i in range(self.args.epoch):

            for j, (target13, target26, target52, img_data) in enumerate(train_loader):

                self.net.train()
                if ISCUDA:
                    target13 = target13.to(DEVICE)
                    target26 = target26.to(DEVICE)
                    target52 = target52.to(DEVICE)
                    img_data = img_data.to(DEVICE)

                output_13, output_26, output_52 = self.net(img_data)

                loss_13 = self._loss_fn(output_13, target13, alpha=ALPHA)
                loss_26 = self._loss_fn(output_26, target26, alpha=ALPHA)
                loss_52 = self._loss_fn(output_52, target52, alpha=ALPHA)
                loss = loss_13 + loss_26 + loss_52
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()


                if j % self.args.record_point == 0:

                    checktime = time.time() - start_time

                    result = "{'epoch':%d,'batch':%d,'loss':%.5f,'loss_13':%.5f,'loss_26':%.5f,'loss_52':%f',total_time':%.2f,'time':%s}" % (
                        i, j, loss, loss_13, loss_26, loss_52, checktime,
                        time.strftime("%Y%m%d%H%M%S", time.localtime()))
                    print(result)

                    # self.logging(result, dataloader_len, self.args.record_point)
                    if NEEDSAVE:
                        # torch.save(self.net.state_dict(), self.save_path)
                        torch.save(self.net, self.save_path)
                        print("net save successful")

                if NEEDTEST and j % self.args.test_point == 0:
                    self.net.eval()

                    batch_count = i
                    self.test(batch_count,j)
            if NEEDSAVE:
                torch.save(self.net.state_dict(), self.savepath_epoch)
                # torch.save(self.net, self.savepath_epoch)
                # print("an epoch save successful")

    def test(self, batch_count,j):
        with torch.no_grad():
            self.net.eval()

            img = Image.open(TEST_IMG)
            # img_ = cv2.imread(TEST_IMG)

            last_boxes = self.detecter.detect(img, self.args.threshold, net=self.net)

            draw = ImageDraw.Draw(img)
            font = ImageFont.truetype(font="arial.ttf", size=10, encoding="utf-8")

            if np.any(last_boxes):
                for box in last_boxes:
                    xybox = box[:4].astype("i4")
                    text_x, text_y = list(box[:2])[0], list(box[:2])[1] - 10
                    text_conf = list(box[:2])[0] + 30
                    draw.text((text_x, text_y), cfg.COCO_DICT[int(box[5])], fill=(255, 0, 0), font=font)
                    draw.text((text_conf, text_y), "%.2f" % box[4], fill=(255, 0, 0), font=font)
                    draw.rectangle(list(xybox), outline="green", width=2)

            # img.show()
            if NEEDSAVE:
                testpic_savedir = os.path.join(SAVE_DIR, "testpic", self.netfile_name)
                utils.makedir(testpic_savedir)
                testpic_savefile = os.path.join(testpic_savedir, "{0}_{1}.jpg".format(batch_count,j))
                img.save(testpic_savefile)

            if NEEDSHOW:
                plt.clf()
                plt.axis("off")
                plt.imshow(img)
                plt.pause(0.1)
Beispiel #51
0
    """
    #frame_resize = mx.nd.array(cv2.resize(frame, (self.data_shape[0], self.data_shape[1])))
    
    #frame_resize = mx.img.imresize(frame, self.data_shape[0], self.data_shape[1], cv2.INTER_LINEAR)
    
    
    # Change dimensions from (w,h,channels) to (channels, w, h)
    
    #frame_t = mx.nd.transpose(frame_resize, axes=(2,0,1))
    
    #frame_norm = frame_t - self.mean_pixels_nd
    
    print(y_gen[0].asnumpy().shape)
    """

    print(y_gen[0].asnumpy().flatten()[:20])

    result = Detector.filter_positive_detections(y_gen[0].asnumpy())

    for k, det in enumerate(result):

        #img = cv2.imread(im_list[k])

        #img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        visualize_detection(frame, det, classes, 0.6)

    end = time.time()

    #print(time.process_time() - start)
Beispiel #52
0
class Detect(RequestHandler):
    brute_detector = None
    alias_data = None
    data_response = None
    param_extractor = None
    path_extractor = None
    entity_factory = None

    def data_received(self, chunk):
        pass

    def initialize(self, alias_data):
        from detect.data.response import Response

        self.data_response = Response()
        self.data_response.open_connection()
        self.alias_data = alias_data
        self.param_extractor = ParamExtractor(self)
        self.path_extractor = PathExtractor(self)
        self.entity_factory = EntityFactory(self.alias_data)
        self.brute_detector = Detector(self.alias_data)

    def on_finish(self):
        pass

    @asynchronous
    def post(self, *args, **kwargs):
        self.set_header("Content-Type", "application/json")

        detection_id = ObjectId()

        app_log.info(
            "app=detection,function=detect,detection_id=%s,application_id=%s,session_id=%s,q=%s",
            detection_id,
            self.param_extractor.application_id(),
            self.param_extractor.session_id(),
            self.param_extractor.query(),
        )

        if False:
            url = "%smessage?v=%s&q=%s&msg_id=%s" % (
                WIT_URL,
                WIT_URL_VERSION,
                url_escape(self.param_extractor.query()),
                str(detection_id),
            )
            r = HTTPRequest(url, headers={"Authorization": "Bearer %s" % WIT_TOKEN})
            client = AsyncHTTPClient()
            client.fetch(r, callback=self.wit_call_back)
        else:
            date = datetime.now()
            outcomes = self.brute_detector.detect(self.param_extractor.query())
            self.data_response.insert(
                self.param_extractor.user_id(),
                self.param_extractor.application_id(),
                self.param_extractor.session_id(),
                detection_id,
                "brute",
                date,
                self.param_extractor.query(),
                outcomes=outcomes,
            )

            self.set_status(202)
            self.set_header("Location", "/%s" % str(detection_id))
            self.set_header("_id", str(detection_id))
            self.finish()

            Worker(
                self.param_extractor.user_id(),
                self.param_extractor.application_id(),
                self.param_extractor.session_id(),
                detection_id,
                date,
                self.param_extractor.query(),
                self.param_extractor.skip_slack_log(),
                detection_type="wit",
                outcomes=outcomes,
            ).start()

    @asynchronous
    def get(self, detection_id, *args, **kwargs):
        data = self.data_response.get(self.path_extractor.detection_id(detection_id))
        if data is not None:
            self.set_header("Content-Type", "application/json")
            self.set_status(200)
            self.finish(
                dumps(
                    {
                        "type": data["type"],
                        "q": data["q"],
                        "outcomes": data["outcomes"],
                        "_id": data["_id"],
                        "version": data["version"],
                        "timestamp": data["timestamp"],
                    }
                )
            )
        else:
            self.set_status(404)
            self.finish()

    def wit_call_back(self, response):
        data = json_decode(response.body)
        outcomes = []
        date = datetime.now()
        for outcome in data["outcomes"]:
            entities = []
            for _type in outcome["entities"].keys():
                if _type not in ["polite"]:
                    for value in outcome["entities"][_type]:
                        suggested = value["suggested"] if "suggested" in value else False
                        key = value["value"]["value"] if type(value["value"]) is dict else value["value"]
                        entity = self.entity_factory.create(_type, key, suggested)

                        # TODO this needs to be moved somewhere else preferably a seperate service call
                        entities.append(entity)

            outcomes.append(
                {"confidence": outcome["confidence"] * 100, "intent": outcome["intent"], "entities": entities}
            )

        self.data_response.insert(
            self.param_extractor.user_id(),
            self.param_extractor.application_id(),
            self.param_extractor.session_id(),
            ObjectId(data["msg_id"]),
            "wit",
            date,
            self.param_extractor.query(),
            outcomes=outcomes,
        )

        self.set_status(202)
        self.set_header("Location", "/%s" % data["msg_id"])
        self.set_header("_id", data["msg_id"])
        self.finish()

        Worker(
            self.param_extractor.user_id(),
            self.param_extractor.application_id(),
            self.param_extractor.session_id(),
            ObjectId(data["msg_id"]),
            date,
            self.param_extractor.query(),
            self.param_extractor.skip_slack_log(),
            detection_type="wit",
            outcomes=outcomes,
        ).start()