def test_save_models_output_dir(tmpdir): m = Sequential() m.add(Dense(3, input_dim=3)) fname = "test.hdf5" cb = SaveModels({fname: m}, output_dir=str(tmpdir), every_epoch=10) cb.on_epoch_end(10) assert os.path.exists(os.path.join(str(tmpdir), fname))
def test_save_models_overwrite(tmpdir): m = Sequential() m.add(Dense(3, input_dim=3)) fname = str(tmpdir) + "/test.hdf5" cb = SaveModels({fname: m}, every_epoch=10, overwrite=False) cb.on_epoch_end(10) assert os.path.exists(fname) with pytest.raises(OSError): cb.on_epoch_end(10)
def test_save_models(tmpdir): m = Sequential() m.add(Dense(3, input_dim=3)) fname = str(tmpdir) + "/test.hdf5" cb = SaveModels({fname: m}, every_epoch=10) cb.on_epoch_end(0) assert not os.path.exists(fname) cb.on_epoch_end(10) assert os.path.exists(fname)