def setupNetData( self, generateData=True, seed=42, preprocess=False, **kwargs): """ Resulting network data files created: - One for each bucket - One for each training rep, where samples are not repeated in a given file. Each samples is given its own category (_category = _sequenceId). The classification json is saved when generating the final training file. """ if generateData: ndg = NetworkDataGenerator() self.dataDict = ndg.split( filePath=self.dataPath, numLabels=1, textPreprocess=preprocess, **kwargs) filename, ext = os.path.splitext(self.dataPath) self.classificationFile = "{}_categories.json".format(filename) # Generate test data files: one network data file for each bucket. bucketFilePaths = bucketCSVs(self.dataPath) for bucketFile in bucketFilePaths: ndg.reset() ndg.split( filePath=bucketFile, numLabels=1, textPreprocess=preprocess, **kwargs) bucketFileName, ext = os.path.splitext(bucketFile) if not self.orderedSplit: # the sequences will be written to the file in random order ndg.randomizeData(seed) dataFile = "{}_network{}".format(bucketFileName, ext) ndg.saveData(dataFile, self.classificationFile) # the classification file here gets (correctly) overwritten later self.bucketFiles.append(dataFile) # Generate training data file(s). self.trainingDicts = [] uniqueDataDict = OrderedDict() included = [] seqID = 0 for dataEntry in self.dataDict.values(): uniqueID = dataEntry[2] if uniqueID not in included: # skip over the samples that are repeated in multiple buckets uniqueDataDict[seqID] = dataEntry included.append(uniqueID) seqID += 1 self.trainingDicts.append(uniqueDataDict) ndg.reset() ndg.split( dataDict=uniqueDataDict, numLabels=1, textPreprocess=preprocess, **kwargs) for rep in xrange(self.trainingReps): # use a different file for each training rep if not self.orderedSplit: ndg.randomizeData(seed) ndg.stripCategories() # replace the categories w/ seqId dataFile = "{}_network_training_{}{}".format(filename, rep, ext) ndg.saveData(dataFile, self.classificationFile) self.dataFiles.append(dataFile) # TODO: maybe add a method (and arg) for removing all these data files else: # TODO (only if needed) raise NotImplementedError("Must generate data.") # labels references match the classification json self.mapLabelRefs()