def testCloneModelWithDefinitionOnly(self): checkpointMgr = ModelCheckpointMgr() modelID = uuid.uuid1().hex destModelID = uuid.uuid1().hex # Create the source model with meta-info only (no checkpoint) modelDef = {'a': 1, 'b': 2, 'c':3} checkpointMgr.define(modelID, modelDef) # Clone the source model checkpointMgr.clone(modelID, destModelID) # Verify that the destination model's definition is the same as the # source model's destModelDef = checkpointMgr.loadModelDefinition(destModelID) self.assertDictEqual(destModelDef, modelDef) # Calling load when the model checkpoint doesn't exist should raise an # exception with self.assertRaises(ModelNotFound): checkpointMgr.load(destModelID) # Calling clone when the destination model archive already exists should # raise an exception with self.assertRaises(ModelAlreadyExists): checkpointMgr.clone(modelID, destModelID)
def testStorageDirOverrideViaEnvironmentVariable(self): with ModelCheckpointStoragePatch() as storagePatch: checkpointMgr = ModelCheckpointMgr() tempModelCheckpointDir = storagePatch.tempModelCheckpointDir modelEntryDir = checkpointMgr._getModelDir(modelID="abc", mustExist=False) self.assertIn(tempModelCheckpointDir, modelEntryDir)
def testRemoveAll(self): """ Test removeAll """ checkpointMgr = ModelCheckpointMgr() # Should be empty at first ids = checkpointMgr.getModelIDs() self.assertSequenceEqual(ids, []) # Create some checkpoints using meta info expModelIDs = [uuid.uuid1().hex, uuid.uuid1().hex] expModelIDs.sort() for modelID in expModelIDs: checkpointMgr.define(modelID, definition={'a':1}) ids = checkpointMgr.getModelIDs() self.assertItemsEqual(ids, expModelIDs) # Delete checkpoint store ModelCheckpointMgr.removeAll() ids = checkpointMgr.getModelIDs() self.assertSequenceEqual(ids, [])
def testUpdateCheckpointAttributesNoModelEntry(self): """ When a model entry doesn't exist, calling updateCheckpointAttributes should raise ModelNotFound """ checkpointMgr = ModelCheckpointMgr() modelID = uuid.uuid1().hex # Calling updateCheckpointAttributes when the model entry doesn't exist # should raise an exception with self.assertRaises(ModelNotFound): checkpointMgr.updateCheckpointAttributes(modelID, "attributes")
def testUpdateCheckpointAttributesWithNoModeCheckpointException(self): """ When a model entry exists, but a checkpoint hasn't been saved yet, calling updateCheckpointAttributes should raise ModelNotFound """ checkpointMgr = ModelCheckpointMgr() modelID = uuid.uuid1().hex # Save its definition checkpointMgr.define(modelID, definition=dict(a=1, b=2)) with self.assertRaises(ModelNotFound): checkpointMgr.updateCheckpointAttributes(modelID, "attributes")
def testCloneModelFromNonExistentSourceRaisesModelNotFound(self): checkpointMgr = ModelCheckpointMgr() # Create the source model with meta-info only modelID = uuid.uuid1().hex destModelID = uuid.uuid1().hex # Calling clone when the source model archive doesn't exist should # raise an exception with self.assertRaises(ModelNotFound): checkpointMgr.clone(modelID, destModelID) # Let's try it again to make sure that the first attempt did not create # unwanted side-effect with self.assertRaises(ModelNotFound): checkpointMgr.clone(modelID, destModelID)
def testDefineModel(self): """ Test support for defining a model """ checkpointMgr = ModelCheckpointMgr() modelID = uuid.uuid1().hex modelDefinition = {'a': 1, 'b': 2, 'c':3} # Calling loadModelDefinition when the model doesn't exist should raise # ModelNotFound self.assertRaises(ModelNotFound, checkpointMgr.loadModelDefinition, modelID) # Define the model checkpointMgr.define(modelID, definition=modelDefinition) # Load model definition and verify integrity of model definition object retrievedModelDefinition = checkpointMgr.loadModelDefinition(modelID) self.assertDictEqual(retrievedModelDefinition, modelDefinition)
def __init__(self, modelID): """ :param modelID: model ID; string """ self._modelID = modelID # The model object from OPF ModelFactory; set up by the loadModel() method self._model = None # The _InputRowEncoder object; set up by the loadModel() method self._inputRowEncoder = None self._checkpointMgr = ModelCheckpointMgr() # True if model was loaded from existing checkpoint or after a full # checkpoint is made; False initially if model was created from # params. This informs us whether an incremental checkpoint is possible. self._hasCheckpoint = False self._modelCheckpointBatchIDSetCache = None # Input data samples that have accumulated since last full checkpoint self._inputSamplesSinceLastFullCheckpointCache = None
def testUpdateCheckpointAttributes(self): """ Test updateCheckpointAttributes """ checkpointMgr = ModelCheckpointMgr() modelID = uuid.uuid1().hex # Create a model that we can save originalModel = ModelFactory.create(self._getModelParams("variant1")) # Save its definition checkpointMgr.define(modelID, definition=dict(a=1, b=2)) # Save the checkpoint checkpointMgr.save(modelID, originalModel, attributes="attributes1") self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID), "attributes1") # Update model checkpoint attribes newAttributes = dict(this=[1, 2, 3], that="abc") checkpointMgr.updateCheckpointAttributes(modelID, newAttributes) self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID), newAttributes)
class _ModelArchiver(object): """ Helper class for loading/creating and checkpointing model """ # Name of the attribute that is stored as an integral component of the # checkpoint. This attribute contains a list of batchIDs that were processed # since the previous checkpoint. _BATCH_IDS_CHECKPOINT_ATTR_NAME = "batchIDs" # Name of the attribute that is stored as an integral component of the # checkpoint. It contains a list of ModelInputRow objects processed since # the last full checkpoint for use in preparing the last incremental # model checkpoint for new input. The value is in pickle string format. _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME = "incrementalInputSamples" _MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS = 100 def __init__(self, modelID): """ :param modelID: model ID; string """ self._modelID = modelID # The model object from OPF ModelFactory; set up by the loadModel() method self._model = None # The _InputRowEncoder object; set up by the loadModel() method self._inputRowEncoder = None self._checkpointMgr = ModelCheckpointMgr() # True if model was loaded from existing checkpoint or after a full # checkpoint is made; False initially if model was created from # params. This informs us whether an incremental checkpoint is possible. self._hasCheckpoint = False self._modelCheckpointBatchIDSetCache = None # Input data samples that have accumulated since last full checkpoint self._inputSamplesSinceLastFullCheckpointCache = None @property def model(self): """ An OPF Model object or None if not loaded yet """ return self._model @property def inputRowEncoder(self): """ An _InputRowEncoder object or None if model not loaded yet """ return self._inputRowEncoder @property def modelCheckpointBatchIDSet(self): """ A sequence of input batch identifiers associated with current model checkpoint """ if self._modelCheckpointBatchIDSetCache is None: self._loadCheckpointAttributes() return self._modelCheckpointBatchIDSetCache @property def checkpointMgr(self): return self._checkpointMgr @property def _inputSamplesSinceLastFullCheckpoint(self): if self._inputSamplesSinceLastFullCheckpointCache is None: self._loadCheckpointAttributes() return self._inputSamplesSinceLastFullCheckpointCache @_inputSamplesSinceLastFullCheckpoint.setter def _inputSamplesSinceLastFullCheckpoint(self, value): self._inputSamplesSinceLastFullCheckpointCache = value @classmethod def _encodeDataSamples(cls, dataSamples): """ :param dataSamples: a sequence of data samples to be saved as the _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute :returns: a string encoding of the data samples """ return base64.standard_b64encode( pickle.dumps(dataSamples, pickle.HIGHEST_PROTOCOL)) @classmethod def _decodeDataSamples(cls, dataSamples): """ :param dataSamples: string-encoded data samples from the _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute :returns: a sequence of data samples """ return pickle.loads(base64.standard_b64decode(dataSamples)) def _loadCheckpointAttributes(self): # Load the checkpoint attributes try: checkpointAttributes = self._checkpointMgr.loadCheckpointAttributes( self._modelID) except model_checkpoint_mgr.ModelNotFound: self._modelCheckpointBatchIDSetCache = set() self._inputSamplesSinceLastFullCheckpoint = [] else: self._modelCheckpointBatchIDSetCache = set( checkpointAttributes[self._BATCH_IDS_CHECKPOINT_ATTR_NAME]) inputSamples = checkpointAttributes.get( self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME) if inputSamples: self._inputSamplesSinceLastFullCheckpoint = self._decodeDataSamples( inputSamples) else: self._inputSamplesSinceLastFullCheckpoint = [] def loadModel(self): """ Load the model and construct the input row encoder. On success, the loaded model may be accessed via the `model` attribute :raises: model_checkpoint_mgr.ModelNotFound """ if self._model is not None: return modelDefinition = None # Load the model try: self._model = self._checkpointMgr.load(self._modelID) self._hasCheckpoint = True except model_checkpoint_mgr.ModelNotFound: # So, we didn't have a checkpoint... try to create our model from model # definition params self._hasCheckpoint = False try: modelDefinition = self._checkpointMgr.loadModelDefinition( self._modelID) except model_checkpoint_mgr.ModelNotFound: raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL, msg="modelID=%s not found" % (self._modelID)) else: modelParams = modelDefinition["modelParams"] # TODO: when creating the model from params, do we need to call # its model.setFieldStatistics() method? And where will the # fieldStats come from, anyway? self._model = ModelFactory.create( modelConfig=modelParams["modelConfig"]) self._model.enableLearning() self._model.enableInference(modelParams["inferenceArgs"]) # Construct the object for converting a flat input row into a format # that is consumable by an OPF model try: if modelDefinition is None: modelDefinition = self._checkpointMgr.loadModelDefinition( self._modelID) except model_checkpoint_mgr.ModelNotFound: raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL, msg="modelID=%s not found" % (self._modelID)) else: inputSchema = modelDefinition["inputSchema"] # Convert it to a sequence of FieldMetaInfo instances # NOTE: if loadMetaInfo didn't raise, we expect "inputSchema" to be # present; it would be a logic error if it isn't. inputFieldsMeta = tuple(FieldMetaInfo(*f) for f in inputSchema) self._inputRowEncoder = _InputRowEncoder(fieldsMeta=inputFieldsMeta) # If the checkpoint was incremental, feed the cached data into the model for inputSample in self._inputSamplesSinceLastFullCheckpoint: # Convert a flat input sample into a format that is consumable by an OPF # model self._inputRowEncoder.appendRecord(inputSample) # Infer self._model.run(self._inputRowEncoder.getNextRecordDict()) def saveModel(self, currentRunBatchIDSet, currentRunInputSamples): """ :param currentRunBatchIDSet: a set of batch ids to be saved in model checkpoint attributes :param currentRunInputSamples: a sequence of model input data sample objects for incremental checkpoint; will be saved in checkpoint attributes if an incremental checkpoint is performed. """ if self._model is not None: self._modelCheckpointBatchIDSetCache = currentRunBatchIDSet.copy() if (not self._hasCheckpoint or (len(self._inputSamplesSinceLastFullCheckpoint) + len(currentRunInputSamples)) > self._MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS): # Perform a full checkpoint self._inputSamplesSinceLastFullCheckpointCache = [] self._checkpointMgr.save( modelID=self._modelID, model=self._model, attributes={ self._BATCH_IDS_CHECKPOINT_ATTR_NAME: list(self._modelCheckpointBatchIDSetCache) }) self._hasCheckpoint = True else: # Perform an incremental checkpoint self._inputSamplesSinceLastFullCheckpoint.extend( currentRunInputSamples) attributes = { self._BATCH_IDS_CHECKPOINT_ATTR_NAME: list(self._modelCheckpointBatchIDSetCache), self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME: self._encodeDataSamples( self._inputSamplesSinceLastFullCheckpoint) } self._checkpointMgr.updateCheckpointAttributes( self._modelID, attributes)
class _ModelArchiver(object): """ Helper class for loading/creating and checkpointing model """ # Name of the attribute that is stored as an integral component of the # checkpoint. This attribute contains a list of batchIDs that were processed # since the previous checkpoint. _BATCH_IDS_CHECKPOINT_ATTR_NAME = "batchIDs" # Name of the attribute that is stored as an integral component of the # checkpoint. It contains a list of ModelInputRow objects processed since # the last full checkpoint for use in preparing the last incremental # model checkpoint for new input. The value is in pickle string format. _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME = "incrementalInputSamples" _MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS = 100 def __init__(self, modelID): """ :param modelID: model ID; string """ self._modelID = modelID # The model object from OPF ModelFactory; set up by the loadModel() method self._model = None # The _InputRowEncoder object; set up by the loadModel() method self._inputRowEncoder = None self._checkpointMgr = ModelCheckpointMgr() # True if model was loaded from existing checkpoint or after a full # checkpoint is made; False initially if model was created from # params. This informs us whether an incremental checkpoint is possible. self._hasCheckpoint = False self._modelCheckpointBatchIDSetCache = None # Input data samples that have accumulated since last full checkpoint self._inputSamplesSinceLastFullCheckpointCache = None @property def model(self): """ An OPF Model object or None if not loaded yet """ return self._model @property def inputRowEncoder(self): """ An _InputRowEncoder object or None if model not loaded yet """ return self._inputRowEncoder @property def modelCheckpointBatchIDSet(self): """ A sequence of input batch identifiers associated with current model checkpoint """ if self._modelCheckpointBatchIDSetCache is None: self._loadCheckpointAttributes() return self._modelCheckpointBatchIDSetCache @property def checkpointMgr(self): return self._checkpointMgr @property def _inputSamplesSinceLastFullCheckpoint(self): if self._inputSamplesSinceLastFullCheckpointCache is None: self._loadCheckpointAttributes() return self._inputSamplesSinceLastFullCheckpointCache @_inputSamplesSinceLastFullCheckpoint.setter def _inputSamplesSinceLastFullCheckpoint(self, value): self._inputSamplesSinceLastFullCheckpointCache = value @classmethod def _encodeDataSamples(cls, dataSamples): """ :param dataSamples: a sequence of data samples to be saved as the _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute :returns: a string encoding of the data samples """ return base64.standard_b64encode(pickle.dumps(dataSamples, pickle.HIGHEST_PROTOCOL)) @classmethod def _decodeDataSamples(cls, dataSamples): """ :param dataSamples: string-encoded data samples from the _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute :returns: a sequence of data samples """ return pickle.loads(base64.standard_b64decode(dataSamples)) def _loadCheckpointAttributes(self): # Load the checkpoint attributes try: checkpointAttributes = self._checkpointMgr.loadCheckpointAttributes( self._modelID) except model_checkpoint_mgr.ModelNotFound: self._modelCheckpointBatchIDSetCache = set() self._inputSamplesSinceLastFullCheckpoint = [] else: self._modelCheckpointBatchIDSetCache = set( checkpointAttributes[self._BATCH_IDS_CHECKPOINT_ATTR_NAME]) inputSamples = checkpointAttributes.get( self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME) if inputSamples: self._inputSamplesSinceLastFullCheckpoint = self._decodeDataSamples( inputSamples) else: self._inputSamplesSinceLastFullCheckpoint = [] def loadModel(self): """ Load the model and construct the input row encoder. On success, the loaded model may be accessed via the `model` attribute :raises: model_checkpoint_mgr.ModelNotFound """ if self._model is not None: return modelDefinition = None # Load the model try: self._model = self._checkpointMgr.load(self._modelID) self._hasCheckpoint = True except model_checkpoint_mgr.ModelNotFound: # So, we didn't have a checkpoint... try to create our model from model # definition params self._hasCheckpoint = False try: modelDefinition = self._checkpointMgr.loadModelDefinition(self._modelID) except model_checkpoint_mgr.ModelNotFound: raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL, msg="modelID=%s not found" % (self._modelID)) else: modelParams = modelDefinition["modelParams"] # TODO: when creating the model from params, do we need to call # its model.setFieldStatistics() method? And where will the # fieldStats come from, anyway? self._model = ModelFactory.create( modelConfig=modelParams["modelConfig"]) self._model.enableLearning() self._model.enableInference(modelParams["inferenceArgs"]) # Construct the object for converting a flat input row into a format # that is consumable by an OPF model try: if modelDefinition is None: modelDefinition = self._checkpointMgr.loadModelDefinition( self._modelID) except model_checkpoint_mgr.ModelNotFound: raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL, msg="modelID=%s not found" % (self._modelID)) else: inputSchema = modelDefinition["inputSchema"] # Convert it to a sequence of FieldMetaInfo instances # NOTE: if loadMetaInfo didn't raise, we expect "inputSchema" to be # present; it would be a logic error if it isn't. inputFieldsMeta = tuple(FieldMetaInfo(*f) for f in inputSchema) self._inputRowEncoder = _InputRowEncoder(fieldsMeta=inputFieldsMeta) # If the checkpoint was incremental, feed the cached data into the model for inputSample in self._inputSamplesSinceLastFullCheckpoint: # Convert a flat input sample into a format that is consumable by an OPF # model self._inputRowEncoder.appendRecord(inputSample) # Infer self._model.run(self._inputRowEncoder.getNextRecordDict()) def saveModel(self, currentRunBatchIDSet, currentRunInputSamples): """ :param currentRunBatchIDSet: a set of batch ids to be saved in model checkpoint attributes :param currentRunInputSamples: a sequence of model input data sample objects for incremental checkpoint; will be saved in checkpoint attributes if an incremental checkpoint is performed. """ if self._model is not None: self._modelCheckpointBatchIDSetCache = currentRunBatchIDSet.copy() if (not self._hasCheckpoint or (len(self._inputSamplesSinceLastFullCheckpoint) + len(currentRunInputSamples)) > self._MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS): # Perform a full checkpoint self._inputSamplesSinceLastFullCheckpointCache = [] self._checkpointMgr.save( modelID=self._modelID, model=self._model, attributes={ self._BATCH_IDS_CHECKPOINT_ATTR_NAME: list(self._modelCheckpointBatchIDSetCache)}) self._hasCheckpoint = True else: # Perform an incremental checkpoint self._inputSamplesSinceLastFullCheckpoint.extend(currentRunInputSamples) attributes = { self._BATCH_IDS_CHECKPOINT_ATTR_NAME: list(self._modelCheckpointBatchIDSetCache), self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME: self._encodeDataSamples(self._inputSamplesSinceLastFullCheckpoint) } self._checkpointMgr.updateCheckpointAttributes(self._modelID, attributes)
def testRemoveAndGetModelIDs(self): """ Test getModelIDs and remove methods """ checkpointMgr = ModelCheckpointMgr() # Should be empty at first ids = checkpointMgr.getModelIDs() self.assertListEqual(ids, []) # Create some checkpoints using meta info expModelIDs = [uuid.uuid1().hex, uuid.uuid1().hex] expModelIDs.sort() for modelID in expModelIDs: checkpointMgr.define(modelID, definition={'a':1}) ids = checkpointMgr.getModelIDs() ids.sort() self.assertListEqual(ids, expModelIDs) # Delete one of them checkpointMgr.remove(expModelIDs[0]) expModelIDs.remove(expModelIDs[0]) ids = checkpointMgr.getModelIDs() ids.sort() self.assertListEqual(ids, expModelIDs) # Delete all of them for modelID in expModelIDs: checkpointMgr.remove(modelID) ids = checkpointMgr.getModelIDs() self.assertListEqual(ids, []) # If we try and delete a non-existing model, should get an exception self.assertRaises(ModelNotFound, checkpointMgr.remove, "IDx")
def testCloneModelWithCheckpoint(self): checkpointMgr = ModelCheckpointMgr() modelID = uuid.uuid1().hex destModelID = uuid.uuid1().hex # Create the source model with meta-info only (no checkpoint) modelDef = {'a': 1, 'b': 2, 'c':3} checkpointMgr.define(modelID, modelDef) # Create a model that we can clone model1 = ModelFactory.create(self._getModelParams("variant1")) checkpointMgr.save(modelID, model1, attributes="attributes1") # Clone the source model checkpointMgr.clone(modelID, destModelID) # Discard the source model checkpoint checkpointMgr.remove(modelID) # Verify that the destination model's definition is the same as the # source model's destModelDef = checkpointMgr.loadModelDefinition(destModelID) self.assertDictEqual(destModelDef, modelDef) # Verify that the destination model's attributes match the source's attributes = checkpointMgr.loadCheckpointAttributes(destModelID) self.assertEqual(attributes, "attributes1") # Attempt to load the cloned model from checkpoint model = checkpointMgr.load(destModelID) self.assertEqual(str(model.getFieldInfo()), str(model1.getFieldInfo()))
def testModelCheckpointSaveAndLoadSupport(self): """ Test saving and loading models """ checkpointMgr = ModelCheckpointMgr() modelID = uuid.uuid1().hex # Calling load when the model doesn't exist should raise an # exception self.assertRaises(ModelNotFound, checkpointMgr.load, modelID) # Create a model that we can save originalModel = ModelFactory.create(self._getModelParams("variant1")) # Save it checkpointMgr.define(modelID, definition=dict(a=1, b=2)) # Attempting to load model that hasn't been checkpointed should raise # ModelNotFound self.assertRaises(ModelNotFound, checkpointMgr.load, modelID) # Save the checkpoint checkpointMgr.save(modelID, originalModel, attributes="attributes1") # Load the model from the saved checkpoint loadedModel = checkpointMgr.load(modelID) self.assertEqual(str(loadedModel.getFieldInfo()), str(originalModel.getFieldInfo())) del loadedModel del originalModel self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID), "attributes1") # Make sure we can replace an existing model model2 = ModelFactory.create(self._getModelParams("variant2")) checkpointMgr.save(modelID, model2, attributes="attributes2") model = checkpointMgr.load(modelID) self.assertEqual(str(model.getFieldInfo()), str(model2.getFieldInfo())) self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID), "attributes2") model3 = ModelFactory.create(self._getModelParams("variant3")) checkpointMgr.save(modelID, model3, attributes="attributes3") model = checkpointMgr.load(modelID) self.assertEqual(str(model.getFieldInfo()), str(model3.getFieldInfo())) self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID), "attributes3") # Simulate a failure during checkpointing and make sure it doesn't mess # up our already existing checkpoint try: checkpointMgr.save(modelID, "InvalidModel", attributes="attributes4") except AttributeError: pass model = checkpointMgr.load(modelID) self.assertEqual(str(model.getFieldInfo()), str(model3.getFieldInfo())) self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID), "attributes3")
def testCloneModelWithDefinitionOnly(self): checkpointMgr = ModelCheckpointMgr() modelID = uuid.uuid1().hex destModelID = uuid.uuid1().hex # Create the source model with meta-info only (no checkpoint) modelDef = {'a': 1, 'b': 2, 'c': 3} checkpointMgr.define(modelID, modelDef) # Clone the source model checkpointMgr.clone(modelID, destModelID) # Verify that the destination model's definition is the same as the # source model's destModelDef = checkpointMgr.loadModelDefinition(destModelID) self.assertDictEqual(destModelDef, modelDef) # Calling load when the model checkpoint doesn't exist should raise an # exception with self.assertRaises(ModelNotFound): checkpointMgr.load(destModelID) # Calling clone when the destination model archive already exists should # raise an exception with self.assertRaises(ModelAlreadyExists): checkpointMgr.clone(modelID, destModelID)
def testRemoveAndGetModelIDs(self): """ Test getModelIDs and remove methods """ checkpointMgr = ModelCheckpointMgr() # Should be empty at first ids = checkpointMgr.getModelIDs() self.assertListEqual(ids, []) # Create some checkpoints using meta info expModelIDs = [uuid.uuid1().hex, uuid.uuid1().hex] expModelIDs.sort() for modelID in expModelIDs: checkpointMgr.define(modelID, definition={'a': 1}) ids = checkpointMgr.getModelIDs() ids.sort() self.assertListEqual(ids, expModelIDs) # Delete one of them checkpointMgr.remove(expModelIDs[0]) expModelIDs.remove(expModelIDs[0]) ids = checkpointMgr.getModelIDs() ids.sort() self.assertListEqual(ids, expModelIDs) # Delete all of them for modelID in expModelIDs: checkpointMgr.remove(modelID) ids = checkpointMgr.getModelIDs() self.assertListEqual(ids, []) # If we try and delete a non-existing model, should get an exception self.assertRaises(ModelNotFound, checkpointMgr.remove, "IDx")
def testCloneModelWithCheckpoint(self): checkpointMgr = ModelCheckpointMgr() modelID = uuid.uuid1().hex destModelID = uuid.uuid1().hex # Create the source model with meta-info only (no checkpoint) modelDef = {'a': 1, 'b': 2, 'c': 3} checkpointMgr.define(modelID, modelDef) # Create a model that we can clone model1 = ModelFactory.create(self._getModelParams("variant1")) checkpointMgr.save(modelID, model1, attributes="attributes1") # Clone the source model checkpointMgr.clone(modelID, destModelID) # Discard the source model checkpoint checkpointMgr.remove(modelID) # Verify that the destination model's definition is the same as the # source model's destModelDef = checkpointMgr.loadModelDefinition(destModelID) self.assertDictEqual(destModelDef, modelDef) # Verify that the destination model's attributes match the source's attributes = checkpointMgr.loadCheckpointAttributes(destModelID) self.assertEqual(attributes, "attributes1") # Attempt to load the cloned model from checkpoint model = checkpointMgr.load(destModelID) self.assertEqual(str(model.getFieldInfo()), str(model1.getFieldInfo()))