def launch(pathTo, hyper): fileIndexMap = parameters.loadMap(pathTo.textIndexMap) filesCount = len(fileIndexMap) fileEmbeddingSize = hyper.fileEmbeddingSize wordIndexMap = parameters.loadMap(pathTo.wordIndexMap) wordEmbeddings = parameters.loadEmbeddings(pathTo.wordEmbeddings) metricsPath = pathTo.metrics('history.csv') if os.path.exists(metricsPath): os.remove(metricsPath) contextProvider = parameters.IndexContextProvider(pathTo.contexts) windowSize = contextProvider.windowSize - 1 contextSize = windowSize - 1 negative = contextProvider.negative contexts = contextProvider[:] log.info('Contexts loading complete. {0} contexts loaded {1} words and {2} negative samples each.', len(contexts), contextProvider.windowSize, contextProvider.negative) fileEmbeddings = rnd2(filesCount, fileEmbeddingSize) model = Model(fileEmbeddings, wordEmbeddings, contextSize=contextSize, negative=negative) # model = Model.load(pathTo.fileEmbeddings, pathTo.wordEmbeddings, pathTo.weights) train(model, fileIndexMap, wordIndexMap, wordEmbeddings, contexts, epochs=hyper.epochs, batchSize=hyper.batchSize, learningRate=hyper.learningRate, metricsPath=metricsPath) model.dump(pathTo.fileEmbeddings, pathTo.weights)
def trainTextVectors(connector, w2vEmbeddingsPath, wordIndexMapPath, wordFrequencyMapPath, wordEmbeddingsPath, contextsPath, sample, minCount, windowSize, negative, strict, contextsPerText, superBatchSize, fileEmbeddingSize, epochs, learningRate, fileEmbeddingsPath): if exists(wordIndexMapPath) and exists(wordFrequencyMapPath) and exists(wordEmbeddingsPath) \ and exists(contextsPath) and exists(pathTo.textIndexMap): wordIndexMap = parameters.loadMap(wordIndexMapPath) wordFrequencyMap = parameters.loadMap(wordFrequencyMapPath) wordEmbeddings = parameters.loadEmbeddings(wordEmbeddingsPath) textIndexMap = parameters.loadMap(pathTo.textIndexMap) else: w2vWordIndexMap, w2vWordEmbeddings = parameters.loadW2VParameters(w2vEmbeddingsPath) names, texts = extract(connector) wordIndexMap, wordFrequencyMap, wordEmbeddings = buildWordMaps(texts, w2vWordIndexMap, w2vWordEmbeddings) parameters.dumpWordMap(wordIndexMap, wordIndexMapPath) del w2vWordIndexMap del w2vWordEmbeddings gc.collect() parameters.dumpWordMap(wordFrequencyMap, wordFrequencyMapPath) log.progress('Dumping contexts...') parameters.dumpEmbeddings(wordEmbeddings, wordEmbeddingsPath) log.info('Dumped indices, frequencies and embeddings') texts = subsampleAndPrune(texts, wordFrequencyMap, sample, minCount) textIndexMap = inferContexts(contextsPath, names, texts, wordIndexMap, windowSize, negative, strict, contextsPerText) parameters.dumpWordMap(textIndexMap, pathTo.textIndexMap) with h5py.File(contextsPath, 'r') as contextsFile: contexts = contextsFile['contexts'] log.info('Loaded {0} contexts. Shape: {1}', len(contexts), contexts.shape) fileEmbeddings = numpy.random.rand(len(contexts), fileEmbeddingSize).astype('float32') trainingBatch = numpy.zeros((superBatchSize, contextsPerText, 1+windowSize+negative)).astype('int32') superBatchesCount = len(contexts) / superBatchSize for superBatchIndex in xrange(0, superBatchesCount): log.info('Text batch: {0}/{1}.', superBatchIndex + 1, superBatchesCount) # TODO: this only works if superBatchSize == textsCount; otherwise text indices do not match contexts.read_direct(trainingBatch, source_sel=numpy.s_[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize]) trainingBatchReshaped = trainingBatch.reshape((superBatchSize*contextsPerText, 1+windowSize+negative)) fileEmbeddingsBatch = fileEmbeddings[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize] model = traininig.Model(fileEmbeddingsBatch, wordEmbeddings, contextSize=windowSize-2, negative=negative) traininig.train(model, textIndexMap, wordIndexMap, wordEmbeddings, trainingBatchReshaped, epochs, 1, learningRate) fileEmbeddings[superBatchIndex*superBatchSize:(superBatchIndex+1)*superBatchSize] = model.fileEmbeddings.get_value() contextsFile.flush() log.progress('Dumping text embeddings...') binary.dumpTensor(fileEmbeddingsPath, fileEmbeddings) log.info('Dumping text embeddings complete')