コード例 #1
0
ファイル: utils.py プロジェクト: zggl/discgen
def create_svhn_streams(training_batch_size, monitoring_batch_size):
    """Creates SVHN data streams.

    Parameters
    ----------
    training_batch_size : int
        Batch size for training.
    monitoring_batch_size : int
        Batch size for monitoring.

    Returns
    -------
    rval : tuple of data streams
        Data streams for the main loop, the training set monitor,
        the validation set monitor and the test set monitor.

    """
    train_set = SVHN(2, ('train',), sources=('features',),
                     subset=slice(0, 63257))
    valid_set = SVHN(2, ('train',), sources=('features',),
                     subset=slice(63257, 73257))
    test_set = SVHN(2, ('test',), sources=('features',))

    return create_streams(train_set, valid_set, test_set, training_batch_size,
                          monitoring_batch_size)
コード例 #2
0
def preprocess_svhn(main_loop, save_path):
    h5file = h5py.File(save_path, mode='w')

    ali, = Selector(main_loop.model.top_bricks).select('/ali').bricks
    x = tensor.tensor4('features')
    y = tensor.imatrix('targets')
    params = ali.encoder.apply(x)
    mu = params[:, :ali.encoder._nlat]
    acts = []
    acts += [mu]
    acts += VariableFilter(bricks=[
        ali.encoder.layers[-9], ali.encoder.layers[-6], ali.encoder.layers[-3]
    ],
                           roles=[OUTPUT])(ComputationGraph([mu]).variables)
    output = tensor.concatenate([act.flatten(ndim=2) for act in acts], axis=1)
    preprocess = theano.function([x, y], [output.flatten(ndim=2), y])

    train_set = SVHN(2,
                     which_sets=('train', ),
                     sources=('features', 'targets'))
    train_stream = DataStream.default_stream(train_set,
                                             iteration_scheme=SequentialScheme(
                                                 train_set.num_examples, 100))
    train_features, train_targets = map(
        numpy.vstack,
        list(
            zip(*[
                preprocess(*batch)
                for batch in train_stream.get_epoch_iterator()
            ])))

    test_set = SVHN(2, which_sets=('test', ), sources=('features', 'targets'))
    test_stream = DataStream.default_stream(test_set,
                                            iteration_scheme=SequentialScheme(
                                                test_set.num_examples, 100))
    test_features, test_targets = map(
        numpy.vstack,
        list(
            zip(*[
                preprocess(*batch)
                for batch in test_stream.get_epoch_iterator()
            ])))

    data = (('train', 'features', train_features), ('test', 'features',
                                                    test_features),
            ('train', 'targets', train_targets), ('test', 'targets',
                                                  test_targets))
    fill_hdf5_file(h5file, data)
    for i, label in enumerate(('batch', 'feature')):
        h5file['features'].dims[i].label = label
    for i, label in enumerate(('batch', 'index')):
        h5file['targets'].dims[i].label = label

    h5file.flush()
    h5file.close()
コード例 #3
0
ファイル: streams.py プロジェクト: MiriamHu/ALI
def create_svhn_data_streams(batch_size, monitoring_batch_size, rng=None):
    train_set = SVHN(2, ('extra', ), sources=('features', ))
    valid_set = SVHN(2, ('train', ), sources=('features', ))
    main_loop_stream = DataStream.default_stream(
        train_set,
        iteration_scheme=ShuffledScheme(train_set.num_examples,
                                        batch_size,
                                        rng=rng))
    train_monitor_stream = DataStream.default_stream(
        train_set,
        iteration_scheme=ShuffledScheme(5000, monitoring_batch_size, rng=rng))
    valid_monitor_stream = DataStream.default_stream(
        valid_set,
        iteration_scheme=ShuffledScheme(5000, monitoring_batch_size, rng=rng))
    return main_loop_stream, train_monitor_stream, valid_monitor_stream
コード例 #4
0
def get_svhn(split, sources, load_in_memory):
    from fuel.datasets import SVHN
    if 'test' not in split:
        subset = slice(0, 62000) if 'train' in split else slice(62000, 72000)
        split = ('train', )
    else:
        subset = None
    return SVHN(2,
                split,
                sources=sources,
                subset=subset,
                load_in_memory=load_in_memory)
コード例 #5
0
def test_svhn():
    data_path = config.data_path
    try:
        config.data_path = '.'
        f = h5py.File('svhn_format_2.hdf5', 'w')
        f['features'] = numpy.arange(100, dtype='uint8').reshape((10, 10))
        f['targets'] = numpy.arange(10, dtype='uint8').reshape((10, 1))
        split_dict = {
            'train': {
                'features': (0, 8),
                'targets': (0, 8)
            },
            'test': {
                'features': (8, 10),
                'targets': (8, 10)
            }
        }
        f.attrs['split'] = H5PYDataset.create_split_array(split_dict)
        f.close()
        dataset = SVHN(which_format=2, which_sets=('train', ))
        assert_equal(dataset.filename, 'svhn_format_2.hdf5')
    finally:
        config.data_path = data_path
        os.remove('svhn_format_2.hdf5')