def __generate_word_matrix(self, index_lookup): """ Generate a BOW matrix with rows, columns corresponding to documents, words respectively. @param index_lookup: A dictionary with keys for the attributes. In order to know which colounm should be incremented in word_matrix. """ batches = s.load(open(env_paths.get_batches_path(self.training), "rb")) length = len(batches) processed = 1 for batch in batches: docs_list = s.load(open(env_paths.get_doc_list_path(self.training, batch), "rb")) bag_of_words_matrix = zeros([len(docs_list), len(index_lookup)]) row = 0 for doc in docs_list: for token in doc: try: # If word is not found in the dictionary col = index_lookup[token] bag_of_words_matrix[row, col] += 1 except KeyError: continue row += 1 # Serialize bag of words s.dump(bag_of_words_matrix.tolist(), open(env_paths.get_bow_matrix_path(self.training, batch), "wb")) print "Processed " + str(processed) + " of " + str(length) + " batches" processed += 1
def __generate_output_data(self): """ Generate the output data of the DBN so that it can be visualised. """ if not len(self.output_data) == 0: return try: self.output_data = s.load(open('output/output_data.p', 'rb')) self.class_indices = s.load(open('output/class_indices.p', 'rb')) if not self.classes_to_visualise == None: self.__filter_output_data(self.classes_to_visualise) except: self.output_data = generate_output_for_test_data( image_data=self.image_data, binary_output=self.binary_output ) if self.testing else generate_output_for_train_data( image_data=self.image_data, binary_output=self.binary_output) self.class_indices = get_all_class_indices( training=False) if self.testing else get_all_class_indices() if not self.classes_to_visualise == None: self.__filter_output_data(self.classes_to_visualise) s.dump([out.tolist() for out in self.output_data], open('output/output_data.p', 'wb')) s.dump(self.class_indices, open('output/class_indices.p', 'wb')) self.legend = get_class_names_for_class_indices( list(set(sorted(self.class_indices))))
def __init__(self, testing=True, binary_output=False): """ @param testing: Should be True if test data is to be plottet. Otherwise False. @param image_data: If the testing should be done on image data. @param binary_output: If the output of the DBN must be binary. """ if not check_for_data: print 'No DBN data or testing data.' return self.status = -1 self.output = [] self.testing = testing self.binary_output = binary_output try: self.output_data = s.load(open('output/output_data.p', 'rb')) self.class_indices = s.load(open('output/class_indices.p', 'rb')) except: self.output_data = generate_output_for_test_data( binary_output=self.binary_output) if testing else generate_output_for_train_data( binary_output=self.binary_output) self.class_indices = get_all_class_indices(training=False) if testing else get_all_class_indices() s.dump([out.tolist() for out in self.output_data], open('output/output_data.p', 'wb')) s.dump(self.class_indices, open('output/class_indices.p', 'wb')) self.output_data = np.array(self.output_data)
def __generate_input_data(self): """ Generate the input data for the DBN so that it can be visualized. """ if not len(self.input_data) == 0: return try: self.input_data = s.load(open('output/input_data.p', 'rb')) self.class_indices = s.load(open('output/class_indices.p', 'rb')) if not self.classes_to_visualise == None: self.__filter_input_data(self.classes_to_visualise) except: self.input_data = generate_input_data_list( training=False) if self.testing else generate_input_data_list( ) self.class_indices = get_all_class_indices( training=False) if self.testing else get_all_class_indices() if not self.classes_to_visualise == None: self.__filter_input_data(self.classes_to_visualise) s.dump([input.tolist() for input in self.input_data], open('output/input_data.p', 'wb')) s.dump(self.class_indices, open('output/class_indices.p', 'wb')) self.legend = get_class_names_for_class_indices( list(set(sorted(self.class_indices))))
def __generate_word_matrix(self, index_lookup): """ Generate a BOW matrix with rows, columns corresponding to documents, words respectively. @param index_lookup: A dictionary with keys for the attributes. In order to know which colounm should be incremented in word_matrix. """ batches = s.load(open(env_paths.get_batches_path(self.training), "rb")) length = len(batches) processed = 1 for batch in batches: docs_list = s.load(open(env_paths.get_doc_list_path(self.training, batch), "rb")) bag_of_words_matrix = zeros([len(docs_list), len(index_lookup)]) row = 0 for doc in docs_list: for token in doc: try: # If word is not found in the dictionary col = index_lookup[token] bag_of_words_matrix[row, col] += 1 except KeyError: continue row += 1 # Serialize bag of words s.dump(bag_of_words_matrix.tolist(), open(env_paths.get_bow_matrix_path(self.training, batch), "wb")) print 'Processed ' + str(processed) + ' of ' + str(length) + ' batches' processed += 1
def __save_output__(self,batch_index,outputs): """ Serialize the output of the rbm. @param batch_index: Index of the batch. @param outputs: The output probabilitites of the rbm """ s.dump(outputs.tolist() , open(env_paths.get_rbm_output_path(str(self.num_hid),batch_index,self.layer_index), "wb" ) )
def __save_output__(self, batch_index, outputs): """ Serialize the output of the rbm. @param batch_index: Index of the batch. @param outputs: The output probabilitites of the rbm """ s.dump( outputs.tolist(), open( env_paths.get_rbm_output_path(str(self.num_hid), batch_index, self.layer_index), "wb"))
def __set_attributes(self): """ Set the attributes containing of a list of words of all attributes in the bag of words matrix. @return: The generated list of words acting as attributes for the BOWs. """ batches = s.load(open(env_paths.get_batches_path(self.training), "rb")) length = len(batches) attributes = [] processed = 1 for batch in batches: docs_list = s.load(open(env_paths.get_doc_list_path(self.training, batch), "rb")) tmp_attributes = list( set(sorted(list(chain(*docs_list)))) ) # Retrieve the each word of the docs list in a sorted list attributes += tmp_attributes attributes = list( set(sorted(attributes)) ) # Sort the attributes list so that there is no 2 occurrences of the same word. if not self.acceptance_lst == None: attributes = list( set(attributes).intersection(self.acceptance_lst) ) # Only consider words in the acceptance list. print "Processed attribute " + str(processed) + " of " + str(length) + " batches" processed += 1 # Find attributes of the most common words. d = dict.fromkeys(attributes) processed = 1 for batch in batches: docs_list = s.load(open(env_paths.get_doc_list_path(self.training, batch), "rb")) words = list(list(chain(*docs_list))) for w in words: try: if d[w] == None: d[w] = 1 else: d[w] += 1 except KeyError: continue print "Processed summing " + str(processed) + " of " + str(length) + " batches" processed += 1 sorted_att = sorted(d.items(), key=lambda x: x[1]) sorted_att = sorted_att[len(sorted_att) - self.max_words_matrix :] attributes = [elem[0] for elem in sorted_att] # Serialize attributes s.dump(attributes, open(env_paths.get_attributes_path(self.training), "wb")) return attributes
def __stem_doc(doc_details): # Import nltk tools from nltk.tokenize import wordpunct_tokenize as wordpunct_tokenize # from nltk.stem.snowball import EnglishStemmer from nltk.stem.porter import PorterStemmer as EnglishStemmer idx, doc = doc_details if idx % 100 == 0: print "Processed doc " + str(idx) if doc.endswith('.txt'): d = open(doc).read() stemmer = EnglishStemmer() # This method only works for english documents. # Stem, lowercase, substitute all punctuations, remove stopwords. attribute_names = [stemmer.stem(token.lower()) for token in wordpunct_tokenize( re.sub('[%s]' % re.escape(string.punctuation), '', d.decode(encoding='UTF-8', errors='ignore'))) if token.lower() not in stopwords.get_stopwords()] s.dump(attribute_names, open(doc.replace(".txt", ".p"), "wb"))
def save_metadata(metadata, save_dir, _log, _run): filename = os.path.join(save_dir, METADATA_FILENAME) _log.info('Saving metadata to %s', filename) with open(filename, 'w') as f: print(dump(metadata), file=f) if SACRED_OBSERVE_FILES: _run.add_artifact(filename)
def __set_attributes(self): """ Set the attributes containing of a list of words of all attributes in the bag of words matrix. @return: The generated list of words acting as attributes for the BOWs. """ batches = s.load(open(env_paths.get_batches_path(self.training), "rb")) length = len(batches) attributes = [] processed = 1 for batch in batches: docs_list = s.load(open(env_paths.get_doc_list_path(self.training, batch), "rb")) tmp_attributes = list( set(sorted(list(chain(*docs_list))))) # Retrieve the each word of the docs list in a sorted list attributes += tmp_attributes attributes = list( set(sorted(attributes))) # Sort the attributes list so that there is no 2 occurrences of the same word. if not self.acceptance_lst == None: attributes = list( set(attributes).intersection(self.acceptance_lst)) # Only consider words in the acceptance list. print 'Processed attribute ' + str(processed) + ' of ' + str(length) + ' batches' processed += 1 # Find attributes of the most common words. d = dict.fromkeys(attributes) processed = 1 for batch in batches: docs_list = s.load(open(env_paths.get_doc_list_path(self.training, batch), "rb")) words = list(list(chain(*docs_list))) for w in words: try: if d[w] == None: d[w] = 1 else: d[w] += 1 except KeyError: continue print 'Processed summing ' + str(processed) + ' of ' + str(length) + ' batches' processed += 1 sorted_att = sorted(d.items(), key=lambda x: x[1]) sorted_att = sorted_att[len(sorted_att) - self.max_words_matrix:] attributes = [elem[0] for elem in sorted_att] # Serialize attributes s.dump(attributes, open(env_paths.get_attributes_path(self.training), "wb")) return attributes
def train(model_path, _log, _run): """Train a majority vote model.""" train_reader = read_train_corpus() c = Counter(tag for _, tag in train_reader.tagged_words()) majority_tag = c.most_common(n=1)[0][0] _log.info('Saving model to %s', model_path) with open(model_path, 'w') as f: print(dump({'majority_tag': majority_tag}), file=f) if SACRED_OBSERVE_FILES: _run.add_artifact(model_path)
def __generate_input_data(self): """ Generate the input data for the DBN so that it can be visualized. """ if not len(self.input_data) == 0: return try: self.input_data = s.load(open('output/input_data.p', 'rb')) self.class_indices = s.load(open('output/class_indices.p', 'rb')) if not self.classes_to_visualise == None: self.__filter_input_data(self.classes_to_visualise) except: self.input_data = generate_input_data_list(training=False) if self.testing else generate_input_data_list() self.class_indices = get_all_class_indices(training=False) if self.testing else get_all_class_indices() if not self.classes_to_visualise == None: self.__filter_input_data(self.classes_to_visualise) s.dump([input.tolist() for input in self.input_data], open('output/input_data.p', 'wb')) s.dump(self.class_indices, open('output/class_indices.p', 'wb')) self.legend = get_class_names_for_class_indices(list(set(sorted(self.class_indices))))
def train(model_path, _log, _run, window=2): """Train a memorization model.""" train_reader = read_train_corpus() sents, tags = separate_tagged_sents(train_reader.tagged_sents()) sents = preprocess(sents) _log.info('Start training model') model = MemorizationTagger.train(sents, tags, window=window) _log.info('Saving model to %s', model_path) with open(model_path, 'w') as f: print(dump(model), file=f) if SACRED_OBSERVE_FILES: _run.add_artifact(model_path)
def train(model_path, _log, _run, cutoff=0.1, idf_path=None): """Train a naive Bayes summarizer.""" train_docs = list(read_train_jsonl()) idf_table = None if idf_path is None else read_idf() model = NaiveBayesSummarizer.train(train_docs, cutoff=cutoff, idf_table=idf_table) _log.info('Saving model to %s', model_path) with open(model_path, 'w') as f: print(dump(model), file=f) if SAVE_FILES: _run.add_artifact(model_path)
def __stem_doc(doc_details): # Import nltk tools from nltk.tokenize import wordpunct_tokenize as wordpunct_tokenize # from nltk.stem.snowball import EnglishStemmer from nltk.stem.porter import PorterStemmer as EnglishStemmer idx, doc = doc_details if idx % 100 == 0: print "Processed doc " + str(idx) if doc.endswith(".txt"): d = open(doc).read() stemmer = EnglishStemmer() # This method only works for english documents. # Stem, lowercase, substitute all punctuations, remove stopwords. attribute_names = [ stemmer.stem(token.lower()) for token in wordpunct_tokenize( re.sub("[%s]" % re.escape(string.punctuation), "", d.decode(encoding="UTF-8", errors="ignore")) ) if token.lower() not in stopwords.get_stopwords() ] s.dump(attribute_names, open(doc.replace(".txt", ".p"), "wb"))
def __generate_output_data(self): """ Generate the output data of the DBN so that it can be visualised. """ if not len(self.output_data) == 0: return try: self.output_data = s.load(open('output/output_data.p', 'rb')) self.class_indices = s.load(open('output/class_indices.p', 'rb')) if not self.classes_to_visualise == None: self.__filter_output_data(self.classes_to_visualise) except: self.output_data = generate_output_for_test_data(image_data=self.image_data, binary_output=self.binary_output) if self.testing else generate_output_for_train_data( image_data=self.image_data, binary_output=self.binary_output) self.class_indices = get_all_class_indices(training=False) if self.testing else get_all_class_indices() if not self.classes_to_visualise == None: self.__filter_output_data(self.classes_to_visualise) s.dump([out.tolist() for out in self.output_data], open('output/output_data.p', 'wb')) s.dump(self.class_indices, open('output/class_indices.p', 'wb')) self.legend = get_class_names_for_class_indices(list(set(sorted(self.class_indices))))
def train(model_path, _log, _run, gamma_word=0.1, gamma_init=0.1, gamma_trans=0.1, tf_path=None): """Train an HMM summarizer.""" train_docs = list(read_train_jsonl()) tf_table = None if tf_path is None else read_tf() model = HMMSummarizer.train( train_docs, gamma_word=gamma_word, gamma_init=gamma_init, gamma_trans=gamma_trans, tf_table=tf_table) _log.info('Saving model to %s', model_path) with open(model_path, 'w') as f: print(dump(model), file=f) if SAVE_FILES: _run.add_artifact(model_path)
def save_rbm_weights(weight_matrices,hidden_biases,visible_biases): """ Save the weight matrices from the rbm pretraining. @param weight_matrices: the weight matrices of the rbm pretraining. """ s.dump([w.tolist() for w in weight_matrices] , open( env_paths.get_rbm_weights_path(), "wb" ) ) s.dump([b.tolist() for b in hidden_biases] , open( env_paths.get_rbm_hidden_biases_path(), "wb" ) ) s.dump([b.tolist() for b in visible_biases] , open( env_paths.get_rbm_visible_biases_path(), "wb" ) )
def generate_bows(self): """ Run through all steps of the dataprocessing to generate the BOWs for a training set and/or a testset. 1. Take all serialized stemmed documents and assign them into batches. Each batch should represent an equal number of docs from a category, except the last batch. 2. Calculate the number of words to extract an attributes list corresponding to the X (word_count) most used words. 3. Generate the BOWs for all batches. The BOWs will be saved in an output folder of the project root. """ print "Data Processing Started" timer = time() completed = self.__read_docs_from_filesystem() if not completed: print "Dataprocessing ended with an error." return print "Time ", time() - timer print "Filtering Words" timer = time() # Add all text of docs as a tokenized list if self.trainingset_attributes == None: attributes = self.__set_attributes() else: attributes = self.trainingset_attributes s.dump(attributes, open(env_paths.get_attributes_path(self.training), "wb")) print "Time ", time() - timer print "Generate bag of words matrix" timer = time() # Generate a dictionary for lookup of the words index_lookup = dict(zip(attributes, range(len(attributes)))) # Generate word matrix self.__generate_word_matrix(index_lookup) print "Time ", time() - timer
def generate_bows(self): """ Run through all steps of the dataprocessing to generate the BOWs for a training set and/or a testset. 1. Take all serialized stemmed documents and assign them into batches. Each batch should represent an equal number of docs from a category, except the last batch. 2. Calculate the number of words to extract an attributes list corresponding to the X (word_count) most used words. 3. Generate the BOWs for all batches. The BOWs will be saved in an output folder of the project root. """ print 'Data Processing Started' timer = time() completed = self.__read_docs_from_filesystem() if not completed: print 'Dataprocessing ended with an error.' return print 'Time ', time() - timer print 'Filtering Words' timer = time() # Add all text of docs as a tokenized list if self.trainingset_attributes == None: attributes = self.__set_attributes() else: attributes = self.trainingset_attributes s.dump(attributes, open(env_paths.get_attributes_path(self.training), "wb")) print 'Time ', time() - timer print 'Generate bag of words matrix' timer = time() # Generate a dictionary for lookup of the words index_lookup = dict(zip(attributes, range(len(attributes)))) # Generate word matrix self.__generate_word_matrix(index_lookup) print 'Time ', time() - timer
def save_rbm_weights(weight_matrices, hidden_biases, visible_biases): """ Save the weight matrices from the rbm pretraining. @param weight_matrices: the weight matrices of the rbm pretraining. """ s.dump([w.tolist() for w in weight_matrices], open(env_paths.get_rbm_weights_path(), "wb")) s.dump([b.tolist() for b in hidden_biases], open(env_paths.get_rbm_hidden_biases_path(), "wb")) s.dump([b.tolist() for b in visible_biases], open(env_paths.get_rbm_visible_biases_path(), "wb"))
def __save_batch_loading_docs(self, batch_number, docs_list, docs_names, class_indices): """ Save batches for the document loading process in the initialization phase. This is done due to vast sizes of data - lack of memory. @param batch_number: Representing the number of documents in the batch. @param docs_list: List containing a string for each document in the batch. @param docs_names: List containing the names of each document in the same order as the docs_list. @param class_indices: List containing which class/folder each document belongs to. """ # Serialize all relevant variables s.dump(docs_list, open(env_paths.get_doc_list_path(self.training, batch_number), "wb")) s.dump(docs_names, open(env_paths.get_doc_names_path(self.training, batch_number), "wb")) s.dump(class_indices, open(env_paths.get_class_indices_path(self.training, batch_number), "wb"))
def save_dbn(weight_matrices,fine_tuning_error_train,fine_tuning_error_test,output = None): """ Save the deep belief network into serialized files. @param weight_matrices: the weight matrices of the deep belief network. """ s.dump([w.tolist() for w in weight_matrices] , open( env_paths.get_dbn_weight_path(), "wb" ) ) s.dump(fine_tuning_error_train , open( env_paths.get_dbn_training_error_path(), "wb" ) ) s.dump(fine_tuning_error_test , open( env_paths.get_dbn_test_error_path(), "wb" ) ) if not output == None: out = open(env_paths.get_dbn_output_txt_path(),"w") for elem in output: out.write(elem+"\n") out.close()
def train(model_path, _log, _run, stopwords_path=None, train_algo='iis', cutoff=4, sigma=0., trim_length=10): """Train a maximum entropy summarizer.""" train_docs = list(read_train_jsonl()) stopwords = None if stopwords_path is None else read_stopwords() model = MaxentSummarizer.train(train_docs, stopwords=stopwords, algorithm=train_algo, cutoff=cutoff, sigma=sigma, trim_length=trim_length) _log.info('Saving model to %s', model_path) with open(model_path, 'w') as f: print(dump(model), file=f) if SAVE_FILES: _run.add_artifact(model_path)
def save_dbn(weight_matrices, fine_tuning_error_train, fine_tuning_error_test, output=None): """ Save the deep belief network into serialized files. @param weight_matrices: the weight matrices of the deep belief network. """ s.dump([w.tolist() for w in weight_matrices], open(env_paths.get_dbn_weight_path(), "wb")) s.dump(fine_tuning_error_train, open(env_paths.get_dbn_training_error_path(), "wb")) s.dump(fine_tuning_error_test, open(env_paths.get_dbn_test_error_path(), "wb")) if not output == None: out = open(env_paths.get_dbn_output_txt_path(), "w") for elem in output: out.write(elem + "\n") out.close()
def finetune( corpus, _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", load_samples_from=None, overwrite=False, load_src=None, src_key_as_lang=False, main_src=None, device="cpu", word_emb_path="wiki.id.vec", freeze=False, thresh=0.95, projective=False, multiroot=True, batch_size=32, save_samples=False, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with PPTX.""" if max_length is None: max_length = {} if load_src is None: load_src = {"src": ("artifacts", "model.pth")} main_src = "src" elif main_src not in load_src: raise ValueError(f"{main_src} not found in load_src") artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) if load_samples_from: _log.info("Loading samples from %s", load_samples_from) with open(load_samples_from, "rb") as f: samples = pickle.load(f) else: samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) kv = KeyedVectors.load_word2vec_format(word_emb_path) if load_samples_from: _log.info( "Skipping non-main src because samples are processed and loaded") srcs = [] else: srcs = [src for src in load_src if src != main_src] if src_key_as_lang and corpus["lang"] in srcs: _log.info("Removing %s from src parsers because it's the tgt", corpus["lang"]) srcs.remove(corpus["lang"]) srcs.append(main_src) for src_i, src in enumerate(srcs): _log.info("Processing src %s [%d/%d]", src, src_i + 1, len(srcs)) load_from, load_params = load_src[src] path = Path(load_from) / "vocab.yml" _log.info("Loading %s vocabulary from %s", src, path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending %s vocabulary with target words", src) vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) samples_ = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading %s model from metadata %s", src, path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading %s model parameters from %s", src, path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating %s extended word embedding layer", src) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) model.to(device) for wh in ["train", "dev"]: if load_samples_from: assert all("pptx_mask" in s for s in samples[wh]) continue for i, s in enumerate(samples_[wh]): s["_id"] = i runner = Runner() runner.state.update({"pptx_masks": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), ], ) @runner.on(Event.BATCH) def compute_pptx_ambiguous_arcs_mask(state): assert state["batch"]["mask"].all() scores = state["total_arc_type_scores"] pptx_mask = compute_ambiguous_arcs_mask( scores, thresh, projective, multiroot) state["pptx_masks"].extend(pptx_mask) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples_[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info( "Computing PPTX ambiguous arcs mask for %s set with source %s", wh, src) with torch.no_grad(): runner.run( BucketIterator(samples_[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["pptx_masks"]) == len(samples_[wh]) assert len(runner.state["_ids"]) == len(samples_[wh]) for i, pptx_mask in zip(runner.state["_ids"], runner.state["pptx_masks"]): samples_[wh][i]["pptx_mask"] = pptx_mask.tolist() _log.info("Computing (log) number of trees stats on %s set", wh) report_log_ntrees_stats(samples_[wh], "pptx_mask", batch_size, projective, multiroot) _log.info("Combining the ambiguous arcs mask") assert len(samples_[wh]) == len(samples[wh]) for i in range(len(samples_[wh])): pptx_mask = torch.tensor(samples_[wh][i]["pptx_mask"]) assert pptx_mask.dim() == 3 if "pptx_mask" in samples[wh][i]: old_mask = torch.tensor(samples[wh][i]["pptx_mask"]) else: old_mask = torch.zeros(1, 1, 1).bool() samples[wh][i]["pptx_mask"] = (old_mask | pptx_mask).tolist() assert src == main_src _log.info("Main source is %s", src) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") if save_samples: path = artifacts_dir / "samples.pkl" _log.info("Saving samples to %s", path) with open(path, "wb") as f: pickle.dump(samples, f) samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} for wh in ["train", "dev"]: _log.info("Computing (log) number of trees stats on %s set", wh) report_log_ntrees_stats(samples[wh], "pptx_mask", batch_size, projective, multiroot) model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), compute_total_arc_type_scores(model, vocab), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): mask = state["batch"]["mask"] pptx_mask = state["batch"]["pptx_mask"].bool() scores = state["total_arc_type_scores"] pptx_loss = compute_aatrn_loss(scores, pptx_mask, mask, projective, multiroot) pptx_loss /= mask.size(0) loss = pptx_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "pptx_loss": pptx_loss.item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = {"loss": loss.item()} state["n_items"] = mask.long().sum().item() finetuner.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) pptx_loss = eval_state["mean_pptx_loss"] _log.info("dev_pptx_loss: %.4f", pptx_loss) _run.log_scalar("dev_pptx_loss", pptx_loss, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"], compute_loss=False) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]
def save_large_batch(batch, data): s.dump(data.tolist(), open(env_paths.get_dbn_large_batch_data_path(batch), 'wb'))
def save_large_batches_lst(lst): s.dump(lst, open(env_paths.get_dbn_batches_lst_path(), 'wb'))
def __read_docs_from_filesystem(self): """ Read all docs and assign them to batches, so that each doc category is represented equally across batches. """ docs_names = [] docs_names_split = [] class_indices = [] class_indices_split = [] class_names = [] batches = [] print 'Generating class indices and docs names list.' doc_count = 0 for folder in self.paths: docs_names_split.append([]) class_indices_split.append([]) class_names.append(folder.split('/')[len(folder.split('/')) - 1]) if self.trainingset_size == None: # If data processing should be done on all data in the specified folders. docs = os.listdir(folder) elif not self.trainingset_size == None and self.trainingset_attributes == None: # If data processing should be done on parts of the docs in the specified folders - for training and testing purposes. docs = os.listdir(folder)[:int(len(os.listdir(folder)) * self.trainingset_size)] else: # If data processing should be done on a test set. docs = os.listdir(folder)[int(len(os.listdir(folder)) * self.trainingset_size):] for doc in docs: if doc.endswith('.p'): # Append the name of the document to the list containing document names. docs_names_split[-1].append(folder + '/' + doc) class_indices_split[-1].append(len(class_names) - 1) doc_count += 1 if len(docs_names_split) == 0: # Check if docs have been stemmed. print 'Documents have not been stemmed. Please stem documents in order to create bag of words matrices.' return 0 # Ensure that batches contain an even amount of docs from each category. print 'Arranging the documents.' if doc_count < self.batchsize: print 'Number of documents must be greater than batchsize. Please revise the batchsize.' return 0 number_of_batches = doc_count / self.batchsize number_of_classes = len(self.paths) batches_collected_class_indices = [] batches_collected_docs_names = [] # Calculate fraction of category in each batch. d = {} for i in range(len(class_indices_split)): d[i] = float(len(class_indices_split[i])) / number_of_batches count = 0 for i in range(number_of_batches): batch_class_indices = [] batch_docs_names = [] d_tmp = array([int(v) for v in d.values()]) while True: if (len(batch_class_indices) == self.batchsize) and (not doc_count - count < self.batchsize) or ( count == doc_count): break if len(d_tmp[d_tmp > 0]) == 0: break for j in range(number_of_classes): if (len(batch_class_indices) == self.batchsize) and (not doc_count - count < self.batchsize) or ( count == doc_count): break if len(class_indices_split[j]) > 0 and d_tmp[j] != 0: batch_class_indices.append(class_indices_split[j].pop(0)) batch_docs_names.append(docs_names_split[j].pop(0)) d_tmp[j] -= 1 count += 1 batches_collected_class_indices.append(batch_class_indices) batches_collected_docs_names.append(batch_docs_names) for i in range(number_of_batches): bsize = self.batchsize if i < number_of_batches - 1 else self.batchsize + (doc_count % self.batchsize) batch_class_indices = batches_collected_class_indices[i] batch_docs_names = batches_collected_docs_names[i] if len(batch_class_indices) < bsize: while True: if len(batch_class_indices) == bsize: break for j in range(number_of_classes): if len(batch_class_indices) == bsize: break if len(class_indices_split[j]) > 0: batch_class_indices.append(class_indices_split[j].pop(0)) batch_docs_names.append(docs_names_split[j].pop(0)) # Shuffle the batch batch_class_indices_shuf = [] batch_docs_names_shuf = [] index_shuf = range(len(batch_class_indices)) shuffle(index_shuf) for k in index_shuf: batch_class_indices_shuf.append(batch_class_indices[k]) batch_docs_names_shuf.append(batch_docs_names[k]) # Append batch to full lists class_indices += batch_class_indices_shuf docs_names += batch_docs_names_shuf print 'Reading and saving docs from file system' count = 0 class_indices_batch = [] docs_names_batch = [] docs_list = [] for i in xrange(len(class_indices)): if not count == 0 and ( count % self.batchsize) == 0: # Save the batch if batchsize is reached or if the last document has been read. if not (len(class_indices) - count) < self.batchsize: print 'Read ', str(count), ' of ', len(class_indices) self.__save_batch_loading_docs(count, docs_list, docs_names_batch, class_indices_batch) batches.append(count) # Reset the lists docs_list = [] docs_names_batch = [] class_indices_batch = [] d = s.load(open(docs_names[i], 'rb')) docs_list.append(d) docs_names_batch.append(docs_names[i]) class_indices_batch.append(class_indices[i]) count += 1 # Save the remaining docs if len(docs_list) > 0: print 'Read ', str(count), ' of ', len(class_indices) self.__save_batch_loading_docs(count, docs_list, docs_names_batch, class_indices_batch) batches.append(count) s.dump(class_names, open(env_paths.get_class_names_path(self.training), "wb")) s.dump(batches, open(env_paths.get_batches_path(self.training), "wb")) return 1
def make_model( vocab, _log, word_emb_path="wiki.en.vec", artifacts_dir="artifacts", tag_size=50, n_heads=10, n_layers=6, ff_size=2048, kv_size=64, p_word=0.5, p_out=0.5, arc_size=128, type_size=128, ): kv = KeyedVectors.load_word2vec_format(word_emb_path) _log.info("Creating model") model = SelfAttGraph( len(vocab["words"]), len(vocab["types"]), len(vocab["tags"]), word_size=kv.vector_size, tag_size=tag_size, n_heads=n_heads, n_layers=n_layers, ff_size=ff_size, kv_size=kv_size, word_dropout=p_word, outdim_dropout=p_out, arc_size=arc_size, type_size=type_size, ) _log.info("Model created with %d parameters", sum(p.numel() for p in model.parameters())) weight = torch.randn(len(vocab["words"]), kv.vector_size) cnt_pre, cnt_unk = 0, 0 for w in vocab["words"]: wid = vocab["words"].index(w) if w in kv: weight[wid] = torch.from_numpy(kv[w]) cnt_pre += 1 elif w.lower() in kv: weight[wid] = torch.from_numpy(kv[w.lower()]) cnt_pre += 1 else: cnt_unk += 1 with torch.no_grad(): # freeze embedding to preserve alignment model.word_emb = torch.nn.Embedding.from_pretrained(weight, freeze=True) _log.info("Initialized %d words with pre-trained embedding", cnt_pre) _log.info("Found %d unknown words", cnt_unk) path = Path(artifacts_dir) / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") return model
def __read_docs_from_filesystem(self): """ Read all docs and assign them to batches, so that each doc category is represented equally across batches. """ docs_names = [] docs_names_split = [] class_indices = [] class_indices_split = [] class_names = [] batches = [] print "Generating class indices and docs names list." doc_count = 0 for folder in self.paths: docs_names_split.append([]) class_indices_split.append([]) class_names.append(folder.split("/")[len(folder.split("/")) - 1]) if self.trainingset_size == None: # If data processing should be done on all data in the specified folders. docs = os.listdir(folder) elif ( not self.trainingset_size == None and self.trainingset_attributes == None ): # If data processing should be done on parts of the docs in the specified folders - for training and testing purposes. docs = os.listdir(folder)[: int(len(os.listdir(folder)) * self.trainingset_size)] else: # If data processing should be done on a test set. docs = os.listdir(folder)[int(len(os.listdir(folder)) * self.trainingset_size) :] for doc in docs: if doc.endswith(".p"): # Append the name of the document to the list containing document names. docs_names_split[-1].append(folder + "/" + doc) class_indices_split[-1].append(len(class_names) - 1) doc_count += 1 if len(docs_names_split) == 0: # Check if docs have been stemmed. print "Documents have not been stemmed. Please stem documents in order to create bag of words matrices." return 0 # Ensure that batches contain an even amount of docs from each category. print "Arranging the documents." if doc_count < self.batchsize: print "Number of documents must be greater than batchsize. Please revise the batchsize." return 0 number_of_batches = doc_count / self.batchsize number_of_classes = len(self.paths) batches_collected_class_indices = [] batches_collected_docs_names = [] # Calculate fraction of category in each batch. d = {} for i in range(len(class_indices_split)): d[i] = float(len(class_indices_split[i])) / number_of_batches count = 0 for i in range(number_of_batches): batch_class_indices = [] batch_docs_names = [] d_tmp = array([int(v) for v in d.values()]) while True: if ( (len(batch_class_indices) == self.batchsize) and (not doc_count - count < self.batchsize) or (count == doc_count) ): break if len(d_tmp[d_tmp > 0]) == 0: break for j in range(number_of_classes): if ( (len(batch_class_indices) == self.batchsize) and (not doc_count - count < self.batchsize) or (count == doc_count) ): break if len(class_indices_split[j]) > 0 and d_tmp[j] != 0: batch_class_indices.append(class_indices_split[j].pop(0)) batch_docs_names.append(docs_names_split[j].pop(0)) d_tmp[j] -= 1 count += 1 batches_collected_class_indices.append(batch_class_indices) batches_collected_docs_names.append(batch_docs_names) for i in range(number_of_batches): bsize = self.batchsize if i < number_of_batches - 1 else self.batchsize + (doc_count % self.batchsize) batch_class_indices = batches_collected_class_indices[i] batch_docs_names = batches_collected_docs_names[i] if len(batch_class_indices) < bsize: while True: if len(batch_class_indices) == bsize: break for j in range(number_of_classes): if len(batch_class_indices) == bsize: break if len(class_indices_split[j]) > 0: batch_class_indices.append(class_indices_split[j].pop(0)) batch_docs_names.append(docs_names_split[j].pop(0)) # Shuffle the batch batch_class_indices_shuf = [] batch_docs_names_shuf = [] index_shuf = range(len(batch_class_indices)) shuffle(index_shuf) for k in index_shuf: batch_class_indices_shuf.append(batch_class_indices[k]) batch_docs_names_shuf.append(batch_docs_names[k]) # Append batch to full lists class_indices += batch_class_indices_shuf docs_names += batch_docs_names_shuf print "Reading and saving docs from file system" count = 0 class_indices_batch = [] docs_names_batch = [] docs_list = [] for i in xrange(len(class_indices)): if ( not count == 0 and (count % self.batchsize) == 0 ): # Save the batch if batchsize is reached or if the last document has been read. if not (len(class_indices) - count) < self.batchsize: print "Read ", str(count), " of ", len(class_indices) self.__save_batch_loading_docs(count, docs_list, docs_names_batch, class_indices_batch) batches.append(count) # Reset the lists docs_list = [] docs_names_batch = [] class_indices_batch = [] d = s.load(open(docs_names[i], "rb")) docs_list.append(d) docs_names_batch.append(docs_names[i]) class_indices_batch.append(class_indices[i]) count += 1 # Save the remaining docs if len(docs_list) > 0: print "Read ", str(count), " of ", len(class_indices) self.__save_batch_loading_docs(count, docs_list, docs_names_batch, class_indices_batch) batches.append(count) s.dump(class_names, open(env_paths.get_class_names_path(self.training), "wb")) s.dump(batches, open(env_paths.get_batches_path(self.training), "wb")) return 1
def finetune( _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", overwrite=False, load_from="artifacts", load_params="model.pth", device="cpu", word_emb_path="wiki.id.vec", freeze=False, projective=False, multiroot=True, batch_size=32, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with self-training.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) path = Path(load_from) / "vocab.yml" _log.info("Loading vocabulary from %s", path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending vocabulary with target words") vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading model from metadata %s", path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading model parameters from %s", path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating extended word embedding layer") kv = KeyedVectors.load_word2vec_format(word_emb_path) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) model.to(device) for wh in ["train"]: for i, s in enumerate(samples[wh]): s["_id"] = i runner = Runner() runner.state.update({"st_heads": [], "st_types": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), predict_batch(projective, multiroot), ], ) @runner.on(Event.BATCH) def save_st_trees(state): state["st_heads"].extend(state["pred_heads"].tolist()) state["st_types"].extend(state["pred_types"].tolist()) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info("Computing ST trees for %s set", wh) with torch.no_grad(): runner.run( BucketIterator(samples[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["st_heads"]) == len(samples[wh]) assert len(runner.state["st_types"]) == len(samples[wh]) assert len(runner.state["_ids"]) == len(samples[wh]) for i, st_heads, st_types in zip(runner.state["_ids"], runner.state["st_heads"], runner.state["st_types"]): assert len(samples[wh][i]["words"]) == len(st_heads) assert len(samples[wh][i]["words"]) == len(st_types) samples[wh][i]["st_heads"] = st_heads samples[wh][i]["st_types"] = st_types _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): bat = state["batch"] words, tags, heads, types = bat["words"], bat["tags"], bat[ "st_heads"], bat["st_types"] mask = bat["mask"] arc_scores, type_scores = model(words, tags, mask, heads) arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), -1e9) # mask padding heads type_scores[..., vocab["types"].index(vocab.PAD_TOKEN)] = -1e9 # remove root arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:] heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:] arc_scores = rearrange(arc_scores, "bsz slen1 slen2 -> (bsz slen2) slen1") heads = heads.reshape(-1) arc_loss = torch.nn.functional.cross_entropy(arc_scores, heads, reduction="none") type_scores = rearrange(type_scores, "bsz slen ntypes -> (bsz slen) ntypes") types = types.reshape(-1) type_loss = torch.nn.functional.cross_entropy(type_scores, types, reduction="none") arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean() type_loss = type_loss.masked_select(mask.reshape(-1)).mean() loss = arc_loss + type_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "arc_ppl": arc_loss.exp().item(), "type_ppl": type_loss.exp().item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = { "arc_loss": arc_loss.item(), "type_loss": type_loss.item() } finetuner.on( Event.BATCH, [ get_n_items(), update_params(opt), log_grads(_run, model), log_stats(_run) ], ) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"]) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]
def finetune( _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", overwrite=False, load_from="artifacts", load_params="model.pth", device="cpu", word_emb_path="wiki.id.vec", freeze=False, thresh=0.95, projective=False, multiroot=True, batch_size=32, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with PPT.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) path = Path(load_from) / "vocab.yml" _log.info("Loading vocabulary from %s", path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending vocabulary with target words") vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading model from metadata %s", path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading model parameters from %s", path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating extended word embedding layer") kv = KeyedVectors.load_word2vec_format(word_emb_path) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) model.to(device) for wh in ["train", "dev"]: for i, s in enumerate(samples[wh]): s["_id"] = i runner = Runner() runner.state.update({"ppt_masks": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), ], ) @runner.on(Event.BATCH) def compute_ppt_ambiguous_arcs_mask(state): assert state["batch"]["mask"].all() scores = state["total_arc_type_scores"] ppt_mask = compute_ambiguous_arcs_mask(scores, thresh, projective, multiroot) state["ppt_masks"].extend(ppt_mask.tolist()) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info("Computing PPT ambiguous arcs mask for %s set", wh) with torch.no_grad(): runner.run( BucketIterator(samples[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["ppt_masks"]) == len(samples[wh]) assert len(runner.state["_ids"]) == len(samples[wh]) for i, ppt_mask in zip(runner.state["_ids"], runner.state["ppt_masks"]): samples[wh][i]["ppt_mask"] = ppt_mask _log.info("Computing (log) number of trees stats on %s set", wh) report_log_ntrees_stats(samples[wh], "ppt_mask", batch_size, projective, multiroot) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), compute_total_arc_type_scores(model, vocab), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): mask = state["batch"]["mask"] ppt_mask = state["batch"]["ppt_mask"].bool() scores = state["total_arc_type_scores"] ppt_loss = compute_aatrn_loss(scores, ppt_mask, mask, projective, multiroot) ppt_loss /= mask.size(0) loss = ppt_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "ppt_loss": ppt_loss.item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = {"loss": loss.item()} state["n_items"] = mask.long().sum().item() finetuner.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) ppt_loss = eval_state["mean_ppt_loss"] _log.info("dev_ppt_loss: %.4f", ppt_loss) _run.log_scalar("dev_ppt_loss", ppt_loss, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"], compute_loss=False) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]
def train( _log, _run, _rnd, artifacts_dir="artifacts", overwrite=False, max_length=None, load_types_vocab_from=None, batch_size=16, device="cpu", lr=0.001, patience=5, max_epoch=1000, ): """Train a self-attention graph-based parser.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) _log.info("Creating vocabulary") vocab = Vocab.from_samples(chain(*samples.values())) if load_types_vocab_from: path = Path(load_types_vocab_from) _log.info("Loading types vocab from %s", path) vocab["types"] = load(path.read_text(encoding="utf8"))["types"] _log.info("Vocabulary created") for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} model = make_model(vocab) model.to(device) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5) trainer = Runner() trainer.state.update({"dev_larcs_nopunct": -1, "dev_uarcs_nopunct": -1}) trainer.on(Event.BATCH, [batch2tensors(device, vocab), set_train_mode(model)]) @trainer.on(Event.BATCH) def compute_loss(state): bat = state["batch"] words, tags, heads, types = bat["words"], bat["tags"], bat[ "heads"], bat["types"] mask = bat["mask"] arc_scores, type_scores = model(words, tags, mask, heads) arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), -1e9) # mask padding heads type_scores[..., vocab["types"].index(Vocab.PAD_TOKEN)] = -1e9 # remove root arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:] heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:] arc_scores = rearrange(arc_scores, "bsz slen1 slen2 -> (bsz slen2) slen1") heads = heads.reshape(-1) arc_loss = torch.nn.functional.cross_entropy(arc_scores, heads, reduction="none") type_scores = rearrange(type_scores, "bsz slen ntypes -> (bsz slen) ntypes") types = types.reshape(-1) type_loss = torch.nn.functional.cross_entropy(type_scores, types, reduction="none") arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean() type_loss = type_loss.masked_select(mask.reshape(-1)).mean() loss = arc_loss + type_loss state["loss"] = loss arc_loss, type_loss = arc_loss.item(), type_loss.item() state["stats"] = { "arc_ppl": math.exp(arc_loss), "type_ppl": math.exp(type_loss), } state["extra_stats"] = {"arc_loss": arc_loss, "type_loss": type_loss} state["n_items"] = bat["mask"].long().sum().item() trainer.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @trainer.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) scheduler.step(accs["las_nopunct"]) if eval_state["counts"].larcs_nopunct > state["dev_larcs_nopunct"]: state["better"] = True elif eval_state["counts"].larcs_nopunct < state["dev_larcs_nopunct"]: state["better"] = False elif eval_state["counts"].uarcs_nopunct > state["dev_uarcs_nopunct"]: state["better"] = True else: state["better"] = False if state["better"]: _log.info("Found new best result on dev!") state["dev_larcs_nopunct"] = eval_state["counts"].larcs_nopunct state["dev_uarcs_nopunct"] = eval_state["counts"].uarcs_nopunct state["dev_accs"] = accs state["dev_epoch"] = state["epoch"] else: _log.info("Not better, the best so far is epoch %d:", state["dev_epoch"]) print_accs(state["dev_accs"]) print_accs(state["test_accs"], on="test") @trainer.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if not state["better"]: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"]) state["test_accs"] = eval_state["counts"].accs print_accs(state["test_accs"], on="test", run=_run, step=state["n_iters"]) trainer.on( Event.EPOCH_FINISHED, [ maybe_stop_early(patience=patience), save_state_dict("model", model, under=artifacts_dir, when="better"), ], ) EpochTimer().attach_on(trainer) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(trainer) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting training") try: trainer.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return trainer.state["dev_accs"]["las_nopunct"]