Exemplo n.º 1
0
 def partitionIndices(self, seed=42):
     """
 Partitions list of two-tuples of train and test indices for each trial.
 """
     if self.experimentType == "k-folds":
         self.partitions = KFolds(self.folds).split(
             range(len(self.samples)),
             randomize=(not self.orderedSplit),
             seed=seed)
     else:
         # TODO: use StandardSplit in data_split.py
         length = len(self.samples)
         if self.orderedSplit:
             for split in self.trainSizes:
                 trainIndices = range(split)
                 testIndices = range(split, length)
                 self.partitions.append((trainIndices, testIndices))
         else:
             # randomly sampled, not repeated
             random.seed(seed)
             for split in self.trainSizes:
                 trainIndices = random.sample(xrange(length), split)
                 testIndices = [
                     i for i in xrange(length) if i not in trainIndices
                 ]
                 self.partitions.append((trainIndices, testIndices))
Exemplo n.º 2
0
    def partitionIndices(self, _):
        """
    Sets self.partitions for the number of tokens for each sample in the
    training and test sets.

    The order of sequences is already specified by the network data files; if
    generated by the experiment, these are in order or randomized as specified
    by the orderedSplit arg.
    """
        if self.experimentType == "k-folds":
            for fold in xrange(self.folds):
                dataFile = self.dataFiles[fold]
                numTokens = NetworkDataGenerator.getNumberOfTokens(dataFile)
                self.partitions = KFolds(self.folds).split(numTokens,
                                                           randomize=False)
        else:
            for trial, split in enumerate(self.trainSizes):
                dataFile = self.dataFiles[trial]
                numTokens = NetworkDataGenerator.getNumberOfTokens(dataFile)
                self.partitions.append((numTokens[:split], numTokens[split:]))