Пример #1
0
def trainTextVectors(connector, w2vEmbeddingsPath, wordIndexMapPath, wordFrequencyMapPath, wordEmbeddingsPath, contextsPath,
                     sample, minCount, windowSize, negative, strict, contextsPerText, superBatchSize, fileEmbeddingSize,
                     epochs, learningRate, fileEmbeddingsPath):
    if exists(wordIndexMapPath) and exists(wordFrequencyMapPath) and exists(wordEmbeddingsPath) \
            and exists(contextsPath) and exists(pathTo.textIndexMap):
        wordIndexMap = parameters.loadMap(wordIndexMapPath)
        wordFrequencyMap = parameters.loadMap(wordFrequencyMapPath)
        wordEmbeddings = parameters.loadEmbeddings(wordEmbeddingsPath)
        textIndexMap = parameters.loadMap(pathTo.textIndexMap)
    else:
        w2vWordIndexMap, w2vWordEmbeddings = parameters.loadW2VParameters(w2vEmbeddingsPath)

        names, texts = extract(connector)
        wordIndexMap, wordFrequencyMap, wordEmbeddings = buildWordMaps(texts, w2vWordIndexMap, w2vWordEmbeddings)

        parameters.dumpWordMap(wordIndexMap, wordIndexMapPath)
        del w2vWordIndexMap
        del w2vWordEmbeddings
        gc.collect()

        parameters.dumpWordMap(wordFrequencyMap, wordFrequencyMapPath)

        log.progress('Dumping contexts...')
        parameters.dumpEmbeddings(wordEmbeddings, wordEmbeddingsPath)
        log.info('Dumped indices, frequencies and embeddings')

        texts = subsampleAndPrune(texts, wordFrequencyMap, sample, minCount)

        textIndexMap = inferContexts(contextsPath, names, texts, wordIndexMap, windowSize, negative, strict, contextsPerText)

        parameters.dumpWordMap(textIndexMap, pathTo.textIndexMap)

    with h5py.File(contextsPath, 'r') as contextsFile:
        contexts = contextsFile['contexts']
        log.info('Loaded {0} contexts. Shape: {1}', len(contexts), contexts.shape)

        fileEmbeddings = numpy.random.rand(len(contexts), fileEmbeddingSize).astype('float32')
        trainingBatch = numpy.zeros((superBatchSize, contextsPerText, 1+windowSize+negative)).astype('int32')
        superBatchesCount = len(contexts) / superBatchSize

        for superBatchIndex in xrange(0, superBatchesCount):
            log.info('Text batch: {0}/{1}.', superBatchIndex + 1, superBatchesCount)

            # TODO: this only works if superBatchSize == textsCount; otherwise text indices do not match
            contexts.read_direct(trainingBatch, source_sel=numpy.s_[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize])
            trainingBatchReshaped = trainingBatch.reshape((superBatchSize*contextsPerText, 1+windowSize+negative))

            fileEmbeddingsBatch = fileEmbeddings[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize]

            model = traininig.Model(fileEmbeddingsBatch, wordEmbeddings, contextSize=windowSize-2, negative=negative)
            traininig.train(model, textIndexMap, wordIndexMap, wordEmbeddings, trainingBatchReshaped, epochs, 1, learningRate)

            fileEmbeddings[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize] = model.fileEmbeddings.get_value()
            contextsFile.flush()

        log.progress('Dumping text embeddings...')
        binary.dumpTensor(fileEmbeddingsPath, fileEmbeddings)
        log.info('Dumping text embeddings complete')
    def dump(self, fileEmbeddingsPath, weightsPath):
        fileEmbeddings = self.fileEmbeddings.get_value()
        binary.dumpTensor(fileEmbeddingsPath, fileEmbeddings)

        weights = self.weights.get_value()
        binary.dumpTensor(weightsPath, weights)