Пример #1
0
def test_in_memory():
    skip_if_not_available(datasets=['mnist'])
    # Load MNIST and get two batches
    mnist = MNIST('train')
    data_stream = DataStream(mnist,
                             iteration_scheme=SequentialScheme(
                                 num_examples=mnist.num_examples,
                                 batch_size=256))
    epoch = data_stream.get_epoch_iterator()
    for i, (features, targets) in enumerate(epoch):
        if i == 1:
            break
    assert numpy.all(features == mnist.features[256:512])

    # Pickle the epoch and make sure that the data wasn't dumped
    with tempfile.NamedTemporaryFile(delete=False) as f:
        filename = f.name
        pickle_dump(epoch, f)
    assert os.path.getsize(filename) < 1024 * 1024  # Less than 1MB

    # Reload the epoch and make sure that the state was maintained
    del epoch
    with open(filename, 'rb') as f:
        epoch = cPickle.load(f)
    features, targets = next(epoch)
    assert numpy.all(features == mnist.features[512:768])
Пример #2
0
def secure_pickle_dump(object_, path):
    """Try pickling into a temporary file and then move."""
    try:
        dirname = os.path.dirname(path)
        with tempfile.NamedTemporaryFile(delete=False, dir=dirname) as temp:
            pickle_dump(object_, temp)
        shutil.move(temp.name, path)
    except Exception as e:
        # if "temp" in locals():
        #    os.remove(temp.name)
        logger.error(" Error {0}".format(str(e)))
Пример #3
0
def test_load_log():
    log = TrainingLog()
    log[0].channel0 = 0

    # test simple TrainingLog pickles
    with tempfile.NamedTemporaryFile() as f:
        pickle_dump(log, f)
        f.flush()

        log2 = plot.load_log(f.name)
        assert log2[0].channel0 == 0

    # test MainLoop pickles
    main_loop = MainLoop(model=None, data_stream=None, algorithm=None, log=log)

    with tempfile.NamedTemporaryFile() as f:
        pickle_dump(main_loop, f)
        f.flush()

        log2 = plot.load_log(f.name)
        assert log2[0].channel0 == 0
Пример #4
0
def test_text():
    # Test word level and epochs.
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
        sentences1 = f.name
        f.write("This is a sentence\n")
        f.write("This another one")
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
        sentences2 = f.name
        f.write("More sentences\n")
        f.write("The last one")
    dictionary = {'<UNK>': 0, '</S>': 1, 'this': 2, 'a': 3, 'one': 4}
    text_data = TextFile(files=[sentences1, sentences2],
                         dictionary=dictionary,
                         bos_token=None,
                         preprocess=lower)
    stream = text_data.get_default_stream()
    epoch = stream.get_epoch_iterator()
    assert len(list(epoch)) == 4
    epoch = stream.get_epoch_iterator()
    for sentence in zip(range(3), epoch):
        pass
    f = BytesIO()
    pickle_dump(epoch, f)
    sentence = next(epoch)
    f.seek(0)
    epoch = cPickle.load(f)
    assert next(epoch) == sentence
    assert_raises(StopIteration, next, epoch)

    # Test character level.
    dictionary = dict([(chr(ord('a') + i), i)
                       for i in range(26)] + [(' ', 26)] + [('<S>', 27)] +
                      [('</S>', 28)] + [('<UNK>', 29)])
    text_data = TextFile(files=[sentences1, sentences2],
                         dictionary=dictionary,
                         preprocess=lower,
                         level="character")
    sentence = next(text_data.get_default_stream().get_epoch_iterator())[0]
    assert sentence[:3] == [27, 19, 7]
    assert sentence[-3:] == [2, 4, 28]
Пример #5
0
def test_load_log():
    log = TrainingLog()
    log[0].channel0 = 0

    # test simple TrainingLog pickles
    with tempfile.NamedTemporaryFile() as f:
        pickle_dump(log, f)
        f.flush()

        log2 = plot.load_log(f.name)
        assert log2[0].channel0 == 0

    # test MainLoop pickles
    main_loop = MainLoop(model=None, data_stream=None,
                         algorithm=None, log=log)

    with tempfile.NamedTemporaryFile() as f:
        pickle_dump(main_loop, f)
        f.flush()

        log2 = plot.load_log(f.name)
        assert log2[0].channel0 == 0
Пример #6
0
 def dump_log(self, main_loop):
     with open(self.path_to_log, "wb") as destination:
         pickle_dump(main_loop.log, destination)
Пример #7
0
 def dump_iteration_state(self, main_loop):
     with open(self.path_to_iteration_state, "wb") as destination:
         pickle_dump(main_loop.iteration_state, destination)
Пример #8
0
 def dump_log(self, main_loop):
     with open(self.path_to_log, "wb") as destination:
         pickle_dump(main_loop.log, destination)
Пример #9
0
 def dump_iteration_state(self, main_loop):
     with open(self.path_to_iteration_state, "wb") as destination:
         pickle_dump(main_loop.iteration_state, destination)