Ejemplo n.º 1
0
def p_numexpressions(p):
    """numexpressions : LBRACKET numexpression RBRACKET numexpressions
                    | LBRACKET numexpression RBRACKET
    """
    if len(p) == 5:
        p[0] = Tree("array", p[2], p[4])
    else:
        p[0] = Tree("array", p[2])
    global_tree.append((tuple(['numexpressions'] + p[1:])))
Ejemplo n.º 2
0
def p_unaryiter(p):
    """unaryiter : unaryop unaryexpr unaryiter
                    | epsilon
    """
    if len(p) == 4:
        if p[3] is None:
            p[0] = Tree(p[1], None, p[2])
        else:
            p[3].left = p[2]
            p[0] = Tree(p[1], None, p[3])
    global_tree.append((tuple(['unaryiter'] + p[1:])))
Ejemplo n.º 3
0
def p_signedterms(p):
    """signedterms : signal term signedterms
                    | epsilon
    """
    if len(p) == 4:
        if p[3] is None:
            p[0] = Tree(p[1], None, p[2])
        else:
            p[3].left = p[2]
            p[0] = Tree(p[1], None, p[3])
    global_tree.append((tuple(['signedterms'] + p[1:])))
Ejemplo n.º 4
0
    def __getitem__(self, idx):
        img = Image.open(
            os.path.join(self.img_dir,
                         str(self.partition[idx]) + '.png'))

        tree = ET.parse(
            os.path.join(self.tree_dir,
                         str(self.partition[idx]) + '.xml'))
        et_root = tree.getroot()

        def build_tree(ele, et_ele):
            for et_child in et_ele:
                new_child = Tree(et_child.tag)
                ele.add_child(new_child)
                build_tree(new_child, et_child)
            ele.add_child(Tree('end'))

        root = Tree(et_root.tag)
        build_tree(root, et_root)

        if self.tree_transform != None:
            root = self.tree_transform(root)
        if self.img_transform != None:
            img = self.img_transform(img)
        return {'img': img, 'tree': root}
Ejemplo n.º 5
0
 def __init__(self, time_not_looping=0.05):
     self._not_looping = time_not_looping
     self._subscriptions = Tree(config.MQTT_HOME, ["Functions"])
     self.payload_on = ("ON", True, "True")
     self.payload_off = ("OFF", False, "False")
     self._retained = []
     self.mqtt_home = config.MQTT_HOME
     super().__init__()
     self.id = "SmartServer"
     self.enable_logger(log)
     self.username_pw_set(config.MQTT_USER, config.MQTT_PASSWORD)
     self.on_connect = self._connected
     self.on_message = self._execute_sync
     self.on_disconnect = self._on_disconnect
     self.connect(config.MQTT_HOST, 1883, 60)
     asyncio.ensure_future(self._keep_connected())
Ejemplo n.º 6
0
def p_unaryexpr(p):
    """unaryexpr : signal factor
                    | factor
    """
    if len(p) == 2:
        p[0] = p[1]
    else:
        p[0] = Tree(p[1], p[2])
    global_tree.append((tuple(['unaryexpr'] + p[1:])))
Ejemplo n.º 7
0
def p_funclist(p):
    """funclist : funcdef funclist
                | funcdef
    """
    if len(p) == 3:
        p[0] = Tree('funclist', p[1], p[2])
    else:
        p[0] = p[1]
    global_tree.append(('funclist', p[1:]))
Ejemplo n.º 8
0
    def forward(self, img, path):
        
        output = []
        for i in range(len(path)):
            # if path[i] = root
            if path[i].parent == None:
                img_features = self.cnn(img)
                child_c = self.img_to_c(img_features)
                child_h = path[0].value.detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
                dummy = Tree('')
                dummy.state = child_c, child_h
                feed_list = [dummy]
            else:
                feed_list = [path[i].parent]
                idx = path[i].parent.children.index(path[i])
                for j in range(idx):
                    child = path[i].parent.children[j]
                    child_c = child.value.detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
                    child_h = child.value.detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
                    child.state = self.tree_lstm(child.value, child_c, child_h)
                    feed_list.append(child)
            
            # if this is the end
            if i == len(path)-1:
                idx = len(path[i].children)
            else:
                idx = path[i].children.index(path[i+1])
            
            for j in range(idx):
                child = path[i].children[j]
                child_c = child.value.detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
                child_h = child.value.detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
                child.state = self.tree_lstm(child.value, child_c, child_h)
                feed_list.append(child)

            child_c, child_h = zip(* map(lambda x: x.state, feed_list))
            child_c, child_h = torch.cat(child_c, dim=0), torch.cat(child_h, dim=0)
            
            path[i].state = self.tree_lstm(path[i].value, child_c, child_h)
            output.append(self.h_to_word(path[i].state[1]))

        return torch.cat(output, dim=0)
Ejemplo n.º 9
0
def p_lvalue(p):
    """lvalue : IDENT
                | IDENT numexpressions
    """
    if find_var(p[1], Scope.actual_scope) is None:
        semantic_error("variable \'{}\' not declared.".format(p[1]), p)

    if len(p) == 2:
        p[0] = p[1]
    else:
        p[0] = Tree(p[1], p[2])
    global_tree.append((tuple(['lvalue'] + p[1:])))
Ejemplo n.º 10
0
def generate_random_tree(size):
    T = Tree()
    T.add_root(random.randint(0, MAX_VAL))
    frontier = deque()
    frontier.append(T.root())

    while len(T) < size:
        p = frontier.popleft()
        num_children = random.randint(1, min(4, size - len(T)))
        for _ in range(num_children):
            T.add_child(p, random.randint(0, MAX_VAL))
        frontier.extend(T.children(p))

    return T
Ejemplo n.º 11
0
def predict_tree(model, img, word_dict, tree, max_child=7):

    def str_to_tensor(word):
        return torch.from_numpy(word_dict[word]).float().to(device)

    def get_path(node, root):
        path = []
        n = node
        path.insert(0, n)
        while n!=root:
            n = n.parent
            path.insert(0, n)
        return path
    
    def tree_str(node):
        out_str = '  '*node.height() + node.str
        for child in node.children:
            child_str = tree_str(child)
            out_str += '\n' + child_str
        return out_str

    root = Tree(str_to_tensor('root'))
    root.str = 'root'

    target_queue = [tree]
    target_index = 0

    queue = [root]
    while len(queue) != 0:
        parent = queue[0]
        path = get_path(parent, root)

        pred = model(img, path).detach()
        pred_value = pred[-1]
        pred_node = Tree(pred_value)
        word = list(word_dict.keys())[torch.argmax(pred_node.value)]

        # check ans
        target_parent = target_queue[0]
        
        if word!=target_parent.children[target_index].value:
            print('wrong!!! target:{} predict:{}'.format(target_parent.children[target_index].value, word))
            # print(tree_str(root))
            # print("====================")
            word = target_parent.children[target_index].value

        pred_node.str = word


        parent.add_child(pred_node)
        if pred_node.str == 'end':
            target_index = 0
            target_queue.pop(0)
            queue.pop(0)
        else:
            target_queue.append(target_parent.children[target_index])
            queue.append(pred_node)
            target_index += 1
Ejemplo n.º 12
0
    def _decode(self, code):
        """ Given a binary code build a tree corresponding to that code.
        @param code (int):  A 2b-bit number encoding a tree of size b.
        @return tree (Tree): A Tree object corresponding to the binary encoding.
        """
        binary_code = [int(bit) for bit in "{0:b}".format(code)]
        binary_code = binary_code[1:]

        # Assert the binary code is a valid encoding of a tree.
        zeros = len([x for x in binary_code if x == 0])
        ones = len([x for x in binary_code if x == 1])
        assert (zeros == ones)

        tree = Tree()
        elem = 0
        cursor = tree.add_root(elem)
        for bit in binary_code:
            if bit == 0:
                elem += 1
                cursor = tree.add_child(cursor, elem)
            elif bit == 1:
                cursor = tree.parent(cursor)

        return tree
Ejemplo n.º 13
0
def p_binaryoperator(p):
    """binaryoperator : numexpression relationaloperator
    """
    p[0] = Tree(p[2], p[1])
    global_tree.append((tuple(['binaryoperator'] + p[1:])))
Ejemplo n.º 14
0
class MQTTHandler(MQTTClient):
    def __init__(self, time_not_looping=0.05):
        self._not_looping = time_not_looping
        self._subscriptions = Tree(config.MQTT_HOME, ["Functions"])
        self.payload_on = ("ON", True, "True")
        self.payload_off = ("OFF", False, "False")
        self._retained = []
        self.mqtt_home = config.MQTT_HOME
        super().__init__()
        self.id = "SmartServer"
        self.enable_logger(log)
        self.username_pw_set(config.MQTT_USER, config.MQTT_PASSWORD)
        self.on_connect = self._connected
        self.on_message = self._execute_sync
        self.on_disconnect = self._on_disconnect
        self.connect(config.MQTT_HOST, 1883, 60)
        asyncio.ensure_future(self._keep_connected())

    def _connected(self, client, userdata, flags, rc):
        log.info("Connection returned result: " + connack_string(rc))
        self._publishDeviceStats()
        self._subscribeTopics()

    def _on_disconnect(self, client, userdata, rc):
        if rc != 0:
            log.warn("Unexpected disconnection.")

    async def _keep_connected(self):
        log.info("Keeping connected")
        while True:
            self.loop(0.05)
            await asyncio.sleep(self._not_looping)

    def _subscribeTopics(self):
        for obj, topic in self._subscriptions.__iter__(with_path=True):
            super().subscribe(topic, qos=1)

    def unsubscribe(self, topic, callback=None):
        if self._isDeviceTopic(topic):
            topic = self.getRealTopic(topic)
        if callback is None:
            log.debug("unsubscribing topic {}".format(topic))
            self._subscriptions.removeObject(topic)
            super().unsubscribe(topic)
        else:
            try:
                cbs = self._subscriptions.getFunctions(topic)
                if type(cbs) not in (tuple, list):
                    self._subscriptions.removeObject(topic)
                    return
                try:
                    cbs = list(cbs)
                    cbs.remove(callback)
                except ValueError:
                    log.warn(
                        "Callback to topic {!s} not subscribed".format(topic),
                        local_only=True)
                    return
                self._subscriptions.setFunctions(topic, cbs)
            except ValueError:
                log.warn("Topic {!s} does not exist".format(topic))

    async def subscribe(self, topic, callback, qos=0, check_retained=True):
        if self._isDeviceTopic(topic):
            topic = self.getRealTopic(topic)
        log.debug("Subscribing to topic {}".format(topic))
        self._subscriptions.addObject(topic, callback)
        if check_retained:
            if topic[-4:] == "/set":
                # subscribe to topic without /set to get retained message for this topic state
                # this is done additionally to the retained topic with /set in order to recreate
                # the current state and then get new instructions in /set
                state_topic = topic[:-4]
                self._retained.append(state_topic)
                self._subscriptions.addObject(state_topic, callback)
                super().subscribe(state_topic, qos)
                await self._await_retained(state_topic, callback, True)
                # to give retained state time to process before adding /set subscription
            self._retained.append(topic)
        super().subscribe(topic, qos)
        if check_retained:
            asyncio.ensure_future(self._await_retained(topic, callback))

    def _publishDeviceStats(self):
        pass

    def getDeviceTopic(self, attrib, is_request=False):
        if is_request:
            attrib += "/set"
        return ".{}".format(attrib)

    def _isDeviceTopic(self, topic):
        if topic[:1] == ".":
            return True
        return False

    def getRealTopic(self, device_topic):
        if device_topic[:1] != ".":
            raise ValueError("DeviceTopic does not start with .")
        return "{}/{}/{}".format(self.mqtt_home, self.id, device_topic[1:])

    async def _await_retained(self, topic, cb=None, remove_after=False):
        st = 0
        while topic in self._retained and st <= 8:
            await asyncio.sleep(0.1)
            st += 1
        try:
            log.debug("removing retained topic {}".format(topic))
            self._retained.remove(topic)
        except ValueError:
            pass
        if remove_after:
            self.unsubscribe(topic, cb)

    def _execute_sync(self, client, userdata, msg):
        """mqtt library only handles sync callbacks so add it to async loop"""
        asyncio.ensure_future(self._execute(msg.topic, msg.payload,
                                            msg.retain))

    async def _execute(self, topic, msg, retain):
        log.debug("mqtt execution: {!s} {!s}".format(topic, msg))
        msg = msg.decode()
        try:
            msg = json.loads(msg)
        except:
            pass  # maybe not a json string, no way of knowing
        cb = None
        if topic in self._retained:
            retain = True
        else:
            for topicR in self._retained:
                if topicR[-1:] == "#":
                    if topic.find(topicR[:-1]) != -1:
                        retain = True
        if retain:
            try:
                cb = self._subscriptions.getFunctions(topic + "/set")
            except IndexError:
                try:
                    cb = self._subscriptions.getFunctions(topic)
                except IndexError:
                    pass
        if cb is None:
            try:
                cb = self._subscriptions.getFunctions(topic)
            except IndexError:
                log.warn("No cb found for topic {!s}".format(topic))
        if cb:
            for callback in cb if (type(cb) == list
                                   or type(cb) == tuple) else [cb]:
                try:
                    if asyncio.iscoroutinefunction(callback):
                        res = await callback(topic=topic,
                                             msg=msg,
                                             retain=retain)
                    else:
                        res = callback(topic=topic, msg=msg, retain=retain)
                    if not retain:
                        if (type(res) == int
                                and res is not None) or res == True:
                            # so that an integer 0 is interpreted as a result to send back
                            if res == True and type(res) != int:
                                res = msg
                                # send original msg back
                            if topic[-4:] == "/set":
                                # if a /set topic is found, send without /set
                                self.publish(topic[:-4], res, retain=True)
                except Exception as e:
                    log.error(
                        "Error executing {!s} mqtt topic {!r}: {!s}".format(
                            "retained " if retain else "", topic, e))
            if retain and topic[-2:] != "/#":
                # only remove if it is not a wildcard topic to allow other topics
                # to handle retained messages belonging to this wildcard
                try:
                    self._retained.remove(topic)
                except ValueError:
                    pass
                    # already removed by _await_retained

    def publish(self, topic, msg, retain=False, qos=0):
        if type(msg) == dict or type(msg) == list:
            msg = json.dumps(msg)
        elif type(msg) != str:
            msg = str(msg)
        if self._isDeviceTopic(topic):
            topic = self.getRealTopic(topic)
        super().publish(topic, msg, retain=retain, qos=qos)
Ejemplo n.º 15
0
 def build_tree(ele, et_ele):
     for et_child in et_ele:
         new_child = Tree(et_child.tag)
         ele.add_child(new_child)
         build_tree(new_child, et_child)
     ele.add_child(Tree('end'))
Ejemplo n.º 16
0
def predict_tree(img, model, device, word_dict, max_child=6):
    tranf = transforms.Compose([WordEmbedding(word_dict), TreeToTensor()])
    end_value = tranf(Tree('end')).value.to(device)
    model.to(device)

    root = Tree('root')
    root = tranf(root)
    root.value = root.value.to(device)
    out_size = root.value.size()

    queue = [root]
    while len(queue) != 0:
        sub_tree = Tree(image_caption_model(img, [root]).flatten().detach())
        max_value = torch.max(sub_tree.value)
        sub_tree.value = torch.where(sub_tree.value >= max_value,
                                     torch.ones(out_size).to(device),
                                     torch.zeros(out_size).to(device))
        queue[0].add_child(sub_tree)

        if len(queue[0].children) >= max_child:
            sub_tree = Tree(end_value.clone().detach())
            queue[0].add_child(sub_tree)
        if torch.equal(end_value, sub_tree.value):
            queue.pop(0)
        else:
            queue.append(sub_tree)

    root.for_each_value(lambda x: x.cpu().numpy())
    vec2word = Vec2Word(word_dict)
    root = vec2word(root)
    return root