示例#1
0
def trainTextVectors(connector, w2vEmbeddingsPath, wordIndexMapPath, wordFrequencyMapPath, wordEmbeddingsPath, contextsPath,
                     sample, minCount, windowSize, negative, strict, contextsPerText, superBatchSize, fileEmbeddingSize,
                     epochs, learningRate, fileEmbeddingsPath):
    if exists(wordIndexMapPath) and exists(wordFrequencyMapPath) and exists(wordEmbeddingsPath) \
            and exists(contextsPath) and exists(pathTo.textIndexMap):
        wordIndexMap = parameters.loadMap(wordIndexMapPath)
        wordFrequencyMap = parameters.loadMap(wordFrequencyMapPath)
        wordEmbeddings = parameters.loadEmbeddings(wordEmbeddingsPath)
        textIndexMap = parameters.loadMap(pathTo.textIndexMap)
    else:
        w2vWordIndexMap, w2vWordEmbeddings = parameters.loadW2VParameters(w2vEmbeddingsPath)

        names, texts = extract(connector)
        wordIndexMap, wordFrequencyMap, wordEmbeddings = buildWordMaps(texts, w2vWordIndexMap, w2vWordEmbeddings)

        parameters.dumpWordMap(wordIndexMap, wordIndexMapPath)
        del w2vWordIndexMap
        del w2vWordEmbeddings
        gc.collect()

        parameters.dumpWordMap(wordFrequencyMap, wordFrequencyMapPath)

        log.progress('Dumping contexts...')
        parameters.dumpEmbeddings(wordEmbeddings, wordEmbeddingsPath)
        log.info('Dumped indices, frequencies and embeddings')

        texts = subsampleAndPrune(texts, wordFrequencyMap, sample, minCount)

        textIndexMap = inferContexts(contextsPath, names, texts, wordIndexMap, windowSize, negative, strict, contextsPerText)

        parameters.dumpWordMap(textIndexMap, pathTo.textIndexMap)

    with h5py.File(contextsPath, 'r') as contextsFile:
        contexts = contextsFile['contexts']
        log.info('Loaded {0} contexts. Shape: {1}', len(contexts), contexts.shape)

        fileEmbeddings = numpy.random.rand(len(contexts), fileEmbeddingSize).astype('float32')
        trainingBatch = numpy.zeros((superBatchSize, contextsPerText, 1+windowSize+negative)).astype('int32')
        superBatchesCount = len(contexts) / superBatchSize

        for superBatchIndex in xrange(0, superBatchesCount):
            log.info('Text batch: {0}/{1}.', superBatchIndex + 1, superBatchesCount)

            # TODO: this only works if superBatchSize == textsCount; otherwise text indices do not match
            contexts.read_direct(trainingBatch, source_sel=numpy.s_[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize])
            trainingBatchReshaped = trainingBatch.reshape((superBatchSize*contextsPerText, 1+windowSize+negative))

            fileEmbeddingsBatch = fileEmbeddings[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize]

            model = traininig.Model(fileEmbeddingsBatch, wordEmbeddings, contextSize=windowSize-2, negative=negative)
            traininig.train(model, textIndexMap, wordIndexMap, wordEmbeddings, trainingBatchReshaped, epochs, 1, learningRate)

            fileEmbeddings[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize] = model.fileEmbeddings.get_value()
            contextsFile.flush()

        log.progress('Dumping text embeddings...')
        binary.dumpTensor(fileEmbeddingsPath, fileEmbeddings)
        log.info('Dumping text embeddings complete')
示例#2
0
 def dump(self, parametersPath, embeddingsPath):
     embeddings = self.wordEmbeddings.get_value()
     parameters.dumpEmbeddings(embeddings, embeddingsPath)
示例#3
0
 def dump(self, parametersPath, embeddingsPath):
     embeddings = self.wordEmbeddings.get_value()
     parameters.dumpEmbeddings(embeddings, embeddingsPath)
示例#4
0
def processData(inputDirectoryPath, w2vEmbeddingsFilePath, fileIndexMapFilePath,
                wordIndexMapFilePath, wordEmbeddingsFilePath, contextsPath, windowSize, negative, strict):
    if os.path.exists(contextsPath):
        os.remove(contextsPath)

    fileContextSize = 1
    wordContextSize = windowSize - fileContextSize

    fileIndexMap = {}
    wordIndexMap = collections.OrderedDict()
    wordEmbeddings = []

    noNegativeSamplingPath = contextsPath
    if negative > 0:
        noNegativeSamplingPath += '.temp'

    if os.path.exists(noNegativeSamplingPath):
        os.remove(noNegativeSamplingPath)

    pathName = inputDirectoryPath + '/*.txt'
    textFilePaths = glob.glob(pathName)
    textFilePaths = sorted(textFilePaths)
    textFileCount = len(textFilePaths)

    w2vWordIndexMap, w2vEmbeddings = parameters.loadW2VParameters(w2vEmbeddingsFilePath)

    contextsCount = 0
    with open(noNegativeSamplingPath, 'wb+') as noNegativeSamplingFile:
        binary.writei(noNegativeSamplingFile, 0) # this is a placeholder for contexts count
        binary.writei(noNegativeSamplingFile, windowSize)
        binary.writei(noNegativeSamplingFile, 0)

        startTime = time.time()

        for textFileIndex, textFilePath in enumerate(textFilePaths):
            fileIndexMap[textFilePath] = textFileIndex

            contextProvider = WordContextProvider(textFilePath=textFilePath)
            for wordContext in contextProvider.iterate(wordContextSize):
                allWordsInWordVocabulary = [word in w2vWordIndexMap for word in wordContext]

                if not all(allWordsInWordVocabulary):
                    continue

                for word in wordContext:
                    if word not in wordIndexMap:
                        wordIndexMap[word] = len(wordIndexMap)
                        wordEmbeddingIndex = w2vWordIndexMap[word]
                        wordEmbedding = w2vEmbeddings[wordEmbeddingIndex]
                        wordEmbeddings.append(wordEmbedding)

                indexContext = [textFileIndex] + map(lambda w: wordIndexMap[w], wordContext)

                binary.writei(noNegativeSamplingFile, indexContext)
                contextsCount += 1

            currentTime = time.time()
            elapsed = currentTime - startTime
            secondsPerFile = elapsed / (textFileIndex + 1)

            log.progress('Reading contexts: {0:.3f}%. Elapsed: {1} ({2:.3f} sec/file). Words: {3}. Contexts: {4}.',
                         textFileIndex + 1,
                         textFileCount,
                         log.delta(elapsed),
                         secondsPerFile,
                         len(wordIndexMap),
                         contextsCount)

        log.lineBreak()

        noNegativeSamplingFile.seek(0, io.SEEK_SET)
        binary.writei(noNegativeSamplingFile, contextsCount)
        noNegativeSamplingFile.flush()

    if negative > 0:
        with open(contextsPath, 'wb+') as contextsFile:
            startTime = time.time()

            contextProvider = parameters.IndexContextProvider(noNegativeSamplingPath)

            binary.writei(contextsFile, contextsCount)
            binary.writei(contextsFile, windowSize)
            binary.writei(contextsFile, negative)

            batchSize = 10000
            batchesCount = contextsCount / batchSize + 1

            wordIndices = map(lambda item: item[1], wordIndexMap.items())
            wordIndices = numpy.asarray(wordIndices)
            maxWordIndex = max(wordIndices)

            for batchIndex in xrange(0, batchesCount):
                contexts = contextProvider[batchIndex * batchSize : (batchIndex + 1) * batchSize]
                negativeSamples = generateNegativeSamples(negative, contexts, wordIndices, maxWordIndex, strict)
                contexts = numpy.concatenate([contexts, negativeSamples], axis=1)
                contexts = numpy.ravel(contexts)

                binary.writei(contextsFile, contexts)

                currentTime = time.time()
                elapsed = currentTime - startTime

                log.progress('Negative sampling: {0:.3f}%. Elapsed: {1}.',
                     batchIndex + 1,
                     batchesCount,
                     log.delta(elapsed))

            log.lineBreak()
            contextsFile.flush()

            os.remove(noNegativeSamplingPath)

    parameters.dumpWordMap(fileIndexMap, fileIndexMapFilePath)
    parameters.dumpWordMap(wordIndexMap, wordIndexMapFilePath)
    parameters.dumpEmbeddings(wordEmbeddings, wordEmbeddingsFilePath)