def launch(pathTo, hyper):
    fileIndexMap = parameters.loadMap(pathTo.textIndexMap)
    filesCount = len(fileIndexMap)
    fileEmbeddingSize = hyper.fileEmbeddingSize
    wordIndexMap = parameters.loadMap(pathTo.wordIndexMap)
    wordEmbeddings = parameters.loadEmbeddings(pathTo.wordEmbeddings)
    metricsPath = pathTo.metrics('history.csv')

    if os.path.exists(metricsPath):
        os.remove(metricsPath)

    contextProvider = parameters.IndexContextProvider(pathTo.contexts)
    windowSize = contextProvider.windowSize - 1
    contextSize = windowSize - 1
    negative = contextProvider.negative
    contexts = contextProvider[:]

    log.info('Contexts loading complete. {0} contexts loaded {1} words and {2} negative samples each.',
             len(contexts),
             contextProvider.windowSize,
             contextProvider.negative)

    fileEmbeddings = rnd2(filesCount, fileEmbeddingSize)
    model = Model(fileEmbeddings, wordEmbeddings, contextSize=contextSize, negative=negative)
    # model = Model.load(pathTo.fileEmbeddings, pathTo.wordEmbeddings, pathTo.weights)

    train(model, fileIndexMap, wordIndexMap, wordEmbeddings, contexts,
          epochs=hyper.epochs,
          batchSize=hyper.batchSize,
          learningRate=hyper.learningRate,
          metricsPath=metricsPath)

    model.dump(pathTo.fileEmbeddings, pathTo.weights)
示例#2
0
def trainTextVectors(connector, w2vEmbeddingsPath, wordIndexMapPath, wordFrequencyMapPath, wordEmbeddingsPath, contextsPath,
                     sample, minCount, windowSize, negative, strict, contextsPerText, superBatchSize, fileEmbeddingSize,
                     epochs, learningRate, fileEmbeddingsPath):
    if exists(wordIndexMapPath) and exists(wordFrequencyMapPath) and exists(wordEmbeddingsPath) \
            and exists(contextsPath) and exists(pathTo.textIndexMap):
        wordIndexMap = parameters.loadMap(wordIndexMapPath)
        wordFrequencyMap = parameters.loadMap(wordFrequencyMapPath)
        wordEmbeddings = parameters.loadEmbeddings(wordEmbeddingsPath)
        textIndexMap = parameters.loadMap(pathTo.textIndexMap)
    else:
        w2vWordIndexMap, w2vWordEmbeddings = parameters.loadW2VParameters(w2vEmbeddingsPath)

        names, texts = extract(connector)
        wordIndexMap, wordFrequencyMap, wordEmbeddings = buildWordMaps(texts, w2vWordIndexMap, w2vWordEmbeddings)

        parameters.dumpWordMap(wordIndexMap, wordIndexMapPath)
        del w2vWordIndexMap
        del w2vWordEmbeddings
        gc.collect()

        parameters.dumpWordMap(wordFrequencyMap, wordFrequencyMapPath)

        log.progress('Dumping contexts...')
        parameters.dumpEmbeddings(wordEmbeddings, wordEmbeddingsPath)
        log.info('Dumped indices, frequencies and embeddings')

        texts = subsampleAndPrune(texts, wordFrequencyMap, sample, minCount)

        textIndexMap = inferContexts(contextsPath, names, texts, wordIndexMap, windowSize, negative, strict, contextsPerText)

        parameters.dumpWordMap(textIndexMap, pathTo.textIndexMap)

    with h5py.File(contextsPath, 'r') as contextsFile:
        contexts = contextsFile['contexts']
        log.info('Loaded {0} contexts. Shape: {1}', len(contexts), contexts.shape)

        fileEmbeddings = numpy.random.rand(len(contexts), fileEmbeddingSize).astype('float32')
        trainingBatch = numpy.zeros((superBatchSize, contextsPerText, 1+windowSize+negative)).astype('int32')
        superBatchesCount = len(contexts) / superBatchSize

        for superBatchIndex in xrange(0, superBatchesCount):
            log.info('Text batch: {0}/{1}.', superBatchIndex + 1, superBatchesCount)

            # TODO: this only works if superBatchSize == textsCount; otherwise text indices do not match
            contexts.read_direct(trainingBatch, source_sel=numpy.s_[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize])
            trainingBatchReshaped = trainingBatch.reshape((superBatchSize*contextsPerText, 1+windowSize+negative))

            fileEmbeddingsBatch = fileEmbeddings[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize]

            model = traininig.Model(fileEmbeddingsBatch, wordEmbeddings, contextSize=windowSize-2, negative=negative)
            traininig.train(model, textIndexMap, wordIndexMap, wordEmbeddings, trainingBatchReshaped, epochs, 1, learningRate)

            fileEmbeddings[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize] = model.fileEmbeddings.get_value()
            contextsFile.flush()

        log.progress('Dumping text embeddings...')
        binary.dumpTensor(fileEmbeddingsPath, fileEmbeddings)
        log.info('Dumping text embeddings complete')