Exemple #1
0
    def generate(self, max_depth):
        """Generate a random arithmetic expression tree, using just n-ary plus and binary minus
            Args:
                max_depth: integer > 0
            Returns:
                expression tree where leaves are int.
        """

        if max_depth == 1: # recursion base case
            v = random.sample(self.leaf_values, 1)[0]
            return Tree(node_type_id='num_value', value=self.NumValue(abstract_value=v))

        elif max_depth > 1:
            types = self.node_types
            node_type = random.sample(types, 1)[0]
            o = random.sample(['+', '-'], 1)[0]

            if node_type.id == 'num_value':
                return self.generate(1)

            elif node_type.id == 'op_n' and o == "+":
                n = random.randint(node_type.arity.min_value, node_type.arity.max_value)
                return Tree(node_type.id, children=[self.generate(max_depth - 1) for _ in range(n)],
                            value=OpValue(abstract_value='+'))

            elif node_type.id == 'op_n' and o == "-":
                return Tree(node_type.id, children=[
                    self.generate(max_depth - 1),
                    self.generate(max_depth - 1)],
                            value=OpValue(abstract_value='-'))
Exemple #2
0
            def init(n: Tree, depth=0):
                nonlocal self
                n.meta['dec_batch'] = self
                n.meta['node_numb'] = self.counters['embs']
                self.counters['embs'] += 1
                self.depths['embs'].append(depth)

                if n.value is not None:
                    k = 'vals_' + n.node_type_id
                    n.meta['value_numb'] = self.counters[k]

                    if k not in self.depths.keys():
                        self.depths[k] = []
                    self.depths[k].append(depth)

                    self.counters[k] += 1
Exemple #3
0
def label_tree_with_real_data(xml_tree: ET.Element, final_tree: Tree,
                              tokenizer):

    value = xml_tree.attrib["value"]
    if xml_tree.tag == "node" and value != "ROOT":
        #check if in frequent word in dev set otherwise label as others (last dimension)
        try:
            idx = shared_list.tags_idx.index(value)
        except:
            idx = len(shared_list.tags_idx) - 1
        final_tree.node_type_id = "POS_tag"
        final_tree.value = TagValue(
            representation=tf.one_hot(idx, len(shared_list.tags_idx)))
        final_tree.children = []
        for child in xml_tree.getchildren():
            final_tree.children.append(Tree(node_type_id="fake "))

    elif xml_tree.tag == "leaf":
        #check if in tag found in dev set otherwise label as others (last dimension)
        idx = tokenizer.texts_to_sequences([value])
        final_tree.node_type_id = "word"
        final_tree.value = WordValue(representation=tf.one_hot(
            idx[0][0], WordValue.representation_shape))
        for child in xml_tree.getchildren():
            final_tree.children.append(Tree(node_type_id="fake "))

    #RECURSION
    elif xml_tree.tag == "node" and value == "ROOT":
        label_tree_with_real_data(xml_tree.getchildren()[0], final_tree,
                                  tokenizer)

    for child_xml, child_real in zip(xml_tree.getchildren(),
                                     final_tree.children):
        label_tree_with_real_data(child_xml, child_real, tokenizer)
Exemple #4
0
def label_tree_with_sentenceTree(dev_data, tes_data, base_path):
    """
    function that given a tree (target for NN one) without sentence "label" it also with it
    :param dev_data:
    :param tes_data:
    :param base_path:
    :return:
    """
    #read xml file first
    for data in dev_data + tes_data:
        name = data['name']
        #after got file name, read tree from xml file
        tree = read_tree_from_file(base_path + name)
        data['sentence_tree'] = tree

    #count occurency of words
    word_occ = []
    for data in dev_data:
        count_word_tag_occ(data['sentence_tree'], word_occ)
    tokenizer, _ = extraxt_topK_words(word_occ, filters="~")
    TagValue.update_rep_shape(len(shared_list.tags_idx))

    #label tree with real data
    for data in dev_data + tes_data:
        final_tree = Tree(node_type_id="dummy root",
                          children=[],
                          value="dummy")
        label_tree_with_real_data(data['sentence_tree'], final_tree, tokenizer)
        final_tree = final_tree.children[0]
        if final_tree.value.abstract_value == "S":
            data['sentence_tree'] = final_tree
        else:
            idx = shared_list.tags_idx.index("S")
            tag = TagValue(
                representation=tf.one_hot(idx, len(shared_list.tags_idx)))
            S_node = Tree(node_type_id="POS_tag",
                          children=[final_tree],
                          value=tag)
            data['sentence_tree'] = S_node
Exemple #5
0
    def generate(self, max_depth):
        """Generate a random arithmetic expression tree, using just binary plus and minus
            Args:
                max_depth: integer > 0
            Returns:
                expression tree where leaves are int.
        """

        if max_depth == 1:  # recursion base case
            v = random.sample(self.leaf_values, 1)[0]
            return Tree(node_type_id='num_value', value=self.NumValue(abstract_value=v))

        elif max_depth > 1:
            types = self.node_types
            node_type = random.sample(types, 1)[0]

            if node_type.id == 'num_value':
                return self.generate(1)

            else:
                return Tree(node_type.id, children=[
                    self.generate(max_depth - 1),
                    self.generate(max_depth - 1)], value=None)
def reconstruct_tree(data, tmp, tree_cnn_type):
    dummy_root = Tree(
        node_type_id="dummy",
        children=[],
        meta={'dummy root'},
        value=ImageValueAlexNet(abstract_value=data) if tree_cnn_type
        == "alexnet" else ImageValueInception(abstract_value=data))
    last_node = dummy_root
    parent_node = None
    travesed_node = []

    for i in range(1, len(tmp)):
        #loop iterating trough tree nodes

        data = get_node_value(tmp[i])

        count = tmp[i - 1].count(")")

        # if is current node child
        if tmp[i - 1].__contains__("("):
            leaf_n = tmp[i - 1].split("(")
            new_node = Tree(
                node_type_id="",
                children=[],
                meta={'label': leaf_n[1]},
                value=ImageValueAlexNet(abstract_value=data) if tree_cnn_type
                == "alexnet" else ImageValueInception(abstract_value=data))
            parent_node = last_node

            # update list of all internal node traversed useful later
            travesed_node.append(parent_node)

            parent_node.children.append(new_node)
            last_node = new_node

        elif count:
            # if listed all child of same node

            # deleted from list all node no more used
            travesed_node = travesed_node[:len(travesed_node) - count]
            parent_node = travesed_node[-1]

            leaf_n = tmp[i - 1].split("),")
            new_node = Tree(
                node_type_id="leaf",
                children=[],
                meta={'label': leaf_n[1]},
                value=ImageValueAlexNet(abstract_value=data) if tree_cnn_type
                == "alexnet" else ImageValueInception(abstract_value=data))
            parent_node.children.append(new_node)
            last_node = new_node

        else:
            # if is current node sibling
            leaf_n = tmp[i - 1].split("],")
            new_node = Tree(
                node_type_id="leaf",
                children=[],
                meta={'label': leaf_n[1]},
                value=ImageValueAlexNet(abstract_value=data) if tree_cnn_type
                == "alexnet" else ImageValueInception(abstract_value=data))
            parent_node.children.append(new_node)
            last_node = new_node

    return dummy_root.children[0]
def main(argv=None):

    #########
    # Checkpoints and Summaries
    #########

    if tf.gfile.Exists(FLAGS.model_dir):
        if FLAGS.overwrite:
            tf.logging.warn("Deleting old log directory at {}".format(
                FLAGS.model_dir))
            tf.gfile.DeleteRecursively(FLAGS.model_dir)
            tf.gfile.MakeDirs(FLAGS.model_dir)

    else:
        tf.gfile.MakeDirs(FLAGS.model_dir)

    summary_writer = tfs.create_file_writer(FLAGS.model_dir, flush_millis=1000)
    summary_writer.set_as_default()
    print("Summaries in " + FLAGS.model_dir)

    with open(os.path.join(FLAGS.model_dir, "flags.json"), 'w') as f:
        json.dump(FLAGS.flag_values_dict(), f)

    #########
    # DATA
    #########

    if FLAGS.fixed_arity:
        tree_gen = BinaryExpressionTreeGen(9)
    else:
        tree_gen = NaryExpressionTreeGen(9, FLAGS.max_arity)

    def get_batch():
        return [
            tree_gen.generate(FLAGS.max_depth) for _ in range(FLAGS.batch_size)
        ]

    #########
    # MODEL
    #########

    activation = getattr(tf.nn, FLAGS.activation)

    encoder = Encoder(
        tree_def=tree_gen.tree_def,
        embedding_size=FLAGS.embedding_size,
        cut_arity=FLAGS.cut_arity,
        max_arity=FLAGS.max_arity,
        variable_arity_strategy=FLAGS.enc_variable_arity_strategy,
        cellsbuilder=EncoderCellsBuilder(
            EncoderCellsBuilder.simple_cell_builder(
                hidden_coef=FLAGS.hidden_cell_coef,
                activation=activation,
                gate=FLAGS.encoder_gate),
            EncoderCellsBuilder.simple_dense_embedder_builder(
                activation=activation)),
        name='encoder')

    decoder = Decoder(
        tree_def=tree_gen.tree_def,
        embedding_size=FLAGS.embedding_size,
        max_node_count=FLAGS.max_node_count,
        max_depth=FLAGS.max_depth,
        max_arity=FLAGS.max_arity,
        cut_arity=FLAGS.cut_arity,
        cellsbuilder=DecoderCellsBuilder(
            distrib_builder=DecoderCellsBuilder.simple_distrib_cell_builder(
                FLAGS.hidden_cell_coef, activation=activation),
            categorical_value_inflater_builder=DecoderCellsBuilder.
            simple_1ofk_value_inflater_builder(FLAGS.hidden_cell_coef,
                                               activation=activation),
            dense_value_inflater_builder=None,  # unused
            node_inflater_builder=DecoderCellsBuilder.
            simple_node_inflater_builder(FLAGS.hidden_cell_coef,
                                         activation=activation,
                                         gate=FLAGS.decoder_gate)),
        variable_arity_strategy=FLAGS.dec_variable_arity_strategy)

    ###########
    # TRAINING
    ###########

    optimizer = tf.train.AdamOptimizer()

    with tfs.always_record_summaries():
        for i in range(FLAGS.max_iter):
            with tfe.GradientTape() as tape:
                xs = get_batch()
                batch_enc = encoder(xs)
                batch_dec = decoder(encodings=batch_enc.get_root_embeddings(),
                                    targets=xs)

                loss_struct, loss_val = batch_dec.reconstruction_loss()
                loss = loss_struct + loss_val

            variables = encoder.variables + decoder.variables
            grad = tape.gradient(loss, variables)

            gnorm = tf.global_norm(grad)
            grad, _ = tf.clip_by_global_norm(grad, 0.02, gnorm)

            tfs.scalar("norms/grad", gnorm)

            optimizer.apply_gradients(
                zip(grad, variables),
                global_step=tf.train.get_or_create_global_step())

            if i % FLAGS.check_every == 0:

                batch_unsuperv = decoder(
                    encodings=batch_enc.get_root_embeddings())

                _, _, v_avg_sup, v_acc_sup = Tree.compare_trees(
                    xs, batch_dec.decoded_trees)
                s_avg, s_acc, v_avg, v_acc = Tree.compare_trees(
                    xs, batch_unsuperv.decoded_trees)

                print("{0}:\t{1:.3f}".format(i, loss))

                tfs.scalar("loss/struct", loss_struct)
                tfs.scalar("loss/val", loss_val)
                tfs.scalar("loss/loss", loss)

                tfs.scalar("overlaps/supervised/value_avg", v_avg_sup)
                tfs.scalar("overlaps/supervised/value_acc", v_acc_sup)

                tfs.scalar("overlaps/unsupervised/struct_avg", s_avg)
                tfs.scalar("overlaps/unsupervised/struct_acc", s_acc)
                tfs.scalar("overlaps/unsupervised/value_avg", v_avg)
                tfs.scalar("overlaps/unsupervised/value_acc", v_acc)
Exemple #8
0
        def init(n: Tree, **kwargs):

            nonlocal self
            n.meta['emb_batch'] = self
            n.meta['node_numb'] = self.counters['embs']
            self.counters['embs'] += 1