Exemple #1
0
def test_splitter():
    ds = give_data()
    # split with defaults
    spl1 = Splitter('chunks')
    assert_raises(NotImplementedError, spl1, ds)

    splits = list(spl1.generate(ds))
    assert_equal(len(splits), len(ds.sa['chunks'].unique))

    for split in splits:
        # it should have perform basic slicing!
        assert_true(split.samples.base is ds.samples)
        assert_equal(len(split.sa['chunks'].unique), 1)
        assert_true('lastsplit' in split.a)
    assert_true(splits[-1].a.lastsplit)

    # now again, more customized
    spl2 = Splitter('targets', attr_values = [0,1,1,2,3,3,3], count=4,
                   noslicing=True)
    splits = list(spl2.generate(ds))
    assert_equal(len(splits), 4)
    for split in splits:
        # it should NOT have perform basic slicing!
        assert_false(split.samples.base is ds.samples)
        assert_equal(len(split.sa['targets'].unique), 1)
        assert_equal(len(split.sa['chunks'].unique), 10)
    assert_true(splits[-1].a.lastsplit)

    # two should be identical
    assert_array_equal(splits[1].samples, splits[2].samples)

    # now go wild and split by feature attribute
    ds.fa['roi'] = np.repeat([0,1], 5)
    # splitter should auto-detect that this is a feature attribute
    spl3 = Splitter('roi')
    splits = list(spl3.generate(ds))
    assert_equal(len(splits), 2)
    for split in splits:
        assert_true(split.samples.base is ds.samples)
        assert_equal(len(split.fa['roi'].unique), 1)
        assert_equal(split.shape, (100, 5))

    # and finally test chained splitters
    cspl = ChainNode([spl2, spl3, spl1])
    splits = list(cspl.generate(ds))
    # 4 target splits and 2 roi splits each and 10 chunks each
    assert_equal(len(splits), 80)
Exemple #2
0
def test_sifter():
    # somewhat duplicating the doctest
    ds = Dataset(samples=np.arange(8).reshape((4,2)),
                 sa={'chunks':   [ 0 ,  1 ,  2 ,  3 ],
                     'targets':  ['c', 'c', 'p', 'p']})
    par = ChainNode([NFoldPartitioner(cvtype=2, attr='chunks'),
                     Sifter([('partitions', 2),
                             ('targets', ['c', 'p'])])
                     ])
    dss = list(par.generate(ds))
    assert_equal(len(dss), 4)
    for ds_ in dss:
        testing = ds[ds_.sa.partitions == 2]
        assert_array_equal(np.unique(testing.sa.targets), ['c', 'p'])
        # and we still have both targets  present in training
        training = ds[ds_.sa.partitions == 1]
        assert_array_equal(np.unique(training.sa.targets), ['c', 'p'])
Exemple #3
0
    def test_discarded_boundaries(self):
        ds = datasets['hollow']
        # four runs
        ds.sa['chunks'] = np.repeat(np.arange(4), 10)
        # do odd even splitting for lots of boundaries in few splits
        part = ChainNode([OddEvenPartitioner(),
                          StripBoundariesSamples('chunks', 1, 2)])

        parts = [d.samples.sid for d in part.generate(ds)]

        # both dataset should have the same samples, because the boundaries are
        # identical and the same sample should be stripped
        assert_array_equal(parts[0], parts[1])

        # we strip 3 samples per boundary
        assert_equal(len(parts[0]), len(ds) - (3 * 3))

        for i in [9, 10, 11, 19, 20, 21, 29, 30, 31]:
            assert_false(i in parts[0])
Exemple #4
0
def plot_feature_hist(dataset, xlim=None, noticks=True,
                      targets_attr='targets', chunks_attr=None,
                    **kwargs):
    """Plot histograms of feature values for each labels.

    Parameters
    ----------
    dataset : Dataset
    xlim : None or 2-tuple
      Common x-axis limits for all histograms.
    noticks : bool
      If True, no axis ticks will be plotted. This is useful to save
      space in large plots.
    targets_attr : string, optional
      Name of samples attribute to be used as targets
    chunks_attr : None or string
      If a string, a histogram will be plotted per each target and each
      chunk (as defined in sa named `chunks_attr`), resulting is a
      histogram grid (targets x chunks).
    **kwargs
      Any additional arguments are passed to matplotlib's hist().
    """
    lsplit = ChainNode([NFoldPartitioner(1, attr=targets_attr),
                        Splitter('partitions', attr_values=[2])])
    csplit = ChainNode([NFoldPartitioner(1, attr=chunks_attr),
                        Splitter('partitions', attr_values=[2])])

    nrows = len(dataset.sa[targets_attr].unique)
    ncols = len(dataset.sa[chunks_attr].unique)

    def doplot(data):
        """Just a little helper which plots the histogram and removes
        ticks etc"""

        pl.hist(data, **kwargs)

        if xlim is not None:
            pl.xlim(xlim)

        if noticks:
            pl.yticks([])
            pl.xticks([])

    fig = 1

    # for all labels
    for row, ds in enumerate(lsplit.generate(dataset)):
        if chunks_attr:
            for col, d in enumerate(csplit.generate(ds)):

                pl.subplot(nrows, ncols, fig)
                doplot(d.samples.ravel())

                if row == 0:
                    pl.title('C:' + str(d.sa[chunks_attr].unique[0]))
                if col == 0:
                    pl.ylabel('L:' + str(d.sa[targets_attr].unique[0]))

                fig += 1
        else:
            pl.subplot(1, nrows, fig)
            doplot(ds.samples)

            pl.title('L:' + str(ds.sa[targets_attr].unique[0]))

            fig += 1