def trainTextVectors(connector, w2vEmbeddingsPath, wordIndexMapPath, wordFrequencyMapPath, wordEmbeddingsPath, contextsPath, sample, minCount, windowSize, negative, strict, contextsPerText, superBatchSize, fileEmbeddingSize, epochs, learningRate, fileEmbeddingsPath): if exists(wordIndexMapPath) and exists(wordFrequencyMapPath) and exists(wordEmbeddingsPath) \ and exists(contextsPath) and exists(pathTo.textIndexMap): wordIndexMap = parameters.loadMap(wordIndexMapPath) wordFrequencyMap = parameters.loadMap(wordFrequencyMapPath) wordEmbeddings = parameters.loadEmbeddings(wordEmbeddingsPath) textIndexMap = parameters.loadMap(pathTo.textIndexMap) else: w2vWordIndexMap, w2vWordEmbeddings = parameters.loadW2VParameters(w2vEmbeddingsPath) names, texts = extract(connector) wordIndexMap, wordFrequencyMap, wordEmbeddings = buildWordMaps(texts, w2vWordIndexMap, w2vWordEmbeddings) parameters.dumpWordMap(wordIndexMap, wordIndexMapPath) del w2vWordIndexMap del w2vWordEmbeddings gc.collect() parameters.dumpWordMap(wordFrequencyMap, wordFrequencyMapPath) log.progress('Dumping contexts...') parameters.dumpEmbeddings(wordEmbeddings, wordEmbeddingsPath) log.info('Dumped indices, frequencies and embeddings') texts = subsampleAndPrune(texts, wordFrequencyMap, sample, minCount) textIndexMap = inferContexts(contextsPath, names, texts, wordIndexMap, windowSize, negative, strict, contextsPerText) parameters.dumpWordMap(textIndexMap, pathTo.textIndexMap) with h5py.File(contextsPath, 'r') as contextsFile: contexts = contextsFile['contexts'] log.info('Loaded {0} contexts. Shape: {1}', len(contexts), contexts.shape) fileEmbeddings = numpy.random.rand(len(contexts), fileEmbeddingSize).astype('float32') trainingBatch = numpy.zeros((superBatchSize, contextsPerText, 1+windowSize+negative)).astype('int32') superBatchesCount = len(contexts) / superBatchSize for superBatchIndex in xrange(0, superBatchesCount): log.info('Text batch: {0}/{1}.', superBatchIndex + 1, superBatchesCount) # TODO: this only works if superBatchSize == textsCount; otherwise text indices do not match contexts.read_direct(trainingBatch, source_sel=numpy.s_[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize]) trainingBatchReshaped = trainingBatch.reshape((superBatchSize*contextsPerText, 1+windowSize+negative)) fileEmbeddingsBatch = fileEmbeddings[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize] model = traininig.Model(fileEmbeddingsBatch, wordEmbeddings, contextSize=windowSize-2, negative=negative) traininig.train(model, textIndexMap, wordIndexMap, wordEmbeddings, trainingBatchReshaped, epochs, 1, learningRate) fileEmbeddings[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize] = model.fileEmbeddings.get_value() contextsFile.flush() log.progress('Dumping text embeddings...') binary.dumpTensor(fileEmbeddingsPath, fileEmbeddings) log.info('Dumping text embeddings complete')
def processData(inputDirectoryPath, w2vEmbeddingsFilePath, fileIndexMapFilePath, wordIndexMapFilePath, wordEmbeddingsFilePath, contextsPath, windowSize, negative, strict): if os.path.exists(contextsPath): os.remove(contextsPath) fileContextSize = 1 wordContextSize = windowSize - fileContextSize fileIndexMap = {} wordIndexMap = collections.OrderedDict() wordEmbeddings = [] noNegativeSamplingPath = contextsPath if negative > 0: noNegativeSamplingPath += '.temp' if os.path.exists(noNegativeSamplingPath): os.remove(noNegativeSamplingPath) pathName = inputDirectoryPath + '/*.txt' textFilePaths = glob.glob(pathName) textFilePaths = sorted(textFilePaths) textFileCount = len(textFilePaths) w2vWordIndexMap, w2vEmbeddings = parameters.loadW2VParameters(w2vEmbeddingsFilePath) contextsCount = 0 with open(noNegativeSamplingPath, 'wb+') as noNegativeSamplingFile: binary.writei(noNegativeSamplingFile, 0) # this is a placeholder for contexts count binary.writei(noNegativeSamplingFile, windowSize) binary.writei(noNegativeSamplingFile, 0) startTime = time.time() for textFileIndex, textFilePath in enumerate(textFilePaths): fileIndexMap[textFilePath] = textFileIndex contextProvider = WordContextProvider(textFilePath=textFilePath) for wordContext in contextProvider.iterate(wordContextSize): allWordsInWordVocabulary = [word in w2vWordIndexMap for word in wordContext] if not all(allWordsInWordVocabulary): continue for word in wordContext: if word not in wordIndexMap: wordIndexMap[word] = len(wordIndexMap) wordEmbeddingIndex = w2vWordIndexMap[word] wordEmbedding = w2vEmbeddings[wordEmbeddingIndex] wordEmbeddings.append(wordEmbedding) indexContext = [textFileIndex] + map(lambda w: wordIndexMap[w], wordContext) binary.writei(noNegativeSamplingFile, indexContext) contextsCount += 1 currentTime = time.time() elapsed = currentTime - startTime secondsPerFile = elapsed / (textFileIndex + 1) log.progress('Reading contexts: {0:.3f}%. Elapsed: {1} ({2:.3f} sec/file). Words: {3}. Contexts: {4}.', textFileIndex + 1, textFileCount, log.delta(elapsed), secondsPerFile, len(wordIndexMap), contextsCount) log.lineBreak() noNegativeSamplingFile.seek(0, io.SEEK_SET) binary.writei(noNegativeSamplingFile, contextsCount) noNegativeSamplingFile.flush() if negative > 0: with open(contextsPath, 'wb+') as contextsFile: startTime = time.time() contextProvider = parameters.IndexContextProvider(noNegativeSamplingPath) binary.writei(contextsFile, contextsCount) binary.writei(contextsFile, windowSize) binary.writei(contextsFile, negative) batchSize = 10000 batchesCount = contextsCount / batchSize + 1 wordIndices = map(lambda item: item[1], wordIndexMap.items()) wordIndices = numpy.asarray(wordIndices) maxWordIndex = max(wordIndices) for batchIndex in xrange(0, batchesCount): contexts = contextProvider[batchIndex * batchSize : (batchIndex + 1) * batchSize] negativeSamples = generateNegativeSamples(negative, contexts, wordIndices, maxWordIndex, strict) contexts = numpy.concatenate([contexts, negativeSamples], axis=1) contexts = numpy.ravel(contexts) binary.writei(contextsFile, contexts) currentTime = time.time() elapsed = currentTime - startTime log.progress('Negative sampling: {0:.3f}%. Elapsed: {1}.', batchIndex + 1, batchesCount, log.delta(elapsed)) log.lineBreak() contextsFile.flush() os.remove(noNegativeSamplingPath) parameters.dumpWordMap(fileIndexMap, fileIndexMapFilePath) parameters.dumpWordMap(wordIndexMap, wordIndexMapFilePath) parameters.dumpEmbeddings(wordEmbeddings, wordEmbeddingsFilePath)