class CorpusIteratorFuncHead(): def __init__(self, language, partition="train", storeMorph=False, splitLemmas=False): self.basis = CorpusIterator(language, partition=partition, storeMorph=storeMorph, splitLemmas=splitLemmas) def permute(self): self.basis.permute() def length(self): return self.basis.length() def iterator(self, rejectShortSentences=False): iterator = self.basis.iterator( rejectShortSentences=rejectShortSentences) for sentence in iterator: reverse_content_head(sentence) yield sentence def getSentence(self, index): return reverse_content_head(self.basis.getSentence(index))
class CorpusIteratorFuncHeadFraction(): def __init__(self, language, partition="train", fraction=1.0, storeMorph=False, splitLemmas=False): self.basis = CorpusIterator(language, partition=partition, storeMorph=storeMorph, splitLemmas=splitLemmas, shuffleDataSeed=4) self.basis.data = self.basis.data[:int(fraction * len(self.basis.data))] self.permute() self.fraction = fraction def permute(self): self.basis.permute() def length(self): return self.basis.length() def iterator(self, rejectShortSentences=False): iterator = self.basis.iterator( rejectShortSentences=rejectShortSentences) counter = 0 print("Actual length", self.length()) for sentence in iterator: reverse_content_head(sentence) yield sentence def getSentence(self, index): return reverse_content_head(self.basis.getSentence(index))
if wordNum > 0: crossEntropy = 0.99 * crossEntropy + 0.01 * (totalDepLength / wordNum) else: assert totalDepLength == 0 numberOfWords = wordNum return (totalDepLength, numberOfWords, byType) assert batchSize == 1 depLengths = [] if True: corpus = CorpusIterator(args.language, "train") corpusIterator = corpus.iterator() if corpus.length() == 0: quit() while True: try: batch = [next(corpusIterator)] except StopIteration: break partitions = range(1) for partition in partitions: counter += 1 printHere = (counter % 200 == 0) current = batch[partition * batchSize:(partition + 1) * batchSize] if len(current) == 0: continue depLength = doForwardPass(current)