def main(): data_pre = data_pre_1 try: with open(os.path.join(project_base, data_pre + "_dmp"), "r") as fp: print "un-pickling..." ds = cPickle.load(fp) except IOError: print "something wrong with the pickle file path" within_tls = [] for doc in ds.docs.values(): within_tls.extend(doc.tlinks_within_gold) within_tls.extend(doc.tlinks_within_closure) cfg = Config(mallet_bin, project_temp, data_pre + "_within") train_cfg = train_on_data(within_tls, lambda x: x.ds_id, lambda x: x.type, within_tlink_feats, "MaxEnt", cfg) data_pre = data_pre_2 try: with open(os.path.join(project_base, data_pre + "_dmp"), "r") as fp: print "un-pickling..." ds = cPickle.load(fp) except IOError: print "something wrong with the pickle file path" all_a = [] all_c = [] all_u = [] all_i = [] all_m = [] all_um = [] all_uw = [] all_mi = [] all_within = [] all_candid = [] for doc in ds.docs.values(): within = doc.tlinks_within_gold + doc.tlinks_within_closure for tl in within: tl.sent.tlinks_within.append(tl) all_within.extend(within) candid = create_candid_within3(doc) update_candid_id(candid, doc.ds_id) all_candid.extend(candid) cfg = Config(mallet_bin, project_temp, data_pre + "_within_" + doc.ds_id) ti = apply_to_data(candid, lambda x: x.ds_id, within_tlink_feats, train_cfg.model_path, cfg) for c in candid: c.pred = ti.ID2pred[c.ds_id] c.probs = ti.ID2probs[c.ds_id] # for s in doc.sents: # expand(s) for s in doc.sents: # print s.ds_id """ if s.freq_tx != []: print 'freq_tx:', s.freq_tx for tl in s.freq_tl: tl.pred = 'OVERLAP' tx3s_copy = s.timex3s[:] for tx in s.freq_tx: tx3s_copy.remove(tx) a, c, u, i = get_conflict_info2(s.events + tx3s_copy, s.candids_within, lambda x: x.span[0].begin, lambda x: x.pred) """ a, c, u, i = get_conflict_info2( s.events + s.timex3s, s.candids_within, lambda x: x.span[0].begin, lambda x: x.pred ) all_a.extend(a) all_c.extend(c) all_u.extend(u) all_i.extend(i) print "document:", doc.ds_id, "\n" for s in doc.sents: # print s.ds_id # m, um, uw, mi = verify(s.candids_within + s.freq_tl, s.tlinks_within) m, um, uw, mi = verify(s.candids_within, s.tlinks_within) all_m.extend(m) all_um.extend(um) all_uw.extend(uw) all_mi.extend(mi) """ resolve_conflict_within2(s) print 'after resolution:', doc.ds_id, '\n' rslt = verify(s.candids_within, s.tlinks_within) print s.ds_id, rslt, '\n' pair1 = [x + y for (x, y) in zip(pair1, rslt[0])] label1 = [x + y for (x, y) in zip(label1, rslt[1])] """ all_rslt = [all_a, all_c, all_u, all_i, all_m, all_um, all_uw, all_mi, all_within, all_candid] try: with open(os.path.join(project_base, data_pre_2 + "_within_result_dmp"), "w") as fp: cPickle.dump(all_rslt, fp) except: print "something wrong when dumping"
t3.res = get_kth_large_key(t3.probs, 2) change.append(t3) """ for tl in change: tl.pred = tl.res if tl in all_m: mic.append(tl) if tl in all_um: umic.append(tl) if tl in all_uw: uwic.append(tl) print len(change), len(mic), len(umic), len(uwic) print len(set(change)), len(set(mic)), len(set(umic)), len(set(uwic)) m, um, uw, mi = verify(all_candid, all_within) print len(m), len(um), len(uw), len(mi) def get_uniq_tl_from_tri(tris, tls=None): uniq = set() for tri in tris: uniq.update(tri) if not tls: buf = list(uniq) else: buf = [] for tl in uniq: if tl in tls: buf.append(tl) return buf