def test_in_memory(): skip_if_not_available(datasets=['mnist']) # Load MNIST and get two batches mnist = MNIST('train') data_stream = DataStream(mnist, iteration_scheme=SequentialScheme( num_examples=mnist.num_examples, batch_size=256)) epoch = data_stream.get_epoch_iterator() for i, (features, targets) in enumerate(epoch): if i == 1: break assert numpy.all(features == mnist.features[256:512]) # Pickle the epoch and make sure that the data wasn't dumped with tempfile.NamedTemporaryFile(delete=False) as f: filename = f.name pickle_dump(epoch, f) assert os.path.getsize(filename) < 1024 * 1024 # Less than 1MB # Reload the epoch and make sure that the state was maintained del epoch with open(filename, 'rb') as f: epoch = cPickle.load(f) features, targets = next(epoch) assert numpy.all(features == mnist.features[512:768])
def secure_pickle_dump(object_, path): """Try pickling into a temporary file and then move.""" try: dirname = os.path.dirname(path) with tempfile.NamedTemporaryFile(delete=False, dir=dirname) as temp: pickle_dump(object_, temp) shutil.move(temp.name, path) except Exception as e: # if "temp" in locals(): # os.remove(temp.name) logger.error(" Error {0}".format(str(e)))
def test_load_log(): log = TrainingLog() log[0].channel0 = 0 # test simple TrainingLog pickles with tempfile.NamedTemporaryFile() as f: pickle_dump(log, f) f.flush() log2 = plot.load_log(f.name) assert log2[0].channel0 == 0 # test MainLoop pickles main_loop = MainLoop(model=None, data_stream=None, algorithm=None, log=log) with tempfile.NamedTemporaryFile() as f: pickle_dump(main_loop, f) f.flush() log2 = plot.load_log(f.name) assert log2[0].channel0 == 0
def test_text(): # Test word level and epochs. with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: sentences1 = f.name f.write("This is a sentence\n") f.write("This another one") with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: sentences2 = f.name f.write("More sentences\n") f.write("The last one") dictionary = {'<UNK>': 0, '</S>': 1, 'this': 2, 'a': 3, 'one': 4} text_data = TextFile(files=[sentences1, sentences2], dictionary=dictionary, bos_token=None, preprocess=lower) stream = text_data.get_default_stream() epoch = stream.get_epoch_iterator() assert len(list(epoch)) == 4 epoch = stream.get_epoch_iterator() for sentence in zip(range(3), epoch): pass f = BytesIO() pickle_dump(epoch, f) sentence = next(epoch) f.seek(0) epoch = cPickle.load(f) assert next(epoch) == sentence assert_raises(StopIteration, next, epoch) # Test character level. dictionary = dict([(chr(ord('a') + i), i) for i in range(26)] + [(' ', 26)] + [('<S>', 27)] + [('</S>', 28)] + [('<UNK>', 29)]) text_data = TextFile(files=[sentences1, sentences2], dictionary=dictionary, preprocess=lower, level="character") sentence = next(text_data.get_default_stream().get_epoch_iterator())[0] assert sentence[:3] == [27, 19, 7] assert sentence[-3:] == [2, 4, 28]
def dump_log(self, main_loop): with open(self.path_to_log, "wb") as destination: pickle_dump(main_loop.log, destination)
def dump_iteration_state(self, main_loop): with open(self.path_to_iteration_state, "wb") as destination: pickle_dump(main_loop.iteration_state, destination)