Esempio n. 1
0
def create_source_target(b,
                         options,
                         dataset,
                         delex=True,
                         relex=False,
                         doCategory=[],
                         negraph=False,
                         lowercased=True):
    """
    Write target and source files, and reference files for BLEU.
    :param b: instance of Benchmark class
    :param options: string "delex" or "notdelex" to label files
    :param dataset: dataset part: train, dev, test
    :param delex: boolean; perform delexicalisation or not
    TODO:update parapms
    :return: if delex True, return list of replacement dictionaries for each example
    """
    source_out = []
    source_nodes_out = []
    source_edges_out_labels = []
    source_edges_out_node1 = []
    source_edges_out_node2 = []
    target_out = []
    rplc_list = []  # store the dict of replacements for each example
    for entr in b.entries:
        tripleset = entr.modifiedtripleset
        lexics = entr.lexs
        category = entr.category
        if doCategory and not category in doCategory:
            #if not category in UNSEEN_CATEGORIES:
            continue
        for lex in lexics:
            triples = ''
            properties_objects = {}
            tripleSep = ""
            # if len(tripleset.triples)>2:
            #     a=0
            for triple in tripleset.triples:
                triples += tripleSep + triple.s + '|' + triple.p + '|' + triple.o + ' '
                tripleSep = "<TSP>"

                properties_objects[triple.p] = triple.o
            triples = triples.replace('_', ' ').replace('"', '')
            # separate punct signs from text
            out_src = ' '.join(re.split('(\W)', triples))
            out_trg = ' '.join(re.split('(\W)', lex.lex))
            if delex:
                out_src, out_trg, rplc_dict = delexicalisation(
                    out_src, out_trg, category, properties_objects)
                rplc_list.append(rplc_dict)

            if negraph:
                source_nodes, source_edges = buildGraphWithNE(out_src)
            else:
                source_nodes, source_edges = buildGraph(out_src)
            source_nodes_out.append(source_nodes)
            source_edges_out_labels.append(source_edges[0])
            source_edges_out_node1.append(source_edges[1])
            source_edges_out_node2.append(source_edges[2])
            source_out.append(' '.join(out_src.split()))
            target_out.append(' '.join(out_trg.split()))

    #TODO: we could add a '-src-features.txt' if we want to attach features to nodes
    if not relex:
        if doCategory == UNSEEN_CATEGORIES:
            options = options + '-unseen'
        #we do not need to re-generate GCN input files when doing relexicalisation.. check this works ok
        with open(dataset + '-webnlg-' + options + '-src-nodes.txt',
                  'w+',
                  encoding='utf8') as f:
            f.write('\n'.join(source_nodes_out).lower() if (
                lowercased and not delex) else '\n'.join(source_nodes_out))
        with open(dataset + '-webnlg-' + options + '-src-labels.txt',
                  'w+',
                  encoding='utf8') as f:
            f.write('\n'.join(source_edges_out_labels))
        with open(dataset + '-webnlg-' + options + '-src-node1.txt',
                  'w+',
                  encoding='utf8') as f:
            f.write('\n'.join(source_edges_out_node1))
        with open(dataset + '-webnlg-' + options + '-src-node2.txt',
                  'w+',
                  encoding='utf8') as f:
            f.write('\n'.join(source_edges_out_node2))
        with open(dataset + '-webnlg-' + options + '-tgt.txt',
                  'w+',
                  encoding='utf8') as f:
            f.write('\n'.join(target_out).lower() if (
                lowercased and not delex) else '\n'.join(target_out))

    with open(dataset + '-webnlg-' + options + '.triple',
              'w+',
              encoding='utf8') as f:
        f.write('\n'.join(source_out))
    with open(dataset + '-webnlg-' + options + '.lex', 'w+',
              encoding='utf8') as f:
        f.write('\n'.join(target_out).lower() if (
            lowercased and not delex) else '\n'.join(target_out))

    # create separate files with references for multi-bleu.pl for dev set
    scr_refs = defaultdict(list)
    if (dataset == 'dev' or dataset.startswith('test')) and not delex:
        ##TODO: I think that taking only the nodes part is enough for BLEU scripts, see if we really nead the whole graph here in the src part
        for src, trg in zip(source_out, target_out):
            scr_refs[src].append(trg)
        # length of the value with max elements
        max_refs = sorted(scr_refs.values(), key=len)[-1]
        keys = [key for (key, value) in sorted(scr_refs.items())]
        values = [value for (key, value) in sorted(scr_refs.items())]
        # write the source file not delex
        with open(dataset + "-" + options + '-source.triple',
                  'w+',
                  encoding='utf8') as f:
            f.write('\n'.join(keys))
        # write references files
        for j in range(0, len(max_refs)):
            with open(dataset + "-" + options + '-reference' + str(j) + '.lex',
                      'w+',
                      encoding='utf8') as f:
                out = ''
                for ref in values:
                    try:
                        out += ref[j].lower() + '\n' if (
                            lowercased and not delex) else ref[j] + '\n'
                    except:
                        out += '\n'
                f.write(out)
                f.close()

        #write reference files for E2E evaluation metrics
        with open(dataset + "-" + options + '-conc.txt', 'w+',
                  encoding='utf8') as f:
            for ref in values:
                for j in range(len(ref)):
                    f.write(ref[j].lower() +
                            '\n' if (lowercased and not delex) else ref[j] +
                            '\n')
                f.write("\n")
            f.close()

    return rplc_list
Esempio n. 2
0
def create_test_data(b,
                     options,
                     dataset='test',
                     delex=True,
                     relex=False,
                     doCategory=[],
                     negraph=False,
                     lowercased=True):
    nodes = []  # [batch, node_num,]
    in_neigh_indices = []  # [batch, node_num, neighbor_num,]
    in_neigh_edges = []
    out_neigh_indices = []  # [batch, node_num, neighbor_num,]
    out_neigh_edges = []
    sentences = []  # [batch, sent_length,]
    ids = []
    type = []
    max_in_neigh = 0
    max_out_neigh = 0
    max_node = 0
    max_sent = 0
    rplc_list = []  # store the dict of replacements for each example
    for entr in b.entries:
        tripleset = entr.modifiedtripleset
        lexics = entr.lexs
        id = entr.id
        category = entr.category
        if doCategory and not category in doCategory:
            #if not category in UNSEEN_CATEGORIES:
            continue
        triples = ''
        properties_objects = {}
        tripleSep = ""
        for triple in tripleset.triples:

            triples += tripleSep + triple.s + '|' + triple.p + '|' + triple.o + ' '
            tripleSep = "<tsp>"
            triples = triples.lower()
            properties_objects[triple.p] = triple.o
        triples = triples.replace('_', ' ').replace('"', '')
        # separate punct signs from text
        out_src = ' '.join(re.split('(\W)', triples)).lower()
        out_src = filter(lambda x: x in printable, out_src)
        out_trg = out_src
        out_trg = filter(lambda x: x in printable, out_trg)
        if delex:
            out_src, out_trg, rplc_dict = delexicalisation(
                out_src, out_trg, category, properties_objects)
            rplc_list.append(rplc_dict)
        # If we want to have special arcs in the graph for multi-word named entities then add -e argument.
        out_trg = out_trg.strip().split()
        # build graph
        rdf_node = []
        rdf_edge = []
        for t in out_src.split("< tsp >"):
            t = t.strip().split(" | ")
            subjectList = t[0].strip().split()
            for index, item in enumerate(subjectList):
                if not item in rdf_node:
                    rdf_node.append(item)
                if index != 0 and not (rdf_node.index(subjectList[index - 1]),
                                       rdf_node.index(subjectList[index]),
                                       "NE") in rdf_edge:
                    rdf_edge.append((rdf_node.index(subjectList[index - 1]),
                                     rdf_node.index(subjectList[index]), "NE"))
            subject = subjectList[-1]

            objectList = t[2].strip().split()
            for index, item in enumerate(objectList):
                if not item in rdf_node:
                    rdf_node.append(item)
                if index != 0 and not (rdf_node.index(objectList[index - 1]),
                                       rdf_node.index(objectList[index]),
                                       "NE") in rdf_edge:
                    rdf_edge.append((rdf_node.index(objectList[index - 1]),
                                     rdf_node.index(objectList[index]), "NE"))
            object = objectList[0]

            relationList = t[1].strip().split()
            for index, item in enumerate(relationList):
                if not item in rdf_node:
                    rdf_node.append(item)
                if index != 0 and not (rdf_node.index(relationList[index - 1]),
                                       rdf_node.index(relationList[index]),
                                       "NE") in rdf_edge:
                    rdf_edge.append(
                        (rdf_node.index(relationList[index - 1]),
                         rdf_node.index(relationList[index]), "NE"))
            rdf_edge.append((rdf_node.index(subject),
                             rdf_node.index(relationList[0]), "A0"))
            rdf_edge.append((rdf_node.index(object),
                             rdf_node.index(relationList[-1]), "A1"))
        nodes.append(rdf_node)

        # 2. & 3.
        in_indices = [[
            i,
        ] for i, x in enumerate(rdf_node)]
        in_edges = [[
            ':self',
        ] for i, x in enumerate(rdf_node)]
        out_indices = [[
            i,
        ] for i, x in enumerate(rdf_node)]
        out_edges = [[
            ':self',
        ] for i, x in enumerate(rdf_node)]
        for (i, j, lb) in rdf_edge:
            in_indices[j].append(i)
            in_edges[j].append(lb)
            out_indices[i].append(j)
            out_edges[i].append(lb)
        in_neigh_indices.append(in_indices)
        in_neigh_edges.append(in_edges)
        out_neigh_indices.append(out_indices)
        out_neigh_edges.append(out_edges)
        # 4.
        sentences.append(out_trg)
        ids.append(id)
        # update lengths
        max_in_neigh = max(max_in_neigh, max(len(x) for x in in_indices))
        max_out_neigh = max(max_out_neigh, max(len(x) for x in out_indices))
        max_node = max(max_node, len(rdf_node))
        max_sent = max(max_sent, len(out_trg))
        type.append('rdf')
    return zip(nodes, in_neigh_indices, in_neigh_edges, out_neigh_indices, out_neigh_edges, sentences, ids, type), \
               max_node, max_in_neigh, max_out_neigh, max_sent