def main(): pred_trees,ids = trees.load_trees_with_idx(pred_tree_file,id_file) gold_trees,ids = trees.load_trees_with_idx(gold_tree_file,id_file) turn_med_breaks = 0 sent_break_pauses = [] pre_edit_breaks = 0 post_edit_breaks = 0 total_turn_medial_positions = 0 # number of positions between words that sentence breaks could go. sum(len(turn)-1) over all turns total_pred_breaks = 0 total_gold_edits = 0 for i,tree in enumerate(pred_trees): total_turn_medial_positions += (get_wd_len(tree)-1) break_idxs = get_sent_break_idx(tree) gold_edit_idxs = get_edit_idxs(gold_trees[i]) if gold_edit_idxs: total_gold_edits += len(gold_edit_idxs[0]) if break_idxs: total_pred_breaks += len(break_idxs)-1 turn_id = ids[i] sent_break_pauses.extend(get_break_pauses(break_idxs[:-1],turn2pause[turn_id])) #Q: what pauses happen at sent breaks? turn_med_breaks += len(break_idxs[:-1]) if gold_edit_idxs and break_idxs: pre_edit_idxs = gold_edit_idxs[0] post_edit_idxs = gold_edit_idxs[1] pre_edit_breaks += intersection_size(pre_edit_idxs,break_idxs) post_edit_breaks += intersection_size(post_edit_idxs,break_idxs) pause_counts = count_pauses(sent_break_pauses) print(pred_tree_file.split('/')[-1]) print('Pauses at turn-internal sentence breaks:') print(pause_counts) print(f'Turn-medial breaks: {turn_med_breaks}') print(f'Total edits: {total_gold_edits}') print(f'Predicted breaks pre-edits: {pre_edit_breaks}') print(f'Predicted breaks post-edit: {post_edit_breaks}') print(f'Total turn medial positions: {total_turn_medial_positions}')
def get_stats(path_to_trees, path_to_sent_ids): tokens = [] speakers = set() scenarios = set() trees, sent_ids = T.load_trees_with_idx(path_to_trees, path_to_sent_ids, strip_top=False) for sentence_id, file in wav_files: if file.endswith('.wav') and sentence_id in sent_ids: index_of_sent_id = sent_ids.index(sentence_id) speakers.add(sentence_id2speaker[sentence_id]) if "1.3" in file: scenarios.add("d") elif "2.3" in file: # 5th character = scenario scenarios.add( os.path.splitext(os.path.basename(file))[0][4]) else: raise ValueError tree = trees[index_of_sent_id] # print(tree) transcription = [] for child in tree.leaves(): try: word = child.word # print(word) transcription.append(word) except AttributeError: # print(tree) pass tokens.append(len(transcription)) # print(tokens) return tokens, speakers, scenarios
id_file = os.path.join(turndir,'turn_dev_sent_ids_medium.txt') turn2part = pickle.load(open(os.path.join(turndir,'turn_dev_partition.pickle'),'rb')) turn2pitch = pickle.load(open(os.path.join(turndir,'turn_dev_pitch.pickle'),'rb')) turn2fbank = pickle.load(open(os.path.join(turndir,'turn_dev_fbank.pickle'),'rb')) turn2pause = pickle.load(open(os.path.join(turndir,'turn_dev_pause.pickle'),'rb')) turn2dur = pickle.load(open(os.path.join(turndir,'turn_dev_duration.pickle'),'rb')) turn_ids = [l.strip() for l in open(os.path.join(turndir,'turn_dev_sent_ids_medium.txt'),'r').readlines()] turn_trees = [l.strip() for l in open(os.path.join(turndir,'turn_dev_medium.trees'),'r').readlines()] sent2turn = pickle.load(open(os.path.join(datadir,'sent2turn.pickle'),'rb')) turn2sent = pickle.load(open(os.path.join(datadir,'turn2sent.pickle'),'rb')) treestrings = [l.strip() for l in open(pred_tree_file).readlines()] tree_list,ids = trees.load_trees_with_idx(pred_tree_file,id_file) turn2tree = dict(zip(ids,tree_list)) turn2treestring = dict(zip(ids,treestrings)) def get_wd_len(constituent): leaves = 0 for leaf in constituent.leaves(): leaves += 1 return leaves def get_sent_break_idx(tree): if len(tree.children) == 1: return False else:
with open(time_file) as f: lines = f.readlines() for i,line in enumerate(lines): conv,spk,sent,_,_,_ = line.split('\t') sent = sent.split('~')[-1] id_num = '_'.join([conv,spk,sent]) sent_ids.append(id_num) """ id_file = os.path.join(data_dir, f'{split}_sent_ids.txt') with open(id_file, 'r') as f: sent_ids = [l.strip() for l in f.readlines()] sent_ids = sorted(sent_ids) tree_dict = {} loaded_trees, tree_ids = trees.load_trees_with_idx(tree_file, id_file) for tree_id, tree in zip(tree_ids, loaded_trees): tree_dict[tree_id] = [(leaf.tag, leaf.word) for leaf in tree.leaves()] for sent in sent_ids: print(sent) if sent in dur_dict: dur = dur_dict[sent] part = part_dict[sent] pitch = pitch_dict[sent] fbank = fbank_dict[sent] pause = pause_dict[sent] tree = tree_dict[sent] tree_wds = len(tree) dur_wds = dur.shape[-1] if not tree_wds == dur_wds:
def get_turn_stats(): in_dir = "/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/data/input_features/turn_pause_dur_fixed" out_dir = "/afs/inf.ed.ac.uk/user/s20/s2096077/prosody_nlp/data/input_features" # load turn2sent : get #SUs per turn with open( os.path.join( "/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/data/input_features", "turn2sent.pickle"), "rb") as f: turn2sent = pickle.load(f) with open(os.path.join(out_dir, "turn_stats.txt"), "w") as f: for split in ["train", "dev", "test"]: path_to_trees = os.path.join(in_dir, "turn_{}_medium.trees".format(split)) path_to_sent_ids = os.path.join( in_dir, "turn_{}_sent_ids_medium.txt".format(split)) # get trees and ids of turns trees, sent_ids = T.load_trees_with_idx(path_to_trees, path_to_sent_ids, strip_top=True) tokens = [] sus = [] disfluencies = [] disfluent_turns = 0 for turn_id in turn2sent: disfluencies_in_tree = 0 if turn_id in sent_ids: index_of_turn_id = sent_ids.index(turn_id) sus.append(len(turn2sent[turn_id])) else: continue tree = trees[index_of_turn_id] # print(tree) transcription = [] linearized_tree = tree.linearize() disfluencies_in_tree = linearized_tree.count("EDITED") for child in tree.leaves(): try: word = child.word # print(word) transcription.append(word) except AttributeError: # print(tree) pass tokens.append(len(transcription)) if disfluencies_in_tree != 0: disfluent_turns += 1 disfluencies.append(disfluencies_in_tree) f.write("Stats for: " + split + "\n") f.write("Mean number of tokens per turn: " + str(np.mean(np.array(tokens))) + "\n") f.write("Mean number of SUs per turn: " + str(np.mean(np.array(sus))) + "\n") f.write("SUs per turn: " + str(Counter(sus)) + "\n") f.write("Mean number of disfluencies per turn: " + str(np.mean(np.array(disfluencies))) + "\n") f.write("#Turns with disfluencies: " + str(disfluent_turns) + "\n") f.write("Percentage of disfluent turns: " + str(disfluent_turns / len(sent_ids)) + "\n") f.write("=================================" + "\n")
ids = load_lines(os.path.join(data, 'turn_dev.ids')) turn2sent = load_pickle(os.path.join(data, '..', '..', 'turn2sent.pickle')) ### STEP 1 ################################################# # correctly segmented turns -> add sent IDS to list # incorrectly segmented turns # -> combine sent ids and add to list # -> split sent ids and add to list ############################################################ tree_file = os.path.join(data, '..', '..', 'sentence_pause_dur_fixed', 'dev.trees') id_file = os.path.join(data, '..', '..', 'sentence_pause_dur_fixed', 'dev_sent_ids.txt') treebank, all_sents = trees.load_trees_with_idx( tree_file, id_file) # EKN trees get loaded in as trees here id2senttree = dict(zip(all_sents, treebank)) wrong_turn_ids = [] usable_sent_ids = [] for idnum, gold, pred in zip(ids, golds, preds): if gold == pred: usable_sent_ids.extend(turn2sent[idnum]) else: g_list = np.array([int(i) for i in gold.split()]) p_list = np.array([int(i) for i in pred.split()]) if np.sum(g_list) > np.sum(p_list): print(f'{idnum}: UNDERSEG') sents = turn2sent[idnum] idx = 0 curr_sent = ''
def get_turn_stats(lang): path_to_vm_annotations_1 = "/group/corporapublic/verbmobil/1.3" path_to_vm_annotations_2 = "/group/corpora/large4/verbmobil/2.3" out_dir = "/afs/inf.ed.ac.uk/group/msc-projects/s2096077/vm_{}_turns/".format( lang) if lang == "eng": path_to_input_features = "/afs/inf.ed.ac.uk/user/s20/s2096077/prosody_nlp/data/vm/input_features" elif lang == "ger": path_to_input_features = "/afs/inf.ed.ac.uk/user/s20/s2096077/prosody_nlp/data/vm/ger/input_features" else: raise ValueError # load turn2sent : get #SUs per turn with open(os.path.join(path_to_input_features, "turn2sent.pickle"), "rb") as f: turn2sent = pickle.load(f) with open(os.path.join(out_dir, "stats.txt"), "w") as f: for split in ["train", "dev", "test"]: path_to_trees = "/afs/inf.ed.ac.uk/group/msc-projects/s2096077/vm_{}_turns/turn_{}.trees".format( lang, split) path_to_sent_ids = "/afs/inf.ed.ac.uk/group/msc-projects/s2096077/vm_{}_turns/turn_{}_sent_ids.txt".format( lang, split) # get trees and ids of turns trees, sent_ids = T.load_trees_with_idx(path_to_trees, path_to_sent_ids, strip_top=False) tokens = [] sus = [] disfluencies = [] disfluent_turns = 0 for turn_id in turn2sent: if turn_id in sent_ids: index_of_turn_id = sent_ids.index(turn_id) sus.append(len(turn2sent[turn_id])) if len(turn2sent[turn_id]) == 18: print(turn_id) else: continue tree = trees[index_of_turn_id] # print(tree) transcription = [] for child in tree.leaves(): try: word = child.word # print(word) transcription.append(word) except AttributeError: # print(tree) pass tokens.append(len(transcription)) # print(split, np.mean(np.array(tokens)), np.mean(np.array(sus))) # do something with par files to get info about disfluencies if len(turn_id.split("_")) > 3: turn_id = turn_id[:-4] if os.path.isfile( os.path.join(path_to_vm_annotations_1, turn_id[0:5], turn_id + ".par")): path_to_par = os.path.join(path_to_vm_annotations_1, turn_id[0:5], turn_id + ".par") elif os.path.isfile( os.path.join(path_to_vm_annotations_2, turn_id[0:5], turn_id + ".par")): path_to_par = os.path.join(path_to_vm_annotations_2, turn_id[0:5], turn_id + ".par") else: print(turn_id) raise FileNotFoundError # print(path_to_par) with open(path_to_par, "r", encoding="ISO-8859-1") as parfile: partext = parfile.read() tr2 = re.findall('TR2:\s\d+\s.*', partext) full_transcription = "" for line in tr2: transcription = line.split("\t")[2] full_transcription += transcription + " " false_starts = re.findall("-/(.*?)/-", full_transcription) repetitions = re.findall("\+/(.*?)/\+", full_transcription) # filled_pauses = re.findall('<((?:uhm|uh|hm|hes|"ah|"ahm))>', # full_transcription) # print(full_transcription) if false_starts or repetitions: disfluent_turns += 1 disfluencies.append(len(false_starts) + len(repetitions)) else: disfluencies.append(0) f.write("Stats for: " + split + "\n") f.write("Mean number of tokens per turn: " + str(np.mean(np.array(tokens))) + "\n") f.write("Mean number of SUs per turn: " + str(np.mean(np.array(sus))) + "\n") f.write("SUs per turn: " + str(Counter(sus)) + "\n") f.write("Mean number of disfluencies per turn: " + str(np.mean(np.array(disfluencies))) + "\n") f.write("#Turns with disfluencies: " + str(disfluent_turns) + "\n") f.write("Percentage of disfluent turns: " + str(disfluent_turns / len(sent_ids)) + "\n") f.write("=================================" + "\n")
gold_f = os.path.join(datadir, 'turn_dev_medium.trees') turn_ids_file = os.path.join(datadir, 'turn_dev_sent_ids.txt') turn_ids = [l.strip() for l in open(turn_ids_file).readlines()] singlesent_turn_ids = set([ l.strip() for l in open(os.path.join( datadir, 'singlesent_turns_dev_sent_ids.txt')).readlines() ]) multisent_turn_ids = set([ l.strip() for l in open( os.path.join(datadir, 'multisent_turns_dev_sent_ids.txt')).readlines() ]) pros_trees, pros_ids = trees.load_trees_with_idx(pros_f, turn_ids_file, strip_top=True) nonpros_trees, nonpros_ids = trees.load_trees_with_idx(nonpros_f, turn_ids_file, strip_top=True) gold_trees, gold_ids = trees.load_trees_with_idx(gold_f, turn_ids_file, strip_top=True) id2prostree = dict(zip(pros_ids, pros_trees)) id2nonprostree = dict(zip(nonpros_ids, nonpros_trees)) id2goldtree = dict(zip(gold_ids, gold_trees)) tree_dict = {'pros': id2prostree, 'nonpros': id2nonprostree} oversplit = {'pros': [], 'nonpros': []}
import os import trees outdir = '/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/code/self_attn_speech_parser/output/turn_pause_dur_fixed' datadir = '/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/data/input_features/turn_pause_dur_fixed' gold_path = os.path.join(datadir, 'turn_dev.UTT_SEG.trees') id_path = os.path.join(datadir, 'turn_dev_sent_ids.UTT_SEG.txt') sp_pred_path = os.path.join( outdir, 'turn_seg_72240_dev=99.37.pt_turn_dev_predicted.txt') nonsp_pred_path = os.path.join( outdir, 'turn_nonsp_seg_72240_dev=74.83.pt_turn_dev_predicted.txt') gold_trees, gold_ids = trees.load_trees_with_idx(gold_path, id_path) nonsp_pred_trees, gold_ids = trees.load_trees_with_idx(nonsp_pred_path, id_path) sp_pred_trees, gold_ids = trees.load_trees_with_idx(sp_pred_path, id_path) def const_len(const): const_len = 0 for leaf in const.leaves(): const_len += 1 return const_len def calc_SER(gold_ids, gold_trees, pred_trees): correct_preds = 0 incorrect_preds = 0 total_gold_bounds = 0 for idnum, pred, gold in zip(gold_ids, gold_trees, pred_trees):
def run_viz(args): assert args.model_path_base.endswith( ".pt"), "Only pytorch savefiles supported" print("Loading test trees from {}...".format(args.viz_path)) viz_treebank, viz_sent_ids = trees.load_trees_with_idx(args.viz_path, \ args.viz_sent_id_path) print("Loaded {:,} test examples.".format(len(viz_treebank))) print("Loading model from {}...".format(args.model_path_base)) info = torch_load(args.model_path_base) assert 'hparams' in info[ 'spec'], "Only self-attentive models are supported" parser = parse_model.SpeechParser.from_spec(info['spec'], \ info['state_dict']) viz_feat_dict = {} if info['spec']['speech_features'] is not None: speech_features = info['spec']['speech_features'] print("Loading speech features for test set...") for feat_type in speech_features: print("\t", feat_type) feat_path = os.path.join(args.feature_path, \ args.viz_prefix + '_' + feat_type + '.pickle') with open(feat_path, 'rb') as f: feat_data = pickle.load(f, encoding='latin1') viz_feat_dict[feat_type] = feat_data from viz import viz_attention stowed_values = {} orig_multihead_forward = parse_model.MultiHeadAttention.forward def wrapped_multihead_forward(self, inp, batch_idxs, **kwargs): res, attns = orig_multihead_forward(self, inp, batch_idxs, **kwargs) stowed_values['attns{}'.format( stowed_values["stack"])] = attns.cpu().data.numpy() stowed_values['stack'] += 1 return res, attns parse_model.MultiHeadAttention.forward = wrapped_multihead_forward # Select the sentences we will actually be visualizing max_len_viz = 40 if max_len_viz > 0: viz_treebank = [ tree for tree in viz_treebank if len(list(tree.leaves())) <= max_len_viz ] #viz_treebank = viz_treebank[:1] print("Parsing viz sentences...") viz_data = [] for start_index in range(0, len(viz_treebank), args.eval_batch_size): subbatch_trees = viz_treebank[start_index:start_index + \ args.eval_batch_size] subbatch_sent_ids = viz_sent_ids[start_index:start_index + \ args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for \ leaf in tree.leaves()] for tree in subbatch_trees] subbatch_features = load_features(subbatch_sent_ids, viz_feat_dict) stowed_values = dict(stack=0) predicted, _ = parser.parse_batch(subbatch_sentences, \ subbatch_sent_ids, subbatch_features) del _ predicted = [p.convert() for p in predicted] stowed_values['predicted'] = predicted for snum, sentence in enumerate(subbatch_sentences): sentence_words = [tokens.START] + [x[1] for x in sentence ] + [tokens.STOP] for stacknum in range(stowed_values['stack']): attns_padded = stowed_values['attns{}'.format(stacknum)] attns = attns_padded[snum::len(subbatch_sentences), :len( sentence_words), :len(sentence_words)] dat = viz_attention(sentence_words, attns) viz_data.append(dat) outf = open(args.viz_out, 'wb') pickle.dump(viz_data, outf) outf.close()
def run_test(args): print("Loading test trees from {}...".format(args.test_path)) if args.test_lbls: test_txt = [ l.strip().split() for l in open(args.test_path, 'r').readlines() ] test_lbls = [ l.strip().split() for l in open(args.test_lbls, 'r').readlines() ] test_sent_ids = [ l.strip() for l in open(args.test_sent_id_path, 'r').readlines() ] test_treebank = [(txt, lbl) for txt, lbl in zip(test_txt, test_lbls)] else: test_treebank, test_sent_ids = trees.load_trees_with_idx(args.test_path, \ args.test_sent_id_path, strip_top=False) if not args.new_set: test_pause_path = os.path.join(args.feature_path, args.test_prefix + \ '_pause.pickle') with open(test_pause_path, 'rb') as f: test_pause_data = pickle.load(f, encoding='latin1') # to_remove = set(test_sent_ids).difference(set(test_pause_data.keys())) # to_remove = sorted([test_sent_ids.index(i) for i in to_remove]) # for x in to_remove[::-1]: # test_treebank.pop(x) # test_sent_ids.pop(x) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith(".pt"), "Only pytorch files supported" info = torch_load(args.model_path_base) print(info.keys()) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = parse_model.SpeechParser.from_spec(info['spec'], \ info['state_dict']) from prettytable import PrettyTable total_params = 0 table = PrettyTable(["Modules", "Parameters"]) for name, parameter in parser.named_parameters(): if not parameter.requires_grad: continue param = parameter.numel() table.add_row([name, param]) total_params += param parser.eval() # turn off dropout at evaluation time label_vocab = parser.label_vocab #print("{} ({:,}): {}".format("label", label_vocab.size, \ # sorted(value for value in label_vocab.values))) test_feat_dict = {} if info['spec']['speech_features'] is not None: speech_features = info['spec']['speech_features'] print("Loading speech features for test set...") for feat_type in speech_features: print("\t", feat_type) feat_path = os.path.join(args.feature_path, \ args.test_prefix + '_' + feat_type + '.pickle') with open(feat_path, 'rb') as f: feat_data = pickle.load(f, encoding='latin1') test_feat_dict[feat_type] = feat_data print("Parsing test sentences...") start_time = time.time() test_predicted = [] test_scores = [] pscores = [] gscores = [] with torch.no_grad(): for start_index in range(0, len(test_treebank), args.eval_batch_size): subbatch_treebank = test_treebank[start_index:start_index \ + args.eval_batch_size] subbatch_sent_ids = test_sent_ids[start_index:start_index \ + args.eval_batch_size] if args.test_lbls: # EKN using this instead of the seg flag bc it's an hparam subbatch_txt = [turn[0] for turn in subbatch_treebank] subbatch_lbl = [turn[1] for turn in subbatch_treebank] subbatch_sentences = [[(lbl,txt) for lbl,txt in zip(sent_lbl,sent_txt)] for \ sent_lbl,sent_txt in zip(subbatch_lbl,subbatch_txt)] else: subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in \ tree.leaves()] for tree in subbatch_treebank] subbatch_trees = [t.convert() for t in subbatch_treebank] subbatch_features = load_features(subbatch_sent_ids, test_feat_dict\ , args.sp_off) predicted, scores = parser.parse_batch(subbatch_sentences, \ subbatch_sent_ids, subbatch_features) if not args.get_scores: del scores else: charts = parser.parse_batch(subbatch_sentences, \ subbatch_sent_ids, subbatch_features, subbatch_trees, True) for i in range(len(charts)): decoder_args = dict(sentence_len=len(subbatch_sentences[i]),\ label_scores_chart=charts[i],\ gold=subbatch_trees[i],\ label_vocab=parser.label_vocab, \ is_train=False, \ backoff=True) p_score, _, _, _, _ = chart_helper.decode( False, **decoder_args) g_score, _, _, _, _ = chart_helper.decode( True, **decoder_args) pscores.append(p_score) gscores.append(g_score) test_scores += scores if args.test_lbls: test_predicted.extend(predicted) else: test_predicted.extend([p.convert() for p in predicted]) # DEBUG # print(test_scores) #print(test_score_offsets) with open(args.output_path, 'w') as output_file: for tree in test_predicted: if args.test_lbls: #import pdb;pdb.set_trace() lbls = ' '.join(tree) output_file.write("{}\n".format(lbls)) else: output_file.write("{}\n".format(tree.linearize())) print("Output written to:", args.output_path) if args.get_scores: with open(args.output_path + '.scores', 'w') as output_file: for score1, score2, score3 in zip(test_scores, pscores, gscores): output_file.write("{}\t{}\t{}\n".format( score1, score2, score3)) print("Output scores written to:", args.output_path + '.scores') if args.write_gold: with open(args.test_prefix + '_sent_ids.txt', 'w') as sid_file: for sent_id in test_sent_ids: sid_file.write("{}\n".format(sent_id)) print("Sent ids written to:", args.test_prefix + '_sent_ids.txt') with open(args.test_prefix + '_gold.txt', 'w') as gold_file: for tree in test_treebank: gold_file.write("{}\n".format(tree.linearize())) print("Gold trees written to:", args.test_prefix + '_gold.txt') # The tree loader does some preprocessing to the trees (e.g. stripping TOP # symbols or SPMRL morphological features). We compare with the input file # directly to be extra careful about not corrupting the evaluation. We also # allow specifying a separate "raw" file for the gold trees: the inputs to # our parser have traces removed and may have predicted tags substituted, # and we may wish to compare against the raw gold trees to make sure we # haven't made a mistake. As far as we can tell all of these variations give # equivalent results. ref_gold_path = args.test_path if args.test_path_raw is not None: print("Comparing with raw trees from", args.test_path_raw) ref_gold_path = args.test_path_raw else: # Need this since I'm evaluating on subset ref_gold_path = None if args.test_lbls: test_fscore = evaluate.seg_fscore(test_treebank, test_predicted, is_train=False) else: test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, \ test_predicted, ref_gold_path=ref_gold_path, is_train=False) print("test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ))
def run_train(args, hparams): if args.numpy_seed is not None: print("Setting numpy random seed to {}...".format(args.numpy_seed)) np.random.seed(args.numpy_seed) # EKN Extra assurance of deterministic behavior on GPU #torch.manual_seed(args.numpy_seed) torch.backends.cudnn.deterministic = True if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.numpy_seed) sys.stdout.flush() # Make sure that pytorch is actually being initialized randomly. # On my cluster I was getting highly correlated results from multiple # runs, but calling reset_parameters() changed that. A brief look at the # pytorch source code revealed that pytorch initializes its RNG by # calling std::random_device, which according to the C++ spec is allowed # to be deterministic. seed_from_numpy = np.random.randint(2147483648) print("Manual seed for pytorch:", seed_from_numpy) torch.manual_seed(seed_from_numpy) hparams.set_from_args(args) print("Loading training trees from {}...".format(args.train_path)) #EKN taking out SEG path for now hparams.seg = False if hparams.seg: train_txt = [ l.strip().split() for l in open(args.train_path, 'r').readlines() ] train_lbls = [ l.strip().split() for l in open(args.train_lbls, 'r').readlines() ] train_sent_ids = [ l.strip() for l in open(args.train_sent_id_path, 'r').readlines() ] train_parse = [(txt, lbl) for txt, lbl in zip(train_txt, train_lbls)] else: train_treebank, train_sent_ids = trees.load_trees_with_idx(args.train_path,\ args.train_sent_id_path, strip_top=False) # EKN trees get loaded in as trees here # Note strip_top=True for SWBD print("Processing pause features for training...") pause_path = os.path.join(args.feature_path, args.prefix + 'train_pause.pickle') with open(pause_path, 'rb') as f: pause_data = pickle.load(f, encoding='latin1') print("Processing trees for training...") wsj_sents = set([x for x in train_sent_ids if 'wsj' in x]) if len(wsj_sents) > 0: assert args.speech_features is None to_keep = set(pause_data.keys()) to_keep = to_keep.union(wsj_sents) if not hparams.seg: train_parse = [tree.convert() for tree in train_treebank] # comment this back in, when using speech features: # Removing sentences without speech info # to_remove = set(train_sent_ids).difference(to_keep) # to_remove = sorted([train_sent_ids.index(i) for i in to_remove]) # for x in to_remove[::-1]: # train_parse.pop(x) # train_sent_ids.pop(x) train_set = list( zip(train_sent_ids, train_parse)) # EKN train_set is a list of tuples: (sent_id, tree) print("Loaded {:,} training examples.".format(len(train_set))) # Remove sentences without prosodic features in dev set print("Loading development trees from {}...".format(args.dev_path)) if hparams.seg: dev_txt = [ l.strip().split() for l in open(args.dev_path, 'r').readlines() ] dev_lbls = [ l.strip().split() for l in open(args.dev_lbls, 'r').readlines() ] dev_sent_ids = [ l.strip() for l in open(args.dev_sent_id_path, 'r').readlines() ] dev_treebank = [(txt, lbl) for txt, lbl in zip(dev_txt, dev_lbls)] else: dev_treebank, dev_sent_ids = trees.load_trees_with_idx(args.dev_path, \ args.dev_sent_id_path, strip_top=False) # Note strip_top=True for SWBD dev_pause_path = os.path.join(args.feature_path, args.prefix + \ 'dev_pause.pickle') # comment this back in, when using speech features: # with open(dev_pause_path, 'rb') as f: # dev_pause_data = pickle.load(f, encoding='latin1') # to_remove = set(dev_sent_ids).difference(set(dev_pause_data.keys())) # to_remove = sorted([dev_sent_ids.index(i) for i in to_remove]) # for x in to_remove[::-1]: # dev_treebank.pop(x) # dev_sent_ids.pop(x) #if hparams.max_len_dev > 0: # dev_treebank = [tree for tree in dev_treebank if \ # len(list(tree.leaves())) <= hparams.max_len_dev] print("Loaded {:,} development examples.".format(len(dev_treebank))) print("Constructing vocabularies...") sys.stdout.flush() tag_vocab = vocabulary.Vocabulary() tag_vocab.index(tokens.START) tag_vocab.index(tokens.STOP) tag_vocab.index(tokens.TAG_UNK) word_vocab = vocabulary.Vocabulary() word_vocab.index(tokens.START) word_vocab.index(tokens.STOP) word_vocab.index(tokens.UNK) pause_vocab = vocabulary.Vocabulary() pause_vocab.index(tokens.START) pause_vocab.index(tokens.STOP) label_vocab = vocabulary.Vocabulary() label_vocab.index(()) char_set = set() for v in pause_data.values(): pauses = v['pause_aft'] for p in pauses: pause_vocab.index(str(p)) for tree in train_parse: if hparams.seg: wds = tree[0] lbls = tree[1] for lbl, wd in zip(lbls, wds): tag_vocab.index(lbl) word_vocab.index(wd) char_set |= set(wd) else: nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, trees.InternalParseNode): label_vocab.index(node.label) nodes.extend(reversed(node.children)) else: tag_vocab.index(node.tag) word_vocab.index(node.word) char_set |= set(node.word) char_vocab = vocabulary.Vocabulary() # If codepoints are small (e.g. Latin alphabet), index by codepoint directly highest_codepoint = max(ord(char) for char in char_set) if highest_codepoint < 512: if highest_codepoint < 256: highest_codepoint = 256 else: highest_codepoint = 512 # This also takes care of constants like tokens.CHAR_PAD for codepoint in range(highest_codepoint): char_index = char_vocab.index(chr(codepoint)) assert char_index == codepoint else: char_vocab.index(tokens.CHAR_UNK) char_vocab.index(tokens.CHAR_START_SENTENCE) char_vocab.index(tokens.CHAR_START_WORD) char_vocab.index(tokens.CHAR_STOP_WORD) char_vocab.index(tokens.CHAR_STOP_SENTENCE) for char in sorted(char_set): char_vocab.index(char) tag_vocab.freeze() word_vocab.freeze() label_vocab.freeze() char_vocab.freeze() pause_vocab.freeze() def print_vocabulary(name, vocab): special = {tokens.START, tokens.STOP, tokens.UNK} print("{} ({:,}): {}".format( name, vocab.size, sorted(value for value in vocab.values if value in special) + sorted(value for value in vocab.values if value not in special))) if args.print_vocabs: print_vocabulary("Tag", tag_vocab) print_vocabulary("Word", word_vocab) print_vocabulary("Label", label_vocab) print_vocabulary("Pause", pause_vocab) feat_dict = {} speech_features = None if args.speech_features is not None: speech_features = args.speech_features.split(',') if 'pause' in speech_features: hparams.use_pause = True print("Loading speech features for training set...") for feat_type in speech_features: print("\t", feat_type) feat_path = os.path.join(args.feature_path, \ args.prefix + 'train_' + feat_type + '.pickle') with open(feat_path, 'rb') as f: feat_data = pickle.load(f, encoding='latin1') feat_dict[feat_type] = feat_data dev_feat_dict = {} if args.speech_features is not None: speech_features = args.speech_features.split(',') print("Loading speech features for dev set...") for feat_type in speech_features: print("\t", feat_type) feat_path = os.path.join(args.feature_path, \ args.prefix + 'dev_' + feat_type + '.pickle') with open(feat_path, 'rb') as f: feat_data = pickle.load(f, encoding='latin1') dev_feat_dict[feat_type] = feat_data print("Hyperparameters:") hparams.print() print("Initializing model...") sys.stdout.flush() load_path = args.load_path if load_path is not None: print("Loading parameters from {}".format(load_path)) info = torch_load(load_path) parser = parse_model.SpeechParser.from_spec(info['spec'], \ info['state_dict']) else: parser = parse_model.SpeechParser( tag_vocab, word_vocab, label_vocab, char_vocab, pause_vocab, speech_features, hparams, ) print("Initializing optimizer...") trainable_parameters = [param for param in parser.parameters() \ if param.requires_grad] # SHORTEN PRINTING #print(parser) #for name, param in parser.named_parameters(): # print(name, param.data.shape, param.requires_grad) if args.optimizer == 'SGD': trainer = torch.optim.SGD(trainable_parameters, lr=0.05) else: trainer = torch.optim.Adam(trainable_parameters, \ lr=hparams.learning_rate, \ betas=(0.9, 0.98), eps=1e-9) if load_path is not None: trainer.load_state_dict(info['trainer']) def set_lr(new_lr): for param_group in trainer.param_groups: param_group['lr'] = new_lr assert hparams.step_decay, "Only step_decay schedule is supported" warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( trainer, 'max', factor=hparams.step_decay_factor, patience=hparams.step_decay_patience, verbose=True, ) def schedule_lr(iteration): iteration = iteration + 1 if iteration <= hparams.learning_rate_warmup_steps: set_lr(iteration * warmup_coeff) clippable_parameters = trainable_parameters grad_clip_threshold = np.inf if hparams.clip_grad_norm == 0 \ else hparams.clip_grad_norm print("Training...") total_processed = 0 current_processed = 0 check_every = len(train_parse) / args.checks_per_epoch best_dev_fscore = -np.inf best_dev_model_path = None best_dev_processed = 0 start_time = time.time() def check_dev(): nonlocal best_dev_fscore nonlocal best_dev_model_path nonlocal best_dev_processed dev_start_time = time.time() dev_predicted = [] eval_batch_size = args.eval_batch_size for dev_start_index in range(0, len(dev_treebank), eval_batch_size): subbatch_trees = dev_treebank[dev_start_index:dev_start_index \ + eval_batch_size] subbatch_sent_ids = dev_sent_ids[dev_start_index:dev_start_index \ + eval_batch_size] if hparams.seg: subbatch_txt = [tree[0] for tree in subbatch_trees] subbatch_lbl = [tree[1] for tree in subbatch_trees] subbatch_sentences = [[(lbl,txt) for lbl,txt in zip(sent_lbl,sent_txt)] \ for sent_lbl,sent_txt in zip(subbatch_lbl,subbatch_txt)] else: subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in \ tree.leaves()] for tree in subbatch_trees] subbatch_features = load_features(subbatch_sent_ids, dev_feat_dict) predicted, _ = parser.parse_batch(subbatch_sentences, \ subbatch_sent_ids, subbatch_features) del _ if hparams.seg: dev_predicted.extend(predicted) else: dev_predicted.extend([p.convert() for p in predicted]) if hparams.seg: dev_fscore = evaluate.seg_fscore(dev_treebank, dev_predicted) else: dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted) """ with open('tmp_preds.txt','w') as f: for pred in dev_predicted: f.write(pred.linearize()) f.write('\n') with open('tmp_gold.txt','w') as f: for gold in dev_treebank: f.write(gold.linearize()) f.write('\n') """ print("dev-fscore {} " "dev-elapsed {} " "total-elapsed {}".format( dev_fscore, format_elapsed(dev_start_time), format_elapsed(start_time), )) sys.stdout.flush() if dev_fscore.fscore > best_dev_fscore: if best_dev_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_dev_model_path + ext if os.path.exists(path): print( "Removing previous model file {}...".format(path)) os.remove(path) best_dev_fscore = dev_fscore.fscore best_dev_model_path = "{}_dev={:.2f}".format( args.model_path_base, dev_fscore.fscore) best_dev_processed = total_processed print("Saving new best model to {}...".format(best_dev_model_path)) torch.save( { 'spec': parser.spec, 'state_dict': parser.state_dict(), 'trainer': trainer.state_dict(), }, best_dev_model_path + ".pt") sys.stdout.flush() for epoch in itertools.count(start=1): if args.epochs is not None and epoch > args.epochs: break np.random.shuffle(train_set) epoch_start_time = time.time() for start_index in range(0, len(train_set), args.batch_size): trainer.zero_grad() schedule_lr(total_processed // args.batch_size) batch_loss_value = 0.0 batch_trees = [x[1] for x in \ train_set[start_index:start_index + args.batch_size]] # EKN this is where the trees get batched batch_sent_ids = [x[0] for x in \ train_set[start_index:start_index + args.batch_size]] if hparams.seg: batch_txt = [turn[0] for turn in batch_trees] batch_lbl = [turn[1] for turn in batch_trees] batch_sentences = [[(lbl,txt) for lbl,txt in zip(sent_lbl,sent_txt)] for \ sent_lbl,sent_txt in zip(batch_lbl,batch_txt)] else: batch_sentences = [[(leaf.tag, leaf.word) for leaf \ in tree.leaves()] for tree in batch_trees] # EKN this is where the sentences get broken into tags and words batch_num_tokens = sum(len(sentence) for sentence \ in batch_sentences) for subbatch_sentences, subbatch_trees, subbatch_sent_ids in \ parser.split_batch(batch_sentences, batch_trees, \ batch_sent_ids, args.subbatch_max_tokens): subbatch_features = load_features(subbatch_sent_ids, feat_dict) """ subbatch_num_tokens = [len(sentence) for sentence in subbatch_sentences] for i,sent in enumerate(subbatch_features): feat_len = len(subbatch_features[sent]['partition']) tree_len = subbatch_num_tokens[i] if not feat_len == tree_len: print(sent) """ _, loss = parser.parse_batch(subbatch_sentences, \ subbatch_sent_ids, subbatch_features, subbatch_trees) if hparams.predict_tags: loss = loss[0] / len( batch_trees) + loss[1] / batch_num_tokens else: loss = loss / len(batch_trees) loss_value = float(loss.data.cpu().numpy()) batch_loss_value += loss_value if loss_value > 0: loss.backward() del loss total_processed += len(subbatch_trees) current_processed += len(subbatch_trees) grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, \ grad_clip_threshold) trainer.step() print("epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "batch-loss {:.4f} " "grad-norm {:.4f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, start_index // args.batch_size + 1, int(np.ceil(len(train_parse) / args.batch_size)), total_processed, batch_loss_value, grad_norm, format_elapsed(epoch_start_time), format_elapsed(start_time), )) sys.stdout.flush() # DEBUG: if args.debug and start_index > args.batch_size * 2: print("In debug mode, exiting") exit(0) if current_processed >= check_every: current_processed -= check_every check_dev() # adjust learning rate at the end of an epoch if (total_processed // args.batch_size + 1) > \ hparams.learning_rate_warmup_steps: scheduler.step(best_dev_fscore) if (total_processed - best_dev_processed) > \ ((hparams.step_decay_patience + 1) \ * hparams.max_consecutive_decays * len(train_parse)): print("Terminating due to lack of improvement in dev fscore.") break
sent2part = pickle.load( open(os.path.join(datadir, 'train_partition.pickle'), 'rb')) sent2pitch = pickle.load( open(os.path.join(datadir, 'train_pitch.pickle'), 'rb')) sent2fbank = pickle.load( open(os.path.join(datadir, 'train_fbank.pickle'), 'rb')) sent2pause = pickle.load( open(os.path.join(datadir, 'train_pause.pickle'), 'rb')) sent2dur = pickle.load( open(os.path.join(datadir, 'train_duration.pickle'), 'rb')) sent_treestrings = [ l.strip() for l in open(os.path.join(sentdir, 'train.trees'), 'r').readlines() ] sent_trees, sent_ids = trees.load_trees_with_idx( os.path.join(sentdir, 'train.trees'), os.path.join(datadir, 'train_sent_ids.txt'), strip_top=False) sent2tree = dict(zip(sent_ids, sent_trees)) sent2treestring = dict(zip(sent_ids, sent_treestrings)) turn2part = pickle.load( open(os.path.join(turndir, 'turn_train_partition.pickle'), 'rb')) turn2pitch = pickle.load( open(os.path.join(turndir, 'turn_train_pitch.pickle'), 'rb')) turn2fbank = pickle.load( open(os.path.join(turndir, 'turn_train_fbank.pickle'), 'rb')) turn2pause = pickle.load( open(os.path.join(turndir, 'turn_train_pause.pickle'), 'rb')) turn2dur = pickle.load( open(os.path.join(turndir, 'turn_train_duration.pickle'), 'rb')) turn_ids = [
import trees import os treefile = '/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/data/input_features/sentence_pause_dur_fixed/dev.trees' idfile = '/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/data/input_features/sentence_pause_dur_fixed/dev_sent_ids.txt' trees, ids = trees.load_trees_with_idx(treefile, idfile) lbl2count = {} for idnum, tree in zip(ids, trees): lbl = tree.label if lbl in lbl2count: lbl2count[lbl] += 1 else: lbl2count[lbl] = 1 print("lbl2count") print(lbl2count) output = '/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/code/self_attn_speech_parser/output/turn_pause_dur_fixed' nonsp = os.path.join( output, "turn_medium_nonsp_glove_72240_dev=86.82.pt_turn_dev_predicted.txt") sp = os.path.join( output, "turn_medium_sp_glove_ab_duration_72240_dev=91.32.pt_turn_dev_predicted.txt" ) nonsp = open(nonsp, 'r').read() sp = open(sp, 'r').read()
import os import trees #split = 'train' #split = 'dev' split = 'test' # data_dir = '/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/data/input_features/turn_pause_dur_fixed' data_dir = '/afs/inf.ed.ac.uk/user/s20/s2096077/prosody_nlp/data/input_features/turn_pause_dur_fixed' turn_trees = os.path.join(data_dir, f'turn_{split}.trees') turn_ids = os.path.join(data_dir, f'turn_{split}_sent_ids.txt') trees, ids = trees.load_trees_with_idx(turn_trees, turn_ids, strip_top=False) out_trees = [] out_ids = [] ## Step 1: print stats on dataset leaf_lens = [] for tree, id_num in zip(trees, ids): leaves = tree.leaves() num_leaves = 0 for leaf in leaves: num_leaves += 1 leaf_lens.append(num_leaves) #print(f'{id_num}\t{num_leaves}') print('-' * 50) print(f'Max len: {max(leaf_lens)}') print(f'Mean len: {sum(leaf_lens)/len(leaf_lens)}') for ln in leaf_lens: if ln > 270:
import re from statsmodels.stats import weightstats as stests from scipy.stats import ttest_ind datadir = '/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/data/input_features/' sentdir = os.path.join(datadir,'sentence_pause_dur_fixed') turndir = os.path.join(datadir,'turn_pause_dur_fixed') sent2part = pickle.load(open(os.path.join(sentdir,'train_partition.pickle'),'rb')) sent2pitch = pickle.load(open(os.path.join(sentdir,'train_pitch.pickle'),'rb')) sent2fbank = pickle.load(open(os.path.join(sentdir,'train_fbank.pickle'),'rb')) sent2pause = pickle.load(open(os.path.join(sentdir,'train_pause.pickle'),'rb')) sent2dur = pickle.load(open(os.path.join(sentdir,'train_duration.pickle'),'rb')) sent_treestrings = [l.strip() for l in open(os.path.join(sentdir,'train.trees'),'r').readlines()] sent_trees,sent_ids = trees.load_trees_with_idx(os.path.join(sentdir,'train.trees'),os.path.join(sentdir,'train_sent_ids.txt')) sent2tree = dict(zip(sent_ids,sent_trees)) sent2treestring = dict(zip(sent_ids,sent_treestrings)) turn2part = pickle.load(open(os.path.join(turndir,'turn_train_partition.pickle'),'rb')) turn2pitch = pickle.load(open(os.path.join(turndir,'turn_train_pitch.pickle'),'rb')) turn2fbank = pickle.load(open(os.path.join(turndir,'turn_train_fbank.pickle'),'rb')) turn2pause = pickle.load(open(os.path.join(turndir,'turn_train_pause.pickle'),'rb')) turn2dur = pickle.load(open(os.path.join(turndir,'turn_train_duration.pickle'),'rb')) turn_ids = [l.strip() for l in open(os.path.join(turndir,'turn_train_sent_ids.txt'),'r').readlines()] turn_trees = [l.strip() for l in open(os.path.join(turndir,'turn_train.trees'),'r').readlines()] sent2turn = pickle.load(open(os.path.join(datadir,'sent2turn.pickle'),'rb')) turn2sent = pickle.load(open(os.path.join(datadir,'turn2sent.pickle'),'rb')) turn_medial_sents = [] for turn in turn2sent:
import numpy as np import tempfile from PYEVALB import scorer data_dir = '/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/data/input_features/turn_pause_dur_fixed/' results_dir = '/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/code/self_attn_speech_parser/output/turn_pause_dur_fixed' output_dir = os.path.join(results_dir, 'length_eval') gold_file = os.path.join(data_dir, 'turn_dev_medium.trees') sp_file = os.path.join( results_dir, 'turn_sp_correct_eval_72240_dev=90.90.pt_dev_predicted.txt') nonsp_file = os.path.join( results_dir, 'turn_nonsp_correct_eval_72240_dev=86.09.pt_dev_predicted.txt') gold_id_file = os.path.join(data_dir, 'turn_dev_sent_ids_medium.txt') gold_trees, ids = trees.load_trees_with_idx(gold_file, gold_id_file) nonsp_trees, ids = trees.load_trees_with_idx(nonsp_file, gold_id_file) sp_trees, ids = trees.load_trees_with_idx(sp_file, gold_id_file) id2goldtree = dict(zip(ids, gold_trees)) id2nonsptree = dict(zip(ids, nonsp_trees)) id2sptree = dict(zip(ids, sp_trees)) lens = [] for tree in gold_trees: lens.append(len(list(tree.leaves()))) median = np.quantile(np.array(lens), 0.5) lower_med = np.quantile(np.array(lens), 0.25) upper_med = np.quantile(np.array(lens), 0.75)
out_dir = "/afs/inf.ed.ac.uk/user/s20/s2096077/prosody_nlp/data/vm/ger/vm_word_times" path_to_trees = "/afs/inf.ed.ac.uk/user/s20/s2096077/prosody_nlp/data/vm/ger/input_features/new_trees/all_clean.trees" path_to_sent_ids = "/afs/inf.ed.ac.uk/user/s20/s2096077/prosody_nlp/data/vm/ger/input_features/all_clean_sent_ids.txt" path_to_textgrids = "/afs/inf.ed.ac.uk/group/msc-projects/s2096077/vm_ger_textgrids_complete" with open('sentence_id2recording_{}_new.pickle'.format(lang), 'rb') as handle: sentence_id2recording = pickle.load(handle) with open('sentence_id2speaker_{}.pickle'.format(lang), 'rb') as handle: sentence_id2speaker = pickle.load(handle) # take [:150] for sample data wav_files = list(sentence_id2recording.items()) trees, sent_ids = trees.load_trees_with_idx(path_to_trees, path_to_sent_ids, strip_top=False) for sentence_id, file in wav_files: if file.endswith('.wav') and sentence_id in sent_ids: index_of_sent_id = sent_ids.index(sentence_id) tree = trees[index_of_sent_id] transcription = [] for child in tree.leaves(): try: word = child.word transcription.append(word) except AttributeError: print(tree) pass
data = '/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/data/input_features/turn_pause_dur_fixed/seg' out = '/afs/inf.ed.ac.uk/group/project/prosody/prosody_nlp/code/self_attn_speech_parser/output/turn_pause_dur_fixed/seg' def load_lines(filename): return [l.strip() for l in open(filename, 'r').readlines()] preds = load_lines( os.path.join(out, 'turn_nonsp_seg_72240_dev=73.31.pt_turn_dev_predicted.txt')) tree_file = os.path.join(data, '..', 'turn_dev_medium.trees') id_file = os.path.join(data, '..', 'turn_dev_sent_ids_medium.txt') turn_trees, turn_ids = trees.load_trees_with_idx( tree_file, id_file) # EKN trees get loaded in as trees here id2turntree = dict(zip(turn_ids, turn_trees)) assert len(turn_ids) == len(preds) subturn2tree = {} ordered_subturns = [] for turn, pred in zip(turn_ids, preds): turn_tree = id2turntree[turn] wds = [leaf.word for leaf in turn_tree.leaves()] tags = [leaf.tag for leaf in turn_tree.leaves()] preds = pred.split() assert len(wds) == len(preds) subturn_wds = [] subturn_tags = []