コード例 #1
0
def main():
    dataset = dataloader.read_rstdt(split="train", relation_level="coarse-grained", with_root=True)

    relation_mapper = treetk.rstdt.RelationMapper()
    i = 0
    for data in dataset:
        edu_ids = data.edu_ids
        edus = data.edus
        edus_postag = data.edus_postag
        edus_head = data.edus_head
        sbnds = data.sbnds
        pbnds = data.pbnds
        nary_sexp = data.nary_sexp
        bin_sexp = data.bin_sexp
        arcs = data.arcs

        print("Data instance #%d" % i)
        print("\t Paragraph #0")
        print("\t\t Sentence #0")
        print("\t\t\t EDU #0")
        print("\t\t\t\t EDU ID:", edu_ids[0])
        print("\t\t\t\t EDU:", edus[0])
        print("\t\t\t\t EDU (POS):", edus_postag[0])
        print("\t\t\t\t EDU (head):", edus_head[0])
        p_i = 1
        s_i = 1
        e_i = 1
        for p_begin, p_end in pbnds:
            print("\t Paragraph #%d" % p_i)
            for s_begin, s_end in sbnds[p_begin:p_end+1]:
                print("\t\t Sentence #%d" % s_i)
                for edu_id, edu, edu_postag, edu_head in zip(edu_ids[1+s_begin:1+s_end+1],
                                                             edus[1+s_begin:1+s_end+1],
                                                             edus_postag[1+s_begin:1+s_end+1],
                                                             edus_head[1+s_begin:1+s_end+1]):
                    print("\t\t\t EDU #%d" % e_i)
                    print("\t\t\t\t EDU ID:", edu_id)
                    print("\t\t\t\t EDU:", edu)
                    print("\t\t\t\t EDU (POS):", edu_postag)
                    print("\t\t\t\t EDU (head):", edu_head)
                    e_i += 1
                s_i += 1
            p_i += 1
        nary_tree = treetk.rstdt.postprocess(treetk.sexp2tree(nary_sexp, with_nonterminal_labels=True, with_terminal_labels=False))
        nary_tree = treetk.rstdt.map_relations(nary_tree, mode="c2a")
        bin_tree = treetk.rstdt.postprocess(treetk.sexp2tree(bin_sexp, with_nonterminal_labels=True, with_terminal_labels=False))
        bin_tree = treetk.rstdt.map_relations(bin_tree, mode="c2a")
        arcs = [(h,d,relation_mapper.c2a(l)) for h,d,l in arcs]
        dtree = treetk.arcs2dtree(arcs)
        treetk.pretty_print(nary_tree)
        treetk.pretty_print(bin_tree)
        treetk.pretty_print_dtree(dtree)
        i += 1
コード例 #2
0
def parse(model, decoder, dataset, path_pred):
    """
    :type model: SpanBasedModel
    :type decoder: IncrementalCKYDecoder
    :type dataset: numpy.ndarray
    :type path_pred: str
    :rtype: None
    """
    with open(path_pred, "w") as f:

        for data in pyprind.prog_bar(dataset):
            edu_ids = data.edu_ids
            edus = data.edus
            edus_postag = data.edus_postag
            edus_head = data.edus_head
            sbnds = data.sbnds
            pbnds = data.pbnds

            # Feature extraction
            edu_vectors = model.forward_edus(edus, edus_postag,
                                             edus_head)  # (n_edus, bilstm_dim)
            padded_edu_vectors = model.pad_edu_vectors(
                edu_vectors)  # (n_edus+2, bilstm_dim)
            mask_bwd, mask_fwd = model.make_masks(
            )  # (1, bilstm_dim), (1, bilstm_dim)

            # Parsing (bracketing)
            span_scores = precompute_all_span_scores(
                model=model,
                edus=edus,
                edus_postag=edus_postag,
                sbnds=sbnds,
                pbnds=pbnds,
                padded_edu_vectors=padded_edu_vectors,
                mask_bwd=mask_bwd,
                mask_fwd=mask_fwd)
            unlabeled_sexp = decoder.decode(span_scores=span_scores,
                                            inputs=edu_ids,
                                            sbnds=sbnds,
                                            pbnds=pbnds,
                                            use_sbnds=True,
                                            use_pbnds=True)  # list of str
            unlabeled_tree = treetk.sexp2tree(unlabeled_sexp,
                                              with_nonterminal_labels=False,
                                              with_terminal_labels=False)
            unlabeled_tree.calc_spans()
            unlabeled_spans = treetk.aggregate_spans(
                unlabeled_tree, include_terminal=False,
                order="pre-order")  # list of (int, int)

            # Parsing (assigning majority labels to the unlabeled tree)
            span2label = {(b, e): "<ELABORATION,N/S>"
                          for (b, e) in unlabeled_spans}
            labeled_tree = treetk.assign_labels(unlabeled_tree,
                                                span2label,
                                                with_terminal_labels=False)
            labeled_sexp = treetk.tree2sexp(labeled_tree)

            f.write("%s\n" % " ".join(labeled_sexp))
コード例 #3
0
def parse(model, decoder, databatch, path_pred):
    """
    :type model: Model
    :type decoder: IncrementalCKYDecoder
    :type databatch: DataBatch
    :type path_pred: str
    :rtype: None
    """
    with open(path_pred, "w") as f:
        prog_bar = pyprind.ProgBar(len(databatch))

        for edu_ids, edus, edus_postag, edus_head, sbnds, pbnds \
                in zip(databatch.batch_edu_ids,
                       databatch.batch_edus,
                       databatch.batch_edus_postag,
                       databatch.batch_edus_head,
                       databatch.batch_sbnds,
                       databatch.batch_pbnds):

            # Feature extraction
            edu_vectors = model.forward_edus(edus, edus_postag,
                                             edus_head)  # (n_edus, bilstm_dim)
            padded_edu_vectors = model.pad_edu_vectors(
                edu_vectors)  # (n_edus+2, bilstm_dim)
            mask_bwd, mask_fwd = model.make_masks(
            )  # (1, bilstm_dim), (1, bilstm_dim)

            # Parsing (constituency)
            unlabeled_sexp = decoder.decode(
                model=model,
                sexps=edu_ids,
                edus=edus,
                edus_postag=edus_postag,
                sbnds=sbnds,
                pbnds=pbnds,
                padded_edu_vectors=padded_edu_vectors,
                mask_bwd=mask_bwd,
                mask_fwd=mask_fwd,
                use_sbnds=True,
                use_pbnds=True)  # list of str
            unlabeled_tree = treetk.sexp2tree(unlabeled_sexp,
                                              with_nonterminal_labels=False,
                                              with_terminal_labels=False)
            unlabeled_tree.calc_spans()
            unlabeled_spans = treetk.aggregate_spans(
                unlabeled_tree, include_terminal=False,
                order="pre-order")  # list of (int, int)

            # Assigning majority labels to the unlabeled tree
            span2label = {(b, e): "<ELABORATION,N/S>"
                          for (b, e) in unlabeled_spans}
            labeled_tree = treetk.assign_labels(unlabeled_tree,
                                                span2label,
                                                with_terminal_labels=False)
            labeled_sexp = treetk.tree2sexp(labeled_tree)

            f.write("%s\n" % " ".join(labeled_sexp))

            prog_bar.update()
コード例 #4
0
def read_trees(path):
    """
    :type path: str
    :rtype: list of NonTerminal
    """
    sexps = utils.read_lines(path, process=lambda line: line.split())
    trees = [treetk.rstdt.postprocess(treetk.sexp2tree(sexp, with_nonterminal_labels=True, with_terminal_labels=False)) for sexp in sexps]
    return trees
コード例 #5
0
def main() :
    config = utils.Config()

    utils.mkdir(os.path.join(config.getpath("data"), "rstdt-vocab"))

    filenames = []
    for filename in os.listdir(os.path.join(config.getpath("data"), "rstdt", "wsj", "train")):
        filenames.append(os.path.join(config.getpath("data"), "rstdt", "wsj", "train", filename))
    for filename in os.listdir(os.path.join(config.getpath("data"), "rstdt", "wsj", "test")):
        filenames.append(os.path.join(config.getpath("data"), "rstdt", "wsj", "test", filename))
    filenames = [n for n in filenames if n.endswith(".labeled.bin.ctree")]
    filenames.sort()

    relation_mapper = treetk.rstdt.RelationMapper()

    frelations = []
    crelations = []
    nuclearities = []

    for filename in pyprind.prog_bar(filenames):
        sexp = utils.read_lines(filename, process=lambda line: line)
        sexp = treetk.preprocess(sexp)
        tree = treetk.rstdt.postprocess(treetk.sexp2tree(sexp, with_nonterminal_labels=True, with_terminal_labels=False))

        nodes = treetk.traverse(tree, order="pre-order", include_terminal=False, acc=None)

        part_frelations = []
        part_crelations = []
        part_nuclearities = []
        for node in nodes:
            relations_ = node.relation_label.split("/")
            part_frelations.extend(relations_)
            part_crelations.extend([relation_mapper.f2c(r) for r in relations_])
            part_nuclearities.append(node.nuclearity_label)

        part_frelations.append("<root>")
        part_crelations.append("<root>")

        frelations.append(part_frelations)
        crelations.append(part_crelations)
        nuclearities.append(part_nuclearities)

    fcounter = utils.get_word_counter(lines=frelations)
    ccounter = utils.get_word_counter(lines=crelations)
    ncounter = utils.get_word_counter(lines=nuclearities)

    frelations = fcounter.most_common() # list of (str, int)
    crelations = ccounter.most_common() # list of (str, int)
    nuclearities = ncounter.most_common() # list of (str, int)

    utils.write_vocab(os.path.join(config.getpath("data"), "rstdt-vocab", "relations.fine.vocab.txt"),
                      frelations)
    utils.write_vocab(os.path.join(config.getpath("data"), "rstdt-vocab", "relations.coarse.vocab.txt"),
                      crelations)
    utils.write_vocab(os.path.join(config.getpath("data"), "rstdt-vocab", "nuclearities.vocab.txt"),
                      nuclearities)
def process(path):

    # NOTE: We use n-ary ctrees (ie., *.labeled.nary.ctree) to generate dtrees.
    #       Morey et al. (2018) demonstrate that scores evaluated on these dtrees are
    #       superficially lower than those on right-heavy binarized trees (ie., *.labeled.bin.ctree).

    filenames = os.listdir(path)
    filenames = [n for n in filenames if n.endswith(".labeled.nary.ctree")]
    filenames.sort()

    def func_label_rule(node, i, j):
        relations = node.relation_label.split("/")
        if len(relations) == 1:
            return relations[0]  # Left-most node is head.
        else:
            if i > j:
                return relations[j]
            else:
                return relations[j - 1]

    for filename in pyprind.prog_bar(filenames):
        sexp = utils.read_lines(os.path.join(path, filename),
                                process=lambda line: line.split())
        assert len(sexp) == 1
        sexp = sexp[0]

        # Constituency
        ctree = treetk.rstdt.postprocess(
            treetk.sexp2tree(sexp,
                             with_nonterminal_labels=True,
                             with_terminal_labels=False))

        # Dependency
        # Assign heads
        ctree = treetk.rstdt.assign_heads(ctree)
        # Conversion
        dtree = treetk.ctree2dtree(ctree, func_label_rule=func_label_rule)
        arcs = dtree.tolist(labeled=True)

        # Write
        with open(
                os.path.join(path,
                             filename.replace(".labeled.nary.ctree", ".arcs")),
                "w") as f:
            f.write("%s\n" %
                    " ".join(["%d-%d-%s" % (h, d, l) for h, d, l in arcs]))
コード例 #7
0
def parse(sampler, databatch, path_pred):
    """
    :type sampler: TreeSampler
    :type databatch: DataBatch
    :type path_pred: str
    :rtype: None
    """
    with open(path_pred, "w") as f:
        prog_bar = pyprind.ProgBar(len(databatch))

        for edu_ids, edus, edus_postag, edus_head, sbnds, pbnds \
                in zip(databatch.batch_edu_ids,
                       databatch.batch_edus,
                       databatch.batch_edus_postag,
                       databatch.batch_edus_head,
                       databatch.batch_sbnds,
                       databatch.batch_pbnds):

            # Tree sampling (constituency)
            unlabeled_sexp = sampler.sample(sexps=edu_ids,
                                            edus=edus,
                                            edus_head=edus_head,
                                            sbnds=sbnds,
                                            pbnds=pbnds)
            unlabeled_tree = treetk.sexp2tree(unlabeled_sexp,
                                              with_nonterminal_labels=False,
                                              with_terminal_labels=False)
            unlabeled_tree.calc_spans()
            unlabeled_spans = treetk.aggregate_spans(
                unlabeled_tree, include_terminal=False,
                order="pre-order")  # list of (int, int)

            # Assigning majority labels to the unlabeled tree
            span2label = {(b, e): "<ELABORATION,N/S>"
                          for (b, e) in unlabeled_spans}
            labeled_tree = treetk.assign_labels(unlabeled_tree,
                                                span2label,
                                                with_terminal_labels=False)
            labeled_sexp = treetk.tree2sexp(labeled_tree)

            f.write("%s\n" % " ".join(labeled_sexp))

            prog_bar.update()
def parse(sampler, dataset, path_pred):
    """
    :type sampler: TreeSampler
    :type dataset: numpy.ndarray
    :type path_pred: str
    :rtype: None
    """
    with open(path_pred, "w") as f:

        for data in pyprind.prog_bar(dataset):
            edu_ids = data.edu_ids
            edus = data.edus
            edus_head = data.edus_head
            sbnds = data.sbnds
            pbnds = data.pbnds

            # Tree sampling (constituency)
            unlabeled_sexp = sampler.sample(sexps=edu_ids,
                                            edus=edus,
                                            edus_head=edus_head,
                                            sbnds=sbnds,
                                            pbnds=pbnds)
            unlabeled_tree = treetk.sexp2tree(unlabeled_sexp,
                                              with_nonterminal_labels=False,
                                              with_terminal_labels=False)
            unlabeled_tree.calc_spans()
            unlabeled_spans = treetk.aggregate_spans(
                unlabeled_tree, include_terminal=False,
                order="pre-order")  # list of (int, int)

            # Assigning majority labels to the unlabeled tree
            span2label = {(b, e): "<ELABORATION,N/S>"
                          for (b, e) in unlabeled_spans}
            labeled_tree = treetk.assign_labels(unlabeled_tree,
                                                span2label,
                                                with_terminal_labels=False)
            labeled_sexp = treetk.tree2sexp(labeled_tree)

            f.write("%s\n" % " ".join(labeled_sexp))
コード例 #9
0
def train(model, decoder, sampler, max_epoch, n_init_epochs, negative_size,
          batch_size, weight_decay, gradient_clipping, optimizer_name,
          train_dataset, dev_dataset, path_train, path_valid, path_snapshot,
          path_pred, path_gold):
    """
    :type model: SpanBasedModel
    :type decoder: IncrementalCKYDecoder
    :type sampler: TreeSampler
    :type max_epoch: int
    :type n_init_epochs: int
    :type negative_size: int
    :type batch_size: int
    :type weight_decay: float
    :type gradient_clipping: float
    :type optimizer_name: str
    :type train_dataset: numpy.ndarray
    :type dev_dataset: numpy.ndarray
    :type path_train: str
    :type path_valid: str
    :type path_snapshot: str
    :type path_pred: str
    :type path_gold: str
    :rtype: None
    """
    writer_train = jsonlines.Writer(open(path_train, "w"), flush=True)
    if dev_dataset is not None:
        writer_valid = jsonlines.Writer(open(path_valid, "w"), flush=True)

    boundary_flags = [(True, False)]
    assert negative_size >= len(boundary_flags)
    negative_tree_sampler = treesamplers.NegativeTreeSampler()

    # Optimizer preparation
    if optimizer_name == "adam":
        opt = optimizers.Adam()
    else:
        raise ValueError("Invalid optimizer_name=%s" % optimizer_name)

    opt.setup(model)

    if weight_decay > 0.0:
        opt.add_hook(chainer.optimizer.WeightDecay(weight_decay))
    if gradient_clipping:
        opt.add_hook(chainer.optimizer.GradientClipping(gradient_clipping))

    n_train = len(train_dataset)
    it = 0
    bestscore_holder = utils.BestScoreHolder(scale=100.0)
    bestscore_holder.init()

    if dev_dataset is not None:
        # Initial validation
        with chainer.using_config("train", False), chainer.no_backprop_mode():
            parse(model=model,
                  decoder=decoder,
                  dataset=dev_dataset,
                  path_pred=path_pred)
            scores = metrics.rst_parseval(pred_path=path_pred,
                                          gold_path=path_gold)
            old_scores = metrics.old_rst_parseval(pred_path=path_pred,
                                                  gold_path=path_gold)
            out = {
                "epoch": 0,
                "Morey2018": {
                    "Unlabeled Precision": scores["S"]["Precision"] * 100.0,
                    "Precision_info": scores["S"]["Precision_info"],
                    "Unlabeled Recall": scores["S"]["Recall"] * 100.0,
                    "Recall_info": scores["S"]["Recall_info"],
                    "Micro F1": scores["S"]["Micro F1"] * 100.0
                },
                "Marcu2000": {
                    "Unlabeled Precision":
                    old_scores["S"]["Precision"] * 100.0,
                    "Precision_info": old_scores["S"]["Precision_info"],
                    "Unlabeled Recall": old_scores["S"]["Recall"] * 100.0,
                    "Recall_info": old_scores["S"]["Recall_info"],
                    "Micro F1": old_scores["S"]["Micro F1"] * 100.0
                }
            }
            writer_valid.write(out)
            utils.writelog(utils.pretty_format_dict(out))
        # Saving
        bestscore_holder.compare_scores(scores["S"]["Micro F1"], step=0)
        serializers.save_npz(path_snapshot, model)
        utils.writelog("Saved the model to %s" % path_snapshot)
    else:
        # Saving
        serializers.save_npz(path_snapshot, model)
        utils.writelog("Saved the model to %s" % path_snapshot)

    for epoch in range(1, max_epoch + 1):

        perm = np.random.permutation(n_train)

        ########## E-Step (BEGIN) ##########
        utils.writelog("E step ===>")

        prog_bar = pyprind.ProgBar(n_train)

        for inst_i in range(0, n_train, batch_size):

            ### Mini batch

            for data in train_dataset[inst_i:inst_i + batch_size]:

                ### One data instance

                edu_ids = data.edu_ids
                edus = data.edus
                edus_postag = data.edus_postag
                edus_head = data.edus_head
                sbnds = data.sbnds
                pbnds = data.pbnds

                with chainer.using_config("train",
                                          False), chainer.no_backprop_mode():

                    # Feature extraction
                    edu_vectors = model.forward_edus(
                        edus, edus_postag, edus_head)  # (n_edus, bilstm_dim)
                    padded_edu_vectors = model.pad_edu_vectors(
                        edu_vectors)  # (n_edus+2, bilstm_dim)
                    mask_bwd, mask_fwd = model.make_masks(
                    )  # (1, bilstm_dim), (1, bilstm_dim)

                    # Positive tree
                    if epoch <= n_init_epochs:
                        pos_sexp = sampler.sample(inputs=edu_ids,
                                                  edus=edus,
                                                  edus_head=edus_head,
                                                  sbnds=sbnds,
                                                  pbnds=pbnds)
                    else:
                        span_scores = precompute_all_span_scores(
                            model=model,
                            edus=edus,
                            edus_postag=edus_postag,
                            sbnds=sbnds,
                            pbnds=pbnds,
                            padded_edu_vectors=padded_edu_vectors,
                            mask_bwd=mask_bwd,
                            mask_fwd=mask_fwd)
                        pos_sexp = decoder.decode(span_scores=span_scores,
                                                  inputs=edu_ids,
                                                  sbnds=sbnds,
                                                  pbnds=pbnds,
                                                  use_sbnds=True,
                                                  use_pbnds=True)
                    pos_tree = treetk.sexp2tree(pos_sexp,
                                                with_nonterminal_labels=False,
                                                with_terminal_labels=False)
                    pos_tree.calc_spans()
                    pos_spans = treetk.aggregate_spans(
                        pos_tree, include_terminal=False,
                        order="post-order")  # list of (int, int)
                    data.pos_spans = pos_spans  #NOTE
                    prog_bar.update()
        ########## E-Step (END) ##########

        ########## M-Step (BEGIN) ##########
        utils.writelog("M step ===>")

        for inst_i in range(0, n_train, batch_size):

            # Processing one mini-batch

            # Init
            loss_bracketing, acc_bracketing = 0.0, 0.0
            actual_batchsize = 0

            for data in train_dataset[perm[inst_i:inst_i + batch_size]]:

                # Processing one instance

                edu_ids = data.edu_ids
                edus = data.edus
                edus_postag = data.edus_postag
                edus_head = data.edus_head
                sbnds = data.sbnds
                pbnds = data.pbnds
                pos_spans = data.pos_spans  # NOTE

                # Feature extraction
                edu_vectors = model.forward_edus(
                    edus, edus_postag, edus_head)  # (n_edus, bilstm_dim)
                padded_edu_vectors = model.pad_edu_vectors(
                    edu_vectors)  # (n_edus+2, bilstm_dim)
                mask_bwd, mask_fwd = model.make_masks(
                )  # (1, bilstm_dim), (1, bilstm_dim)

                # Negative trees
                pos_neg_spans = []
                margins = []
                pos_neg_spans.append(pos_spans)
                with chainer.using_config("train",
                                          False), chainer.no_backprop_mode():
                    for use_sbnds, use_pbnds in boundary_flags:
                        span_scores = precompute_all_span_scores(
                            model=model,
                            edus=edus,
                            edus_postag=edus_postag,
                            sbnds=sbnds,
                            pbnds=pbnds,
                            padded_edu_vectors=padded_edu_vectors,
                            mask_bwd=mask_bwd,
                            mask_fwd=mask_fwd)
                        neg_bin_sexp = decoder.decode(
                            span_scores=span_scores,
                            inputs=edu_ids,
                            sbnds=sbnds,
                            pbnds=pbnds,
                            use_sbnds=use_sbnds,
                            use_pbnds=use_pbnds,
                            gold_spans=pos_spans)  # list of str
                        neg_tree = treetk.sexp2tree(
                            neg_bin_sexp,
                            with_nonterminal_labels=False,
                            with_terminal_labels=False)
                        neg_tree.calc_spans()
                        neg_spans = treetk.aggregate_spans(
                            neg_tree,
                            include_terminal=False,
                            order="pre-order")  # list of (int, int)
                        margin = compute_tree_distance(pos_spans,
                                                       neg_spans,
                                                       coef=1.0)
                        pos_neg_spans.append(neg_spans)
                        margins.append(margin)
                for _ in range(negative_size - len(boundary_flags)):
                    neg_bin_sexp = negative_tree_sampler.sample(inputs=edu_ids,
                                                                sbnds=sbnds,
                                                                pbnds=pbnds)
                    neg_tree = treetk.sexp2tree(neg_bin_sexp,
                                                with_nonterminal_labels=False,
                                                with_terminal_labels=False)
                    neg_tree.calc_spans()
                    neg_spans = treetk.aggregate_spans(
                        neg_tree, include_terminal=False,
                        order="pre-order")  # list of (int, int)
                    margin = compute_tree_distance(pos_spans,
                                                   neg_spans,
                                                   coef=1.0)
                    pos_neg_spans.append(neg_spans)
                    margins.append(margin)

                # Scoring
                pred_scores = model.forward_spans_for_bracketing(
                    edus=edus,
                    edus_postag=edus_postag,
                    sbnds=sbnds,
                    pbnds=pbnds,
                    padded_edu_vectors=padded_edu_vectors,
                    mask_bwd=mask_bwd,
                    mask_fwd=mask_fwd,
                    batch_spans=pos_neg_spans,
                    aggregate=True)  # (1+negative_size, 1)

                # Bracketing Loss
                for neg_i in range(negative_size):
                    loss_bracketing += F.clip(
                        pred_scores[1 + neg_i] + margins[neg_i] -
                        pred_scores[0], 0.0, 10000000.0)

                # Ranking Accuracy
                pred_scores = F.reshape(
                    pred_scores,
                    (1, 1 + negative_size))  # (1, 1+negative_size)
                gold_scores = np.zeros((1, ), dtype=np.int32)  # (1,)
                gold_scores = utils.convert_ndarray_to_variable(
                    gold_scores, seq=False)  # (1,)
                acc_bracketing += F.accuracy(pred_scores, gold_scores)

                actual_batchsize += 1

            # Backward & Update
            actual_batchsize = float(actual_batchsize)
            loss_bracketing = loss_bracketing / actual_batchsize
            acc_bracketing = acc_bracketing / actual_batchsize
            loss = loss_bracketing
            model.zerograds()
            loss.backward()
            opt.update()
            it += 1

            # Write log
            loss_bracketing_data = float(cuda.to_cpu(loss_bracketing.data))
            acc_bracketing_data = float(cuda.to_cpu(acc_bracketing.data))
            out = {
                "iter": it,
                "epoch": epoch,
                "progress": "%d/%d" % (inst_i + actual_batchsize, n_train),
                "progress_ratio":
                float(inst_i + actual_batchsize) / n_train * 100.0,
                "Bracketing Loss": loss_bracketing_data,
                "Ranking Accuracy": acc_bracketing_data * 100.0
            }
            writer_train.write(out)
            utils.writelog(utils.pretty_format_dict(out))
        ########## M-Step (END) ##########

        if dev_dataset is not None:
            # Validation
            with chainer.using_config("train",
                                      False), chainer.no_backprop_mode():
                parse(model=model,
                      decoder=decoder,
                      dataset=dev_dataset,
                      path_pred=path_pred)
                scores = metrics.rst_parseval(pred_path=path_pred,
                                              gold_path=path_gold)
                old_scores = metrics.old_rst_parseval(pred_path=path_pred,
                                                      gold_path=path_gold)
                out = {
                    "epoch": epoch,
                    "Morey2018": {
                        "Unlabeled Precision":
                        scores["S"]["Precision"] * 100.0,
                        "Precision_info": scores["S"]["Precision_info"],
                        "Unlabeled Recall": scores["S"]["Recall"] * 100.0,
                        "Recall_info": scores["S"]["Recall_info"],
                        "Micro F1": scores["S"]["Micro F1"] * 100.0
                    },
                    "Marcu2000": {
                        "Unlabeled Precision":
                        old_scores["S"]["Precision"] * 100.0,
                        "Precision_info": old_scores["S"]["Precision_info"],
                        "Unlabeled Recall": old_scores["S"]["Recall"] * 100.0,
                        "Recall_info": old_scores["S"]["Recall_info"],
                        "Micro F1": old_scores["S"]["Micro F1"] * 100.0
                    }
                }
                writer_valid.write(out)
                utils.writelog(utils.pretty_format_dict(out))
            # Saving
            did_update = bestscore_holder.compare_scores(
                scores["S"]["Micro F1"], epoch)
            if did_update:
                serializers.save_npz(path_snapshot, model)
                utils.writelog("Saved the model to %s" % path_snapshot)
            # Finished?
            if bestscore_holder.ask_finishing(max_patience=10):
                utils.writelog(
                    "Patience %d is over. Training finished successfully." %
                    bestscore_holder.patience)
                writer_train.close()
                if dev_dataset is not None:
                    writer_valid.close()
                return
        else:
            # No validation
            # Saving
            serializers.save_npz(path_snapshot, model)
    def build_ngrams(self, batch_edus, batch_nary_sexp, threshold):
        """
        :type batch_edus: list of list of list of str
        :type batch_nary_sexp: list of list of str
        :type threshold: int
        :rtype: list of (str, int), list of (str, int), list of (str, int), list of (str, int), list of (str, int), list of (str, int)
        """
        # Counting
        counter_span_begin = Counter()
        counter_span_end = Counter()
        counter_lc_begin = Counter()
        counter_lc_end = Counter()
        counter_rc_begin = Counter()
        counter_rc_end = Counter()
        prog_bar = pyprind.ProgBar(len(batch_edus))
        for edus, sexp in zip(batch_edus, batch_nary_sexp):
            ngrams_span_begin = []
            ngrams_span_end = []
            ngrams_lc_begin = []
            ngrams_lc_end = []
            ngrams_rc_begin = []
            ngrams_rc_end = []
            # Aggregate spans from the constituent tree
            tree = treetk.rstdt.postprocess(treetk.sexp2tree(sexp, with_nonterminal_labels=True, with_terminal_labels=False))
            tree.calc_spans()
            spans = treetk.aggregate_spans(tree, include_terminal=False, order="pre-order") # list of (int, int, str)
            n_edus = len(edus)
            for span in spans:
                b, e, l = span
                # Extract N-grams from a span
                part_ngrams_span_begin = self.extract_ngrams(edus[b], position="begin") # list of strj:jw
                part_ngrams_span_end = self.extract_ngrams(edus[e], position="end")
                ngrams_span_begin.extend(part_ngrams_span_begin)
                ngrams_span_end.extend(part_ngrams_span_end)
                # Extract N-grams from the left-context EDU
                if b > 0:
                    part_ngrams_lc_begin = self.extract_ngrams(edus[b-1], position="begin")
                    part_ngrams_lc_end = self.extract_ngrams(edus[b-1], position="end")
                    ngrams_lc_begin.extend(part_ngrams_lc_begin)
                    ngrams_lc_end.extend(part_ngrams_lc_end)
                # Extract N-grams from the right-context EDU
                if e < n_edus-1:
                    part_ngrams_rc_begin = self.extract_ngrams(edus[e+1], position="begin")
                    part_ngrams_rc_end = self.extract_ngrams(edus[e+1], position="end")
                    ngrams_rc_begin.extend(part_ngrams_rc_begin)
                    ngrams_rc_end.extend(part_ngrams_rc_end)
            counter_span_begin.update(ngrams_span_begin)
            counter_span_end.update(ngrams_span_end)
            counter_lc_begin.update(ngrams_lc_begin)
            counter_lc_end.update(ngrams_lc_end)
            counter_rc_begin.update(ngrams_rc_begin)
            counter_rc_end.update(ngrams_rc_end)
            prog_bar.update()

        # Filtering
        counter_span_begin = [(ngram,cnt) for ngram,cnt in counter_span_begin.most_common() if cnt >= threshold]
        counter_span_end = [(ngram,cnt) for ngram,cnt in counter_span_end.most_common() if cnt >= threshold]
        counter_lc_begin = [(ngram,cnt) for ngram,cnt in counter_lc_begin.most_common() if cnt >= threshold]
        counter_lc_end = [(ngram,cnt) for ngram,cnt in counter_lc_end.most_common() if cnt >= threshold]
        counter_rc_begin = [(ngram,cnt) for ngram,cnt in counter_rc_begin.most_common() if cnt >= threshold]
        counter_rc_end = [(ngram,cnt) for ngram,cnt in counter_rc_end.most_common() if cnt >= threshold]

        return counter_span_begin, counter_span_end,\
                counter_lc_begin, counter_lc_end,\
                counter_rc_begin, counter_rc_end
コード例 #11
0
def read_rstdt(split, relation_level, with_root=False):
    """
    :type split: str
    :type relation_level: str
    :type with_root: bool
    :rtype: numpy.ndarray(shape=(dataset_size), dtype="O")
    """
    if not relation_level in ["coarse-grained", "fine-grained"]:
        raise ValueError(
            "relation_level must be 'coarse-grained' or 'fine-grained'")

    config = utils.Config()

    path_root = os.path.join(config.getpath("data"), "rstdt", "wsj", split)

    if relation_level == "coarse-grained":
        relation_mapper = treetk.rstdt.RelationMapper()

    # Reading
    dataset = []

    filenames = os.listdir(path_root)
    filenames = [n for n in filenames if n.endswith(".edus.tokens")]
    filenames.sort()

    for filename in filenames:
        # Path
        path_edus = os.path.join(path_root, filename + ".preprocessed")
        path_edus_postag = os.path.join(
            path_root, filename.replace(".edus.tokens", ".edus.postags"))
        path_edus_head = os.path.join(
            path_root, filename.replace(".edus.tokens", ".edus.heads"))
        path_sbnds = os.path.join(path_root,
                                  filename.replace(".edus.tokens", ".sbnds"))
        path_pbnds = os.path.join(path_root,
                                  filename.replace(".edus.tokens", ".pbnds"))
        path_nary_sexp = os.path.join(
            path_root, filename.replace(".edus.tokens", ".labeled.nary.ctree"))
        path_bin_sexp = os.path.join(
            path_root, filename.replace(".edus.tokens", ".labeled.bin.ctree"))
        path_arcs = os.path.join(path_root,
                                 filename.replace(".edus.tokens", ".arcs"))

        kargs = OrderedDict()

        # EDUs
        edus = utils.read_lines(path_edus, process=lambda line: line.split())
        if with_root:
            edus = [["<root>"]] + edus
        kargs["edus"] = edus

        # EDU IDs
        edu_ids = np.arange(len(edus)).tolist()
        kargs["edu_ids"] = edu_ids

        # EDUs (POS tags)
        edus_postag = utils.read_lines(path_edus_postag,
                                       process=lambda line: line.split())
        if with_root:
            edus_postag = [["<root>"]] + edus_postag
        kargs["edus_postag"] = edus_postag

        # EDUs (head)
        edus_head = utils.read_lines(path_edus_head,
                                     process=lambda line: tuple(line.split()))
        if with_root:
            edus_head = [("<root>", "<root>", "<root>")] + edus_head
        kargs["edus_head"] = edus_head

        # Sentence boundaries
        sbnds = utils.read_lines(
            path_sbnds,
            process=lambda line: tuple([int(x) for x in line.split()]))
        kargs["sbnds"] = sbnds

        # Paragraph boundaries
        pbnds = utils.read_lines(
            path_pbnds,
            process=lambda line: tuple([int(x) for x in line.split()]))
        kargs["pbnds"] = pbnds

        # Constituent tree
        nary_sexp = utils.read_lines(path_nary_sexp,
                                     process=lambda line: line.split())[0]
        bin_sexp = utils.read_lines(path_bin_sexp,
                                    process=lambda line: line.split())[0]
        if relation_level == "coarse-grained":
            nary_tree = treetk.rstdt.postprocess(
                treetk.sexp2tree(nary_sexp,
                                 with_nonterminal_labels=True,
                                 with_terminal_labels=False))
            bin_tree = treetk.rstdt.postprocess(
                treetk.sexp2tree(bin_sexp,
                                 with_nonterminal_labels=True,
                                 with_terminal_labels=False))
            nary_tree = treetk.rstdt.map_relations(nary_tree, mode="f2c")
            bin_tree = treetk.rstdt.map_relations(bin_tree, mode="f2c")
            nary_sexp = treetk.tree2sexp(nary_tree)
            bin_sexp = treetk.tree2sexp(bin_tree)
        kargs["nary_sexp"] = nary_sexp
        kargs["bin_sexp"] = bin_sexp

        # Dependency tree
        hyphens = utils.read_lines(path_arcs,
                                   process=lambda line: line.split())
        assert len(hyphens) == 1
        hyphens = hyphens[0]  # list of str
        arcs = treetk.hyphens2arcs(hyphens)  # list of (int, int, str)
        if relation_level == "coarse-grained":
            arcs = [(h, d, relation_mapper.f2c(l)) for h, d, l in arcs]
        kargs["arcs"] = arcs

        # DataInstance
        # data = utils.DataInstance(
        #                 edus=edus,
        #                 edu_ids=edu_ids,
        #                 edus_postag=edus_postag,
        #                 edus_head=edus_head,
        #                 sbnds=sbnds,
        #                 pbnds=pbnds,
        #                 nary_sexp=nary_sexp,
        #                 bin_sexp=bin_sexp,
        #                 arcs=arcs)
        data = utils.DataInstance(**kargs)
        dataset.append(data)

    # NOTE that sentence/paragraph boundaries do NOT consider ROOTs even if with_root=True.

    dataset = np.asarray(dataset, dtype="O")

    n_docs = len(dataset)

    n_paras = 0
    for data in dataset:
        n_paras += len(data.pbnds)

    n_sents = 0
    for data in dataset:
        n_sents += len(data.sbnds)

    n_edus = 0
    for data in dataset:
        if with_root:
            n_edus += len(data.edus[1:])  # Exclude the ROOT
        else:
            n_edus += len(data.edus)

    utils.writelog("split=%s" % split)
    utils.writelog("# of documents=%d" % n_docs)
    utils.writelog("# of paragraphs=%d" % n_paras)
    utils.writelog("# of sentences=%d" % n_sents)
    utils.writelog("# of EDUs (w/o ROOTs)=%d" % n_edus)
    return dataset
コード例 #12
0
# Test Affinity Propagation
model = clustering.clustering(vectors=vectors,
                              method="affinitypropagation",
                              params={"preference": -50})
cluster_ids = model.get_cluster_assignments()
cluster_centers = model.get_cluster_centers()
predicted_cluster_ids = model.predict_clusters(vectors)

# Test Mean Shift
model = clustering.clustering(vectors=vectors, method="meanshift", params={})
cluster_ids = model.get_cluster_assignments()
cluster_centers = model.get_cluster_centers()
predicted_cluster_ids = model.predict_clusters(vectors)

# Test Agglomerative Clustering
model = clustering.clustering(vectors=vectors,
                              method="agglomerative",
                              params={
                                  "n_clusters": 4,
                                  "linkage": "ward",
                                  "n_neighbors": 10
                              })
cluster_ids = model.get_cluster_assignments()
sexp = model.get_tree_sexp()
tree = treetk.sexp2tree(treetk.preprocess(sexp),
                        with_nonterminal_labels=False,
                        with_terminal_labels=False)
# treetk.pretty_print(tree)

print("OK")
def main(args):
    config = utils.Config()

    utils.mkdir(os.path.join(config.getpath("data"), "rstdt-vocab"))

    filenames = os.listdir(
        os.path.join(config.getpath("data"), "rstdt", "renamed"))
    filenames = [n for n in filenames if n.endswith(".edus")]
    filenames.sort()

    # Concat
    filepaths = [
        os.path.join(config.getpath("data"), "rstdt", "tmp.preprocessing",
                     filename + ".tokenized.lowercased.replace_digits")
        for filename in filenames
    ]
    textpreprocessor.concat.run(
        filepaths,
        os.path.join(config.getpath("data"), "rstdt", "tmp.preprocessing",
                     "concat.tokenized.lowercased.replace_digits"))

    # Build vocabulary
    if args.with_root:
        special_words = ["<root>"]
    else:
        special_words = []
    textpreprocessor.create_vocabulary.run(
        os.path.join(config.getpath("data"), "rstdt", "tmp.preprocessing",
                     "concat.tokenized.lowercased.replace_digits"),
        os.path.join(config.getpath("data"), "rstdt-vocab", "words.vocab.txt"),
        prune_at=50000,
        min_count=-1,
        special_words=special_words,
        with_unk=True)

    # Build vocabulary for fine-grained/coarse-grained relations
    relation_mapper = treetk.rstdt.RelationMapper()
    frelations = []
    crelations = []
    nuclearities = []
    for filename in filenames:
        sexp = utils.read_lines(os.path.join(
            config.getpath("data"), "rstdt", "renamed",
            filename.replace(".edus", ".labeled.bin.ctree")),
                                process=lambda line: line)
        sexp = treetk.preprocess(sexp)
        tree = treetk.rstdt.postprocess(
            treetk.sexp2tree(sexp,
                             with_nonterminal_labels=True,
                             with_terminal_labels=False))
        nodes = treetk.traverse(tree,
                                order="pre-order",
                                include_terminal=False,
                                acc=None)
        part_frelations = []
        part_crelations = []
        part_nuclearities = []
        for node in nodes:
            relations_ = node.relation_label.split("/")
            part_frelations.extend(relations_)
            part_crelations.extend(
                [relation_mapper.f2c(r) for r in relations_])
            part_nuclearities.append(node.nuclearity_label)
        if args.with_root:
            part_frelations.append("<root>")
            part_crelations.append("<root>")
        frelations.append(part_frelations)
        crelations.append(part_crelations)
        nuclearities.append(part_nuclearities)

    fcounter = utils.get_word_counter(lines=frelations)
    ccounter = utils.get_word_counter(lines=crelations)
    ncounter = utils.get_word_counter(lines=nuclearities)
    frelations = fcounter.most_common()  # list of (str, int)
    crelations = ccounter.most_common()  # list of (str, int)
    nuclearities = ncounter.most_common()  # list of (str, int)
    utils.write_vocab(
        os.path.join(config.getpath("data"), "rstdt-vocab",
                     "relations.fine.vocab.txt"), frelations)
    utils.write_vocab(
        os.path.join(config.getpath("data"), "rstdt-vocab",
                     "relations.coarse.vocab.txt"), crelations)
    utils.write_vocab(
        os.path.join(config.getpath("data"), "rstdt-vocab",
                     "nuclearities.vocab.txt"), nuclearities)