def testUpdateCheckpointAttributesNoModelEntry(self):
    """ When a model entry doesn't exist, calling  updateCheckpointAttributes
    should raise ModelNotFound
    """
    checkpointMgr = ModelCheckpointMgr()

    modelID = uuid.uuid1().hex

    # Calling updateCheckpointAttributes when the model entry doesn't exist
    # should raise an exception
    with self.assertRaises(ModelNotFound):
      checkpointMgr.updateCheckpointAttributes(modelID, "attributes")
    def testUpdateCheckpointAttributesNoModelEntry(self):
        """ When a model entry doesn't exist, calling  updateCheckpointAttributes
    should raise ModelNotFound
    """
        checkpointMgr = ModelCheckpointMgr()

        modelID = uuid.uuid1().hex

        # Calling updateCheckpointAttributes when the model entry doesn't exist
        # should raise an exception
        with self.assertRaises(ModelNotFound):
            checkpointMgr.updateCheckpointAttributes(modelID, "attributes")
  def testUpdateCheckpointAttributesWithNoModeCheckpointException(self):
    """ When a model entry exists, but a checkpoint hasn't been saved yet,
    calling  updateCheckpointAttributes should raise ModelNotFound
    """
    checkpointMgr = ModelCheckpointMgr()

    modelID = uuid.uuid1().hex

    # Save its definition
    checkpointMgr.define(modelID, definition=dict(a=1, b=2))

    with self.assertRaises(ModelNotFound):
      checkpointMgr.updateCheckpointAttributes(modelID, "attributes")
    def testUpdateCheckpointAttributesWithNoModeCheckpointException(self):
        """ When a model entry exists, but a checkpoint hasn't been saved yet,
    calling  updateCheckpointAttributes should raise ModelNotFound
    """
        checkpointMgr = ModelCheckpointMgr()

        modelID = uuid.uuid1().hex

        # Save its definition
        checkpointMgr.define(modelID, definition=dict(a=1, b=2))

        with self.assertRaises(ModelNotFound):
            checkpointMgr.updateCheckpointAttributes(modelID, "attributes")
  def testUpdateCheckpointAttributes(self):
    """ Test updateCheckpointAttributes """
    checkpointMgr = ModelCheckpointMgr()

    modelID = uuid.uuid1().hex

    # Create a model that we can save
    originalModel = ModelFactory.create(self._getModelParams("variant1"))

    # Save its definition
    checkpointMgr.define(modelID, definition=dict(a=1, b=2))

    # Save the checkpoint
    checkpointMgr.save(modelID, originalModel, attributes="attributes1")

    self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                     "attributes1")

    # Update model checkpoint attribes
    newAttributes = dict(this=[1, 2, 3], that="abc")
    checkpointMgr.updateCheckpointAttributes(modelID, newAttributes)

    self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                     newAttributes)
    def testUpdateCheckpointAttributes(self):
        """ Test updateCheckpointAttributes """
        checkpointMgr = ModelCheckpointMgr()

        modelID = uuid.uuid1().hex

        # Create a model that we can save
        originalModel = ModelFactory.create(self._getModelParams("variant1"))

        # Save its definition
        checkpointMgr.define(modelID, definition=dict(a=1, b=2))

        # Save the checkpoint
        checkpointMgr.save(modelID, originalModel, attributes="attributes1")

        self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                         "attributes1")

        # Update model checkpoint attribes
        newAttributes = dict(this=[1, 2, 3], that="abc")
        checkpointMgr.updateCheckpointAttributes(modelID, newAttributes)

        self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                         newAttributes)
Ejemplo n.º 7
0
class _ModelArchiver(object):
    """ Helper class for loading/creating and checkpointing model
  """

    # Name of the attribute that is stored as an integral component of the
    # checkpoint. This attribute contains a list of batchIDs that were processed
    # since the previous checkpoint.
    _BATCH_IDS_CHECKPOINT_ATTR_NAME = "batchIDs"

    # Name of the attribute that is stored as an integral component of the
    # checkpoint. It contains a list of ModelInputRow objects processed since
    # the last full checkpoint for use in preparing the last incremental
    # model checkpoint for new input. The value is in pickle string format.
    _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME = "incrementalInputSamples"

    _MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS = 100

    def __init__(self, modelID):
        """
    :param modelID: model ID; string
    """
        self._modelID = modelID

        # The model object from OPF ModelFactory; set up by the loadModel() method
        self._model = None

        # The _InputRowEncoder object; set up by the loadModel() method
        self._inputRowEncoder = None

        self._checkpointMgr = ModelCheckpointMgr()

        # True if model was loaded from existing checkpoint or after a full
        # checkpoint is made; False initially if model was created from
        # params. This informs us whether an incremental checkpoint is possible.
        self._hasCheckpoint = False

        self._modelCheckpointBatchIDSetCache = None

        # Input data samples that have accumulated since last full checkpoint
        self._inputSamplesSinceLastFullCheckpointCache = None

    @property
    def model(self):
        """ An OPF Model object or None if not loaded yet """
        return self._model

    @property
    def inputRowEncoder(self):
        """ An _InputRowEncoder object or None if model not loaded yet """
        return self._inputRowEncoder

    @property
    def modelCheckpointBatchIDSet(self):
        """ A sequence of input batch identifiers associated with current
    model checkpoint
    """
        if self._modelCheckpointBatchIDSetCache is None:
            self._loadCheckpointAttributes()
        return self._modelCheckpointBatchIDSetCache

    @property
    def checkpointMgr(self):
        return self._checkpointMgr

    @property
    def _inputSamplesSinceLastFullCheckpoint(self):
        if self._inputSamplesSinceLastFullCheckpointCache is None:
            self._loadCheckpointAttributes()
        return self._inputSamplesSinceLastFullCheckpointCache

    @_inputSamplesSinceLastFullCheckpoint.setter
    def _inputSamplesSinceLastFullCheckpoint(self, value):
        self._inputSamplesSinceLastFullCheckpointCache = value

    @classmethod
    def _encodeDataSamples(cls, dataSamples):
        """
    :param dataSamples: a sequence of data samples to be saved as the
      _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute

    :returns: a string encoding of the data samples
    """
        return base64.standard_b64encode(
            pickle.dumps(dataSamples, pickle.HIGHEST_PROTOCOL))

    @classmethod
    def _decodeDataSamples(cls, dataSamples):
        """
    :param dataSamples: string-encoded data samples from the
      _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute

    :returns: a sequence of data samples
    """
        return pickle.loads(base64.standard_b64decode(dataSamples))

    def _loadCheckpointAttributes(self):
        # Load the checkpoint attributes
        try:
            checkpointAttributes = self._checkpointMgr.loadCheckpointAttributes(
                self._modelID)
        except model_checkpoint_mgr.ModelNotFound:
            self._modelCheckpointBatchIDSetCache = set()
            self._inputSamplesSinceLastFullCheckpoint = []
        else:
            self._modelCheckpointBatchIDSetCache = set(
                checkpointAttributes[self._BATCH_IDS_CHECKPOINT_ATTR_NAME])

            inputSamples = checkpointAttributes.get(
                self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME)
            if inputSamples:
                self._inputSamplesSinceLastFullCheckpoint = self._decodeDataSamples(
                    inputSamples)
            else:
                self._inputSamplesSinceLastFullCheckpoint = []

    def loadModel(self):
        """ Load the model and construct the input row encoder. On success,
    the loaded model may be accessed via the `model` attribute

    :raises: model_checkpoint_mgr.ModelNotFound
    """
        if self._model is not None:
            return

        modelDefinition = None

        # Load the model
        try:
            self._model = self._checkpointMgr.load(self._modelID)
            self._hasCheckpoint = True
        except model_checkpoint_mgr.ModelNotFound:
            # So, we didn't have a checkpoint... try to create our model from model
            # definition params
            self._hasCheckpoint = False
            try:
                modelDefinition = self._checkpointMgr.loadModelDefinition(
                    self._modelID)
            except model_checkpoint_mgr.ModelNotFound:
                raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL,
                                        msg="modelID=%s not found" %
                                        (self._modelID))
            else:
                modelParams = modelDefinition["modelParams"]

            # TODO: when creating the model from params, do we need to call
            #   its model.setFieldStatistics() method? And where will the
            #   fieldStats come from, anyway?
            self._model = ModelFactory.create(
                modelConfig=modelParams["modelConfig"])
            self._model.enableLearning()
            self._model.enableInference(modelParams["inferenceArgs"])

        # Construct the object for converting a flat input row into a format
        # that is consumable by an OPF model
        try:
            if modelDefinition is None:
                modelDefinition = self._checkpointMgr.loadModelDefinition(
                    self._modelID)
        except model_checkpoint_mgr.ModelNotFound:
            raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL,
                                    msg="modelID=%s not found" %
                                    (self._modelID))
        else:
            inputSchema = modelDefinition["inputSchema"]

        # Convert it to a sequence of FieldMetaInfo instances
        # NOTE: if loadMetaInfo didn't raise, we expect "inputSchema" to be
        #   present; it would be a logic error if it isn't.
        inputFieldsMeta = tuple(FieldMetaInfo(*f) for f in inputSchema)

        self._inputRowEncoder = _InputRowEncoder(fieldsMeta=inputFieldsMeta)

        # If the checkpoint was incremental, feed the cached data into the model
        for inputSample in self._inputSamplesSinceLastFullCheckpoint:
            # Convert a flat input sample into a format that is consumable by an OPF
            # model
            self._inputRowEncoder.appendRecord(inputSample)

            # Infer
            self._model.run(self._inputRowEncoder.getNextRecordDict())

    def saveModel(self, currentRunBatchIDSet, currentRunInputSamples):
        """
    :param currentRunBatchIDSet: a set of batch ids to be saved in model
      checkpoint attributes

    :param currentRunInputSamples: a sequence of model input data sample objects
      for incremental checkpoint; will be saved in checkpoint attributes if an
      incremental checkpoint is performed.
    """
        if self._model is not None:
            self._modelCheckpointBatchIDSetCache = currentRunBatchIDSet.copy()

            if (not self._hasCheckpoint
                    or (len(self._inputSamplesSinceLastFullCheckpoint) +
                        len(currentRunInputSamples)) >
                    self._MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS):
                # Perform a full checkpoint
                self._inputSamplesSinceLastFullCheckpointCache = []

                self._checkpointMgr.save(
                    modelID=self._modelID,
                    model=self._model,
                    attributes={
                        self._BATCH_IDS_CHECKPOINT_ATTR_NAME:
                        list(self._modelCheckpointBatchIDSetCache)
                    })

                self._hasCheckpoint = True
            else:
                # Perform an incremental checkpoint
                self._inputSamplesSinceLastFullCheckpoint.extend(
                    currentRunInputSamples)
                attributes = {
                    self._BATCH_IDS_CHECKPOINT_ATTR_NAME:
                    list(self._modelCheckpointBatchIDSetCache),
                    self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME:
                    self._encodeDataSamples(
                        self._inputSamplesSinceLastFullCheckpoint)
                }

                self._checkpointMgr.updateCheckpointAttributes(
                    self._modelID, attributes)
Ejemplo n.º 8
0
class _ModelArchiver(object):
  """ Helper class for loading/creating and checkpointing model
  """

  # Name of the attribute that is stored as an integral component of the
  # checkpoint. This attribute contains a list of batchIDs that were processed
  # since the previous checkpoint.
  _BATCH_IDS_CHECKPOINT_ATTR_NAME = "batchIDs"

  # Name of the attribute that is stored as an integral component of the
  # checkpoint. It contains a list of ModelInputRow objects processed since
  # the last full checkpoint for use in preparing the last incremental
  # model checkpoint for new input. The value is in pickle string format.
  _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME = "incrementalInputSamples"

  _MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS = 100


  def __init__(self, modelID):
    """
    :param modelID: model ID; string
    """
    self._modelID = modelID

    # The model object from OPF ModelFactory; set up by the loadModel() method
    self._model = None

    # The _InputRowEncoder object; set up by the loadModel() method
    self._inputRowEncoder = None

    self._checkpointMgr = ModelCheckpointMgr()

    # True if model was loaded from existing checkpoint or after a full
    # checkpoint is made; False initially if model was created from
    # params. This informs us whether an incremental checkpoint is possible.
    self._hasCheckpoint = False

    self._modelCheckpointBatchIDSetCache = None

    # Input data samples that have accumulated since last full checkpoint
    self._inputSamplesSinceLastFullCheckpointCache = None


  @property
  def model(self):
    """ An OPF Model object or None if not loaded yet """
    return self._model


  @property
  def inputRowEncoder(self):
    """ An _InputRowEncoder object or None if model not loaded yet """
    return self._inputRowEncoder


  @property
  def modelCheckpointBatchIDSet(self):
    """ A sequence of input batch identifiers associated with current
    model checkpoint
    """
    if self._modelCheckpointBatchIDSetCache is None:
      self._loadCheckpointAttributes()
    return self._modelCheckpointBatchIDSetCache


  @property
  def checkpointMgr(self):
    return self._checkpointMgr


  @property
  def _inputSamplesSinceLastFullCheckpoint(self):
    if self._inputSamplesSinceLastFullCheckpointCache is None:
      self._loadCheckpointAttributes()
    return self._inputSamplesSinceLastFullCheckpointCache


  @_inputSamplesSinceLastFullCheckpoint.setter
  def _inputSamplesSinceLastFullCheckpoint(self, value):
    self._inputSamplesSinceLastFullCheckpointCache = value


  @classmethod
  def _encodeDataSamples(cls, dataSamples):
    """
    :param dataSamples: a sequence of data samples to be saved as the
      _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute

    :returns: a string encoding of the data samples
    """
    return base64.standard_b64encode(pickle.dumps(dataSamples,
                                                  pickle.HIGHEST_PROTOCOL))


  @classmethod
  def _decodeDataSamples(cls, dataSamples):
    """
    :param dataSamples: string-encoded data samples from the
      _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute

    :returns: a sequence of data samples
    """
    return pickle.loads(base64.standard_b64decode(dataSamples))


  def _loadCheckpointAttributes(self):
    # Load the checkpoint attributes
    try:
      checkpointAttributes = self._checkpointMgr.loadCheckpointAttributes(
        self._modelID)
    except model_checkpoint_mgr.ModelNotFound:
      self._modelCheckpointBatchIDSetCache = set()
      self._inputSamplesSinceLastFullCheckpoint = []
    else:
      self._modelCheckpointBatchIDSetCache = set(
        checkpointAttributes[self._BATCH_IDS_CHECKPOINT_ATTR_NAME])

      inputSamples = checkpointAttributes.get(
        self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME)
      if inputSamples:
        self._inputSamplesSinceLastFullCheckpoint = self._decodeDataSamples(
          inputSamples)
      else:
        self._inputSamplesSinceLastFullCheckpoint = []


  def loadModel(self):
    """ Load the model and construct the input row encoder. On success,
    the loaded model may be accessed via the `model` attribute

    :raises: model_checkpoint_mgr.ModelNotFound
    """
    if self._model is not None:
      return

    modelDefinition = None

    # Load the model
    try:
      self._model = self._checkpointMgr.load(self._modelID)
      self._hasCheckpoint = True
    except model_checkpoint_mgr.ModelNotFound:
      # So, we didn't have a checkpoint... try to create our model from model
      # definition params
      self._hasCheckpoint = False
      try:
        modelDefinition = self._checkpointMgr.loadModelDefinition(self._modelID)
      except model_checkpoint_mgr.ModelNotFound:
        raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL,
                                msg="modelID=%s not found" % (self._modelID))
      else:
        modelParams = modelDefinition["modelParams"]

      # TODO: when creating the model from params, do we need to call
      #   its model.setFieldStatistics() method? And where will the
      #   fieldStats come from, anyway?
      self._model = ModelFactory.create(
        modelConfig=modelParams["modelConfig"])
      self._model.enableLearning()
      self._model.enableInference(modelParams["inferenceArgs"])


    # Construct the object for converting a flat input row into a format
    # that is consumable by an OPF model
    try:
      if modelDefinition is None:
        modelDefinition = self._checkpointMgr.loadModelDefinition(
          self._modelID)
    except model_checkpoint_mgr.ModelNotFound:
      raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL,
                              msg="modelID=%s not found" % (self._modelID))
    else:
      inputSchema = modelDefinition["inputSchema"]

    # Convert it to a sequence of FieldMetaInfo instances
    # NOTE: if loadMetaInfo didn't raise, we expect "inputSchema" to be
    #   present; it would be a logic error if it isn't.
    inputFieldsMeta = tuple(FieldMetaInfo(*f) for f in inputSchema)

    self._inputRowEncoder = _InputRowEncoder(fieldsMeta=inputFieldsMeta)

    # If the checkpoint was incremental, feed the cached data into the model
    for inputSample in self._inputSamplesSinceLastFullCheckpoint:
      # Convert a flat input sample into a format that is consumable by an OPF
      # model
      self._inputRowEncoder.appendRecord(inputSample)

      # Infer
      self._model.run(self._inputRowEncoder.getNextRecordDict())


  def saveModel(self, currentRunBatchIDSet, currentRunInputSamples):
    """
    :param currentRunBatchIDSet: a set of batch ids to be saved in model
      checkpoint attributes

    :param currentRunInputSamples: a sequence of model input data sample objects
      for incremental checkpoint; will be saved in checkpoint attributes if an
      incremental checkpoint is performed.
    """
    if self._model is not None:
      self._modelCheckpointBatchIDSetCache = currentRunBatchIDSet.copy()

      if (not self._hasCheckpoint or
          (len(self._inputSamplesSinceLastFullCheckpoint) +
           len(currentRunInputSamples)) >
          self._MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS):
        # Perform a full checkpoint
        self._inputSamplesSinceLastFullCheckpointCache = []

        self._checkpointMgr.save(
          modelID=self._modelID, model=self._model,
          attributes={
            self._BATCH_IDS_CHECKPOINT_ATTR_NAME:
              list(self._modelCheckpointBatchIDSetCache)})

        self._hasCheckpoint = True
      else:
        # Perform an incremental checkpoint
        self._inputSamplesSinceLastFullCheckpoint.extend(currentRunInputSamples)
        attributes = {
          self._BATCH_IDS_CHECKPOINT_ATTR_NAME:
            list(self._modelCheckpointBatchIDSetCache),

          self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME:
            self._encodeDataSamples(self._inputSamplesSinceLastFullCheckpoint)
        }

        self._checkpointMgr.updateCheckpointAttributes(self._modelID,
                                                       attributes)