def train(args): logger = logging.getLogger('SegEDU') logger.info('Loading data...') if args.train_files: train_files = args.train_files else: preprocessed_train_dir = os.path.join(args.rst_dir, 'preprocessed/train/') train_files = [os.path.join(preprocessed_train_dir, filename) for filename in os.listdir(preprocessed_train_dir) if filename.endswith('.preprocessed')] if args.dev_files: dev_files = args.dev_files else: preprocessed_dev_dir = os.path.join(args.rst_dir, 'preprocessed/dev/') dev_files = [os.path.join(preprocessed_dev_dir, filename) for filename in sorted(os.listdir(preprocessed_dev_dir)) if filename.endswith('.preprocessed')] if args.test_files: test_files = args.test_files else: preprocessed_test_dir = os.path.join(args.rst_dir, 'preprocessed/test/') test_files = [os.path.join(preprocessed_test_dir, filename) for filename in os.listdir(preprocessed_test_dir) if filename.endswith('.preprocessed')] rst_data = RSTData(train_files=train_files, dev_files=dev_files, test_files=test_files) logger.info('Loading vocab...') with open(args.word_vocab_path, 'rb') as fin: word_vocab = pickle.load(fin) logger.info('Word vocab size: {}'.format(word_vocab.size())) rst_data.word_vocab = word_vocab logger.info('Initialize the model...') model = AttnSegModel(args, word_vocab) logger.info('Training the model...') model.train(rst_data, args.epochs, args.batch_size, print_every_n_batch=20) logger.info('Done with model training')
def evaluate(args): logger = logging.getLogger('SegEDU') logger.info('Loading data...') if args.test_files: test_files = args.test_files else: preprocessed_test_dir = os.path.join(args.rst_dir, 'preprocessed/test/') test_files = [ os.path.join(preprocessed_test_dir, filename) for filename in os.listdir(preprocessed_test_dir) if filename.endswith('.preprocessed') ] rst_data = RSTData(test_files=test_files) logger.info('Loading vocab...') with open(args.word_vocab_path, 'rb') as fin: word_vocab = pickle.load(fin) logger.info('Word vocab size: {}'.format(word_vocab.size())) rst_data.word_vocab = word_vocab logger.info('Loading the model...') model = AttnSegModel(args, word_vocab) model.restore('best', args.model_dir) eval_batches = rst_data.gen_mini_batches(args.batch_size, test=True, shuffle=False) perf = model.evaluate(eval_batches, print_result=False) logger.info(perf)
def segment(args): """ Segment raw text into edus. """ logger = logging.getLogger('SegEDU') rst_data = RSTData() logger.info('Loading vocab...') with open(args.word_vocab_path, 'rb') as fin: word_vocab = pickle.load(fin) logger.info('Word vocab size: {}'.format(word_vocab.size())) rst_data.word_vocab = word_vocab logger.info('Loading the model...') model = AttnSegModel(args, word_vocab) model.restore('best', args.model_dir) if model.use_ema: model.sess.run(model.ema_backup_op) model.sess.run(model.ema_assign_op) spacy_nlp = spacy.load('en', disable=['parser', 'ner', 'textcat']) for file in args.input_files: logger.info('Segmenting {}...'.format(file)) raw_sents = [] with open(file, 'r') as fin: for line in fin: line = line.strip() if line: raw_sents.append(line) samples = [] for sent in spacy_nlp.pipe(raw_sents, batch_size=1000, n_threads=5): samples.append({ 'words': [token.text for token in sent], 'edu_seg_indices': [] }) rst_data.test_samples = samples data_batches = rst_data.gen_mini_batches(args.batch_size, test=True, shuffle=False) edus = [] for batch in data_batches: batch_pred_segs = model.segment(batch) for sample, pred_segs in zip(batch['raw_data'], batch_pred_segs): one_edu_words = [] for word_idx, word in enumerate(sample['words']): if word_idx in pred_segs: edus.append(' '.join(one_edu_words)) one_edu_words = [] one_edu_words.append(word) if one_edu_words: edus.append(' '.join(one_edu_words)) if not os.path.exists(args.result_dir): os.makedirs(args.result_dir) save_path = os.path.join(args.result_dir, os.path.basename(file)) logger.info('Saving into {}'.format(save_path)) with open(save_path, 'w') as fout: for edu in edus: fout.write(edu + '\n')
def load_model(args): """ Segment raw text into edus. """ logger = logging.getLogger('SegEDU') rst_data = RSTData() logger.info('Loading vocab...') with open(args.word_vocab_path, 'rb') as fin: word_vocab = pickle.load(fin) logger.info('Word vocab size: {}'.format(word_vocab.size())) rst_data.word_vocab = word_vocab logger.info('Loading the model...') model = AttnSegModel(args, word_vocab) model.restore('best', args.model_dir) if model.use_ema: model.sess.run(model.ema_backup_op) model.sess.run(model.ema_assign_op) return model, rst_data, logger
def segment_data(dfs, col_names): """Segment the given dataframes into EDUs, add the EDUs into the dataframes and return""" args = parse_args() np.random.seed(args.seed) tf.set_random_seed(args.seed) # Logging logger = logging.getLogger("SegEDU") logger.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_handler.setFormatter(formatter) logger.addHandler(console_handler) # Loading rst_data = RSTData() logger.info('Loading vocab...') with open(args.word_vocab_path, 'rb') as fin: word_vocab = pickle.load(fin) logger.info('Word vocab size: {}'.format(word_vocab.size())) rst_data.word_vocab = word_vocab logger.info('Loading the model...') model = AttnSegModel(args, word_vocab) model.restore('best', args.model_dir) if model.use_ema: model.sess.run(model.ema_backup_op) model.sess.run(model.ema_assign_op) spacy_nlp = spacy.load('en', disable=['parser', 'ner', 'textcat']) for df, col_name in zip(dfs, col_names): edu_results = {} for idx, row in tqdm(df.iterrows(), total=len(df.index)): try: # logger.info('Segmenting example {}...'.format(idx)) raw_sents = [row[col_name]] samples = [] for sent in spacy_nlp.pipe(raw_sents, batch_size=1000, n_threads=5): samples.append({ 'words': [token.text for token in sent], 'words_ws': [token.text_with_ws for token in sent], 'edu_seg_indices': [] }) rst_data.test_samples = samples data_batches = rst_data.gen_mini_batches(args.batch_size, test=True, shuffle=False) edus = [] for batch in data_batches: batch_pred_segs = model.segment(batch) for sample, pred_segs in zip(batch['raw_data'], batch_pred_segs): one_edu_words = [] for word_idx, word in enumerate(sample['words_ws']): if word_idx in pred_segs: edus.append(''.join(one_edu_words)) one_edu_words = [] one_edu_words.append(word) if one_edu_words: edus.append(''.join(one_edu_words)) edu_results[idx] = edus except: logger.info("Crashed while segmenting {}.".format(idx)) edu_results[idx] = [] continue df['edus'] = pd.Series(edu_results) merged = pd.concat(dfs).reset_index(drop=True) merged = merged[merged['edus'].map(lambda x: len(x)) > 0] # Remove rows with unsegmentable EDUs return merged
def segment(args): """ Segment raw text into edus. """ logger = logging.getLogger('SegEDU') rst_data = RSTData() logger.info('Loading vocab...') with open(args.word_vocab_path, 'rb') as fin: word_vocab = pickle.load(fin) logger.info('Word vocab size: {}'.format(word_vocab.size())) rst_data.word_vocab = word_vocab logger.info('Loading the model...') model = AttnSegModel(args, word_vocab) model.restore('best', args.model_dir) if model.use_ema: model.sess.run(model.ema_backup_op) model.sess.run(model.ema_assign_op) spacy_nlp = spacy.load('en', disable=['parser', 'ner', 'textcat']) spacy_nlp.add_pipe(lambda doc: spacy_nlp.make_doc(" ".join( [token.text for token in doc if token.text != ","])), first=True) for f in args.input_files: # f = "../data/rst/TRAINING/wsj_1103.out" logger.info('Segmenting {}...'.format(f)) raw_sents = [] with open(f, 'r') as fin: for line in fin: line = line.strip() if line: raw_sents.append(line) samples = [] tttt = "" myslot = st.empty() for sent in spacy_nlp.pipe(raw_sents, batch_size=1000, n_threads=5): samples.append({ 'words': [token.text for token in sent], 'edu_seg_indices': [] }) tttt += str(sent) + "\n" myslot.text(tttt) sub = "" rst_data.test_samples = samples data_batches = rst_data.gen_mini_batches(args.batch_size, test=True, shuffle=False) edus = [] for batch in data_batches: batch_pred_segs = model.segment(batch) for sample, pred_segs in zip(batch['raw_data'], batch_pred_segs): rep = sample["words"] indexes = pred_segs start_idx = 0 for i in range(len(indexes) + 1): if i == len(indexes): end_idx = len(rep) else: end_idx = indexes[i] sub += "[" + str( rep[start_idx:end_idx]) + "]" + str(i) + " " start_idx = end_idx sub += "\n" one_edu_words = [] for word_idx, word in enumerate(sample['words']): if word_idx in pred_segs: edus.append(' '.join(one_edu_words)) one_edu_words = [] one_edu_words.append(word) if one_edu_words: edus.append(' '.join(one_edu_words)) myslot.text(sub) if not os.path.exists(args.result_dir): os.makedirs(args.result_dir) save_path = os.path.join(args.result_dir, os.path.basename(f)) logger.info('Saving into {}'.format(save_path)) with open(save_path, 'w') as fout: for edu in edus: fout.write(edu + '\n')