def setup_mnist_corpus(datafiles, labels, ngram_size, shuffle=True, limit_instance_count=None, subsets=("train", )): file_loader = partial(load_mnist_files, subsets=subsets) viewpoints = (MnistVP((28, 28)), ) corpus = Corpus(file_loader, viewpoints) LOGGER.info('loading data files') corpus.load_files(datafiles) LOGGER.info('constructing ngrams') corpus.set_to_ngram(ngram_size) if shuffle: corpus.shuffle_instances() if limit_instance_count is not None: corpus.limit_instance_count(limit_instance_count) return corpus
def setup_midi_corpus(datafiles, ngram_size, rebuild=False): viewpoints = (MidiPitchVP(min_pitch=36, max_pitch=100), ) corpus = Corpus(load_midi_files, viewpoints) LOGGER.info('loading data files') corpus.load_files(datafiles, clear=rebuild) corpus.set_to_ngram(ngram_size=ngram_size) return corpus
def setup_interval_corpus(datafiles, ngram_size, rebuild=False): viewpoints = (MidiPitchIntervalVP(), ) corpus = Corpus(load_midi_files, viewpoints) LOGGER.info('loading data files') corpus.load_files(datafiles, clear=rebuild) corpus.set_to_ngram(ngram_size=1) return corpus