def main(): dataset = dataloader.read_rstdt(split="train", relation_level="coarse-grained", with_root=True) relation_mapper = treetk.rstdt.RelationMapper() i = 0 for data in dataset: edu_ids = data.edu_ids edus = data.edus edus_postag = data.edus_postag edus_head = data.edus_head sbnds = data.sbnds pbnds = data.pbnds nary_sexp = data.nary_sexp bin_sexp = data.bin_sexp arcs = data.arcs print("Data instance #%d" % i) print("\t Paragraph #0") print("\t\t Sentence #0") print("\t\t\t EDU #0") print("\t\t\t\t EDU ID:", edu_ids[0]) print("\t\t\t\t EDU:", edus[0]) print("\t\t\t\t EDU (POS):", edus_postag[0]) print("\t\t\t\t EDU (head):", edus_head[0]) p_i = 1 s_i = 1 e_i = 1 for p_begin, p_end in pbnds: print("\t Paragraph #%d" % p_i) for s_begin, s_end in sbnds[p_begin:p_end+1]: print("\t\t Sentence #%d" % s_i) for edu_id, edu, edu_postag, edu_head in zip(edu_ids[1+s_begin:1+s_end+1], edus[1+s_begin:1+s_end+1], edus_postag[1+s_begin:1+s_end+1], edus_head[1+s_begin:1+s_end+1]): print("\t\t\t EDU #%d" % e_i) print("\t\t\t\t EDU ID:", edu_id) print("\t\t\t\t EDU:", edu) print("\t\t\t\t EDU (POS):", edu_postag) print("\t\t\t\t EDU (head):", edu_head) e_i += 1 s_i += 1 p_i += 1 nary_tree = treetk.rstdt.postprocess(treetk.sexp2tree(nary_sexp, with_nonterminal_labels=True, with_terminal_labels=False)) nary_tree = treetk.rstdt.map_relations(nary_tree, mode="c2a") bin_tree = treetk.rstdt.postprocess(treetk.sexp2tree(bin_sexp, with_nonterminal_labels=True, with_terminal_labels=False)) bin_tree = treetk.rstdt.map_relations(bin_tree, mode="c2a") arcs = [(h,d,relation_mapper.c2a(l)) for h,d,l in arcs] dtree = treetk.arcs2dtree(arcs) treetk.pretty_print(nary_tree) treetk.pretty_print(bin_tree) treetk.pretty_print_dtree(dtree) i += 1
def parse(model, decoder, dataset, path_pred): """ :type model: SpanBasedModel :type decoder: IncrementalCKYDecoder :type dataset: numpy.ndarray :type path_pred: str :rtype: None """ with open(path_pred, "w") as f: for data in pyprind.prog_bar(dataset): edu_ids = data.edu_ids edus = data.edus edus_postag = data.edus_postag edus_head = data.edus_head sbnds = data.sbnds pbnds = data.pbnds # Feature extraction edu_vectors = model.forward_edus(edus, edus_postag, edus_head) # (n_edus, bilstm_dim) padded_edu_vectors = model.pad_edu_vectors( edu_vectors) # (n_edus+2, bilstm_dim) mask_bwd, mask_fwd = model.make_masks( ) # (1, bilstm_dim), (1, bilstm_dim) # Parsing (bracketing) span_scores = precompute_all_span_scores( model=model, edus=edus, edus_postag=edus_postag, sbnds=sbnds, pbnds=pbnds, padded_edu_vectors=padded_edu_vectors, mask_bwd=mask_bwd, mask_fwd=mask_fwd) unlabeled_sexp = decoder.decode(span_scores=span_scores, inputs=edu_ids, sbnds=sbnds, pbnds=pbnds, use_sbnds=True, use_pbnds=True) # list of str unlabeled_tree = treetk.sexp2tree(unlabeled_sexp, with_nonterminal_labels=False, with_terminal_labels=False) unlabeled_tree.calc_spans() unlabeled_spans = treetk.aggregate_spans( unlabeled_tree, include_terminal=False, order="pre-order") # list of (int, int) # Parsing (assigning majority labels to the unlabeled tree) span2label = {(b, e): "<ELABORATION,N/S>" for (b, e) in unlabeled_spans} labeled_tree = treetk.assign_labels(unlabeled_tree, span2label, with_terminal_labels=False) labeled_sexp = treetk.tree2sexp(labeled_tree) f.write("%s\n" % " ".join(labeled_sexp))
def parse(model, decoder, databatch, path_pred): """ :type model: Model :type decoder: IncrementalCKYDecoder :type databatch: DataBatch :type path_pred: str :rtype: None """ with open(path_pred, "w") as f: prog_bar = pyprind.ProgBar(len(databatch)) for edu_ids, edus, edus_postag, edus_head, sbnds, pbnds \ in zip(databatch.batch_edu_ids, databatch.batch_edus, databatch.batch_edus_postag, databatch.batch_edus_head, databatch.batch_sbnds, databatch.batch_pbnds): # Feature extraction edu_vectors = model.forward_edus(edus, edus_postag, edus_head) # (n_edus, bilstm_dim) padded_edu_vectors = model.pad_edu_vectors( edu_vectors) # (n_edus+2, bilstm_dim) mask_bwd, mask_fwd = model.make_masks( ) # (1, bilstm_dim), (1, bilstm_dim) # Parsing (constituency) unlabeled_sexp = decoder.decode( model=model, sexps=edu_ids, edus=edus, edus_postag=edus_postag, sbnds=sbnds, pbnds=pbnds, padded_edu_vectors=padded_edu_vectors, mask_bwd=mask_bwd, mask_fwd=mask_fwd, use_sbnds=True, use_pbnds=True) # list of str unlabeled_tree = treetk.sexp2tree(unlabeled_sexp, with_nonterminal_labels=False, with_terminal_labels=False) unlabeled_tree.calc_spans() unlabeled_spans = treetk.aggregate_spans( unlabeled_tree, include_terminal=False, order="pre-order") # list of (int, int) # Assigning majority labels to the unlabeled tree span2label = {(b, e): "<ELABORATION,N/S>" for (b, e) in unlabeled_spans} labeled_tree = treetk.assign_labels(unlabeled_tree, span2label, with_terminal_labels=False) labeled_sexp = treetk.tree2sexp(labeled_tree) f.write("%s\n" % " ".join(labeled_sexp)) prog_bar.update()
def read_trees(path): """ :type path: str :rtype: list of NonTerminal """ sexps = utils.read_lines(path, process=lambda line: line.split()) trees = [treetk.rstdt.postprocess(treetk.sexp2tree(sexp, with_nonterminal_labels=True, with_terminal_labels=False)) for sexp in sexps] return trees
def main() : config = utils.Config() utils.mkdir(os.path.join(config.getpath("data"), "rstdt-vocab")) filenames = [] for filename in os.listdir(os.path.join(config.getpath("data"), "rstdt", "wsj", "train")): filenames.append(os.path.join(config.getpath("data"), "rstdt", "wsj", "train", filename)) for filename in os.listdir(os.path.join(config.getpath("data"), "rstdt", "wsj", "test")): filenames.append(os.path.join(config.getpath("data"), "rstdt", "wsj", "test", filename)) filenames = [n for n in filenames if n.endswith(".labeled.bin.ctree")] filenames.sort() relation_mapper = treetk.rstdt.RelationMapper() frelations = [] crelations = [] nuclearities = [] for filename in pyprind.prog_bar(filenames): sexp = utils.read_lines(filename, process=lambda line: line) sexp = treetk.preprocess(sexp) tree = treetk.rstdt.postprocess(treetk.sexp2tree(sexp, with_nonterminal_labels=True, with_terminal_labels=False)) nodes = treetk.traverse(tree, order="pre-order", include_terminal=False, acc=None) part_frelations = [] part_crelations = [] part_nuclearities = [] for node in nodes: relations_ = node.relation_label.split("/") part_frelations.extend(relations_) part_crelations.extend([relation_mapper.f2c(r) for r in relations_]) part_nuclearities.append(node.nuclearity_label) part_frelations.append("<root>") part_crelations.append("<root>") frelations.append(part_frelations) crelations.append(part_crelations) nuclearities.append(part_nuclearities) fcounter = utils.get_word_counter(lines=frelations) ccounter = utils.get_word_counter(lines=crelations) ncounter = utils.get_word_counter(lines=nuclearities) frelations = fcounter.most_common() # list of (str, int) crelations = ccounter.most_common() # list of (str, int) nuclearities = ncounter.most_common() # list of (str, int) utils.write_vocab(os.path.join(config.getpath("data"), "rstdt-vocab", "relations.fine.vocab.txt"), frelations) utils.write_vocab(os.path.join(config.getpath("data"), "rstdt-vocab", "relations.coarse.vocab.txt"), crelations) utils.write_vocab(os.path.join(config.getpath("data"), "rstdt-vocab", "nuclearities.vocab.txt"), nuclearities)
def process(path): # NOTE: We use n-ary ctrees (ie., *.labeled.nary.ctree) to generate dtrees. # Morey et al. (2018) demonstrate that scores evaluated on these dtrees are # superficially lower than those on right-heavy binarized trees (ie., *.labeled.bin.ctree). filenames = os.listdir(path) filenames = [n for n in filenames if n.endswith(".labeled.nary.ctree")] filenames.sort() def func_label_rule(node, i, j): relations = node.relation_label.split("/") if len(relations) == 1: return relations[0] # Left-most node is head. else: if i > j: return relations[j] else: return relations[j - 1] for filename in pyprind.prog_bar(filenames): sexp = utils.read_lines(os.path.join(path, filename), process=lambda line: line.split()) assert len(sexp) == 1 sexp = sexp[0] # Constituency ctree = treetk.rstdt.postprocess( treetk.sexp2tree(sexp, with_nonterminal_labels=True, with_terminal_labels=False)) # Dependency # Assign heads ctree = treetk.rstdt.assign_heads(ctree) # Conversion dtree = treetk.ctree2dtree(ctree, func_label_rule=func_label_rule) arcs = dtree.tolist(labeled=True) # Write with open( os.path.join(path, filename.replace(".labeled.nary.ctree", ".arcs")), "w") as f: f.write("%s\n" % " ".join(["%d-%d-%s" % (h, d, l) for h, d, l in arcs]))
def parse(sampler, databatch, path_pred): """ :type sampler: TreeSampler :type databatch: DataBatch :type path_pred: str :rtype: None """ with open(path_pred, "w") as f: prog_bar = pyprind.ProgBar(len(databatch)) for edu_ids, edus, edus_postag, edus_head, sbnds, pbnds \ in zip(databatch.batch_edu_ids, databatch.batch_edus, databatch.batch_edus_postag, databatch.batch_edus_head, databatch.batch_sbnds, databatch.batch_pbnds): # Tree sampling (constituency) unlabeled_sexp = sampler.sample(sexps=edu_ids, edus=edus, edus_head=edus_head, sbnds=sbnds, pbnds=pbnds) unlabeled_tree = treetk.sexp2tree(unlabeled_sexp, with_nonterminal_labels=False, with_terminal_labels=False) unlabeled_tree.calc_spans() unlabeled_spans = treetk.aggregate_spans( unlabeled_tree, include_terminal=False, order="pre-order") # list of (int, int) # Assigning majority labels to the unlabeled tree span2label = {(b, e): "<ELABORATION,N/S>" for (b, e) in unlabeled_spans} labeled_tree = treetk.assign_labels(unlabeled_tree, span2label, with_terminal_labels=False) labeled_sexp = treetk.tree2sexp(labeled_tree) f.write("%s\n" % " ".join(labeled_sexp)) prog_bar.update()
def parse(sampler, dataset, path_pred): """ :type sampler: TreeSampler :type dataset: numpy.ndarray :type path_pred: str :rtype: None """ with open(path_pred, "w") as f: for data in pyprind.prog_bar(dataset): edu_ids = data.edu_ids edus = data.edus edus_head = data.edus_head sbnds = data.sbnds pbnds = data.pbnds # Tree sampling (constituency) unlabeled_sexp = sampler.sample(sexps=edu_ids, edus=edus, edus_head=edus_head, sbnds=sbnds, pbnds=pbnds) unlabeled_tree = treetk.sexp2tree(unlabeled_sexp, with_nonterminal_labels=False, with_terminal_labels=False) unlabeled_tree.calc_spans() unlabeled_spans = treetk.aggregate_spans( unlabeled_tree, include_terminal=False, order="pre-order") # list of (int, int) # Assigning majority labels to the unlabeled tree span2label = {(b, e): "<ELABORATION,N/S>" for (b, e) in unlabeled_spans} labeled_tree = treetk.assign_labels(unlabeled_tree, span2label, with_terminal_labels=False) labeled_sexp = treetk.tree2sexp(labeled_tree) f.write("%s\n" % " ".join(labeled_sexp))
def train(model, decoder, sampler, max_epoch, n_init_epochs, negative_size, batch_size, weight_decay, gradient_clipping, optimizer_name, train_dataset, dev_dataset, path_train, path_valid, path_snapshot, path_pred, path_gold): """ :type model: SpanBasedModel :type decoder: IncrementalCKYDecoder :type sampler: TreeSampler :type max_epoch: int :type n_init_epochs: int :type negative_size: int :type batch_size: int :type weight_decay: float :type gradient_clipping: float :type optimizer_name: str :type train_dataset: numpy.ndarray :type dev_dataset: numpy.ndarray :type path_train: str :type path_valid: str :type path_snapshot: str :type path_pred: str :type path_gold: str :rtype: None """ writer_train = jsonlines.Writer(open(path_train, "w"), flush=True) if dev_dataset is not None: writer_valid = jsonlines.Writer(open(path_valid, "w"), flush=True) boundary_flags = [(True, False)] assert negative_size >= len(boundary_flags) negative_tree_sampler = treesamplers.NegativeTreeSampler() # Optimizer preparation if optimizer_name == "adam": opt = optimizers.Adam() else: raise ValueError("Invalid optimizer_name=%s" % optimizer_name) opt.setup(model) if weight_decay > 0.0: opt.add_hook(chainer.optimizer.WeightDecay(weight_decay)) if gradient_clipping: opt.add_hook(chainer.optimizer.GradientClipping(gradient_clipping)) n_train = len(train_dataset) it = 0 bestscore_holder = utils.BestScoreHolder(scale=100.0) bestscore_holder.init() if dev_dataset is not None: # Initial validation with chainer.using_config("train", False), chainer.no_backprop_mode(): parse(model=model, decoder=decoder, dataset=dev_dataset, path_pred=path_pred) scores = metrics.rst_parseval(pred_path=path_pred, gold_path=path_gold) old_scores = metrics.old_rst_parseval(pred_path=path_pred, gold_path=path_gold) out = { "epoch": 0, "Morey2018": { "Unlabeled Precision": scores["S"]["Precision"] * 100.0, "Precision_info": scores["S"]["Precision_info"], "Unlabeled Recall": scores["S"]["Recall"] * 100.0, "Recall_info": scores["S"]["Recall_info"], "Micro F1": scores["S"]["Micro F1"] * 100.0 }, "Marcu2000": { "Unlabeled Precision": old_scores["S"]["Precision"] * 100.0, "Precision_info": old_scores["S"]["Precision_info"], "Unlabeled Recall": old_scores["S"]["Recall"] * 100.0, "Recall_info": old_scores["S"]["Recall_info"], "Micro F1": old_scores["S"]["Micro F1"] * 100.0 } } writer_valid.write(out) utils.writelog(utils.pretty_format_dict(out)) # Saving bestscore_holder.compare_scores(scores["S"]["Micro F1"], step=0) serializers.save_npz(path_snapshot, model) utils.writelog("Saved the model to %s" % path_snapshot) else: # Saving serializers.save_npz(path_snapshot, model) utils.writelog("Saved the model to %s" % path_snapshot) for epoch in range(1, max_epoch + 1): perm = np.random.permutation(n_train) ########## E-Step (BEGIN) ########## utils.writelog("E step ===>") prog_bar = pyprind.ProgBar(n_train) for inst_i in range(0, n_train, batch_size): ### Mini batch for data in train_dataset[inst_i:inst_i + batch_size]: ### One data instance edu_ids = data.edu_ids edus = data.edus edus_postag = data.edus_postag edus_head = data.edus_head sbnds = data.sbnds pbnds = data.pbnds with chainer.using_config("train", False), chainer.no_backprop_mode(): # Feature extraction edu_vectors = model.forward_edus( edus, edus_postag, edus_head) # (n_edus, bilstm_dim) padded_edu_vectors = model.pad_edu_vectors( edu_vectors) # (n_edus+2, bilstm_dim) mask_bwd, mask_fwd = model.make_masks( ) # (1, bilstm_dim), (1, bilstm_dim) # Positive tree if epoch <= n_init_epochs: pos_sexp = sampler.sample(inputs=edu_ids, edus=edus, edus_head=edus_head, sbnds=sbnds, pbnds=pbnds) else: span_scores = precompute_all_span_scores( model=model, edus=edus, edus_postag=edus_postag, sbnds=sbnds, pbnds=pbnds, padded_edu_vectors=padded_edu_vectors, mask_bwd=mask_bwd, mask_fwd=mask_fwd) pos_sexp = decoder.decode(span_scores=span_scores, inputs=edu_ids, sbnds=sbnds, pbnds=pbnds, use_sbnds=True, use_pbnds=True) pos_tree = treetk.sexp2tree(pos_sexp, with_nonterminal_labels=False, with_terminal_labels=False) pos_tree.calc_spans() pos_spans = treetk.aggregate_spans( pos_tree, include_terminal=False, order="post-order") # list of (int, int) data.pos_spans = pos_spans #NOTE prog_bar.update() ########## E-Step (END) ########## ########## M-Step (BEGIN) ########## utils.writelog("M step ===>") for inst_i in range(0, n_train, batch_size): # Processing one mini-batch # Init loss_bracketing, acc_bracketing = 0.0, 0.0 actual_batchsize = 0 for data in train_dataset[perm[inst_i:inst_i + batch_size]]: # Processing one instance edu_ids = data.edu_ids edus = data.edus edus_postag = data.edus_postag edus_head = data.edus_head sbnds = data.sbnds pbnds = data.pbnds pos_spans = data.pos_spans # NOTE # Feature extraction edu_vectors = model.forward_edus( edus, edus_postag, edus_head) # (n_edus, bilstm_dim) padded_edu_vectors = model.pad_edu_vectors( edu_vectors) # (n_edus+2, bilstm_dim) mask_bwd, mask_fwd = model.make_masks( ) # (1, bilstm_dim), (1, bilstm_dim) # Negative trees pos_neg_spans = [] margins = [] pos_neg_spans.append(pos_spans) with chainer.using_config("train", False), chainer.no_backprop_mode(): for use_sbnds, use_pbnds in boundary_flags: span_scores = precompute_all_span_scores( model=model, edus=edus, edus_postag=edus_postag, sbnds=sbnds, pbnds=pbnds, padded_edu_vectors=padded_edu_vectors, mask_bwd=mask_bwd, mask_fwd=mask_fwd) neg_bin_sexp = decoder.decode( span_scores=span_scores, inputs=edu_ids, sbnds=sbnds, pbnds=pbnds, use_sbnds=use_sbnds, use_pbnds=use_pbnds, gold_spans=pos_spans) # list of str neg_tree = treetk.sexp2tree( neg_bin_sexp, with_nonterminal_labels=False, with_terminal_labels=False) neg_tree.calc_spans() neg_spans = treetk.aggregate_spans( neg_tree, include_terminal=False, order="pre-order") # list of (int, int) margin = compute_tree_distance(pos_spans, neg_spans, coef=1.0) pos_neg_spans.append(neg_spans) margins.append(margin) for _ in range(negative_size - len(boundary_flags)): neg_bin_sexp = negative_tree_sampler.sample(inputs=edu_ids, sbnds=sbnds, pbnds=pbnds) neg_tree = treetk.sexp2tree(neg_bin_sexp, with_nonterminal_labels=False, with_terminal_labels=False) neg_tree.calc_spans() neg_spans = treetk.aggregate_spans( neg_tree, include_terminal=False, order="pre-order") # list of (int, int) margin = compute_tree_distance(pos_spans, neg_spans, coef=1.0) pos_neg_spans.append(neg_spans) margins.append(margin) # Scoring pred_scores = model.forward_spans_for_bracketing( edus=edus, edus_postag=edus_postag, sbnds=sbnds, pbnds=pbnds, padded_edu_vectors=padded_edu_vectors, mask_bwd=mask_bwd, mask_fwd=mask_fwd, batch_spans=pos_neg_spans, aggregate=True) # (1+negative_size, 1) # Bracketing Loss for neg_i in range(negative_size): loss_bracketing += F.clip( pred_scores[1 + neg_i] + margins[neg_i] - pred_scores[0], 0.0, 10000000.0) # Ranking Accuracy pred_scores = F.reshape( pred_scores, (1, 1 + negative_size)) # (1, 1+negative_size) gold_scores = np.zeros((1, ), dtype=np.int32) # (1,) gold_scores = utils.convert_ndarray_to_variable( gold_scores, seq=False) # (1,) acc_bracketing += F.accuracy(pred_scores, gold_scores) actual_batchsize += 1 # Backward & Update actual_batchsize = float(actual_batchsize) loss_bracketing = loss_bracketing / actual_batchsize acc_bracketing = acc_bracketing / actual_batchsize loss = loss_bracketing model.zerograds() loss.backward() opt.update() it += 1 # Write log loss_bracketing_data = float(cuda.to_cpu(loss_bracketing.data)) acc_bracketing_data = float(cuda.to_cpu(acc_bracketing.data)) out = { "iter": it, "epoch": epoch, "progress": "%d/%d" % (inst_i + actual_batchsize, n_train), "progress_ratio": float(inst_i + actual_batchsize) / n_train * 100.0, "Bracketing Loss": loss_bracketing_data, "Ranking Accuracy": acc_bracketing_data * 100.0 } writer_train.write(out) utils.writelog(utils.pretty_format_dict(out)) ########## M-Step (END) ########## if dev_dataset is not None: # Validation with chainer.using_config("train", False), chainer.no_backprop_mode(): parse(model=model, decoder=decoder, dataset=dev_dataset, path_pred=path_pred) scores = metrics.rst_parseval(pred_path=path_pred, gold_path=path_gold) old_scores = metrics.old_rst_parseval(pred_path=path_pred, gold_path=path_gold) out = { "epoch": epoch, "Morey2018": { "Unlabeled Precision": scores["S"]["Precision"] * 100.0, "Precision_info": scores["S"]["Precision_info"], "Unlabeled Recall": scores["S"]["Recall"] * 100.0, "Recall_info": scores["S"]["Recall_info"], "Micro F1": scores["S"]["Micro F1"] * 100.0 }, "Marcu2000": { "Unlabeled Precision": old_scores["S"]["Precision"] * 100.0, "Precision_info": old_scores["S"]["Precision_info"], "Unlabeled Recall": old_scores["S"]["Recall"] * 100.0, "Recall_info": old_scores["S"]["Recall_info"], "Micro F1": old_scores["S"]["Micro F1"] * 100.0 } } writer_valid.write(out) utils.writelog(utils.pretty_format_dict(out)) # Saving did_update = bestscore_holder.compare_scores( scores["S"]["Micro F1"], epoch) if did_update: serializers.save_npz(path_snapshot, model) utils.writelog("Saved the model to %s" % path_snapshot) # Finished? if bestscore_holder.ask_finishing(max_patience=10): utils.writelog( "Patience %d is over. Training finished successfully." % bestscore_holder.patience) writer_train.close() if dev_dataset is not None: writer_valid.close() return else: # No validation # Saving serializers.save_npz(path_snapshot, model)
def build_ngrams(self, batch_edus, batch_nary_sexp, threshold): """ :type batch_edus: list of list of list of str :type batch_nary_sexp: list of list of str :type threshold: int :rtype: list of (str, int), list of (str, int), list of (str, int), list of (str, int), list of (str, int), list of (str, int) """ # Counting counter_span_begin = Counter() counter_span_end = Counter() counter_lc_begin = Counter() counter_lc_end = Counter() counter_rc_begin = Counter() counter_rc_end = Counter() prog_bar = pyprind.ProgBar(len(batch_edus)) for edus, sexp in zip(batch_edus, batch_nary_sexp): ngrams_span_begin = [] ngrams_span_end = [] ngrams_lc_begin = [] ngrams_lc_end = [] ngrams_rc_begin = [] ngrams_rc_end = [] # Aggregate spans from the constituent tree tree = treetk.rstdt.postprocess(treetk.sexp2tree(sexp, with_nonterminal_labels=True, with_terminal_labels=False)) tree.calc_spans() spans = treetk.aggregate_spans(tree, include_terminal=False, order="pre-order") # list of (int, int, str) n_edus = len(edus) for span in spans: b, e, l = span # Extract N-grams from a span part_ngrams_span_begin = self.extract_ngrams(edus[b], position="begin") # list of strj:jw part_ngrams_span_end = self.extract_ngrams(edus[e], position="end") ngrams_span_begin.extend(part_ngrams_span_begin) ngrams_span_end.extend(part_ngrams_span_end) # Extract N-grams from the left-context EDU if b > 0: part_ngrams_lc_begin = self.extract_ngrams(edus[b-1], position="begin") part_ngrams_lc_end = self.extract_ngrams(edus[b-1], position="end") ngrams_lc_begin.extend(part_ngrams_lc_begin) ngrams_lc_end.extend(part_ngrams_lc_end) # Extract N-grams from the right-context EDU if e < n_edus-1: part_ngrams_rc_begin = self.extract_ngrams(edus[e+1], position="begin") part_ngrams_rc_end = self.extract_ngrams(edus[e+1], position="end") ngrams_rc_begin.extend(part_ngrams_rc_begin) ngrams_rc_end.extend(part_ngrams_rc_end) counter_span_begin.update(ngrams_span_begin) counter_span_end.update(ngrams_span_end) counter_lc_begin.update(ngrams_lc_begin) counter_lc_end.update(ngrams_lc_end) counter_rc_begin.update(ngrams_rc_begin) counter_rc_end.update(ngrams_rc_end) prog_bar.update() # Filtering counter_span_begin = [(ngram,cnt) for ngram,cnt in counter_span_begin.most_common() if cnt >= threshold] counter_span_end = [(ngram,cnt) for ngram,cnt in counter_span_end.most_common() if cnt >= threshold] counter_lc_begin = [(ngram,cnt) for ngram,cnt in counter_lc_begin.most_common() if cnt >= threshold] counter_lc_end = [(ngram,cnt) for ngram,cnt in counter_lc_end.most_common() if cnt >= threshold] counter_rc_begin = [(ngram,cnt) for ngram,cnt in counter_rc_begin.most_common() if cnt >= threshold] counter_rc_end = [(ngram,cnt) for ngram,cnt in counter_rc_end.most_common() if cnt >= threshold] return counter_span_begin, counter_span_end,\ counter_lc_begin, counter_lc_end,\ counter_rc_begin, counter_rc_end
def read_rstdt(split, relation_level, with_root=False): """ :type split: str :type relation_level: str :type with_root: bool :rtype: numpy.ndarray(shape=(dataset_size), dtype="O") """ if not relation_level in ["coarse-grained", "fine-grained"]: raise ValueError( "relation_level must be 'coarse-grained' or 'fine-grained'") config = utils.Config() path_root = os.path.join(config.getpath("data"), "rstdt", "wsj", split) if relation_level == "coarse-grained": relation_mapper = treetk.rstdt.RelationMapper() # Reading dataset = [] filenames = os.listdir(path_root) filenames = [n for n in filenames if n.endswith(".edus.tokens")] filenames.sort() for filename in filenames: # Path path_edus = os.path.join(path_root, filename + ".preprocessed") path_edus_postag = os.path.join( path_root, filename.replace(".edus.tokens", ".edus.postags")) path_edus_head = os.path.join( path_root, filename.replace(".edus.tokens", ".edus.heads")) path_sbnds = os.path.join(path_root, filename.replace(".edus.tokens", ".sbnds")) path_pbnds = os.path.join(path_root, filename.replace(".edus.tokens", ".pbnds")) path_nary_sexp = os.path.join( path_root, filename.replace(".edus.tokens", ".labeled.nary.ctree")) path_bin_sexp = os.path.join( path_root, filename.replace(".edus.tokens", ".labeled.bin.ctree")) path_arcs = os.path.join(path_root, filename.replace(".edus.tokens", ".arcs")) kargs = OrderedDict() # EDUs edus = utils.read_lines(path_edus, process=lambda line: line.split()) if with_root: edus = [["<root>"]] + edus kargs["edus"] = edus # EDU IDs edu_ids = np.arange(len(edus)).tolist() kargs["edu_ids"] = edu_ids # EDUs (POS tags) edus_postag = utils.read_lines(path_edus_postag, process=lambda line: line.split()) if with_root: edus_postag = [["<root>"]] + edus_postag kargs["edus_postag"] = edus_postag # EDUs (head) edus_head = utils.read_lines(path_edus_head, process=lambda line: tuple(line.split())) if with_root: edus_head = [("<root>", "<root>", "<root>")] + edus_head kargs["edus_head"] = edus_head # Sentence boundaries sbnds = utils.read_lines( path_sbnds, process=lambda line: tuple([int(x) for x in line.split()])) kargs["sbnds"] = sbnds # Paragraph boundaries pbnds = utils.read_lines( path_pbnds, process=lambda line: tuple([int(x) for x in line.split()])) kargs["pbnds"] = pbnds # Constituent tree nary_sexp = utils.read_lines(path_nary_sexp, process=lambda line: line.split())[0] bin_sexp = utils.read_lines(path_bin_sexp, process=lambda line: line.split())[0] if relation_level == "coarse-grained": nary_tree = treetk.rstdt.postprocess( treetk.sexp2tree(nary_sexp, with_nonterminal_labels=True, with_terminal_labels=False)) bin_tree = treetk.rstdt.postprocess( treetk.sexp2tree(bin_sexp, with_nonterminal_labels=True, with_terminal_labels=False)) nary_tree = treetk.rstdt.map_relations(nary_tree, mode="f2c") bin_tree = treetk.rstdt.map_relations(bin_tree, mode="f2c") nary_sexp = treetk.tree2sexp(nary_tree) bin_sexp = treetk.tree2sexp(bin_tree) kargs["nary_sexp"] = nary_sexp kargs["bin_sexp"] = bin_sexp # Dependency tree hyphens = utils.read_lines(path_arcs, process=lambda line: line.split()) assert len(hyphens) == 1 hyphens = hyphens[0] # list of str arcs = treetk.hyphens2arcs(hyphens) # list of (int, int, str) if relation_level == "coarse-grained": arcs = [(h, d, relation_mapper.f2c(l)) for h, d, l in arcs] kargs["arcs"] = arcs # DataInstance # data = utils.DataInstance( # edus=edus, # edu_ids=edu_ids, # edus_postag=edus_postag, # edus_head=edus_head, # sbnds=sbnds, # pbnds=pbnds, # nary_sexp=nary_sexp, # bin_sexp=bin_sexp, # arcs=arcs) data = utils.DataInstance(**kargs) dataset.append(data) # NOTE that sentence/paragraph boundaries do NOT consider ROOTs even if with_root=True. dataset = np.asarray(dataset, dtype="O") n_docs = len(dataset) n_paras = 0 for data in dataset: n_paras += len(data.pbnds) n_sents = 0 for data in dataset: n_sents += len(data.sbnds) n_edus = 0 for data in dataset: if with_root: n_edus += len(data.edus[1:]) # Exclude the ROOT else: n_edus += len(data.edus) utils.writelog("split=%s" % split) utils.writelog("# of documents=%d" % n_docs) utils.writelog("# of paragraphs=%d" % n_paras) utils.writelog("# of sentences=%d" % n_sents) utils.writelog("# of EDUs (w/o ROOTs)=%d" % n_edus) return dataset
# Test Affinity Propagation model = clustering.clustering(vectors=vectors, method="affinitypropagation", params={"preference": -50}) cluster_ids = model.get_cluster_assignments() cluster_centers = model.get_cluster_centers() predicted_cluster_ids = model.predict_clusters(vectors) # Test Mean Shift model = clustering.clustering(vectors=vectors, method="meanshift", params={}) cluster_ids = model.get_cluster_assignments() cluster_centers = model.get_cluster_centers() predicted_cluster_ids = model.predict_clusters(vectors) # Test Agglomerative Clustering model = clustering.clustering(vectors=vectors, method="agglomerative", params={ "n_clusters": 4, "linkage": "ward", "n_neighbors": 10 }) cluster_ids = model.get_cluster_assignments() sexp = model.get_tree_sexp() tree = treetk.sexp2tree(treetk.preprocess(sexp), with_nonterminal_labels=False, with_terminal_labels=False) # treetk.pretty_print(tree) print("OK")
def main(args): config = utils.Config() utils.mkdir(os.path.join(config.getpath("data"), "rstdt-vocab")) filenames = os.listdir( os.path.join(config.getpath("data"), "rstdt", "renamed")) filenames = [n for n in filenames if n.endswith(".edus")] filenames.sort() # Concat filepaths = [ os.path.join(config.getpath("data"), "rstdt", "tmp.preprocessing", filename + ".tokenized.lowercased.replace_digits") for filename in filenames ] textpreprocessor.concat.run( filepaths, os.path.join(config.getpath("data"), "rstdt", "tmp.preprocessing", "concat.tokenized.lowercased.replace_digits")) # Build vocabulary if args.with_root: special_words = ["<root>"] else: special_words = [] textpreprocessor.create_vocabulary.run( os.path.join(config.getpath("data"), "rstdt", "tmp.preprocessing", "concat.tokenized.lowercased.replace_digits"), os.path.join(config.getpath("data"), "rstdt-vocab", "words.vocab.txt"), prune_at=50000, min_count=-1, special_words=special_words, with_unk=True) # Build vocabulary for fine-grained/coarse-grained relations relation_mapper = treetk.rstdt.RelationMapper() frelations = [] crelations = [] nuclearities = [] for filename in filenames: sexp = utils.read_lines(os.path.join( config.getpath("data"), "rstdt", "renamed", filename.replace(".edus", ".labeled.bin.ctree")), process=lambda line: line) sexp = treetk.preprocess(sexp) tree = treetk.rstdt.postprocess( treetk.sexp2tree(sexp, with_nonterminal_labels=True, with_terminal_labels=False)) nodes = treetk.traverse(tree, order="pre-order", include_terminal=False, acc=None) part_frelations = [] part_crelations = [] part_nuclearities = [] for node in nodes: relations_ = node.relation_label.split("/") part_frelations.extend(relations_) part_crelations.extend( [relation_mapper.f2c(r) for r in relations_]) part_nuclearities.append(node.nuclearity_label) if args.with_root: part_frelations.append("<root>") part_crelations.append("<root>") frelations.append(part_frelations) crelations.append(part_crelations) nuclearities.append(part_nuclearities) fcounter = utils.get_word_counter(lines=frelations) ccounter = utils.get_word_counter(lines=crelations) ncounter = utils.get_word_counter(lines=nuclearities) frelations = fcounter.most_common() # list of (str, int) crelations = ccounter.most_common() # list of (str, int) nuclearities = ncounter.most_common() # list of (str, int) utils.write_vocab( os.path.join(config.getpath("data"), "rstdt-vocab", "relations.fine.vocab.txt"), frelations) utils.write_vocab( os.path.join(config.getpath("data"), "rstdt-vocab", "relations.coarse.vocab.txt"), crelations) utils.write_vocab( os.path.join(config.getpath("data"), "rstdt-vocab", "nuclearities.vocab.txt"), nuclearities)